GAN的简单实现--MNIST数据集+tensorflow
生成对抗网络(GAN)及其变种已经成为最近十年以来机器学习领域最为重要的思想。--2018图灵奖得主 Yann LeCun
GAN的基础知识复习:click here
GAN模型的挑战即训练优化:click here
1、模型简介及代码
本程序主要是采用最初GAN的基本原理,选择简单的二层神经网络以及MNIST数据集并基于tensorflow平台来实现,以求得对GAN的原理以及实现过程有一个更深入的理解。
程序框架及注释:
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspecsess = tf.InteractiveSession()
z_dim = 100
batchs= 128
mnist = input_data.read_data_sets("/home/zhaocq/桌面/tensorflow/mnist/raw/",one_hot=True)def weight_variable(shape,name):initial = tf.random_normal(shape, stddev=0.01)return tf.Variable(initial,name = name)
def bias_variable(shape,name):initial = tf.zeros(shape) #给偏置增加小的正值用来避免死亡节点比如#inital = tf.constant(0.1, tf.float32, shape)return tf.Variable(initial,name = name)
#生成器随机噪声100维
z = tf.placeholder(tf.float32,shape=[None,100],name = 'z')
#鉴别器准备MNIST图像输入设置
x = tf.placeholder(tf.float32,shape=[None,784],name = 'x')
#生成器参数定义
g_w1 = weight_variable([100,128],'g_w1')
g_b1 = bias_variable([128],'g_b1')
g_w2 = weight_variable([128,784],'g_w2')
g_b2 = bias_variable([784],'g_b2')
generator_dict = [g_w1,g_b1,g_w2,g_b2]
#鉴别器参数定义
d_w1 = weight_variable([784,128],'d_w1')
d_b1 = bias_variable([128],'d_b1')
d_w2 = weight_variable([128,1],'d_w2')
d_b2 = bias_variable([1],'d_b2')
discriminator_dict = [d_w1,d_b1,d_w2,d_b2]#生成器网络定义
def generator(z):g_h1 = tf.nn.relu(tf.matmul(z,g_w1) + g_b1)g_h2 = tf.nn.sigmoid(tf.matmul(g_h1,g_w2) + g_b2)return g_h2
#定义鉴别器
def discrimnator(x):d_h1 = tf.nn.relu(tf.matmul(x,d_w1)+d_b1)d_logit = tf.matmul(d_h1,d_w2)+d_b2d_prob = tf.nn.sigmoid(d_logit)return d_prob,d_logitg_sample = generator(z)
d_real,d_logit_real = discrimnator(x)
d_fake,d_logit_fake = discrimnator(g_sample)
#定义损失
d_loss = - tf.reduce_mean(tf.log(d_real) + tf.log(1.- d_fake))
g_loss = - tf.reduce_mean(tf.log(d_fake))
#定义优化器,仅优化相关参数
d_slover = tf.train.AdamOptimizer().minimize(d_loss,var_list = discriminator_dict)
g_slover = tf.train.AdamOptimizer().minimize(g_loss,var_list = generator_dict)def sample_z(m,n):'''Uniform prior for G(z)'''return np.random.uniform(-1.,1.,size=[m,n])
def plot(samples):fig = plt.figure(figsize=(4, 4))gs = gridspec.GridSpec(4, 4)gs.update(wspace=0.05, hspace=0.05)for i, sample in enumerate(samples):ax = plt.subplot(gs[i])plt.axis('off')ax.set_xticklabels([])ax.set_yticklabels([])ax.set_aspect('equal')plt.imshow(sample.reshape(28, 28), cmap='Greys_r')return fig#训练过程
sess.run(tf.global_variables_initializer())
i=0
for it in range(500000):#输出imageif it % 10000 == 0:samples = sess.run(g_sample,feed_dict={z: sample_z(16,100)})fig = plot(samples)plt.savefig('out/{}.png'.format(str(i).zfill(3)),bbox_inches = 'tight')i += 1plt.close(fig)x_md,_ =mnist.train.next_batch(batchs)_,d_loss_curr = sess.run([d_slover,d_loss],feed_dict={x: x_md, z: sample_z(batchs,z_dim)})_,g_loss_curr = sess.run([g_slover,g_loss],feed_dict={z: sample_z(batchs,z_dim)})if it % 10000 ==0:print('iter:{}'.format(it))print('d loss : {:.4}'.format(d_loss_curr))print('g loss : {:.4}'.format(g_loss_curr))print()#测试test
sampl = sess.run(g_sample,feed_dict={z: sample_z(5,100)})
I=np.reshape(sampl[1],(28,28))
#plt.imshow(np.reshape(sampl[1],(28,28)))
plt.imshow(I)return tf.Variable(initial,name = name)
#生成器随机噪声100维
z = tf.placeholder(tf.float32,shape=[None,100],name = 'z')
#鉴别器准备MNIST图像输入设置
x = tf.placeholder(tf.float32,shape=[None,784],name = 'x')
#生成器参数定义
g_w1 = weight_variable([100,128],'g_w1')
g_b1 = bias_variable([128],'g_b1')
g_w2 = weight_variable([128,784],'g_w2')
g_b2 = bias_variable([784],'g_b2')
generator_dict = [g_w1,g_b1,g_w2,g_b2]
#鉴别器参数定义
d_w1 = weight_variable([784,128],'d_w1')
d_b1 = bias_variable([128],'d_b1')
d_w2 = weight_variable([128,1],'d_w2')
d_b2 = bias_variable([1],'d_b2')
discriminator_dict = [d_w1,d_b1,d_w2,d_b2]#生成器网络定义
def generator(z):g_h1 = tf.nn.relu(tf.matmul(z,g_w1) + g_b1)g_h2 = tf.nn.sigmoid(tf.matmul(g_h1,g_w2) + g_b2)return g_h2
#定义鉴别器
def discrimnator(x):d_h1 = tf.nn.relu(tf.matmul(x,d_w1)+d_b1)d_logit = tf.matmul(d_h1,d_w2)+d_b2d_prob = tf.nn.sigmoid(d_logit)return d_prob,d_logitg_sample = generator(z)
d_real,d_logit_real = discrimnator(x)
d_fake,d_logit_fake = discrimnator(g_sample)
#定义损失
d_loss = - tf.reduce_mean(tf.log(d_real) + tf.log(1.- d_fake))
g_loss = - tf.reduce_mean(tf.log(d_fake))
#定义优化器,仅优化相关参数
d_slover = tf.train.AdamOptimizer().minimize(d_loss,var_list = discriminator_dict)
g_slover = tf.train.AdamOptimizer().minimize(g_loss,var_list = generator_dict)def sample_z(m,n):'''Uniform prior for G(z)'''return np.random.uniform(-1.,1.,size=[m,n])
def plot(samples):fig = plt.figure(figsize=(4, 4))gs = gridspec.GridSpec(4, 4)gs.update(wspace=0.05, hspace=0.05)for i, sample in enumerate(samples):ax = plt.subplot(gs[i])plt.axis('off')ax.set_xticklabels([])ax.set_yticklabels([])ax.set_aspect('equal')plt.imshow(sample.reshape(28, 28), cmap='Greys_r')return fig#训练过程
sess.run(tf.global_variables_initializer())
i=0
for it in range(500000):#输出imageif it % 10000 == 0:samples = sess.run(g_sample,feed_dict={z: sample_z(16,100)})fig = plot(samples)plt.savefig('out/{}.png'.format(str(i).zfill(3)),bbox_inches = 'tight')i += 1plt.close(fig)x_md,_ =mnist.train.next_batch(batchs)_,d_loss_curr = sess.run([d_slover,d_loss],feed_dict={x: x_md, z: sample_z(batchs,z_dim)})_,g_loss_curr = sess.run([g_slover,g_loss],feed_dict={z: sample_z(batchs,z_dim)})if it % 10000 ==0:print('iter:{}'.format(it))print('d loss : {:.4}'.format(d_loss_curr))print('g loss : {:.4}'.format(g_loss_curr))print()#测试test
sampl = sess.run(g_sample,feed_dict={z: sample_z(5,100)})
I=np.reshape(sampl[1],(28,28))
#plt.imshow(np.reshape(sampl[1],(28,28)))
plt.imshow(I)
第100000次迭代结果:
GAN的简单实现--MNIST数据集+tensorflow相关推荐
- 残差神经网络Resnet(MNIST数据集tensorflow实现)
简述: 残差神经网络(ResNet)主要是用于搭建深度的网络结构模型 (一)优势: 与传统的神经网络相比残差神经网络具有更好的深度网络构建能力,能避免因为网络层次过深而造成的梯度弥散和梯度爆炸. (二 ...
- 基于Keras搭建mnist数据集训练识别的Pipeline
搭建模型 import tensorflow as tf from tensorflow import keras# get data (train_images, train_labels), (t ...
- TensorFlow笔记(3)——利用TensorFlow和MNIST数据集训练一个最简单的手写数字识别模型...
前言 当我们开始学习编程的时候,第一件事往往是学习打印"Hello World".就好比编程入门有Hello World,机器学习入门有MNIST. MNIST是一个入门级的计算机 ...
- 深度学习之利用TensorFlow实现简单的卷积神经网络(MNIST数据集)
卷积神经网络(Convolutional Neural Networks, CNN)是一类包含卷积计算且具有深度结构的前馈神经网络(Feedforward Neural Networks),是深度学习 ...
- 深度学习之利用TensorFlow实现简单的全连接层网络(MNIST数据集)
Tensorflow是一个基于数据流编程(Dataflow Programming)的符号数学系统,被广泛应用于各类机器学习(Machine Learning)算法的编程实现,其前身是谷歌的神经网络算 ...
- GAN生成对抗网络基本概念及基于mnist数据集的代码实现
本文主要总结了GAN(Generative Adversarial Networks) 生成对抗网络的基本原理并通过mnist数据集展示GAN网络的应用. GAN网络是由两个目标相对立的网络构成的,在 ...
- [深度学习-实践]GAN基于手写体Mnist数据集生成新图片
系列文章目录 深度学习GAN(一)之简单介绍 深度学习GAN(二)之基于CIFAR10数据集的例子 深度学习GAN(三)之基于手写体Mnist数据集的例子 深度学习GAN(四)之PIX2PIX GAN ...
- Tensorflow 笔记 XIII——“百无聊赖”:深挖 mnist 数据集与 fashion-mnist 数据集的读取原理,经典数据的读取你真的懂了吗?
文章目录 数据集简介 Mnist 出门右转 Fashion-Mnist 数据集制作需求来源 写给专业的机器学习研究者 获取数据 类别标注 读取原理 原理获取 TRAINING SET LABEL FI ...
- 使用MNIST数据集,在TensorFlow上实现基础LSTM网络
使用MNIST数据集,在TensorFlow上实现基础LSTM网络 By 路雪2017年9月29日 13:39 本文介绍了如何在 TensorFlow 上实现基础 LSTM 网络的详细过程.作者选用了 ...
最新文章
- vim删除多行_Vim 可视化模式入门 | Linux 中国
- Publons:文章审稿、编辑工作认证平台
- 仓库管理员怎样做台账_工作日志之仓库管理员与会计之间的对接工作
- lintcode-49-字符大小写排序
- python:实现简单的web开发demo
- 怎么重置blockinput的锁_AppleID被锁如何解决 AppleID被锁激活方法介绍【图文】
- 如何查找UI5应用对应在ABAP Netweaver服务器上的BSP应用名称
- Redis和数据库的结合
- scroll jquery
- 【杭电ACM】1.2.6 decimal system
- JSON的C代码示例
- Python遍历文件夹下所有文件及目录
- android CTS GTS 环境搭建
- Essay Writing Guide
- git上传代码的账户名不是本人的问题
- 共享手机 馅饼还是陷阱
- cai鸡——处女作博客“横空出世”
- android—性能优化2—内存优化
- 【单片机笔记】运放电流检测实用电路
- excel工具栏隐藏了怎么办_你会用 Excel照相机吗?
热门文章
- Unity Shader - Custom SSSM(Screen Space Shadow Map) 自定义屏幕空间阴影图
- 网络战利器——“网络安全态势感知”
- AMP 功放 类别 - class A, B ,AB, D
- 警告: Exception encountered during context initialization - cancelling refresh attempt
- Android提取字符串中的特殊字符(以手机号为例)并修改样式和添加点击事件
- 体重指数计算器 (Body Mass Index Calculator)
- 全球及中国小型风力涡轮机叶片行业运行走势及投资战略决策咨询报告2021-2027年版
- 从零开始建设Discuz论坛
- Android permission
- 图片过大无法发送怎么办?分享三种图片压缩工具