Generative Adversarial Nets

前言

GAN同时训练两个模型:

(1)生成模型G,其主要是用来捕获给定数据的分布特征,依此生成类似的数据样本。

(2)判别模型D,用来判别数据到底是来自原始数据还是由生成模型G生成的伪造数据。

G和D的训练类似一个假币制造者G和验钞者D的对抗游戏。G要制造足够假的假币,D要以最大概率识别出G制造的假币。推理证明这个模型最终的结果就是G生成了一个很像原始分布的数据(即假币足够逼真),而D的判别概率稳定在 1/2 1 / 2 1/2(即直观上看,和瞎猜是不是假币没什么区别)。

最原始的GAN生成器和判别器都是用的多层感知机,就是现在CNN里面常提到的全连接层。

介绍

判别模型(discriminative model)学习分辨样本到底是来自生成的样本还是原始数据。而生成模型(generative model)充当一个伪造者,制造假的样本数据。判别模型判断得越准,生成模型也会根据对手的进步不断提升自己的造假水平。最终两者到达一个平衡点,那就是生成的样本足够逼真,判别模型只能以一半的概率判断是否是假的。

为了得到生成器G在给定数据集 x x x上的分布pg" role="presentation">pgpgp_g,我们首先要定义一个噪声变量 pz(z) p z ( z ) p_z(z) ,这个变量很关键,因为它就是生成器G的输入,通过多层感知机转化为了伪造样本 G(z;θg) G ( z ; θ g ) G(z;\theta_g) ,这里的 θg θ g \theta_g就是生成器G的相关多层感知机参数。

另外还需要定义一个有关判别器的多层感知机 D(x;θd) D ( x ; θ d ) D(x;\theta_d),D仅仅输出一个标量,即样本真假的概率。

D训练就是为了最小化样本来自G的判别为真的概率,G的训练目标则是最大化让D误判的概率。可以参考一下下面这个公式:

minGmaxDEx∼pdata(x)[logD(x)]+Ez∼pz(z)[log(1−D(G(z)))] min G max D E x ∼ p d a t a ( x ) [ l o g D ( x ) ] + E z ∼ p z ( z ) [ l o g ( 1 − D ( G ( z ) ) ) ]

\min \limits_{G}\max \limits_{D}E_{x\sim p_{data(x)}}[logD(x)]+E_{z\sim p_{z(z)}}[log(1-D(G(z)))]

后面那一长串公式可以简写成 V(D,G) V ( D , G ) V(D,G),其实理论上这个公式直观上看上去不难理解。为了证明它的收敛性,有一个全局的最优解 pg=pdata p g = p d a t a p_g=p_{data},即G学到的数据分布 pg p g p_g等于给定的样本分布 pdata p d a t a p_{data},作者用了比较长的篇幅来证明Jensen-Shannon divergence在两个数据分布之间是非负的,当且进档他们相等才为0。

算法

下图是作者算法的主要流程:

简单翻译一下里面的步骤:

  • 生成 batch_size=m b a t c h _ s i z e = m batch\_size=m 的噪声向量样本(输入到G,其输出再输入给D)
  • 从原始分布获取 m m m个样本
  • 首先更新k" role="presentation">kkk次的判别器(判别器D的参数有来自G生成的造假数据,也有来自 pdata(x) p d a t a ( x ) p_{data}(x)的原始数据)
  • 生成 batch_size=m b a t c h _ s i z e = m batch\_size=m 的噪声向量样本(输入到G,其输出再输入给D)
  • 更新生成器
  • 如果迭代没有结束就返回第一步继续循环

效果

既然GAN中的G是伪造者,那么伪造水平如何呢,这里参考文中的算法步骤,用MNIST训练出了一个简单的GAN。

一开始的时候,G非常弱,输入任意噪声,它只能产生很随机的效果,如:

一开始生成这个,D几乎以100%的概率拒绝G:你这根本就是假的,算了吧。那么G很不甘心,它就努力模仿,想成为一个造假界的王者。经过多轮训练迭代,最终G能生成:

动态成长图如下:

看到没有,你有张良计我有过墙梯,你判别器那么厉害,我生成器也是越来越能造假。假到你分辨不出。

代码实战

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import os
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as pltimg_height = 28
img_width = 28
img_size = img_height*img_width
batch_size = 128
h1_size = 128
h2_size = 256
max_epoch = 1000000
z_size = 100  # 噪声维度
keep_prob = 0.5
save_path = './gan_output/'z = tf.placeholder(tf.float32,shape=[None,z_size])
x = tf.placeholder(tf.float32,shape=[None,img_size])def xavier_init(shape):'''初始化方法,来源一篇论文,保证每一层都有一致的方差'''in_dim = shape[0]stddev = 1./tf.sqrt(in_dim/2.)return tf.random_normal(shape=shape,stddev=stddev)def get_z(shape):'''生成随机噪声,作为G的输入'''return np.random.uniform(-1.,1.,size=shape).astype(np.float32)def generator(z_prior):'''生成器,两层感知机,L1用ReLU,Out用sigmoid'''# L1w1 = tf.Variable(xavier_init([z_size,h1_size]))b1 = tf.Variable(tf.zeros([h1_size]),dtype=tf.float32)h1 = tf.nn.relu(tf.matmul(z_prior,w1)+b1)# Outw2= tf.Variable(xavier_init([h1_size,img_size]))b2 = tf.Variable(tf.zeros([img_size]),dtype=tf.float32)x_generated = tf.nn.sigmoid(tf.matmul(h1,w2)+b2)# 待训练参数要一并返回params = [w1,b1,w2,b2]return x_generated, paramsdef discriminator(x,x_generated,keep_prob):'''判别器,两层感知机,L1用ReLU,Out用sigmoid注意判别器用同样的w和b去计算原始样本x和G的生成样本'''# L1w1 = tf.Variable(xavier_init([img_size,h1_size]))b1 = tf.Variable(tf.zeros([h1_size]),dtype=tf.float32)h1_x = tf.nn.dropout(tf.nn.relu(tf.matmul(x,w1)+b1),keep_prob)  # 不加dropout迭代到一定次数会挂掉h1_x_generated = tf.nn.dropout(tf.nn.relu(tf.matmul(x_generated,w1)+b1),keep_prob)# Outw2 = tf.Variable(xavier_init([h1_size,1]))b2 = tf.Variable(tf.zeros([1]),dtype=tf.float32)d_prob_x = tf.nn.sigmoid(tf.matmul(h1_x,w2)+b2)d_prob_x_generated = tf.nn.sigmoid(tf.matmul(h1_x_generated,w2)+b2)params = [w1,b1,w2,b2]return d_prob_x,d_prob_x_generated,paramsdef save(samples, index):'''只是用来把图片保存到本地,和训练无关'''fig = plt.figure(figsize=(4,4))gs = gridspec.GridSpec(4,4)gs.update(wspace=0.05,hspace=0.05)for i,sample in enumerate(samples):ax = plt.subplot(gs[i])plt.axis('off')ax.set_xticklabels([])ax.set_yticklabels([])ax.set_aspect('equal')plt.imshow(sample.reshape(img_width,img_height),cmap='Greys_r')plt.savefig(save_path+'{}.png'.format(str(index).zfill(3)))plt.close(fig)x_generated,g_params = generator(z)  # 生产伪造样本
d_prob_real,d_prob_fake,d_params = discriminator(x,x_generated,keep_prob)  # 把伪造样本和生成的一并传入计算各自概率# 这两个是论文里面的那个很长的公式
d_loss = -tf.reduce_mean(tf.log(d_prob_real+1e-30) + tf.log(1.-d_prob_fake+1e-30))  # 不加这个1e-30会出现log(0)
g_loss = -tf.reduce_mean(tf.log(d_prob_fake+1e-30))  # tf有内置的sigmoid_cross_entropy_with_logits可以解决这个问题,但我没用上g_solver = tf.train.AdamOptimizer(0.001).minimize(g_loss,var_list=g_params)
d_solver = tf.train.AdamOptimizer(0.001).minimize(d_loss,var_list=d_params)sess = tf.Session()
sess.run(tf.global_variables_initializer())mnist = input_data.read_data_sets('./MNIST_data', one_hot=True)  # 加载数据集if not os.path.exists(save_path):os.makedirs(save_path)for i in range(max_epoch):if i % 1000 == 0:  # 这个只是用来保存图片,和训练没什么关系samples = sess.run(x_generated, feed_dict = {z:get_z([16,z_size])})index = int(i/1000)save(samples, index)# *主要的训练步骤*x_mb,_ = mnist.train.next_batch(batch_size)_,d_loss_ = sess.run([d_solver,d_loss],feed_dict={x:x_mb,z:get_z([batch_size,z_size])})_,g_loss_ = sess.run([g_solver,g_loss],feed_dict={z:get_z([batch_size,z_size])})if i % 1000 == 0:print('iter: %d, d_loss: %.3f, g_loss: %.3f\n' % (i,d_loss_,g_loss_))

总结

仅仅两层感知机效果不怎么样,不过多层了就要训练比较久。想要获得更好的效果,可以考虑用DCGAN,卷积才是王道。

GAN[1]:原论文介绍及代码实战相关推荐

  1. Classifier-Free Diffusion Guidance【论文精读加代码实战】

    Classifier-Free Diffusion Guidance[论文精读加代码实战] 0.前言 1.Classifier-Free Diffusion Guidance介绍 1.1原理介绍 1. ...

  2. T5的整体介绍【代码实战】

    T5的整体介绍[代码实战] 0.前言 1.Header 2.summary 3 T5 model 3.1 forward 3.2 预训练任务 3.2.1 multi sentence pairs 3. ...

  3. GAN 数学原理简单介绍以及代码实践

    1. GAN 数学原理 1.1 GAN 概述 GAN(Generative Adversarial Network) 是一种深度生成神经网络,它包括 生成模型 与 判别模型 两个部分.其中,生成模型 ...

  4. 第十三讲:textcnn做文本分类任务,基于论文:Relation_Classification_via_Convolutional_Deep_Neural_Network的实战代码

      大家好,我是爱编程的喵喵.双985硕士毕业,现担任全栈工程师一职,热衷于将数据思维应用到工作与生活中.从事机器学习以及相关的前后端开发工作.曾在阿里云.科大讯飞.CCF等比赛获得多次Top名次.现 ...

  5. 【学习打卡05】可解释机器学习笔记之CAM+Captum代码实战

    可解释机器学习笔记之CAM+Captum代码实战 文章目录 可解释机器学习笔记之CAM+Captum代码实战 代码实战介绍 torch-cam工具包 可视化CAM类激活热力图 预训练ImageNet- ...

  6. 玩转NLP论文实战图书馆----附论文地址和代码地址

    一.文本生成概述 1.文本生成概述: https://www.jiqizhixin.com/articles/2017-05-22 任务包括:对话生成.摘要生成.机器翻译. 2.NLP综述:自然语言生 ...

  7. 【Segment Anything Model】论文+代码实战调用SAM模型预训练权重+相关论文

    上篇文章已经全局初步介绍了SAM和其功能,本篇作为进阶使用. 文章目录 0.前言 1.SAM原论文 1️⃣名词:提示分割,分割一切模型,数据标注,零样本,分割一切模型的数据集 2️⃣Introduct ...

  8. 【总结】有三AI所有GAN相关学习资料汇总,有图文、视频、代码实战等......

    GAN无疑是这几年深度学习领域里最酷的技术,不管是理论的研究,还是GAN在图像生成,图像翻译,语音图像等基础领域的应用,都非常的丰富.我们公众号输出过非常多的GAN相关资源,本次做一个简单汇总. 免费 ...

  9. 深度学习代码实战演示_Tensorflow_卷积神经网络CNN_循环神经网络RNN_长短时记忆网络LSTM_对抗生成网络GAN

    前言 经过大半年断断续续的学习和实践,终于将深度学习的基础知识看完了,虽然还有很多比较深入的内容没有涉及到,但也是感觉收获满满.因为是断断续续的学习做笔记写代码跑实验,所以笔记也零零散散的散落在每个角 ...

最新文章

  1. Http协议简单介绍
  2. cron表达式 每隔8小时_cron表达式详解
  3. 操作系统(3) 多处理器编程:从入门到放弃
  4. linux的网络配置有线线缆被拔出
  5. Nginx1.10编译安装
  6. LabView-之1: 串口驱动
  7. php有个schost.exe_windows找不到svchost.exe(附图)
  8. mac android 手机连接打印机,为什么苹果电脑连接打印机打不出来怎么办
  9. 【有利可图网】PS实战教程55:打破次元壁,将照片从三次元跨越到二次元
  10. 一睹64位Windows XP的芳容(也是从网上copy的,扫了一下,没有仔细看)
  11. flume系列之:监控JMX reporter
  12. ufs2.2 协议扫盲(十一)
  13. @EnableConfigurationProperties 的作用
  14. Nginx的部署与配置
  15. Jeesite Login 登录 分析
  16. win7和linux mint双系统安装总结
  17. Resin 与 Tomcat 服务器对比
  18. 【专精特新周报】北交所首份2022年半年报出炉;创北交所最快上会记录 天马新材、华岭股份北交所过会...
  19. C++ map和set
  20. 基于51单片机的洗衣机控制系统

热门文章

  1. Eclipse 和 MyEclipse控制台console不停的自动跳动,跳出来解决方案http://jmhmlu.blog.163.com/blog/static/16161229820124311
  2. 给崩坏三桌面版的启动窗口加个启动音效
  3. 使用UnityHub下载任意版本Unity
  4. Python爬取车票,车次和余票等
  5. ThinkPad电脑系统损坏如何重装Win10系统教学分享
  6. uos应用_统信UOS桌面操作系统V20专业版正式发布:大量自研应用
  7. 三国志战略版:Daniel_S5_PK2逆向克制!菜刀骑兵杀手-桃园箕形陷阵盾
  8. Vysor破解助手for Linux/macOS/Windows
  9. 史上最直白之Attention详解(原理+代码)
  10. js小数点toFixed