Tensorflow

一开始呢,让我们先了解一下tensorflow的运行方式。简单来说,我们使用步骤一共有三个:创建图,运行图,保存图。

Tensorflow的计算是在图(graph)里面计算的,因此我们必须按照自己的需求来设计一张图。当然图的意思不是代表图片,而是代表一种结构。当创建好图之后,我们导入数据(也叫喂数据)来运行这张图。运行的过程中我们需要调整自己的参数。假如结果符合我们的要求,我们就保存这张图和里面的数据。

即使看的不明所以也没关系,接下来我们会用最简单的一种结构来解决MNIST数据集。在使用的途中你会对tensorflow更加了解。

MNIST

MNIST数据集是一个手写数字训练集(handwritten digit database)。里面有0到9的手写数字图片,并帮你打上了标签。打上标签的意思它有一个文件写明了图片代表的数字。


MNIST是一个很有用的数据集,在接下来的时间里,我们会针对它不断提高我们神经网络的复杂度进而提高我们的网络的准确率。

全连接层

全连接层(full-connected layer),顾名思义,是将前面层的节点全部连接然后通过自己之后传入下一层。

前面讲到我们需要创建图,然后喂数据来运行。传入的数据被我们称为输入层。在处理MNIST数据集的时候,我们把每个像素都作为输入的数据,然后分批导入图片。输入层经过网络之后输出的数据作为输出层。本文网络简易结构:

MNIST的每张图片的分辨率都为28*28,那么输入层一共有784个节点(即每个像素都是一个节点)。之所以这样设置,是因为每个像素都包含了图片的信息,它们共同决定了这张图片的数字。

然后我们设置全连接层的形状(shape)为[784,10]。因为我们只有一层全连接层,它接受输入层的784个节点然后输出十个节点(十个分类)。如下图所示,X代表图片的某个像素,经过全连接层层后输出十个值,最大值即是网络的结果。

制作图片时候不是很精确,其实XW1XW1XW_1+b1b1b_1这种形式。

代码解析

导入需要的包
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data

相信掌握python的人对于import as的用法不会陌生。Tensorflow可以通过第三句导入MNIST数据集,命名为input_data

处理训练集
mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
定义批次
batch_size = 100
n_batch = mnist.train.num_examples // batch_size

前面说到要将MNIST分批次处理,在这里我们定义了batch_size=100。即每次将传入100张照片进行处理,batch的数量为全部的照片的数量对batch_size取余。在这里的mnist.train.num_examples是tensorflow为我们准备好的语句了。

需要注意的是,one_hot是一种格式。根据MNIST数据集,我们一共有十个分类。假如一张图片分类为‘0’,那么它的标签格式为:

[1, 0, 0, 0, 0, 0, 0, 0, 0, 0]

可以看的出来,此时下标为0的值为1,而其他全为0。

Tips:假如IDE提示网络连接失败,那就需要你自己上网找MNIST数据集,一共有四个gz文件。假如下载在当前目录,那么需要新建一个’MNIST_data’文件夹放置这四个文件。

构建tensorflow的图
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])create a simple neutral network
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))

这里就在构建一个图了。tf.placehoder是创建一个占位符,用来接受输入的数据。在这里我们创建x,y来分别接受传入的图片和对应的标签。tf.float32是tensorflow里面的float类型,而后面的[None, 784]代表了占位符的形状。前文提到,我们将784个像素作为输入,但我们一次性输入100张图片,所以输入会是一个[100, 784]的矩阵。

用None表示数量可以产生变化!

tf.Variable就是创建一个变量。权重的参数都应该设置为变量,因为它在训练的时候需要被更新,在测试的时候又能需要不产生变化。这里有W和b,tf.zeros把他们初始化成形状为[784, 10]和[10]但值全为0的矩阵。

定义需要的变量
prediction = tf.nn.softmax(tf.matmul(x, W)+b)correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1))                 accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))loss = tf.reduce_mean(tf.square(y-prediction))train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)init = tf.global_variables_initializer()

prediction是预测值。我们网络最后会导出一个[batch_size, 10]的张量出来,利用softmax我们可以得到分类预测。softmax是激活函数的一种,一方面它可以将我们所创建的线性模型转化成非线性,第二方面是它对变化比较敏感。


其中,V的第i个值经过softmax的值,等于e的Vi次方除以e的所有V的值次方之和。例子:

来自:http://t.cn/ReWqrhA

我们可以看出来,经过Softmax之后原本的值的对比会更加明显(从3:1变成0.88:0.12, 三倍变成7.3倍)。即对的更对,错的更错。

tf.argmax可以取张量里面某一维的最大值的下标。那么取出每一张图片标签和预测值里面的分类,再判断是否相等就可以得到准确与否(correct_prediction)。

tf.reduce_mean把准确率平均就能求出平均准确率。tf.cast使得准确率转化成浮点数,因此求平均的时候不会省略小数部分。

loss是损失值。由于神经网络得到的分类并不一定正确,所以不正确的估计我们会传递回去作为一个损失激励权重更新。而如何确定loss的大小就是用损失函数来决定。这里的损失函数是将y减去网络的预测值然后平方取平均。

举个例子,加入我们输入一张’6‘的图片(数据是虚构的):

train_step节点代表利用梯度下降法来降低loss值。换句话说,它告诉我们需要求loss对权重的梯度来更新权重。这方面涉及到权重的更新方法,会在后面详细介绍。

init代表初始化所有变量的操作。这又要重新提一下,我们到这里也只是画好了一个图。我们在图里面放了很多节点,但到这里它都没产生任何值!

图的结构
运行构建好的图
with tf.Session() as sess:sess.run(init)for epoch in range(21):for batch in range(n_batch):batch_xs, batch_ys = mnist.train.next_batch(batch_size)sess.run(train_step, feed_dict={x: batch_xs, y: batch_ys})acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})print('Iter' + str(epoch) + ",Testing Accuracy" + str(acc))

with tf.Session() as sess代表之后我们开始运行。首先我们都会开始sess.run(init)来运行init这个操作,即现在才开始初始化变量的操作。epoch代表迭代的次数。迭代代表跑完一整个数据集。

mnist.train.next_batch是内置的函数,表示下一批(batch_size)的数据。sess.run(train_step)好像只是运行train_step这个节点,但实际上为了运行它,我们将跟它相关联的节点都跑完了,也就是跑完了一整张图。

feed_dict是代表你喂的数据的字典。将batch_xs, batch_ys都放置在对应的占位符x, y上,此时每次运行x, y都是我们得到的新的批次的数据。接着是运行准确率的节点,调用的是测试集的图片。

我们会得到这样的数据:

你会发现准确率到一定的值就上升不了了,这是因为我们的网络过于简陋。在接下来的课程我们会加入卷积层,池化层,正则化等部分来改善识别的能力。

但是下一篇文章我们会继续深入这个网络来讲权重更新的细节。

全连接层解决MNIST相关推荐

  1. keras库的安装及使用,以全连接层和手写数字识别MNIST为例

    1.什么是keras 什么是keras? keras以TensorFlow和Theano作为后端封装,是一个专门用于深度学习的python模块. 包含了全连接层,卷积层,池化层,循环层,嵌入层等等等, ...

  2. 神经网络-全连接层(1)

    写在前面:感谢@夏龙对本文的审阅并提出了宝贵的意见. 接下来聊一聊现在大热的神经网络.最近这几年深度学习发展十分迅速,感觉已经占据了整个机器学习的"半壁江山".各大会议也是被深度学 ...

  3. 神经网络学习(二)Tensorflow-简单神经网络(全连接层神经网络)实现手写字体识别

    神经网络学习(二)神经网络-手写字体识别 框架:Tensorflow 1.10.0 数据集:mnist数据集 策略:交叉熵损失 优化:梯度下降 五个模块:拿数据.搭网络.求损失.优化损失.算准确率 一 ...

  4. CNN 全连接层与卷积层深刻理解

    CNN 全连接层与卷积层 卷积和全连接关系 卷积取的是局部特征,全连接就是把以前的局部特征重新通过权值矩阵组装成完整的图. 因为用到了所有的局部特征,所以叫全连接 什么是全连接层 全连接层(fully ...

  5. 机器学习入门(15)— 全连接层与卷积层的区别、卷积神经网络结构、卷积运算、填充、卷积步幅、三维数据卷积、多维卷积核运算以及批处理

    卷积神经网络(Convolutional Neural Network,CNN)CNN 被用于图像识别.语音识别等各种场合,在图像识别的比赛中,基于深度学习的方法几乎都以 CNN 为基础. 1. 全连 ...

  6. Lesson 16.1016.1116.1216.13 卷积层的参数量计算,1x1卷积核分组卷积与深度可分离卷积全连接层 nn.Sequential全局平均池化,NiN网络复现

    二 架构对参数量/计算量的影响 在自建架构的时候,除了模型效果之外,我们还需要关注模型整体的计算效率.深度学习模型天生就需要大量数据进行训练,因此每次训练中的参数量和计算量就格外关键,因此在设计卷积网 ...

  7. 基于PyTorch框架的多层全连接神经网络实现MNIST手写数字分类

    多层全连接神经网络实现MNIST手写数字分类 1 简单的三层全连接神经网络 2 添加激活函数 3 添加批标准化 4 训练网络 5 结论 参考资料 先用PyTorch实现最简单的三层全连接神经网络,然后 ...

  8. 卷积神经网络CNN要点:CNN结构、采样层、全连接层、Zero-padding、激活函数及Dropout

    CNN结构: 卷积层:特征提取: 采样层:特征选择: 全连接层:根据特征进行分类. 采样层(pooling): max-pooling:克服卷积层权值参数误差: average-pooling:克服卷 ...

  9. 深度学习(19)神经网络与全连接层二: 测试(张量)实战

    深度学习(19)神经网络与全连接层二: 测试(张量)实战 1. 传入测试集数据 2. 数据类型转换 3. 创建test_db 4. test/evluation 5. 创建神经网络 6. 输出 7. ...

最新文章

  1. ,改变LI背景颜色与背景图片
  2. 面试又挂了,你理解了 Java 8 的 Consumer、Supplier、Predicate和Function吗?
  3. Node-RED安装图形化节点dashboard实现订阅mqtt主题并在仪表盘中显示温度
  4. Java Web——JavaBean简介
  5. Linux的换网变化IP进行固定IP
  6. SAP Spartacus angular.json 中定义的 serve-ssr
  7. Java 类型转换String,List,Map,Array
  8. 性能报告——使用AOP与DYNAMICProxy的orm性能测试
  9. 华为机试HJ13:句子逆序
  10. c++ mysql 写库 乱码 ??_mysql c++ 乱码 解决方法
  11. Javascript模块化编程(转自阮一峰的网络日志)
  12. 关于unity,Player打包面板的信息(上)
  13. python编程怎么画三角形的外接圆_用MATLAB画三角形外接圆
  14. 老服务器上安装windows server 2016
  15. 解读:学习网络安全自学好还是报培训班好
  16. 提高班—I Belonged To You
  17. Building dependency tree… Done Package aptitude is not available, but is referred to by another pac
  18. Chrome常见黑客插件及用法
  19. 两组数据的偏差率_GWT测试报告 篇七十五:隐患难忽视,RIVAL 3 WIRELESS精准度LOD测试...
  20. java指定图片的dpi和存储大小kb

热门文章

  1. 稀疏Softmax(Sparse Softmax)
  2. 特权账号管理系统是什么?是堡垒机吗?
  3. ubuntu安装AMD显卡驱动后无法进入系统解决办法
  4. 新书出版:步步惊“芯” —软核处理器内部设计分析
  5. Uploader 文件上传
  6. 量子计算机造出时间晶体:跳出热力学第二定律的「永动机」出现了?
  7. redis保存登录用户信息
  8. MOBA游戏战斗服务器设计思路
  9. Realtime DB技术详解
  10. 男子带充电宝过机场安检时突然发生爆炸