使用 ONNX 将模型从 PyTorch 迁移到 Caffe2

1. PyTorch及ONNX环境准备

为了正常运行ONNX,咱们须要安装最新的Pytorchpython

git clone --recursive https://github.com/pytorch/pytorch
cd pytorch
mkdir build && cd build
sudo cmake .. -DPYTHON_INCLUDE_DIR=/usr/include/python3.5  -DUSE_MPI=OFF
make install
export PYTHONPATH=$PYTHONPATH:/opt/pytorch/build

上面的"/opt/pytorch/build"是你前面build pytorch的目录,写对路径便可。git

经过整个PyTorch的源码安装,PyTorch支持的相关ONNX库也会随之安装好。安装路径在:/usr/local/lib/python3.5/dist-packages/torchgithub

运行以下命令安装ONNX的库:bash

conda install -c conda-forge onnx

此外,还须要安装onnx-caffe2,一个纯Python库,它为ONNX提供了一个caffe2的编译器。你能够用pip安装onnx-caffe2:app

pip3 install onnx-caffe2

2. 准备好把PyTorch转换成ONNX的代码

https://github.com/lindylin1817/pytorch2caffe2 上面的pytorch2caffe2.py就是一段参考代码,把DeblurGAN训练好的模型转换成ONNX 。代码解释以下:dom

import os
import sys
import torch
import torch.onnx
import torch.utils.model_zoo
from torch.autograd import Variable
sys.path.append("../DeblurGAN")
from models.models import create_model
import models.networks as networks
from options.test_options import TestOptions
import shutil
import onnx
from onnx_caffe2.backend import Caffe2Backend

batch_size = 1    # just a random number

# Load the pretrained model weights
model_path = './model/char_deblur/latest_net_G.pth'
onnx_model_path = "./deblurring.onnx.pb"
state_dict = torch.utils.model_zoo.load_url(model_path, model_dir="./model/char_deblur")

# Load the DeblurnGAN neural network
gan_opt = TestOptions().parse()
gan_opt.name = "char_deblur"
gan_opt.checkpoints_dir = "./model/"
gan_opt.model = "test"
gan_opt.dataset_mode = "single"
gan_opt.dataroot = "/tmp/gan/"
try:
    shutil.rmtree(gan_opt.dataroot)
except:
    pass
os.mkdir(gan_opt.dataroot)
gan_opt.loadSizeX = 64
gan_opt.loadSizeY = 64
gan_opt.fineSize = 64
gan_opt.learn_residual = True
gan_opt.nThreads = 1  # test code only supports nThreads = 1
gan_opt.batchSize = 1  # test code only supports batchSize = 1
gan_opt.serial_batches = True  # no shuffle
gan_opt.no_flip = True  # no flip
#torch_model = create_model(gan_opt)

gpus = []
torch_model = networks.define_G(gan_opt.input_nc, gan_opt.output_nc, gan_opt.ngf,
                                gan_opt.which_model_netG, gan_opt.norm, not gan_opt.no_dropout, gpus, False,
                                gan_opt.learn_residual)

torch_model.load_state_dict(state_dict)
#torch_model.load_state_dict(state_dict)

# set the train mode to false since we will only run the forward pass.
torch_model.train(False)

# Input to the model
x = Variable(torch.randn(batch_size, 3, 60, 60), requires_grad=True)
x = x.float()

# Export the model
torch_out = torch.onnx._export(torch_model,             # model being run
                               x,                       # model input (or a tuple for multiple inputs)
                               onnx_model_path, # where to save the model (can be a file or file-like object)
                               verbose=True, export_params=True, training=False)      # store the trained parameter weights inside the model file

onnx_model = onnx.load(onnx_model_path)
onnx.checker.check_model(onnx_model)
model_name = onnx_model_path.replace('.onnx.pb','')
init_net, predict_net = Caffe2Backend.onnx_graph_to_caffe2_net(onnx_model.graph, device="CUDA")
with open(model_name + "_init.pb", "wb") as f:
    f.write(init_net.SerializeToString())
with open(model_name + "_predict.pb", "wb") as f:
    f.write(predict_net.SerializeToString())

基于这个例子中,用户须要本身修改的部分有以下几个:ide

  • 训练好的PyTorch模型的路径:如在这个例子中“./model/char_deblur/latest_net_G.pth”须要修改为本身的模型路径
  •  

经过上面的代码,将生成两个Caffe2的pb文件,deblurring_init.pb 和 deblurring_predict.pb。ui

相关文章
相关标签/搜索