为什么tensorflow要有Variable的对象?

编程语言中,都有变量之概念,用于保存中间计算结果,如计算100以内正整数的累加和.

int sum = 0;
for (int i = 0 ;i<=100;i++){sum += i;
}

上述c++代码中,sum定义变量用于保存累加和.在内部实现时,会为sum变量分配一块固定内存,每次循环,该内存值会变化,但内存地址不变.sum变量和地址是绑定的.

以下为python语言的实现

sum = 0
for i in range(0,101):sum += iprint(id(sum))

也定义了变量sum,观察输出,会发现输出id(sum)在变化,即变量sum的地址在变化.在python中,变量不是和内存的地址绑定的.变量只是对象的一个名称.赋值语句将不同对象赋给某个变量,会改变变量的地址,其实相当于一个新的变量,只是名称没变.

tensorflow中有很多求变量的梯度操作,即变量变化对函数值的影响.该操作要根据变量变化前后的值来计算,如果变量的地址变化,则会增加实现难度.tensorflow通过定义Variable对象来简化梯度的计算.

        with tf.GradientTape() as tape:  # with结构记录梯度信息y = tf.matmul(x_train, w1) + b1  # 神经网络乘加运算y = tf.nn.softmax(y)  # 使输出y符合概率分布(此操作后与独热码同量级,可相减求loss)y_ = tf.one_hot(y_train, depth=3)  # 将标签值转换为独热码格式,方便计算loss和accuracyloss = tf.reduce_mean(tf.square(y_ - y))  # 采用均方误差损失函数mse = mean(sum(y-out)^2)loss_all += loss.numpy()  # 将每个step计算出的loss累加,为后续求loss平均值提供数据,这样计算的loss更准确# 计算loss对各个参数的梯度grads = tape.gradient(loss, [w1, b1])# 实现梯度更新 w1 = w1 - lr * w1_grad    b = b - lr * b_gradw1.assign_sub(lr * grads[0])  # 参数w1自更新b1.assign_sub(lr * grads[1])  # 参数b自更新

上述代码用梯度下降算法训练网络参数.GradientTape对象用于跟踪参数的变化对损失函数值的影响(梯度),网络参数定义为Variable对象,GradientTape会记录Variable的变化,从而计算出梯度.如果用普通变量来定义网络参数,那么每次迭代更新,网络参数就是一个新的对象.GradientTape无法跟踪这些新变量.因此,将网络参数定义成Variable,就是为了让GradientTape跟踪变化,从而计算梯度.在更新网络参数时,用Variable.assign_xxx方法,而不能用赋值操作符,赋值操作符(=)会返回一个新的Variable.

 # 实现梯度更新 w1 = w1 - lr * w1_grad    b = b - lr * b_grad#w1.assign_sub(lr * grads[0])  # 参数w1自更新#b1.assign_sub(lr * grads[1])  # 参数b自更新w1 = w1 - lr*grads[0]#w1就是一个新的对象!!!!b1 = b1 - lr*grads[1] #b1就是一个新的对象!!!!

上述代码会在第二次迭代时出错.

tensorflow 之 Variable的理解相关推荐

  1. 【Tensorflow】op的理解和自定义损失函数

    tensorflow中的基本概念 本文是在阅读官方文档后的一些个人理解. 官方文档地址:https://www.tensorflow.org/versions/r0.12/get_started/ba ...

  2. TensorFlow 中 identity 函数理解

    理解 identity: n.身份; 本身; 本体; 特征; 特有的感觉(或信仰); 同一性; 相同; 一致; identity的意思是自身的意思,简单说就是赋值. x = tf.Variable(0 ...

  3. 『TensorFlow』通过代码理解gan网络_中

    『cs231n』通过代码理解gan网络&tensorflow共享变量机制_上 上篇是一个尝试生成minist手写体数据的简单GAN网络,之前有介绍过,图片维度是28*28*1,生成器的上采样使 ...

  4. tensorflow || 滑动平均的理解--tf.train.ExponentialMovingAverage

    1 滑动平均的理解 滑动平均(exponential moving average),或者叫做指数加权平均(exponentially weighted moving average),可以用来估计变 ...

  5. tensorflow tf.Variable 的用法

    import tensorflow as tf #导入模块 import numpy as np tf.Variable(3) # 数字输入 <tf.Variable 'Variable:0' ...

  6. sklearn、theano、TensorFlow 以及 theras 的理解

    sklearn ⇒ 机器学习算法和模型: theras theano TensorFlow 1. 理解模型以及函数,参数返回值的实际意义 一定要注意模型的构造函数,接收的参数列表,以及该模型本身所要解 ...

  7. TensorFlow中Variable()和get_variable()

      tf.Variable()和tf.get_variable()都可以用来创建变量,但是前者会自动保证唯一性,而后者不能保证唯一性. 1 tf.Variable: Variable(initial_ ...

  8. tensorflow 里metrics_深入理解TensorFlow中的tf.metrics算子

    [IT168 技术]01 概述 本文将深入介绍Tensorflow内置的评估指标算子,以避免出现令人头疼的问题. tf.metrics.accuracy() tf.metrics.precision( ...

  9. TensorFlow从入门到理解(六):可视化梯度下降

    运行代码: import tensorflow as tf import numpy as np import matplotlib.pyplot as plt from mpl_toolkits.m ...

最新文章

  1. 基于TI TMS320C6678 + Xilinx Kintex-7 的高性能信号处理方案
  2. oracle中escape关键字用法
  3. 润乾V5手机报表说明文档
  4. 时代天使点燃口腔赛道,瑞尔集团离下一只“牙茅”还有多远?
  5. 易语言 服务器抓包,易语言抓包获得地址实现TP路由器登陆的代码
  6. UI5_INFO_FETCH_FROM_DB
  7. centos7中无法确定光盘权限怎么办_图解KVM安装CentOS7.6操作系统
  8. java socket smtp_JAVA Socket实现smtp发送邮件
  9. linux下挂载iso镜像的方法
  10. mysql5.1安装失败_解决MySQL5.1安装时出现Cannot create windows service for mysql.error:0
  11. PHP数组合并的常见问题
  12. 指纹对比软件_杰恩世软件平台钢片AOI检测应用
  13. CSS规范(OOCSS SMACSS BEM)
  14. 联想外接键盘fn热键取消
  15. 机器人动力学与控制学习笔记(十一)————机器人凯恩方程动力学建模
  16. 公网远程Everything快速搜索私有云资料【内网穿透】
  17. java基础-File类与IO流
  18. 华为云宣布将在全球范围内推出区块链服务
  19. 【Android 进阶】开发APP常见的错误
  20. 【嵌入式开发教程6】手把手教你做平板电脑-触摸屏驱动实验教程

热门文章

  1. 内存管理 —— 地址翻译
  2. 小学数学与计算机整合课优质教案,小学数学与信息技术整合教案
  3. 【AD小知识】原理图的导出元器件清单方法及设置
  4. 【GRUB】GRUB2基本操作
  5. [Effective C++]条款14:在资源管理类中小心copying行为
  6. Cet6高频词汇汇总
  7. linux cdn服务器,wdcdn系统,CDN缓存系统,CDN加速系统,多节点CDN自架系统,CDN安装配置部署--Linux解决方案,技术支持与培训,服务器架构,性能优化,负载均衡,集群分流...
  8. 实锤了,不仅5G基站是电老虎,5G手机耗电同样严重
  9. docker(11、Docker Swarm4)11、副本数量(replicated 和 global )12、Label 控制 Service 的位置 13、如何配置 Health Check
  10. 光学背投屏幕安装方法的技术建议