PyTorch Text Classifier

This tutorial will show you how to train and deploy a Pytorch model using the Qwak platform. We will follow the steps from one of the official Pytorch tutorials and show how to run it in the Qwak platform.

We will use the built-in AG_NEWS dataset and train a text classifier. It will predict whether the given text belongs to four categories: World News, Sports, Business, or Science and Technology.

Creating an empty project

Before we start, we have to create an empty Qwak project using the Qwak CLI:

qwak models init --model-class-name PyTorchExample --model-directory model-example .

In the model-example directory, you will find the automatically-generated Python classes.

Defining dependencies

First, we have to define the dependencies. Let's use a conda.yml file with the following content:

channels:
  - pytorch
  - conda-forge
  
dependencies:
  - python=3.9
  - pip
  - pandas=1.5.3
  - torchdata
  - pytorch
  - pip:
      - torchtext
      - tokenizer

Preparing data

In pre_processing_data.py we will prepare the data:

In the model.py file, we have to add imports of all dependencies we use during the training:

from torchtext.data.utils import get_tokenizer
from torchtext.datasets import AG_NEWS
from torchtext.vocab import build_vocab_from_iterator
import torch


tokenizer = get_tokenizer('basic_english')
train_iter = AG_NEWS(split='train')


def yield_tokens(data_iter):
    for _, text in data_iter:
        yield tokenizer(text)


vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

text_pipeline = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: int(x) - 1


def collate_batch(batch):
    label_list, text_list, offsets = [], [], [0]
    for (_label, _text) in batch:
        label_list.append(label_pipeline(_label))
        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
        text_list.append(processed_text)
        offsets.append(processed_text.size(0))
    label_list = torch.tensor(label_list, dtype=torch.int64)
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
    text_list = torch.cat(text_list)
    return label_list.to(device), text_list.to(device), offsets.to(device)

In text_classification_model.py we will define the text classification model:

import time
import torch
from torch import nn


class TextClassificationModel(nn.Module):

    def __init__(self, vocab_size, embed_dim, num_class):
        super(TextClassificationModel, self).__init__()
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)
        self.fc = nn.Linear(embed_dim, num_class)
        self.init_weights()

    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()

    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)
        return self.fc(embedded)


def train(dataloader, model: TextClassificationModel, optimizer, criterion, epoch):
    model.train()
    total_acc, total_count = 0, 0
    log_interval = 500
    start_time = time.time()

    for idx, (label, text, offsets) in enumerate(dataloader):
        optimizer.zero_grad()
        predicted_label = model(text, offsets)
        loss = criterion(predicted_label, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        optimizer.step()
        total_acc += (predicted_label.argmax(1) == label).sum().item()
        total_count += label.size(0)
        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches '
                  '| accuracy {:8.3f}'.format(epoch, idx, len(dataloader),
                                              total_acc/total_count))
            total_acc, total_count = 0, 0
            start_time = time.time()


def evaluate(dataloader, model: TextClassificationModel, criterion):
    model.eval()
    total_acc, total_count = 0, 0

    with torch.no_grad():
        for idx, (label, text, offsets) in enumerate(dataloader):
            predicted_label = model(text, offsets)
            loss = criterion(predicted_label, label)
            total_acc += (predicted_label.argmax(1) == label).sum().item()
            total_count += label.size(0)

    return total_acc / total_count

Defining a new model

In Pytorch, we will implement the model as a Python class. Let's put the model implementation in the model.py file:

📘

Note: We store the hyper-parameters using the log_param function.

You will see them later in the Qwak UI, and potentially use the parameters or metrics logging feature for experiment tracking.

import time
import qwak
import torch
from qwak.model.base import QwakModel
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split
from torchtext.data import to_map_style_dataset
from torchtext.datasets import AG_NEWS

from pre_processing_data import collate_batch, vocab, device
from text_classification_model import TextClassificationModel, train, evaluate

# Hyperparameters
EPOCHS = 10 # epoch
LR = 5  # learning rate
BATCH_SIZE = 64 # batch size for training


class TextClassifierQwak(QwakModel):
    def __init__(self):
        super().__init__()
        train_iter = AG_NEWS(split='train')
        num_class = len(set([label for (label, text) in train_iter]))
        vocab_size = len(vocab)
        emsize = 64
        self.model = TextClassificationModel(vocab_size, emsize, num_class).to(device)
        qwak.log_param({"epoch": EPOCHS, "lr": LR, "batch size": BATCH_SIZE})

    def build(self):
        criterion = torch.nn.CrossEntropyLoss()
        optimizer = torch.optim.SGD(self.model.parameters(), lr=LR)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
        total_accu = None
        train_iter, test_iter = AG_NEWS()
        train_dataset = to_map_style_dataset(train_iter)
        test_dataset = to_map_style_dataset(test_iter)
        num_train = int(len(train_dataset) * 0.95)
        split_train_, split_valid_ = \
            random_split(train_dataset, [num_train, len(train_dataset) - num_train])

        train_dataloader =  DataLoader(split_train_, batch_size=BATCH_SIZE,
                                      shuffle=True, collate_fn=collate_batch)
        valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE,
                                      shuffle=True, collate_fn=collate_batch)
        test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE,
                                     shuffle=True, collate_fn=collate_batch)

        for epoch in range(1, EPOCHS + 1):
            epoch_start_time = time.time()
            train(dataloader=train_dataloader, model=self.model, optimizer=optimizer, criterion=criterion, epoch=epoch)
            accu_val = evaluate(dataloader=valid_dataloader,model=self.model, criterion=criterion)
            if total_accu is not None and total_accu > accu_val:
                scheduler.step()
            else:
                total_accu = accu_val
            print('-' * 59)
            print('| end of epoch {:3d} | time: {:5.2f}s | '
                  'valid accuracy {:8.3f} '.format(epoch,
                                                   time.time() - epoch_start_time,
                                                   accu_val))
            print('-' * 59)

        print('Checking the results of test dataset.')
        accu_test = evaluate(test_dataloader)
        print('test accuracy {:8.3f}'.format(accu_test))
        qwak.log_metric({"accuracy": accu_test})


if __name__ == '__main__':
    model = TextClassifierQwak()
    model.build()
    model.initialize_model()
    # model.predict()

Now, the model is ready to use.

Testing your model

To follow software development good practices, we must define a test to verify whether we get the correct response from the model before we start using it in production.

In the tests directory, let's open the test_qwak_model.py file.

In the file, we will use a test client to get a prediction from the model we have just built:

import pandas as pd
from qwak.testing.fixtures import real_time_client


def test_realtime_api(real_time_client):
    feature_vector = [
        {
            'text': 'To the Moon!'
        }]

    result: pd.DataFrame = real_time_client.predict(feature_vector)
    print(result)
    assert result['label'].values[0] == 4 # category 4 = science and technology