Pytorch 训练技巧

文章目录

  • Pytorch 训练技巧
    • 1、指定GPU编号
    • 2、查看模型每层输出详情
    • 3、梯度裁剪(Gradient Clipping)
    • 4、扩展单张图片维度
    • 5、独热编码
    • 6、防止验证模型时爆显存
    • 7、学习率衰减
    • 8、冻结某些层的参数

1、指定GPU编号

  • 设置当前使用的GPU设备仅为0号设备,设备名称为 /gpu:0os.environ["CUDA_VISIBLE_DEVICES"] = "0"
  • 设置当前使用的GPU设备为0,1号两个设备,名称依次为 /gpu:0/gpu:1os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" ,根据顺序表示优先使用0号设备,然后使用1号设备。

指定GPU的命令需要放在和神经网络相关的一系列操作的前面。

2、查看模型每层输出详情

Keras有一个简洁的API来查看模型的每一层输出尺寸,这在调试网络时非常有用。现在在PyTorch中也可以实现这个功能。

使用很简单,如下用法:

from torchsummary import summary
summary(your_model, input_size=(channels, H, W))

input_size 是根据你自己的网络模型的输入尺寸进行设置。

pytorch-summarygithub.com

3、梯度裁剪(Gradient Clipping)

import torch.nn as nnoutputs = model(data)
loss= loss_fn(outputs, target)
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), max_norm=20, norm_type=2)
optimizer.step()

nn.utils.clip_grad_norm_ 的参数:

  • parameters – 一个基于变量的迭代器,会进行梯度归一化
  • max_norm – 梯度的最大范数
  • norm_type – 规定范数的类型,默认为L2

4、扩展单张图片维度

因为在训练时的数据维度一般都是 (batch_size, c, h, w),而在测试时只输入一张图片,所以需要扩展维度,扩展维度有多个方法:

import cv2
import torchimage = cv2.imread(img_path)
image = torch.tensor(image)
print(image.size())img = image.view(1, *image.size())
print(img.size())# output:
# torch.Size([h, w, c])
# torch.Size([1, h, w, c])

import cv2
import numpy as npimage = cv2.imread(img_path)
print(image.shape)
img = image[np.newaxis, :, :, :]
print(img.shape)# output:
# (h, w, c)
# (1, h, w, c)

或(感谢知乎用户coldleaf的补充)

import cv2
import torchimage = cv2.imread(img_path)
image = torch.tensor(image)
print(image.size())img = image.unsqueeze(dim=0)
print(img.size())img = img.squeeze(dim=0)
print(img.size())# output:
# torch.Size([(h, w, c)])
# torch.Size([1, h, w, c])
# torch.Size([h, w, c])

tensor.unsqueeze(dim):扩展维度,dim指定扩展哪个维度。

tensor.squeeze(dim):去除dim指定的且size为1的维度,维度大于1时,squeeze()不起作用,不指定dim时,去除所有size为1的维度。

5、独热编码

在PyTorch中使用交叉熵损失函数的时候会自动把label转化成onehot,所以不用手动转化,而使用MSE需要手动转化成onehot编码。

import torch
class_num = 8
batch_size = 4def one_hot(label):"""将一维列表转换为独热编码"""label = label.resize_(batch_size, 1)m_zeros = torch.zeros(batch_size, class_num)# 从 value 中取值,然后根据 dim 和 index 给相应位置赋值onehot = m_zeros.scatter_(1, label, 1)  # (dim,index,value)return onehot.numpy()  # Tensor -> Numpylabel = torch.LongTensor(batch_size).random_() % class_num  # 对随机数取余
print(one_hot(label))# output:
[[0. 0. 0. 1. 0. 0. 0. 0.][0. 0. 0. 0. 1. 0. 0. 0.][0. 0. 1. 0. 0. 0. 0. 0.][0. 1. 0. 0. 0. 0. 0. 0.]]

Convert int into one-hot formatdiscuss.pytorch.org

6、防止验证模型时爆显存

验证模型时不需要求导,即不需要梯度计算,关闭autograd,可以提高速度,节约内存。如果不关闭可能会爆显存。

with torch.no_grad():# 使用model进行预测的代码pass

Pytorch 训练时无用的临时变量可能会越来越多,导致 out of memory ,可以使用下面语句来清理这些不需要的变量。

torch.cuda.empty_cache()

更详细的优化可以查看 优化显存使用 和 显存利用问题。

7、学习率衰减

import torch.optim as optim
from torch.optim import lr_scheduler# 训练前的初始化
optimizer = optim.Adam(net.parameters(), lr=0.001)
scheduler = lr_scheduler.StepLR(optimizer, 10, 0.1)  # # 每过10个epoch,学习率乘以0.1# 训练过程中
for n in n_epoch:scheduler.step()...

8、冻结某些层的参数

参考:Pytorch 冻结预训练模型的某一层

在加载预训练模型的时候,我们有时想冻结前面几层,使其参数在训练过程中不发生变化。

我们需要先知道每一层的名字,通过如下代码打印:

net = Network()  # 获取自定义网络结构
for name, value in net.named_parameters():print('name: {0},\t grad: {1}'.format(name, value.requires_grad))

假设前几层信息如下:

name: cnn.VGG_16.convolution1_1.weight,   grad: True
name: cnn.VGG_16.convolution1_1.bias,    grad: True
name: cnn.VGG_16.convolution1_2.weight,  grad: True
name: cnn.VGG_16.convolution1_2.bias,    grad: True
name: cnn.VGG_16.convolution2_1.weight,  grad: True
name: cnn.VGG_16.convolution2_1.bias,    grad: True
name: cnn.VGG_16.convolution2_2.weight,  grad: True
name: cnn.VGG_16.convolution2_2.bias,    grad: True

后面的True表示该层的参数可训练,然后我们定义一个要冻结的层的列表:

no_grad = ['cnn.VGG_16.convolution1_1.weight','cnn.VGG_16.convolution1_1.bias','cnn.VGG_16.convolution1_2.weight','cnn.VGG_16.convolution1_2.bias'
]

冻结方法如下:

net = Net.CTPN()  # 获取网络结构
for name, value in net.named_parameters():if name in no_grad:value.requires_grad = Falseelse:value.requires_grad = True

冻结后我们再打印每层的信息:

name: cnn.VGG_16.convolution1_1.weight,   grad: False
name: cnn.VGG_16.convolution1_1.bias,    grad: False
name: cnn.VGG_16.convolution1_2.weight,  grad: False
name: cnn.VGG_16.convolution1_2.bias,    grad: False
name: cnn.VGG_16.convolution2_1.weight,  grad: True
name: cnn.VGG_16.convolution2_1.bias,    grad: True
name: cnn.VGG_16.convolution2_2.weight,  grad: True
name: cnn.VGG_16.convolution2_2.bias,    grad: True

可以看到前两层的weight和bias的requires_grad都为False,表示它们不可训练。

最后在定义优化器时,只对requires_grad为True的层的参数进行更新。

optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=0.01)

Pytorch 训练技巧相关推荐

  1. 送你9个快速使用Pytorch训练解决神经网络的技巧(附代码)

    来源:读芯术 本文约4800字,建议阅读10分钟. 本文为大家介绍9个使用Pytorch训练解决神经网络的技巧 图片来源:unsplash.com/@dulgier 事实上,你的模型可能还停留在石器时 ...

  2. PyTorch训练加速技巧

    PyTorch训练加速技巧 由于最近的程序对速度要求比较高,想要快速出结果,因此特地学习了一下混合精度运算和并行化操作,由于已经有很多的文章介绍相关的原理,因此本篇只讲述如何应用PyTorch实现混合 ...

  3. PyTorch训练加速17种技巧

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 文自 机器之心 作者:LORENZ KUHN 编辑:陈萍 掌握这 ...

  4. 【Pytorch神经网络理论篇】 24 神经网络中散度的应用:F散度+f-GAN的实现+互信息神经估计+GAN模型训练技巧

    1 散度在无监督学习中的应用 在神经网络的损失计算中,最大化和最小化两个数据分布间散度的方法,已经成为无监督模型中有效的训练方法之一. 在无监督模型训练中,不但可以使用K散度JS散度,而且可以使用其他 ...

  5. .mb是什么文件_神经网络长什么样不知道? 这有一份简单的 pytorch可视化技巧(1)

    神经网络长什么样不知道?这有一份简单的 pytorch可视化技巧(1) 深度学习这几年伴随着硬件性能的进一步提升,人们开始着手于设计更深更复杂的神经网络,有时候我们在开源社区拿到网络模型的时候,做客可 ...

  6. 编写高效的PyTorch代码技巧(下)

    点击上方"算法猿的成长",关注公众号,选择加"星标"或"置顶" 总第 133 篇文章,本文大约 3000 字,阅读大约需要 15 分钟 原文 ...

  7. Pytorch常用技巧记录

    Pytorch常用技巧记录 目录 文章目录 Pytorch常用技巧记录 1.指定GPU编号 2.查看模型每层输出详情 3.梯度裁剪(Gradient Clipping) 4.扩展单张图片维度 5.独热 ...

  8. PyTorch学习笔记(六):PyTorch进阶训练技巧

    PyTorch实战:PyTorch进阶训练技巧 往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本系列目录: P ...

  9. 高效又稳定的ChatGPT大模型训练技巧总结,让训练事半功倍!

    文|python 前言 近期,ChatGPT成为了全网热议的话题.ChatGPT是一种基于大规模语言模型技术(LLM, large language model)实现的人机对话工具.现在主流的大规模语 ...

最新文章

  1. jQuery实现用户注册的表单验证
  2. VML编程之------VML语言入门《VML极道教程》原著:沐缘华
  3. 为什么 HashMap 的加载因子是0.75?
  4. java mina unix client
  5. 【渝粤教育】广东开放大学 数据结构 形成性考核 (24)
  6. Windows7 64位下SDK Manager.exe无法运行问题解决方法
  7. Java springboot B2B2C o2o多用户商城 springcloud架构-(十)高可用的服务注册中心
  8. Java代码质量改进之:使用ThreadLocal维护线程内部变量
  9. 2-1 组合优化问题
  10. css td中画斜线,css 模拟表格斜线
  11. unll是什么意思_javascript中null是什么意思?
  12. DHTMLX Gantt 甘特图 使用
  13. CodeForces - 272C Dima and Staircase (线段树区间更新)
  14. bmp图片灰度化和二值化
  15. php 获取当前目录和当前文件夹
  16. 高质量的CAD练习图纸在线分享
  17. 【C/C++】从API学习STL algorithm 001(for_each、find、find_if、find_end、find_first_of 快到碗里来(◕ᴗ◕✿)
  18. can‘t convert np.ndarray of type numpy.object_. The only supported types are: float64, float32, floa
  19. Excel2010的LARGE函数应用详解
  20. Python Argparse 库讲解特别好的

热门文章

  1. 如何使用Selenium IDE浏览器插件轻松完成脚本录制,轻松搞定自动化测试!
  2. 互联网安全技术有哪些
  3. c语言文件组织与多文件gcc命令行编译
  4. [Python] Codecombat攻略 远边的森林 Forest (1-40关)
  5. Java基础之模拟披萨店
  6. tensorflow-keras框架搭建GoogLeNet分类网络(附带注释)
  7. linux 上安装ffmpeg
  8. 动态规划:0-1背包问题
  9. WSL安装z3报错ModuleNotFound
  10. 火狐浏览器所有的快捷键大全