深度学习《WGAN模型》
WGAN是一个对原始GAN进行重大改进的网络
主要是在如下方面做了改进
实例测试代码如下:
还是用我16张鸣人的照片搞一波事情,每一个上述的改进点,我再代码中都是用 Difference 标注的。
import torch as t
from torch import nn
from torch.autograd import Variable
from torch.optim import RMSprop
from torchvision import transforms
from torchvision.utils import make_grid
from torchvision.datasets import CIFAR10
from pylab import plt
import os
import torchvision.datasets as datasets
from torchvision.utils import save_image# 至于 WGAN和GAN的区别请全文搜索 Importment Difference 即可查看# step 1: ========================================== 定义本程序运行需要的一些参数
class WGAN_Config:lr = 0.0001nz = 100 # noise dimensionimage_size = 64nc = 3 # chanel of imgngf = 64 # generator channelndf = 64 # discriminator channelbatch_size = 16max_epoch = 5000 # =1 when debugclamp_num = 0.01 # WGAN clip gradientwgan_opt = WGAN_Config()def deprocess_img(img):out = 0.5 * (img + 1)out = out.clamp(0, 1)out = out.view(-1, 3, wgan_opt.image_size, wgan_opt.image_size)return out# step 2: ========================================== 老流程,加载数据集。
# data preprocess
transform = transforms.Compose([transforms.Resize(wgan_opt.image_size),transforms.ToTensor(),transforms.Normalize([0.5] * 3, [0.5] * 3)
])# dataset = CIFAR10(root='cifar10/', transform=transform, download=True)
# dataloader = t.utils.data.DataLoader(dataset, wgan_opt.batch_size, shuffle=True)data_path = os.path.abspath("D:/software/Anaconda3/doc/3D_Naruto")
print (os.listdir(data_path))
# 请注意,在data_path下面再建立一个目录,存放所有图片,ImageFolder会在子目录下读取数据,否则下一步会报错。
dataset = datasets.ImageFolder(root=data_path, transform=transform)
dataloader = t.utils.data.DataLoader(dataset, batch_size=wgan_opt.batch_size, shuffle=True)# step 3: ========================================== 定义WGAN的G网络和D网络的模型
class generator(nn.Module):def __init__(self):super(generator, self).__init__()self.netg = nn.Sequential(nn.ConvTranspose2d(wgan_opt.nz, wgan_opt.ngf * 8, 4, 1, 0, bias=False),nn.BatchNorm2d(wgan_opt.ngf * 8),nn.ReLU(True),nn.ConvTranspose2d(wgan_opt.ngf * 8, wgan_opt.ngf * 4, 4, 2, 1, bias=False),nn.BatchNorm2d(wgan_opt.ngf * 4),nn.ReLU(True),nn.ConvTranspose2d(wgan_opt.ngf * 4, wgan_opt.ngf * 2, 4, 2, 1, bias=False),nn.BatchNorm2d(wgan_opt.ngf * 2),nn.ReLU(True),nn.ConvTranspose2d(wgan_opt.ngf * 2, wgan_opt.ngf, 4, 2, 1, bias=False),nn.BatchNorm2d(wgan_opt.ngf),nn.ReLU(True),nn.ConvTranspose2d(wgan_opt.ngf, wgan_opt.nc, 4, 2, 1, bias=False),nn.Tanh())def forward(self, imgs):out = self.netg(imgs)return outclass discriminator(nn.Module):def __init__(self):super(discriminator, self).__init__()self.netd = nn.Sequential(nn.Conv2d(wgan_opt.nc, wgan_opt.ndf, 4, 2, 1, bias=False),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(wgan_opt.ndf, wgan_opt.ndf * 2, 4, 2, 1, bias=False),nn.BatchNorm2d(wgan_opt.ndf * 2),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(wgan_opt.ndf * 2, wgan_opt.ndf * 4, 4, 2, 1, bias=False),nn.BatchNorm2d(wgan_opt.ndf * 4),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(wgan_opt.ndf * 4, wgan_opt.ndf * 8, 4, 2, 1, bias=False),nn.BatchNorm2d(wgan_opt.ndf * 8),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(wgan_opt.ndf * 8, 1, 4, 1, 0, bias=False),# Importment Difference 1: do not use sigmoid func here any more.# nn.Sigmoid())def forward(self, imgs):out = self.netd(imgs)return out.view(imgs.shape[0])netd = discriminator()
netg = generator()# step 4: ========================================== 初始化两个网络的参数
# 这一步是新学习的。参数权重初始化过程
def weight_init(m):# weight_initialization: important for wganclass_name = m.__class__.__name__if class_name.find('Conv') != -1:m.weight.data.normal_(0, 0.02)elif class_name.find('Norm') != -1:m.weight.data.normal_(1.0, 0.02)netd.apply(weight_init)
netg.apply(weight_init)# step 5: ========================================== 定义优化器,这里使用 RMSprop,不使用Adam
# 也推荐使用 SGD
# Importment Difference 2: Use RMSprop instead of Adam
# optimizer
optimizerD = RMSprop(netd.parameters(), lr=wgan_opt.lr)
optimizerG = RMSprop(netg.parameters(), lr=wgan_opt.lr)# Importment Difference: No Log in loss
# criterion
# criterion = nn.BCELoss()# step 6: ========================================== 开始训练了
# begin training
rand_noise = Variable(t.FloatTensor(wgan_opt.batch_size, wgan_opt.nz, 1, 1).normal_(0, 1))iter_count = 0
# 将BCEloss 改为非log的loss,按照文章的记载,通常会使用直接同1和-1做比较
one = t.ones(wgan_opt.batch_size)
mone = -1 * one
for epoch in range(wgan_opt.max_epoch):for ii, data in enumerate(dataloader, 0):imgs = data[0] # real imagenoise = Variable(t.randn(imgs.size(0), wgan_opt.nz, 1, 1)) # fake imageprint(imgs.shape)# Importment Difference 4: clip param for discriminatorfor parm in netd.parameters():parm.data.clamp_(-wgan_opt.clamp_num, wgan_opt.clamp_num)# ----- train discriminator network -----netd.zero_grad()output = netd(imgs) # train netd with real imgoutput.backward(one) # 跟 1 进行比较fake_pic = netg(noise).detach() # train netd with real img, 梯度在此截断,不要继续往前传播。output2 = netd(fake_pic)output2.backward(mone) # 跟 -1 进行比较optimizerD.step()# ------ train generator later -------# we train the discriminator many times, and less train for generator.# train netd more times: because the better netd is the better netg will beif (ii + 1) % 1 == 0:netg.zero_grad()noise.data.normal_(0, 1)fake_pic = netg(noise)output = netd(fake_pic)output.backward(one) # 跟 1 进行比较optimizerG.step()if iter_count % 50 == 0:rand_imgs = netg(rand_noise)rand_imgs = deprocess_img(rand_imgs.data)save_image(rand_imgs, 'D:/software/Anaconda3/doc/3D_Img/wgan2/test_%d.png' % (iter_count))iter_count = iter_count + 1print('iter_count: ', iter_count)
效果如下:
最后都是用随机噪音产生的图片,时间太长了,训练次数不太够啊。
深度学习《WGAN模型》相关推荐
- 深度学习之自编码器(4)变分自编码器
深度学习之自编码器(4)变分自编码器 1. VAE原理 基本的自编码器本质上是学习输入 x\boldsymbol xx和隐藏变量 z\boldsymbol zz之间映射关系,它是一个 判别模型(Di ...
- 深度学习之自编码器(5)VAE图片生成实战
深度学习之自编码器(5)VAE图片生成实战 1. VAE模型 2. Reparameterization技巧 3. 网络训练 4. 图片生成 VAE图片生成实战完整代码 本节我们基于VAE模型实战F ...
- 深度学习之自编码器AutoEncoder
深度学习之自编码器AutoEncoder 原文:http://blog.csdn.net/marsjhao/article/details/73480859 一.什么是自编码器(Autoencoder ...
- 深度学习之自编码器(3)自编码器变种
深度学习之自编码器(3)自编码器变种 1. Denoising Auto-Encoder 2. Dropout Auto-Encoder 3. Adversarial Auto-Encoder 一般 ...
- 深度学习之自编码器(2)Fashion MNIST图片重建实战
深度学习之自编码器(2)Fashion MNIST图片重建实战 1. Fashion MNIST数据集 2. 编码器 3. 解码器 4. 自编码器 5. 网络训练 6. 图片重建 完整代码 自编码器 ...
- 深度学习之自编码器(1)自编码器原理
深度学习之自编码器(1)自编码器原理 自编码器原理 前面我们介绍了在给出样本及其标签的情况下,神经网络如何学习的算法,这类算法需要学习的是在给定样本 x\boldsymbol xx下的条件概率 P( ...
- 【深度学习】 自编码器(AutoEncoder)
目录 RDAE稳健深度自编码 自编码器(Auto-Encoder) DAE 深度自编码器 RDAE稳健深度自编码 自编码器(Auto-Encoder) AE算法的原理 Auto-Encoder,中文称 ...
- 深入理解深度学习——Transformer:编码器(Encoder)部分
分类目录:<深入理解深度学习>总目录 相关文章: ·注意力机制(AttentionMechanism):基础知识 ·注意力机制(AttentionMechanism):注意力汇聚与Nada ...
- 深度学习之自编码器实现——实现图像去噪
大家好,我是带我去滑雪! 自编码器是一种无监督学习的神经网络,是一种数据压缩算法,主要用于数据降维和特征提取.它的基本思想是将输入数据经过一个编码器映射到隐藏层,再通过一个解码器映射到输出层,使得输出 ...
- 深度学习:自编码器、深度信念网络和深度玻尔兹曼机
最近自己会把自己个人博客中的文章陆陆续续的复制到CSDN上来,欢迎大家关注我的 个人博客,以及我的github. 本文主要讲解有关自编码器.深度信念网络和深度玻尔兹曼机的相关知识. 一.自编码器 1. ...
最新文章
- (Mac-使用问题)Mac升级到 10.12后,下载的一些安装包提示损坏。
- python迭代器生成器装饰器
- 新疆计算机一级考试试题手机软件,新疆维吾尔自治区计算机一级考试理论题库(最新最完整)...
- access函数_ACCESS中的DLookUp函数是如何运算的?
- vue项目安装less_部署vue项目、安装mongodb
- 前端学习(1603):脚手架组件使用
- python中安装opencv一直说不是内部或外部文件_Window系统下Python如何安装OpenCV库
- Golang slice高级应用
- HTTP摘要认证原理以及HttpClient4.3实现
- 以太坊系列之十四: solidity特殊函数
- 写一个函数,求两个整数之和,要求在函数体内不得使用+、-、*、/四则运算符号
- Mac乐谱制作工具---Sibelius 8 for Mac西贝柳斯
- 【算法图解】 之 [二分查找法] 详解
- hibernate二级缓存(一)一级缓存与二级缓存
- 时频分析之STFT:短时傅里叶变换的原理与代码实现(非调用Matlab API)
- XML/HTML/CSS/JS之间的区别和联系
- 52_LSTM及简介,RNN单元的内部结构,LSTM单元的内部结构,原理,遗忘门,输入门,输出门,LSTM变体GRU,LSTM变体FC-LSTM,Pytorch LSTM API介绍,案例(学习笔记)
- 多媒体教学计算机遥控,多媒体教学系统使用说明
- Linux实现基于Loopback的NVI(NAT Virtual Interface)
- [高通MSM8953_64][Android10]解决制作差分包不生成system_manifest.xml的问题
热门文章
- slf4j 与log4j 日志管理
- chmod命令详解使用格式和方法
- 面试官系统精讲Java源码及大厂真题 - 25 整体设计:队列设计思想、工作中使用场景
- 分布式面试 - 如何基于 dubbo 进行服务治理、服务降级、失败重试以及超时重试?
- Docker搭建WebLogic服务器
- sqoop连接mysql_sqoop安装
- 【Java】输出50-100范围内所有的素数
- 【Java】编写Java GUI应用程序,完成从键盘输入矩形的长和宽,求矩形的周长和面积并输出结果的功能...
- C#LeetCode刷题之#11-盛最多水的容器(Container With Most Water)
- C#LeetCode刷题之#7-反转整数(Reverse Integer)