Intro to Digit Classification with LibTorch

Recently I’ve been considering deploy a trained digit classification model for inferring to a project developed in C++. Here is a brief post of training the model and integrating libtorch C++ API, using CNN-based neural networks and MNIST datasets as example.

Training the Model

For reasons of simplicity and convenience, I still use Python Torch to train the model here. However, you can still find the official PyTorch example using C++ frontend training the model at MNIST Example with the PyTorch C++ Frontend.

Prepare the Datasets

The MNIST datasets (Modified National Institute of Standards and Technology database) is a large database of handwritten digits which consists of 60,000 training and 10,000 testing grayscale images sized by $28 * 28$ pixels in 10 classes and commonly used for training various image processing systems.

If your want to make a custom MNIST datasets, I recommend you take a look at MNIST Database and I’ve been writtern a data_2_mnist Python Script may help you do the convertion.

Firstly import necessary libraries and download the raw MNIST datasets to ./data folder using intergrated PyTorch utils, then set the default batch_size and define the transform function.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
from torchvision import datasets, transforms

train_kwargs = {'batch_size': 64}
test_kwargs = {'batch_size': 1000}

if torch.cuda.is_available():
cuda_kwargs = {
'num_workers': 1,
'pin_memory' : True,
'shuffle' : True
}

train_kwargs.update(cuda_kwargs)
test_kwargs.update(cuda_kwargs)

transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST(
root='./data',
train=True,
download=True,
transform=transform
)
test_dataset = datasets.MNIST(
root='./data',
train=False,
download =True,
transform=transform
)

train_loader = torch.utils.data.DataLoader(train_dataset, **train_kwargs)
test_loader = torch.utils.data.DataLoader(test_dataset, **test_kwargs)

Note the batch_size affects the time that required to complete each epoch and the smoothness of the gradient between each iteration in the deep learning training process. The example above uses 64 as batch_size for training and 1,000 as batch_size for testing.

The larger batch size obtains a more stable estimate of the gradient descent direction, but globally it is not necessarily the correct estimate, although the gradient estimate is more accurate, how far in this direction is still related to the learning rate, at this time you can try to improve the learning rate and increase the number of iterations. However, people usually using smaller batch with smaller learning rate and then increasing epoches appropriately can often achieve better generalization performance, because the variance of gradient estimation of small batch is larger and has the effect of regularization.

The normalize function is used to converts the data into a standard Normal(Gaussian) distribution that normalize the data to the range of [0, 1], where 0.1307 and 0.3081 are the mean and standard deviation of the MNIST datasets, the mean and the standard deviation determined by the datasets itself.

$$
Normalize: output[channel] = \frac {input[channel] - mean[channel]}{std[channel]}
$$

Next try some visualization to the loaded MNIST datasets and make sure the data is in the correct format (here just uses train dataset).

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def img_cvt(tensor):
image = tensor.cpu().clone().detach().numpy()
image = image.transpose(1, 2, 0)
image = image * np.array((0.5, 0.5, 0.5)) + np.array((0.5, 0.5, 0.5))
image = image.clip(0, 1)
return image

data_iter = iter(train_loader)
images, labels = data_iter.next()
fig = plt.figure(figsize=(25, 4))

for idx in np.arange(20):
ax = fig.add_subplot(2, 10, idx + 1, xticks=[], yticks=[])
plt.imshow(img_cvt(images[idx]))
ax.set_title([labels[idx].item()])

Build the Neural Network

Here I just simply use a simple CNN model from PyTorch MNIST Example.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)

def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output

There’re 3 basic knowledge in the above code, here just a brief description of what they do:

  • Conv2d is used to implement the 2d convolution operation. The four parameters mainly used in the above code are represented in a left-to-right table: number of input channels, number of output channels, the size of the convolution kernel (type int or tuple when the convolution is a square, represents the height and width) and the step length of each slide of the convolution.

  • Dropout is the random discarding of a part of neurons in different training processes. It means to let the activation value of a certain neuron with a certain probability $p$, let it stop working and not update the weights during this training process and not participate in the computation of the neural network. But its weights have to be kept (just not updated for a while) and it may have to work again at the next training process.

  • Linear is used to set the fully connected layer in the network, it should be noted that the input and output of the fully connected layer are two-dimensional tensor.

Train and Export the Model

Firstly define the train and test function as follows:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def train(model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()

if batch_idx % 30 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch,
batch_idx * len(data),
len(train_loader.dataset),
100. * batch_idx / len(train_loader),
loss.item())
)

def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()

test_loss /= len(test_loader.dataset)
test_accuracy = 100. * correct / len(test_loader.dataset)

print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss,
correct,
len(test_loader.dataset),
test_accuracy)
)

Then transfer pre-built model to correct device and choose a optimizer, I used Adadelta optimizer here, the paper of that optimizer is Adadelta: An Adaptive Learning Rate Method.

I want the learning rate to be reduced to a fraction of the original gamma every certain number of steps (or epoch), so the StepLR scheduler is used here.

1
2
3
4
5
6
7
8
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = Net().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=1.0)
epochs = 10
scheduler = StepLR(optimizer, step_size=1, gamma=0.7)

Then we can start training the model.

1
2
3
4
for epoch in range(1, epochs + 1):
train(model, device, train_loader, optimizer, epoch)
test(model, device, test_loader)
scheduler.step()

After training, we can evaluate the model and export the model to a file.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
data_iter      = iter(test_loader)
images, labels = data_iter.next()
images = images.to(device)
labels = labels.to(device)
output = model(images)
_, preds = torch.max(output, 1)
fig = plt.figure(figsize=(25, 4))

for idx in np.arange(20):
ax = fig.add_subplot(2, 10, idx + 1, xticks=[], yticks=[])
plt.imshow(im_convert(images[idx]))
ax.set_title("{} ({})".format(str(preds[idx].item()), str(labels[idx].item())),
color=("green" if preds[idx] == labels[idx] else "red")
)

Note that if we want to use the trained model towards a non-python environment, we should export the model as a TorchScript Module. A Module is the basic unit of composition in PyTorch, it contains (from Basics of PyTorch Model Authoring):

  • A constructor, which prepares the module for invocation
  • A set of Parameters and sub-Modules. These are initialized by the constructor and can be used by the module during invocation.
  • A forward function. This is the code that is run when the module is invoked.

Also the moudle could be optimized for mobile devices Optimize a TorchScript Model

Finally, we have exported the traced module to a file called mnist_traced.pt which could be used in libtorch.

1
2
3
4
blob_input      = torch.zeros(1, 1, 28, 28).to(device)
trained_network = model.to(device)
traced_model = torch.jit.trace(trained_network, blob_input)
traced_model.save("mnist_traced.pt")

Inferring with LibTorch C++ Frontend API

For inference, I prefer to wrap libtorch frontend API to a single header file to use.

Header Module Helper

As the code below shows, class dnn::ts_mnist is initialized with the model file path and label vector, and a _input vector is used to store the input data in futher steps.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#pragma once

#include <vector>
#include <string_view>
#include <utility>

#include <torch/torch.h>
#include <torch/script.h>

namespace dnn {
template <typename T = int>
class ts_mnist {
public:
ts_mnist(const std::string_view& path, const std::vector<T>& labels)
: _path(path), _labels(labels) {
_module = torch::jit::load(_path.data());
_inputs.resize(1);
}

private:
const std::string_view _path;
const std::vector<T> _labels;

torch::jit::script::Module _module;
std::vector<torch::jit::IValue> _inputs;
}
}

The constructor of ts_mnist will load the module by calling torch::jit::load. Then wrote the function for inferring.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
// the minmal OpenCV header used for image processing
#include <opencv2/core.hpp>
#include <opencv2/imgproc.hpp>
#include <opencv2/highgui.hpp>

// publich function of dnn::mnist
auto inferring(const cv::Mat& image) noexcept {
cv::cvtColor(image, _gray, cv::COLOR_BGR2GRAY);
cv::resize(_gray, _gray, cv::Size(28, 28));

_tensor_image = torch::from_blob(_gray.data, { _gray.rows, _gray.cols }, torch::kUInt8);
_tensor_image_normed = (_tensor_image / 255.f).sub_(0.5f).div_(0.5f);

_inputs[0] = _tensor_image_normed.unsqueeze_(0).unsqueeze_(0);
_output = _module.forward(_inputs).toTensor();
_index = _output.argmax().item<int32_t>();

return std::pair<T, torch::Tensor>(_labels.at(_index), _output.index({0, _index}));
}

// private variable of dnn::mnist
torch::Tensor _tensor_image;
torch::Tensor _tensor_image_normed;
torch::Tensor _output;

int32_t _index;

cv::Mat _gray;

Assuming that the inferring function accepts a color image as an incoming parameter, the color convertion of grayscale should be apply to that image because the module (or the model inside the module) requires a 1 dimensional tensor input, then force resize the grayscale image to $28 * 28$ to match the input shape.

Before calling forward function to the module, the image should be converted to a tensor and normalized, referring to the input shape of the module at exporting step, the normalized tensor’s shape is 2 dimensional ($28 * 28$), thus we should add 2 demension to the tensor to match the input shape ($1 * 1 * 28 * 28$), just calling unsqueeze_(0) twice.

Then the output result can be obtained by calling forward function, finally we can apply a argmax function to the output tensor to get the final result that we want.

Compile using CMake and Test

Integrating libtorch into existing C++ projects is easy, for macOS users, just install libtorch, torchvision and opencv using HomeBrew, and add following lines to the CMakeLists.txt file would be fine.

1
2
3
4
5
6
find_package(Torch REQUIRED)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
target_link_libraries("${PROJECT_NAME}" "${TORCH_LIBRARIES}")

find_package(OpenCV REQUIRED)
target_link_libraries("${PROJECT_NAME}" PUBLIC "${OpenCV_LIBS}")

For Linux users, apt or alternative package manager may not good at handling this, which CMake find_package function would always reporting the package was not found. Things would goes little complicated if you’re using the official pre-compiled version of libtorch, you may encounter ABI breakage of linker errors, and manually compiling libtorch may fix this problem (manually compiling OpenCV is also recommend, if turely needed). However, you may specify a path to the pre-built libraries like this so that CMake can find that.

1
2
3
set(CMAKE_FIND_ROOT_PATH <libraries path>)
# or
find_package(Torch REQUIRED PATHS <library path included CMakeList.txt>)

For MSVC users, please take a look at official libtorch minmal example for more details.

Then have a simple test, the following code uses first argument as the module file path, the second argument as the path of the image for inferring and the third argument as the labels index.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
#include <chrono>
#include <iostream>
#include <string>
#include <string_view>
#include <vector>

#include "ts_mnist_inferring.hpp" // the header file of dnn::mnist

int main(int argc, char *argv[]) {
if (argc != 4) return -1;
const auto module_path = std::string_view(argv[1]);
const auto image_path = std::string_view(argv[2]);
std::vector<int> lables;
for (const auto& idx : std::string_view(argv[3])) lables.push_back(idx - '0');

auto net = dnn::ts_mnist(module_path, lables);
auto img = cv::imread(image_path.data(), cv::IMREAD_COLOR);

auto s = std::chrono::high_resolution_clock::now();
auto res = net.inferring(img);
auto e = std::chrono::high_resolution_clock::now();

using namespace std::chrono_literals;
cv::resize(img, img, cv::Size(320, 320));
cv::putText(
img,
std::string("Res: ") + std::to_string(res.first) + std::string(" FPS: ") + std::to_string(static_cast<int>(1.0s / (e - s))),
cv::Point(10, 25),
cv::FONT_HERSHEY_SIMPLEX,
1,
cv::Scalar(255, 255, 255),
2
);
cv::imshow("result", img);
cv::waitKey(0);

return 0;
}

As a intro to digit classification with libtorch, I haven’t test the libtorch C++ Frontend API for CUDA devices currently, maybe I could do that later :D