WGAN是一个对原始GAN进行重大改进的网络

主要是在如下方面做了改进

实例测试代码如下:

还是用我16张鸣人的照片搞一波事情,每一个上述的改进点,我再代码中都是用 Difference 标注的。

import torch as t
from torch import nn
from torch.autograd import Variable
from torch.optim import RMSprop
from torchvision import transforms
from torchvision.utils import make_grid
from torchvision.datasets import CIFAR10
from pylab import plt
import os
import torchvision.datasets as datasets
from torchvision.utils import save_image# 至于 WGAN和GAN的区别请全文搜索 Importment Difference 即可查看# step 1: ========================================== 定义本程序运行需要的一些参数
class WGAN_Config:lr = 0.0001nz = 100  # noise dimensionimage_size = 64nc = 3  # chanel of imgngf = 64  # generator channelndf = 64  # discriminator channelbatch_size = 16max_epoch = 5000  # =1 when debugclamp_num = 0.01  # WGAN clip gradientwgan_opt = WGAN_Config()def deprocess_img(img):out = 0.5 * (img + 1)out = out.clamp(0, 1)out = out.view(-1, 3, wgan_opt.image_size, wgan_opt.image_size)return out# step 2: ========================================== 老流程,加载数据集。
# data preprocess
transform = transforms.Compose([transforms.Resize(wgan_opt.image_size),transforms.ToTensor(),transforms.Normalize([0.5] * 3, [0.5] * 3)
])# dataset = CIFAR10(root='cifar10/', transform=transform, download=True)
# dataloader = t.utils.data.DataLoader(dataset, wgan_opt.batch_size, shuffle=True)data_path = os.path.abspath("D:/software/Anaconda3/doc/3D_Naruto")
print (os.listdir(data_path))
# 请注意,在data_path下面再建立一个目录,存放所有图片,ImageFolder会在子目录下读取数据,否则下一步会报错。
dataset = datasets.ImageFolder(root=data_path, transform=transform)
dataloader = t.utils.data.DataLoader(dataset, batch_size=wgan_opt.batch_size, shuffle=True)# step 3: ========================================== 定义WGAN的G网络和D网络的模型
class generator(nn.Module):def __init__(self):super(generator, self).__init__()self.netg = nn.Sequential(nn.ConvTranspose2d(wgan_opt.nz, wgan_opt.ngf * 8, 4, 1, 0, bias=False),nn.BatchNorm2d(wgan_opt.ngf * 8),nn.ReLU(True),nn.ConvTranspose2d(wgan_opt.ngf * 8, wgan_opt.ngf * 4, 4, 2, 1, bias=False),nn.BatchNorm2d(wgan_opt.ngf * 4),nn.ReLU(True),nn.ConvTranspose2d(wgan_opt.ngf * 4, wgan_opt.ngf * 2, 4, 2, 1, bias=False),nn.BatchNorm2d(wgan_opt.ngf * 2),nn.ReLU(True),nn.ConvTranspose2d(wgan_opt.ngf * 2, wgan_opt.ngf, 4, 2, 1, bias=False),nn.BatchNorm2d(wgan_opt.ngf),nn.ReLU(True),nn.ConvTranspose2d(wgan_opt.ngf, wgan_opt.nc, 4, 2, 1, bias=False),nn.Tanh())def forward(self, imgs):out = self.netg(imgs)return outclass discriminator(nn.Module):def __init__(self):super(discriminator, self).__init__()self.netd = nn.Sequential(nn.Conv2d(wgan_opt.nc, wgan_opt.ndf, 4, 2, 1, bias=False),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(wgan_opt.ndf, wgan_opt.ndf * 2, 4, 2, 1, bias=False),nn.BatchNorm2d(wgan_opt.ndf * 2),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(wgan_opt.ndf * 2, wgan_opt.ndf * 4, 4, 2, 1, bias=False),nn.BatchNorm2d(wgan_opt.ndf * 4),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(wgan_opt.ndf * 4, wgan_opt.ndf * 8, 4, 2, 1, bias=False),nn.BatchNorm2d(wgan_opt.ndf * 8),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(wgan_opt.ndf * 8, 1, 4, 1, 0, bias=False),# Importment Difference 1: do not use sigmoid func here any more.# nn.Sigmoid())def forward(self, imgs):out = self.netd(imgs)return out.view(imgs.shape[0])netd = discriminator()
netg = generator()# step 4: ========================================== 初始化两个网络的参数
# 这一步是新学习的。参数权重初始化过程
def weight_init(m):# weight_initialization: important for wganclass_name = m.__class__.__name__if class_name.find('Conv') != -1:m.weight.data.normal_(0, 0.02)elif class_name.find('Norm') != -1:m.weight.data.normal_(1.0, 0.02)netd.apply(weight_init)
netg.apply(weight_init)# step 5: ========================================== 定义优化器,这里使用 RMSprop,不使用Adam
# 也推荐使用 SGD
# Importment Difference 2: Use RMSprop instead of Adam
# optimizer
optimizerD = RMSprop(netd.parameters(), lr=wgan_opt.lr)
optimizerG = RMSprop(netg.parameters(), lr=wgan_opt.lr)# Importment Difference: No Log in loss
# criterion
# criterion = nn.BCELoss()# step 6: ========================================== 开始训练了
# begin training
rand_noise = Variable(t.FloatTensor(wgan_opt.batch_size, wgan_opt.nz, 1, 1).normal_(0, 1))iter_count = 0
# 将BCEloss 改为非log的loss,按照文章的记载,通常会使用直接同1和-1做比较
one = t.ones(wgan_opt.batch_size)
mone = -1 * one
for epoch in range(wgan_opt.max_epoch):for ii, data in enumerate(dataloader, 0):imgs = data[0]  # real imagenoise = Variable(t.randn(imgs.size(0), wgan_opt.nz, 1, 1))  # fake imageprint(imgs.shape)# Importment Difference 4: clip param for discriminatorfor parm in netd.parameters():parm.data.clamp_(-wgan_opt.clamp_num, wgan_opt.clamp_num)# ----- train discriminator network -----netd.zero_grad()output = netd(imgs)  # train netd with real imgoutput.backward(one)  # 跟 1 进行比较fake_pic = netg(noise).detach()  # train netd with real img, 梯度在此截断,不要继续往前传播。output2 = netd(fake_pic)output2.backward(mone)  # 跟 -1 进行比较optimizerD.step()# ------ train generator later -------# we train the discriminator many times, and less train for generator.# train netd more times: because the better netd is the better netg will beif (ii + 1) % 1 == 0:netg.zero_grad()noise.data.normal_(0, 1)fake_pic = netg(noise)output = netd(fake_pic)output.backward(one)  # 跟 1 进行比较optimizerG.step()if iter_count % 50 == 0:rand_imgs = netg(rand_noise)rand_imgs = deprocess_img(rand_imgs.data)save_image(rand_imgs, 'D:/software/Anaconda3/doc/3D_Img/wgan2/test_%d.png' % (iter_count))iter_count = iter_count + 1print('iter_count: ', iter_count)

效果如下:
最后都是用随机噪音产生的图片,时间太长了,训练次数不太够啊。

深度学习《WGAN模型》相关推荐

  1. 深度学习之自编码器(4)变分自编码器

    深度学习之自编码器(4)变分自编码器 1. VAE原理  基本的自编码器本质上是学习输入 x\boldsymbol xx和隐藏变量 z\boldsymbol zz之间映射关系,它是一个 判别模型(Di ...

  2. 深度学习之自编码器(5)VAE图片生成实战

    深度学习之自编码器(5)VAE图片生成实战 1. VAE模型 2. Reparameterization技巧 3. 网络训练 4. 图片生成 VAE图片生成实战完整代码  本节我们基于VAE模型实战F ...

  3. 深度学习之自编码器AutoEncoder

    深度学习之自编码器AutoEncoder 原文:http://blog.csdn.net/marsjhao/article/details/73480859 一.什么是自编码器(Autoencoder ...

  4. 深度学习之自编码器(3)自编码器变种

    深度学习之自编码器(3)自编码器变种 1. Denoising Auto-Encoder 2. Dropout Auto-Encoder 3. Adversarial Auto-Encoder  一般 ...

  5. 深度学习之自编码器(2)Fashion MNIST图片重建实战

    深度学习之自编码器(2)Fashion MNIST图片重建实战 1. Fashion MNIST数据集 2. 编码器 3. 解码器 4. 自编码器 5. 网络训练 6. 图片重建 完整代码  自编码器 ...

  6. 深度学习之自编码器(1)自编码器原理

    深度学习之自编码器(1)自编码器原理 自编码器原理  前面我们介绍了在给出样本及其标签的情况下,神经网络如何学习的算法,这类算法需要学习的是在给定样本 x\boldsymbol xx下的条件概率 P( ...

  7. 【深度学习】 自编码器(AutoEncoder)

    目录 RDAE稳健深度自编码 自编码器(Auto-Encoder) DAE 深度自编码器 RDAE稳健深度自编码 自编码器(Auto-Encoder) AE算法的原理 Auto-Encoder,中文称 ...

  8. 深入理解深度学习——Transformer:编码器(Encoder)部分

    分类目录:<深入理解深度学习>总目录 相关文章: ·注意力机制(AttentionMechanism):基础知识 ·注意力机制(AttentionMechanism):注意力汇聚与Nada ...

  9. 深度学习之自编码器实现——实现图像去噪

    大家好,我是带我去滑雪! 自编码器是一种无监督学习的神经网络,是一种数据压缩算法,主要用于数据降维和特征提取.它的基本思想是将输入数据经过一个编码器映射到隐藏层,再通过一个解码器映射到输出层,使得输出 ...

  10. 深度学习:自编码器、深度信念网络和深度玻尔兹曼机

    最近自己会把自己个人博客中的文章陆陆续续的复制到CSDN上来,欢迎大家关注我的 个人博客,以及我的github. 本文主要讲解有关自编码器.深度信念网络和深度玻尔兹曼机的相关知识. 一.自编码器 1. ...

最新文章

  1. (Mac-使用问题)Mac升级到 10.12后,下载的一些安装包提示损坏。
  2. python迭代器生成器装饰器
  3. 新疆计算机一级考试试题手机软件,新疆维吾尔自治区计算机一级考试理论题库(最新最完整)...
  4. access函数_ACCESS中的DLookUp函数是如何运算的?
  5. vue项目安装less_部署vue项目、安装mongodb
  6. 前端学习(1603):脚手架组件使用
  7. python中安装opencv一直说不是内部或外部文件_Window系统下Python如何安装OpenCV库
  8. Golang slice高级应用
  9. HTTP摘要认证原理以及HttpClient4.3实现
  10. 以太坊系列之十四: solidity特殊函数
  11. 写一个函数,求两个整数之和,要求在函数体内不得使用+、-、*、/四则运算符号
  12. Mac乐谱制作工具---Sibelius 8 for Mac西贝柳斯
  13. 【算法图解】 之 [二分查找法] 详解
  14. hibernate二级缓存(一)一级缓存与二级缓存
  15. 时频分析之STFT:短时傅里叶变换的原理与代码实现(非调用Matlab API)
  16. XML/HTML/CSS/JS之间的区别和联系
  17. 52_LSTM及简介,RNN单元的内部结构,LSTM单元的内部结构,原理,遗忘门,输入门,输出门,LSTM变体GRU,LSTM变体FC-LSTM,Pytorch LSTM API介绍,案例(学习笔记)
  18. 多媒体教学计算机遥控,多媒体教学系统使用说明
  19. Linux实现基于Loopback的NVI(NAT Virtual Interface)
  20. [高通MSM8953_64][Android10]解决制作差分包不生成system_manifest.xml的问题

热门文章

  1. slf4j 与log4j 日志管理
  2. chmod命令详解使用格式和方法
  3. 面试官系统精讲Java源码及大厂真题 - 25 整体设计:队列设计思想、工作中使用场景
  4. 分布式面试 - 如何基于 dubbo 进行服务治理、服务降级、失败重试以及超时重试?
  5. Docker搭建WebLogic服务器
  6. sqoop连接mysql_sqoop安装
  7. 【Java】输出50-100范围内所有的素数
  8. 【Java】编写Java GUI应用程序,完成从键盘输入矩形的长和宽,求矩形的周长和面积并输出结果的功能...
  9. C#LeetCode刷题之#11-盛最多水的容器(Container With Most Water)
  10. C#LeetCode刷题之#7-反转整数(Reverse Integer)