随着深度学习模型参数量的增加,现有GPU加载一个深度模型(尤其是预训练模型)后,剩余显存无法容纳很多的训练数据,甚至会仅能容纳一条训练数据。

梯度累积(Gradient Accumulation)是一种不需要额外硬件资源就可以增加批量样本数量(Batch Size)的训练技巧。这是一个通过时间换空间的优化措施,它将多个Batch训练数据的梯度进行累积,在达到指定累积次数后,使用累积梯度统一更新一次模型参数,以达到一个较大Batch Size的模型训练效果。累积梯度等于多个Batch训练数据的梯度的平均值

TensorFlow 2.0中梯度累积的代码实现:

# 给定累积次数
accum_steps = 5
# 模型训练
for step, x_batch_train in enumerate(x_train):with tf.GradientTape() as tape:y_pred = model(x_batch_train,training=True)loss_value = loss_fn(y_batch_train, y_pred)# 计算梯度if step == 0:accum_grads = [tf.Variable(tf.zeros_like(tv), trainable=False) for tv in self.trainable_variables]accum_grads = [accum_grads[i].assign_add(grad / accum_batch) for i, grad in enumerate(                        self.grad_clipping(  # 梯度裁剪tape.gradient(loss, self.trainable_variables), gard_theta))]# 参数更新if (step + 1) % accum_batch == 0:self.optimizer.apply_gradients(zip(accum_grads, self.trainable_variables))accum_grads = [tv.assign(tf.zeros_like(tv)) for tv in accum_grads]

梯度累积(Gradient Accumulation)相关推荐

  1. 梯度累加(Gradient Accumulation)

    受显存限制,运行一些预训练的large模型时,batch-size往往设置的比较小1-4,否则就会'CUDA out of memory',但一般batch-size越大(一定范围内)模型收敛越稳定效 ...

  2. 通俗理解深度学习梯度累加(Gradient Accumulation)的原理

    首先你得明白什么是梯度,可以看我之前写的一篇博客 : 微分与梯度的概念理解 本质上,梯度是一种方向导数,是一个矢量,因此这里的梯度累加并不是简单的相加,而是类似于初高中物理学的力的合成,梯度作为一种方 ...

  3. pytorch 梯度累积(gradient accumulation)

    梯度累积 - gradient accumulation 在深度学习训练的时候,数据的batch size大小受到GPU内存限制,batch size大小会影响模型最终的准确性和训练过程的性能.在GP ...

  4. Gradient Accumulation 梯度累加 (Pytorch)

    我们在训练神经网络的时候,batch_size的大小会对最终的模型效果产生很大的影响.一定条件下,batch_size设置的越大,模型就会越稳定.batch_size的值通常设置在 8-32 之间,但 ...

  5. PyTorch中的梯度累积

    我们在训练神经网络的时候,超参数batch_size的大小会对模型最终效果产生很大的影响,通常的经验是,batch_size越小效果越差:batch_size越大模型越稳定.理想很丰满,现实很骨感,很 ...

  6. pytorch DDP加速之gradient accumulation设置

    pytorch DDP 参考:https://zhuanlan.zhihu.com/p/250471767 GPU高效通信算法-Ring Allreduce: https://www.zhihu.co ...

  7. [源码解析] 深度学习流水线并行GPipe (2) ----- 梯度累积

    [源码解析] 深度学习流水线并行GPipe (2) ----- 梯度累积 文章目录 [源码解析] 深度学习流水线并行GPipe (2) ----- 梯度累积 0x00 摘要 0x01 概述 1.1 前 ...

  8. AI系统——梯度累积算法

    明天博士论文要答辩了,只有一张12G二手卡,今晚通宵要搞定10个模型实验 挖槽,突然想出一个T9开天霹雳模型,加载不进去我那张12G的二手卡,感觉要错过今年上台Best Paper领奖 上面出现的问题 ...

  9. Tensorflow中的各种梯度处理gradient

    最近其实一直想自己手动创建op,这样的话好像得懂tensorflow自定义api/op的规则,设计前向与反向,注册命名,注意端口以及文件组织,最后可能还要需要重新编译才能使用.这一部分其实记得tens ...

最新文章

  1. php算出明天的日期,PHP获取昨天、今天及明天日期的方法
  2. java8 stream中的惰性求值
  3. 实战SSM_O2O商铺_04自下而上逐步整合SSM
  4. cocos2d menu菜单类
  5. ASP.NET 页面对象模型
  6. 【人生】不管你挣多少, 钱永远是问题
  7. 独家 | 李飞飞亲口跟我们说:离职Google是假新闻
  8. 计算机考试如何添加打印机,如何添加网络打印机?
  9. 《人类简史》《未来简史》读后感作文5000字
  10. 测试人收入情况大曝光,你的收入在什么水平
  11. 大专学历计算机专业可以积分,持有大专紧缺急需专业可直接申请上海居住证积分?...
  12. 使用python-docx实现对word文档里的字符串、图片批量替换
  13. 知识欠缺到沙漠化了吧
  14. Error occurred while trying to proxy request项目突然起不来了
  15. Ionic4--路由跳转
  16. GRUB4DOS详解
  17. 华为公司员工待遇全面揭秘 选择自 CQP 的 Blog
  18. 【STM32Cube_23】使用USART接收GPS数据并解析(L80-R)
  19. CCD和CMOS图像传感器的快门
  20. R语言:ggplot2画带误差棒的组合折线图教程。

热门文章

  1. 我奋斗了18年、不是为了和你喝咖啡
  2. C++11 enable_if 详解
  3. 调用百度ai接口实现图片文字识别详解
  4. 服务器维护配件,服务器维修,服务器升级,服务器配件,磁盘柜维修及维护
  5. 那一年,我与电脑结下了不解之缘
  6. win10 redis集群搭建 ruby
  7. DataFrame按照时间分组然后求平均
  8. QQ拼音直接提权WIN8
  9. ubuntu 16.04 和 18.04 替换apt源为阿里源
  10. [C++]小根堆 插入/删除/初始化