引言

我们神经网络跑模型时会发现相同的超参每次的结果都会不同,因为神经网络算法利用了随机性,比如初始化随机权重,因此用同样的数据训练同一个网络会得到不同的结果。初学者可能会有些懵圈,因为算法表现得不太稳定。但实际上它们就是这么设计的。随机初始化可以让网络通过学习,得到一个所学函数的很好的近似。
然而有的时候结果会相差过多,很难复现。所以我们需要保证prtorch的可重复性。

现象

对于同一个模型和同一批训练集以及测试集,我们都赋予同样的超参数,然而对于不同的进程接过去截然不同,下面贴几个结果图:
尽管两个进程都将训练集的Loss训到了0.0,但是在测试集的准确率却有差别。

分析及原因

其实本质上是因为我们神经网络中有很多随机性操作,例如BN、dropout以及在我们选取训练数据时的shuffer和随即裁剪等等。
此外GPU和CPU运算结果有时也不一致。

解决方法

我们对于上述随机的控制可以加入随机种子,具体的随机种子加入可以分为三个部分:

  1. Pthon/Numpy 随机种子
import random
import numpy as np
random.seed(seed)
np.random.seed(seed)
  1. Pytorch种子
torch.manual_seed(seed)            # 为CPU设置随机种子
torch.cuda.manual_seed(seed)       # 为当前GPU设置随机种子
torch.cuda.manual_seed_all(seed)   # 为所有GPU设置随机种子`
  1. CUDNN种子(控制GPU)
from torch.backends import cudnn
cudnn.benchmark = False
cudnn.deterministic = True`

自己的实践

对于Section2的实验结果,自己也是设置了种子,设置如下:

// An highlighted blocknp.random.seed(args.seed)torch.manual_seed(args.seed)torch.cuda.manual_seed(args.seed)torch.cuda.manual_seed_all(args.seed)cudnn.benchmark = Truetorch.backends.cudnn.deterministic = True

然而结果也不如人意,想了想进行了以下更改:
增加了对python的随机种子,因为可能读取数据中用了随机化。
另外将benchmark 设为 False,牺牲速度,换取精度,更改如下:

  np.random.seed(args.seed)torch.manual_seed(args.seed)torch.cuda.manual_seed(args.seed)torch.cuda.manual_seed_all(args.seed)random.seed(args.seed)              ##cudnn.benchmark = False             ##torch.backends.cudnn.deterministic = True

目前还不知道训练是否稳定,期待后续。

拓展

其实还有一些小因素影响到了我们模型的重现能力。一个是如果dataloader采用了多线程(num_workers > 1), 那么由于读取数据的顺序不同,最终运行结果也会有差异;另一个可能是数据的shuffer。
另外我们说一些cudnn.benchmark,设置 torch.backends.cudnn.benchmark=True 将会让程序在开始时花费一点额外时间,为整个网络的每个卷积层搜索最适合它的卷积实现算法,进而实现网络的加速。适用场景是网络结构固定(不是动态变化的),网络的输入形状(包括 batch size,图片大小,输入的通道)是不变的,其实也就是一般情况下都比较适用。反之,如果卷积层的设置一直变化,将会导致程序不停地做优化,反而会耗费更多的时间,然而可能会选择训练不稳定,所以我们将torch.backends.cudnn.benchmark=False

Pytorch的坑之 训练结果不太稳定,无法复现训练结果?相关推荐

  1. 太棒了!PyTorch 1.7发布,支持CUDA 11、Windows分布式训练

    点上方蓝字计算机视觉联盟获取更多干货 在右上方 ··· 设为星标 ★,与你不见不散 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:机器之心 AI博士笔记系列推荐 周志华<机器学习> ...

  2. pytorch .item_从数据到模型,你可能需要1篇详实的pytorch踩坑指南

    原创 · 作者 | Giant 学校 | 浙江大学 研究方向 | 对话系统.text2sql 熟悉DL的朋友应该知道Tensorflow.Pytorch.Caffe这些成熟的框架,它们让广大AI爱好者 ...

  3. PyTorch 入坑七:模块与nn.Module学习

    PyTorch 入坑七 模型创建概述 PyTorch中的模块 torch模块 torch.Tensor模块 torch.sparse模块 torch.cuda模块 torch.nn模块 torch.n ...

  4. PyTorch入坑(一)~(三): Tensor的概念,基本操作和线性回归

    PyTorch 一文入门 PyTorch 入坑一:数据类型与Tensor的概念 PyTorch数据类型 Tensor的概念 Tensor与Variable Variable Tensor Tensor ...

  5. Pytorch踩坑记录:关于用net.eval()和with no grad装饰器计算结果不一样的问题

    Pytorch踩坑记录 相同点 net.eval()和with toch.no_grad()的相同点:都停止反向传播 不同点: 1.net.eval() 用net.eval(),此时BN层会用训练时的 ...

  6. Pytorch实现戴口罩人脸检测和戴口罩识别(含训练代码 戴口罩人脸数据集)

    Pytorch实现戴口罩人脸检测和戴口罩识别(含训练代码 戴口罩人脸数据集) 目录 Pytorch实现戴口罩人脸检测和戴口罩识别(含训练代码 戴口罩人脸数据集) 1.戴口罩识别的方法 (1)基于多类别 ...

  7. PyTorch 1.7发布,支持CUDA 11、Windows分布式训练

    机器之心报道 参与:魔王.小舟 昨日,PyTorch 团队发布 PyTorch 1.7 版本.该版本增添了很多新特性,如支持 CUDA 11.Windows 分布式训练.增加了支持快速傅里叶变换(FF ...

  8. python的自带数据集_解决Keras自带数据集与预训练model下载太慢问题

    keras的数据集源码下载地址太慢.尝试过修改源码中的下载地址,直接报错. 从源码或者网络资源下好数据集,下载好以后放到目录 ~/.keras/datasets/ 下面. 其中:cifar10需要改文 ...

  9. PyTorch 入坑六 数据处理模块Dataloader、Dataset、Transforms

    深度学习中的数据处理概述 深度学习三要素:数据.算力和算法 在工程实践中,数据的重要性越来越引起人们的关注.在数据科学界流传着一种说法,"数据决定了模型的上限,算法决定了模型的下限" ...

最新文章

  1. 拆解交易系统--异地多活
  2. 多文多面阐述HMM很清晰
  3. Gesture Based TableView
  4. Centos7.2部署DHCP服务
  5. python【蓝桥杯vip练习题库】ADV-301 字符串压缩
  6. 《SQL初学者指南(第2版)》——2.4 指定列
  7. matlab画gds图,如何将图片转换为.gds文件?(转)
  8. 怎么用python画房子_怎么用python画小猪佩奇
  9. 开发者经常用到的75 个功能强大的 jQuery插件和教程汇总(上篇)
  10. c++ 文件操作方式
  11. 【优化算法】天牛须搜索优化粒子群算法【含Matlab源码 1256期】
  12. LayUI的后台管理模板
  13. 爱,是尘世间人人追求的人生之最,是生活中无处不在的美
  14. javaweb学习(5)--Cookie
  15. 好书分享、能量传递-《软技能 代码之外的生存指南》自我营销篇
  16. 解开关于人工智能的六个迷思
  17. 测试还是国外的香?走进海外测试开发工程师
  18. 仿新浪微博发布时 @ 及 #某话题# 的效果
  19. ffmpeg实现视频实时动态时间水印
  20. 一篇文章带你快速上手Airtest和Poco

热门文章

  1. 二、肿瘤发展进程、全基因组突变发生顺序(The evolutionary history of 2,658 cancers)
  2. 支付宝踩过的坑sign check fail: check Sign and Data Fail��JSON also��
  3. python3命令需要使用命令行开发者工具_3 个 Python 命令行工具
  4. 沈阳理工大学c语言考研初试题,2020沈阳理工大学C语言程序设计考研考试大纲
  5. 在桌面建立 LNK 快捷方式
  6. error LNK2001: unresolved external symbol _WinMain@16debug/main.exe:fatal
  7. 如何正确的使用Photoshop进行图像的二值化(详细步骤)刘博士
  8. Genesis脚本---自动输出Gberber274格式资料 脚本
  9. 旅游景区景点购票小程序毕业设计毕设(后台java的springboot框架)
  10. coreldraw 导入面料_CDR真强!支持导入的文件格式这么多