1. Pix2Pix介绍


该方法由Phillip Isola等提出。在其2016年题为“使用条件对抗网络的图像到图像翻译”的论文中,该论文于2017年在CVPR上发表。




Pix2Pix GAN已在一系列图像到图像转换任务中得到了证明,例如将地图转换为卫星照片,将黑白照片转换为颜色,将产品草图转换为产品照片。

现在我们已经熟悉了Pix2Pix GAN,下面我们准备一个可用于图像到图像转换的数据集。

2. 下载卫星地图数据集

这个数据集由纽约的卫星图像及其相应的Google地图组成。 图像的转换问题涉及将卫星照片转换为Google地图格式,或者将Google地图图像转换为卫星照片。

数据集在pix2pix网站上提供,可以作为255 MB的zip文件下载。
Download Maps Dataset (maps.tar.gz)



3. 数据预处理(Data Reprocessing)


from os import listdir
from numpy import asarray
from numpy import vstack
from keras.preprocessing.image import img_to_array
from keras.preprocessing.image import load_img
from numpy import savez_compressed# load all images in a directory into memory
def load_images(path, size=(256,512)):src_list, tar_list = list(), list()# enumerate filenames in directory, assume all are imagesfor filename in listdir(path):# load and resize the imagepixels = load_img(path + filename, target_size=size)# convert to numpy arraypixels = img_to_array(pixels)# split into satellite and mapsat_img, map_img = pixels[:, :256], pixels[:, 256:]src_list.append(sat_img)tar_list.append(map_img)return [asarray(src_list), asarray(tar_list)]# dataset path
path = 'D:/ML/datasets/maps/train/'
# load dataset
[src_images, tar_images] = load_images(path)
print('Loaded: ', src_images.shape, tar_images.shape)
# save as compressed numpy array
filename = 'maps_256.npz'
savez_compressed(filename, src_images, tar_images)
print('Saved dataset: ', filename)


Loaded:  (1096, 256, 256, 3) (1096, 256, 256, 3)
Saved dataset:  maps_256.npz


# load the prepared dataset
from numpy import load
from matplotlib import pyplot
# load the dataset
data = load('maps_256.npz')
src_images, tar_images = data['arr_0'], data['arr_1']
print('Loaded: ', src_images.shape, tar_images.shape)
# plot source images
n_samples = 3
for i in range(n_samples):pyplot.subplot(2, n_samples, 1 + i)pyplot.axis('off')pyplot.imshow(src_images[i].astype('uint8'))
# plot target image
for i in range(n_samples):pyplot.subplot(2, n_samples, 1 + n_samples + i)pyplot.axis('off')pyplot.imshow(tar_images[i].astype('uint8'))

4. 定义判别器

这个判别器是基于PatchGAN discriminator model实现的。

注意这里的输入是两个图片,in_src_image是卫星图像, in_target_image是谷歌地图。
激活函数用LeakyReLU, 除了第一层与最后一层,其它都用BatchNormalization.

# define the discriminator model
def define_discriminator(image_shape):# weight initializationinit = RandomNormal(stddev=0.02)# source image inputin_src_image = Input(shape=image_shape)# target image inputin_target_image = Input(shape=image_shape)# concatenate images channel-wisemerged = Concatenate()([in_src_image, in_target_image])# C64d = Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(merged)d = LeakyReLU(alpha=0.2)(d)# C128d = Conv2D(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)d = BatchNormalization()(d)d = LeakyReLU(alpha=0.2)(d)# C256d = Conv2D(256, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)d = BatchNormalization()(d)d = LeakyReLU(alpha=0.2)(d)# C512d = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)d = BatchNormalization()(d)d = LeakyReLU(alpha=0.2)(d)# second last output layerd = Conv2D(512, (4,4), padding='same', kernel_initializer=init)(d)d = BatchNormalization()(d)d = LeakyReLU(alpha=0.2)(d)# patch outputd = Conv2D(1, (4,4), padding='same', kernel_initializer=init)(d)patch_out = Activation('sigmoid')(d)# define modelmodel = Model([in_src_image, in_target_image], patch_out)# compile modelopt = Adam(lr=0.0002, beta_1=0.5)model.compile(loss='binary_crossentropy', optimizer=opt, loss_weights=[0.5])return model
if __name__ == '__main__':d_model = define_discriminator((256,256,3))print(d_model.summary())


5. 定义生成器

生成器是使用U-Net架构的encoder-decoder模型。 该模型获取源图像(例如卫星照片)并生成目标图像(例如Google地图图像)。 它首先通过对输入图像进行下采样或编码到瓶颈层(bottleneck layer),然后对瓶颈(bottleneck layer)表示进行上采样或解码到输出图像的大小来做到这一点。 U-Net体系结构意味着在编码层和相应的解码层之间添加跳过连接(skip-connections),从而形成U形。


生成器的encoder和decoder由convolutional, batch normalization, dropout, and activation layers组成。 这种标准化意味着我们可以开发辅助函数来创建每个图层块,并反复调用它以建立模型的encoder和decoder部分。

下面的define_generator()函数实现了U-Net编码器-解码器生成器模型。 它使用define_encoder_block()帮助函数创建用于编码器的层块,并使用coder_block()函数创建用于解码器的层块。 tanh激活函数在输出层中使用,这意味着生成的图像中的像素值将在[-1,1]范围内。

(256,256,3) ->Encoder-> (1,1,512) -> Decoder -> (256,256,3)

# define an encoder block
def define_encoder_block(layer_in, n_filters, batchnorm=True):# weight initializationinit = RandomNormal(stddev=0.02)# add downsampling layerg = Conv2D(n_filters, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(layer_in)# conditionally add batch normalizationif batchnorm:g = BatchNormalization()(g, training=True)# leaky relu activationg = LeakyReLU(alpha=0.2)(g)return g# define a decoder block
def decoder_block(layer_in, skip_in, n_filters, dropout=True):# weight initializationinit = RandomNormal(stddev=0.02)# add upsampling layerg = Conv2DTranspose(n_filters, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(layer_in)# add batch normalizationg = BatchNormalization()(g, training=True)# conditionally add dropoutif dropout:g = Dropout(0.5)(g, training=True)# merge with skip connectiong = Concatenate()([g, skip_in])# relu activationg = Activation('relu')(g)return g# define the standalone generator model
def define_generator(image_shape=(256,256,3)):# weight initializationinit = RandomNormal(stddev=0.02)# image inputin_image = Input(shape=image_shape)# encoder modele1 = define_encoder_block(in_image, 64, batchnorm=False)e2 = define_encoder_block(e1, 128)e3 = define_encoder_block(e2, 256)e4 = define_encoder_block(e3, 512)e5 = define_encoder_block(e4, 512)e6 = define_encoder_block(e5, 512)e7 = define_encoder_block(e6, 512)# bottleneck, no batch norm and relub = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(e7)b = Activation('relu')(b)# decoder modeld1 = decoder_block(b, e7, 512)d2 = decoder_block(d1, e6, 512)d3 = decoder_block(d2, e5, 512)d4 = decoder_block(d3, e4, 512, dropout=False)d5 = decoder_block(d4, e3, 256, dropout=False)d6 = decoder_block(d5, e2, 128, dropout=False)d7 = decoder_block(d6, e1, 64, dropout=False)# outputg = Conv2DTranspose(3, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d7)out_image = Activation('tanh')(g)# define modelmodel = Model(in_image, out_image)return model
if __name__ == '__main__':g_model = define_generator((256,256,3))print(g_model.summary())


6. 定义GAN模型

GAN的模型主要是训练生成器,所以判别器不训练(d_model.trainable = False)。
输出层是 dis_out=(16,16,1)
gen_out = (256,256,3)

# define the combined generator and discriminator model, for updating the generator
def define_gan(g_model, d_model, image_shape):# make weights in the discriminator not trainabled_model.trainable = False# define the source imagein_src = Input(shape=image_shape)# connect the source image to the generator inputgen_out = g_model(in_src)# connect the source input and generator output to the discriminator inputdis_out = d_model([in_src, gen_out])# src image as input, generated image and classification outputmodel = Model(in_src, [dis_out, gen_out])# compile modelopt = Adam(lr=0.0002, beta_1=0.5)model.compile(loss=['binary_crossentropy', 'mae'], optimizer=opt, loss_weights=[1,100])return model
if __name__ == '__main__':d_model = define_discriminator((256,256,3))g_model = define_generator((256,256,3))gan_model = define_gan(g_model, d_model, (256,256,3))print(g_model.summary())
Model: "model_1"
Layer (type)                    Output Shape         Param #     Connected to
input_3 (InputLayer)            [(None, 256, 256, 3) 0
conv2d_6 (Conv2D)               (None, 128, 128, 64) 3136        input_3[0][0]
leaky_re_lu_5 (LeakyReLU)       (None, 128, 128, 64) 0           conv2d_6[0][0]
conv2d_7 (Conv2D)               (None, 64, 64, 128)  131200      leaky_re_lu_5[0][0]
batch_normalization_4 (BatchNor (None, 64, 64, 128)  512         conv2d_7[0][0]
leaky_re_lu_6 (LeakyReLU)       (None, 64, 64, 128)  0           batch_normalization_4[0][0]
conv2d_8 (Conv2D)               (None, 32, 32, 256)  524544      leaky_re_lu_6[0][0]
batch_normalization_5 (BatchNor (None, 32, 32, 256)  1024        conv2d_8[0][0]
leaky_re_lu_7 (LeakyReLU)       (None, 32, 32, 256)  0           batch_normalization_5[0][0]
conv2d_9 (Conv2D)               (None, 16, 16, 512)  2097664     leaky_re_lu_7[0][0]
batch_normalization_6 (BatchNor (None, 16, 16, 512)  2048        conv2d_9[0][0]
leaky_re_lu_8 (LeakyReLU)       (None, 16, 16, 512)  0           batch_normalization_6[0][0]
conv2d_10 (Conv2D)              (None, 8, 8, 512)    4194816     leaky_re_lu_8[0][0]
batch_normalization_7 (BatchNor (None, 8, 8, 512)    2048        conv2d_10[0][0]
leaky_re_lu_9 (LeakyReLU)       (None, 8, 8, 512)    0           batch_normalization_7[0][0]
conv2d_11 (Conv2D)              (None, 4, 4, 512)    4194816     leaky_re_lu_9[0][0]
batch_normalization_8 (BatchNor (None, 4, 4, 512)    2048        conv2d_11[0][0]
leaky_re_lu_10 (LeakyReLU)      (None, 4, 4, 512)    0           batch_normalization_8[0][0]
conv2d_12 (Conv2D)              (None, 2, 2, 512)    4194816     leaky_re_lu_10[0][0]
batch_normalization_9 (BatchNor (None, 2, 2, 512)    2048        conv2d_12[0][0]
leaky_re_lu_11 (LeakyReLU)      (None, 2, 2, 512)    0           batch_normalization_9[0][0]
conv2d_13 (Conv2D)              (None, 1, 1, 512)    4194816     leaky_re_lu_11[0][0]
activation_1 (Activation)       (None, 1, 1, 512)    0           conv2d_13[0][0]
conv2d_transpose (Conv2DTranspo (None, 2, 2, 512)    4194816     activation_1[0][0]
batch_normalization_10 (BatchNo (None, 2, 2, 512)    2048        conv2d_transpose[0][0]
dropout (Dropout)               (None, 2, 2, 512)    0           batch_normalization_10[0][0]
concatenate_1 (Concatenate)     (None, 2, 2, 1024)   0           dropout[0][0]                    leaky_re_lu_11[0][0]
activation_2 (Activation)       (None, 2, 2, 1024)   0           concatenate_1[0][0]
conv2d_transpose_1 (Conv2DTrans (None, 4, 4, 512)    8389120     activation_2[0][0]
batch_normalization_11 (BatchNo (None, 4, 4, 512)    2048        conv2d_transpose_1[0][0]
dropout_1 (Dropout)             (None, 4, 4, 512)    0           batch_normalization_11[0][0]
concatenate_2 (Concatenate)     (None, 4, 4, 1024)   0           dropout_1[0][0]                  leaky_re_lu_10[0][0]
activation_3 (Activation)       (None, 4, 4, 1024)   0           concatenate_2[0][0]
conv2d_transpose_2 (Conv2DTrans (None, 8, 8, 512)    8389120     activation_3[0][0]
batch_normalization_12 (BatchNo (None, 8, 8, 512)    2048        conv2d_transpose_2[0][0]
dropout_2 (Dropout)             (None, 8, 8, 512)    0           batch_normalization_12[0][0]
concatenate_3 (Concatenate)     (None, 8, 8, 1024)   0           dropout_2[0][0]                  leaky_re_lu_9[0][0]
activation_4 (Activation)       (None, 8, 8, 1024)   0           concatenate_3[0][0]
conv2d_transpose_3 (Conv2DTrans (None, 16, 16, 512)  8389120     activation_4[0][0]
batch_normalization_13 (BatchNo (None, 16, 16, 512)  2048        conv2d_transpose_3[0][0]
concatenate_4 (Concatenate)     (None, 16, 16, 1024) 0           batch_normalization_13[0][0]     leaky_re_lu_8[0][0]
activation_5 (Activation)       (None, 16, 16, 1024) 0           concatenate_4[0][0]
conv2d_transpose_4 (Conv2DTrans (None, 32, 32, 256)  4194560     activation_5[0][0]
batch_normalization_14 (BatchNo (None, 32, 32, 256)  1024        conv2d_transpose_4[0][0]
concatenate_5 (Concatenate)     (None, 32, 32, 512)  0           batch_normalization_14[0][0]     leaky_re_lu_7[0][0]
activation_6 (Activation)       (None, 32, 32, 512)  0           concatenate_5[0][0]
conv2d_transpose_5 (Conv2DTrans (None, 64, 64, 128)  1048704     activation_6[0][0]
batch_normalization_15 (BatchNo (None, 64, 64, 128)  512         conv2d_transpose_5[0][0]
concatenate_6 (Concatenate)     (None, 64, 64, 256)  0           batch_normalization_15[0][0]     leaky_re_lu_6[0][0]
activation_7 (Activation)       (None, 64, 64, 256)  0           concatenate_6[0][0]
conv2d_transpose_6 (Conv2DTrans (None, 128, 128, 64) 262208      activation_7[0][0]
batch_normalization_16 (BatchNo (None, 128, 128, 64) 256         conv2d_transpose_6[0][0]
concatenate_7 (Concatenate)     (None, 128, 128, 128 0           batch_normalization_16[0][0]     leaky_re_lu_5[0][0]
activation_8 (Activation)       (None, 128, 128, 128 0           concatenate_7[0][0]
conv2d_transpose_7 (Conv2DTrans (None, 256, 256, 3)  6147        activation_8[0][0]
activation_9 (Activation)       (None, 256, 256, 3)  0           conv2d_transpose_7[0][0]
Total params: 54,429,315
Trainable params: 54,419,459
Non-trainable params: 9,856

7. 加载真实图片以及生成假的图片

generate_real_samples 方法是生成真实图片。每个数组标签都是1, shape是(16,16,1)

# load and prepare training images
def load_real_samples(filename):# load compressed arraysdata = load(filename)# unpack arraysX1, X2 = data['arr_0'], data['arr_1']# scale from [0,255] to [-1,1]X1 = (X1 - 127.5) / 127.5X2 = (X2 - 127.5) / 127.5return [X1, X2]# select a batch of random samples, returns images and target
def generate_real_samples(dataset, n_samples, patch_shape):# unpack datasettrainA, trainB = dataset# choose random instancesix = randint(0, trainA.shape[0], n_samples)# retrieve selected imagesX1, X2 = trainA[ix], trainB[ix]# generate 'real' class labels (1)y = ones((n_samples, patch_shape, patch_shape, 1))return [X1, X2], y# generate a batch of images, returns images and targets
def generate_fake_samples(g_model, samples, patch_shape):# generate fake instanceX = g_model.predict(samples)# create 'fake' class labels (0)y = zeros((len(X), patch_shape, patch_shape, 1))return X, y

8. 用生成器每个几个Epoch生成一些假的图片。看看效果

# generate samples and save as a plot and save the model
def summarize_performance(step, g_model, dataset, n_samples=3):# select a sample of input images[X_realA, X_realB], _ = generate_real_samples(dataset, n_samples, 1)# generate a batch of fake samplesX_fakeB, _ = generate_fake_samples(g_model, X_realA, 1)# scale all pixels from [-1,1] to [0,1]X_realA = (X_realA + 1) / 2.0X_realB = (X_realB + 1) / 2.0X_fakeB = (X_fakeB + 1) / 2.0# plot real source imagesfor i in range(n_samples):pyplot.subplot(3, n_samples, 1 + i)pyplot.axis('off')pyplot.imshow(X_realA[i])# plot generated target imagefor i in range(n_samples):pyplot.subplot(3, n_samples, 1 + n_samples + i)pyplot.axis('off')pyplot.imshow(X_fakeB[i])# plot real target imagefor i in range(n_samples):pyplot.subplot(3, n_samples, 1 + n_samples*2 + i)pyplot.axis('off')pyplot.imshow(X_realB[i])# save plot to filefilename1 = 'pix2pix_plot_%06d.png' % (step+1)pyplot.savefig(filename1)pyplot.close()# save the generator modelfilename2 = 'pix2pix_model_%06d.h5' % (step+1)g_model.save(filename2)print('>Saved: %s and %s' % (filename1, filename2))

10. 训练过程

# train pix2pix models
def train(d_model, g_model, gan_model, dataset, n_epochs=100, n_batch=1):# determine the output square shape of the discriminatorn_patch = d_model.output_shape[1]# unpack datasettrainA, trainB = dataset# calculate the number of batches per training epochbat_per_epo = int(len(trainA) / n_batch)# calculate the number of training iterationsn_steps = bat_per_epo * n_epochs# manually enumerate epochsfor i in range(n_steps):# select a batch of real samples[X_realA, X_realB], y_real = generate_real_samples(dataset, n_batch, n_patch)# generate a batch of fake samplesX_fakeB, y_fake = generate_fake_samples(g_model, X_realA, n_patch)# update discriminator for real samplesd_loss1 = d_model.train_on_batch([X_realA, X_realB], y_real)# update discriminator for generated samplesd_loss2 = d_model.train_on_batch([X_realA, X_fakeB], y_fake)# update the generatorg_loss, _, _ = gan_model.train_on_batch(X_realA, [y_real, X_realB])# summarize performanceprint('>%d, d1[%.3f] d2[%.3f] g[%.3f]' % (i+1, d_loss1, d_loss2, g_loss))# summarize model performanceif (i+1) % (bat_per_epo * 10) == 0:summarize_performance(i, g_model, dataset)

11. 训练后效果

在前10个时间段之后,尽管街道的线条并不完全笔直且图像中有些模糊,但仍会生成看起来合理的地图图像。 但是,大型结构在正确的位置带有大多数正确的颜色。




# example of pix2pix gan for satellite to map image-to-image translation
import tensorflow as tf
import tensorflow.keras as keras
import numpy as np
import matplotlib.pyplot as plt
from numpy import load
from numpy import zeros
from numpy import ones
from numpy.random import randint
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import Conv2DTransposefrom tensorflow.keras.layers import Activation
from tensorflow.keras.layers import Concatenate
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import LeakyReLU
from matplotlib import pyplot# define the discriminator model
def define_discriminator(image_shape):# weight initializationinit = RandomNormal(stddev=0.02)# source image inputin_src_image = Input(shape=image_shape)# target image inputin_target_image = Input(shape=image_shape)# concatenate images channel-wisemerged = Concatenate()([in_src_image, in_target_image])# C64d = Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(merged)d = LeakyReLU(alpha=0.2)(d)# C128d = Conv2D(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)d = BatchNormalization()(d)d = LeakyReLU(alpha=0.2)(d)# C256d = Conv2D(256, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)d = BatchNormalization()(d)d = LeakyReLU(alpha=0.2)(d)# C512d = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)d = BatchNormalization()(d)d = LeakyReLU(alpha=0.2)(d)# second last output layerd = Conv2D(512, (4,4), padding='same', kernel_initializer=init)(d)d = BatchNormalization()(d)d = LeakyReLU(alpha=0.2)(d)# patch outputd = Conv2D(1, (4,4), padding='same', kernel_initializer=init)(d)patch_out = Activation('sigmoid')(d)# define modelmodel = Model([in_src_image, in_target_image], patch_out)# compile modelopt = Adam(lr=0.0002, beta_1=0.5)model.compile(loss='binary_crossentropy', optimizer=opt, loss_weights=[0.5])return model# define an encoder block
def define_encoder_block(layer_in, n_filters, batchnorm=True):# weight initializationinit = RandomNormal(stddev=0.02)# add downsampling layerg = Conv2D(n_filters, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(layer_in)# conditionally add batch normalizationif batchnorm:g = BatchNormalization()(g, training=True)# leaky relu activationg = LeakyReLU(alpha=0.2)(g)return g# define a decoder block
def decoder_block(layer_in, skip_in, n_filters, dropout=True):# weight initializationinit = RandomNormal(stddev=0.02)# add upsampling layerg = Conv2DTranspose(n_filters, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(layer_in)# add batch normalizationg = BatchNormalization()(g, training=True)# conditionally add dropoutif dropout:g = Dropout(0.5)(g, training=True)# merge with skip connectiong = Concatenate()([g, skip_in])# relu activationg = Activation('relu')(g)return g# define the standalone generator model
def define_generator(image_shape=(256,256,3)):# weight initializationinit = RandomNormal(stddev=0.02)# image inputin_image = Input(shape=image_shape)# encoder modele1 = define_encoder_block(in_image, 64, batchnorm=False)e2 = define_encoder_block(e1, 128)e3 = define_encoder_block(e2, 256)e4 = define_encoder_block(e3, 512)e5 = define_encoder_block(e4, 512)e6 = define_encoder_block(e5, 512)e7 = define_encoder_block(e6, 512)# bottleneck, no batch norm and relub = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(e7)b = Activation('relu')(b)# decoder modeld1 = decoder_block(b, e7, 512)d2 = decoder_block(d1, e6, 512)d3 = decoder_block(d2, e5, 512)d4 = decoder_block(d3, e4, 512, dropout=False)d5 = decoder_block(d4, e3, 256, dropout=False)d6 = decoder_block(d5, e2, 128, dropout=False)d7 = decoder_block(d6, e1, 64, dropout=False)# outputg = Conv2DTranspose(3, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d7)out_image = Activation('tanh')(g)# define modelmodel = Model(in_image, out_image)return model# define the combined generator and discriminator model, for updating the generator
def define_gan(g_model, d_model, image_shape):# make weights in the discriminator not trainabled_model.trainable = False# define the source imagein_src = Input(shape=image_shape)# connect the source image to the generator inputgen_out = g_model(in_src)# connect the source input and generator output to the discriminator inputdis_out = d_model([in_src, gen_out])print(dis_out)# src image as input, generated image and classification outputmodel = Model(in_src, [dis_out, gen_out])# compile modelopt = Adam(lr=0.0002, beta_1=0.5)model.compile(loss=['binary_crossentropy', 'mae'], optimizer=opt, loss_weights=[1,100])return model# load and prepare training images
def load_real_samples(filename):# load compressed arraysdata = load(filename)# unpack arraysX1, X2 = data['arr_0'], data['arr_1']# scale from [0,255] to [-1,1]X1 = (X1 - 127.5) / 127.5X2 = (X2 - 127.5) / 127.5return [X1, X2]# select a batch of random samples, returns images and target
def generate_real_samples(dataset, n_samples, patch_shape):# unpack datasettrainA, trainB = dataset# choose random instancesix = randint(0, trainA.shape[0], n_samples)# retrieve selected imagesX1, X2 = trainA[ix], trainB[ix]# generate 'real' class labels (1)y = ones((n_samples, patch_shape, patch_shape, 1))return [X1, X2], y# generate a batch of images, returns images and targets
def generate_fake_samples(g_model, samples, patch_shape):# generate fake instanceX = g_model.predict(samples)# create 'fake' class labels (0)y = zeros((len(X), patch_shape, patch_shape, 1))return X, y# generate samples and save as a plot and save the model
def summarize_performance(step, g_model, dataset, n_samples=3):# select a sample of input images[X_realA, X_realB], _ = generate_real_samples(dataset, n_samples, 1)# generate a batch of fake samplesX_fakeB, _ = generate_fake_samples(g_model, X_realA, 1)# scale all pixels from [-1,1] to [0,1]X_realA = (X_realA + 1) / 2.0X_realB = (X_realB + 1) / 2.0X_fakeB = (X_fakeB + 1) / 2.0# plot real source imagesfor i in range(n_samples):pyplot.subplot(3, n_samples, 1 + i)pyplot.axis('off')pyplot.imshow(X_realA[i])# plot generated target imagefor i in range(n_samples):pyplot.subplot(3, n_samples, 1 + n_samples + i)pyplot.axis('off')pyplot.imshow(X_fakeB[i])# plot real target imagefor i in range(n_samples):pyplot.subplot(3, n_samples, 1 + n_samples*2 + i)pyplot.axis('off')pyplot.imshow(X_realB[i])# save plot to filefilename1 = 'pix2pix_plot_%06d.png' % (step+1)pyplot.savefig(filename1)pyplot.close()# save the generator modelfilename2 = 'pix2pix_model_%06d.h5' % (step+1)g_model.save(filename2)print('>Saved: %s and %s' % (filename1, filename2))# train pix2pix models
def train(d_model, g_model, gan_model, dataset, n_epochs=100, n_batch=1):# determine the output square shape of the discriminatorn_patch = d_model.output_shape[1]# unpack datasettrainA, trainB = dataset# calculate the number of batches per training epochbat_per_epo = int(len(trainA) / n_batch)# calculate the number of training iterationsn_steps = bat_per_epo * n_epochs# manually enumerate epochsfor i in range(n_steps):# select a batch of real samples[X_realA, X_realB], y_real = generate_real_samples(dataset, n_batch, n_patch)# generate a batch of fake samplesX_fakeB, y_fake = generate_fake_samples(g_model, X_realA, n_patch)# update discriminator for real samplesd_loss1 = d_model.train_on_batch([X_realA, X_realB], y_real)# update discriminator for generated samplesd_loss2 = d_model.train_on_batch([X_realA, X_fakeB], y_fake)# update the generatorg_loss, _, _ = gan_model.train_on_batch(X_realA, [y_real, X_realB])# summarize performanceprint('>%d, d1[%.3f] d2[%.3f] g[%.3f]' % (i+1, d_loss1, d_loss2, g_loss))# summarize model performanceif (i+1) % (bat_per_epo * 10) == 0:summarize_performance(i, g_model, dataset)def start_train():# load image datadataset = load_real_samples('maps_256.npz')print('Loaded', dataset[0].shape, dataset[1].shape)# define input shape based on the loaded datasetimage_shape = dataset[0].shape[1:]# define the modelsd_model = define_discriminator(image_shape)print(image_shape)print(d_model.summary())g_model = define_generator(image_shape)# define the composite modelgan_model = define_gan(g_model, d_model, image_shape)# train modeltrain(d_model, g_model, gan_model, dataset)if __name__ == '__main__':#d_model = define_discriminator((256,256,3))#print(d_model.summary())#g_model = define_generator((256,256,3))#print(g_model.summary())#gan_model = define_gan(g_model, d_model, (256,256,3))#print(g_model.summary())start_train()


