最近在做图像分类实验时,在4个gpu上使用pytorch的DataParallel 函数并行跑程序,批次为16时会报如下所示的错误:
  RuntimeError: CUDA out of memory. Tried to allocate 858.00 MiB (GPU 3; 10.92 GiB total capacity; 10.10 GiB already allocated; 150.69 MiB free; 10.13 GiB reserved in total by PyTorch)

  实验发现,每块gpu最多可以跑2条数据,但是我又想设置batch_size=16,参考https://zhuanlan.zhihu.com/p/86441879了解到transformer-XL官方写的BalancedDataParallel 函数,用来解决DataParallel 显存使用不平衡的问题(参考代码见最后)。
  为了理解BalancedDataParallel 函数用法,我们先来弄清楚几个问题。
1,DataParallel 函数是如何工作的?
  首先将模型加载到主 GPU 上,然后再将模型复制到各个指定的从 GPU 中,然后将输入数据按 batch 维度进行划分,具体来说就是每个 GPU 分配到的数据 batch 数量是总输入数据的 batch 除以指定 GPU 个数。每个 GPU 将针对各自的输入数据独立进行 forward计算,之后会把计算结果传到主GPU 上完成梯度计算和参数更新,最后将更新后的参数复制到从 GPU 中,这样就完成了一次迭代计算。参考https://blog.csdn.net/zhjm07054115/article/details/104799661当gpu=2,batch_size=30时,我们可以从下图清楚的看到首先会在两个gpu上分别分配15条数据,进行forward计算,之后汇总结果再进行梯度计算和参数更新。
  我们可以看到反向传播计算和参数更新完全放在主gpu上进行的,这样会造成显存使用不平衡的问题。

2,梯度累加
  参考https://blog.csdn.net/wuzhongqiang/article/details/102572324做的反向传播梯度累加实验,发现pytorch在反向传播的时候,默认累加上了上一次求的梯度, 如果不想让上一次的梯度影响自己本次梯度计算的话,需要手动的清零。

  了解了DataParallel 函数和梯度累加后,我们就可以来解决显存使用不平衡问题以及如何在显存固定的情况下加大训练批次。
  首先,简单介绍BalancedDataParallel 用法【下图截取自https://github.com/Link-Li/Balanced-DataParallel】

  简单解释一下:当我们需要在3个gpu并行跑程序,每个gpu最多一次可以处理3条数据,分配是[3,3,3],那么3个gpu最多可以同时处理9条数据,也就是batch_size最大可设为9,因为主gpu上还要进行反向传播,所以这里我们设置主gpu处理2条数据,分布就是[2,3,3],batch_size=8。
  此时如果我们想加大批次,使得batch_size=16,那么分布应该是[4,6,6],但是我们知道每个gpu最多可以处理3条数据,这里就用到梯度累加的方法了,即上图中的acc_grad,acc_grad参数表示将batch_size分成多少份送入网络,当acc_grad=2,表示我们会先将16个数据分成2份,每份有8条数据,每次输入8条数据分给3个gpu做并行训练,forward计算结果放入主gpu上进行反向传播,由于梯度可以累加,循环两次后,再更新参数。这样做不仅可以缓解显存不平衡问题也可以解决显存不足的问题。
  下面是我根据https://blog.csdn.net/zhjm07054115/article/details/104799661做了修改,加上BalancedDataParallel 完整代码:

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from data_parallel_balance import BalancedDataParallel# Dataset
class RandomDataset(Dataset):def __init__(self, size, length):self.len = lengthself.data = torch.randn(length, size)self.target=np.random.randint(3,size=length)def __getitem__(self, index):label=torch.tensor(self.target[index])return self.data[index],labeldef __len__(self):return self.len# model
class Model(nn.Module):def __init__(self, input_size, output_size):super(Model, self).__init__()self.fc = nn.Linear(input_size, output_size)def forward(self, input):output = self.fc(input)print("\tIn Model: input size", input.size(),"output size", output.size())return output# trian
def train(rand_loader,model,optimizer,criterion):train_loss=0# trainmodel.train()optimizer.zero_grad()for image,target in rand_loader:print('image:',image.shape)if batch_chunk > 0:image_chunks = torch.chunk(image, batch_chunk, 0)target_chunks = torch.chunk(target, batch_chunk, 0)for i in range(len(image_chunks)):print('image_chunks:',i)img=image_chunks[i].to(device)lab=target_chunks[i].to(device)out=model(img)print("Chunks_Outputs: input size", img.size(),"output_size", out.size())loss=criterion(out,lab)# print('{} chunk,loss:{}.'.format(i,loss))train_loss+=loss.item()loss = loss.float().mean().type_as(loss) / len(image_chunks)loss.backward()else:image = image.to(device)target=target.to(device)output = model(image)loss=criterion(output,target)train_loss=loss.item()print("Outside: input size", image.size(),"output_size", output.size())optimizer.step()  return train_lossif __name__=="__main__":input_size = 5output_size = 3batch_size = 32data_size = 70batch_chunk=2gpu0_bsz=8epochs=2device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# datarand_loader = DataLoader(dataset=RandomDataset(input_size, data_size),batch_size=batch_size, shuffle=True)# model                 model = Model(input_size, output_size)if torch.cuda.device_count() > 1:print("Let's use", torch.cuda.device_count(), "GPUs!")if gpu0_bsz >= 0:   model = BalancedDataParallel(gpu0_bsz // batch_chunk, model, dim=0)else:model = nn.DataParallel(model)model.to(device)# optimizeroptimizer= torch.optim.SGD(model.parameters(), lr = 0.01, momentum = 0.9)# losscriterion=nn.CrossEntropyLoss()for epoch in range(epochs):print('Epoch:',epoch)train(rand_loader,model,optimizer,criterion)

参考:
pytorch多gpu并行训练
transformer-XL的官方代码
BalancedDataParallel 参考代码
PyTorch-4 nn.DataParallel 数据并行详解
Pytorch反向传播中的细节-计算梯度时的默认累加

欢迎大家留言批评指正!

pytorch多gpu DataParallel 及梯度累加解决显存不平衡和显存不足问题相关推荐

  1. pytorch 多GPU训练总结(DataParallel的使用)

    版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明. 本文链接:https://blog.csdn.net/weixin_40087578/arti ...

  2. Pytorch分布式训练/多卡训练(二) —— Data Parallel并行(DDP)(2.2)(代码示例)(BN同步主卡保存梯度累加多卡测试inference随机种子seed)

    DDP的使用非常简单,因为它不需要修改你网络的配置.其精髓只有一句话 model = DistributedDataPrallel(model, device_ids=[local_rank], ou ...

  3. Gradient Accumulation 梯度累加 (Pytorch)

    我们在训练神经网络的时候,batch_size的大小会对最终的模型效果产生很大的影响.一定条件下,batch_size设置的越大,模型就会越稳定.batch_size的值通常设置在 8-32 之间,但 ...

  4. Pytorch的nn.DataParallel详细解析

    前言 pytorch中的GPU操作默认是异步的,当调用一个使用GPU的函数时,这些操作会在特定设备上排队但不一定在稍后执行.这就使得pytorch可以进行并行计算.但是pytorch异步计算的效果对调 ...

  5. pytorch多gpu并行训练操作指南

    关注上方"深度学习技术前沿",选择"星标公众号", 资源干货,第一时间送达! 来源:知乎 作者:link-web 链接:https://zhuanlan.zhi ...

  6. pytorch多gpu并行训练

    pytorch多gpu并行训练 link-web 转自:pytorch多gpu并行训练 - 知乎 目录(目录不可点击) 说明 1.和DataParallel的区别 2.如何启动程序的时候 2.1 单机 ...

  7. [源码解析] PyTorch 分布式(2) ----- DataParallel(上)

    [源码解析] PyTorch 分布式(2) ----- DataParallel(上) 文章目录 [源码解析] PyTorch 分布式(2) ----- DataParallel(上) 0x00 摘要 ...

  8. 【计数网络】梯度累加增加LCFCN的BatchSize

    LCFCN是一个以分割网络为基础的专用于计数的网络. LCFCN模型由于loss的特殊性 batch size 目前只能为1 LCFCN代码 https://github.com/ElementAI/ ...

  9. Pytorch多GPU笔记

    Pytorch分布式笔记 Pytorch多GPU计算笔记 DP和DDP的区别 DP DDP Apex amp的使用 apex.parallel.DistributedDataParallel的使用 D ...

最新文章

  1. HDU 2206 IP的计算(字符串处理)
  2. 既可生成点云又可生成网格的超网络方法 ICML
  3. lua打开是二进制代码_物联网的构建:使用Lua高级语言进行嵌入式开发
  4. 一个页面区分管理者和普通用户如何设计_如何从「百度知道」中删除 bai du zhi dao?...
  5. storm-kafka编程指南
  6. 机器学习算法与Python实践之(二)k近邻(KNN)
  7. 用随机整数填充缺失值_输入一个整数值并在C中用零填充进行打印
  8. AOJ 6.Hero In Maze
  9. sklearn 机器学习 Pipeline 模板
  10. 一针一线皆关“云” 报喜鸟以匠心融合科技
  11. 陆正耀神州优车被强制执行超10亿
  12. 给P40让路!华为Mate 30 5G降至这个价,还贵吗?
  13. Asp.net MVC - 使用PRG模式(附源码)
  14. Qt sender()函数
  15. 计算两点间距离C++
  16. 敏捷团队中有效沟通的5种模式
  17. 信能阳光——打造国内体育照明领域的旗舰品牌
  18. 数据结构课程设计之排序综合
  19. 爬取猫眼票房数并数据可视化
  20. 英语口语100之每日十句口语

热门文章

  1. Unity Live2D的接入和使用
  2. softmax分类器 matlab,softmax原理及Matlab实现
  3. Unity3D关于Sprite packer和Packing tag的使用
  4. 【C语言】C语言操作符的分类及应用【超详细讲解】
  5. 如何构建全球实时音视频云及其海外网络传输优化
  6. SpringBoot整合第三方技术学习笔记(自用)
  7. javascript求1~100的素数和
  8. Rx第三部分--深入序列
  9. 华为eNSP配置dhcp 下发ipv4地址
  10. AJAX开发过程中的七宗罪