目录

1、导入必要库

2、加载数据

3、构建网络

4、训练模型

5、保存模型参数

1)、仅仅保存和加载模型参数

2)、保存和加载整个模型

3)、保存多个模型参数


1、导入必要库

import torch
from torch import optim, nn
import torch.utils.data as Data

2、加载数据

x = torch.linspace(1, 10, 10)       # x data (torch tensor)
y = torch.linspace(10, 1, 10)       # y data (torch tensor)# 注意:x的数据类型是 torch.FloatTensor
# y的数据类型是 torch.LongTensor
# x = torch.cat((x0, x1), 0).type(torch.FloatTensor)  # FloatTensor = 32-bit floating
# y = torch.cat((y0, y1), ).type(torch.LongTensor)    # LongTensor = 64-bit integer# 先转换成 torch 能识别的 Dataset
torch_dataset = Data.TensorDataset(x, y)# 把 dataset 放入 DataLoader
loader = Data.DataLoader(dataset=torch_dataset,      # torch TensorDataset formatbatch_size=3,      # mini batch sizeshuffle=True,               # 要不要打乱数据 (打乱比较好)num_workers=0,              # 多线程来读数据
)

3、构建网络

# 定义网络结构 build net
class Net(torch.nn.Module):def __init__(self,n_feature,n_hidden,n_output):super(Net, self).__init__()self.fc1 =torch.nn.Linear(n_feature,n_hidden)self.fc2 =torch.nn.Linear(n_hidden,n_output)# 定义一个前向传播过程函数def forward(self, x):x=F.relu(self.fc1(x))x=self.fc2(x)return x
# 实例化一个网络为 model
model = Net(n_feature=1,n_hidden=10,n_output=10)
print(model)

4、训练模型

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_func = nn.CrossEntropyLoss() # 训练模型
model.train()
for epoch in range(5):for step, (b_x, b_y) in enumerate(loader): output = model(b_x)loss = loss_func(output, b_y)optimizer.zero_grad()loss.backward()optimizer.step()# 测试模型
model.eval()
for step, (b_x, b_y) in enumerate(loader):output = model(b_x)loss = loss_func(output, b_y)_, pred_y = torch.max(output.data, 1)correct = (pred_y == b_y).sum()total = b_y.size(0)print('Epoch: ', step, '| test loss: %.4f' % loss.data.numpy(), '| test accuracy: %.2f' % (float(correct)/total))

5、保存模型参数

1)、仅仅保存和加载模型参数

# 保存模型参数
torch.save(model.state_dict(), './path/model.pkl')
# 读取模型参数
model.load_state_dict(torch.load('./path/model.pkl'))

2)、保存和加载整个模型

# 保存整个模型
torch.save(model,  './path/model.pkl')
# 加载整个模型
model = torch.load('./path/model.pkl')

3)、保存多个模型参数

# 多个模型参数保存
torch.save({'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss,...}, PATH)# 模型参数加载
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

Pytorch 模型训练步骤相关推荐

  1. 《PyTorch模型训练实用教程》—学习笔记

    文章目录 前言 数据 Dataset类 DataLoader类 transform 裁剪-Crop 翻转和旋转-Flip and Rotation 图像变换 对transforms操作,使数据增强更灵 ...

  2. 手把手教你洞悉 PyTorch 模型训练过程,彻底掌握 PyTorch 项目实战!(文末重金招聘导师)...

    (文末重金招募导师) 在CVPR 2020会议接收中,PyTorch 使用了405次,TensorFlow 使用了102次,PyTorch使用数是TensorFlow的近4倍. 自2019年开始,越来 ...

  3. PyTorch 模型训练实用教程(附代码)

    向AI转型的程序员都关注了这个号???????????? 机器学习AI算法工程   公众号:datayx PyTorch 能在短时间内被众多研究人员和工程师接受并推崇是因为其有着诸多优点,如采用 Py ...

  4. Pytorch模型训练实用教程学习笔记:四、优化器与学习率调整

    前言 最近在重温Pytorch基础,然而Pytorch官方文档的各种API是根据字母排列的,并不适合学习阅读. 于是在gayhub上找到了这样一份教程<Pytorch模型训练实用教程>,写 ...

  5. 9个让PyTorch模型训练提速的技巧!

    ↑↑↑关注后"星标"Datawhale 每日干货 & 每月组队学习,不错过 Datawhale干货 来源:AI公园,译者:ronghuaiyang 作者:William F ...

  6. 9个技巧让你的PyTorch模型训练变得飞快!

    公众号关注 "视学算法" 设为"星标",第一时间知晓最新干货~ 作者丨William Falcon 来源丨AI公园 不要让你的神经网络变成这样 让我们面对现实吧 ...

  7. 加速 PyTorch 模型训练的 9 个技巧

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 导读 一个step by step的指南,非常的实用. 不要让你的 ...

  8. 9 个技巧让你的 PyTorch 模型训练变得飞快!

    点击上方"AI遇见机器学习",选择"星标"公众号 重磅干货,第一时间送达 作者 | William Falcon 编译 | ronghuaiyang 来源 | ...

  9. 收藏 | 9 个技巧让你的 PyTorch 模型训练变得飞快!

    点上方蓝字计算机视觉联盟获取更多干货 在右上方 ··· 设为星标 ★,与你不见不散 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:作者 | William Falcon 编译 | rongh ...

最新文章

  1. 删除指定文件夹下的小于 4K的所用文件...
  2. 四川职业学校计算机专业那个好6,四川排名前六的单招学院那些专业比较好?
  3. FreeBSD学习总结
  4. 1021. Deepest Root (25)
  5. mysql连接不断线_某些小时后MySql连接自动掉线
  6. 生活在别处——“Samsung Cloud Print”云打印体验
  7. 4.10_composite_结构型模式:组合模式
  8. 《TensorFlow 2.0深度学习算法实战教材》学习笔记(三、TensorFlow 基础)
  9. dockerfile如何运行镜像内的脚本_Docker精华问答 | Docker commit如何用?
  10. 学业水平测试计算机考试软件,普通高中学业水平考试系统
  11. 基于C51控制蜂鸣器
  12. 《简明Python教程》学习笔记
  13. STM32库函数模板创建
  14. 区块链项目需要服务器吗,区块链需要服务器吗
  15. 【索引】Rujia Liu's Problems for Beginners
  16. android终端模拟器 apt,借贵吧问个安卓终端模拟器的问题
  17. Pymol获得蛋白中二级结构信息
  18. 拥有一本CISP证书,我的工资会翻倍吗?
  19. Windows 下Nexus搭建Maven私服
  20. MarkDown首行缩进和换行

热门文章

  1. DLookup使用详解
  2. Vue+Echarts+百度地图API
  3. YesPlayMusic:一个高颜值多音频资源的网易云音乐播放器
  4. ASAS-CoMoSpA研究: 评价SpA不同分类标准的表现
  5. @RequestParam和@RequestBody的使用
  6. 转行互联网,零基础应届生应该选择什么样的岗位作为切入点?
  7. 价目表制作,价目表小程序
  8. mysql慢查询导致502_MySQL Statement cancellation timer故障排查分享
  9. 记一次对钓鱼邮件的实地反击
  10. 【计算机网络】常见的HTTP报文头部信息