点击上方“视学算法”,选择加"星标"或“置顶

重磅干货,第一时间送达

作者丨xxy-zhihu@知乎

来源丨https://zhuanlan.zhihu.com/p/339039943

编辑丨极市平台

导读

本文为一篇实操教程,作者介绍了PyTorch AutoGrad C++层实现中各个概念的解释。

autograd依赖的数据结构

at::Tensor:shared ptr 指向 TensorImpl

TensorImpl:对 at::Tensor 的实现

  • 包含一个类型为 [AutogradMetaInterface](c10::AutogradMetaInterface) 的autograd_meta_,在tensor是需要求导的variable时,会被实例化为 [AutogradMeta](c10::AutogradMetaInterface) ,里面包含了autograd需要的信息

Variable: 就是Tensor,为了向前兼容保留的

  • using Variable = at::Tensor;

  • 概念上有区别, Variable 是需要计算gradient的, Tensor 是不需要计算gradient的

  • VariableAutogradMeta是对 [AutogradMetaInterface](c10::AutogradMetaInterface)的实现,里面包含了一个 Variable,就是该variable的gradient

  • 带有version和view

  • 会实例化 AutogradMeta , autograd需要的关键信息都在这里

AutoGradMeta : 记录 Variable 的autograd历史信息

  • 包含一个叫grad_的 Variable, 即 AutoGradMeta 对应的var的梯度tensor

  • 包含类型为 Node 指针的 grad_fn (var在graph内部时)和 grad_accumulator(var时叶子时), 记录生成grad_的方法

  • 包含 output_nr ,标识var对应 grad_fn的输入编号

  • 构造函数包含一个类型为 Edge的gradient_edge, gradient_edge.function 就是 grad_fn, 另外 gradient_edge.input_nr 记录着对应 grad_fn的输入编号,会赋值给 AutoGradMetaoutput_nr

autograd::Edge: 指向autograd::Node的一个输入

  • 包含类型为 Node 指针,表示edge指向的Node

  • 包含 input_nr, 表示edge指向的Node的输入编号

autograd::Node: 对应AutoGrad Graph中的Op

  • 是所有autograd op的抽象基类,子类重载apply方法

    • next_edges_记录出边

    • input_metadata_记录输入的tensor的metadata

  • 实现的子类一般是可求导的函数和他们的梯度计算op

  • Node in AutoGrad Graph

    • Variable通过Edge关联Node的输入和输出

    • 多个Edge指向同一个Var时,默认做累加

  • call operator

    • 最重要的方法,实现计算

  • next_edge

    • 缝合Node的操作

    • 获取Node的出边,next_edge(index)/next_edges()

    • add_next_edge(),创建

前向计算

PyTorch通过tracing只生成了后向AutoGrad Graph.

代码是生成的,需要编译才能看到对应的生成结果

  • gen_variable_type.py生成可导版本的op

  • 生成的代码在 pytorch/torch/csrc/autograd/generated/

  • 前向计算时,进行了tracing,记录了后向计算图构建需要的信息

  • 这里以relu为例,代码在pytorch/torch/csrc/autograd/generated/VariableType_0.cpp

Tensor relu(const Tensor & self) {                                                                                                                                                                   auto& self_ = unpack(self, "self", 0);                                                                                                                                                             std::shared_ptr<ReluBackward0> grad_fn;                                                                                                                                                            if (compute_requires_grad( self )) { // 如果输入var需要grad    // ReluBackward0的类型是Node                                                                                                                                                                grad_fn = std::shared_ptr<ReluBackward0>(new ReluBackward0(), deleteNode);                                                                                                                          // collect_next_edges(var)返回输入var对应的指向的    // grad_fn(前一个op的backward或者是一个accumulator的)的输入的Edge    // set_next_edges(),在grad_fn中记录这些Edge(这里完成了后向的构图)    grad_fn->set_next_edges(collect_next_edges( self ));     // 记录当前var的一个版本                                                                                                                                              grad_fn->self_ = SavedVariable(self, false);                                                                                                                                                     }                                                                                                                                                                                                  #ifndef NDEBUG                                                                                                                                                                                     c10::optional<Storage> self__storage_saved =                                                                                                                                                         self_.has_storage() ? c10::optional<Storage>(self_.storage()) : c10::nullopt;                                                                                                                    c10::intrusive_ptr<TensorImpl> self__impl_saved;                                                                                                                                                   if (self_.defined()) self__impl_saved = self_.getIntrusivePtr();                                                                                                                                   #endif                                                                                                                                                                                             auto tmp = ([&]() {                                                                                                                                                                                  at::AutoNonVariableTypeMode non_var_type_mode(true);                                                                                                                                               return at::relu(self_); // 前向计算                                                                                                                                                                          })();                                                                                                                                                                                              auto result = std::move(tmp);                                                                                                                                                                      #ifndef NDEBUG                                                                                                                                                                                     if (self__storage_saved.has_value())                                                                                                                                                                 AT_ASSERT(self__storage_saved.value().is_alias_of(self_.storage()));                                                                                                                             if (self__impl_saved) AT_ASSERT(self__impl_saved == self_.getIntrusivePtr());                                                                                                                      #endif                                                                                                                                                                                             if (grad_fn) {                   // grad_fn增加一个输入,记录输出var的metadata作为grad_fn的输入      // 输出var的AutoGradMeta实例化,输出var的AutoGradMeta指向起grad_fn的输入                                                                                                                                                                            set_history(flatten_tensor_args( result ), grad_fn);                                                                                                                                           }                                                                                                                                                                                                  return result;                                                                                                                                                                                   }
  • 可以看到和 grad_fn 相关的操作trace了一个op的计算,构建了后向计算图.

后向计算

autograd::backward():计算output var的梯度值,调用的 run_backward()

autograd::grad() :计算有output var和到特定input的梯度值,调用的 run_backward()

autograd::run_backward()

  • 对于要求梯度的output var,获取其指向的grad_fn作为roots,是后向图的起点

  • 对于有input var的,获取其指向的grad_fn作为output_edges, 是后向图的终点

  • 调用 autograd::Engine::get_default_engine().execute(...) 执行后向计算

autograd::Engine::execute(...)

  • 创建 GraphTask ,记录了一些配置信息

  • 创建 GraphRoot ,是一个Node,把所有的roots作为其输出边,Node的apply()返回的是roots的grad【这里已经得到一个单起点的图】

  • 计算依赖 compute_dependencies(...)

    • 从GraphRoot开始,广度遍历,记录所有碰到的grad_fn的指针,并统计grad_fn被遇到的次数,这些信息记录到GraphTask中

  • GraphTask 初始化:当有input var时,判断后向图中哪些节点是真正需要计算的

  • GraphTask 执行

    • 选择CPU or GPU线程执行

    • 以CPU为例,调用的 autograd::Engine::thread_main(...)

autograd::Engine::thread_main(...)

  • evaluate_function(...) ,输入输出的处理,调度

    • call_function(...) , 调用对应的Node计算

    • 执行后向过程中的生成的中间grad Tensor,如果不释放,可以用于计算高阶导数;(同构的后向图,之前的grad tensor是新的输出,grad_fn变成之前grad_fn的backward,这些新的输出还可以再backward)

  • 具体的执行机制可以支撑单独开一个Topic分析,在这里讨论到后向图完成构建为止.

点个在看 paper不断!

实操教程|PyTorch AutoGrad C++层实现相关推荐

  1. [转载]你们要的GIF动图制作全攻略!看完就会做!(实操教程)

    非常实用呀 原文地址:你们要的GIF动图制作全攻略!看完就会做!(实操教程)作者:木木老贼 来源:文案匠(ID:sun-work) 作者:一木(授权转载,如需转载请联系文案匠) 文章配图的GIF动图怎 ...

  2. 通过大白菜u盘启动工具备份/还原/重装/激活系统/修复引导 实操教程(上)

    通过大白菜u盘启动工具备份/还原/重装/激活系统/修复引导 实操教程(上) 前言 进入大白菜u盘的pe系统 用GHOST进行系统盘备份/还原 在D盘上安装新系统(以win10-2004为例) 镜像下载 ...

  3. 寻找亚马逊测评师邮箱_美国及欧盟亚马逊产品外观专利查询步骤实操教程(已验证)...

    亚马逊产品外观专利防不胜防:美国及欧盟外观专利查询步骤实操教程(已验证) 欧洲 https://www.tmdn.org/tmdsview-web/dsview-logo-white.15c95da2 ...

  4. 实操教程|火遍全网的剪纸风格究竟是怎么做出来的?

    原文来自公众号:希音的设计笔记 > 添加微信:xiyin0820 获取高质量样机 | C4D教程 | OC渲染教程 | Sketch教程 Adobe2021 | Adobe2020 | LED字 ...

  5. mysql教程乛it教程网_MySQL数据库实操教程(35)——完结篇

    版权声明 专栏概况 从2019年7月21日至今,约莫一个月的时间终于写完了MySQL教程,我已将其集结在专栏<MySQL数据库实操教程>,概述如下: 共计35篇文章 每篇文章均附源码和运行 ...

  6. MySQL数据库实操教程(35)——完结篇

    版权声明 本文原创作者:谷哥的小弟 作者博客地址:http://blog.csdn.net/lfdfhl 专栏概况 从2019年7月21日至今,约莫一个月的时间终于写完了MySQL教程,我已将其集结在 ...

  7. MetagenoNets:在线宏基因组网络分析实操教程

    宏基因组研究中网络分析已经十分普及,但却缺少整合的分析方法,限制了广大同行的使用. 关于网络分析的基本步骤,和现在工具的比较,详见原文解读 - NAR:宏基因组网络分析工具MetagenoNets 本 ...

  8. 网络分析系统_MetagenoNets:在线宏基因组网络分析实操教程

    宏基因组研究中网络分析已经十分普及,但却缺少整合的分析方法,限制了广大同行的使用. 关于网络分析的基本步骤,和现在工具的比较,详见原文解读 - NAR:宏基因组网络分析工具MetagenoNets 本 ...

  9. TensorFlow Probability 概率编程入门级实操教程

    雷锋网 AI 科技评论按:TensorFlow Probability(TFP)是一个基于 TensorFlow 的 Python 库,能够更容易地结合概率模型和深度学习.数据科学家.统计学以及机器学 ...

最新文章

  1. 因缺失log4j.properties 配置文件导致flume无法正常启动。
  2. Visual studio 2010 sp1中文版正式版无法安装Silverlight5_Tools rc1 的解决办法
  3. [导入]在ASP.NET 2.0中使用样式、主题和皮肤
  4. 数据结构与算法笔记(十三)—— 树与树的算法
  5. MFC 中屏蔽CDialog类窗体处理ESC和ESCAPE按键
  6. ooooo123123emabc
  7. 1008 数组元素循环右移问题 (20)
  8. MongoDB基础介绍安装与使用
  9. ueditor chrome bug
  10. 专为人工智能和数据科学而生的Go语言,或将取代Python
  11. Linux的企业-Codis 3集群搭建详解
  12. 机器学习与知识发现_01机器学习算法整体知识体系与学习路线攻略
  13. 八类网线和七类网线的区别_Cat8 八类网线与超五类网线、六类网线、超六类网线及七类/超七类网线的区别...
  14. php5.4安装教程,centos php 5.4 安装教程
  15. case when 效率_采用机械涡轮复合增压系统优化7.8 L柴油机的 稳态效率和排放性能...
  16. 建站之星安装提示无法连接数据库
  17. scratch游戏中背景移动的奥秘
  18. linux加载和卸载驱动模块出现 'XXX': device or resource busy 错误提示
  19. 新盲盒交友源码搭建Soul2.3正版免公众号免备案域名支持个人支付
  20. 【Fracturing amp; Destruction】Unity3D的物体爆裂、炸裂、碎裂效果

热门文章

  1. vs2010设置boost开发环境
  2. 【Codeforces】913C Party Lemonade (贪...)。
  3. 赠书 | 干货!用 Python 动手学强化学习
  4. 清华、北大教授同台激辩:脑科学是否真的能启发AI?
  5. 希捷发布CORTX对象存储软件与开源社区,普惠超大规模数据存储
  6. 应届生失业率或继续上升?别怕,这份秋招指南请收好!
  7. 知识图谱实体链接是什么?一份“由浅入深”的综述
  8. 据说这是大多数人【减肥】的真实写照
  9. AttoNets,一种新型的更快、更高效边缘计算神经网络
  10. 2018年最后几天学什么?给你关注度最高的10篇文章