keras 的 example 文件 mnist_net2net.py 解析
该程序是介绍,如何把一个浅层的卷积神经网络,加深,加宽
如先建立一个简单的神经网络,结构如下:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv1 (Conv2D) (None, 28, 28, 64) 640
_________________________________________________________________
pool1 (MaxPooling2D) (None, 14, 14, 64) 0
_________________________________________________________________
conv2 (Conv2D) (None, 14, 14, 64) 36928
_________________________________________________________________
pool2 (MaxPooling2D) (None, 7, 7, 64) 0
_________________________________________________________________
flatten (Flatten) (None, 3136) 0
_________________________________________________________________
fc1 (Dense) (None, 64) 200768
_________________________________________________________________
fc2 (Dense) (None, 10) 650
=================================================================
Total params: 238,986
Trainable params: 238,986
Non-trainable params: 0
_________________________________________________________________
None
训练完成后,想办法把他加宽,成下面这样
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv1 (Conv2D) (None, 28, 28, 128) 1280
_________________________________________________________________
pool1 (MaxPooling2D) (None, 14, 14, 128) 0
_________________________________________________________________
conv2 (Conv2D) (None, 14, 14, 64) 73792
_________________________________________________________________
pool2 (MaxPooling2D) (None, 7, 7, 64) 0
_________________________________________________________________
flatten (Flatten) (None, 3136) 0
_________________________________________________________________
fc1 (Dense) (None, 128) 401536
_________________________________________________________________
fc2 (Dense) (None, 10) 1290
=================================================================
Total params: 477,898
Trainable params: 477,898
Non-trainable params: 0
_________________________________________________________________
None
或者加深,变成下面这样
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv1 (Conv2D) (None, 28, 28, 64) 640
_________________________________________________________________
pool1 (MaxPooling2D) (None, 14, 14, 64) 0
_________________________________________________________________
conv2 (Conv2D) (None, 14, 14, 64) 36928
_________________________________________________________________
conv2-deeper (Conv2D) (None, 14, 14, 64) 36928
_________________________________________________________________
pool2 (MaxPooling2D) (None, 7, 7, 64) 0
_________________________________________________________________
flatten (Flatten) (None, 3136) 0
_________________________________________________________________
fc1 (Dense) (None, 64) 200768
_________________________________________________________________
fc1-deeper (Dense) (None, 64) 4160
_________________________________________________________________
fc2 (Dense) (None, 10) 650
=================================================================
Total params: 280,074
Trainable params: 280,074
Non-trainable params: 0
_________________________________________________________________
None
也就是介绍如何对神经网络参数进行增、改、查
首先是获取参数,获取卷积层参数和全连接层代码就是下面两行:
w_conv1, b_conv1 = teacher_model.get_layer('conv1').get_weights()w_fc1, b_fc1 = teacher_model.get_layer('fc1').get_weights()
加宽的话,修改卷积层和全连接层参数是下面两行:
model.get_layer('conv1').set_weights([new_w_conv1, new_b_conv1])model.get_layer('fc1').set_weights([new_w_fc1, new_b_fc1])
至于改成什么数据,那就自己可以自由发挥了,要么在原来的基础上,拼接随机的一些层,要么把原来的复制一份然后加一些噪音
加深的话,就是新建一个神经网络,把原有的层的参数获取重新拷贝过去就行了,新增加的层的参数,可以自由发挥如何初始化,
修改后的神经网络重新再进行训练
keras 的 example 文件 mnist_net2net.py 解析相关推荐
- keras 的 example 文件 cnn_seq2seq.py 解析
该代码是实现一个翻译功能,好像是英语翻译为法语,嗯,我看不懂法语 首先这个代码有一个bug,本人提交了一个pull request来修复, https://github.com/keras-team/ ...
- keras 的 example 文件 cifar10_resnet.py 解析
该代码功能是卷积神经网络进行图像识别,数据集是cifar10 同时演示了回调函数 ModelCheckpoint, LearningRateScheduler, ReduceLROnPlateau 的 ...
- keras 的 example 文件 babi_rnn.py 解析
该代码的目的和 https://blog.csdn.net/zhqh100/article/details/105193991 类似 数据集也是同一个数据集,只不过这个是从 qa2_two-suppo ...
- keras 的 example 文件 mnist_hierarchical_rnn.py 解析
很显然,我没有看懂 HRNN 是啥意思,没有去看论文,应该就是一种RNN结构的变形吧 网络结构如下: _________________________________________________ ...
- keras 的 example 文件 mnist_denoising_autoencoder.py 解析
mnist_denoising_autoencoder.py 是一个编解码神经网络,其意义就是如果图片中有噪点的话,可以去除噪点,还原图片 其编码网络为: ______________________ ...
- keras 的 example 文件 mnist_cnn.py 解析
mnist_cnn.py 基本上就是最简单的一个卷积神经网络了,其结构如下: _____________________________________________________________ ...
- keras 的 example 文件 imdb_bidirectional_lstm.py 解析
imdb是一个文本情感分析的数据集,通过评论来分析观众对电影是好评还是差评 其网络结构比较简单 ____________________________________________________ ...
- keras 的 example 文件 lstm_text_generation.py 解析
该程序是学习现有的文章,然后学习预测下个字符,这样一个字符一个字符的学会写文章 先打印下char_indices {'\n': 0, ' ': 1, '!': 2, '"': 3, &quo ...
- keras 的 example 文件 lstm_stateful.py 解析
该程序要通过一个LSTM来实现拟合窗口平均数的功能 先看输入输出数据, print(x_train[:10]) [[[-0.08453234]][[ 0.02169589]][[ 0.07949955 ...
最新文章
- linux编译mysql报无法将左值_'错误:无法将'std::ostream {aka std::basic_ostream
- linux定义别名出错,Linux自定义别名alias重启失效问题
- 数据库框架的log4j日志配置
- 如何操控输入框中的placeholder属性
- 女性自我的迷宫:看EMI的人体自拍
- 【Flink】 Flink 源码之 Buffer Timeout优化
- Python元组,列表,解构和循环
- 解决虚拟机克隆后eth0不见的问题
- android layout wrap_content,android-如果高度为WRAP_CONTENT,则不显示VideoVi...
- vue app准备学习工作
- SaltStack 拉取和推送文件
- python 高并发 tomcat_TOMCAT 高并发配置
- linux硬盘速率测试,【Linux】测试硬盘读写速度
- Scripting for Testers 测试人员脚本编程教程 Lynda课程中文字幕
- 【软件测试的重要性】
- unity reflect_使用Unity Reflect的不同方法
- Android Parcel数据传输源码解析
- ubuntu16.04 安装微信客户端
- 如何给CSDN博客添加微信公众号二维码或自定义栏目
- Qt实战案例(2)——电子时钟的设计
热门文章
- f是一个python内部变量类型,Python基础变量类型——List浅析
- C++ 类模板的使用
- Caused by: org.greenrobot.eventbus.EventBusException: Subscriber class com.baidu.iov.dueros.film.ui
- Android Studio 3.5 之后导入第三方Library 库的方法
- Can't toast on a thread that has not called Looper.prepare()
- Appium 并发测试基于unitest
- 关于C语言中printf函数“输出歧视”的问题
- 运行在CentOS7.5上的Django项目时间不正确问题
- Go 学习推荐 —(Go by example 中文版、Go 构建 Web 应用、Go 学习笔记、Golang常见错误、Go 语言四十二章经、Go 语言高级编程)
- TCP的三次握手和四次分手