作者 | Elvanth@知乎

来源 | https://zhuanlan.zhihu.com/p/377155682
编辑 | 极市平台
本文仅作学术交流,版权归原作者所有,如有侵权请联系删除。

导读

本文所分析的问题与解决方案将在最近发布的pytorch版本中解决;因此解决所有烦恼的根源是方法,更新pytorch~ >>

一个快捷的解决方案:

def worker_init_fn(worker_id):worker_seed = torch.initial_seed() % 2**32np.random.seed(worker_seed)random.seed(worker_seed)ds = DataLoader(ds, 10, shuffle=False, num_workers=4, worker_init_fn=worker_init_fn)

01 关于pytorch数据集随机种子的基本认识

在pytorch中random、torch.random等随机值产生方法一般没有问题,只有少数工人运行也可以保障其不同的最终值.

np.random.seed 会出现问题的原因是,当多处理采用 fork 方式产生子进程时,numpy 不会对不同的子进程产生不同的随机值.

换言之,当没有多处理使用时,numpy 不会出现随机种子的不同的问题;实验代码的可复现性要求一个是工人种子 ,即工人内包括numpy,random,torch.random所有的随机表现;另一个是Base ,即程序运行后的初始随机值,其可以通过以下两种方式产生

  1. torch.manual_seed(base_seed)

  2. 由特定的seed generator设置

generator = torch. Generator()
g.manual_seed(base_seed)
DataLoader(dataset, ..., generator=generator)

使用spawn模式可以斩断以上所有烦恼.

02 直接在网上搜这个问题会得到什么答案

参考很多的解决方案时,往往会提出以下功能:

def worker_init_fn(worker_id):np.random.seed(np.random.get_state()[1][0] + worker_id)

让我们看看它的输出结果:
(第0,3列是索引,第1,4列是np.random的结果,第2,5列是random.randint的结果)

epoch 0
tensor([[    0,  5125, 13588,     0, 15905, 23182],[    1,  7204, 19825,     0, 13653, 25225]])
tensor([[    2,  1709, 11504,     0, 12842, 23238],[    3,  5715, 14058,     0, 15236, 28033]])
tensor([[    4,  1062, 11239,     0, 10142, 29869],[    5,  6574, 15672,     0, 19623, 25600]])
============================================================
epoch 1
tensor([[    0,  5125, 18134,     0, 15905, 28990],[    1,  7204, 13206,     0, 13653, 25106]])
tensor([[    2,  1709, 15512,     0, 12842, 29703],[    3,  5715, 14201,     0, 15236, 27696]])
tensor([[    4,  1062, 13994,     0, 10142, 23411],[    5,  6574, 18532,     0, 19623, 21744]])
============================================================

假设上述方案对一个时代内可以防止不同的工人出现随机值相同的情况,但不同的时代之间,其最终的随机种子仍然是不变的。

03 那应该如何解决

来自pytorch官方的解决方案:

https://github.com/pytorch/pytorch/pull/56488#issuecomment-825128350

def worker_init_fn(worker_id):worker_seed = torch.initial_seed() % 2**32np.random.seed(worker_seed)random.seed(worker_seed)ds = DataLoader(ds, 10, shuffle=False, num_workers=4, worker_init_fn=worker_init_fn)

来自numpy.random原作者的解决方案:

https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562

def worker_init_fn(id):process_seed = torch.initial_seed()# Back out the base_seed so we can use all the bits.base_seed = process_seed - idss = np.random.SeedSequence([id, base_seed])# More than 128 bits (4 32-bit words) would be overkill.np.random.seed(ss.generate_state(4))ds = DataLoader(ds, 10, shuffle=False, num_workers=4, worker_init_fn=worker_init_fn)

一个更简单但不保证正确性的解决方案:

def worker_init_fn(worker_id):np.random.seed((worker_id + torch.initial_seed()) % np.iinfo(np.int32).max)ds = DataLoader(ds, 10, shuffle=False, num_workers=4, worker_init_fn=worker_init_fn)

04 附上可运行的完整文件

import numpy as np
import random
import torch# np.random.seed(0)class Transform(object):def __init__(self):passdef __call__(self, item = None):return [np.random.randint(10000, 20000), random.randint(20000,30000)]class RandomDataset(object):def __init__(self):passdef __getitem__(self, ind):item = [ind, np.random.randint(1, 10000), random.randint(10000, 20000), 0]tsfm =Transform()(item)return np.array(item + tsfm)def __len__(self):return 20from torch.utils.data import DataLoaderdef worker_init_fn(worker_id):np.random.seed(np.random.get_state()[1][0] + worker_id)ds = RandomDataset()
ds = DataLoader(ds, 10, shuffle=False, num_workers=4, worker_init_fn=worker_init_fn)for epoch in range(2):print("epoch {}".format(epoch))np.random.seed()for batch in ds:print(batch)

如果觉得有用,就请分享到朋友圈吧!


往期精彩回顾适合初学者入门人工智能的路线及资料下载机器学习及深度学习笔记等资料打印机器学习在线手册深度学习笔记专辑《统计学习方法》的代码复现专辑
AI基础下载机器学习的数学基础专辑黄海广老师《机器学习课程》视频课
本站qq群851320808,加入微信群请扫码:

【深度学习】PyTorch 数据集随机值的完美实践相关推荐

  1. 【 数据集加载 DatasetDataLoader 模块实现与源码详解 深度学习 Pytorch笔记 B站刘二大人 (7/10)】

    数据集加载 Dataset&DataLoader 模块实现与源码详解 深度学习 Pytorch笔记 B站刘二大人 (7/10) 模块介绍 在本节中没有关于数学原理的相关介绍,使用的数据集和类型 ...

  2. 纽约大学深度学习PyTorch课程笔记(自用)Week2

    纽约大学深度学习PyTorch课程笔记Week2 2. Week2 2.1 梯度下降和反向传播算法导论 2.1.1 梯度下降优化算法 参数化模型 梯度下降 2.1.2 在传统神经网络中随机梯度下降和反 ...

  3. torch的拼接函数_从零开始深度学习Pytorch笔记(13)—— torch.optim

    前文传送门: 从零开始深度学习Pytorch笔记(1)--安装Pytorch 从零开始深度学习Pytorch笔记(2)--张量的创建(上) 从零开始深度学习Pytorch笔记(3)--张量的创建(下) ...

  4. 伯禹公益AI《动手学深度学习PyTorch版》Task 07 学习笔记

    伯禹公益AI<动手学深度学习PyTorch版>Task 07 学习笔记 Task 07:优化算法进阶:word2vec:词嵌入进阶 微信昵称:WarmIce 优化算法进阶 emmmm,讲实 ...

  5. 纽约大学深度学习PyTorch课程笔记(自用)Week3

    纽约大学深度学习PyTorch课程笔记Week3 Week 3 3.1 神经网络参数变换可视化及卷积的基本概念 3.1.1 神经网络的可视化 3.1.2 参数变换 一个简单的参数变换:权重共享 超网络 ...

  6. 【动手学深度学习PyTorch版】27 数据增强

    上一篇请移步[动手学深度学习PyTorch版]23 深度学习硬件CPU 和 GPU_水w的博客-CSDN博客 目录 一.数据增强 1.1 数据增强(主要是关于图像增强) ◼ CES上的真实的故事 ◼ ...

  7. 深度学习PyTorch笔记(12):线性神经网络——softmax回归

    深度学习PyTorch笔记(12):线性神经网络--softmax回归 6 线性神经网络--softmax回归 6.1 softmax回归 6.1.1 概念 6.1.2 softmax运算 6.2 图 ...

  8. 纽约大学深度学习PyTorch课程笔记(自用)Week6

    纽约大学深度学习PyTorch课程笔记Week6 Week 6 6.1 卷积网络的应用 6.1.1 邮政编码识别器 使用CNN进行识别 6.1.2 人脸检测 一个多尺度人脸检测系统 6.1.3 语义分 ...

  9. 伯禹公益AI《动手学深度学习PyTorch版》Task 04 学习笔记

    伯禹公益AI<动手学深度学习PyTorch版>Task 04 学习笔记 Task 04:机器翻译及相关技术:注意力机制与Seq2seq模型:Transformer 微信昵称:WarmIce ...

最新文章

  1. 视频直播技术详解(0)开篇
  2. 【第一章】 Spring概述 —— 跟我学Spring3
  3. 反编译中内部类调用外部类成员问题
  4. leetcode 208. Implement Trie (Prefix Tree) | 208. 实现 Trie 前缀树(Java)
  5. NPM包管理器跟换国内镜像CNPM
  6. Scintilla 3 24在MFC中的使用 动态 静态
  7. 【janino】janino 加载自定义函数
  8. 一个ip对应多个域名多个ssl证书配置-Nginx实现多域名证书HTTPS
  9. 在MFC中调用DLL .
  10. 类中的__init__()
  11. live2d动态壁纸android,Live2DViewerEX动态壁纸
  12. 阿里云云原生一体化数仓 - 数据安全能力解读
  13. 如何避免计算机被别人共享,win7如何防止别人偷窥电脑 win7防止别人偷窥电脑操作方法...
  14. 计算机怎么切换到音乐,win10系统如何快速切换到下一首歌曲?
  15. Python实现测量平差数据处理
  16. 图像分类经典卷积神经网络—SENet论文翻译(中英文对照版)—Squeeze-and-Excitation Networks(挤压和激励网络)
  17. 将centos7打造成桌面系统
  18. 计算机系统的性能建模与设计 排队论实战,计算机系统的性能建模与设计:排队论实战(计算机科学丛书)...
  19. 视频抠像边缘模拟真实光照AE/PR插件 Light Wrap Fantastic
  20. 永远不要去依赖别人_经典语录:不要轻易去依赖一个人,它会成为你的习惯

热门文章

  1. Python 零碎信息-基础 02
  2. 一款WP小游戏代码分享
  3. (C#) 调用执行批处理文件
  4. C# .Net中的类型转换
  5. 用react-service做状态管理,适用于react、react native
  6. 【Oracle】PL/SQL Developer使用技巧(持续更新中)
  7. UploadHandleServlet
  8. NoSQL 非关系数据库
  9. 在往sql server 插入数据时 报此错误“ 消息 8152,级别 16,状态 14,第 1 行 将截断字符串或二进制数据。”...
  10. tomcat - JVM 配置