微信公众号:OpenCV学堂
关注获取更多计算机视觉与深度学习知识
TorchScript介绍
trace使用
class MyCell(torch.nn.Module):
def __init__(self):
super(MyCell, self).__init__()
self.linear = torch.nn.Linear(4, 4)
def forward(self, x, h):
new_h = torch.tanh(self.linear(x) + h)
return new_h, new_h
my_cell = MyCell()
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell = torch.jit.trace(my_cell, (x, h))
print(traced_cell)
traced_cell(x, h)
print(traced_cell.graph)
运行结果如下:
MyCell(
original_name=MyCell
(linear): Linear(original_name=Linear)
)
跟踪执行结果
graph(%self.1 : __torch__.MyCell,
%input : Float(3:4, 4:1, requires_grad=0, device=cpu),
%h : Float(3:4, 4:1, requires_grad=0, device=cpu)):
%19 : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="linear"](%self.1)
%21 : Tensor = prim::CallMethod[name="forward"](%19, %input)
%12 : int = prim::Constant[value=1]() # D:/python/pytorch_openvino_demo/ch5/faster_rcnn2onnx.py:112:0
%13 : Float(3:4, 4:1, requires_grad=1, device=cpu) = aten::add(%21, %h, %12) # D:/python/pytorch_openvino_demo/ch5/faster_rcnn2onnx.py:112:0
%14 : Float(3:4, 4:1, requires_grad=1, device=cpu) = aten::tanh(%13) # D:/python/pytorch_openvino_demo/ch5/faster_rcnn2onnx.py:112:0
%15 : (Float(3:4, 4:1, requires_grad=1, device=cpu), Float(3:4, 4:1, requires_grad=1, device=cpu)) = prim::TupleConstruct(%14, %14)
return (%15)
script部分使用
script是导出模型为中间IR格式文件,支持高性能libtorch C++部署,我们以torchvision中Mask-RCNN导出中间格式IR为例,代码演示如下:
import torch
import torchvision as tv
num_classes = 3
model = tv.models.detection.maskrcnn_resnet50_fpn(
pretrained=False, progress=True,
num_classes=num_classes,
pretrained_backbone=True)
im = torch.zeros(1, 3, *(1333, 800)).to("cpu")
model.load_state_dict(torch.load("D:/gaobao_model.pth"))
model = model.to("cpu")
model.eval()
ts = torch.jit.script(model)
ts.save("gaobao.ts")
loaded_trace = torch.jit.load("gaobao.ts")
loaded_trace.eval()
with torch.no_grad():
print(loaded_trace(list(im)))
#include // One-stop header.
#include
#include
int main(int argc, const char* argv[]) {
if (argc != 2) {
std::cerr << "usage: example-app \n" ;
return -1;
}
// Deserialize the ScriptModule from a file using torch::jit::load().
torch::jit::script::Module module = torch::jit::load(argv[1]);
std::vector inputs;
inputs.push_back(torch::randn({4, 8}));
inputs.push_back(torch::randn({8, 5}));
torch::Tensor output = module.forward(std::move(inputs)).toTensor();
std::cout << output << std::endl;
}
上面代码来自官方测试程序,特别说明一下!
扫码查看OpenCV+OpenVIO+Pytorch系统化学习路线图
推荐阅读
CV全栈开发者说 - 从传统算法到深度学习怎么修炼
2022入坑深度学习,我选择Pytorch框架!
Pytorch轻松实现经典视觉任务
教程推荐 | Pytorch框架CV开发-从入门到实战
OpenCV4 C++学习 必备基础语法知识三
OpenCV4 C++学习 必备基础语法知识二
OpenCV4.5.4 人脸检测+五点landmark新功能测试
OpenCV4.5.4人脸识别详解与代码演示
OpenCV二值图象分析之Blob分析找圆
OpenCV4.5.x DNN + YOLOv5 C++推理
OpenCV4.5.4 直接支持YOLOv5 6.1版本模型推理
OpenVINO2021.4+YOLOX目标检测模型部署测试
比YOLOv5还厉害的YOLOX来了,官方支持OpenVINO推理