TorchScript学习笔记

TorchScript是一种可从python代码中创建序列化模型的方法。可以从python代码中保存,并在非python环境中加载模型。注:TorchScript 主要实现的是在 PyTorch 中表示神经网络模型所需的 Python 功能,并不适用于所有的Python特性。

torch.jit

TorchScript是Pytorch的JIT实现。
JIT ,全称是 Just In Time Compilation(即时编译)。
JIT 是 Python 和 C++ 的桥梁,我们可以使用 Python 训练模型,然后通过 JIT 将模型转为语言无关的模块,从而可以使用C++把 PyTorch 模型部署到任意平台和设备上:树莓派、iOS、Android 等。(多线程执行和性能原因,一般Python代码并不适合做部署。)
导出模型主要有两种方式,Scripting和Tracing。

1.Tracing方式

tracing是相对简单的方式,输入向量,追踪向量在forward函数的流动来获得模型结构。
必须要有输入。
只适用于比较简单的模型,如果forward函数中有控制流结构,向量一次无法遍历所有的分支,这时就要借助script方式。

torch.jit.trace(func, example_inputs)

torch.jit.trace会将torch::jit::Module 转成 torch::jit::Graph 。
如果trace的是Python 函数,那么返回ScriptFunction , 如果是nn.Module.forward或者nn.Module,返回的就是ScriptModule。 如果trace的时候是eval/train模式,那么返回的ScriptModule就是eval/train模式。

  • 例子:trace一个函数
def sigmoid(z):s = 1 / (1 + 1 / np.exp(z))return sx = torch.full((2, 2), 1)
print("x: ",x)
traced_func = torch.jit.trace(sigmoid, x)
print(traced_func.graph)
print(traced_func(x))

输出:

x:  tensor([[1, 1],[1, 1]])
graph(%0 : Long(2:2, 2:1, requires_grad=0, device=cpu)):%1 : Double(2:2, 2:1, requires_grad=0, device=cpu) = prim::Constant[value= 2.7183  2.7183  2.7183  2.7183 [ CPUDoubleType{2,2} ]]() # E:\Anaconda2021\envs\torch1.7\lib\site-packages\torch\tensor.py:519:0%2 : Double(2:2, 2:1, requires_grad=0, device=cpu) = aten::reciprocal(%1) # E:\Anaconda2021\envs\torch1.7\lib\site-packages\torch\tensor.py:519:0%3 : Long(requires_grad=0, device=cpu) = prim::Constant[value={1}]() # E:\Anaconda2021\envs\torch1.7\lib\site-packages\torch\tensor.py:519:0%4 : Double(2:2, 2:1, requires_grad=0, device=cpu) = aten::mul(%2, %3) # E:\Anaconda2021\envs\torch1.7\lib\site-packages\torch\tensor.py:519:0%5 : Long(requires_grad=0, device=cpu) = prim::Constant[value={1}]() # C:/Users/Administrator/Desktop/frcnn_1.7/lt_test.py:28:0%6 : int = prim::Constant[value=1]() # C:/Users/Administrator/Desktop/frcnn_1.7/lt_test.py:28:0%7 : Double(2:2, 2:1, requires_grad=0, device=cpu) = aten::add(%4, %5, %6) # C:/Users/Administrator/Desktop/frcnn_1.7/lt_test.py:28:0%8 : Double(2:2, 2:1, requires_grad=0, device=cpu) = aten::reciprocal(%7) # E:\Anaconda2021\envs\torch1.7\lib\site-packages\torch\tensor.py:519:0%9 : Long(requires_grad=0, device=cpu) = prim::Constant[value={1}]() # E:\Anaconda2021\envs\torch1.7\lib\site-packages\torch\tensor.py:519:0%10 : Double(2:2, 2:1, requires_grad=0, device=cpu) = aten::mul(%8, %9) # E:\Anaconda2021\envs\torch1.7\lib\site-packages\torch\tensor.py:519:0return (%10)tensor([[0.7311, 0.7311],[0.7311, 0.7311]], dtype=torch.float64)
  • 例子:trace一个nn.Module的子类
import torch
from torchvision.models import resnet50model = resnet50(pretrained=True)
model = model.eval()
resnet = torch.jit.trace(model, torch.rand(1,3,224,224))
print(resnet.graph)

输出:

graph(%self.1 : __torch__.torchvision.models.resnet.ResNet,%input.1 : Float(1:150528, 3:50176, 224:224, 224:1, requires_grad=0, device=cpu)):%2664 : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="fc"](%self.1)%2661 : __torch__.torch.nn.modules.pooling.AdaptiveAvgPool2d = prim::GetAttr[name="avgpool"](%self.1)%2660 : __torch__.torch.nn.modules.container.___torch_mangle_141.Sequential = prim::GetAttr[name="layer4"](%self.1)%2582 : __torch__.torch.nn.modules.container.___torch_mangle_113.Sequential = prim::GetAttr[name="layer3"](%self.1)%2435 : __torch__.torch.nn.modules.container.___torch_mangle_61.Sequential = prim::GetAttr[name="layer2"](%self.1)%2334 : __torch__.torch.nn.modules.container.___torch_mangle_25.Sequential = prim::GetAttr[name="layer1"](%self.1)%2256 : __torch__.torch.nn.modules.pooling.MaxPool2d = prim::GetAttr[name="maxpool"](%self.1)%2255 : __torch__.torch.nn.modules.activation.ReLU = prim::GetAttr[name="relu"](%self.1)%2254 : __torch__.torch.nn.modules.batchnorm.BatchNorm2d = prim::GetAttr[name="bn1"](%self.1)%2249 : __torch__.torch.nn.modules.conv.Conv2d = prim::GetAttr[name="conv1"](%self.1)%2847 : Tensor = prim::CallMethod[name="forward"](%2249, %input.1)%2848 : Tensor = prim::CallMethod[name="forward"](%2254, %2847)%2849 : Tensor = prim::CallMethod[name="forward"](%2255, %2848)%2850 : Tensor = prim::CallMethod[name="forward"](%2256, %2849)%2851 : Tensor = prim::CallMethod[name="forward"](%2334, %2850)%2852 : Tensor = prim::CallMethod[name="forward"](%2435, %2851)%2853 : Tensor = prim::CallMethod[name="forward"](%2582, %2852)%2854 : Tensor = prim::CallMethod[name="forward"](%2660, %2853)%2855 : Tensor = prim::CallMethod[name="forward"](%2661, %2854)%2059 : int = prim::Constant[value=1]() # E:\Anaconda2021\envs\torch1.7\lib\site-packages\torchvision\models\resnet.py:214:0%2060 : int = prim::Constant[value=-1]() # E:\Anaconda2021\envs\torch1.7\lib\site-packages\torchvision\models\resnet.py:214:0%input : Float(1:2048, 2048:1, requires_grad=1, device=cpu) = aten::flatten(%2855, %2059, %2060) # E:\Anaconda2021\envs\torch1.7\lib\site-packages\torchvision\models\resnet.py:214:0%2856 : Tensor = prim::CallMethod[name="forward"](%2664, %input)return (%2856)

2. Scripting方式

script方式通过解析AST的方式生成静态图。不需要有输入。Python ast官方文档

torch.jit.script

  • 例子:script一个resnet50+FPN的backbone
from backbone import resnet50_fpn_backbone
# device = torch.device('cpu')
model = resnet50_fpn_backbone()
# model.to(device).eval()
traced_model = torch.jit.script(model)
print(traced_model.graph)
# torch.jit.save(traced_model, 'saved_cpu.pt')

输出:

graph(%self : __torch__.backbone.resnet50_fpn_model.BackboneWithFPN,%x.1 : Tensor):%2 : __torch__.backbone.resnet50_fpn_model.IntermediateLayerGetter = prim::GetAttr[name="body"](%self)%x.3 : Dict(str, Tensor) = prim::CallMethod[name="forward"](%2, %x.1) # C:\Users\Administrator\Desktop\frcnn_1.7\backbone\resnet50_fpn_model.py:228:12%5 : __torch__.backbone.feature_pyramid_network.FeaturePyramidNetwork = prim::GetAttr[name="fpn"](%self)%x.5 : Dict(str, Tensor) = prim::CallMethod[name="forward"](%5, %x.3) # C:\Users\Administrator\Desktop\frcnn_1.7\backbone\resnet50_fpn_model.py:229:12return (%x.5)

TorchScript Type 类型解释

TorchScript 类型系统划分为 TSType 和 TSModuleType

  • TSType :
    Meta Types,如 Any。更像是类型约束,可以表示任何类型的类型。
    Primitive Types,如 int,float, str
    Structural Types,如 TSTuple ,TSNamedTuple,TSList ,TSDict ,TSOptional,TSFuture,TSRRef
    Nominal Types (Python classes),如 MyClass (自定义), torch.tensor (built-in)
  • TSModuleType:
    表示torch.nn.Module及其子类。因为它的定义部分来自对象,部分来自类定义,不属于静态类型,因此不能用作TorchScript type annotation,也不能够和TSType进行组合使用。

TorchScript学习笔记相关推荐

  1. PyTorch 学习笔记(六):PyTorch hook 和关于 PyTorch backward 过程的理解 call

    您的位置 首页 PyTorch 学习笔记系列 PyTorch 学习笔记(六):PyTorch hook 和关于 PyTorch backward 过程的理解 发布: 2017年8月4日 7,195阅读 ...

  2. 容器云原生DevOps学习笔记——第三期:从零搭建CI/CD系统标准化交付流程

    暑期实习期间,所在的技术中台-效能研发团队规划设计并结合公司开源协同实现符合DevOps理念的研发工具平台,实现研发过程自动化.标准化: 实习期间对DevOps的理解一直懵懵懂懂,最近观看了阿里专家带 ...

  3. 容器云原生DevOps学习笔记——第二期:如何快速高质量的应用容器化迁移

    暑期实习期间,所在的技术中台-效能研发团队规划设计并结合公司开源协同实现符合DevOps理念的研发工具平台,实现研发过程自动化.标准化: 实习期间对DevOps的理解一直懵懵懂懂,最近观看了阿里专家带 ...

  4. 2020年Yann Lecun深度学习笔记(下)

    2020年Yann Lecun深度学习笔记(下)

  5. 2020年Yann Lecun深度学习笔记(上)

    2020年Yann Lecun深度学习笔记(上)

  6. 知识图谱学习笔记(1)

    知识图谱学习笔记第一部分,包含RDF介绍,以及Jena RDF API使用 知识图谱的基石:RDF RDF(Resource Description Framework),即资源描述框架,其本质是一个 ...

  7. 计算机基础知识第十讲,计算机文化基础(第十讲)学习笔记

    计算机文化基础(第十讲)学习笔记 采样和量化PictureElement Pixel(像素)(链接: 采样的实质就是要用多少点(这个点我们叫像素)来描述一张图像,比如,一幅420x570的图像,就表示 ...

  8. Go 学习推荐 —(Go by example 中文版、Go 构建 Web 应用、Go 学习笔记、Golang常见错误、Go 语言四十二章经、Go 语言高级编程)

    Go by example 中文版 Go 构建 Web 应用 Go 学习笔记:无痕 Go 标准库中文文档 Golang开发新手常犯的50个错误 50 Shades of Go: Traps, Gotc ...

  9. MongoDB学习笔记(入门)

    MongoDB学习笔记(入门) 一.文档的注意事项: 1.  键值对是有序的,如:{ "name" : "stephen", "genda" ...

最新文章

  1. linux软中断分析,linux操作系统下的软中断问题分析_linux教程
  2. esxi时区设置 +8_Go语言MySQL时区问题
  3. nm命令中符号类型详解
  4. Android之解决androidx.appcompat.widget.Toolbar去掉左边距
  5. 普通话测试系统_普通话
  6. Chapter7-11_Deep Learning for Question Answering (2/2)
  7. python中classmethod与staticmethod的差异及应用
  8. myeclipse 10.7安装过程与初次启动
  9. MVC中modelstate的使用
  10. 查看openfrie是否连接mysql_openfire连接mysql数据库的字符集问题解决
  11. PHP SQL注入攻击与防御
  12. Android 三类框架的理解以及MVVM框架的使用
  13. Git 连接码云 上传本地项目
  14. 解决打开一个excel文件,却出现两个窗口的办法
  15. 六爪机器人_六爪机器人
  16. 学习方法和学习经验总结
  17. 苹果公司不给iPhone配大电池的原因
  18. 局域网bs虚拟服务器怎么创建,搭建局域网地图服务器
  19. win10系统的计算机C盘在哪,win10系统只有一个C盘怎么解决
  20. 分享48个Go源码,总有一款适合您

热门文章

  1. 缓慢且无语的《赵氏孤儿》
  2. 6. 创业者与电子商务创业
  3. 【错误记录】Android Studio 打包 apk 文件报错 ( The destination folder does not exist or is not writeable )
  4. #第六届立创电赛#离线语音控制的空调智能插座
  5. Tomcat部署失败的原因
  6. 【高级UI】【009】贝塞尔曲线图形原理和公式推导
  7. 特征选择方法之互信息
  8. 城区现有5种共享单车 ofo率先加入智慧城市建设
  9. Qml 透明窗口,设置不规则等透明窗口,鼠标可穿透到桌面
  10. C语言字符数组的输入和输出