这是Pytorch学习之路的第五篇

遇到问题

虽然已经知道了怎么保存已经训练好的网络模型,但是还是不知道怎么调用。其他博客中讲的有点简略,还需要自己摸索一下:

PyTorch要加载已经训练好的网络模型,需要保留什么代码,增加什么代码?

解决方法(只讨论仅加载参数的方法)

导入的库都不变,且只有测试模型前代码需要做改动:

import torch.nn as nn
import torch.nn.functional as F
#以下为需要保留的代码
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
class CNNNet(nn.Module):def __init__(self):super(CNNNet, self).__init__()self.conv1 = nn.Conv2d(in_channels=3,out_channels=16,kernel_size=5,stride=1)self.pool1 = nn.MaxPool2d(kernel_size=2,stride=2)self.conv2 = nn.Conv2d(in_channels=16,out_channels=36,kernel_size=3,stride=1)self.pool2 = nn.MaxPool2d(kernel_size=2,stride=2)#self.aap = nn.AdaptiveAvgPool2d(1)self.fc1 = nn.Linear(1296,128)self.fc2 = nn.Linear(128,10)#self.fc3 = nn.Linear(36,10)def forward(self,x):x = self.pool1(F.relu(self.conv1(x)))x = self.pool2(F.relu(self.conv2(x)))#x = self.aap(x)#x = x.view(x.shape[0],-1)#x = self.fc3(x)x = x.view(-1,36*6*6)#print("x.shape:{}".format(x.shape))x = F.relu(self.fc2(F.relu(self.fc1(x))))return xmodel = CNNNet()#以下为新增代码
model.load_state_dict(torch.load('./model/model.pth'))#再加载网络的参数
model = model.to(device)
print("load success")

注意

model = torch.load('./model/model.pth')

会报错

原因未知。

效果

成功

灵感来源

  1. pytorch:无法加载CNN模型并做预测TypeError:'collections. OrderedDict’对象不可调用(转载)
  2. Pytorch文档阅读(五)如何保存、加载网络模型(转载)

PyTorch如何加载已经训练好的网络模型相关推荐

  1. PyTorch 加载预训练权重

    前言  使用PyTorch官方提供的权重或者其他第三方提供的权重对相同模型的参数进行初始化,在数据量较少的前提下,可以帮助模型更快地收敛到最优点,达到更好的效果,即迁移学习.  在大部分的迁移学习场景 ...

  2. 【Pytorch】加载torchvision中预训练好的模型并修改默认下载路径(使用models.__dict__[model_name]()读取)

    说明 使用torchvision.model加载预训练好的模型时,发现默认下载路径在系统盘下面的用户目录下(这个你执行的时候就会发现),即C:\用户名\.cache\torch\.checkpoint ...

  3. Pytorch 词嵌入word_embedding2实例(加载已训练词向量)

    目录 1.加载已训练好的词嵌入 2.是否需要重新训练词嵌入 3.不重新训练词嵌入时优化器设置

  4. torch编程-加载预训练权重-模型冻结-解耦-梯度不反传

    1)加载预训练权重 net = torchvision.models.resnet50(pretrained=False) # 构建模型 pretrained_model = torch.load(p ...

  5. PyTorch中加载模型权重

    在做深度学习项目时,从头训练一个模型是需要大量时间和算力的,我们通常采用加载预训练权重的方法,而我们往往面临以下几种情况: #mermaid-svg-freoBrrdezozjyan {font-fa ...

  6. PyTorch数据加载处理

    PyTorch数据加载处理 PyTorch提供了许多工具来简化和希望数据加载,使代码更具可读性. 1.下载安装包 • scikit-image:用于图像的IO和变换 • pandas:用于更容易地进行 ...

  7. 使用torchvision.models.inception_v3(pretrained=True)加载预训练的模型每次都特别慢

    欢迎大家关注笔者,你的关注是我持续更博的最大动力 原创文章,转载告知,盗版必究 使用torchvision.models.inception_v3(pretrained=True)加载预训练的模型每次 ...

  8. PyTorch数据加载器

    We'll be covering the PyTorch DataLoader in this tutorial. Large datasets are indispensable in the w ...

  9. java加载tensorflow训练的PB模型记录

    java加载tensorflow训练的PB模型记录 python训练 1. 模型的输入输出定义 2. 训练时保存模型的方法 java加载模型 1.maven依赖 2. Java代码实例 tensor注 ...

最新文章

  1. Java泛型 通配符? extends与super
  2. binary格式和ELF格式区别。用ida打开的样子
  3. Even Parity UVA - 11464 (枚举)
  4. Cinder LVM Oversubscription in thin provisioning
  5. 一些挺不错的visualstudio主题样式
  6. Eureka深入理解
  7. 『设计模式』设计模式--策略模式
  8. “辩者21事”之解读——分析性理性要与辩证理性相结合
  9. 又一任务被Transformer攻陷!NVIDIA开源HORST,用Transformer解决早期动作识别和动作预期任务...
  10. Retrofit:类型安全的REST客户端for 安卓Java
  11. javascript语言扩展:可迭代对象(3)
  12. 学习Spring(一) -- 配置Spring
  13. 开源Scout攻击检测工具
  14. 华为手机 图标消失_华为手机升级EMUI 10后解决Google Play“消失”教程
  15. Python学习第二章:变量和简单类型
  16. 【复】基于 WebRTC 的音视频在线监考模块的设计与实现(下)
  17. 网络工程师职业发展方向和职业前景
  18. Python WEB 开发,什么是 WSGI ?uWSGI、Gunincorn 都是啥玩意儿?
  19. Net-speed 一键安装脚本
  20. PayPal第三方支付

热门文章

  1. libvirt php,libvirt虚拟化开发简介
  2. Adobe Acrobat 给pdf添加多级书签(制作目录)
  3. 搜索引擎下拉食云速捷详细_「seo推广技术」seo关键词软件首要云速捷安全
  4. Kotlin重载操作符和约定声明规则
  5. 5年为山西提供超5万岗位,2000万互联网众包用户,百度智能云数据众包高速增长
  6. 在TTF字体库查找指定的字符
  7. stata 导出 相关系数表_Stata高效输入:搜狗输入法自定义短语
  8. codeforces 1328 C. Ternary XOR(贪心)
  9. 影视后期制作学习(AE)(三维动画)(成果在视频部分)
  10. 中国手持式红外测温仪市场深度研究分析报告(2021)