tensorflow(2) TensorFlow Mechanics 101

标签(空格分隔): tensorflow


参考 TF英文社区TF中文社区

fully_connected_feed.py 是总体的运行过程
mnist.py中定义了四个函数,inference,training,loss,evaluation

mnist.py

一、inference

就是网络结构函数,mnist.py中的inference定义的网络有一对全连接层,和一个有10个线性节点的线性层

  • input:inference输入placeholder和第一层,第二层网络hidden units的个数
  • 每一层都有唯一的name_scope,所有的item都创建在这个namescope下,相当于给这一层的所有item加了一个前缀
with tf.name_scope('hidden1'):
  • 在每一个scope中,weight和biase由tf.Variable生成,大小根据(输入输出)的维度设置
    weight=[connect from,connect to]
    biase=[connect to]
  • 每个变量在创建时,都会被给予一个初始化操作
weights = tf.Variable(tf.truncated_normal([IMAGE_PIXELS, hidden1_units],stddev=1.0 / math.sqrt(float(IMAGE_PIXELS))),name='weights')
biases = tf.Variable(tf.zeros([hidden1_units]),name='biases')

比如weights会用tf.truncated_normal初始化器,根据给定的均值和标准差生成一个随机分布
biase根据tf.zeros保证它们的初始值都是0。

graph中主要有三个operation,两个tf.nn.relu和一个tf.matmul
最后,程序会返回包含了输出结果的logits Tensor。

二、loss

loss() 也是graph的一部分,输入两个参数,神经网络的分类结果和labels正确结果。进行比较,计算损失。

def loss(logits, labels):labels = tf.to_int64(labels)cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits, name='xentropy')return tf.reduce_mean(cross_entropy, name='xentropy_mean')

这个函数分为三步

  • 1.先将labels转换成所需要的格式
    tf.to_int64(labels)这个操作可以将labels抓换成指定的格式1-hot labels,
    1-hot labels:例如,如果类标识符为“3”,那么该值就会被转换为:
    [0, 0, 0, 1, 0, 0, 0, 0, 0, 0]
    将inference的判断结果和labels进行比较
  • 2.利用函数tf.nn.sparse_softmax_cross_entropy_with_logits 计算交叉熵
  • 3.计算一个batch的平均loss
loss = tf.reduce_mean(cross_entropy, name='xentropy_mean')

tf.reduce_mean函数(可跨维度的计算平均值),计算batch维度(第一维度)下交叉熵(cross entropy)的平均值,将将该值作为总损失

三、training

  • input: loss tensor , learning rate
    主要分为四步
  • 1、 创建一个summarizer, 用来更新损失,summary的值会被写在events file里面
    tf.summary.scalar('loss', loss)
  • 2、创建一个optimizer优化器对象tf.train.GradientDescentOptimizer(设置学习率)
  • 3、创建global_step变量 ,用于记录全局训练步骤的单值
  • 4、开始优化 optimizer.minimize(输入loss 和 global step)
  • return train_op

四、evaluation

输入网络的分类结果和labels,和loss函数的输入一样

def evaluation(logits, labels):"""Evaluate the quality of the logits at predicting the label.Args:logits: Logits tensor, float - [batch_size, NUM_CLASSES].labels: Labels tensor, int32 - [batch_size], with values in therange [0, NUM_CLASSES).Returns:A scalar int32 tensor with the number of examples (out of batch_size)that were predicted correctly."""# For a classifier model, we can use the in_top_k Op.# It returns a bool tensor with shape [batch_size] that is true for# the examples where the label is in the top k (here k=1)# of all logits for that example.correct = tf.nn.in_top_k(logits, labels, 1)# Return the number of true entries.return tf.reduce_sum(tf.cast(correct, tf.int32))

fully_connected_feed.py

一旦图建立完之后,就可以在循环训练和评估
tensorflow/tensorflow/examples/tutorials/mnist/fully_connected_feed.py

总体步骤

1、设置输入
定义placeholder,函数def placeholder_inputs(batch_size)
2、开始训练run_rainning

  • 读入数据集
  • 建立图
  • 创建session
  • 初始化
  • 开始循环训练
    • check status
    • do evaluation

1.place holder

2.the graph

建图中所有的操作都是在with tf.Graph().as_default()下进行的
tf.graph可能会执行所有的ops,可以包含多个图,创建多个线程
我们只需要一个single graph

3.session

在定义完图后,需要创建一个会话session来开启这个图
- 创建session sess=tf.session()
- 创建initializer, initializer=tf.global_variables_initializer
- sess.run(initializer) 会自动初始化所有的变量

4.training loop

在变量初始化完成之后,就可以开始训练了
最简单的训练过程就以下两行代码

with step in xrange(FLAGS.max_Step)sess.run(train_op)

但是本例子要复杂一点,读入的数据每一步都要进行切分,以适应之前生成的place_holder

(1).fill_feed_dict

先让image_feed和labels_feed去向dataset索要下一次训练的一个batchsize的数据

images_feed, labels_feed = data_set.next_batch(FLAGS.batch_size,FLAGS.fake_data)

再讲这个数据整合成一个python字典的形式,image_placeholder 和labels_placeholder作为字典的key, image_feed和labels_feed作为字典的value

feed_dict = {images_placeholder: images_feed,labels_placeholder: labels_feed,
}

(2).检查训练状态

see.run 在每一步训练之后都会取得两个值,loss 和train_op,(train_op不返回任何值,discard)
,所以会得到每一步的loss
每训练100次,check一下,输出loss
每训练1000次,进行evaluation,将生成的model保存一下

(3).do_eval

计算整个epoch的精度

  true_count = 0  # Counts the number of correct predictions.steps_per_epoch = data_set.num_examples // FLAGS.batch_sizenum_examples = steps_per_epoch * FLAGS.batch_sizefor step in xrange(steps_per_epoch):feed_dict = fill_feed_dict(data_set,images_placeholder,labels_placeholder)true_count += sess.run(eval_correct, feed_dict=feed_dict)precision = float(true_count) / num_examplesprint('  Num examples: %d  Num correct: %d  Precision @ 1: %0.04f' %(num_examples, true_count, precision))

代码

建图步骤

  • Generate placeholders for the images and labels.
  • Build a Graph that computes predictions from the inference model.
  • Add to the Graph the Ops for loss calculation.
  • Add to the Graph the Ops that calculate and apply gradients.
  • Add the Op to compare the logits to the labels during evaluation.
  • Build the summary Tensor based on the TF collection of Summaries.
  • Add the variable initializer Op.
  • Create a saver for writing training checkpoints.
  • Create a session for running Ops on the Graph.
  • Instantiate a SummaryWriter to output summaries and the Graph.
  • And then after everything is built:Run the Op to initialize the variables.
    Start the training loop.

读取数据

def run_training():"""Train MNIST for a number of steps."""# Get the sets of images and labels for training, validation, and# test on MNIST.data_sets = input_data.read_data_sets(FLAGS.input_data_dir, FLAGS.fake_data)

开始建图

  # Tell TensorFlow that the model will be built into the default Graph.with tf.Graph().as_default():# Generate placeholders for the images and labels.images_placeholder, labels_placeholder = placeholder_inputs(FLAGS.batch_size)# Build a Graph that computes predictions from the inference model.logits = mnist.inference(images_placeholder,FLAGS.hidden1,FLAGS.hidden2)# Add to the Graph the Ops for loss calculation.loss = mnist.loss(logits, labels_placeholder)# Add to the Graph the Ops that calculate and apply gradients.train_op = mnist.training(loss, FLAGS.learning_rate)# Add the Op to compare the logits to the labels during evaluation.eval_correct = mnist.evaluation(logits, labels_placeholder)# Build the summary Tensor based on the TF collection of Summaries.summary = tf.summary.merge_all()# Add the variable initializer Op.init = tf.global_variables_initializer()# Create a saver for writing training checkpoints.saver = tf.train.Saver()# Create a session for running Ops on the Graph.sess = tf.Session()# Instantiate a SummaryWriter to output summaries and the Graph.summary_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph)# And then after everything is built:# Run the Op to initialize the variables.sess.run(init)

开始循环训练

    # Start the training loop.for step in xrange(FLAGS.max_steps):start_time = time.time()# Fill a feed dictionary with the actual set of images and labels# for this particular training step.feed_dict = fill_feed_dict(data_sets.train,images_placeholder,labels_placeholder)# Run one step of the model.  The return values are the activations# from the `train_op` (which is discarded) and the `loss` Op.  To# inspect the values of your Ops or variables, you may include them# in the list passed to sess.run() and the value tensors will be# returned in the tuple from the call._, loss_value = sess.run([train_op, loss],feed_dict=feed_dict)duration = time.time() - start_time# Write the summaries and print an overview fairly often.if step % 100 == 0:# Print status to stdout.print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration))# Update the events file.summary_str = sess.run(summary, feed_dict=feed_dict)summary_writer.add_summary(summary_str, step)summary_writer.flush()# Save a checkpoint and evaluate the model periodically.if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps:checkpoint_file = os.path.join(FLAGS.log_dir, 'model.ckpt')saver.save(sess, checkpoint_file, global_step=step)# Evaluate against the training set.print('Training Data Eval:')do_eval(sess,eval_correct,images_placeholder,labels_placeholder,data_sets.train)# Evaluate against the validation set.print('Validation Data Eval:')do_eval(sess,eval_correct,images_placeholder,labels_placeholder,data_sets.validation)# Evaluate against the test set.print('Test Data Eval:')do_eval(sess,eval_correct,images_placeholder,labels_placeholder,data_sets.test)

tensorflow学习(2)TensorFlow Mechanics 101相关推荐

  1. tensorflow学习笔记——使用TensorFlow操作MNIST数据(1)

    续集请点击我:tensorflow学习笔记--使用TensorFlow操作MNIST数据(2) 本节开始学习使用tensorflow教程,当然从最简单的MNIST开始.这怎么说呢,就好比编程入门有He ...

  2. 深度学习与TensorFlow

    深度学习与TensorFlow DNN(深度神经网络算法)现在是AI社区的流行词.最近,DNN 在许多数据科学竞赛/Kaggle 竞赛中获得了多次冠军. 自从 1962 年 Rosenblat 提出感 ...

  3. 深度学习调用TensorFlow、PyTorch等框架

    深度学习调用TensorFlow.PyTorch等框架 一.开发目标目标 提供统一接口的库,它可以从C++和Python中的多个框架中运行深度学习模型.欧米诺使研究人员能够在自己选择的框架内轻松建立模 ...

  4. TensorFlow学习笔记——实现经典LeNet5模型

    TensorFlow实现LeNet-5模型 文章目录 TensorFlow实现LeNet-5模型 前言 一.什么是TensorFlow? 计算图 Session 二.什么是LeNet-5? INPUT ...

  5. 在浏览器中进行深度学习:TensorFlow.js (四)用基本模型对MNIST数据进行识别

    2019独角兽企业重金招聘Python工程师标准>>> 在了解了TensorflowJS的一些基本模型的后,大家会问,这究竟有什么用呢?我们就用深度学习中被广泛使用的MINST数据集 ...

  6. Tensorflow学习资源

    欢迎大家关注我们的网站和系列教程:http://www.tensorflownews.com/,学习更多的机器学习.深度学习的知识! 作者:AI小昕 在之前的Tensorflow系列文章中,我们教大家 ...

  7. TensorFlow 深度学习笔记 TensorFlow实现与优化深度神经网络

    TensorFlow 深度学习笔记 TensorFlow实现与优化深度神经网络 转载请注明作者:梦里风林 Github工程地址:https://github.com/ahangchen/GDLnote ...

  8. 【干货】史上最全的Tensorflow学习资源汇总,速藏!

    一 .Tensorflow教程资源: 1)适合初学者的Tensorflow教程和代码示例:(https://github.com/aymericdamien/TensorFlow-Examples)该 ...

  9. tensorflow学习函数笔记

    为什么80%的码农都做不了架构师?>>>    [TensorFlow教程资源](https://my.oschina.net/u/3787228/blog/1794868](htt ...

最新文章

  1. python二分法求解_Python使用二分法求平方根的简单示例
  2. SQLSERVER查看sql语句的执行时间
  3. Xamarin Android SDK无法更新的解决办法
  4. 32边界的链码表示MPP算法MATLAB实现
  5. mysql互为主从利弊_MySQL互为主从复制常见问题
  6. 【算法随记一】Canny边缘检测算法实现和优化分析。
  7. Android自定义View——可以设置最大宽高的FrameLayout
  8. 深入了解java虚拟机(JVM) 第四章 对象的创建
  9. Nginx 部署 Django
  10. 数据库水平拆分和垂直拆分区别(以mysql为例)
  11. Ioc容器beanDefinition-Spring 源码系列(1)
  12. linux文件编程(二)
  13. 借着酒劲儿,是真敢说!程序员酒后吐真言
  14. 【解决方案】GB28181/RTSP/Onvif/HikSDK/Ehome协议视频共享平台EasyCVR人脸识别助力打造智慧安检
  15. C++ 捕获程序异常奔溃minidump
  16. 设置导航栏字体大小,颜色和加粗字体的方法
  17. 前端优化首屏加载速度
  18. 都2022年了,还在争论编程语言?
  19. HTTP/2和HTTP/3
  20. php 短信验证 云之讯,python3.7实现云之讯、聚合短信平台的短信发送功能

热门文章

  1. C语言中‘a‘和“a“有什么区别?
  2. AXURE表白原型(拼图+心形照片墙+表白信)
  3. 计算机简单门电路和加法运算
  4. 浏览器上登录堡垒机_登录堡垒机.doc
  5. Toast自定义,图片加文字
  6. android utf-8 转 gbk编码,Golang GBK与UTF-8互转
  7. 韩乔生逐条点评“韩乔生语录”:那是宋世雄老师说的
  8. Redis 编译安装 基础命令 服务优化 持久化
  9. win7开机突然变得很慢_Win7开机慢的原因及其解决方法
  10. C++实现 L1-038 新世界 (5分)