optimizer.zero_grad()意思是把梯度置零,也就是把loss关于weight的导数变成0.

在学习pytorch的时候注意到,对于每个batch大都执行了这样的操作:

        # zero the parameter gradientsoptimizer.zero_grad()# forward + backward + optimizeoutputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()

对于这些操作我是把它理解成一种梯度下降法,贴一个自己之前手写的简单梯度下降法作为对照:

    # gradient descentweights = [0] * nalpha = 0.0001max_Iter = 50000for i in range(max_Iter):loss = 0d_weights = [0] * nfor k in range(m):h = dot(input[k], weights)d_weights = [d_weights[j] + (label[k] - h) * input[k][j] for j in range(n)] loss += (label[k] - h) * (label[k] - h) / 2d_weights = [d_weights[k]/m for k in range(n)]weights = [weights[k] + alpha * d_weights[k] for k in range(n)]if i%10000 == 0:print "Iteration %d loss: %f"%(i, loss/m)print weights

可以发现它们实际上是一一对应的:

optimizer.zero_grad()对应d_weights = [0] * n

即将梯度初始化为零(因为一个batch的loss关于weight的导数是所有sample的loss关于weight的导数的累加和)

outputs = net(inputs)对应h = dot(input[k], weights)

即前向传播求出预测的值

loss = criterion(outputs, labels)对应loss += (label[k] - h) * (label[k] - h) / 2

这一步很明显,就是求loss(其实我觉得这一步不用也可以,反向传播时用不到loss值,只是为了让我们知道当前的loss是多少)
loss.backward()对应d_weights = [d_weights[j] + (label[k] - h) * input[k][j] for j in range(n)]

即反向传播求梯度
optimizer.step()对应weights = [weights[k] + alpha * d_weights[k] for k in range(n)]

即更新所有参数

如有不对,敬请指出。欢迎交流

作者:scut_salmon
来源:CSDN
原文:https://blog.csdn.net/scut_salmon/article/details/82414730
版权声明:本文为博主原创文章,转载请附上博文链接!

torch代码解析 为什么要使用optimizer.zero_grad()相关推荐

  1. Pytorch中的optimizer.zero_grad和loss和net.backward和optimizer.step的理解

    引言 一般训练神经网络,总是逃不开optimizer.zero_grad之后是loss(后面有的时候还会写forward,看你网络怎么写了)之后是是net.backward之后是optimizer.s ...

  2. unet模型及代码解析

    什么是unet 一个U型网络结构,2015年在图像分割领域大放异彩,unet被大量应用在分割领域.它是在FCN的基础上构建,它的U型结构解决了FCN无法上下文的信息和位置信息的弊端 Unet网络结构 ...

  3. 梯度值与参数更新optimizer.zero_grad(),loss.backward、和optimizer.step()、lr_scheduler.step原理解析

    在用pytorch训练模型时,通常会在遍历epochs的过程中依次用到optimizer.zero_grad(),loss.backward.和optimizer.step().lr_schedule ...

  4. Baidu Apollo代码解析之EM Planner中的QP Speed Optimizer 1

    大家好,我已经把CSDN上的博客迁移到了知乎上,欢迎大家在知乎关注我的专栏慢慢悠悠小马车(https://zhuanlan.zhihu.com/duangduangduang).希望大家可以多多交流, ...

  5. [GCN] 代码解析 of GitHub:Semi-supervised classification with graph convolutional networks

    本文解析的代码是论文Semi-Supervised Classification with Graph Convolutional Networks作者提供的实现代码. 原GitHub:Graph C ...

  6. 自然语言处理(三):传统RNN(NvsN,Nvs1,1vsN,NvsM)pytorch代码解析

    文章目录 1.预备知识:深度神经网络(DNN) 2.RNN出现的意义与基本结构 3.根据输入和输出数量的网络结构分类 3.1 N vs N(输入和输出序列等长) 3.2 N vs 1(多输入单输出) ...

  7. Data-Free Knowledge Distillation for Heterogeneous Federated Learning论文阅读+代码解析

    论文地址点这里 一. 介绍 联邦学习具有广阔的应用前景,但面临着来自数据异构的挑战,因为在现实世界中用户数据均为Non-IID分布的.在这样的情况下,传统的联邦学习算法可能会导致无法收敛到各个客户端的 ...

  8. GraphSAGE算法 和 代码解析

    聚合邻居 GraphSAGE研究了聚合邻居操作所需的性质,并且提出了几种新的聚合操作(aggregator),需满足如下条件: (1)聚合操作必须要对聚合节点的数量做到自适应.不管节点的邻居数量怎么变 ...

  9. pytorch之model.zero_grad() 与 optimizer.zero_grad()

    转自 https://cloud.tencent.com/developer/article/1710864 1. 引言 在PyTorch中,对模型参数的梯度置0时通常使用两种方式:model.zer ...

  10. python grad_PyTorch中model.zero_grad()和optimizer.zero_grad()用法

    废话不多说,直接上代码吧~ model.zero_grad() optimizer.zero_grad() 首先,这两种方式都是把模型中参数的梯度设为0 当optimizer = optim.Opti ...

最新文章

  1. 公开课 | 人脸识别的最新进展以及工业级大规模人脸识别实践探讨
  2. 设计模式 - Strategy
  3. mobaxterm 传文件夹_如何使用MobaXterm上传文件到远程Linux系统-MobaXterm使用教程
  4. CentOS 7安装GNOME图形界面并设置默认启动
  5. 2016年Android主流技术
  6. 大小端、位段(惑位域)和内存对齐
  7. 地域跨度入手的8zsb
  8. sizeof运算符介绍以及常见的坑
  9. openstack 之 kolla安装镜像
  10. 家用nas的过去现在和未来--2008n年
  11. JAVA使用JEP进行动态公式计算
  12. 「微信小程序」有哪些冲击与机会?
  13. python再议装饰器
  14. python实现——视频转桌面壁纸
  15. 在线图片尺寸修改 生成图标
  16. 计算机操作系统的加密与恢复,当在 Windows中设置 FIPS 兼容策略时,BitLocker 的恢复密码Windows...
  17. python面向对象之抽象类
  18. 谭浩强C++ 第八章
  19. 【word2vec】算法原理 公式推导
  20. 【Docker闪退】【解决方法】It looks like there is an error with Docker Desktop, restart it to fix it

热门文章

  1. sox处理mp3_sox :音频文件转换命令
  2. UE4----GC(垃圾回收)
  3. 为什么说衰老先从血管开始?
  4. allt什么意思_all是什么意思_all怎么读_all翻译_用法_发音_词组_同反义词_全部的-新东方在线英语词典...
  5. Project 2013项目管理教程(3):建立任务间的依赖性
  6. 台式计算机 按键盘字母键 没反应6,电脑键盘失灵后的解决办法
  7. 滴滴曹乐:如何成为技术大牛?
  8. 网易视频云:新一代列式存储格式Parquet
  9. 互联网晚报 | 12月25日 星期六 | 小米首款自研充电芯片澎湃P1官宣;抖音电商启动“冬季山货节”;全国首批“千兆城市”出炉...
  10. c语言编程仓鼠吃豆子,动态规划之仓鼠吃豆子 - osc_8quu62cg的个人空间 - OSCHINA - 中文开源技术交流社区...