用Tensorflow1.1搭建一个自编码网络(含有多个隐藏层),在MNIST数据上进行训练
自编码网络的作用是将输入的样本进行压缩到隐藏层,然后解压在输出层重建.所以输入层和输出层神经元的数量是相等的.在压缩的过程当中网络会除去冗余的信息(要限制隐藏层神经元的数量),留下有用的特征.类似与主成分分析PCA,
多个隐藏层能够学到更有意义的特征

# 导入相关的库
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
# 设置模型的训练参数
# 学习效率
# 训练的轮数
# 小批量数据大小
# 每隔几轮显示训练结果
# 测试样本数量
learning_rate = 0.001
training_epochs = 40
batch_size = 200
display_step = 1
test_examples = 10
# 网络参数,输入层,第一个隐藏层,第二个隐藏层的节点个数,
# 网络实际是有3个隐藏层(784,256,128,256,784)
n_input = 784
n_hidden1 = 256
n_hidden2 = 128
# 定义输入数据节点(占位节点)
X = tf.placeholder(tf.float32, [None, n_input])   # None的意识是:任意数量
# 在深度模型中,权重初始化的太小,信号会在每层传递的时候逐渐缩小,但是如果权重
# 初始化的太大,信号会在每层传递的时候逐渐放大导致发散和失效,Xavier初始化方
# 方法就是,让权重初始化的不大也不小
# layer1, layer2 为相邻两层神经元的节点数
def xavier_init(layer1, layer2, constant = 1):Min = -constant * np.sqrt(6.0 / (layer1 + layer2))Max = constant * np.sqrt(6.0 / (layer1 + layer2))return tf.random_uniform((layer1, layer2), minval = Min, maxval = Max, dtype = tf.float32)
# 初始化权重(字典)和偏置
weights = {'encoder_1': tf.Variable(xavier_init(n_input, n_hidden1)),'encoder_2': tf.Variable(xavier_init(n_hidden1, n_hidden2)),'decoder_1': tf.Variable(xavier_init(n_hidden2, n_hidden1)),'decoder_2': tf.Variable(xavier_init(n_hidden1, n_input)),
}
# 偏置
biases = {'encoder_1': tf.Variable(tf.random_normal([n_hidden1])),'encoder_2': tf.Variable(tf.random_normal([n_hidden2])),'decoder_1': tf.Variable(tf.random_normal([n_hidden1])),'decoder_2': tf.Variable(tf.random_normal([n_input])),
}
# 定义压缩方法
def encoder(x):h_layer1 = tf.nn.sigmoid(tf.add(tf.matmul(x, weights['encoder_1']), biases['encoder_1']))h_layer2 = tf.nn.sigmoid(tf.add(tf.matmul(h_layer1, weights['encoder_2']), biases['encoder_2']))return h_layer2# 解压重建方法
def decoder(h_layer2):h_layer3 = tf.nn.sigmoid(tf.add(tf.matmul(h_layer2, weights['decoder_1']), biases['decoder_1']))out_layer = tf.nn.sigmoid(tf.add(tf.matmul(h_layer3, weights['decoder_2']), biases['decoder_2']))return out_layer
# 构建模型
encoder_op = encoder(X)
decoder_op = decoder(encoder_op)
# 网络输出结果
pred_y = decoder_op
# 真实值,即输入
y = X
# 定义损失函数和优化器
cost = tf.reduce_mean(tf.pow(y - pred_y, 2))
optimizer = tf.train.RMSPropOptimizer(learning_rate).minimize(cost)
# 变量初始化Op
init = tf.global_variables_initializer()
# 加载数据
mnist = input_data.read_data_sets('MNIST/mnist', one_hot=True)
Extracting MNIST/mnist/train-images-idx3-ubyte.gz
Extracting MNIST/mnist/train-labels-idx1-ubyte.gz
Extracting MNIST/mnist/t10k-images-idx3-ubyte.gz
Extracting MNIST/mnist/t10k-labels-idx1-ubyte.gz
# 开启一个回话
with tf.Session() as sess:sess.run(init)total_batch = int(mnist.train.num_examples/batch_size)# 开始Cost = []for epoch in range(training_epochs):for i in range(total_batch):batch_x, batch_y = mnist.train.next_batch(batch_size)# 执行optimizer,和 cost ,返回 loss value_, c = sess.run([optimizer, cost], feed_dict={X:batch_x})Cost.append(c)# 打印训练情况if epoch % 1 ==0:print 'Epoch: %d, cost = %.9f'%(epoch+1, c)print ('Optimization Finished')fig1,ax1 = plt.subplots(figsize=(10,5))plt.plot(Cost)ax1.set_xlabel('Epochs')ax1.set_ylabel('Cost')plt.show()
#---------------------------------------测试-------------------------------------# 选取测试数据(10个样本)进行重建测试test_pred_y = sess.run(pred_y, feed_dict={X: mnist.test.images[10:20]})# 比较测试原始图片和 压缩重建后的图片fig, ax = plt.subplots(2, 10, figsize=(10, 2))for i in range(10):ax[0][i].imshow(np.reshape(mnist.test.images[10+i], (28,28)))ax[1][i].imshow(np.reshape(test_pred_y[i], (28, 28)))plt.show()
Epoch: 1, cost = 0.065628514
Epoch: 2, cost = 0.064358398
Epoch: 3, cost = 0.058146056
Epoch: 4, cost = 0.049762029
Epoch: 5, cost = 0.043949362
Epoch: 6, cost = 0.039976787
Epoch: 7, cost = 0.037451632
Epoch: 8, cost = 0.033604462
Epoch: 9, cost = 0.030377652
Epoch: 10, cost = 0.029405368...............
Epoch: 36, cost = 0.012330795
Epoch: 37, cost = 0.013209986
Epoch: 38, cost = 0.012119931
Epoch: 39, cost = 0.011887702
Epoch: 40, cost = 0.011322427
Optimization Finished

测试数据的原始图片 与 重建测试图片对比

TensorFlow1.1搭建自编码网络相关推荐

  1. Pytorch:基于转置卷积解码的卷积自编码网络

    Pytorch: 图像自编码器-卷积自编码网络(转置卷积解码)和图像去噪 Copyright: Jingmin Wei, Pattern Recognition and Intelligent Sys ...

  2. 利用TensorFlow搭建CNN,DNN网络实现图像手写识别,总结。

    利用TensorFlow搭建CNN,DNN网络实现图像手写识别,总结. 摘要 一.神经网络与卷积网络的对比 1.数据处理 2.对获取到的数据进行归一化和独热编码 二.开始我们的tensorflow神经 ...

  3. 嵌入式linux开发环境搭建——VirtualBox虚拟机网络环境解析

    嵌入式linux开发环境搭建--VirtualBox虚拟机网络环境解析 本博文转自:Pandoras Box http://blog.csdn.net/yxc135/article/details/8 ...

  4. 不到 200 行代码,教你如何用 Keras 搭建生成对抗网络(GAN)

     不到 200 行代码,教你如何用 Keras 搭建生成对抗网络(GAN) 生成对抗网络(Generative Adversarial Networks,GAN)最早由 Ian Goodfello ...

  5. CNN结构:序列预测复合DNN结构-AcGANs、 ENN误差编码网络

    前言:模式识别问题 模式函数是一个从问题定义域到模式值域的一个单射. 从简单的贝叶斯方法,到只能支持二分类的原始支持向量机,到十几个类的分类上最好用的随机森林方法,到可以支持ImageNet上海量18 ...

  6. php搭建后台 xampp_你还在用wordpress?|搭建独一无二个人网络日志发布平台

    概述 作为IT人,拥有自己的专属博客会让你显得与众不同.它记录了你的工作和生活的轨迹.搭建博客的程序有很多,常见的有wordpress等,但这些程序往往不支持markdown格式的书写.今天分享一个开 ...

  7. Keras——用Keras搭建自编码神经网络(AutoEncoder)

    文章目录 1.前言 2.用Keras搭建自编码神经网络 2.1.导入必要模块 2.2.数据预处理 2.3.搭建模型 2.4.实例化并激活模型 2.5.训练 2.6.可视化 1.前言 自编码,简单来说就 ...

  8. lrd热加载方式启动本地web服务(我用于从github把别人服务器代码拉倒本地去搭建自己的网络服务)

    作者:吴甜甜 个人博客网站: wutiantian.github.io lrd启动本地web服务(我用于从github把别人服务器代码拉倒本地去搭建自己的网络服务) 主要用于局域网布置物联网项目,当然 ...

  9. 使用CANalyzer搭建LIN通信网络

    使用CANalyzer搭建LIN通信网络 Step 1. 创建LIN工程 Step 2. 配置LDF文件及LIN通信报文 Stpe 2.1 配置报文及信号 Step 2.2 配置调度表 Step 2. ...

  10. Cisco RV320/RV042/RV130产品搭建专网网络

    前阵子,基于工作需求,研究了下专网相关技术,并成功使用RV320/RV042 cisco小型商务专网路由器搭建小型VPN网络.(cisco对中文不友好,该款配置不支持中文) 关于项目的话,主要实现将分 ...

最新文章

  1. golang 字符串分割
  2. Android 性能分析工具dumpsys的使用(自己增加一部分在后面)
  3. kibana一直弹出来报错?
  4. iotop--补齐系统监视工具缺失的一环
  5. Uva536 Tree Recovery二叉树重建(先序和中序确定二叉树,后序输出)
  6. Ubuntu 配置串口信息
  7. 干货|基于深度学习的目标检测算法面试必备(RCNN~YOLOv5)
  8. mysql配置 | 快速上手Linux玩转典型应用
  9. Redis事务与MySQL事务的区别
  10. [终极精简版][图解]Nginx搭建flv mp4流媒体服务器
  11. 基于高德地图的城市区域代码表
  12. Linux基础学习记录
  13. linux查看weblogic安装路径,linux下weblogic安装
  14. fc模拟器安卓版_【SFC】魂斗罗3-异形战争模拟器情怀通关2020_EVOS
  15. 西南石油大学天空教室_学府之旅 | 西南石油大学
  16. !!. 与 ?. 的区别
  17. [转]PCB Layout中的走线策略
  18. python四舍五入round_Python四舍五入及round、Decimal使用
  19. windows 服务程序和桌面程序集成(一)
  20. Web 服务寻址(WS-Addressing)对 SOAP 的隐式影响

热门文章

  1. 注解的引入以及注解的使用
  2. python遇到天猫反爬虫_用Python爬取天猫评价-我的新游戏
  3. win10自带sftp服务器_用于Windows系统的免费SFTP服务器-Free SFTP Servers及各款软件功能对比...
  4. swagger整合springMVC
  5. IIS7 请求筛选模块被配置为拒绝超过请求内容长度的请求
  6. 读书笔记 effective c++ Item 26 尽量推迟变量的定义
  7. Nodejs express中创建ejs项目,解决express下默认创建jade,无法创建ejs问
  8. 生成动态代理并加入系统功能的设计模板
  9. javascript 计时器,消失计时器
  10. 找到所需的产品或服务