Masking GAN pytorch
github代码:https://github.com/tgeorgy/mgan
文章的创新点:
1.生成网络输入x,输出包括分割模板mask,和中间图像y,根据mask将输入x与中间图像y结合,得到生成图像.这样得到的生成图像背景与输入x相同,前景为生成部分.
2.采用端到端训练,在cyclegan损失函数的基础上,添加了对输出生成图像进行约束.
模型结构如下,
生成网络首先输出为分割模板mask,以及中间图像y,将中间图像y和mask混合,得到的输出作为最后的生成生成图像.生成网络代码如下,
class Generator(nn.Module):
def __init__(self, input_nc=3, output_nc=4, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6):
assert(n_blocks >= 0)
super(Generator, self).__init__()
self.input_nc = input_nc
self.output_nc = output_nc
self.ngf = ngf
model = [nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0),
norm_layer(ngf),
nn.ReLU(True)]
n_downsampling = 2
for i in range(n_downsampling):
mult = 2**i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
stride=2, padding=1),
norm_layer(ngf * mult * 2),
nn.ReLU(True)]
mult = 2**n_downsampling
for i in range(n_blocks):
model += [ResnetBlock(ngf * mult, norm_layer=norm_layer, use_dropout=use_dropout)]
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
model += [nn.ReflectionPad2d(1),
nn.Conv2d(ngf * mult, int(ngf * mult / 2),
kernel_size=3, stride=1),
norm_layer(int(ngf * mult / 2)),
nn.ReLU(True),
nn.Conv2d(int(ngf * mult / 2), int(ngf * mult / 2)*4,
kernel_size=1, stride=1),
nn.PixelShuffle(2),
norm_layer(int(ngf * mult / 2)),
nn.ReLU(True),
]
model += [nn.ReflectionPad2d(3)]
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
self.model = nn.Sequential(*model)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
代码中,生成网络输入通道为3,输出通道为4,第一个通道为mask,其他三个通道为中间生成图像.
def forward(self, input):
output = self.model(input)
mask = F.sigmoid(output[:, :1])
oimg = output[:, 1:]
mask = mask.repeat(1, 3, 1, 1)
oimg = oimg*mask + input*(1-mask)
return oimg, mask
1
2
3
4
5
6
7
8
采用cyclegan结构,也就是,包含两个生成网络,两个判别网络.
对于每个生成网络,损失函数包括三个部分,第一个为loss_P2N_cyc ,与cyclegan loss相同,即输入到生成网络g1的输出,在输入生成网络g2,得到输出与输入尽量相同.第二个loss_P2N_gan为gan损失函数,也就是判别网络判断label为真.第三个为loss_N2P_idnt,也就是生成网路g1的输出与label尽量相似,也就是文章是end to end(输入-label对应)训练,由于cyclegan不是end to end,所以没有这个损失函数,
criterion_cycle = nn.L1Loss()
criterion_identity = nn.L1Loss()
criterion_gan = nn.MSELoss()
1
2
3
# Train P2N Generator
real_pos_v = Variable(real_pos)
fake_neg, mask_neg = netP2N(real_pos_v)
rec_pos, _ = netN2P(fake_neg)
fake_neg_lbl = netDN(fake_neg)
loss_P2N_cyc = criterion_cycle(rec_pos, real_pos_v)
loss_P2N_gan = criterion_gan(fake_neg_lbl, Variable(real_lbl))
loss_N2P_idnt = criterion_identity(fake_neg, real_pos_v)
1
2
3
4
5
6
7
8
9
# Train N2P Generator
real_neg_v = Variable(real_neg)
fake_pos, mask_pos = netN2P(real_neg_v)
rec_neg, _ = netP2N(fake_pos)
fake_pos_lbl = netDP(fake_pos)
loss_N2P_cyc = criterion_cycle(rec_neg, real_neg_v)
loss_N2P_gan = criterion_gan(fake_pos_lbl, Variable(real_lbl))
loss_P2N_idnt = criterion_identity(fake_pos, real_neg_v)
loss_G = ((loss_P2N_gan + loss_N2P_gan)*0.5 +
(loss_P2N_cyc + loss_N2P_cyc)*lambda_cycle +
(loss_P2N_idnt + loss_N2P_idnt)*lambda_identity)
1
2
3
4
5
6
7
8
9
10
11
12
13
判别网络用于判别输入的真假,
# Train Discriminators
netDN.zero_grad()
netDP.zero_grad()
fake_neg_score = netDN(fake_neg.detach())
loss_D = criterion_gan(fake_neg_score, Variable(fake_lbl))
fake_pos_score = netDP(fake_pos.detach())
loss_D += criterion_gan(fake_pos_score, Variable(fake_lbl))
real_neg_score = netDN.forward(real_neg_v)
loss_D += criterion_gan(real_neg_score, Variable(real_lbl))
real_pos_score = netDP.forward(real_pos_v)
loss_D += criterion_gan(real_pos_score, Variable(real_lbl))
---------------------
作者:imperfect00
来源:CSDN
原文:https://blog.csdn.net/u011961856/article/details/79057469
版权声明:本文为博主原创文章,转载请附上博文链接!
本文提出了一个域转换网络(domain transfer network,DTN),网络的作用是,对于给定两个域S,T,我们希望学习一个生成函数G,将S域的样本映射到域T,这样,对于一个给定函数f,不管f的输入为来自域S或T,f的输出会保持不变.
网络结构如下:
生成网络包括函数f,g.f用于提取输入图像的特征,得到一个特征向量.g的输入为f的输出,输出为目标风格的图像.训练数据为为无监督数据,即,原图像,目标图像不一一对应,分别采用原图像库,目标风格图像库,作为训练.对于原图像,输入生成网络G,输出风格图像.对于目标库的图像,输入生成网络G,输出还是该图像.
网络还包括一个判别网络D,判别网络的作用是判别输入为生成图像(fake),还是输入图像(real).
损失函数
1.对于生成网络,输入原图像,输出为目标风格的图像.同时我们还希望,当输入为目标图像时,生成网络输出也为目标图像,即生成网络对目标图像起到identity matrix的作用,这样构造损失函数LTIDLTID,
式中,x∈tx∈t表示图像x为目标图像,t为目标图像集合.
2.对与函数f,我们希望输入原图像提取的特征向量和原图像通过生成网络G生成的图像的f函数特征向量尽量相似,
式中,x∈sx∈s表示图像x为原图像,s为原图像集合.
3.判别网络D,用于判别原图像的生成图像,目标图像及目标图像的生成图像,用于判别是生成图像还是输入图像,损失函数为:
式中,D1D1用表示判别原图像经过生成网络G的生成图像.D2D2用于判别目标图像经过生成网络G的生成图像.D2D2用于判别目标图像.
4.对于生成网络,损失函数为:
LG=LGANG+αLCONST+βLTID+γLTVLG=LGANG+αLCONST+βLTID+γLTV
式中,B=1.
代码分析
生成网络的输入,输出为32×32×332×32×3的图像.
特征提取函数f部分网络结构包括4个卷积层,前3个卷积层卷积核为3×33×3,第4个卷积核大小为4×44×4,卷积核的stride=2.对于特征提取函数f,可以在其后加一个卷积层,对输入进行分类,例如对于手写字体,可以将其分为10类.可以对该网络进行分类任务训练,这样便起到了对网络进行预训练的作用.代码如下:
def content_extractor(self, images, reuse=False):
# images: (batch, 32, 32, 3) or (batch, 32, 32, 1)
if images.get_shape()[3] == 1:
# For mnist dataset, replicate the gray scale image 3 times.
images = tf.image.grayscale_to_rgb(images)
with tf.variable_scope('content_extractor', reuse=reuse):
with slim.arg_scope([slim.conv2d], padding='SAME', activation_fn=None,
stride=2, weights_initializer=tf.contrib.layers.xavier_initializer()):
with slim.arg_scope([slim.batch_norm], decay=0.95, center=True, scale=True,
activation_fn=tf.nn.relu, is_training=(self.mode=='train' or self.mode=='pretrain')):
net = slim.conv2d(images, 64, [3, 3], scope='conv1') # (batch_size, 16, 16, 64)
net = slim.batch_norm(net, scope='bn1')
net = slim.conv2d(net, 128, [3, 3], scope='conv2') # (batch_size, 8, 8, 128)
net = slim.batch_norm(net, scope='bn2')
net = slim.conv2d(net, 256, [3, 3], scope='conv3') # (batch_size, 4, 4, 256)
net = slim.batch_norm(net, scope='bn3')
net = slim.conv2d(net, 128, [4, 4], padding='VALID', scope='conv4') # (batch_size, 1, 1, 128)
net = slim.batch_norm(net, activation_fn=tf.nn.tanh, scope='bn4')
if self.mode == 'pretrain':
net = slim.conv2d(net, 10, [1, 1], padding='VALID', scope='out')
net = slim.flatten(net)
return net
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
函数g为f的输出特征向量进行解码,得到输出图像,相当于f网络的逆过程,也就是说g的网络结构为4个反卷积层,
def generator(self, inputs, reuse=False):
# inputs: (batch, 1, 1, 128)
with tf.variable_scope('generator', reuse=reuse):
with slim.arg_scope([slim.conv2d_transpose], padding='SAME', activation_fn=None,
stride=2, weights_initializer=tf.contrib.layers.xavier_initializer()):
with slim.arg_scope([slim.batch_norm], decay=0.95, center=True, scale=True,
activation_fn=tf.nn.relu, is_training=(self.mode=='train')):
net = slim.conv2d_transpose(inputs, 512, [4, 4], padding='VALID', scope='conv_transpose1') # (batch_size, 4, 4, 512)
net = slim.batch_norm(net, scope='bn1')
net = slim.conv2d_transpose(net, 256, [3, 3], scope='conv_transpose2') # (batch_size, 8, 8, 256)
net = slim.batch_norm(net, scope='bn2')
net = slim.conv2d_transpose(net, 128, [3, 3], scope='conv_transpose3') # (batch_size, 16, 16, 128)
net = slim.batch_norm(net, scope='bn3')
net = slim.conv2d_transpose(net, 1, [3, 3], activation_fn=tf.nn.tanh, scope='conv_transpose4') # (batch_size, 32, 32, 1)
return net
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
判别网络也为4个卷积层,
def discriminator(self, images, reuse=False):
# images: (batch, 32, 32, 1)
with tf.variable_scope('discriminator', reuse=reuse):
with slim.arg_scope([slim.conv2d], padding='SAME', activation_fn=None,
stride=2, weights_initializer=tf.contrib.layers.xavier_initializer()):
with slim.arg_scope([slim.batch_norm], decay=0.95, center=True, scale=True,
activation_fn=tf.nn.relu, is_training=(self.mode=='train')):
net = slim.conv2d(images, 128, [3, 3], activation_fn=tf.nn.relu, scope='conv1') # (batch_size, 16, 16, 128)
net = slim.batch_norm(net, scope='bn1')
net = slim.conv2d(net, 256, [3, 3], scope='conv2') # (batch_size, 8, 8, 256)
net = slim.batch_norm(net, scope='bn2')
net = slim.conv2d(net, 512, [3, 3], scope='conv3') # (batch_size, 4, 4, 512)
net = slim.batch_norm(net, scope='bn3')
net = slim.conv2d(net, 1, [4, 4], padding='VALID', scope='conv4') # (batch_size, 1, 1, 1)
net = slim.flatten(net)
return net
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
输入为原图像src_images,目标图像trg_images:
self.src_images = tf.placeholder(tf.float32, [None, 32, 32, 3], 'svhn_images')
self.trg_images = tf.placeholder(tf.float32, [None, 32, 32, 1], 'mnist_images')
1
2
对于s域,将原图像输入f,g得到特征向量fx,生成图像fake_images,并将生成图像输入判别网络,
# source domain (svhn to mnist)
self.fx = self.content_extractor(self.src_images)
self.fake_images = self.generator(self.fx)
self.logits = self.discriminator(self.fake_images)
self.fgfx = self.content_extractor(self.fake_images, reuse=True)
# loss
self.d_loss_src = slim.losses.sigmoid_cross_entropy(self.logits, tf.zeros_like(self.logits))
self.g_loss_src = slim.losses.sigmoid_cross_entropy(self.logits, tf.ones_like(self.logits))
self.f_loss_src = tf.reduce_mean(tf.square(self.fx - self.fgfx)) * 15.0
1
2
3
4
5
6
7
8
9
10
对于t域,将目标图像输入f,g,并将目标图像,生成图像分别输入判别网络,
# target domain (mnist)
self.fx = self.content_extractor(self.trg_images, reuse=True)
self.reconst_images = self.generator(self.fx, reuse=True)
self.logits_fake = self.discriminator(self.reconst_images, reuse=True)
self.logits_real = self.discriminator(self.trg_images, reuse=True)
# loss
self.d_loss_fake_trg = slim.losses.sigmoid_cross_entropy(self.logits_fake, tf.zeros_like(self.logits_fake))
self.d_loss_real_trg = slim.losses.sigmoid_cross_entropy(self.logits_real, tf.ones_like(self.logits_real))
self.d_loss_trg = self.d_loss_fake_trg + self.d_loss_real_trg
self.g_loss_fake_trg = slim.losses.sigmoid_cross_entropy(self.logits_fake, tf.ones_like(self.logits_fake))
self.g_loss_const_trg = tf.reduce_mean(tf.square(self.trg_images - self.reconst_images)) * 15.0
self.g_loss_trg = self.g_loss_fake_trg + self.g_loss_const_trg
1
2
3
4
5
6
7
8
9
10
11
12
13
试验结果
首先下载代码,
git clone https://github.com/yunjey/domain-transfer-network
下载训练数据:
cd domain-transfer-network/
./download.sh
1
2
3
将手写字体reize到32×3232×32:
python prepro.py
1
预训练:
python main.py --mode='pretrain'
1
训练:
python main.py --mode='train'
1
测试:
python main.py --mode='eval'
---------------------
作者:imperfect00
来源:CSDN
原文:https://blog.csdn.net/u011961856/article/details/78606706
版权声明:本文为博主原创文章,转载请附上博文链接!
Masking GAN pytorch相关推荐
- gan pytorch 实例_重新思考一阶段实例分割(Rethinking Single Shot Instance Segmentation)
点击上方"CVer",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者:谢恩泽 https://zhuanlan.zhihu.com/ ...
- 程序员的魔法——用Masking GAN让100,000人都露出灿烂笑容
首发地址:https://yq.aliyun.com/articles/324353 继卷积神经网络(CNN)掀起深度学习的浪潮后,生成对抗网络(GAN)逐渐成为了计算机视觉领域的另一重点关注的神经网 ...
- 【PyTorch】50行代码实现GAN——PyTorch
本文来源于PyTorch中文网. 一直想了解GAN到底是个什么东西,却一直没能腾出时间来认真研究,前几日正好搜到一篇关于PyTorch实现GAN训练的文章,特将学习记录如下,本文主要包含两个部分:GA ...
- gan pytorch 实例_GaN教程(1)|美国宜普(EPC)公司推出“如何使用氮化镓器件”系列视频教程,讲授GaN功率晶体管和集成电路设计的方方面面...
美国宜普电源转换公司(EPC)推出"如何使用氮化镓晶体管"系列视频教程,包括14个部分(目前已上线7个视频),旨在为功率系统设计工程师提供技术基础和应用工具集,使其了解如何使用氮化 ...
- Masking GAN
github代码:https://github.com/tgeorgy/mgan 文章的创新点: 1.生成网络输入x,输出包括分割模板mask,和中间图像y,根据mask将输入x与中间图像y结合,得到 ...
- 深度学习(三十三)——GAN参考资源
GAN参考资源 https://blog.csdn.net/liuxiao214/article/category/6940697 某GAN专栏 https://mp.weixin.qq.com/s/ ...
- batchnorm pytorch_GitHub趋势榜第一:TensorFlow+PyTorch深度学习资源大汇总
[新智元导读]该项目是Jupyter Notebook中TensorFlow和PyTorch的各种深度学习架构,模型和技巧的集合.内容非常丰富,适用于Python 3.7,适合当做工具书. 本文搜集整 ...
- 程序员技术进阶手册(二)
这次AI浪潮大火有三个原因:一是大数据的发展,二是计算能力的提高,三是深度学习的兴起.作为程序员,我们应该对深度学习多加关心.深度学习的概念源于人工神经网络的研究,如果追溯深度学习的概念还是要回到20 ...
- 涵盖18+ SOTA GAN实现,这个图像生成领域的PyTorch库火了
视学算法报道 转载自:机器之心 作者:杜伟.陈萍 GAN 自从被提出后,便迅速受到广泛关注.我们可以将 GAN 分为两类,一类是无条件下的生成:另一类是基于条件信息的生成.近日,来自韩国浦项科技大学的 ...
最新文章
- [验证码识别技术]-初级的滑动式验证图片识别
- 资本主义的历史仍未终结(作者:袁剑)【转】
- qt opencv cmake配置 单纯小白
- ML之回归预测之BE:利用BE算法解决回归(实数值评分预测)问题—线性方法解决非线性问题
- android脚步---自动完成文本框
- 帆软高级函数应用之报表函数
- 非常难得的 CMOS sensor 工作原理的深入技术科普
- 《女士品茶》读书笔记
- 高斯-拉格朗日(Gauss-Legendre )Ⅱ型求积公式 数值分析 勘误 P111
- 在 vmware ESXi上安装mac系统虚拟机
- zotero自动安装word插件失败
- 使用Kuboard spray部署Kubernetes 1.24.3 集成Harbor私有镜像库
- DC升压直流高压电源模块12V24v转100V150V200V250V300V350v1000伏线性变化电压控制输出
- python3+selenium框架设计04-封装测试基类
- 中望CAD.NET二次开发(C#)_第01篇_环境搭建
- python生成带有表格的图片
- CAN与RS485比较
- 一篇让你读懂java中的字符串(String)
- OpenJudge百炼-2745-显示器-C语言-模拟
- 《士兵突击》之伍六一:最钢铁的男儿最柔软的内心
热门文章
- vue 目录名称详解_使用脚手架创建vue项目目录详解
- 【c语言】简单计算器
- 双链表偶数节点求和java_java--删除链表偶数节点
- jieba分词_从语言模型原理分析如何jieba更细粒度的分词
- 微服务前端开发框架React-Admin
- Spring MVC+Stomp+Security+H2 Jetty
- C# 8.0的三个令人兴奋的新特性
- 为什么“15。。。”会导致微信ANR?
- 因看不见而恐惧!企业亟需“看得见”威胁
- lisp提取长方形坐标_求修改lisp程序,如何提取CAD中多个点的坐标,(本人想提取UCS坐标系)另外只需要提取X,Y值,不要Z...