​​​​

序言

做图像分类、检测任务时,为了提高模型精度,在数据处理方面,我尝试了很多数据增强tricks(包括了简单的裁切、变形、明暗、颜色调整,也包括了MixUp图像融合以及SMOTE这样的解决类别样本平衡的插值方法),取得了不错的精度提升。在查找资料时,我发现在人脸识别、行人重识别等任务中,还可以通过GAN(生成对抗网络)来生成具有多样高级语义特征的样本来充实训练集数据,以帮助提升模型精度。随着对GAN的了解逐步深入,我发现GAN是一种既有用,又好玩的深度学习模型,横空出世了许多让人眼前一亮的应用。有了好玩儿的模型,还得有趁手的工具,才能尽情玩耍。飞桨作为国内首个开源深度学习框架,有着丰富的开发部署工具和课程、社区的支持,至今已经更新到了2.0版本。“动态图编写模型”,“高级API支持”等特性非常便于上手。更不用说AI Studio上丰富的课程(理论实践全都有),尤其是8卡Tesla v100的算力“金羊毛”,让人怎好意思拒绝~~。自此,我便开始“架着飞桨”,学习在GAN的海洋里“乘风破浪”(向姐姐们致敬)。

生成对抗网络的介绍(GAN)

GAN的全称是Generative Adversarial Networks,即生成对抗网络,由Ian J. Goodfellow等人于2014年10月发表在NIPS大会上的论文《Generative Adversarial Nets》中提出。此后各种花式变体Pix2Pix、CYCLEGAN、STARGAN、StyleGAN等层出不穷,在“换脸”、“换衣”、“换天地”等应用场景下生成的图像、视频以假乱真,好不热闹。前段时间PaddleGAN实现的First Order Motion表情迁移模型,能用一张照片生成一段唱歌视频。各种搞笑鬼畜视频火遍全网。用的就是一种GAN模型哦。深度学习三巨神之一的LeCun也对GAN大加赞赏,称“adversarial training is the coolest thing since sliced bread”。关于GAN网络的研究也呈井喷态势,下面是2014年到2018年命名为GAN的论文数量图表:

GAN的前世今生

  1. 判别模型与生成模型

对抗生成模型GAN首先是一个生成模型,和大家比较熟悉的、用于分类的判别模型不同。

判别模型的数学表示是y=f(x),也可以表示为条件概率分布p(y|x)。当输入一张训练集图片x时,判别模型输出分类标签y。模型学习的是输入图片x与输出的类别标签的映射关系。即学习的目的是在输入图片x的条件下,尽量增大模型输出分类标签y的概率。

而生成模型的数学表示是概率分布p(x)。没有约束条件的生成模型是无监督模型,将给定的简单先验分布π(z)(通常是高斯分布),映射为训练集图片的像素概率分布p(x),即输出一张服从p(x)分布的具有训练集特征的图片。模型学习的是先验分布π(z)与训练集像素概率分布p(x)的映射关系。

  1. 其他生成网络简介

生成网络并非只有GAN,介绍下其他几种:

  • 自回归模型(Autoregressive model)是从回归分析中的线性回归发展而来,只是不用x预测y,而是用x预测 x(自己),所以叫做自回归。多用于序列数据生成如文本、语音。PixelRNN/CNN则使用这种方法生成图片,效果还不错。但是由于是按照像素点去生成图像导致计算成本高, 在可并行性上受限,在处理大型数据如大型图像或视频是具有一定麻烦的。
  • 变分自编码器(VAE):VAE是在AE(Autoencoder自编码器)的基础上让图像编码的潜在向量服从高斯分布从而实现图像的生成,优化了数据对数似然的下界,VAE在图像生成上是可并行的, 但是VAE存在着生成图像模糊的问题。
  • 基于流的模型(Flow-based Model)包括Glow、RealNVP、NICE等。流模型思想很直观:寻找一种变换 y = f(x)(f 可逆,且 y 与 x 的维度相同) 将数据空间映射到另一个空间,新空间各个维度相互独立。这些年,看着GAN一直出风头,流模型表示各种不服,自从2016年问世以来,一直在“不服中…”。

下面就该介绍生成模型中的“明星”——GAN模型了。

GAN的原理

生成对抗网络一般由一个生成器(生成网络),和一个判别器(判别网络)组成。生成器的作用是,通过学习训练集数据的特征,在判别器的指导下,将随机噪声分布尽量拟合为训练数据的真实分布,从而生成具有训练集特征的相似数据。而判别器则负责区分输入的数据是真实的还是生成器生成的假数据,并反馈给生成器。两个网络交替训练,能力同步提高,直到生成网络生成的数据能够以假乱真,并与与判别网络的能力达到一定均衡。

本着“深度愉悦学习”的宗旨,正儿八经介绍GAN流程不如给大伙讲一个“终成一代大师”的励志故事。

故事里的方学芹同学就是GAN网络里的生成器,而文谈同学就是判别器。故事的发展过程就是GAN网络的训练过程。

从“文坛大佬的往事”说起…

方学芹同学和文谈同学从小就是一对热爱文学的诤友。小方爱讲,小文爱听后发表意见。(GAN网络由两个网络组成,一个是生成器,一个是判别器。)

上小学时,小方给小文推荐了《孟母三迁》、《司马光砸缸》和自己照着前两篇写的《司马光砸锅》。(将真数据和生成器生成的假数据一起送给判别器判别真假。)

小文看后说:“《司马光砸锅》是你编的吧,故事讲的不够流畅。”说完,小文赶紧拿小本记下鉴别心得。(判别器通过鉴别真假数据的训练,提高判别能力。)

小方红着脸,去练习如何流畅叙事了。(生成器通过学习判别器的判别结果,提高生成假数据的逼真程度,以获得骗过判别器的能力。)

中学时代,文笔已褪去青涩的方同学推荐了《庆余年》、《海棠依旧》和自己写的《海棠朵朵》给文同学。(将真数据和生成器生成的假数据一起送给判别器判别真假。)

文同学也已刷剧无数不可与小学时同日而语,看后评价:“这个《海棠朵朵》不如前两篇写得引人入胜,又是出自你手吧。”鉴定完毕,文同学的信心又增加了不少。(判别器通过鉴别真假数据的训练,提高判别能力。)

方同学坦然一笑,继续去练习叙事结构与情节渲染。(生成器通过学习判别器的判别结果,提高生成假数据的逼真程度,以获得骗过判别器的能力。)

方同学和文同学就这样“在文学的蒙蔽与反蒙蔽斗争”中度过了他们的中学时代、大学时代、找工作时代,一路共同进步,来到了属于他们的大师时代。(判别器与生成器按前面的套路交替训练,逐步分别提高各自的判别能力和生成以假乱真的数据的能力。)

文学造诣已经炉火纯青方先生终于向多年亦对手亦良师的文先生推荐了《金瓶梅》、《红楼梦》和《青楼梦》三部终极作品。(将真数据和生成器生成的假数据一起送给判别器判别真假。)

文先生这些年来阅人无数,也已是文坛大佬,细细品鉴这些作品后觉得:“这些作品都是出自大师之手,无论古今。”评价第三部作品采前两部之所长,乃“清流之金瓶,烟火之红楼”也。各位文坛名宿也都公允这个评价。(判别器无论再怎么训练,也无法区分真数据和生成器生成的假数据。而且,生成的数据足够逼真,人类也难以分辨了。)

此时方先生坦言,第三部乃是自己的拙作。众人惊呼:“已得曹先生之真传也!”(生成器已经完美的拟合了训练数据的分布特征,GAN训练完成。)

至此,写《司马光砸锅》的小方终成一代文坛大佬,故事圆满。实际上这个故事的结局还有其他版本。

如果小学时的小文就已练就一副火眼金睛,无论小方如何努力也无法取得一点能跟上小文的进步,导致小方根本不知如何着力改进,最终只得放弃文学了。反之,如果当时小文比小方还naive,连《司马光砸锅》也看不出破绽,没了鞭策和方向的小方只好接着写《司马光补锅》、《司马光打铁》、《铁匠的自我修养》…所以,要想打通“完美结局”,需要始终在整个过程中让小文同学比小方同学高明一点点,在前面不远处给方同学指明努力的方向。也就是说要想GAN能稳定的继续训练,要始终让判别器的能力强于生成器一点点。判别器太强,则梯度消失,太弱,则生成器的梯度是错误的梯度。两种情况GAN都无法正常训练。

GAN的本质

其实GAN模型以及所有的生成模型都一样,做的事情只有一件:拟合训练数据的分布。对图片生成任务来说就是拟合训练集图片的像素概率分布。下面我们从原理的角度演示一下GAN的训练过程:

上图中:
黑色点线为训练集数据分布曲线
蓝色点线为判别器输出的分布曲线
绿色实线为生成器输出的分布曲线
z展示的是生成器映射前的简单概率分布(一般是高斯分布)的范围和密度
x展示的是生成器映射后学到的训练集的概率分布的范围和密度

  • (a)判别器与生成器均未训练呈随机分布
  • (b)判别器经过训练,输出的分布在靠近训练集“真”数据分布的区间趋近于1(真),在靠近生成器生成的“假”数据分布的区间趋近于0(假)
  • (c)生成器根据判别器输出的(真假)分布,更新参数,使自己的输出分布趋近于训练集“真”数据的分布。
    经过(b)(c)(b)(c)…步骤的循环交替。判别器的输出分布随着生成器输出的分布与训练集分布的接近而更加平缓;生成器输出的分布则在判别器输出分布的指引下逐渐趋近于训练集“真”数据的分布。
  • (d)训练完成时,生成器输出的分布完美拟合了训练集数据的分布,判别器的输出由于生成器的完美拟合而无法判别生成器输出的真伪而呈一条取值约为0.5(真假之间)的直线。

GAN的组成

  1. 解读GAN的loss函数

GAN网络的训练优化目标就是如下公式:

公式出自Goodfellow在2014年发表的论文Generative Adversarial Nets。这里简单介绍下公式的含义和如何应用到代码中。上式中等号左边的部分:
V(D,G)表示的是生成样本和真实样本的差异度,可以使用二分类(真、假两个类别)的交叉熵损失。

maxV(D, G)表示在生成器固定的情况下,通过最大化交叉熵损失V(D,G)来更新判别器D的参数。

min maxV(D, G)表示生成器要在判别器最大化真、假图片交叉熵损失V(D,G)的情况下,最小化这个交叉熵损失。

等式的右边其实就是将等式左边的交叉熵损失公式展开,并写成概率分布的期望形式。详细的推导请参见原论文《Generative Adversarial Nets》。

  1. 解读GAN的结构与训练流程

如上图所示GAN由一个判别器(Discriminator)和一个生成器(Generator)两个网络组成。

训练时先训练判别器:将训练集数据(Training Set)打上真标签(1)和生成器(Generator)生成的假图片(Fake image)打上假标签(0)一同组成batch送入判别器(Discriminator),对判别器进行训练。计算loss时使判别器对真数据(Training Set)输入的判别趋近于真(1),对生成器(Generator)生成的假图片(Fake image)的判别趋近于假(0)。此过程中只更新判别器(Discriminator)的参数,不更新生成器(Generator)的参数。

然后再训练生成器:将高斯分布的噪声z(Random noise)送入生成器(Generator),然后将生成器(Generator)生成的假图片(Fake image)打上真标签(1)送入判别器(Discriminator)。计算loss时使判别器对生成器(Generator)生成的假图片(Fake image)的判别趋近于真(1)。此过程中只更新生成器(Generator)的参数,不更新判别器(Discriminator)的参数。

下面我们就用飞桨深度学习框架写一下这个GAN的代码

用飞桨动态图实现生成手写字符的GAN

Paddle支持静态图和动态图两种编写模型的方式。

静态图模式(声明式编程范式):先编译后执行的方式。用户需预先定义完整的网络结构,再编译优化网络后,才能执行获得计算结果。
动态图模式(命令式编程范式):解析式的执行方式。用户无需预先定义完整的网络结构,每写一行网络代码,即可同时获得计算结果。
相比之下,静态图模式能够更方便进行全局优化,所以一般情况下执行效率更高;而动态图模式更加直观、灵活,便于调试模型。

为了更加灵活的试验网络配置,方便的观察网络各个模块的实时输出,我们选取所见即所得的动态图模式演示GAN的结构和原理。而且,即使是看中效率的工业应用场景下,动态图模式也获得越来越多的认可。毕竟多年来码农们也没因为执行效率的原因弃用更加友好高级语言,而采用汇编语言去编写应用程序。更何况,Paddle团队的小姐姐、小哥哥们正夜以继日的努力,以期在新版本中(2.0版),赋予广大用Paddle开发项目的小伙伴们“用动态图开发,用静态图部署”的能力。这样就能兼得开发和部署的效率了,真香不是?

1.数据读取模块

要喂入生成器高斯分布的噪声隐变量z的维度设置为100。训练集数据使用Paddle框架内置函数paddle.dataset.mnist.train()、paddle.reader.shuffle()和paddle.batch()进行读取、打乱和划分batch。读取图片数据处理为 [N,W,H] 格式。

import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph import Conv2D, Pool2D, Linear
import numpy as npimport matplotlib.pyplot as plt
# 噪声维度Z_DIM = 100BATCH_SIZE = 128
# 读取真实图片的数据集,这里去除了数据集中的label数据,因为label在这里使用不上,这里不考虑标签分类问题。def mnist_reader(reader):def r():for img, label in reader():yield img.reshape(1, 28, 28)return r
# 噪声生成,通过由噪声来生成假的图片数据输入。def z_reader():while True:yield np.random.normal(0.0, 1.0, (Z_DIM, 1, 1)).astype('float32')
# 生成真实图片readermnist_generator = paddle.batch(paddle.reader.shuffle(mnist_reader(paddle.dataset.mnist.train()), 30000), batch_size=BATCH_SIZE)
# 生成假图片的readerz_generator = paddle.batch(z_reader, batch_size=BATCH_SIZE)

测试下数据读取器和高斯噪声生成器。

import matplotlib.pyplot as plt%matplotlib inline
pics_tmp = next(mnist_generator())print('一个batch图片数据的形状:batch_size =', len(pics_tmp), ', data_shape =', pics_tmp[0].shape)
plt.imshow(pics_tmp[0][0])plt.show()
z_tmp = next(z_generator())print('一个batch噪声z的形状:batch_size =', len(z_tmp), ', data_shape =', z_tmp[0].shape)

一个batch图片数据的形状:batch_size = 128 , data_shape = (1, 28, 28)
一个batch噪声z的形状:batch_size = 128 , data_shape = (100, 1, 1)

2.GAN的判别器D和生成器G

GAN性能的提升从生成器G和判别器D进行左右互搏、交替完善的过程得到的。所以其G网络和D网络的能力应该设计得相近,复杂度也差不多。这个项目中的生成器,采用了两个全链接层接两组上采样和转置卷积层,将输入的噪声z逐渐转化为1×28×28的单通道图片输出。判别器的结构正好相反,先通过两组卷积和池化层将输入的图片转化为越来越小的特征图,再经过两层全链接层,输出图片是真是假的二分类结果。

实际上,本项目实现的是一个DCGAN(深度卷积生成对抗网络)。原版的GAN的判别器和生成器使用的都是全连接层,在DCGAN中使用卷积层代替。这样做的好处是卷积网络能够提取图片数据的二维特征,提高图片的生成质量。

判别器结构:


生成器结构:

# 判别器Dclass D(fluid.dygraph.Layer):def __init__(self, name_scope):super(D, self).__init__(name_scope)name_scope = self.full_name()# 第一组卷积池化self.conv1 = Conv2D(num_channels=1, num_filters=64, filter_size=3)self.bn1 = fluid.dygraph.BatchNorm(num_channels=64, act='relu')self.pool1 = Pool2D(pool_size=2, pool_stride=2)# 第二组卷积池化self.conv2 = Conv2D(num_channels=64, num_filters=128, filter_size=3)self.bn2 = fluid.dygraph.BatchNorm(num_channels=128, act='relu')self.pool2 = Pool2D(pool_size=2, pool_stride=2)# 全连接输出层self.fc1 = Linear(input_dim=128*5*5, output_dim=1024)self.bnfc1 = fluid.dygraph.BatchNorm(num_channels=1024, act='relu')self.fc2 = Linear(input_dim=1024, output_dim=1)def forward(self, img):y = self.conv1(img)y = self.bn1(y)y = self.pool1(y)y = self.conv2(y)y = self.bn2(y)y = self.pool2(y)y = fluid.layers.reshape(y, shape=[-1, 128*5*5])y = self.fc1(y)y = self.bnfc1(y)y = self.fc2(y)return y
# 下面分别实现了“上采样”和“转置卷积”两种方式实现的生成网络G。注释掉其中一个版本可测试另一个。# 通过上采样扩大特征图的版本class G(fluid.dygraph.Layer):def __init__(self, name_scope):super(G, self).__init__(name_scope)name_scope = self.full_name()# 第一组全连接和BN层self.fc1 = Linear(input_dim=100, output_dim=1024)self.bn1 = fluid.dygraph.BatchNorm(num_channels=1024, act='tanh')# 第二组全连接和BN层self.fc2 = Linear(input_dim=1024, output_dim=128*7*7)self.bn2 = fluid.dygraph.BatchNorm(num_channels=128*7*7, act='tanh')# 第一组卷积运算(卷积前进行上采样,以扩大特征图)# 注:此处使用转置卷积的效果似乎不如上采样后直接用卷积,转置卷积生成的图片噪点较多self.conv1 = Conv2D(num_channels=128, num_filters=64, filter_size=5, padding=2)self.bn3 = fluid.dygraph.BatchNorm(num_channels=64, act='tanh')# 第二组卷积运算(卷积前进行上采样,以扩大特征图)self.conv2 = Conv2D(num_channels=64, num_filters=1, filter_size=5, padding=2, act='tanh')def forward(self, z):z = fluid.layers.reshape(z, shape=[-1, 100])y = self.fc1(z)y = self.bn1(y)y = self.fc2(y)y = self.bn2(y)y = fluid.layers.reshape(y, shape=[-1, 128, 7, 7])# 第一组卷积前进行上采样以扩大特征图y = fluid.layers.image_resize(y, scale=2)y = self.conv1(y)y = self.bn3(y)# 第二组卷积前进行上采样以扩大特征图y = fluid.layers.image_resize(y, scale=2)y = self.conv2(y)return y
测试生成器G网络和判别器D网络的前向计算结果。一个batch的数据,输出一张图片。# 测试生成网络G和判别网络Dwith fluid.dygraph.guard():g_tmp = G('G')tmp_g = g_tmp(fluid.dygraph.to_variable(np.array(z_tmp))).numpy()print('生成器G生成图片数据的形状:', tmp_g.shape)plt.imshow(tmp_g[0][0])plt.show()d_tmp = D('D')tmp_d = d_tmp(fluid.dygraph.to_variable(tmp_g)).numpy()print('判别器D判别生成的图片的概率数据形状:', tmp_d.shape)

测试生成器G网络和判别器D网络的前向计算结果。一个batch的数据,输出一张图片。

# 测试生成网络G和判别网络Dwith fluid.dygraph.guard():
g_tmp = G('G')
tmp_g = g_tmp(fluid.dygraph.to_variable(np.array(z_tmp))).numpy()
print('生成器G生成图片数据的形状:', tmp_g.shape)
plt.imshow(tmp_g[0][0])
plt.show()d_tmp = D('D')
tmp_d = d_tmp(fluid.dygraph.to_variable(tmp_g)).numpy()
print('判别器D判别生成的图片的概率数据形状:', tmp_d.shape)

生成器G生成图片数据的形状:(128, 1, 28, 28)
判别器D判别生成的图片的概率数据形状:(128, 1)

3.辅助函数(用于训练过程图片打印,和VisualDL图片打印)

# 显示图片,构建一个18*n大小(n=batch_size/16)的图片阵列,把预测的图片打印到note中。import matplotlib.pyplot as plt%matplotlib inline
def show_image_grid(images, batch_size=128, pass_id=None):fig = plt.figure(figsize=(8, batch_size/32))fig.suptitle("Pass {}".format(pass_id))gs = plt.GridSpec(int(batch_size/16), 16)gs.update(wspace=0.05, hspace=0.05)for i, image in enumerate(images):ax = plt.subplot(gs[i])plt.axis('off')ax.set_xticklabels([])ax.set_yticklabels([])ax.set_aspect('equal')plt.imshow(image[0], cmap='Greys_r')plt.show()
show_image_grid(tmp_g, BATCH_SIZE)

# 拼接一个batch图像用于VisualDL可视化def concatenate_img(input_img):img_arr_broadcasted = ((np.zeros([BATCH_SIZE,3,28,28]) + input_img) * 255).astype('uint8').transpose((0,2,3,1)).reshape([-1,16,28,28,3])# print(img_arr_broadcasted.shape)img_concatenated = np.concatenate(tuple(img_arr_broadcasted), axis=1)# print(img_concatenated.shape)img_concatenated = np.concatenate(tuple(img_concatenated), axis=1)# print(img_concatenated.shape)return img_concatenated
plt.figure(figsize=(12,BATCH_SIZE/32),dpi=80)plt.imshow(concatenate_img(tmp_g))

4.训练过程

训练过程主要有以下几部分:

1)定义判别器与生成器对象
定义一个判别器D和一个生成器G并设置为训练模式。

2)定义优化器对象
由于本项目的GAN在训练时每轮先更新两次D(真假样本各一次)再更新一次G,所以要定义两个判别器优化器对象(真假判别器各一个)。用Adam优化策略,lr设为“迷之2e-4”~~。

3)读取上次保存的模型
用于继续训练。

4)训练判别器和生成器的epoch循环
用一个batch的真数据和一个batch的假数据各更新一次判别器(判别器为同一个),然后更新一次生成器。loss使用的是带sigmoid的交叉熵损失函数“fluid.layers.sigmoid_cross_entropy_with_logits()”。这部分代码是模型的重要部分,做了详细的注释。

5)打印输出、写VisualDL的log
用于监视模型的训练进程。

6)训练结束保存模型
如果保存的模型只需要用于推理,只需保存模型参数;如果模型也要要用于下次继续训练,还需保存优化器参数。

from visualdl import LogWriter
def train(mnist_generator, epoch_num=10, batch_size=128, use_gpu=True, load_model=False):# with fluid.dygraph.guard():place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()with fluid.dygraph.guard(place):# 模型存储路径model_path = './output/'# 定义判别器与生成器对象d = D('D')d.train()g = G('G')g.train()# 定义优化器对象real_d_optimizer = fluid.optimizer.AdamOptimizer(learning_rate=2e-4, parameter_list=d.parameters())fake_d_optimizer = fluid.optimizer.AdamOptimizer(learning_rate=2e-4, parameter_list=d.parameters())g_optimizer = fluid.optimizer.AdamOptimizer(learning_rate=2e-4, parameter_list=g.parameters())# 读取上次保存的模型if load_model == True:g_para, g_opt = fluid.load_dygraph(model_path+'g')d_para, d_r_opt = fluid.load_dygraph(model_path+'d_o_r')# 上面判别器的参数已经读取到d_para了,此处无需再次读取_, d_f_opt = fluid.load_dygraph(model_path+'d_o_f')g.load_dict(g_para)g_optimizer.set_dict(g_opt)d.load_dict(d_para)real_d_optimizer.set_dict(d_r_opt)fake_d_optimizer.set_dict(d_f_opt)# 定义日志写入(先清空日志文件夹)if load_model == False:!rm -rf /home/aistudio/log/real_loss_wrt = LogWriter(logdir='./log/d_real_loss')fake_loss_wrt = LogWriter(logdir='./log/d_fake_loss')g_loss_wrt = LogWriter(logdir='./log/g_loss')image_wrt = LogWriter(logdir='./log/imgs')iteration_num = 0for epoch in range(epoch_num):for i, real_image in enumerate(mnist_generator()):# 丢弃不满整个batch_size的数据if(len(real_image) != BATCH_SIZE):continueiteration_num += 1'''判别器d通过最小化输入真实图片时判别器d的输出与真值标签ones的交叉熵损失,来优化判别器的参数,以增加判别器d识别真实图片real_image为真值标签ones的概率。                '''# 将MNIST数据集里的图片读入real_image,将真值标签ones用数字1初始化real_image = fluid.dygraph.to_variable(np.array(real_image))ones = fluid.dygraph.to_variable(np.ones([len(real_image), 1]).astype('float32'))# 计算判别器d判断真实图片的概率p_real = d(real_image)# 计算判别真图片为真的损失real_cost = fluid.layers.sigmoid_cross_entropy_with_logits(p_real, ones)real_avg_cost = fluid.layers.mean(real_cost)# 反向传播更新判别器d的参数real_avg_cost.backward()real_d_optimizer.minimize(real_avg_cost)d.clear_gradients()'''判别器d通过最小化输入生成器g生成的假图片g(z)时判别器的输出与假值标签zeros的交叉熵损失,来优化判别器d的参数,以增加判别器d识别生成器g生成的假图片g(z)为假值标签zeros的概率。'''# 创建高斯分布的噪声z,将假值标签zeros初始化为0z = next(z_generator())z = fluid.dygraph.to_variable(np.array(z))zeros = fluid.dygraph.to_variable(np.zeros([len(real_image), 1]).astype('float32'))# 判别器d判断生成器g生成的假图片的概率p_fake = d(g(z))# 计算判别生成器g生成的假图片为假的损失fake_cost = fluid.layers.sigmoid_cross_entropy_with_logits(p_fake, zeros)fake_avg_cost = fluid.layers.mean(fake_cost)# 反向传播更新判别器d的参数fake_avg_cost.backward()fake_d_optimizer.minimize(fake_avg_cost)d.clear_gradients()'''生成器g通过最小化判别器d判别生成器生成的假图片g(z)为真的概率d(fake)与真值标签ones的交叉熵损失,来优化生成器g的参数,以增加生成器g使判别器d判别其生成的假图片g(z)为真值标签ones的概率。'''# 生成器用输入的高斯噪声z生成假图片fake = g(z)# 计算判别器d判断生成器g生成的假图片的概率p_confused = d(fake)# 使用判别器d判断生成器g生成的假图片的概率与真值ones的交叉熵计算损失g_cost = fluid.layers.sigmoid_cross_entropy_with_logits(p_confused, ones)g_avg_cost = fluid.layers.mean(g_cost)# 反向传播更新生成器g的参数g_avg_cost.backward()g_optimizer.minimize(g_avg_cost)g.clear_gradients()# 打印输出if(iteration_num % 1000 == 0):print('epoch =', epoch, ', batch =', i, ', real_d_loss =', real_avg_cost.numpy(), ', fake_d_loss =', fake_avg_cost.numpy(), 'g_loss =', g_avg_cost.numpy())show_image_grid(fake.numpy(), BATCH_SIZE, epoch)# 写VisualDL日志real_loss_wrt.add_scalar(tag='loss', step=iteration_num, value=real_avg_cost.numpy())fake_loss_wrt.add_scalar(tag='loss', step=iteration_num, value=fake_avg_cost.numpy())g_loss_wrt.add_scalar(tag='loss', step=iteration_num, value=g_avg_cost.numpy())image_wrt.add_image(tag='numbers', img=concatenate_img(fake.numpy()), step=iteration_num)# 存储模型fluid.save_dygraph(g.state_dict(), model_path+'g')fluid.save_dygraph(g_optimizer.state_dict(), model_path+'g')fluid.save_dygraph(d.state_dict(), model_path+'d_o_r')fluid.save_dygraph(real_d_optimizer.state_dict(), model_path+'d_o_r')fluid.save_dygraph(d.state_dict(), model_path+'d_o_f')fluid.save_dygraph(fake_d_optimizer.state_dict(), model_path+'d_o_f')
train(mnist_generator, epoch_num=20, batch_size=BATCH_SIZE, use_gpu=True)

5.用VisualDL2.0观察训练

我们也可以使用Paddle框架的VisualDL组件更方便的观察训练过程。VisualDL是深度学习模型可视化分析工具,以丰富的图表呈现训练参数变化趋势、模型结构、数据样本、高维数据分布等。可帮助用户更清晰直观地理解深度学习模型训练过程及模型结构,进而实现高效的模型优化,支持实时训练参数分析、图结构、数据样本可视化及高维数据降维呈现等诸多功能。VisualDL原生支持Python的使用, 通过在模型的Python配置中添加几行代码,便可为训练过程提供丰富的可视化支持,全面支持Paddle、ONNX、Caffe等市面主流模型结构可视化,广泛支持各类用户进行可视化分析。

VisualDL使用非常便捷,使用流程只有两个步骤:

1)将loss、image等数据写入log文件

# 导入LogWriter对象
from visualdl import LogWriter
...
# 声明一个loss记录的专用log写入器
real_loss_wrt = LogWriter(logdir='./log/d_real_loss')
...
# 添加loss数据的记录
real_loss_wrt.add_scalar(tag='loss', step=iteration_num, value=real_avg_cost.numpy())

图片数据的写入也是类似的,上面源码中有展示。

2)启动VisualDL服务,在浏览器打开查看页面
首先,在终端输入:visualdl --logdir ./log --port 8081

其中 --logdir ./log 参数指定log文件的存储目录为当前目录下的log文件夹,–port 8081 参数指定VisualDL服务占用的端口,如果其已被占用,可以使用其他端口,如8082、8083等。

然后,再打开这个网址(
https://aistudio.baidu.com/bdcpu3/user/76563/551962/visualdl)即可进入VisualDL页面查看模型训练情况。在AI Studio中,这个网址就是用VisualDL替换原项目运行网址Notebook后面的内容得来的。在自己的主机上,这个网址就是运行VisualDL服务的地址加上端口号,如 http://127.0.0.1:8081 。

查看生成器、判别器的loss曲线:

查看训练过程中生成的图片:



总结与思考

1.经典GAN存在的一些问题

1)训练不稳定
正如前面“文坛大佬”们的传说所寓意的那样,经典GAN的训练是不稳定的。判别器步子迈得大了,容易扯着生成器的蛋~~。为了能够稳定地训练GAN,大伙只能人工调整判别器与生成器的训练进程很不方便。所以油菜花(有才华)的大神们又造了lsgan、wgan以及其改进版wgan-gp,基本上解决了稳定性问题。

2)大尺寸图像生成质量不高
虽然本项目的DCGAN采用了卷积层代替了经典GAN中全连接层,已经提升的生成图片的质量。但是在生成大尺寸图片时还是显得力不从心。为了解决这一问题,我们采用了判别器采用PatchGAN、使用多尺度特征融合、逐层训练等方法改进。其中,BigGAN便是集众tricks于一身,用够Big的数据集,生成尺寸够Big的图片的。

3)无法控制生成的字符类别
正如本项目中演示的那样,经典GAN生成的手写字符是随机的。咱只能是“给啥要啥”,做不到“要啥给啥”。要想做个听话的GAN,就得给它装个控制按钮。这个带按钮的GAN就是CGAN(条件生成对抗网络)。

2.下一个项目预告

下一个项目,我们就介绍“带按钮”的CGAN。CGAN拟合的是条件概率分布,所以可以通过输入的控制变量,控制输出图片的类别。后来这一思想被发扬光大,才有了Pix2Pix、CycleGAN等有趣的风格迁移网络。

本文介绍项目的可执行版本可到AI Studio运行
https://aistudio.baidu.com/aistudio/projectdetail/551962

本文的理论、概念,理解有歪的,欢迎各位少侠、大佬们“拍”正。正所谓“众人拍砖,盖高楼…”~~


如在使用过程中有问题,可加入飞桨官方QQ群进行交流:1108045677。

如果您想详细了解更多飞桨的相关内容,请参阅以下文档。

  • 飞桨PaddleGAN项目地址(欢迎Star)·
    GitHub:
    https://github.com/PaddlePaddle/PaddleGAN
    Gitee:
    https://Gitee.com/PaddlePaddle/PaddleGAN

  • 飞桨官网地址·
    https://www.paddlepaddle.org.cn/

【飞桨PaddlePaddle】四天搞懂生成对抗网络(一)——通俗理解经典GAN相关推荐

  1. 【飞桨PaddlePaddle】四天搞懂生成对抗网络(二)——风格迁移的“精神始祖”Conditional GAN

    从"自由挥洒"到"有的放矢" 1.给GAN加个"按钮" 上一篇<四天搞懂生成对抗网络(一)--通俗理解经典GAN>中,我们实现了 ...

  2. 四天搞懂生成对抗网络(二)——风格迁移的“精神始祖”Conditional GAN

    点击左上方蓝字关注我们 [飞桨开发者说]吕坤,唐山广播电视台,算法工程师,喜欢研究GAN等深度学习技术在媒体.教育上的应用. 从"自由挥洒"到"有的放矢" 1. ...

  3. 四天搞懂生成对抗网络(三)——用CGAN做图像转换的鼻祖pix2pix

    点击左上方蓝字关注我们 [飞桨开发者说]吕坤,唐山广播电视台算法工程师,PPDE飞桨开发者技术专家,喜欢研究GAN等深度学习技术在媒体.教育上的应用. Pix2Pix的不甘の野望 也许是CycleGA ...

  4. 四天搞懂生成对抗网络(一)——通俗理解经典GAN

    点击左上方蓝字关注我们 [飞桨开发者说]吕坤,唐山广播电视台,算法工程师,喜欢研究GAN等深度学习技术在媒体.教育上的应用. 序言 做图像分类.检测任务时,为了提高模型精度,在数据处理方面,我尝试了很 ...

  5. 7天搞定生成对抗网络!百度高级工程师组队来袭

    7天搞定生成对抗网络!百度高级工程师组队来袭 原理+实战|7天学会GAN 课程大纲 讲师介绍 学习收获 万元奖品池等你来战 开课时间 免费报名方式 深度学习中最有趣的方法是什么?GAN! 最近最火的A ...

  6. 七夕礼物没送对?飞桨PaddlePaddle帮你读懂女朋友的小心思

    本文作者:飞桨工程师 量子位 转载 | 公众号 QbitAI 七夕节,广大的钢铁直男们,你们给女朋友的礼物买对了么? "女孩儿的心思男孩你别猜,你猜来猜去也猜不明白.不知道她为什么掉眼泪,也 ...

  7. 干货丨一文看懂生成对抗网络:从架构到训练技巧

    文章来源:机器之心 论文地址:https://arxiv.org/pdf/1710.07035.pdf 生成对抗网络(GAN)提供了一种不需要大量标注训练数据就能学习深度表征的方式.它们通过反向传播算 ...

  8. 判别两棵树是否相等 设计算法_一文看懂生成对抗网络 - GANs?(附:10种典型算法+13种应用)...

    生成对抗网络 – GANs 是最近2年很热门的一种无监督算法,他能生成出非常逼真的照片,图像甚至视频.我们手机里的照片处理软件中就会使用到它. 本文将详细介绍生成对抗网络 – GANs 的设计初衷.基 ...

  9. 一文读懂生成对抗网络GANs(附学习资源)

    原文标题:AnIntuitive Introduction to Generative Adversarial Networks 作者:KeshavDhandhania.ArashDelijani 翻 ...

最新文章

  1. virtualbox 安装ubuntu 时,看不到继续、退出按钮?共享文件无权限?
  2. VS2010配置opencv2.4.9
  3. 服务器指纹识别之 DNS TXT
  4. 七、深入JavaScript的DOM(三)
  5. python 菜鸟入门
  6. [vue] 如何在子组件中访问父组件的实例?
  7. Oracle中row_number()、rank()、dense_rank() 的区别
  8. java、c语言、python、c++的不同之处_总结几点C/C++、Java与Python的区别
  9. mysql经典主从复制
  10. 带负荷测试要求二次最小电流_电流回路基础知识(15):带负荷测试
  11. 来来来!java页面导出数据到excel
  12. 短链接生成接口、长链接转换短链接,可根据ip归属地个性化跳转、随机跳转
  13. WebRoot 与 WEB-INF 相关问题学习整理
  14. 2019年美国大联盟美国总决赛小学组获奖牌名单
  15. 网络拓扑学习之SLB
  16. 小布机器人怎么断网_小布壳Q1,用人工智能重新定义儿童阅读
  17. 深圳金证股份面试的经历
  18. 使用Wifi pineapple(菠萝派)进行Wi-Fi钓鱼攻击
  19. 释放自我。回归本性。要成功。
  20. 从GPT-1到GPT-4看ChatGPT的崛起

热门文章

  1. 华为鸿蒙背后:中国首个自己的开源基金会来了!
  2. 饼状图(pie chart)
  3. 梯度下降法中为什么梯度的反方向是函数下降最快的方向?
  4. 用HTML做一个属于你的 “世界“
  5. linux shell 发送 微信消息
  6. 一款可以对接多用户商城系统的客服系统
  7. 金融行业组织架构及岗位分布
  8. iOS面试合集+答案(一)
  9. 消费电子世界采访联众总裁鲍岳桥完全版
  10. c语言算数运算,C语言:算数运算符