核心知识讲解:

过程:输入特征=>模型预测=>根据结果计算一下损失(损失就是距离target的差距),然后将参数更新,再放回模型中预测,直至收敛,使得损失变得最小,这时候的参数就是我们想要的参数


  • 梯度 (gradient):偏导数相对于所有自变量的向量。在机器学习中,梯度是模型函数偏导数的向量。梯度指向最速上升的方向。
  • 梯度下降法 (gradient descent):一种通过计算并且减小梯度将损失降至最低的技术,它以训练数据为条件,来计算损失相对于模型参数的梯度。通俗来说,梯度下降法以迭代方式调整参数,逐渐找到权重和偏差的最佳组合,从而将损失降至最低。

假设我们能够将所有种可能情况全都计算一遍,那么得到的一定是一个类似于这样的碗状图,在其中必定有一点是损失最低的点,但是现实种我们肯定不会有那么大的计算能力和时间去计算出每个结果,我们通常采用一种叫做梯度下降法的方式来"快速"的找到损失最低的点(梯度下降法属于一种优化算法,虽然并不是最好的优化算法,但是其方式简单,应用也很多)。

  • 起点是随意选定的,因为在预测的开始时,没有人知道权重(w1,w2,w3..b)该是什么,可以设置为0,也可以设置为1,无所谓。通过模型一次计算,计算得出损失(这时候损失并不重要,肯定极大,没有参考意义),然后计算起点处的偏导数(如果只有一个权重那就是导数了),得出起点处的偏导数,而梯度是偏导数的矢量(即包含了此处偏导数的方向大小),可以想象一下抛物线y=ax²+bx+c 在x0处的导数,其大小的绝对值是随着x0的值而变化的,并且有正负之分,绝对值大小代表大小,正负代表方向,所以依据梯度就可以确定权重值调节的方向。
  • 至此,调节的基本原理说的就差不多了,那么剩下的问题就是如何更好的优化,以便用最少的计算量最快的速度去达到目的。

  • 学习速率(也称为步长)

如果让其按照每个点本身的梯度大小来调节权值,那实在是太慢了,所以我们可以为其乘上一个学习速率,意如其名,这样可以人手动的调节学习速率(或许有的人会担心,当即将逼近损失最小的点时,这样会不会不太准确了,放心好了,我们并不需要那么的准确的权值,99%和98%的区别不是太大,但是所要付出的计算量却是超大的)

附上谷歌提供的:优化学习速率体验

下面是两种个效果更好的梯度下降算法方案,第二种更优
随机梯度下降法 (SGD) :它每次迭代只使用一个样本(批量大小为 1)。“随机”这一术语表示构成各个批量的一个样本都是随机选择的。(假设有10000个样本,每次从中随机选一个来执行梯度下降)

小批量随机梯度下降法小批量 SGD)是介于全批量迭代与 SGD 之间的折衷方案。小批量通常包含 10-1000 个随机选择的样本。小批量 SGD 可以减少 SGD 中的杂乱样本数量,但仍然比全批量更高效。(每次随机选一批)

注意:为了安全起见,我们还会通过 clip_gradients_by_norm 将梯度裁剪应用到我们的优化器。梯度裁剪可确保梯度大小在训练期间不会变得过大,梯度过大会导致梯度下降法失败。

神经网络的常见模板代码如下:

import tensorflow as tf
import numpy as np

"""
这里是一个非常好的大数据验证结果,随着数据量的上升,集合的结果也越来越接近真实值,
这也是反馈神经网络的一个比较好的应用
这里不是很需要各种激励函数
而对于dropout,这里可以看到加上dropout,loss的值更快。
随着数据量的上升,结果就更加接近于真实值。
"""

inputX = np.random.rand(3000,1)
noise = np.random.normal(0, 0.05, inputX.shape)
outputY = inputX * 4 + 1 + noise

#这里是第一层
weight1 = tf.Variable(np.random.rand(inputX.shape[1],4))
bias1 = tf.Variable(np.random.rand(inputX.shape[1],4))
x1 = tf.placeholder(tf.float64, [None, 1])
y1_ = tf.matmul(x1, weight1) + bias1
#这里是第二层
weight2 = tf.Variable(np.random.rand(4,1))
bias2 = tf.Variable(np.random.rand(inputX.shape[1],1))
y2_ = tf.matmul(y1_, weight2) + bias2

y = tf.placeholder(tf.float64, [None, 1])

loss = tf.reduce_mean(tf.reduce_sum(tf.square((y2_ - y)), reduction_indices=[1]))
train = tf.train.GradientDescentOptimizer(0.25).minimize(loss)  # 选择梯度下降法

init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)

for i in range(1000):
    sess.run(train, feed_dict={x1: inputX, y: outputY})

print(weight1.eval(sess))
print("---------------------")
print(weight2.eval(sess))
print("---------------------")
print(bias1.eval(sess))
print("---------------------")
print(bias2.eval(sess))
print("------------------结果是------------------")

x_data = np.matrix([[1.],[2.],[3.]])
print(sess.run(y2_,feed_dict={x1: x_data}))

TensorFlow的神经网络设计讲解相关推荐

  1. Tensorflow实现神经网络及实现多层神经网络进行时装分类

    Tensorflow实现神经网络及实现多层神经网络进行时装分类 1. tf.keras构建模型训练评估测试API介绍 import tensorflow as tf from tensorflow i ...

  2. 用神经网络例子讲解TF运行方式~人工智能入门编程例子讲解

    #用神经网络例子讲解TF运行方式#import os #防止出现警告 #os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' #P145 ###########生成及加载数 ...

  3. 图神经网络设计中的算子融合策略

    ©作者 | 刘曜齐 单位 | 北京邮电大学硕士生 来源 | 北邮GAMMA Lab 本文主要的描述基于消息传递机制的图神经网络设计中应用的算子融合策略,带领读者了解有关算子融合的相关问题以及方法. 引 ...

  4. TensorFlow 卷积神经网络实用指南:1~5

    原文:Hands-On Convolutional Neural Networks with TensorFlow 协议:CC BY-NC-SA 4.0 译者:飞龙 本文来自[ApacheCN 深度学 ...

  5. tensorflow循环神经网络(RNN)文本生成莎士比亚剧集

    tensorflow循环神经网络(RNN)文本生成莎士比亚剧集 我们将使用 Andrej Karpathy 在<循环神经网络不合理的有效性>一文中提供的莎士比亚作品数据集.给定此数据中的一 ...

  6. sift论文_卷积神经网络设计相关论文

    最近梳理了一下卷积神经网络设计相关的论文(这个repo现在只列出了最重要的一些论文,后面会持续补充): Neural network architecture design​github.com 1. ...

  7. 3.10 程序示例--神经网络设计-机器学习笔记-斯坦福吴恩达教授

    神经网络设计 在神经网络的结构设计方面,往往遵循如下要点: 输入层的单元数等于样本特征数. 输出层的单元数等于分类的类型数. 每个隐层的单元数通常是越多分类精度越高,但是也会带来计算性能的下降,因此, ...

  8. 利用TensorFlow和神经网络来处理文本分类问题

    利用TensorFlow和神经网络来处理文本分类问题 By 机器之心2017年8月23日 10:33 在这篇文章中,机器之心海外分析师对Medium(链接见文后)上的一篇热门博客进行了介绍,讨论了六个 ...

  9. tensorflow训练神经网络时loss出现nan的问题

    tensorflow训练神经网络时loss出现nan的问题 一般情况下原因是由于优化器上的学习比率learning_rate定义值太大,如: train_step = tf.compat.v1.tra ...

最新文章

  1. 使用pydub做静音帧去除
  2. python提取网页数据
  3. boost::hana::and_用法的测试程序
  4. python打包加版本信息_使用pyi-set_version为PyInstaller打包出来的程序附加版本信息...
  5. IOC操作Bean管理XML方式(bean 的生命周期)
  6. coddenomicon工具
  7. java使用哪个类,怎么知道 java类从哪个jar 加载
  8. 如何避免开源安全噩梦?
  9. c语言spi测试代码,C语言程序SPI
  10. 谷歌SEO是什么意思,谷歌搜索引擎优化怎么做
  11. Spring Security系列之基本原理
  12. 微信公众号:服务号、企业订阅号、个人订阅号的差异对比
  13. python列表报错TypeError: list indices must be integers or slices, not str
  14. Straight lines have to be straight
  15. 三维立体动画制作技巧
  16. 汇开优店APP介绍及汇开优店管家区别
  17. chrome浏览器安全检查_为您的Chrome浏览器检查皮肤
  18. xmanager5 + xshell linux 远程
  19. php生成值班表,EXCEL表制作自动排列值班表【excel值班表表格制作教程】
  20. Razer雷蛇7.1声音驱动卸载后无法安装问题

热门文章

  1. c语言教材课后题答案6,C语言谭浩强版6章课后练习题答案.doc
  2. python基础知识三 字典-dict + 菜中菜
  3. 敏捷模式下的团队测试能力构建
  4. (附源码)SSM兴澜幼儿园管理系统JAVA计算机毕业设计项目
  5. 发送邮件 空格 java_java发送邮件 - 困觉的曼巴er的个人空间 - OSCHINA - 中文开源技术交流社区...
  6. Jupyter Notebook 更改默认目录
  7. 记录一次使用JS生成word后端转换PDF功能
  8. PS5上传图片失败,游戏无法推送更新,提示服务器出了点问题,HTTP状态码:403
  9. NISP-信息安全事件与应急响应
  10. HDU 3549 网络流水题