写在前头
这段时间在学习利用tensorflow编写自编码器代码,在B站上看到了一个教学视频,觉得很有用,但很遗憾视频的画质和声音都有点渣,UP主也没有提供源代码,我在网上搜了下,也没找到,因此自己根据视频把代码整理了出来,给自己做个备份,也给大家提供参考。
视频链接:视频链接
视频的大神是在已经写好的神经网络代码上进行改写的,我在网上有发现这个神经网络的代码,也一并附在这里:神经网络代码
如果教学视频的原作者已经在某处发表了代码,或者有人发现了原作者的代码,还请告诉我一声,为保护原作者的权益,我会删掉这篇文章的~
另外,在完整把代码扒下来之后,运行发现有一两个问题,不能完整运行,因此对里面一些小地方(例如路径等)进行了修改,还对一些代码进行了解释,已经可以直接运行。

tensorflow 自编码网络

import tensorflow as tf
import os
from PIL import Image
import numpy as np
import cv2
from matplotlib import pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data#重置图,这句是另外加的
tf.reset_default_graph()#文件路径名和视频中不一样,按实际的路径即可
mnist=input_data.read_data_sets("./MNIST_data/",one_hot=True)
x=tf.placeholder(tf.float32,[None,784])batch_size=1000def add_layer(input_data,input_num,output_num,activation_function=None):w=tf.Variable(initial_value=tf.random_normal(shape=[input_num,output_num]))b=tf.Variable(initial_value=tf.random_normal(shape=[1,output_num]))output=tf.add(tf.matmul(input_data,w),b)if activation_function:output=activation_function(output)return outputdef build_nn(data):#编码,两个隐含层,隐含层的节点数可以自己调节hidden_layer1=add_layer(data,784,200,activation_function=tf.nn.sigmoid)hidden_layer2=add_layer(hidden_layer1,200,100,activation_function=tf.nn.sigmoid)#解码,节点数和上面的相对应。利用sigmoid函数将数据转换到0~1的范围里hidden_layer3=add_layer(hidden_layer2,100,200,activation_function=tf.nn.sigmoid)output_layer=add_layer(hidden_layer3,200,784,activation_function=tf.nn.sigmoid)return output_layerdef train_nn(data):output=build_nn(data)#损失函数loss=tf.reduce_mean(tf.square(output-data))optimizer=tf.train.AdamOptimizer(learning_rate=0.01).minimize(loss)#保存模型saver=tf.train.Saver()with tf.Session() as sess:sess.run(tf.global_variables_initializer())#因为我们想要第一遍运行的时候,能保存我们的参数模型,#然后第二遍的时候,能直接加载参数模型对测试数据进行训练,因此需要对是否有保存好的参数模型进行判断#判断方法:因为保存模型会生成checkpoint文件,因此可以作为判断依据if not os.path.exists("./SDAE_model_save/checkpoint"):for i in range(30):epoch_cost=0for _ in range(int(mnist.train.num_examples/batch_size)):x_data,y_data=mnist.train.next_batch(batch_size)cost,_=sess.run([loss,optimizer],feed_dict={x:x_data})epoch_cost+=costprint("Epoch",i,":",epoch_cost)saver.save(sess,"./SDAE_model_save/model")else:saver.restore(sess,"./SDAE_model_save/model")#最开始原作者是想要第二遍的时候预测自己的手写数字图片1.jpg的,就可以用下面这句代码。#predict("./1.jpg",sess,output)#但是因为后来预测的效果不好,所以换了mnist测试集里的图片result=sess.run(output,feed_dict={x:mnist.test.next_batch(1)[0]})result=np.reshape(result,(28,28)plt.imshow(result*255)plt.show()def read_data(path):image=cv2.imread(path,cv2.IMREAD_GRAYSCALE)processed_image=cv2.resize(image,(28,28))processed_image=np.resize(processed_image,(1,784))processed_image=processed_image/255.0return image,processed_imagedef predict(image_path,sess,output):imgae,processed_image=read_data(image_path)result=sess.run(output,feed_dict={x:processed_image})result=np.reshape(result,(28,28))#因为想要更清晰地展示图片,所以用plt替换了cv2,又因为Plt不能直接显示0~1的数据,因此需要扩展为255的plt.imshow(result*255)plt.show()cv2.imshow("image",result)cv2.waitKey(0)cv2.destroyAllWindows()def reconstract_image():for i in range(10):if not os.path.exists("./{}",format(i)):os.makedirs("./{}",format(i))batch_size=1for j in range(int(mnist.train.num_examples/batch_size)):x_data,y_data=mnist.train.next_batch(batch_size)img=Image.fromarray(np.reshape(np.array(x_data[0]*255,dtype="unit8"),(28,28)))dir=np.argmax(y_data[0])img.save("./{}/{}.bmp",format(dir,j))if __name__=="__main__":train_nn(x)

结果展示:

我只运行了5步,效果肯定不是很好,只是测试能不能正常用,想要更好的结果,可以调高步数

tensorflow 自编码器 MNIST数据集相关推荐

  1. 基于tensorflow+RNN的MNIST数据集手写数字分类

    2018年9月25日笔记 tensorflow是谷歌google的深度学习框架,tensor中文叫做张量,flow叫做流. RNN是recurrent neural network的简称,中文叫做循环 ...

  2. Tensorflow初探之MNIST数据集学习

    官方文档传送门 MNIST数据集是手写数字0~9的数据集,一般被用作机器学习领域的测试,相当于HelloWorld级别. 本程序先从网上导入数据,再利用最小梯度法进行训练使得样本交叉熵最小,最后给出训 ...

  3. 北京大学曹健——Tensorflow笔记 05 MNIST数据集输出手写数字识别准确率

              # 前向传播:描述了网络结构 minist_forward.py # 反向传播:描述了模型参数的优化方法 mnist_backward.py # 测试输出准确率minist_tes ...

  4. tensorflow(七)实现mnist数据集上图片的训练和测试

    本文使用tensorflow实现在mnist数据集上的图片训练和测试过程,使用了简单的两层神经网络,代码中涉及到的内容,均以备注的形式标出. 关于文中的数据集,大家如果没有下载下来,可以到我的网盘去下 ...

  5. 主成分分析降维(MNIST数据集)

    北京 | 高性能计算之GPU CUDA课程11月24-26日3天密集学习 快速带你晋级阅读全文> 刘凯欣,中国矿业大学在校学生,曾参加过ThoughtWorks举办的结对编程活动. 今天看了用主 ...

  6. 使用MNIST数据集,在TensorFlow上实现基础LSTM网络

    使用MNIST数据集,在TensorFlow上实现基础LSTM网络 By 路雪2017年9月29日 13:39 本文介绍了如何在 TensorFlow 上实现基础 LSTM 网络的详细过程.作者选用了 ...

  7. autoencoder自编码器原理以及在mnist数据集上的实现

    Autoencoder是常见的一种非监督学习的神经网络.它实际由一组相对应的神经网络组成(可以是普通的全连接层,或者是卷积层,亦或者是LSTMRNN等等,取决于项目目的),其目的是将输入数据降维成一个 ...

  8. MNIST数据集实现手写数字识别(基于tensorflow)

    ------------先看看别人的博客--------------------- Tensorflow 实现 MNIST 手写数字识别         用这个的代码跑通了 使用Tensorflow和 ...

  9. Tensorflow mnist 数据集测试代码 + 自己下载数据

    https://blog.csdn.net/weixin_39673686/article/details/81068582 import tensorflow as tf from tensorfl ...

最新文章

  1. android 8.0 ,9.0 静态广播不显示问题处理
  2. Linux之用户组相关操作 groupadd groupdel
  3. DB-Engines:2017 年 2 月份全球数据库排名
  4. python语言设计简单计算器_Python 设计一个简单的计算器-Go语言中文社区
  5. 【机器视觉】机器视觉光源详解
  6. 设计模式:单例模式的写法(基础写法和线程安全写法)
  7. Linux 文件的复制
  8. php员工删除,php+mysql删除指定编号员工信息的方法_PHP
  9. RSA加密、解密、签名、校验签名
  10. 一文看懂Java虚拟机——JVM基础概念整理
  11. 南京林业大学计算机科学技术,南京林业大学信息科学技术学院
  12. linux系统工程师修改打开文件数限制代码教程。服务器运维技术
  13. srs推flv流_SRS流媒体服务器之HLS源码分析(3)
  14. 即时通讯源码/im源码uniapp基于在线聊天系统附完整搭建部署教程
  15. Socket协议脚本编写
  16. 计算机单位厘米 像素,英尺和厘米的换算计算器 显示器的像素解析度可能不是...
  17. CSS解决图片过大撑破DIV的方法
  18. python下载动作电影_Python爬虫实战之取电影天堂,,新手练手项目
  19. windows10装detectron2-0.6,并运行fasterrcnn
  20. POJ 1392 Ouroboros Snake

热门文章

  1. 广东移动:各技术岗简介——FML
  2. 透过浏览器你能知道用户什么,由人肉搜索说开去
  3. 网络测试工具——tcping
  4. Python正则表达式: 元字符/转义/分组/匹配原则/re模块属性方法大全
  5. [推荐] 协同滤波 —— Collaborative Filtering (CF)
  6. 数据分析技能点-多元分析和应用
  7. SAP 实习 面试 工作 心得 体会记录
  8. 功能展示——Android底部导航栏复古风TabHost实现
  9. Jsp服装商城包安装调试
  10. windows 主题壁纸更换