Pix2Pix原理解析
1.网络搭建
class UnetGenerator(nn.Module):"""Create a Unet-based generator"""def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):"""Construct a Unet generatorParameters:input_nc (int) -- the number of channels in input imagesoutput_nc (int) -- the number of channels in output imagesnum_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,image of size 128x128 will become of size 1x1 # at the bottleneckngf (int) -- the number of filters in the last conv layernorm_layer -- normalization layerWe construct the U-Net from the innermost layer to the outermost layer.It is a recursive process."""super(UnetGenerator, self).__init__()# construct unet structureunet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layerfor i in range(num_downs - 5): # add intermediate layers with ngf * 8 filtersunet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)# gradually reduce the number of filters from ngf * 8 to ngfunet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layerdef forward(self, input):"""Standard forward"""return self.model(input)
Unet的模型结构如下图示,因此是从最内层开始搭建:
经过第一行后,网络结构如下,也就是最内层的下采样->上采样。
之后有一个循环,经过第一次循环后,在上一层的外围再次搭建了下采样和上采样:
经过第二次循环:
经过第三次循环:
可以看到每次反卷积的输入特征图的channel是1024,是因为它除了要接受上一层反卷积的输出(512维度),还要接受与其特征图大小相同的下采样层的输出(512维度),因此是1024的维度数。
循环完毕后,再次添加四次外部的降采样和反卷积,最终的网络结构如下:
UnetGenerator((model): UnetSkipConnectionBlock((model): Sequential((0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(1): UnetSkipConnectionBlock((model): Sequential((0): LeakyReLU(negative_slope=0.2, inplace=True)(1): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(3): UnetSkipConnectionBlock((model): Sequential((0): LeakyReLU(negative_slope=0.2, inplace=True)(1): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(3): UnetSkipConnectionBlock((model): Sequential((0): LeakyReLU(negative_slope=0.2, inplace=True)(1): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(3): UnetSkipConnectionBlock((model): Sequential((0): LeakyReLU(negative_slope=0.2, inplace=True)(1): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(3): UnetSkipConnectionBlock((model): Sequential((0): LeakyReLU(negative_slope=0.2, inplace=True)(1): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(3): UnetSkipConnectionBlock((model): Sequential((0): LeakyReLU(negative_slope=0.2, inplace=True)(1): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(3): UnetSkipConnectionBlock((model): Sequential((0): LeakyReLU(negative_slope=0.2, inplace=True)(1): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(2): ReLU(inplace=True)(3): ConvTranspose2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(4): ReLU(inplace=True)(5): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(6): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(7): Dropout(p=0.5, inplace=False)))(4): ReLU(inplace=True)(5): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(6): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(7): Dropout(p=0.5, inplace=False)))(4): ReLU(inplace=True)(5): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(6): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(7): Dropout(p=0.5, inplace=False)))(4): ReLU(inplace=True)(5): ConvTranspose2d(1024, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(4): ReLU(inplace=True)(5): ConvTranspose2d(512, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(6): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(4): ReLU(inplace=True)(5): ConvTranspose2d(256, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(6): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(2): ReLU(inplace=True)(3): ConvTranspose2d(128, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))(4): Tanh()))
)
2.反向传播过程
我们这里假定pix2pix是风格A2B,风格A就是左边的图,风格B是右边的图。
反向传播的代码如下,整个是先更新D再更新G。
(1)首先向前传播,输入A,经过G,得到fakeB;
(2)开始更新D,进入backward_D函数:
- 将A和fakeB cat起来,cat的整体相当于下图中的negative img,送入D,得到pred_fake;
- 计算pred_fake的GAN损失,标签为0;
- 将A与real B cat起来,cat的整体相当于positive img,送入D,得到real_fake;
- 计算pred_real的GAN损失,标签为1;
- fake和real的GAN相加,得到总的判别器GAN损失。
(3)开始更新G,进入backward_G函数:
- 将A和fakeB cat起来,cat的整体相当于下图中的negative img,送入D,得到pred_fake;
- 计算pred_fake的GAN损失,标签为1;
- 计算real B和fake B的逐像素损失L1;
- 将GAN损失和逐像素损失L1相加,得到总损失。
下图就可视化了上述的过程。
def backward_D(self):"""Calculate GAN loss for the discriminator"""# Fake; stop backprop to the generator by detaching fake_Bfake_AB = torch.cat((self.real_A, self.fake_B), 1) # we use conditional GANs; we need to feed both input and output to the discriminatorpred_fake = self.netD(fake_AB.detach())self.loss_D_fake = self.criterionGAN(pred_fake, False)# Realreal_AB = torch.cat((self.real_A, self.real_B), 1)pred_real = self.netD(real_AB)self.loss_D_real = self.criterionGAN(pred_real, True)# combine loss and calculate gradientsself.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5self.loss_D.backward()def backward_G(self):"""Calculate GAN and L1 loss for the generator"""# First, G(A) should fake the discriminatorfake_AB = torch.cat((self.real_A, self.fake_B), 1)pred_fake = self.netD(fake_AB)self.loss_G_GAN = self.criterionGAN(pred_fake, True)# Second, G(A) = Bself.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1# combine loss and calculate gradientsself.loss_G = self.loss_G_GAN + self.loss_G_L1self.loss_G.backward()def optimize_parameters(self):self.forward() # compute fake images: G(A)# update Dself.set_requires_grad(self.netD, True) # enable backprop for Dself.optimizer_D.zero_grad() # set D's gradients to zeroself.backward_D() # calculate gradients for Dself.optimizer_D.step() # update D's weights# update Gself.set_requires_grad(self.netD, False) # D requires no gradients when optimizing Gself.optimizer_G.zero_grad() # set G's gradients to zeroself.backward_G() # calculate graidents for Gself.optimizer_G.step() # udpate G's weights
3.PatchGAN
pix2pix还对判别器的结构做了一定的改动。之前都是对整张图像输出一个是否为真实的概率。pix2pix提出了PatchGan的概念。PatchGAN对图片中的每一个N×N的小块(patch)计算概率,然后再将这些概率求平均值作为整体的输出。
在上面的代码中pred_fake = self.netD(fake_AB.detach())的输出就不是一个概率值,而是30×30的特征图,相当于有30×30个patch。
下图表示标准的D网络结构(n_layers = 3),n_layers 为主要的特征卷积层数为3。如何理解?
- 下面(0)(1)表示head conv层,不算在n_layers layer中;
- (2)(3)(4)才算做是标准的一个n_layers层,因此2-4、5-7、8-10一共是3层。
- 最后有一个卷积层,channel维度为1。
需要注意一下,patchgan channel维度最大为512。
DataParallel((module): NLayerDiscriminator((model): Sequential((0): Conv2d(6, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))(1): LeakyReLU(negative_slope=0.2, inplace=True)(2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(4): LeakyReLU(negative_slope=0.2, inplace=True)(5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(7): LeakyReLU(negative_slope=0.2, inplace=True)(8): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1), bias=False)(9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(10): LeakyReLU(negative_slope=0.2, inplace=True)(11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))))
)
具体代码如下。与我们前面所述的稍微有些不一样,按照前面所述for n in range(1, n_layers)中相当于构建n_layers个特征提取层。但是代码中实际上构建了n_layers-1个,最后一个标准的特征提取层放在了sequence +=[...]中。
但是理解上还是可以按照前面。在spade框架中,就重新了构建patchgan的过程,其中就把最后一个标准的特征提取层也通过for n in range(1, n_layers)构建了。见https://github.com/NVlabs/SPADE/blob/master/models/networks/discriminator.py
class NLayerDiscriminator(nn.Module):def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):"""Construct a PatchGAN discriminatorParameters:input_nc (int) -- the number of channels in input imagesndf (int) -- the number of filters in the last conv layern_layers (int) -- the number of conv layers in the discriminatornorm_layer -- normalization layer"""super(NLayerDiscriminator, self).__init__()if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parametersuse_bias = norm_layer.func == nn.InstanceNorm2delse:use_bias = norm_layer == nn.InstanceNorm2dkw = 4 #卷积核的大小padw = 1 #padingsequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] #head convnf_mult = 1nf_mult_prev = 1for n in range(1, n_layers): # gradually increase the number of filtersnf_mult_prev = nf_multnf_mult = min(2 ** n, 8)sequence += [nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),norm_layer(ndf * nf_mult),nn.LeakyReLU(0.2, True)]nf_mult_prev = nf_multnf_mult = min(2 ** n_layers, 8)sequence += [nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),norm_layer(ndf * nf_mult),nn.LeakyReLU(0.2, True)]sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # channel = 1self.model = nn.Sequential(*sequence)
4.与CGAN的不同之处
下面这张图是CGAN的示意图。可以看到
- 在CGAN模型中,生成器的输入有两个,分别为一个噪声z,以及对应的条件y(在mnist训练中将图像和标签concat在一起),输出为符合该条件的图像G(z|y)
- 判别器的输入同样也为两个,一个是条件,另一个满足该条件的真实图像x。
pix2pix模型与CGAN最大的不同在于,不再输入噪声z。因为实验中,即便给G输入一个噪声z,G也只学会将其忽略并生成图像,噪声z对输出结果的影响几乎微乎其微。因此为了简洁性,将z去掉了。
pix2pix模型中G的输入实际上等于CGAN模型的条件y。
Pix2Pix原理解析相关推荐
- 深度学习之pix2pix原理解析
今天我来给大家介绍一下基于CGAN的pix2pix模型,给大家简单讲解一下pix2pix的原理. 这里我简单给大家CGAN,我用这张图给大家介绍CGAN的原理: CGAN和传统的GAN不同,传统GAN ...
- Pix2Pix原理解析以及代码流程
文章目录 1.网络搭建 2.反向传播过程 3.PatchGAN 4.与CGAN的不同之处 1.网络搭建 class UnetGenerator(nn.Module):""" ...
- Spark Shuffle原理解析
Spark Shuffle原理解析 一:到底什么是Shuffle? Shuffle中文翻译为"洗牌",需要Shuffle的关键性原因是某种具有共同特征的数据需要最终汇聚到一个计算节 ...
- 秋色园QBlog技术原理解析:性能优化篇:用户和文章计数器方案(十七)
2019独角兽企业重金招聘Python工程师标准>>> 上节概要: 上节 秋色园QBlog技术原理解析:性能优化篇:access的并发极限及分库分散并发方案(十六) 中, 介绍了 ...
- Tomcat 架构原理解析到架构设计借鉴
点击上方"方志朋",选择"设为星标" 回复"666"获取新整理的面试文章 Tomcat 架构原理解析到架构设计借鉴 Tomcat 发展这 ...
- 秋色园QBlog技术原理解析:性能优化篇:数据库文章表分表及分库减压方案(十五)...
文章回顾: 1: 秋色园QBlog技术原理解析:开篇:整体认识(一) --介绍整体文件夹和文件的作用 2: 秋色园QBlog技术原理解析:认识整站处理流程(二) --介绍秋色园业务处理流程 3: 秋色 ...
- CSS实现元素居中原理解析
原文:CSS实现元素居中原理解析 在 CSS 中要设置元素水平垂直居中是一个非常常见的需求了.但就是这样一个从理论上来看似乎实现起来极其简单的,在实践中,它往往难住了很多人. 让元素水平居中相对比较简 ...
- 秋色园QBlog技术原理解析:Web之页面处理-内容填充(八)
文章回顾: 1: 秋色园QBlog技术原理解析:开篇:整体认识(一) --介绍整体文件夹和文件的作用 2: 秋色园QBlog技术原理解析:认识整站处理流程(二) --介绍秋色园业务处理流程 3: 秋色 ...
- 秋色园QBlog技术原理解析:UrlRewrite之无后缀URL原理(三)
文章回顾: 1: 秋色园QBlog技术原理解析:开篇:整体认识(一) --介绍整体文件夹和文件的作用 2: 秋色园QBlog技术原理解析:认识整站处理流程(二) --介绍秋色园业务处理流程 本节,将从 ...
最新文章
- HDU2255(最全权完美匹配)
- 有查看自己dian nao mi |W| ma 的软件
- linux的pthread.h
- XCode4 实践HelloWorld
- z-index的取值范围
- 关于TransactionScope出错:“与基础事务管理器的通信失败”的解决方法
- js 台阶有n级_乔欣这是“开眼角”了?只在眼妆中多加这一步,整个人变美了N倍...
- Entity Framework 学习中级篇1—EF支持复杂类型的实现
- Java的“ for each”循环如何工作?
- MATLAB IIR滤波器设计函数buttord与butter
- hdu acm 1010
- HTML学生网页设计作业源码~开心旅游网站设计与实现(HTML期末大作业)
- 实验二 VB基本界面设计
- springboot 报错“LoggerFactory is not a Logback LoggerContext but Logback is on the classpath.” 解决方式
- sizeof 32位和64位操作系统的区别
- 天津大学仁爱学院c语言期末考试题,天津大学仁爱学院2014-2015学年第1学期期末C语言复习.doc...
- Apache Dubbo详解
- VM虚拟机局域网组网配置
- 稳健收益,缺你不可—A股优秀的基金和基金经理
- word之插入LaTex公式
热门文章
- 亿级流量架构:为什么要扩容?服务器扩容思路及问题分析
- 修改props的属性值,Vue warn]: Avoid mutating a prop directly since the value will be overwritten
- 【OpenCV】 n 点透视问题数学建模及其求解(P3P方法)
- SAP 银企直联 批量获取银行账户数据
- 【Java并发编程】主线程等待子线程的多种方法
- 号称ChatGPT“最强竞争对手”的Claude,今天迎来史诗级更新!
- SVN中检出(check out) 和导出(export) 的区别
- 欧盟自由销售证_小公司和自由职业者如何应对VATMOSS欧盟增值税变更
- 【Fusion360】常用快捷键和技巧
- linux开机界面改为图画,CentOS6.5启动界面的更改的方法