文章目录

  • 1、网络搭建
  • 2、反向传播过程
  • 3、PatchGAN
  • 4.与CGAN的不同之处

1、网络搭建

class UnetGenerator(nn.Module):"""Create a Unet-based generator"""def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):"""Construct a Unet generatorParameters:input_nc (int)  -- the number of channels in input imagesoutput_nc (int) -- the number of channels in output imagesnum_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,image of size 128x128 will become of size 1x1 # at the bottleneckngf (int)       -- the number of filters in the last conv layernorm_layer      -- normalization layerWe construct the U-Net from the innermost layer to the outermost layer.It is a recursive process."""super(UnetGenerator, self).__init__()# construct unet structureunet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)  # add the innermost layerfor i in range(num_downs - 5):          # add intermediate layers with ngf * 8 filtersunet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)# gradually reduce the number of filters from ngf * 8 to ngfunet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)  # add the outermost layerdef forward(self, input):"""Standard forward"""return self.model(input)

Unet的模型结构如下图示,因此是从最内层开始搭建:

经过第一行后,网络结构如下,也就是最内层的下采样->上采样。

之后有一个循环,经过第一次循环后,在上一层的外围再次搭建了下采样和上采样:

经过第二次循环:

经过第三次循环:

可以看到每次反卷积的输入特征图的channel是1024,是因为它除了要接受上一层反卷积的输出(512维度),还要接受与其特征图大小相同的下采样层的输出(512维度),因此是1024的维度数。

循环完毕后,再次添加四次外部的降采样和反卷积,最终的网络结构如下:

UnetGenerator((model): UnetSkipConnectionBlock((model): Sequential((0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(1): UnetSkipConnectionBlock((model): Sequential((0): LeakyReLU(negative_slope=0.2, inplace=True)(1): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(3): UnetSkipConnectionBlock((model): Sequential((0): LeakyReLU(negative_slope=0.2, inplace=True)(1): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(3): UnetSkipConnectionBlock((model): Sequential((0): LeakyReLU(negative_slope=0.2, inplace=True)(1): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(3): UnetSkipConnectionBlock((model): Sequential((0): LeakyReLU(negative_slope=0.2, inplace=True)(1): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(3): UnetSkipConnectionBlock((model): Sequential((0): LeakyReLU(negative_slope=0.2, inplace=True)(1): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(3): UnetSkipConnectionBlock((model): Sequential((0): LeakyReLU(negative_slope=0.2, inplace=True)(1): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(3): UnetSkipConnectionBlock((model): Sequential((0): LeakyReLU(negative_slope=0.2, inplace=True)(1): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(2): ReLU(inplace=True)(3): ConvTranspose2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(4): ReLU(inplace=True)(5): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(6): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(7): Dropout(p=0.5, inplace=False)))(4): ReLU(inplace=True)(5): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(6): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(7): Dropout(p=0.5, inplace=False)))(4): ReLU(inplace=True)(5): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(6): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(7): Dropout(p=0.5, inplace=False)))(4): ReLU(inplace=True)(5): ConvTranspose2d(1024, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(4): ReLU(inplace=True)(5): ConvTranspose2d(512, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(6): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(4): ReLU(inplace=True)(5): ConvTranspose2d(256, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(6): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(2): ReLU(inplace=True)(3): ConvTranspose2d(128, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))(4): Tanh()))
)

2、反向传播过程

我们这里假定pix2pix是风格A2B,风格A就是左边的图,风格B是右边的图。

反向传播的代码如下,整个是先更新D再更新G。

(1)首先向前传播,输入A,经过G,得到fakeB;

(2)开始更新D,进入backward_D函数:

  • 将A和fakeB cat起来,cat的整体相当于下图中的negative img,送入D,得到pred_fake;
  • 计算pred_fake的GAN损失,标签为0; 将A与real B cat起来,cat的整体相当于positive img,送入D,得到real_fake;
  • 计算pred_real的GAN损失,标签为1;
  • fake和real的GAN相加,得到总的判别器GAN损失。 (3)开始

(3)更新G,进入backward_G函数:

  • 将A和fakeB cat起来,cat的整体相当于下图中的negative img,送入D,得到pred_fake;
  • 计算pred_fake的GAN损失,标签为1;
  • 计算real B和fake B的逐像素损失L1;
  • 将GAN损失和逐像素损失L1相加,得到总损失。

下图就可视化了上述的过程。

    def backward_D(self):"""Calculate GAN loss for the discriminator"""# Fake; stop backprop to the generator by detaching fake_Bfake_AB = torch.cat((self.real_A, self.fake_B), 1)  # we use conditional GANs; we need to feed both input and output to the discriminatorpred_fake = self.netD(fake_AB.detach())self.loss_D_fake = self.criterionGAN(pred_fake, False)# Realreal_AB = torch.cat((self.real_A, self.real_B), 1)pred_real = self.netD(real_AB)self.loss_D_real = self.criterionGAN(pred_real, True)# combine loss and calculate gradientsself.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5self.loss_D.backward()def backward_G(self):"""Calculate GAN and L1 loss for the generator"""# First, G(A) should fake the discriminatorfake_AB = torch.cat((self.real_A, self.fake_B), 1)pred_fake = self.netD(fake_AB)self.loss_G_GAN = self.criterionGAN(pred_fake, True)# Second, G(A) = Bself.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1# combine loss and calculate gradientsself.loss_G = self.loss_G_GAN + self.loss_G_L1self.loss_G.backward()def optimize_parameters(self):self.forward()                   # compute fake images: G(A)# update Dself.set_requires_grad(self.netD, True)  # enable backprop for Dself.optimizer_D.zero_grad()     # set D's gradients to zeroself.backward_D()                # calculate gradients for Dself.optimizer_D.step()          # update D's weights# update Gself.set_requires_grad(self.netD, False)  # D requires no gradients when optimizing Gself.optimizer_G.zero_grad()        # set G's gradients to zeroself.backward_G()                   # calculate graidents for Gself.optimizer_G.step()             # udpate G's weights

3、PatchGAN

pix2pix还对判别器的结构做了一定的改动。之前都是对整张图像输出一个是否为真实的概率。pix2pix提出了PatchGan的概念。PatchGAN对图片中的每一个N×N的小块(patch)计算概率,然后再将这些概率求平均值作为整体的输出。

在上面的代码中pred_fake = self.netD(fake_AB.detach())的输出就不是一个概率值,而是30×30的特征图,相当于有30×30个patch。

下图表示标准的D网络结构(n_layers = 3),n_layers 为主要的特征卷积层数为3。如何理解?

下面(0)(1)表示head conv层,不算在n_layers layer中;
(2)(3)(4)才算做是标准的一个n_layers层,因此2-4、5-7、8-10一共是3层。
最后有一个卷积层,channel维度为1。
需要注意一下,patchgan channel维度最大为512

DataParallel((module): NLayerDiscriminator((model): Sequential((0): Conv2d(6, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))(1): LeakyReLU(negative_slope=0.2, inplace=True)(2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(4): LeakyReLU(negative_slope=0.2, inplace=True)(5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(7): LeakyReLU(negative_slope=0.2, inplace=True)(8): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1), bias=False)(9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(10): LeakyReLU(negative_slope=0.2, inplace=True)(11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))))
)

具体代码如下。与我们前面所述的稍微有些不一样,按照前面所述for n in range(1, n_layers)中相当于构建n_layers个特征提取层。但是代码中实际上构建了n_layers-1个,最后一个标准的特征提取层放在了sequence +=[…]中。

但是理解上还是可以按照前面。在spade框架中,就重新了构建patchgan的过程,其中就把最后一个标准的特征提取层也通过for n in range(1, n_layers)构建了。见[https://github.com/NVlabs/SPADE/blob/master/models/networks/discriminator.py]

class NLayerDiscriminator(nn.Module):def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):"""Construct a PatchGAN discriminatorParameters:input_nc (int)  -- the number of channels in input imagesndf (int)       -- the number of filters in the last conv layern_layers (int)  -- the number of conv layers in the discriminatornorm_layer      -- normalization layer"""super(NLayerDiscriminator, self).__init__()if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parametersuse_bias = norm_layer.func == nn.InstanceNorm2delse:use_bias = norm_layer == nn.InstanceNorm2dkw = 4  #卷积核的大小padw = 1  #padingsequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]  #head convnf_mult = 1nf_mult_prev = 1for n in range(1, n_layers):  # gradually increase the number of filtersnf_mult_prev = nf_multnf_mult = min(2 ** n, 8)sequence += [nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),norm_layer(ndf * nf_mult),nn.LeakyReLU(0.2, True)]nf_mult_prev = nf_multnf_mult = min(2 ** n_layers, 8)sequence += [nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),norm_layer(ndf * nf_mult),nn.LeakyReLU(0.2, True)]sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]  # channel = 1self.model = nn.Sequential(*sequence)

4.与CGAN的不同之处

下面这张图是CGAN的示意图。可以看到

  • 在CGAN模型中,生成器的输入有两个,分别为一个噪声z,以及对应的条件y(在mnist训练中将图像和标签concat在一起),输出为符合该条件的图像G(z|y)
  • 判别器的输入同样也为两个,一个是条件,另一个满足该条件的真实图像x。

pix2pix模型与CGAN最大的不同在于,不再输入噪声z。因为实验中,即便给G输入一个噪声z,G也只学会将其忽略并生成图像,噪声z对输出结果的影响几乎微乎其微。因此为了简洁性,将z去掉了。

pix2pix模型中G的输入实际上等于CGAN模型的条件y。

Pix2Pix原理解析以及代码流程相关推荐

  1. 【NLP】Doc2vec原理解析及代码实践

    本文概览: 1. 句子向量简介 Word2Vec提供了高质量的词向量,并在一些任务中表现良好.虽然Word2Vec提供了高质量的词汇向量,但是仍然没有有效的方法将它们结合成一个高质量的文档向量.对于一 ...

  2. 如何用Diffusion models做interpolation插值任务?——原理解析和代码实战

    Diffusion Models专栏文章汇总:入门与实战  前言:很多Diffusion models的论文里都演示了插值任务,今天我们讲解一下如何用DDIM/DDPM做interpolation任务 ...

  3. Pix2Pix原理解析

    1.网络搭建 class UnetGenerator(nn.Module):"""Create a Unet-based generator""&qu ...

  4. 深度学习之pix2pix原理解析

    今天我来给大家介绍一下基于CGAN的pix2pix模型,给大家简单讲解一下pix2pix的原理. 这里我简单给大家CGAN,我用这张图给大家介绍CGAN的原理: CGAN和传统的GAN不同,传统GAN ...

  5. 深入理解ResNet原理解析及代码实现

    github地址:https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 论文地址:https://arx ...

  6. 深度学习核心技术精讲100篇(九)-Catboost算法原理解析及代码实现

    前言 今天博主来介绍一个超级简单并且又极其实用的boosting算法包Catboost,据开发者所说这一boosting算法是超越Lightgbm和XGBoost的又一个神器. catboost 简介 ...

  7. ORB_SLAM2 原理、论文解读、代码流程

    ORB_SLAM2 原理+论文解读+代码流程 算法原理 Tracking LocalMapping LoopClosing 代码流程 文件的调用关系 重要变量的数据结构 Tracking流程 Loca ...

  8. 【面试必备】奉上最通俗易懂的XGBoost、LightGBM、BERT、XLNet原理解析

    一只小狐狸带你解锁 炼丹术&NLP 秘籍 在非深度学习的机器学习模型中,基于GBDT算法的XGBoost.LightGBM等有着非常优秀的性能,校招算法岗面试中"出镜率"非 ...

  9. 微服务精通之Hystrix原理解析

    前言 经过微服务精通之Ribbon原理解析的学习,我们了解到了服务消费者获取服务提供者实例的过程,在这之后,服务消费者会调用服务提供者的接口.但是在调用接口的过程中,我们经常会遇见服务之间的延迟和通信 ...

最新文章

  1. HTML中Css详细介绍
  2. 【Tuxedo】Tuxedo入门
  3. PCGen的垃圾收集分析
  4. 【计算机网络】TCP端口
  5. 1-9月欧洲新能源车份额上升 混动车注册量增加8.8%
  6. CentOS启动时报错修复
  7. 我的QQ群,欢迎入坑!
  8. 微软张宏江出任金山CEO 求伯君正式退休
  9. linux网络编程 mingw,Windows网络编程
  10. 第五天 面向对象软件分析与设计
  11. 最新Flutter 微信分享功能实现
  12. java 句柄无效_Java开发网 - java.io.IOException: 句柄无效???
  13. 如何将多个excel表格合并成一个_相同表头的多个Excel表格合并成一个Excel表的方法...
  14. zblog php 优化,Zblog单页面优化,Zblog后台地址修改
  15. 高德AR 车道级导航技术演进与实践
  16. linux提取手机rom,提取安卓手机ROM固件中的APP
  17. 墨客科技执行董事袁英:MOAC区块链赋能实体产业的方案与实践
  18. ubuntu 下stl obj ply 3dx fbx等各种格式转pcd方法
  19. 【阿里面试】C++多态和虚函数
  20. excel中如何拷贝已经筛选出来的数据到另外的一表格中

热门文章

  1. 第六十七章 SQL函数 ISNUMERIC
  2. 大学生体测管理系统开发实战
  3. 【老生谈算法】matlab实现sift算法的图像匹配——sift算法
  4. filewriter
  5. SQL学习—基础查询
  6. SpringBoot快速打包注册服务之appassembler教程+绕坑
  7. ChatGPT研究框架(2023)
  8. 采用dlopen、dlsym、dlclose加载动态链接库【总结】
  9. BufferedReader和scanner用法和区别
  10. UE4Material_节点介绍(Unfinished)