DL之LSTM:tf.contrib.rnn.BasicLSTMCell(rnn_unit)函数的解读

目录

tf.contrib.rnn.BasicLSTMCell(rnn_unit)函数的解读

函数功能解读

函数代码实现


tf.contrib.rnn.BasicLSTMCell(rnn_unit)函数的解读

函数功能解读

"""Basic LSTM recurrent network cell.

The implementation is based on: http://arxiv.org/abs/1409.2329.

We add forget_bias (default: 1) to the biases of the forget gate in order to reduce the scale of forgetting in the beginning of the training.

It does not allow cell clipping, a projection layer, and does not use peep-hole connections: it is the basic baseline.  For advanced models, please use the full @{tf.nn.rnn_cell.LSTMCell}
  that follows.

"""

def __init__(self,
               num_units,
               forget_bias=1.0,
               state_is_tuple=True,
               activation=None,
               reuse=None,
               name=None,
               dtype=None):
    """Initialize the basic LSTM cell.

基本LSTM递归网络单元。

实现基于:http://arxiv.org/abs/1409.2329。

我们在遗忘门的偏见中加入了遗忘偏见(默认值:1),以减少训练开始时的遗忘程度。

它不允许细胞剪切(一个投影层),也不使用窥孔连接:它是基本的基线。对于高级模型,请使用完整的@{tf.n .rnn_cell. lstmcell}遵循。

Args:
      num_units: int, The number of units in the LSTM cell.
      forget_bias: float, The bias added to forget gates (see above).
        Must set to `0.0` manually when restoring from CudnnLSTM-trained checkpoints.
      state_is_tuple: If True, accepted and returned states are 2-tuples of the `c_state` and `m_state`.  If False, they are concatenated along the column axis.  The latter behavior will soon be deprecated.
      activation: Activation function of the inner states.  Default: `tanh`.
      reuse: (optional) Python boolean describing whether to reuse variables in an existing scope.  If not `True`, and the existing scope already has the given variables, an error is raised.
      name: String, the name of the layer. Layers with the same name will share weights, but to avoid mistakes we require reuse=True in such cases.
      dtype: Default dtype of the layer (default of `None` means use the type of the first input). Required when `build` is called before `call`.

When restoring from CudnnLSTM-trained checkpoints, must use `CudnnCompatibleLSTMCell` instead.
    """

参数:
num_units: int类型, LSTM单元中的单元数。
forget_bias: float类型,偏见添加到忘记门(见上面)。
从cudnnlstm训练的检查点恢复时,必须手动设置为“0.0”。
state_is_tuple: 如果为真,则接受状态和返回状态是' c_state '和' m_state '的二元组。如果为假,则沿着列轴连接它们。后一种行为很快就会被摒弃。
activation: 内部状态的激活功能。默认值tanh激活函数。
reuse: (可选)Python布尔值,描述是否在现有范围内重用变量。如果不是“True”,并且现有范围已经有给定的变量,则会引发错误。
name:字符串,层的名称。具有相同名称的层将共享权重,但是为了避免错误,我们需要在这种情况下重用=True。
dtype:该层的默认dtype(默认为‘None’意味着使用第一个输入的类型)。当' build '在' call '之前被调用时是必需的。

从经过cudnnlstm训练的检查点恢复时,必须使用“CudnnCompatibleLSTMCell”。
”“”

函数代码实现


@tf_export("nn.rnn_cell.BasicLSTMCell")
class BasicLSTMCell(LayerRNNCell):"""Basic LSTM recurrent network cell.The implementation is based on: http://arxiv.org/abs/1409.2329.We add forget_bias (default: 1) to the biases of the forget gate in order toreduce the scale of forgetting in the beginning of the training.It does not allow cell clipping, a projection layer, and does notuse peep-hole connections: it is the basic baseline.For advanced models, please use the full @{tf.nn.rnn_cell.LSTMCell}that follows."""def __init__(self,num_units,forget_bias=1.0,state_is_tuple=True,activation=None,reuse=None,name=None,dtype=None):"""Initialize the basic LSTM cell.Args:num_units: int, The number of units in the LSTM cell.forget_bias: float, The bias added to forget gates (see above).Must set to `0.0` manually when restoring from CudnnLSTM-trainedcheckpoints.state_is_tuple: If True, accepted and returned states are 2-tuples ofthe `c_state` and `m_state`.  If False, they are concatenatedalong the column axis.  The latter behavior will soon be deprecated.activation: Activation function of the inner states.  Default: `tanh`.reuse: (optional) Python boolean describing whether to reuse variablesin an existing scope.  If not `True`, and the existing scope already hasthe given variables, an error is raised.name: String, the name of the layer. Layers with the same name willshare weights, but to avoid mistakes we require reuse=True in suchcases.dtype: Default dtype of the layer (default of `None` means use the typeof the first input). Required when `build` is called before `call`.When restoring from CudnnLSTM-trained checkpoints, must use`CudnnCompatibleLSTMCell` instead."""super(BasicLSTMCell, self).__init__(_reuse=reuse, name=name, dtype=dtype)if not state_is_tuple:logging.warn("%s: Using a concatenated state is slower and will soon be ""deprecated.  Use state_is_tuple=True.", self)# Inputs must be 2-dimensional.self.input_spec = base_layer.InputSpec(ndim=2)self._num_units = num_unitsself._forget_bias = forget_biasself._state_is_tuple = state_is_tupleself._activation = activation or math_ops.tanh@propertydef state_size(self):return (LSTMStateTuple(self._num_units, self._num_units)if self._state_is_tuple else 2 * self._num_units)@propertydef output_size(self):return self._num_unitsdef build(self, inputs_shape):if inputs_shape[1].value is None:raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"% inputs_shape)input_depth = inputs_shape[1].valueh_depth = self._num_unitsself._kernel = self.add_variable(_WEIGHTS_VARIABLE_NAME,shape=[input_depth + h_depth, 4 * self._num_units])self._bias = self.add_variable(_BIAS_VARIABLE_NAME,shape=[4 * self._num_units],initializer=init_ops.zeros_initializer(dtype=self.dtype))self.built = Truedef call(self, inputs, state):"""Long short-term memory cell (LSTM).Args:inputs: `2-D` tensor with shape `[batch_size, input_size]`.state: An `LSTMStateTuple` of state tensors, each shaped`[batch_size, num_units]`, if `state_is_tuple` has been set to`True`.  Otherwise, a `Tensor` shaped`[batch_size, 2 * num_units]`.Returns:A pair containing the new hidden state, and the new state (either a`LSTMStateTuple` or a concatenated state, depending on`state_is_tuple`)."""sigmoid = math_ops.sigmoidone = constant_op.constant(1, dtype=dtypes.int32)# Parameters of gates are concatenated into one multiply for efficiency.if self._state_is_tuple:c, h = stateelse:c, h = array_ops.split(value=state, num_or_size_splits=2, axis=one)gate_inputs = math_ops.matmul(array_ops.concat([inputs, h], 1), self._kernel)gate_inputs = nn_ops.bias_add(gate_inputs, self._bias)# i = input_gate, j = new_input, f = forget_gate, o = output_gatei, j, f, o = array_ops.split(value=gate_inputs, num_or_size_splits=4, axis=one)forget_bias_tensor = constant_op.constant(self._forget_bias, dtype=f.dtype)# Note that using `add` and `multiply` instead of `+` and `*` gives a# performance improvement. So using those at the cost of readability.add = math_ops.addmultiply = math_ops.multiplynew_c = add(multiply(c, sigmoid(add(f, forget_bias_tensor))),multiply(sigmoid(i), self._activation(j)))new_h = multiply(self._activation(new_c), sigmoid(o))if self._state_is_tuple:new_state = LSTMStateTuple(new_c, new_h)else:new_state = array_ops.concat([new_c, new_h], 1)return new_h, new_state

DL之LSTM:tf.contrib.rnn.BasicLSTMCell(rnn_unit)函数的解读相关推荐

  1. RNN调试错误:lstm_cell = tf.contrib.rnn.core_rnn_cell.BasicLSTMCell(lstm_size) 方法已失效

    调试递归神经网络(RNN)的时候出现如下错误: ### module 'tensorflow.contrib.rnn' has no attribute 'core_rnn_cell' 经检查是tf. ...

  2. Tensorflow:tf.contrib.rnn.DropoutWrapper函数(谷歌已经为Dropout申请了专利!)、MultiRNNCell函数的解读与理解

    Tensorflow:tf.contrib.rnn.DropoutWrapper函数(谷歌已经为Dropout申请了专利!).MultiRNNCell函数的解读与理解 目录 1.tf.contrib. ...

  3. DL之LSTM:基于tensorflow框架利用LSTM算法对气温数据集训练并回归预测

    DL之LSTM:基于tensorflow框架利用LSTM算法对气温数据集训练并回归预测 目录 输出结果 核心代码 输出结果 数据集 tensorboard可视化 iter: 0 loss: 0.010 ...

  4. TF之RNN:实现利用scope.reuse_variables()告诉TF想重复利用RNN的参数的案例

    TF之RNN:实现利用scope.reuse_variables()告诉TF想重复利用RNN的参数的案例 目录 输出结果 代码设计 输出结果 后期更新-- 代码设计 import tensorflow ...

  5. TensorFlow 2——替换【tensorflow.compat.v1.contrib.rnn.LSTMCell】解决方案

    问题描述 Traceback (most recent call last):   File "D:/Code/Project/a18/ocr/demo.py", line 16, ...

  6. 关于D4RL的agent包的tf.contrib兼容性问题

    1.报错的部分 Traceback (most recent call last):   File "experiment.py", line 2, in <module&g ...

  7. 【TensorFlow】TensorFlow函数精讲之tf.contrib.layers.flatten()

    tf.contrib.layers.flatten(A)函数使得P保留第一个维度,把第一个维度包含的每一子张量展开成一个行向量,返回张量是一个二维的,返回的shape为[第一维度,子张量乘积). 一般 ...

  8. TensorFlow学习笔记12----Creating Estimators in tf.contrib.learn

    原文教程:tensorflow官方教程 记录关键内容与学习感受.未完待续.. Creating Estimators in tf.contrib.learn --tf.contrib.learn框架, ...

  9. 双向RNN:bidirectional_dynamic_rnn()函数

    双向RNN:bidirectional_dynamic_rnn()函数 先说下为什么要使用到双向RNN,在读一篇文章的时候,上文提到的信息十分的重要,但这些信息是不足以捕捉文章信息的,下文隐含的信息同 ...

最新文章

  1. Java基础教程——包装类
  2. CTFshow php特性 web95
  3. easyui 收费_收费班长喻玉华三尺岗亭献青春
  4. MTU(Maximum Transmission Unit),最大传输单元
  5. java 90 训练营 二期下载_90天java进阶营二期 主流java技术与热门开源项目视频教程...
  6. Cisco访问控制列表
  7. mysql默认端口号_什么是MySQL默认端口号?
  8. 分享一个查看U盘闪存的工具,SA们别买到假货了!
  9. linux下使用iptables NAT上网
  10. 【vn.py】SpreadTrading价差交易
  11. 智慧城市建设中 网络安全攻防战如何打赢?
  12. Neural Networks and Deep Learning
  13. A40i使用笔记:QT使用alsa采集音频pcm信息
  14. 00 Linux到底是什么?
  15. 三维点云学习(6)7-3D Object Detection-KITTI object detection evaluation(2)-kitt 数据集文件分卷解压方式
  16. 微信h5页面实现人脸注册和登陆
  17. 33的挑战状(bilibili首届安全挑战赛)
  18. python的udp攻击
  19. Android混淆总结
  20. 幼年产品狗如何养成?这是完全自我修炼教程!

热门文章

  1. tcp http https
  2. 深度学习中,Batch_Normalization加速收敛并提高正确率的内部机制
  3. socket编程方法,概念
  4. 只需几步,U盘就能变“光驱”
  5. dojo发布者订阅者模式(topic.publish/topic.subscribe)
  6. 从原理上搞懂如何设置线程池参数大小?
  7. 3种骚操作,教你查看 Java 字节码!
  8. 每天数十亿次请求的应用经验分享,值得参考!
  9. 详记一次MySQL千万级大表优化过程!
  10. 微软 CTO 韦青:“程序员 35 岁就被淘汰”是个伪概念 | 人物志