搬来了定型设置的方法,深度学习在训练过程中,由于随机初始化,样本读取的随机性,导致重复的实验结果会有差别,个别情况甚至波动较大。一般论文为了严谨,实验结论能够复现/可重复,通常采取固定随机种子使得结果确定

确定性设置

1 随机种子设置

随机函数是最大的不确定性来源,包含了模型参数的随机初始化,样本的shuffle。

  • PyTorch 随机种子

  • python 随机种子

  • numpy 随机种子

# PyTorch
import torch
torch.manual_seed(0)# python
import random
random.seed(0)# Third part libraries
import numpy as np
np.random.seed(0)

CPU版本下,上述随机种子设置完成之后,基本就可实现实验的可复现了。

对于GPU版本,存在大量算法实现为不确定结果的算法,这种算法实现效率很高,但是每次返回的值会不完全一样。主要是由于浮点精度舍弃,不同浮点数以不同顺序相加,值可能会有很小的差异(小数点最末位)。

2 GPU算法确定性实现

GPU算法的不确定来源有两个

  • CUDA convolution benchmarking

  • nondeterministic algorithms

CUDA convolution benchmarking 是为了提升运行效率,对模型参数试运行后,选取最优实现。不同硬件以及benchmarking本身存在噪音,导致不确定性

nondeterministic algorithms:GPU最大优势就是并行计算,如果能够忽略顺序,就避免了同步要求,能够大大提升运行效率,所以很多算法都有非确定性结果的算法实现。通过设置use_deterministic_algorithms,就可以使得pytorch选择确定性算法。

# 不需要benchmarking
torch.backends.cudnn.benchmark=False# 选择确定性算法
torch.use_deterministic_algorithms()

RUNTIME ERROR

对于一个PyTorch 的函数接口,没有确定性算法实现,只有非确定性算法实现,同时设置了use_deterministic_algorithms(),那么会导致运行时错误。比如:

>>> import torch
>>> torch.use_deterministic_algorithms(True)
>>> torch.randn(2, 2).cuda().index_add_(0, torch.tensor([0, 1]), torch.randn(2, 2))
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: index_add_cuda_ does not have a deterministic implementation, but you set
'torch.use_deterministic_algorithms(True)'. ...

错误原因:

index_add没有确定性的实现,出现这种错误,一般都是因为调用了torch.index_select 这个api接口,或者直接调用tensor.index_add_。

解决方案:

自己定义一个确定性的实现,替换调用的接口。对于torch.index_select 这个接口,可以有如下的实现。

def deterministic_index_select(input_tensor, dim, indices):"""input_tensor: Tensordim: dim indices: 1D tensor"""tensor_transpose = torch.transpose(x, 0, dim)return tensor_transpose[indices].transpose(dim, 0)

样本读取随机

  1. 多线程情况下,设置每个线程读取的随机种子

  2. 设置样本generator

# 设置每个读取线程的随机种子
def seed_worker(worker_id):worker_seed = torch.initial_seed() % 2**32numpy.random.seed(worker_seed)random.seed(worker_seed)g = torch.Generator()
# 设置样本shuffle随机种子,作为DataLoader的参数
g.manual_seed(0)DataLoader(train_dataset,batch_size=batch_size,num_workers=num_workers,worker_init_fn=seed_worker,generator=g,
)

有点短哦~~   whaosoft aiot http://143ai.com

PyTorch设置可复现/重复实验相关推荐

  1. Ecol. Lett. | 生态学实验设计中“梯度实验”和“重复实验”的抉择

    本文转载自"生态学文献分享",已获授权 To replicate, or not to replicate – that is the question: how to tackl ...

  2. Android Studio 如何导出和导入自己的常用设置,避免重复制造轮子。加快开发速度...

    Android Studio 如何导出和导入自己的常用设置,避免重复制造轮子.加快开发速度 作者:程序员小冰,CSDN博客:http://blog.csdn.net/qq_21376985 在使用 A ...

  3. CSS设置图片的重复

    CSS设置图片的重复 CSS通过设置background-repeat属性来设置图片的重复方式,包括水平重复.竖直重复和不重复等. 图片的竖直方向重复. <span style="fo ...

  4. JavaScript通过变量设置对象键[重复]

    本文翻译自:JavaScript set object key by variable [duplicate] This question already has answers here : 这个问 ...

  5. 电脑桌面便签怎么设置按天重复提醒每日便签事项?

    支持多端同步功能的云便签支持Windows电脑PC版使用.它不仅可以设置单次提醒,还可以按天.周.月.季度以及年的规则设置重复提醒,那么该电脑版桌面便签分类怎么设置按天重复提醒呢? 一.打开已登录的电 ...

  6. Pytorch设置随机种子

    一.网上方法 # 定义一个可以设置随机种子的函数 def setup_seed(seed):torch.manual_seed(seed)torch.cuda.manual_seed_all(seed ...

  7. 【文本数据挖掘】中文命名实体识别:HMM模型+BiLSTM_CRF模型(Pytorch)【调研与实验分析】

    1️⃣本篇博文是[文本数据挖掘]大作业-中文命名实体识别-调研与实验分析 2️⃣在之前的自然语言课程中也完成过一次命名实体识别的实验 [一起入门NLP]中科院自然语言处理作业三:用BiLSTM+CRF ...

  8. Laravel5.2队列驱动expire参数设置带来的重复执行问题 数据库驱动

    'connections' => [....'database' => ['driver' => 'database','table' => 'jobs','queue' =& ...

  9. 证据积累聚类集成算法(Evidence Accumulation Clustering)代码复现与实验

    1. 基本环境 运行环境: - Python 3.7 + - Jupyter NoteBook - 处理器:2.6 GHz 六核Intel Core i7 2. 聚类集成代码 # 导入包 import ...

  10. Matlab多次重复实验记录结果,MATLAB数据处理实验记录与总结.doc

    MATLAB实验报告 学 号实验名称MATLAB数据处理实验实验目的掌握二维曲线图.三维曲线图.三维曲面图的绘制方法 掌握常用统计图的绘制方法 熟悉三维图形常用编辑方法 了解动画的绘制方法实验记录1. ...

最新文章

  1. javascript中new Date浏览器兼容性处理
  2. 条件随机场CRF HMM,MEMM的区别
  3. 棋盘问题 POJ - 1321
  4. flv 开源 修复_解决开源项目错误和修复的5个步骤
  5. yum 安装 tomcat
  6. 【SaaS - Export项目】用户登录,显示,退出 删除session中的用户信息 销毁session
  7. 真人发音计算机在线用,文字转语音真人发声在线怎么转换?这种操作最简单
  8. 金鹰卡通java面试_两则电视栏目招募通告,来试试?!
  9. 70级圣骑士OK了,纪念下先!
  10. js 格式化prettier配置_使Prettier一键格式化WXSS
  11. 山东理工ACM【1009】Elevator
  12. AVR32单片机 矩阵按键 按键键值函数解析
  13. 安卓手机用AidLux安装Linux免Root,安装到Debian 10不能安装docker
  14. html如何删除网页边框,如何从HTML表中完全删除边框
  15. 华为区块链项目总监: 华为区块链率先于溯源场景落地
  16. 腾讯云OCR(印刷体识别) API使用
  17. 在合并单元格设置编号—“count-a函数”的使用
  18. 奶瓶仔xp主题【主题世界】
  19. 「论文翻译」Graph convolutional networks for computational drug development and discovery
  20. android 虚拟键盘控制

热门文章

  1. Caused by: org.springframework.amqp.AmqpException: No method found for class java.lang.String
  2. java 扩展库_JAVA API的扩展库详解
  3. 计算机辅助几何造型技术知识点,西北工业大学2018博士招生计算机辅助几何造型技术考试大纲...
  4. 深度学习——TensorFlow1.0GPU环境构建
  5. 什么是mCherry?
  6. C#操作CAD-读取和修改数据
  7. ROS机器人020-机器人xacro与rviz室内slam建图仿真
  8. 用VR观看LoL,这到底是如何实现的
  9. 无敌药膏程序员要不要转行?
  10. 有些汽车音响改进意见