该代码实现了通过神经网络来计算两个三位数的相加

先生成一堆训练数据,打印一下

print(questions[:10])
print(expected[:10])

结果为:

[' 31+991', ' 46+154', '    0+2', '    9+9', '    1+7', '  827+2', '  97+09', '    0+8', '    5+3', '  5+239']
['212 ', '515 ', '2   ', '18  ', '8   ', '730 ', '169 ', '8   ', '8   ', '937 ']

编码的时候,questions是前面加空格,后面是真实的计算字符串,也就是右对齐

expected是后面加空格,也就是说expected字符串是左对齐

然后进行编码,参考下面的questions编码方式

 31+991
[[ True False False False False False False False False False False False][False False False False False  True False False False False False False][False False False  True False False False False False False False False][False  True False False False False False False False False False False][False False False False False False False False False False False  True][False False False False False False False False False False False  True][False False False  True False False False False False False False False]]46+154
[[ True False False False False False False False False False False False][False False False False False False  True False False False False False][False False False False False False False False  True False False False][False  True False False False False False False False False False False][False False False  True False False False False False False False False][False False False False False False False  True False False False False][False False False False False False  True False False False False False]]0+2
[[ True False False False False False False False False False False False][ True False False False False False False False False False False False][ True False False False False False False False False False False False][ True False False False False False False False False False False False][False False  True False False False False False False False False False][False  True False False False False False False False False False False][False False False False  True False False False False False False False]]

上面的一行,分别对应[空格, +, 0,1,2,3,4,5,6,7,8,9],所以字符串进行了类似的one-hot编码

expected也是一样:

212
[[False False False False  True False False False False False False False][False False False  True False False False False False False False False][False False False False  True False False False False False False False][ True False False False False False False False False False False False]]
515
[[False False False False False False False  True False False False False][False False False  True False False False False False False False False][False False False False False False False  True False False False False][ True False False False False False False False False False False False]]
2
[[False False False False  True False False False False False False False][ True False False False False False False False False False False False][ True False False False False False False False False False False False][ True False False False False False False False False False False False]]

因为expected中没有加号,所以第二列永远为False

x_train.shape和y_train.shape分别为(45000, 7, 12) (45000, 4, 12)

神经网络模型为:

__________________________________________________________________________________________
Layer (type)                            Output Shape                        Param #
==========================================================================================
lstm_1 (LSTM)                           (None, 128)                         72192
__________________________________________________________________________________________
repeat_vector_1 (RepeatVector)          (None, 4, 128)                      0
__________________________________________________________________________________________
lstm_2 (LSTM)                           (None, 4, 128)                      131584
__________________________________________________________________________________________
time_distributed_1 (TimeDistributed)    (None, 4, 12)                       1548
==========================================================================================
Total params: 205,324
Trainable params: 205,324
Non-trainable params: 0
__________________________________________________________________________________________

上面可以看到,两个LSTM的输出shape不一样,一个是(None, 128),另一个是(None, 4, 128),这是因为第一个RNN的return_sequences为False,而第一个RNN的return_sequences为True

代码解释参考官方教程:

https://keras.io/zh/examples/addition_rnn/

——————————————————————

总目录

keras的example文件解析

keras 的 example 文件 addition_rnn.py 解析相关推荐

  1. keras 的 example 文件 cnn_seq2seq.py 解析

    该代码是实现一个翻译功能,好像是英语翻译为法语,嗯,我看不懂法语 首先这个代码有一个bug,本人提交了一个pull request来修复, https://github.com/keras-team/ ...

  2. keras 的 example 文件 cifar10_resnet.py 解析

    该代码功能是卷积神经网络进行图像识别,数据集是cifar10 同时演示了回调函数 ModelCheckpoint, LearningRateScheduler, ReduceLROnPlateau 的 ...

  3. keras 的 example 文件 babi_rnn.py 解析

    该代码的目的和 https://blog.csdn.net/zhqh100/article/details/105193991 类似 数据集也是同一个数据集,只不过这个是从 qa2_two-suppo ...

  4. keras 的 example 文件 mnist_hierarchical_rnn.py 解析

    很显然,我没有看懂 HRNN 是啥意思,没有去看论文,应该就是一种RNN结构的变形吧 网络结构如下: _________________________________________________ ...

  5. keras 的 example 文件 mnist_denoising_autoencoder.py 解析

    mnist_denoising_autoencoder.py 是一个编解码神经网络,其意义就是如果图片中有噪点的话,可以去除噪点,还原图片 其编码网络为: ______________________ ...

  6. keras 的 example 文件 mnist_cnn.py 解析

    mnist_cnn.py 基本上就是最简单的一个卷积神经网络了,其结构如下: _____________________________________________________________ ...

  7. keras 的 example 文件 imdb_bidirectional_lstm.py 解析

    imdb是一个文本情感分析的数据集,通过评论来分析观众对电影是好评还是差评 其网络结构比较简单 ____________________________________________________ ...

  8. keras 的 example 文件 lstm_text_generation.py 解析

    该程序是学习现有的文章,然后学习预测下个字符,这样一个字符一个字符的学会写文章 先打印下char_indices {'\n': 0, ' ': 1, '!': 2, '"': 3, &quo ...

  9. keras 的 example 文件 lstm_stateful.py 解析

    该程序要通过一个LSTM来实现拟合窗口平均数的功能 先看输入输出数据, print(x_train[:10]) [[[-0.08453234]][[ 0.02169589]][[ 0.07949955 ...

最新文章

  1. html5 点击事件委托,jquery事件委托
  2. sql server数据库定时自动备份
  3. 简单的留言板 php,php 简单留言板教程一
  4. 请问华为三层交换机里面的那个从IP是个什么意思? -
  5. 排序算法之选择法排序(C/C++)
  6. python互斥锁原理_Linux 互斥锁的实现原理(pthread_mutex_t)
  7. vue动态发布到线上_Vue 2.6 发布了
  8. 10万伪原创同义词替代词库ACCESS/EXCELL数据库
  9. React-native开发-Unrecognized font family ‘Ionicons’
  10. 提高计算机CPU运行速度,提高cpu运行速度的方法
  11. 计算机组成原理平均cpi怎么算_计算机组成原理-计算机的性能指标及计算题
  12. 现身说法:37岁老码农找工作!
  13. 2022年系统分析师综合知识考点整理
  14. 计算机主机usb端口使用不了,电脑usb接口不能用,教您电脑usb接口不能用怎么办...
  15. Impala 在网易有数 BI 应用场景的优化经验
  16. 使用IO完善快递管理系统
  17. 设备常用网管配置举例
  18. 2022已加载100%,请查收!
  19. 第一节 、MPC5744P之系统集成单元总结 SIUL2(System Integration Unit Lite2 )介绍
  20. 小米网关+HomeAssistant获取智能硬件数据

热门文章

  1. linux创建一个交换分区,如何创建linux交换分区
  2. Error: Gradle project sync failed. Please fix your project and try again.
  3. Manifest merger failed Suggestion: add 'tools:replace=“Android:value”' to meta-data element at And
  4. java增强for循环
  5. 数据库Mysql的学习(八)-储存过程和事务和导入导出
  6. 一分钟了解负载均衡的一切
  7. 2022-2028年中国普鲁兰多糖行业市场分析及投资前景研究报告
  8. Go 学习笔记(47)— Go 标准库之 strconv(string/int 互相转换、Parse 字符串转换为指定类型、Format 指定类型格式化为字符串)
  9. 每个程序员都需要学习 JavaScript 的7个理由
  10. [高中作文赏析]感受冬天