系列文章目录

深度学习GAN(一)之简单介绍
深度学习GAN(二)之DCGAN基于CIFAR10数据集的例子
深度学习GAN(三)之DCGAN基于手写体Mnist数据集的例子
深度学习GAN(四)之cGAN (Conditional GAN)的例子
深度学习GAN(五)之PIX2PIX GAN的例子
深度学习GAN(六)之CycleGAN的例子


Pix2Pix GAN的例子

  • 系列文章目录
  • 1. Pix2Pix介绍
  • 2. 下载卫星地图数据集
  • 3. 数据预处理(Data Reprocessing)
  • 4. 定义判别器
  • 5. 定义生成器
  • 6. 定义GAN模型
  • 7. 加载真实图片以及生成假的图片
  • 8. 用生成器每个几个Epoch生成一些假的图片。看看效果
  • 10. 训练过程
  • 11. 训练后效果
  • 12.完整的代码

1. Pix2Pix介绍

Pix2Pix是一个对抗神经网络(GAN)模型,设计一般用于图像到图像转换。

该方法由Phillip Isola等提出。在其2016年题为“使用条件对抗网络的图像到图像翻译”的论文中,该论文于2017年在CVPR上发表。

GAN架构由用于输出新的合理合成图像的生成器模型和将图像分类为真实(来自数据集)或伪图像(生成)的鉴别器模型组成。鉴别器模型直接更新,而生成器模型通过鉴别器模型更新。这样,在对抗过程中同时训练两个模型,其中生成器试图更好地欺骗鉴别器,而鉴别器试图更好地识别伪造图像。

Pix2Pix模型是一种条件GAN或cGAN,其中输出图像的生成取决于输入(在这种情况下为源图像)。鉴别器既提供源图像又提供目标图像,并且必须确定目标是否是源图像的合理变换。

通过对抗损失训练生成器,这鼓励了生成器在目标域中生成合理的图像。还通过在生成的图像和预期的输出图像之间测量的L1损耗来更新生成器。这种额外的损失鼓励生成器模型创建源图像的合理翻译。

Pix2Pix GAN已在一系列图像到图像转换任务中得到了证明,例如将地图转换为卫星照片,将黑白照片转换为颜色,将产品草图转换为产品照片。

现在我们已经熟悉了Pix2Pix GAN,下面我们准备一个可用于图像到图像转换的数据集。

2. 下载卫星地图数据集

这个数据集由纽约的卫星图像及其相应的Google地图组成。 图像的转换问题涉及将卫星照片转换为Google地图格式,或者将Google地图图像转换为卫星照片。

数据集在pix2pix网站上提供,可以作为255 MB的zip文件下载。
Download Maps Dataset (maps.tar.gz)

下载后解压后目录结构如下:

进入任意一个目录,打开其中一个图片,

3. 数据预处理(Data Reprocessing)

为了让图片在训练的时候加载的快一点,我们把下载的所有的图片都用Numpy保存在maps_256.npz.

from os import listdir
from numpy import asarray
from numpy import vstack
from keras.preprocessing.image import img_to_array
from keras.preprocessing.image import load_img
from numpy import savez_compressed# load all images in a directory into memory
def load_images(path, size=(256,512)):src_list, tar_list = list(), list()# enumerate filenames in directory, assume all are imagesfor filename in listdir(path):# load and resize the imagepixels = load_img(path + filename, target_size=size)# convert to numpy arraypixels = img_to_array(pixels)# split into satellite and mapsat_img, map_img = pixels[:, :256], pixels[:, 256:]src_list.append(sat_img)tar_list.append(map_img)return [asarray(src_list), asarray(tar_list)]# dataset path
path = 'D:/ML/datasets/maps/train/'
# load dataset
[src_images, tar_images] = load_images(path)
print('Loaded: ', src_images.shape, tar_images.shape)
# save as compressed numpy array
filename = 'maps_256.npz'
savez_compressed(filename, src_images, tar_images)
print('Saved dataset: ', filename)

结果是

Loaded:  (1096, 256, 256, 3) (1096, 256, 256, 3)
Saved dataset:  maps_256.npz

然后运行下面代码验证一下是否正确的可以显示图片。

# load the prepared dataset
from numpy import load
from matplotlib import pyplot
# load the dataset
data = load('maps_256.npz')
src_images, tar_images = data['arr_0'], data['arr_1']
print('Loaded: ', src_images.shape, tar_images.shape)
# plot source images
n_samples = 3
for i in range(n_samples):pyplot.subplot(2, n_samples, 1 + i)pyplot.axis('off')pyplot.imshow(src_images[i].astype('uint8'))
# plot target image
for i in range(n_samples):pyplot.subplot(2, n_samples, 1 + n_samples + i)pyplot.axis('off')pyplot.imshow(tar_images[i].astype('uint8'))
pyplot.show()

4. 定义判别器

这个判别器是基于PatchGAN discriminator model实现的。

注意这里的输入是两个图片,in_src_image是卫星图像, in_target_image是谷歌地图。
同过Concatenate方法,合并为6个通道,每天图片是的3个通道(RGB).
激活函数用LeakyReLU, 除了第一层与最后一层,其它都用BatchNormalization.
输出层输出是(16,16,1)

# define the discriminator model
def define_discriminator(image_shape):# weight initializationinit = RandomNormal(stddev=0.02)# source image inputin_src_image = Input(shape=image_shape)# target image inputin_target_image = Input(shape=image_shape)# concatenate images channel-wisemerged = Concatenate()([in_src_image, in_target_image])# C64d = Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(merged)d = LeakyReLU(alpha=0.2)(d)# C128d = Conv2D(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)d = BatchNormalization()(d)d = LeakyReLU(alpha=0.2)(d)# C256d = Conv2D(256, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)d = BatchNormalization()(d)d = LeakyReLU(alpha=0.2)(d)# C512d = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)d = BatchNormalization()(d)d = LeakyReLU(alpha=0.2)(d)# second last output layerd = Conv2D(512, (4,4), padding='same', kernel_initializer=init)(d)d = BatchNormalization()(d)d = LeakyReLU(alpha=0.2)(d)# patch outputd = Conv2D(1, (4,4), padding='same', kernel_initializer=init)(d)patch_out = Activation('sigmoid')(d)# define modelmodel = Model([in_src_image, in_target_image], patch_out)# compile modelopt = Adam(lr=0.0002, beta_1=0.5)model.compile(loss='binary_crossentropy', optimizer=opt, loss_weights=[0.5])return model
if __name__ == '__main__':d_model = define_discriminator((256,256,3))print(d_model.summary())

它的结构是

5. 定义生成器

生成器是使用U-Net架构的encoder-decoder模型。 该模型获取源图像(例如卫星照片)并生成目标图像(例如Google地图图像)。 它首先通过对输入图像进行下采样或编码到瓶颈层(bottleneck layer),然后对瓶颈(bottleneck layer)表示进行上采样或解码到输出图像的大小来做到这一点。 U-Net体系结构意味着在编码层和相应的解码层之间添加跳过连接(skip-connections),从而形成U形。

下图清楚地显示了跳过连接(skip-connections),显示了编码器的第一层如何连接到解码器的最后一层,依此类推。

生成器的encoder和decoder由convolutional, batch normalization, dropout, and activation layers组成。 这种标准化意味着我们可以开发辅助函数来创建每个图层块,并反复调用它以建立模型的encoder和decoder部分。

下面的define_generator()函数实现了U-Net编码器-解码器生成器模型。 它使用define_encoder_block()帮助函数创建用于编码器的层块,并使用coder_block()函数创建用于解码器的层块。 tanh激活函数在输出层中使用,这意味着生成的图像中的像素值将在[-1,1]范围内。

输入是一个文星图片,经过Encoder-Decoder这个网络结构,最后生成一个谷歌地图
(256,256,3) ->Encoder-> (1,1,512) -> Decoder -> (256,256,3)

# define an encoder block
def define_encoder_block(layer_in, n_filters, batchnorm=True):# weight initializationinit = RandomNormal(stddev=0.02)# add downsampling layerg = Conv2D(n_filters, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(layer_in)# conditionally add batch normalizationif batchnorm:g = BatchNormalization()(g, training=True)# leaky relu activationg = LeakyReLU(alpha=0.2)(g)return g# define a decoder block
def decoder_block(layer_in, skip_in, n_filters, dropout=True):# weight initializationinit = RandomNormal(stddev=0.02)# add upsampling layerg = Conv2DTranspose(n_filters, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(layer_in)# add batch normalizationg = BatchNormalization()(g, training=True)# conditionally add dropoutif dropout:g = Dropout(0.5)(g, training=True)# merge with skip connectiong = Concatenate()([g, skip_in])# relu activationg = Activation('relu')(g)return g# define the standalone generator model
def define_generator(image_shape=(256,256,3)):# weight initializationinit = RandomNormal(stddev=0.02)# image inputin_image = Input(shape=image_shape)# encoder modele1 = define_encoder_block(in_image, 64, batchnorm=False)e2 = define_encoder_block(e1, 128)e3 = define_encoder_block(e2, 256)e4 = define_encoder_block(e3, 512)e5 = define_encoder_block(e4, 512)e6 = define_encoder_block(e5, 512)e7 = define_encoder_block(e6, 512)# bottleneck, no batch norm and relub = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(e7)b = Activation('relu')(b)# decoder modeld1 = decoder_block(b, e7, 512)d2 = decoder_block(d1, e6, 512)d3 = decoder_block(d2, e5, 512)d4 = decoder_block(d3, e4, 512, dropout=False)d5 = decoder_block(d4, e3, 256, dropout=False)d6 = decoder_block(d5, e2, 128, dropout=False)d7 = decoder_block(d6, e1, 64, dropout=False)# outputg = Conv2DTranspose(3, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d7)out_image = Activation('tanh')(g)# define modelmodel = Model(in_image, out_image)return model
if __name__ == '__main__':g_model = define_generator((256,256,3))print(g_model.summary())

它的结构是

6. 定义GAN模型

GAN的模型主要是训练生成器,所以判别器不训练(d_model.trainable = False)。
输入层是卫星图片(256,256,3),
输出层是 dis_out=(16,16,1)
gen_out = (256,256,3)

# define the combined generator and discriminator model, for updating the generator
def define_gan(g_model, d_model, image_shape):# make weights in the discriminator not trainabled_model.trainable = False# define the source imagein_src = Input(shape=image_shape)# connect the source image to the generator inputgen_out = g_model(in_src)# connect the source input and generator output to the discriminator inputdis_out = d_model([in_src, gen_out])# src image as input, generated image and classification outputmodel = Model(in_src, [dis_out, gen_out])# compile modelopt = Adam(lr=0.0002, beta_1=0.5)model.compile(loss=['binary_crossentropy', 'mae'], optimizer=opt, loss_weights=[1,100])return model
if __name__ == '__main__':d_model = define_discriminator((256,256,3))g_model = define_generator((256,256,3))gan_model = define_gan(g_model, d_model, (256,256,3))print(g_model.summary())
Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
input_3 (InputLayer)            [(None, 256, 256, 3) 0
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 128, 128, 64) 3136        input_3[0][0]
__________________________________________________________________________________________________
leaky_re_lu_5 (LeakyReLU)       (None, 128, 128, 64) 0           conv2d_6[0][0]
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 64, 64, 128)  131200      leaky_re_lu_5[0][0]
__________________________________________________________________________________________________
batch_normalization_4 (BatchNor (None, 64, 64, 128)  512         conv2d_7[0][0]
__________________________________________________________________________________________________
leaky_re_lu_6 (LeakyReLU)       (None, 64, 64, 128)  0           batch_normalization_4[0][0]
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (None, 32, 32, 256)  524544      leaky_re_lu_6[0][0]
__________________________________________________________________________________________________
batch_normalization_5 (BatchNor (None, 32, 32, 256)  1024        conv2d_8[0][0]
__________________________________________________________________________________________________
leaky_re_lu_7 (LeakyReLU)       (None, 32, 32, 256)  0           batch_normalization_5[0][0]
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (None, 16, 16, 512)  2097664     leaky_re_lu_7[0][0]
__________________________________________________________________________________________________
batch_normalization_6 (BatchNor (None, 16, 16, 512)  2048        conv2d_9[0][0]
__________________________________________________________________________________________________
leaky_re_lu_8 (LeakyReLU)       (None, 16, 16, 512)  0           batch_normalization_6[0][0]
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, 8, 8, 512)    4194816     leaky_re_lu_8[0][0]
__________________________________________________________________________________________________
batch_normalization_7 (BatchNor (None, 8, 8, 512)    2048        conv2d_10[0][0]
__________________________________________________________________________________________________
leaky_re_lu_9 (LeakyReLU)       (None, 8, 8, 512)    0           batch_normalization_7[0][0]
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 4, 4, 512)    4194816     leaky_re_lu_9[0][0]
__________________________________________________________________________________________________
batch_normalization_8 (BatchNor (None, 4, 4, 512)    2048        conv2d_11[0][0]
__________________________________________________________________________________________________
leaky_re_lu_10 (LeakyReLU)      (None, 4, 4, 512)    0           batch_normalization_8[0][0]
__________________________________________________________________________________________________
conv2d_12 (Conv2D)              (None, 2, 2, 512)    4194816     leaky_re_lu_10[0][0]
__________________________________________________________________________________________________
batch_normalization_9 (BatchNor (None, 2, 2, 512)    2048        conv2d_12[0][0]
__________________________________________________________________________________________________
leaky_re_lu_11 (LeakyReLU)      (None, 2, 2, 512)    0           batch_normalization_9[0][0]
__________________________________________________________________________________________________
conv2d_13 (Conv2D)              (None, 1, 1, 512)    4194816     leaky_re_lu_11[0][0]
__________________________________________________________________________________________________
activation_1 (Activation)       (None, 1, 1, 512)    0           conv2d_13[0][0]
__________________________________________________________________________________________________
conv2d_transpose (Conv2DTranspo (None, 2, 2, 512)    4194816     activation_1[0][0]
__________________________________________________________________________________________________
batch_normalization_10 (BatchNo (None, 2, 2, 512)    2048        conv2d_transpose[0][0]
__________________________________________________________________________________________________
dropout (Dropout)               (None, 2, 2, 512)    0           batch_normalization_10[0][0]
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 2, 2, 1024)   0           dropout[0][0]                    leaky_re_lu_11[0][0]
__________________________________________________________________________________________________
activation_2 (Activation)       (None, 2, 2, 1024)   0           concatenate_1[0][0]
__________________________________________________________________________________________________
conv2d_transpose_1 (Conv2DTrans (None, 4, 4, 512)    8389120     activation_2[0][0]
__________________________________________________________________________________________________
batch_normalization_11 (BatchNo (None, 4, 4, 512)    2048        conv2d_transpose_1[0][0]
__________________________________________________________________________________________________
dropout_1 (Dropout)             (None, 4, 4, 512)    0           batch_normalization_11[0][0]
__________________________________________________________________________________________________
concatenate_2 (Concatenate)     (None, 4, 4, 1024)   0           dropout_1[0][0]                  leaky_re_lu_10[0][0]
__________________________________________________________________________________________________
activation_3 (Activation)       (None, 4, 4, 1024)   0           concatenate_2[0][0]
__________________________________________________________________________________________________
conv2d_transpose_2 (Conv2DTrans (None, 8, 8, 512)    8389120     activation_3[0][0]
__________________________________________________________________________________________________
batch_normalization_12 (BatchNo (None, 8, 8, 512)    2048        conv2d_transpose_2[0][0]
__________________________________________________________________________________________________
dropout_2 (Dropout)             (None, 8, 8, 512)    0           batch_normalization_12[0][0]
__________________________________________________________________________________________________
concatenate_3 (Concatenate)     (None, 8, 8, 1024)   0           dropout_2[0][0]                  leaky_re_lu_9[0][0]
__________________________________________________________________________________________________
activation_4 (Activation)       (None, 8, 8, 1024)   0           concatenate_3[0][0]
__________________________________________________________________________________________________
conv2d_transpose_3 (Conv2DTrans (None, 16, 16, 512)  8389120     activation_4[0][0]
__________________________________________________________________________________________________
batch_normalization_13 (BatchNo (None, 16, 16, 512)  2048        conv2d_transpose_3[0][0]
__________________________________________________________________________________________________
concatenate_4 (Concatenate)     (None, 16, 16, 1024) 0           batch_normalization_13[0][0]     leaky_re_lu_8[0][0]
__________________________________________________________________________________________________
activation_5 (Activation)       (None, 16, 16, 1024) 0           concatenate_4[0][0]
__________________________________________________________________________________________________
conv2d_transpose_4 (Conv2DTrans (None, 32, 32, 256)  4194560     activation_5[0][0]
__________________________________________________________________________________________________
batch_normalization_14 (BatchNo (None, 32, 32, 256)  1024        conv2d_transpose_4[0][0]
__________________________________________________________________________________________________
concatenate_5 (Concatenate)     (None, 32, 32, 512)  0           batch_normalization_14[0][0]     leaky_re_lu_7[0][0]
__________________________________________________________________________________________________
activation_6 (Activation)       (None, 32, 32, 512)  0           concatenate_5[0][0]
__________________________________________________________________________________________________
conv2d_transpose_5 (Conv2DTrans (None, 64, 64, 128)  1048704     activation_6[0][0]
__________________________________________________________________________________________________
batch_normalization_15 (BatchNo (None, 64, 64, 128)  512         conv2d_transpose_5[0][0]
__________________________________________________________________________________________________
concatenate_6 (Concatenate)     (None, 64, 64, 256)  0           batch_normalization_15[0][0]     leaky_re_lu_6[0][0]
__________________________________________________________________________________________________
activation_7 (Activation)       (None, 64, 64, 256)  0           concatenate_6[0][0]
__________________________________________________________________________________________________
conv2d_transpose_6 (Conv2DTrans (None, 128, 128, 64) 262208      activation_7[0][0]
__________________________________________________________________________________________________
batch_normalization_16 (BatchNo (None, 128, 128, 64) 256         conv2d_transpose_6[0][0]
__________________________________________________________________________________________________
concatenate_7 (Concatenate)     (None, 128, 128, 128 0           batch_normalization_16[0][0]     leaky_re_lu_5[0][0]
__________________________________________________________________________________________________
activation_8 (Activation)       (None, 128, 128, 128 0           concatenate_7[0][0]
__________________________________________________________________________________________________
conv2d_transpose_7 (Conv2DTrans (None, 256, 256, 3)  6147        activation_8[0][0]
__________________________________________________________________________________________________
activation_9 (Activation)       (None, 256, 256, 3)  0           conv2d_transpose_7[0][0]
==================================================================================================
Total params: 54,429,315
Trainable params: 54,419,459
Non-trainable params: 9,856

7. 加载真实图片以及生成假的图片

load_real_samples方法是加载真实图片。
generate_real_samples 方法是生成真实图片。每个数组标签都是1, shape是(16,16,1)
generate_fake_samples方法是生成假的图片。每个数组标签都是0,shape是(16,16,1)
标签这里不一样,一般是数字,但是这里是shape为(16,16,1)三维数组。

# load and prepare training images
def load_real_samples(filename):# load compressed arraysdata = load(filename)# unpack arraysX1, X2 = data['arr_0'], data['arr_1']# scale from [0,255] to [-1,1]X1 = (X1 - 127.5) / 127.5X2 = (X2 - 127.5) / 127.5return [X1, X2]# select a batch of random samples, returns images and target
def generate_real_samples(dataset, n_samples, patch_shape):# unpack datasettrainA, trainB = dataset# choose random instancesix = randint(0, trainA.shape[0], n_samples)# retrieve selected imagesX1, X2 = trainA[ix], trainB[ix]# generate 'real' class labels (1)y = ones((n_samples, patch_shape, patch_shape, 1))return [X1, X2], y# generate a batch of images, returns images and targets
def generate_fake_samples(g_model, samples, patch_shape):# generate fake instanceX = g_model.predict(samples)# create 'fake' class labels (0)y = zeros((len(X), patch_shape, patch_shape, 1))return X, y

8. 用生成器每个几个Epoch生成一些假的图片。看看效果

# generate samples and save as a plot and save the model
def summarize_performance(step, g_model, dataset, n_samples=3):# select a sample of input images[X_realA, X_realB], _ = generate_real_samples(dataset, n_samples, 1)# generate a batch of fake samplesX_fakeB, _ = generate_fake_samples(g_model, X_realA, 1)# scale all pixels from [-1,1] to [0,1]X_realA = (X_realA + 1) / 2.0X_realB = (X_realB + 1) / 2.0X_fakeB = (X_fakeB + 1) / 2.0# plot real source imagesfor i in range(n_samples):pyplot.subplot(3, n_samples, 1 + i)pyplot.axis('off')pyplot.imshow(X_realA[i])# plot generated target imagefor i in range(n_samples):pyplot.subplot(3, n_samples, 1 + n_samples + i)pyplot.axis('off')pyplot.imshow(X_fakeB[i])# plot real target imagefor i in range(n_samples):pyplot.subplot(3, n_samples, 1 + n_samples*2 + i)pyplot.axis('off')pyplot.imshow(X_realB[i])# save plot to filefilename1 = 'pix2pix_plot_%06d.png' % (step+1)pyplot.savefig(filename1)pyplot.close()# save the generator modelfilename2 = 'pix2pix_model_%06d.h5' % (step+1)g_model.save(filename2)print('>Saved: %s and %s' % (filename1, filename2))

10. 训练过程

# train pix2pix models
def train(d_model, g_model, gan_model, dataset, n_epochs=100, n_batch=1):# determine the output square shape of the discriminatorn_patch = d_model.output_shape[1]# unpack datasettrainA, trainB = dataset# calculate the number of batches per training epochbat_per_epo = int(len(trainA) / n_batch)# calculate the number of training iterationsn_steps = bat_per_epo * n_epochs# manually enumerate epochsfor i in range(n_steps):# select a batch of real samples[X_realA, X_realB], y_real = generate_real_samples(dataset, n_batch, n_patch)# generate a batch of fake samplesX_fakeB, y_fake = generate_fake_samples(g_model, X_realA, n_patch)# update discriminator for real samplesd_loss1 = d_model.train_on_batch([X_realA, X_realB], y_real)# update discriminator for generated samplesd_loss2 = d_model.train_on_batch([X_realA, X_fakeB], y_fake)# update the generatorg_loss, _, _ = gan_model.train_on_batch(X_realA, [y_real, X_realB])# summarize performanceprint('>%d, d1[%.3f] d2[%.3f] g[%.3f]' % (i+1, d_loss1, d_loss2, g_loss))# summarize model performanceif (i+1) % (bat_per_epo * 10) == 0:summarize_performance(i, g_model, dataset)

11. 训练后效果

在前10个时间段之后,尽管街道的线条并不完全笔直且图像中有些模糊,但仍会生成看起来合理的地图图像。 但是,大型结构在正确的位置带有大多数正确的颜色。


经过约50个训练时期后生成的图像开始看起来非常逼真,至少意味着,并且在其余训练过程中质量似乎仍然保持良好。

请注意下面第一个生成的图像示例(右列,中间行),该示例包含比真实Google地图图像更有用的细节。

12.完整的代码

# example of pix2pix gan for satellite to map image-to-image translation
import tensorflow as tf
import tensorflow.keras as keras
import numpy as np
import matplotlib.pyplot as plt
from numpy import load
from numpy import zeros
from numpy import ones
from numpy.random import randint
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import Conv2DTransposefrom tensorflow.keras.layers import Activation
from tensorflow.keras.layers import Concatenate
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import LeakyReLU
from matplotlib import pyplot# define the discriminator model
def define_discriminator(image_shape):# weight initializationinit = RandomNormal(stddev=0.02)# source image inputin_src_image = Input(shape=image_shape)# target image inputin_target_image = Input(shape=image_shape)# concatenate images channel-wisemerged = Concatenate()([in_src_image, in_target_image])# C64d = Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(merged)d = LeakyReLU(alpha=0.2)(d)# C128d = Conv2D(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)d = BatchNormalization()(d)d = LeakyReLU(alpha=0.2)(d)# C256d = Conv2D(256, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)d = BatchNormalization()(d)d = LeakyReLU(alpha=0.2)(d)# C512d = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)d = BatchNormalization()(d)d = LeakyReLU(alpha=0.2)(d)# second last output layerd = Conv2D(512, (4,4), padding='same', kernel_initializer=init)(d)d = BatchNormalization()(d)d = LeakyReLU(alpha=0.2)(d)# patch outputd = Conv2D(1, (4,4), padding='same', kernel_initializer=init)(d)patch_out = Activation('sigmoid')(d)# define modelmodel = Model([in_src_image, in_target_image], patch_out)# compile modelopt = Adam(lr=0.0002, beta_1=0.5)model.compile(loss='binary_crossentropy', optimizer=opt, loss_weights=[0.5])return model# define an encoder block
def define_encoder_block(layer_in, n_filters, batchnorm=True):# weight initializationinit = RandomNormal(stddev=0.02)# add downsampling layerg = Conv2D(n_filters, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(layer_in)# conditionally add batch normalizationif batchnorm:g = BatchNormalization()(g, training=True)# leaky relu activationg = LeakyReLU(alpha=0.2)(g)return g# define a decoder block
def decoder_block(layer_in, skip_in, n_filters, dropout=True):# weight initializationinit = RandomNormal(stddev=0.02)# add upsampling layerg = Conv2DTranspose(n_filters, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(layer_in)# add batch normalizationg = BatchNormalization()(g, training=True)# conditionally add dropoutif dropout:g = Dropout(0.5)(g, training=True)# merge with skip connectiong = Concatenate()([g, skip_in])# relu activationg = Activation('relu')(g)return g# define the standalone generator model
def define_generator(image_shape=(256,256,3)):# weight initializationinit = RandomNormal(stddev=0.02)# image inputin_image = Input(shape=image_shape)# encoder modele1 = define_encoder_block(in_image, 64, batchnorm=False)e2 = define_encoder_block(e1, 128)e3 = define_encoder_block(e2, 256)e4 = define_encoder_block(e3, 512)e5 = define_encoder_block(e4, 512)e6 = define_encoder_block(e5, 512)e7 = define_encoder_block(e6, 512)# bottleneck, no batch norm and relub = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(e7)b = Activation('relu')(b)# decoder modeld1 = decoder_block(b, e7, 512)d2 = decoder_block(d1, e6, 512)d3 = decoder_block(d2, e5, 512)d4 = decoder_block(d3, e4, 512, dropout=False)d5 = decoder_block(d4, e3, 256, dropout=False)d6 = decoder_block(d5, e2, 128, dropout=False)d7 = decoder_block(d6, e1, 64, dropout=False)# outputg = Conv2DTranspose(3, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d7)out_image = Activation('tanh')(g)# define modelmodel = Model(in_image, out_image)return model# define the combined generator and discriminator model, for updating the generator
def define_gan(g_model, d_model, image_shape):# make weights in the discriminator not trainabled_model.trainable = False# define the source imagein_src = Input(shape=image_shape)# connect the source image to the generator inputgen_out = g_model(in_src)# connect the source input and generator output to the discriminator inputdis_out = d_model([in_src, gen_out])print(dis_out)# src image as input, generated image and classification outputmodel = Model(in_src, [dis_out, gen_out])# compile modelopt = Adam(lr=0.0002, beta_1=0.5)model.compile(loss=['binary_crossentropy', 'mae'], optimizer=opt, loss_weights=[1,100])return model# load and prepare training images
def load_real_samples(filename):# load compressed arraysdata = load(filename)# unpack arraysX1, X2 = data['arr_0'], data['arr_1']# scale from [0,255] to [-1,1]X1 = (X1 - 127.5) / 127.5X2 = (X2 - 127.5) / 127.5return [X1, X2]# select a batch of random samples, returns images and target
def generate_real_samples(dataset, n_samples, patch_shape):# unpack datasettrainA, trainB = dataset# choose random instancesix = randint(0, trainA.shape[0], n_samples)# retrieve selected imagesX1, X2 = trainA[ix], trainB[ix]# generate 'real' class labels (1)y = ones((n_samples, patch_shape, patch_shape, 1))return [X1, X2], y# generate a batch of images, returns images and targets
def generate_fake_samples(g_model, samples, patch_shape):# generate fake instanceX = g_model.predict(samples)# create 'fake' class labels (0)y = zeros((len(X), patch_shape, patch_shape, 1))return X, y# generate samples and save as a plot and save the model
def summarize_performance(step, g_model, dataset, n_samples=3):# select a sample of input images[X_realA, X_realB], _ = generate_real_samples(dataset, n_samples, 1)# generate a batch of fake samplesX_fakeB, _ = generate_fake_samples(g_model, X_realA, 1)# scale all pixels from [-1,1] to [0,1]X_realA = (X_realA + 1) / 2.0X_realB = (X_realB + 1) / 2.0X_fakeB = (X_fakeB + 1) / 2.0# plot real source imagesfor i in range(n_samples):pyplot.subplot(3, n_samples, 1 + i)pyplot.axis('off')pyplot.imshow(X_realA[i])# plot generated target imagefor i in range(n_samples):pyplot.subplot(3, n_samples, 1 + n_samples + i)pyplot.axis('off')pyplot.imshow(X_fakeB[i])# plot real target imagefor i in range(n_samples):pyplot.subplot(3, n_samples, 1 + n_samples*2 + i)pyplot.axis('off')pyplot.imshow(X_realB[i])# save plot to filefilename1 = 'pix2pix_plot_%06d.png' % (step+1)pyplot.savefig(filename1)pyplot.close()# save the generator modelfilename2 = 'pix2pix_model_%06d.h5' % (step+1)g_model.save(filename2)print('>Saved: %s and %s' % (filename1, filename2))# train pix2pix models
def train(d_model, g_model, gan_model, dataset, n_epochs=100, n_batch=1):# determine the output square shape of the discriminatorn_patch = d_model.output_shape[1]# unpack datasettrainA, trainB = dataset# calculate the number of batches per training epochbat_per_epo = int(len(trainA) / n_batch)# calculate the number of training iterationsn_steps = bat_per_epo * n_epochs# manually enumerate epochsfor i in range(n_steps):# select a batch of real samples[X_realA, X_realB], y_real = generate_real_samples(dataset, n_batch, n_patch)# generate a batch of fake samplesX_fakeB, y_fake = generate_fake_samples(g_model, X_realA, n_patch)# update discriminator for real samplesd_loss1 = d_model.train_on_batch([X_realA, X_realB], y_real)# update discriminator for generated samplesd_loss2 = d_model.train_on_batch([X_realA, X_fakeB], y_fake)# update the generatorg_loss, _, _ = gan_model.train_on_batch(X_realA, [y_real, X_realB])# summarize performanceprint('>%d, d1[%.3f] d2[%.3f] g[%.3f]' % (i+1, d_loss1, d_loss2, g_loss))# summarize model performanceif (i+1) % (bat_per_epo * 10) == 0:summarize_performance(i, g_model, dataset)def start_train():# load image datadataset = load_real_samples('maps_256.npz')print('Loaded', dataset[0].shape, dataset[1].shape)# define input shape based on the loaded datasetimage_shape = dataset[0].shape[1:]# define the modelsd_model = define_discriminator(image_shape)print(image_shape)print(d_model.summary())g_model = define_generator(image_shape)# define the composite modelgan_model = define_gan(g_model, d_model, image_shape)# train modeltrain(d_model, g_model, gan_model, dataset)if __name__ == '__main__':#d_model = define_discriminator((256,256,3))#print(d_model.summary())#g_model = define_generator((256,256,3))#print(g_model.summary())#gan_model = define_gan(g_model, d_model, (256,256,3))#print(g_model.summary())start_train()

[深度学习]生成对抗网络的实践例子相关推荐

  1. 深度学习生成对抗网络(GAN)

    一.概述 生成对抗网络(Generative Adversarial Networks)是一种无监督深度学习模型,用来通过计算机生成数据,由Ian J. Goodfellow等人于2014年提出.模型 ...

  2. 你真的了解深度学习生成对抗网络(GAN)吗?

    生成对抗网络(GANs,https://en.wikipedia.org/wiki/Generative_adversarial_network)是一类具有基于网络本身即可以生成数据能力的神经网络结构 ...

  3. [深度学习-实践]条件生成对抗网络cGAN的例子-Tensorflow2.x Keras

    系列文章目录 深度学习GAN(一)之简单介绍 深度学习GAN(二)之DCGAN基于CIFAR10数据集的例子 深度学习GAN(三)之DCGAN基于手写体Mnist数据集的例子 深度学习GAN(四)之c ...

  4. 深度卷积生成对抗网络

    深度卷积生成对抗网络 Deep Convolutional Generative Adversarial Networks GANs如何工作的基本思想.可以从一些简单的,易于抽样的分布,如均匀分布或正 ...

  5. 深度卷积生成对抗网络--DCGAN

    本问转自:https://ask.julyedu.com/question/7681,详情请查看原文 --前言:如何把CNN与GAN结合?DCGAN是这方面最好的尝试之一,DCGAN的原理和GAN是一 ...

  6. DCGAN——深度卷积生成对抗网络

    译文 | 让深度卷积网络对抗:DCGAN--深度卷积生成对抗网络 原文: https://arxiv.org/pdf/1511.06434.pdf -- 前言:如何把CNN与GAN结合?DCGAN是这 ...

  7. 生成对抗网络简介,深度卷积生成对抗网络(DCGAN)简介

    本博客是个人学习的笔记,讲述的是生成对抗网络(generate adversarial network ) 的一种架构:深度生成对抗网络 的简单介绍,下一节将使用 tensorflow 搭建 DCGA ...

  8. 对抗生成网络_深度卷积生成对抗网络

    本教程演示了如何使用深度卷积生成对抗网络(DCGAN)生成手写数字图片.该代码是使用 Keras Sequential API 与 tf.GradientTape 训练循环编写的. 什么是生成对抗网络 ...

  9. 深度卷积生成对抗网络(DCGAN)原理与实现(采用Tensorflow2.x)

    深度卷积生成对抗网络(DCGAN)原理与实现(采用Tensorflow2.x) GAN直观理解 DCGAN网络结构 GAN训练目标 DCGAN实现 数据加载 网络 鉴别网络 生成网络 网络训练 定义损 ...

最新文章

  1. 落谷 P1060 开心的金明
  2. Linux下的QQ截图
  3. 读大话数据结构之二--------算法(上)
  4. Elasticsearch的或且非及其组合
  5. 如何发表自己的第一篇SCI?
  6. 如何为物联网选择微控制器?
  7. 第一:Python+Allure运行报错AttributeError: module ‘allure‘ has no attribute ‘severity_level‘
  8. 精明管理者选人的N种方式
  9. 一般使用作为微型计算机必备,远程教育统考复习之计算机应用基础单选模拟复习题(一)...
  10. 企业架构之道(二)企业架构方法论体系
  11. 记录|深度学习100例-卷积神经网络(CNN)minist数字分类 | 第1天
  12. 2021-08-08在ubuntu上部署nideshop
  13. Statement和PreparedStatement的区别
  14. STM32开发环境配置
  15. 拼多多秒杀活动的谣言
  16. 【手把手】JavaWeb 入门级项目实战 -- 文章发布系统 (作者:剽悍一小兔)第七、八、九节学习随笔
  17. RegSVR32 找不到指定模块问题解决
  18. 谈谈privoxy:关于广告过滤和自动代理切换
  19. android wifi驱动加载失败怎么办,请教WIFI连接失败问题,如何解决
  20. 【Python9】字典与集合

热门文章

  1. Java Virtual Machine Garbage Collection浅析
  2. 《高质量c++/c编程指南》学习摘要
  3. CentOS上使用Docker安装Redis-Cluster (redis6.x)
  4. 【JAVA 第四章 流程控制语句】课后习题 直线斜率 以及判断坐标是否在直线上点到直线的距离
  5. 并行计算教程简介 Introduction to Parallel Computing Tutorial
  6. python标准库之fnmatch,dis,timeit
  7. 将React Native升级到最新版本的最简单方法
  8. javascript入门_您需要一个JavaScript入门工具包
  9. R语言tidyverse数据处理建模案例
  10. 利用循环神经网络生成唐诗_可视化解释11种基本神经网络架构