Building a PyTorch-based Text Classifier

This tutorial will show you how to train and deploy a Pytorch model using the Qwak platform. We will follow the steps from one of the official Pytorch tutorials and show how to run it in the Qwak platform.

We will use the built-in AG_NEWS dataset and train a text classifier. It will predict whether the given text belongs to four categories: World News, Sports, Business, or Science and Technology.

Creating an Empty Project

Before we start, we have to create an empty Qwak project using the Qwak CLI:

qwak models init --model-class-name PyTorchExample --model-directory model-example .

In the model-example directory, you will find the automatically-generated Python classes.

Defining Dependencies

First, we have to define the dependencies. Let's use a conda.yml file with the following content:

  - pytorch
  - conda-forge
  - python=3.9
  - pip
  - pandas=1.5.3
  - torchdata
  - pytorch
  - pip:
      - torchtext
      - tokenizer

Preparing the data

In we will prepare the data:

In the file, we have to add imports of all dependencies we use during the training:

from import get_tokenizer
from torchtext.datasets import AG_NEWS
from torchtext.vocab import build_vocab_from_iterator
import torch

tokenizer = get_tokenizer('basic_english')
train_iter = AG_NEWS(split='train')

def yield_tokens(data_iter):
    for _, text in data_iter:
        yield tokenizer(text)

vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

text_pipeline = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: int(x) - 1

def collate_batch(batch):
    label_list, text_list, offsets = [], [], [0]
    for (_label, _text) in batch:
        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
    label_list = torch.tensor(label_list, dtype=torch.int64)
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
    text_list =

In we will define the text classification model:

import time
import torch
from torch import nn

class TextClassificationModel(nn.Module):

    def __init__(self, vocab_size, embed_dim, num_class):
        super(TextClassificationModel, self).__init__()
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)
        self.fc = nn.Linear(embed_dim, num_class)

    def init_weights(self):
        initrange = 0.5, initrange), initrange)

    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)
        return self.fc(embedded)

def train(dataloader, model: TextClassificationModel, optimizer, criterion, epoch):
    total_acc, total_count = 0, 0
    log_interval = 500
    start_time = time.time()

    for idx, (label, text, offsets) in enumerate(dataloader):
        predicted_label = model(text, offsets)
        loss = criterion(predicted_label, label)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        total_acc += (predicted_label.argmax(1) == label).sum().item()
        total_count += label.size(0)
        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches '
                  '| accuracy {:8.3f}'.format(epoch, idx, len(dataloader),
            total_acc, total_count = 0, 0
            start_time = time.time()

def evaluate(dataloader, model: TextClassificationModel, criterion):
    total_acc, total_count = 0, 0

    with torch.no_grad():
        for idx, (label, text, offsets) in enumerate(dataloader):
            predicted_label = model(text, offsets)
            loss = criterion(predicted_label, label)
            total_acc += (predicted_label.argmax(1) == label).sum().item()
            total_count += label.size(0)

    return total_acc / total_count

ML Model Definition

In Pytorch, we will implement the model as a Python class. Let's put the model implementation in the file:



We store the hyper-parameters using the log_param function.

You will see them later in the Qwak UI, and potentially use the parameters or metrics logging feature for experiment tracking.

import time
import qwak
import torch
from qwak.model.base import QwakModelInterface
from import DataLoader
from import random_split
from import to_map_style_dataset
from torchtext.datasets import AG_NEWS

from pre_processing_data import collate_batch, vocab, device
from text_classification_model import TextClassificationModel, train, evaluate

# Hyperparameters
EPOCHS = 10 # epoch
LR = 5  # learning rate
BATCH_SIZE = 64 # batch size for training

class TextClassifierQwak(QwakModelInterface):
    def __init__(self):
        train_iter = AG_NEWS(split='train')
        num_class = len(set([label for (label, text) in train_iter]))
        vocab_size = len(vocab)
        emsize = 64
        self.model = TextClassificationModel(vocab_size, emsize, num_class).to(device)
        qwak.log_param({"epoch": EPOCHS, "lr": LR, "batch size": BATCH_SIZE})

    def build(self):
        criterion = torch.nn.CrossEntropyLoss()
        optimizer = torch.optim.SGD(self.model.parameters(), lr=LR)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
        total_accu = None
        train_iter, test_iter = AG_NEWS()
        train_dataset = to_map_style_dataset(train_iter)
        test_dataset = to_map_style_dataset(test_iter)
        num_train = int(len(train_dataset) * 0.95)
        split_train_, split_valid_ = \
            random_split(train_dataset, [num_train, len(train_dataset) - num_train])

        train_dataloader =  DataLoader(split_train_, batch_size=BATCH_SIZE,
                                      shuffle=True, collate_fn=collate_batch)
        valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE,
                                      shuffle=True, collate_fn=collate_batch)
        test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE,
                                     shuffle=True, collate_fn=collate_batch)

        for epoch in range(1, EPOCHS + 1):
            epoch_start_time = time.time()
            train(dataloader=train_dataloader, model=self.model, optimizer=optimizer, criterion=criterion, epoch=epoch)
            accu_val = evaluate(dataloader=valid_dataloader,model=self.model, criterion=criterion)
            if total_accu is not None and total_accu > accu_val:
                total_accu = accu_val
            print('-' * 59)
            print('| end of epoch {:3d} | time: {:5.2f}s | '
                  'valid accuracy {:8.3f} '.format(epoch,
                                                   time.time() - epoch_start_time,
            print('-' * 59)

        print('Checking the results of test dataset.')
        accu_test = evaluate(test_dataloader)
        print('test accuracy {:8.3f}'.format(accu_test))
        qwak.log_metric({"accuracy": accu_test})

if __name__ == '__main__':
    model = TextClassifierQwak()
    # model.predict()

Now, the model is ready to use.


To follow software development good practices, we must define a test to verify whether we get the correct response from the model before we start using it in production.

In the tests directory, let's open the file.

In the file, we will use a test client to get a prediction from the model we have just built:

import pandas as pd
from qwak_mock import real_time_client

def test_realtime_api(real_time_client):
    feature_vector = [
            'text': 'To the Moon!'

    result: pd.DataFrame = real_time_client.predict(feature_vector)
    assert result['label'].values[0] == 4 # category 4 = science and technology