还没有女朋友的朋友们,你们有福了,学会CycleGAN把男朋友变成女朋友

  • 前言
  • 效果展示
  • 使用 CycleGAN 进行不成对的图像转换
    • 不成对的数据集
  • CycleGAN模型
    • 数据集
    • 数据加载与预处理
    • 模型构建
    • 训练结果可视化函数
    • 训练步骤
  • 效果二次展示

前言

事情的起因是这样的,室友在经历的4年的找女朋友之旅后,终于放弃了,而我为了让他的青春不留遗憾,只能使用 CycleGAN 把下铺壮汉变成萌妹了。
转眼又到了毕业季,还在为没有女朋友而着急么?还在为没有谈一场青春的恋爱而遗憾么?还没有女朋友的朋友们,你们有福了!!!没有女朋友,还能没有男朋友么?学会 CycleGAN ,把男朋友变成女朋友,赶快学起来吧。

效果展示

在学习之前,大家肯定想先知道CycleGAN模型进行男女性别转换的效果如何,所以先让大家看看模型训练的效果.

效果这么惊人,还不快学起来???

使用 CycleGAN 进行不成对的图像转换

CycleGAN 可以使用两个生成器和两个鉴别器训练不成对(unpaired)的图像。
本文主要以实战为主,如果想要了解 CycleGAN 背后的具体原理,请参考 CycleGAN 原理与实现(采用tensorflow2.x实现).

不成对的数据集

CycleGAN 的一个重要贡献是,改变了pix2pix需要成对的训练数据集的缺点。某些情况下,我们可以很容易地创建成对数据集,如彩色的图像对应的灰度图像数据集,完成成对数据集的构建,从而用于训练灰度图像上色的深度学习模型。但是,更多数的情况下,无法创建成对的数据集,例如从男性到女性的图像转换。
这便是 CycleGAN 的优势所在,因为它不需要成对的数据, CycleGAN 可以训练不成对的数据集!

CycleGAN模型

简单看下CycleGAN的体系架构:

数据集

数据集取自 Celeb A ,可以自行构建数据集,也可以使用此数据集,提取码:nql9。

数据加载与预处理

# 导入必要库
import tensorflow as tf
import os
import time
from matplotlib import pyplot as plt
import tensorflow_datasets as tfds
AUTOTUNE = tf.data.experimental.AUTOTUNE# 定义超参数
BUFFER_SIZE = 128
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256
OUTPUT_CHANNELS = 3
LAMBDA = 10
EPOCHS = 100
"""
# 数据预处理函数
"""
def random_crop(image):cropped_image = tf.image.random_crop(image, size=[IMG_HEIGHT, IMG_WIDTH, 3])return cropped_image# normalizing the images to [-1, 1]
def normalize(image):image = tf.cast(image, tf.float32)image = (image / 127.5) - 1return imagedef random_jitter(image):# resizing to 286 x 286 x 3image = tf.image.resize(image, [286, 286], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)# randomly cropping to 256 x 256 x 3image = random_crop(image)# random mirroringimage = tf.image.random_flip_left_right(image)return imagedef preprocess_image_train(image):image = random_jitter(image)image = normalize(image)return imagedef preprocess_image_test(image):image = normalize(image)return imagedef load(image_file):image = tf.io.read_file(image_file)image = tf.image.decode_jpeg(image)input_image = tf.cast(image, tf.float32)return input_imagedef load_image_train(image_file):image = load(image_file)image = preprocess_image_train(image)return imagedef load_image_test(image_file):image = load(image_file)image = preprocess_image_test(image)return image# 加载男性图片,构建训练数据集
train_man = tf.data.Dataset.list_files('./man2woman/trainA/*.jpg')
train_man = train_man.map(load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_man = train_man.shuffle(BUFFER_SIZE)
train_man = train_man.batch(BATCH_SIZE, drop_remainder=True)
# 加载女性图片,构建训练数据集
train_woman = tf.data.Dataset.list_files('./man2woman/trainB/*.jpg')
train_woman = train_woman.map(load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_woman = train_woman.shuffle(BUFFER_SIZE)
train_woman = train_woman.batch(BATCH_SIZE, drop_remainder=True)

模型构建

在 CycleGAN 中,使用实例归一化代替批归一化,但在 tensorflow 中,未包含实例归一化层,因此需要自行实现。

class InstanceNormalization(tf.keras.layers.Layer):"""Instance Normalization Layer."""def __init__(self, epsilon=1e-5):super(InstanceNormalization, self).__init__()self.epsilon = epsilondef build(self, input_shape):self.scale = self.add_weight(name='scale', shape=input_shape[-1:],initializer=tf.random_normal_initializer(1., 0.02),trainable=True)self.offset = self.add_weight(name='offset',shape=input_shape[-1:],initializer='zeros',trainable=True)def call(self, x):mean, variance = tf.nn.moments(x, axes=[1, 2], keepdims=True)inv = tf.math.rsqrt(variance + self.epsilon)normalized = (x - mean) * invreturn self.scale * normalized + self.offset

为了减少代码量,定义上采样块和下采样块:

# 下采样块
def downsample(filters, size, norm_type='batchnorm', apply_norm=True):initializer = tf.random_normal_initializer(0., 0.02)result = tf.keras.Sequential()result.add(tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',kernel_initializer=initializer, use_bias=False))if apply_norm:if norm_type.lower() == 'batchnorm':result.add(tf.keras.layers.BatchNormalization())elif norm_type.lower() == 'instancenorm':result.add(InstanceNormalization())result.add(tf.keras.layers.LeakyReLU())return result# 上采样快
def upsample(filters, size, norm_type='batchnorm', apply_dropout=False):initializer = tf.random_normal_initializer(0., 0.02)result = tf.keras.Sequential()result.add(tf.keras.layers.Conv2DTranspose(filters, size, strides=2,padding='same',kernel_initializer=initializer,use_bias=False))if norm_type.lower() == 'batchnorm':result.add(tf.keras.layers.BatchNormalization())elif norm_type.lower() == 'instancenorm':result.add(InstanceNormalization())if apply_dropout:result.add(tf.keras.layers.Dropout(0.5))result.add(tf.keras.layers.ReLU())return result

接下来构建生成器:

def unet_generator(output_channels, norm_type='batchnorm'):down_stack = [downsample(64, 4, norm_type, apply_norm=False), downsample(128, 4, norm_type),downsample(256, 4, norm_type),downsample(512, 4, norm_type),downsample(512, 4, norm_type),downsample(512, 4, norm_type),downsample(512, 4, norm_type),downsample(512, 4, norm_type),]up_stack = [upsample(512, 4, norm_type, apply_dropout=True),upsample(512, 4, norm_type, apply_dropout=True),upsample(512, 4, norm_type, apply_dropout=True),upsample(512, 4, norm_type),upsample(256, 4, norm_type),upsample(128, 4, norm_type),upsample(64, 4, norm_type),]initializer = tf.random_normal_initializer(0., 0.02)last = tf.keras.layers.Conv2DTranspose(output_channels, 4, strides=2,padding='same', kernel_initializer=initializer,activation='tanh')  # (bs, 256, 256, 3)concat = tf.keras.layers.Concatenate()inputs = tf.keras.layers.Input(shape=[None, None, 3])x = inputs# Downsampling through the modelskips = []for down in down_stack:x = down(x)skips.append(x)skips = reversed(skips[:-1])# Upsampling and establishing the skip connectionsfor up, skip in zip(up_stack, skips):x = up(x)x = concat([x, skip])x = last(x)return tf.keras.Model(inputs=inputs, outputs=x)

构建鉴别器:

def discriminator(norm_type='batchnorm', target=True):initializer = tf.random_normal_initializer(0., 0.02)inp = tf.keras.layers.Input(shape=[None, None, 3], name='input_image')x = inpif target:tar = tf.keras.layers.Input(shape=[None, None, 3], name='target_image')x = tf.keras.layers.concatenate([inp, tar])  # (bs, 256, 256, channels*2)down1 = downsample(64, 4, norm_type, False)(x)  # (bs, 128, 128, 64)down2 = downsample(128, 4, norm_type)(down1)  # (bs, 64, 64, 128)down3 = downsample(256, 4, norm_type)(down2)  # (bs, 32, 32, 256)zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3)  # (bs, 34, 34, 256)conv = tf.keras.layers.Conv2D(512, 4, strides=1, kernel_initializer=initializer,use_bias=False)(zero_pad1)  # (bs, 31, 31, 512)if norm_type.lower() == 'batchnorm':norm1 = tf.keras.layers.BatchNormalization()(conv)elif norm_type.lower() == 'instancenorm':norm1 = InstanceNormalization()(conv)leaky_relu = tf.keras.layers.LeakyReLU()(norm1)zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu)  # (bs, 33, 33, 512)last = tf.keras.layers.Conv2D(1, 4, strides=1,kernel_initializer=initializer)(zero_pad2)  # (bs, 30, 30, 1)if target:return tf.keras.Model(inputs=[inp, tar], outputs=last)else:return tf.keras.Model(inputs=inp, outputs=last)

实例化生成器与鉴别器:

generator_g = unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
generator_f = unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
discriminator_x = discriminator(norm_type='instancenorm', target=False)
discriminator_y = discriminator(norm_type='instancenorm', target=False)

损失函数与优化器的定义:

loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)
# 鉴别器损失
def discriminator_loss(real, generated):real_loss = loss_obj(tf.ones_like(real), real)generated_loss = loss_obj(tf.zeros_like(generated), generated)total_disc_loss = real_loss + generated_lossreturn total_disc_loss * 0.5
# 生成器损失
def generator_loss(generated):return loss_obj(tf.ones_like(generated), generated)
# 循环一致性损失
def calc_cycle_loss(real_image, cycled_image):loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))return LAMBDA * loss1
# identity loss
def identity_loss(real_image, same_image):loss = tf.reduce_mean(tf.abs(real_image - same_image))return LAMBDA * 0.5 * loss
# 优化器
generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

训练结果可视化函数

创建 generate_images 函数用于在训练过程中查看模型效果.

def generate_images(model, test_input):prediction = model(test_input)plt.figure(figsize=(12, 12))display_list = [test_input[0], prediction[0]]title = ['Input Image', 'Predicted Image']for i in range(2):plt.subplot(1, 2, i+1)plt.title(title[i])# getting the pixel values between [0, 1] to plot it.plt.imshow(display_list[i] * 0.5 + 0.5)plt.axis('off')# plt.show()plt.savefig('results/{}.png'.format(int(time.time())))

训练步骤

首先需要定义训练函数:

@tf.function
def train_step(real_x, real_y):with tf.GradientTape(persistent=True) as tape:# Generator G translates X -> Y# Generator F translates Y -> X.fake_y = generator_g(real_x, training=True)cycled_x = generator_f(fake_y, training=True)fake_x = generator_f(real_y, training=True)cycled_y = generator_g(fake_x, training=True)# same_x and same_y are used for identity loss.same_x = generator_f(real_x, training=True)same_y = generator_g(real_y, training=True)disc_real_x = discriminator_x(real_x, training=True)disc_real_y = discriminator_y(real_y, training=True)disc_fake_x = discriminator_x(fake_x, training=True)disc_fake_y = discriminator_y(fake_y, training=True)# calculate the lossgen_g_loss = generator_loss(disc_fake_y)gen_f_loss = generator_loss(disc_fake_x)total_cycle_loss = calc_cycle_loss(real_x, cycled_x) + calc_cycle_loss(real_y, cycled_y)# Total generator loss = adversarial loss + cycle losstotal_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y)total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x)disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)# Calculate the gradients for generator and discriminatorgenerator_g_gradients = tape.gradient(total_gen_g_loss, generator_g.trainable_variables)generator_f_gradients = tape.gradient(total_gen_f_loss, generator_f.trainable_variables)discriminator_x_gradients = tape.gradient(disc_x_loss, discriminator_x.trainable_variables)discriminator_y_gradients = tape.gradient(disc_y_loss, discriminator_y.trainable_variables)# Apply the gradients to the optimizergenerator_g_optimizer.apply_gradients(zip(generator_g_gradients, generator_g.trainable_variables))generator_f_optimizer.apply_gradients(zip(generator_f_gradients, generator_f.trainable_variables))discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients, discriminator_x.trainable_variables))discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients, discriminator_y.trainable_variables))

最后进行模型的训练:

for epoch in range(EPOCHS):start = time.time()n = 0for image_x, image_y in tf.data.Dataset.zip((train_man, train_woman)):train_step(image_x, image_y)# generate_images(generator_g, sample_man)if n % 10 == 0:print ('.', end='')n += 1# 采样测试数据集, 测试模型效果sample_man = next(iter(train_man))sample_woman = next(iter(train_woman))generate_images(generator_g, sample_man)generate_images(generator_f, sample_woman)print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1, time.time()-start))

效果二次展示

我们已经在开始时看到了 CycleGAN 在男性转换为女性的结果,再看下将女性转换为男性的效果吧!

还没有女朋友的朋友们,你们有福了,学会CycleGAN把男朋友变成女朋友相关推荐

  1. 驾照还没考的朋友有福了!应该是最全的了……

    驾照还没考的朋友有福了!应该是最全的了-- 真的很有用啊!以后留着随时看看,快快转载吧!~ 总结版: 理论知识就背书上的  在路上就关注这些知识    自己印象会很深 容易记住 对自己以后开车  也有 ...

  2. 考驾照01_驾照还没考的女人,你们有福了!

    从此学车人可以自学自考了 不仅如此,驾照还没考的朋友有福了 驾校考试秘笈,不用看书就能通过!! 暂时不考的也先留着,别等着急的时候没处去 速记方法!!! 1.题目里有"口"的选50 ...

  3. 还在为微信朋友圈的大量广告而苦恼吗?一文教你如何清除微信朋友圈的广告!!!

    还在为微信朋友圈的大量广告而苦恼吗?一文教你如何清除微信朋友圈的广告!!! 大家好,我叫亓官劼(qí guān jié ),在CSDN中记录学习的点滴历程,时光荏苒,未来可期,加油~博客地址为:亓官劼 ...

  4. 还好我们有朋友-蝌蚪

    (一) 我们几个聚会唱K的时候,常点的歌就是动力火车的<当>.(当然还爱唱最炫民族风) 虽然我们唱歌常恶搞,但每次我们唱到"让我们红尘作伴,活的潇潇洒洒:策马奔腾共享人世繁华:对 ...

  5. 女朋友可能的若干职业,您选择哪种女孩做女朋友?

    女朋友可能的若干职业,您选择哪种女孩做女朋友? 文/飞天含雪 一, 秘书 特点分析:"秘书"女友一般很漂亮,很有气质,尤其是给大公司老板做秘书的女孩,更是"国色天香&qu ...

  6. miui升级系统无服务器,细数MIUI 11的BUG,还没升级的朋友,先来了解一下

    大家好,我们都知道,作为国内知名的安卓UI,MIUI已经经过十年的发展,今年更是推出了MIUI 11版本.在很多朋友看来,字体.动态音效.息屏显示.儿童空间和MI GO出行等功能,确实是又一次很大的改 ...

  7. 5年软件测试工程师感悟——写给还在迷茫的朋友

    中秋节假期和朋友聚餐,和朋友谈到互联网行业的发展,为什么互联网大厂纷纷大规模裁员?前两 年年大家还在996的工作模式中度过的水深火热,今年好像突然没有了声音,因为很多人都被裁了. 继K12行业断崖式的 ...

  8. 分手的情人还不如最普通的朋友

    心碎离开 转身回到最初荒凉里等待 为了寂寞 是否找个人填心中空白 我们变成了世上 最熟悉的陌生人 今后各自曲折 各自悲哀 只怪我们爱得那么汹涌 爱得那么深 于是梦醒了搁浅了沉默了挥手了 却回不了神 如 ...

  9. 一个老话题~分手还能成为好朋友吗?

    一个老话题--分手了~还能成为朋友吗?   分手了还能成为朋友吗? 以前我看过这里有的!可是我想知道的具体些~ 她我们分手5个月了! 可是我们偶尔还是会联系~ 她打电话我~我也打给她(很少) 很多人说 ...

最新文章

  1. C语言实现的Web服务器
  2. 关于python中的作用域问题
  3. 2019年最新10份开源Java精选资料
  4. 安装Synchronization service (Project Server 2007) 时出现 MSMQ 错误的解决
  5. Job_search_collection
  6. 查看sql_一键查看Oracle数据库当前SQL_WORKAREA_ACTIVE的相关操作
  7. 一个简单的MDX案例及说明 (转载)
  8. C语言程序设计谭浩强第五版复习梳理3
  9. rdkit Recap、BRICS分子片段拆分与合成
  10. flex builder 破解
  11. 趣头条投放广告需要哪些资质?趣头条推广广告怎么样搭建账户?
  12. php随机产生六位数密码
  13. python生成词云图、特殊图形_Python模块---Wordcloud生成词云图
  14. Python:try……excepted捕获方法
  15. pythonmathcot函数_sin cos tan cot公式
  16. Android截屏截图方法所有方法汇总(包括Activity、View、ScrollView、ListView、RecycleView、WebView截屏截图)
  17. 三维绕任意轴旋转矩阵
  18. NO_PROXY is not set
  19. 什么是应用分发?应用分发是什么?
  20. Flex布局脑图总结

热门文章

  1. python的安全插件
  2. solr索引大小对比
  3. [转载] python面面观单元测试_python 使用unittest进行单元测试
  4. [转载] 快速入门(完整):Python实例100个(基于最新Python3.7版本)
  5. Broadwell I7-5775c/5675c BSOD 蓝屏问题
  6. [置顶] Embedded Server:像写main函数一样写Web Server
  7. idea tomcat debug不能启动的问题
  8. Java学习目录(持续更新中)
  9. Python字符的转义
  10. POJ 2457 BFS