|
1 | 1 | # pytorch-cpp
|
2 |
| -Just messing around with PyTorch 1.0 (currently pre-release 1.0rc1) and their new C++ API Libtorch. |
| 2 | +In this repo I experiment with PyTorch 1.0 and their new JIT compiler, as well as their C++ API Libtorch. |
| 3 | + |
| 4 | +Currently, the repo contains a VGG16 based network implementation in PyTorch for CIFAR-10 classification (based on my [previous experiment](https://github.com/laggui/NN_compress)), and the C++ source for inference. |
| 5 | + |
| 6 | +## pytorch/ |
| 7 | +This subdirectory includes the network's [architecture definition](pytorch/vgg.py), the [training script](pytorch/train.py), the [test script](pytorch/test.py) on the CIFAR-10 dataset, a [prediction script](pytorch/predict.py) for inference and, most importantly, the [script to convert the model to Torch Script](pytorch/to_torch_script.py). |
| 8 | + |
| 9 | +## libtorch/ |
| 10 | +This is where you'll find the source for the network's inference in C++. In [predict.cpp](libtorch/predict.cpp), we load the Torch Script module generated in PyTorch, read the input image and pre-process it in order to feed it to our network for inference. |
| 11 | + |
| 12 | +## Example Usage |
| 13 | + |
| 14 | +### PyTorch Predict |
| 15 | + |
| 16 | +```sh |
| 17 | +pytorch$ python predict.py pytorch --model=../data/VGG16model.pth --image=../data/dog.png |
| 18 | +==> Building model... |
| 19 | +==> Loading PyTorch model... |
| 20 | +Predicted: dog | 10.056212425231934 |
| 21 | +Time: 0.06844925880432129 seconds |
| 22 | +``` |
| 23 | + |
| 24 | +### Libtorch |
| 25 | +Before running our prediction, we need to compile the source. In your `libtorch` directory, create a build directory and compile+build the application from source. |
| 26 | + |
| 27 | +```sh |
| 28 | +libtorch$ mkdir build |
| 29 | +libtorch$ cd build |
| 30 | +libtorch/build$ cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch .. |
| 31 | +-- The C compiler identification is GNU 5.4.0 |
| 32 | +-- The CXX compiler identification is GNU 5.4.0 |
| 33 | +-- Check for working C compiler: /usr/bin/cc |
| 34 | +-- Check for working C compiler: /usr/bin/cc -- works |
| 35 | +-- Detecting C compiler ABI info |
| 36 | +. |
| 37 | +. |
| 38 | +. |
| 39 | +-- Configuring done |
| 40 | +-- Generating done |
| 41 | +-- Build files have been written to: libtorch/build |
| 42 | +libtorch/build$ make |
| 43 | +Scanning dependencies of target vgg-predict |
| 44 | +[ 50%] Building CXX object CMakeFiles/vgg-predict.dir/predict.cpp.o |
| 45 | +[100%] Linking CXX executable vgg-predict |
| 46 | +[100%] Built target vgg-predict |
| 47 | +``` |
| 48 | + |
| 49 | +You're now ready to run the application. |
| 50 | + |
| 51 | +```sh |
| 52 | +libtorch/build$ ./vgg-predict ../../data/VGG16model.pth ../../data/dog.png |
| 53 | +Model loaded |
| 54 | +Moving model to GPU |
| 55 | +Predicted: dog | 10.0562 |
| 56 | +Time: 0.009481 seconds |
| 57 | +``` |
| 58 | + |
| 59 | + |
| 60 | + |
0 commit comments