static_rnn和dynamic_rnn的区别主要在于实现不同。

  • static_rnn会把RNN展平,用空间换时间。 gpu会吃不消(个人测试结果)

  • dynamic_rnn则是使用for或者while循环。

调用static_rnn实际上是生成了rnn按时间序列展开之后的图。打开tensorboard你会看到sequence_length个rnn_cell
stack在一起,只不过这些cell是share
weight的。因此,sequence_length就和图的拓扑结构绑定在了一起,因此也就限制了每个batch的sequence_length必须是一致。

调用dynamic_rnn不会将rnn展开,而是利用tf.while_loop这个api,通过Enter, Switch, Merge,
LoopCondition, NextIteration等这些control
flow的节点,生成一个可以执行循环的图(这个图应该还是静态图,因为图的拓扑结构在执行时是不会变化的)。在tensorboard上,你只会看到一个rnn_cell,
外面被一群control
flow节点包围着。对于dynamic_rnn来说,sequence_length仅仅代表着循环的次数,而和图本身的拓扑没有关系,所以每个batch可以有不同sequence_length。

static_rnn

导包、加载数据、定义变量
import tensorflow as tf
tf.reset_default_graph() #流式计算图形graph  循环神经网络 将名字相同重置了图
import datetime #打印时间
import os   #保存文件
from tensorflow.examples.tutorials.mnist import input_data# minst测试集
mnist = input_data.read_data_sets('../', one_hot=True)# 每次使用100条数据进行训练
batch_size = 100
# 图像向量
width = 28
height = 28
# LSTM隐藏神经元数量
rnn_size = 256
# 输出层one-hot向量长度的
out_size = 10

声明变量

def weight_variable(shape, w_alpha=0.01):initial = w_alpha * tf.random_normal(shape)return tf.Variable(initial)def bias_variable(shape, b_alpha=0.1):initial = b_alpha * tf.random_normal(shape)return tf.Variable(initial)# 权重及偏置
w = weight_variable([rnn_size, out_size])
b = bias_variable([out_size])

将数据转化成RNN所要求的数据

# 按照图片大小申请占位符
X = tf.placeholder(tf.float32, [None, height, width])
# 原排列[0,1,2]transpose为[1,0,2]代表前两维装置,如shape=(1,2,3)转为shape=(2,1,3)
# 这里的实际意义是把所有图像向量的相同行号向量转到一起,如x1的第一行与x2的第一行
x = tf.transpose(X, [1, 0, 2])
# reshape -1 代表自适应,这里按照图像每一列的长度为reshape后的列长度
x = tf.reshape(x, [-1, width])
# split默任在第一维即0 dimension进行分割,分割成height份,这里实际指把所有图片向量按对应行号进行重组
x = tf.split(x, height)

构建静态的循环神经网络

# LSTM
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(rnn_size)
# 这里RNN会有与输入层相同数量的输出层,我们只需要最后一个输出
outputs, status = tf.nn.static_rnn(lstm_cell, x, dtype=tf.float32)#取最后一个进行矩阵乘法
y_conv = tf.add(tf.matmul(outputs[-1], w), b)
# 最小化损失优化
Y = tf.placeholder(dtype=tf.float32,shape = [None,10])
#损失使用的交叉熵
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=y_conv, labels=Y))
optimizer = tf.train.AdamOptimizer(0.01).minimize(loss)
# 计算准确率
correct = tf.equal(tf.argmax(y_conv, 1), tf.argmax(Y, 1))
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))

模型的训练

# 启动会话.开始训练
saver = tf.train.Saver()
session = tf.Session()
session.run(tf.global_variables_initializer())
step = 0
acc_rate = 0.90
while 1:batch_x, batch_y = mnist.train.next_batch(batch_size)batch_x = batch_x.reshape((batch_size, height, width))session.run(optimizer, feed_dict={X:batch_x,Y:batch_y})# 每训练10次测试一次if step % 10 == 0:batch_x_test = mnist.test.imagesbatch_y_test = mnist.test.labelsbatch_x_test = batch_x_test.reshape([-1, height, width])acc = session.run(accuracy, feed_dict={X: batch_x_test, Y: batch_y_test})print(datetime.datetime.now().strftime('%c'), ' step:', step, ' accuracy:', acc)# 偏差满足要求,保存模型if acc >= acc_rate:
#             os.sep = ‘/’model_path = os.getcwd() + os.sep + str(acc_rate) + "mnist.model"saver.save(session, model_path, global_step=step)breakstep += 1
session.close()

Wed Dec 18 10:08:45 2019 step: 0 accuracy: 0.1006
Wed Dec 18 10:08:46 2019 step: 10 accuracy: 0.1009
Wed Dec 18 10:08:46 2019 step: 20 accuracy: 0.1028

Wed Dec 18 10:08:57 2019 step: 190 accuracy: 0.9164

dynamic_rnn

加载数据,声明变量
import tensorflow as tf
tf.reset_default_graph()
from tensorflow.examples.tutorials.mnist import input_data# 载入数据
mnist = input_data.read_data_sets("../", one_hot=True)# 输入图片是28
n_input = 28
max_time = 28
lstm_size = 100  # 隐藏单元 可调
n_class = 10  # 10个分类
batch_size = 100   # 每次50个样本 可调
n_batch_size = mnist.train.num_examples // batch_size    # 计算一共有多少批次

Extracting …/train-images-idx3-ubyte.gz
Extracting …/train-labels-idx1-ubyte.gz
Extracting …/t10k-images-idx3-ubyte.gz
Extracting …/t10k-labels-idx1-ubyte.gz

占位符、权重

# 这里None表示第一个维度可以是任意长度
# 创建占位符
x = tf.placeholder(tf.float32,[None, 28*28])
# 正确的标签
y = tf.placeholder(tf.float32,[None, 10])# 初始化权重 ,stddev为标准差
weight = tf.Variable(tf.truncated_normal([lstm_size, n_class], stddev=0.1))
# 初始化偏置层
biases = tf.Variable(tf.constant(0.1, shape=[n_class]))

构建动态RNN、损失函数、准确率

# 定义RNN网络
def RNN(X, weights, biases):#  原始数据为[batch_size,28*28]# input = [batch_size, max_time, n_input]input_ = tf.reshape(X,[-1, max_time, n_input])# 定义LSTM的基本单元
#     lstm_cell = tf.contrib.rnn.BasicLSTMCell(lstm_size)lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(lstm_size)# final_state[0] 是cell state# final_state[1] 是hidden statoutputs, final_state = tf.nn.dynamic_rnn(lstm_cell, input_, dtype=tf.float32)display(final_state)results = tf.nn.softmax(tf.matmul(final_state[1],weights)+biases)return results
# 计算RNN的返回结果
prediction = RNN(x, weight, biases)
# 损失函数
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction, labels=y))
# 使用AdamOptimizer进行优化
train_step = tf.train.AdamOptimizer(1e-4).minimize(loss)
# 将结果存下来
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1))
# 计算正确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

LSTMStateTuple(c=<tf.Tensor ‘rnn/while/Exit_3:0’ shape=(?, 100) dtype=float32>, h=<tf.Tensor ‘rnn/while/Exit_4:0’ shape=(?, 100) dtype=float32>)

训练数据

saver = tf.train.Saver()with tf.Session() as sess:sess.run(tf.global_variables_initializer())for epoch in range(6):for batch in range(n_batch_size):# 取出下一批次数据batch_xs,batch_ys = mnist.train.next_batch(batch_size)sess.run(train_step, feed_dict={x: batch_xs,y: batch_ys})if(batch%100==0):print(str(batch)+"/" + str(n_batch_size))acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})print("Iter" + str(epoch) + " ,Testing Accuracy = " + str(acc))if acc >0.9:saver.save(sess,'./rnn_dynamic')break

0/550
100/550
200/550
300/550
400/550
500/550
Iter0 ,Testing Accuracy = 0.5903

Iter5 ,Testing Accuracy = 0.9103

【tensorflow】static_rnn与dynamic_rnn的区别相关推荐

  1. Keras vs. tf.keras: 在TensorFlow 2.0中有什么区别?

    在本教程中,您将发现Keras和tf.keras之间的区别,包括TensorFlow 2.0中的新增功能. Keras vs. tf.keras: 在TensorFlow 2.0中有什么区别? htt ...

  2. Tensorflow:interactivesession和session的区别。

    目录: 目录: 前言 正文 总结 前言 在练习tensorflow的时候发现了很多很有意思的基本问题,写个帖子记录一下,既方便了回顾,又方便了同学学习/ 正文 tf.Session()和tf.Inte ...

  3. Keras vs tf.keras: 在TensorFlow 2.0中有什么区别?

    导读 在本文中,您将发现Keras和tf.keras之间的区别,包括TensorFlow 2.0中的新增功能. 万众期待的TensorFlow 2.0于9月30日正式发布. 虽然肯定是值得庆祝的时刻, ...

  4. Tensorflow get_variable和Varialbe的区别

    import tensorflow as tf""" tf.get_variable()和Variable有很多不同点 * 它们对重名操作的处理不同 * 它们受name_ ...

  5. Keras与tf.keras:TensorFlow 2.0有什么区别?

      在本教程的第一部分中,我们将讨论Keras和TensorFlow之间相互交织的历史,包括他们共同的受欢迎程度如何相互滋养,相互促进和滋养,使我们走向今天.   然后,我将讨论为什么您应该在以后的所 ...

  6. TensorFlow与PyTorch模型部署性能比较

    TensorFlow与PyTorch模型部署性能比较 前言 2022了,选 PyTorch 还是 TensorFlow?之前有一种说法:TensorFlow 适合业界,PyTorch 适合学界.这种说 ...

  7. 掌握深度学习,为什么要用PyTorch、TensorFlow框架?

    作者 | Martin Heller 译者 | 弯月 责编 | 屠敏 来源 | CSDN(ID:CSDNnews) [导读]如果你需要深度学习模型,那么 PyTorch 和 TensorFlow 都是 ...

  8. 2022年了,PyTorch和TensorFlow选哪个?

    Datawhale推荐 作者:Ryan O'Connor,来源:机器之心 坊间传闻:「TensorFlow 适合业界,PyTorch 适合学界」.都 2022 年了,还是这样吗? 2022年了,你是选 ...

  9. tensorflow与keras关系

    tensorflow简介以及与Keras的关系 - eyesfree - 博客园 TensorFlow 和keras有什么区别? - 知乎

最新文章

  1. python里面temp是啥-python temp file:如何打开多次临时文件?
  2. 内嵌在客户端的网页出现刷新问题
  3. tensorflow从入门到精通100讲(六)-在TensorFlow Serving/Docker中做keras 模型部署
  4. int转base64编码
  5. 手写简版spring --9--对象作用域和FactoryBean
  6. Java中sum和Sum相同吗,Java认为变量Sum 和sum相同。
  7. 推流地址 java_Java实现腾讯云直播生成推流地址和播放地址
  8. 700. 二叉搜索树中的搜索
  9. 电工结业试卷_电工技术基础结业考试试卷
  10. php创建mysql分区,MySql创建分区表
  11. 3d激光雷达开发(从halcon看点云pcl库)
  12. ++库 照片风格转换风格_如何用神经网络实现照片的风格转换
  13. mylyn提交到JIRA的日期格式错误
  14. ubuntu 个人常用的命令
  15. NXP JN5169 USB Dongle 原理图
  16. Linux压缩命令gzip, bzip2和tar
  17. 数据库作业6——嵌套查询
  18. CS269I:Incentives in Computer Science 学习笔记 Lecture 17 评分规则和同辈预测(诚实预报和反馈激励)
  19. 垃圾回收之如何判断对象可以回收、四种引用以及实际案例操作
  20. 部署kubernetes官网博客

热门文章

  1. 从一次换机器的过程谈软硬件的分离
  2. 视频领域的Instagram:Viddy用户突破2600万
  3. 数据挖掘:如何寻找相关项
  4. Vue.js 极简小例: 点击事件
  5. FreeSql (三十)读写分离
  6. [转载]基于Aaf的数据拆分
  7. pyqt5 + pyinstaller 制作爬虫小程序
  8. 函数对象 函数嵌套 名称空间与作用域
  9. BZOJ.2741.[FOTILE模拟赛]L(分块 可持久化Trie)
  10. TextTree - 文本资料收集轻量级工具