本文介绍CycleGAN原理以及在tensorflow中实现。

一、CycleGAN 的原理

cGAN 和对应的 pix2pix 模型,都能够解决一类“图像翻译 ”问题 。 但是 pix2pix 模型要求训练样本必须是“严格成对”的,这种样本往往比较难以获得,CycleGAN 不必使用成对样本也可以进行“图像翻译”。CycleGAN与 pix2pix的不同点在于,它可以利用不成对数据训练出从 X 空间到 Y 空间的映射 。 例如,只要搜集了大量照片以及大量油画图片,可以学习到如何把照片转换成油画。

CycleGAN 的原理:算法的目标是学习从空间 X 到空间 Y 的映射,设这个映射为 F。 它对应着 GAN 中的生成器, F 可以将 X 中的图片 x 转化为 Y 中的图片 F(x)。对于生成的图片,还需要 GAN 中的判别器来判别器是否为真实图片,由此构成对抗生成网络 。但由于没有成对数据,这个网络是无法训练的。对此,作者又提出了所谓的“循环一致性损失”( cycle consistency loss )。让再假设一个映射 G,它可以将 Y 空间中的图片 y 转换为 X 中的图片 G(y)。 CycleGAN,同时学习 F 和 G 两个映射,并要求 F(G(y)) = y,以及 G(F(x)) = x。也就是说,将 x 的图片转换到 Y 空间后,应该还可以转换回来。

循环一致性损失定义:

总损失定义:

CycleGAN 的主要想法是上述的“循环一致性损失”,利用这个损失 3 可以巧妙地处理 X 空间和 Y 空间训练样本不一一配对的问题。

二、在 TensorFlow 中用训练 CycleGAN 模型

1、下载数据集并进行训练

(1)下载数据集

apple2orange数据集包含了苹果和橘子的图像,运行命令下载数据集:bash download_dataset.sh apple2orange,运行报错:wget: command not found,显然是因为没有安装wget导致的,wget用英语定义就是the non-interactive network downloader,翻译过来就是非交互的网络下载器。这里使用homebrew安装wget,Homebrew为macOS提供缺失的软件包管理器,使用Homebrew可以安装Apple没有预装但你需要的东西,Homebrew会将软件包安装到独立目录,并将其文件软链接至 /usr/local。Homebrew 不会将文件安装到它本身目录之外,所以可将 Homebrew 安装到任意位置。先安装homebrew,命令为:/usr/bin/ruby -e "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/master/install)",安装成功后,利用homwbrew安装wget,命令为:brew install wget,安装成功。download_dataset.sh文件内容如下:

URL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/$FILE.zip
ZIP_FILE=./data/$FILE.zip
TARGET_DIR=./data/$FILE/
wget -N $URL -O $ZIP_FILE
mkdir -p $TARGET_DIR
unzip $ZIP_FILE -d ./data/
rm $ZIP_FILE

运行报错:时间戳与 -O 结合使用没有任何效果,去掉wget的 -N参数即可。-N参数表示只获取比本地新的文件,-O参数表示将文档写入$ZIP_FILE中。至此,运行下载数据集命令不报错,数据集下载完成,生成 data/apple2orange 目录。其中 trainA、 testA 中保存了苹果的图像, trainB、 testB 中保存了橙子的图像,如图:

(2)转换图片格式

由于该项目使用 tfrecords 读取数据,再将图片转换为tfrecords格式(大数据文件格式),命令为:

python build_data.py
--X_input_dir data/apple2orange/trainA
--Y_input_dir data/apple2orange/trainB
--X_output_file data/tfrecords/apple.tfrecords
--Y_output_file data/tfrecords/orange.tfrecords

运行报错 except os.error, e: SyntaxError: invalid syntax,这是python2的捕获方法,在python3中为except Exception as e,因此代码改为:

try:os.makedirs(output_dir)
except os.error as e:pass

至此,数据格式转换成功。

(3)训练模型

运行训练模型的命令:

python train.py
--X data/tfrecords/apple.tfrecords
--Y data/tfrecords/orange.tfrecords
--image_size 256

运行报错:absl.flags._exceptions.IllegalFlagValueError: flag --lambda1=10.0: Expect argument to be a string or int, found <class 'float'>,原因是需要int型,而传入float型,将 tf.flags.DEFINE_integer 改为 flags.DEFINE_float 即可。训练开始后,程序会在 checkpoints 文件夹中建立一个以当前时间命名的目录,如“checkpoints/20190624-1053”,训练时的曰志和模型都会保存在该文件夹中。ckpt为tensorflow的模型文件格式,其他几种格式参考:https://blog.csdn.net/sinat_31337047/article/details/81483006

此外,每隔 100 步,程序还会在屏幕上打出当前步数和损失, 可以通过它们来监控模型的训练。

更方便的做法是使用 TensorBoard,即运行 : tensorboard --logdir checkpoints/20190624-1053/,运行命令报错:AttributeError: module 'tensorboard.util' has no attribute 'Retrier',原因是tensorboard与tensorflow版本不符合。使用pip list 查看tensorflow版本,我的是1.13.1,在https://github.com/tensorflow/tensorboard/releases?after=1.13.1找到相对应的tensorboard版本,即1.13.0,重新安装1.13.0的tensorboard即可:pip install tensorboard==1.13.0。

(4)用训练好的模型进行测试

将模型导出为 pb 文件,运行命令:

python export_graph.py
--checkpoint_dir checkpoints/20190624-1146/
--XtoY_model apple2orange.pb
--YtoX_model orange2apole.pb
--image_size 256

运行命令使用模型pretrained/apple2orange.pb将图片data/apple2orange/testA/n07740461_1661.jpg进行转换,把生成的图片存放到data/apple2orange/output_sample.jpg中,如下:

python inference.py
--model pretrained/apple2orange.pb
--input data/apple2orange/testA/n07740461_1661.jpg
--output data/apple2orange/output_sample.jpg
--image_size 256

2、使用自己的数据进行训练

(1)准备两个文件夹, 一个文件夹中存放 X 空间内的图片,另一个文件夹中存放 Y 空间 中的文件 。使用数据集 man2woman.zip, 该数据集是一个人脸数据集,用 CycleGAN做一个实验:将男性变成女性以及将女性变成男性 。man2woman 数据集是从 CelebA 数据集中整理得到的,后者是一个大型的人脸数据集拥有 20 万张人脸图片。CelebA数据集下载地址https://pan.baidu.com/s/1eSNpdRG?errno=0&errmsg=Auth%20Login%20Sucess&&bduss=&ssnerror=0&traceid=#list/path=%2F&parentPath=%2F。

(2)为了训练 CycleGAN,需要先将图片转换成 tfrecords 形式。运行命令后,得到了两个 tfrecords 文件 。

(3)直接利用这两个文件进行训练即可。训练的过程比较漫长,最好都打开 TensorBoard 观察训练的 Loss 和图像生成情况 。如果训练的过程发生了中断,可以不从头开始训练,指定--load_model 参数,可以从之前保存的模型中恢复并继续训练。

(4)使用训练好的模型就可以进行测试了。最终男生照片变成女性照片,女性照片变成男性照片。

三、程序结构分析

1、CycleGAN 模型定义(model.py)

def model(self):# 读入x空间数据X_reader = Reader(self.X_train_file, name='X',image_size=self.image_size, batch_size=self.batch_size)# 读入y空间数据Y_reader = Reader(self.Y_train_file, name='Y',image_size=self.image_size, batch_size=self.batch_size)# 将读入数据保存到x、y变量中x = X_reader.feed()y = Y_reader.feed()# 根据 self.G、self.F、x、y定义循环一致性损失 cycle_losscycle_loss = self.cycle_consistency_loss(self.G, self.F, x, y)# X -> Y(self.G)fake_y = self.G(x)# 定义 self.G 生成图片的损失G_gan_loss = self.generator_loss(self.D_Y, fake_y, use_lsgan=self.use_lsgan)G_loss =  G_gan_loss + cycle_loss# 定义 Y 空间鉴别器的损失D_Y_loss = self.discriminator_loss(self.D_Y, y, self.fake_y, use_lsgan=self.use_lsgan)# Y -> X(self.F)fake_x = self.F(y)# 定义 self.F 生成图片的损失F_gan_loss = self.generator_loss(self.D_X, fake_x, use_lsgan=self.use_lsgan)F_loss = F_gan_loss + cycle_loss# 定义 X 空间鉴别器的损失D_X_loss = self.discriminator_loss(self.D_X, x, self.fake_x, use_lsgan=self.use_lsgan)

其中,self.F和self.G是生成器,D_X,D_Y是鉴别器

2、循环一致性损失定义(model.py)

def cycle_consistency_loss(self, G, F, x, y):# L1 损失forward_loss = tf.reduce_mean(tf.abs(F(G(x))-x))backward_loss = tf.reduce_mean(tf.abs(G(F(y))-y))loss = self.lambda1*forward_loss + self.lambda2*backward_lossreturn loss

3、生成器的损失(model.py)

def generator_loss(self, D, fake_y, use_lsgan=True):# use_lsgan指定了是否用LSGAN对应的损失函数。LSGAN是GAN的一种变体,损失函数略有不同。只关注use_lsgan=false的情况if use_lsgan:# 使用均方损失loss = tf.reduce_mean(tf.squared_difference(D(fake_y), REAL_LABEL))else:# D(fake_y)为生成器生成图像是真实图像的概率,D(fake_y)越大,说明生成器越好# 之所以加负号,是因为tensorflow的优化器都默认损失越小越好loss = -tf.reduce_mean(ops.safe_log(D(fake_y))) / 2return loss

4、鉴别器的损失(model.py)

def discriminator_loss(self, D, y, fake_y, use_lsgan=True):# 只关注use_lsgan=falseif use_lsgan:# use mean squared errorerror_real = tf.reduce_mean(tf.squared_difference(D(y), REAL_LABEL))error_fake = tf.reduce_mean(tf.square(D(fake_y)))else:# y是真实数据,D(y)是判别器判断真实数据的对应概率,该值越大,说明判别器的性能越好,同样取负号error_real = -tf.reduce_mean(ops.safe_log(D(y)))# 再使用交叉摘损失并取负值得到error_fakeerror_fake = -tf.reduce_mean(ops.safe_log(1-D(fake_y)))# 总损失loss = (error_real + error_fake) / 2return loss

5、为4个损失定义优化操作

最终定义出4个损失: G一loss、 F_loss、 D_Y一loss、 D_X一loss。其中, G_loss和 F_loss是生成器损失,这两个损失降低则意昧着生成器的性能提高 。D_Y_loss 和 D X loss 是判别器 , 这两个损失的降低意昧着判别器性能提高。在优化时, 对四个损失同时优化即可。

def optimize(self, G_loss, D_Y_loss, F_loss, D_X_loss):# 对四个损失定义优化操作G_optimizer = make_optimizer(G_loss, self.G.variables, name='Adam_G')D_Y_optimizer = make_optimizer(D_Y_loss, self.D_Y.variables, name='Adam_D_Y')F_optimizer =  make_optimizer(F_loss, self.F.variables, name='Adam_F')D_X_optimizer = make_optimizer(D_X_loss, self.D_X.variables, name='Adam_D_X')# tf.no_op()将优化操作保存,直接调用optimizers即可完成对四个损失的优化with tf.control_dependencies([G_optimizer, D_Y_optimizer, F_optimizer, D_X_optimizer]):return tf.no_op(name='optimizers')
# 优化器定义函数
def make_optimizer(loss, variables, name='Adam'):""" Adam optimizer with learning rate 0.0002 for the first 100k steps (~100 epochs)and a linearly decaying rate that goes to zero over the next 100k steps"""global_step = tf.Variable(0, trainable=False)starter_learning_rate = self.learning_rateend_learning_rate = 0.0start_decay_step = 100000decay_steps = 100000beta1 = self.beta1learning_rate = (tf.where(tf.greater_equal(global_step, start_decay_step),tf.train.polynomial_decay(starter_learning_rate, global_step-start_decay_step,decay_steps, end_learning_rate,power=1.0),starter_learning_rate))tf.summary.scalar('learning_rate/{}'.format(name), learning_rate)learning_step = (tf.train.AdamOptimizer(learning_rate, beta1=beta1, name=name).minimize(loss, global_step=global_step, var_list=variables))return learning_step

四、总结

本文首先介绍了CycleGAN的原理,接着在tensorflow中用CycleGAN训练了两个模型(苹果橘子转换,男性女性转换),最后介绍了模型和损失的定义细节。CycleGAN 不 需要成对数据就可以训练,具有较强的通用性,由此产生了大量有创意的应用,例如男女互换。

CycleGAN 与非配对图像转换相关推荐

  1. 图像翻译/Transformer:ITTR: Unpaired Image-to-Image Translation with Transformers用Transfor进行非配对图像对图像的转换

    图像翻译/Transformer:ITTR: https://arxiv.org/abs/2203.16015用Transformer进行非配对图像对图像的转换 0.摘要 1.概述 2.方法 2.1. ...

  2. CycleGAN非配对图像生成,定制你的卡通照

    点击上方"AI搞事情"关注我们 ❝ Paper:<Unpaired Image-to-Image Translation using Cycle-Consistent Adv ...

  3. GAN系列(三) —— CycleGAN无配对图像翻译

    引入 之前讲的Pix2Pix图像翻译模型,要求数据必须成对,也就是说数据都是label好的,有监督的数据 但是我们很多数据都是没有label的,没有配对的 也就是说pix2pix是有配对下的图像翻译, ...

  4. 图像转换 image translation系列(17)| 最新ICCV2021生成对抗GAN汇总梳理

    (1)GAN改进系列 | 最新ICCV2021生成对抗网络GAN论文梳理汇总 图像编辑系列之(2)基于StyleGAN(3)GAN逆映射(4)人脸 (5)语义生成 | ICCV2021生成对抗GAN梳 ...

  5. (一)带有图像到图像转换的移动风格迁移

    目录 介绍 图像到图像的转换 生成对抗学习 我们系列中的CycleGAN 下一步 下载项目代码 - 7.2 MB 介绍 在本系列文章中,我们将展示一个基于循环一致对抗网络(CycleGAN)的移动图像 ...

  6. 极链AI云丨图像转换代表作CycleGAN快速复现

    图像转换代表作CycleGAN快速复现 极链AI云 关注极链AI云公众号,学习更多知识! 目录 图像转换代表作CycleGAN快速复现 一.模型详情 1.模型简介 2.关键词 3.应用场景 4.模型结 ...

  7. DL之CycleGAN:基于TF利用CycleGAN模型对apple2orange数据集实现图像转换—训练测试过程全记录

    DL之CycleGAN:基于TF利用CycleGAN模型对apple2orange数据集实现图像转换-训练&测试过程全记录 目录 apple2orange数据集 输出结果 训练&测试过 ...

  8. 登顶Github趋势榜,非监督GAN算法U-GAT-IT大幅改进图像转换效果

    点击我爱计算机视觉标星,更快获取CVML新技术 近日,GAN的大家族又出一位重量级新成员U-GAT-IT,图像转换效果提升明显,原作者开源代码这两天登顶Github趋势榜,引起极大关注. U-GAT- ...

  9. 风格迁移篇---重用鉴别器进行编码:朝向无监督的图像到图像转换

    文章目录 Abstract 1. Introduction 2. Related Work 3. Our NICE-GAN 3.1. General Formulation 3.2. Architec ...

最新文章

  1. Linux 下 的 cc 和 gcc
  2. php下使用 $_FILE
  3. C++ STL容器——序列式容器(array、vector、deque、list)
  4. openwrt系统安装到云服务器异常,OpenWrt路由器系统下服务OpenClash 安装教程及其折腾踩坑记录...
  5. iOS中js与objective-c的交互(转)
  6. 磁盘分区原理:从MBR到GPT
  7. TypeScript 里的 module 概念
  8. 7-19 求链式线性表的倒数第K项 (20 分)(思路分析+极简代码+超容易理解)
  9. java log4j权限被否定_SLF4J简介与使用(整合log4j)
  10. 创建war类型的maven工程时报web.xml is missing and failOnMissingWebXml is set to true
  11. putty怎么更改为中文_Putty怎么样设置显示中文 设置Putty显示中文
  12. Python学习,55道django面试题,来试试吧
  13. 信息系统集成监理费收取标准_信息产业部信息系统工程监理与咨询服务收费参考标准Word1...
  14. jQuery练习t188,从0到1
  15. vue 多个filters_Vue filters过滤器的使用方法
  16. Node对象的一些方法
  17. 提高Java反射速度的方法以及对setAccessable的误解
  18. 百度地图清除指定覆盖物
  19. 安卓系统管理软件_BlackBerry为部署车载安卓系统保驾护航
  20. 地图分幅编号C 语言编程,地图分幅编号实验报告讲解

热门文章

  1. FreeRTOS检测堆栈溢出方法
  2. Pyramid Squeeze Attention
  3. Putty全屏/退出全屏快捷键
  4. New1cloud_新壹云-游戏加速器行业解决案例分享
  5. 怎样选择好的海外服务器?
  6. 正则匹配ipv4_IPV4和IPV6正则表达式的深入讲解
  7. vb.net 高精度定时器 1ms级
  8. Linux学习月活动zz
  9. 图说c语言,图说C语言重定向.doc
  10. 学习笔记:黑马程序员C++从0到1(3~4)