论文传送门:https://arxiv.org/pdf/1704.00028.pdf

WGAN存在的问题:在WGAN中,为使得判别器D(x)满足Lipschitz连续条件,从而对网络参数进行了[-c,c]的区间限制,使得网络参数分布极端,参数均接近于-c或c。

WGAN-gp的目的:解决WGAN参数分布极端的问题。 

WGAN-gp的方法:在判别器D的loss中增加梯度惩罚项,代替WGAN中对判别器D的参数区间限制,同样能保证D(x)满足Lipschitz连续条件。(证明过程见论文补充材料)

红框部分:与WGAN不同之处,即判别器D的loss增加梯度惩罚项和优化器选择Adam

梯度惩罚项的计算实现见代码70-87行,判别器D的损失函数修改见代码156行。

import os
import torch
from torch.utils.data import DataLoaderimport torch.nn as nnfrom torchvision import datasets, transforms
from torchvision.utils import save_imagefrom tqdm import tqdmclass Discriminator(nn.Module):  # 定义判别器(WS-divergence)def __init__(self, img_shape=(1, 28, 28)):  # 初始化方法super(Discriminator, self).__init__()  # 继承初始化方法self.img_shape = img_shape  # 图片形状self.linear1 = nn.Linear(self.img_shape[0] * self.img_shape[1] * self.img_shape[2], 512)  # linear映射self.linear2 = nn.Linear(512, 256)  # linear映射self.linear3 = nn.Linear(256, 1)  # linear映射self.leakyrelu = nn.LeakyReLU(0.2, inplace=True)  # leakyrelu激活函数def forward(self, x):  # 前传函数x = torch.flatten(x, 1)  # 输入图片从三维压缩至一维特征向量,(n,1,28,28)-->(n,784)x = self.linear1(x)  # linear映射,(n,784)-->(n,512)x = self.leakyrelu(x)  # leakyrelu激活函数x = self.linear2(x)  # linear映射,(n,512)-->(n,256)x = self.leakyrelu(x)  # leakyrelu激活函数x = self.linear3(x)  # linear映射,(n,256)-->(n,1)return x  # 返回近似拟合的Wasserstein距离class Generator(nn.Module):  # 定义生成器def __init__(self, img_shape=(1, 28, 28), latent_dim=100):  # 初始化方法super(Generator, self).__init__()self.img_shape = img_shape  # 图片形状self.latent_dim = latent_dim  # 噪声z的长度self.linear1 = nn.Linear(self.latent_dim, 128)  # linear映射self.linear2 = nn.Linear(128, 256)  # linear映射self.bn2 = nn.BatchNorm1d(256, 0.8)  # bn操作self.linear3 = nn.Linear(256, 512)  # linear映射self.bn3 = nn.BatchNorm1d(512, 0.8)  # bn操作self.linear4 = nn.Linear(512, 1024)  # linear映射self.bn4 = nn.BatchNorm1d(1024, 0.8)  # bn操作self.linear5 = nn.Linear(1024, self.img_shape[0] * self.img_shape[1] * self.img_shape[2])  # linear映射self.leakyrelu = nn.LeakyReLU(0.2, inplace=True)  # leakyrelu激活函数self.tanh = nn.Tanh()  # tanh激活函数,将输出压缩至(-1.1)def forward(self, z):  # 前传函数z = self.linear1(z)  # linear映射,(n,100)-->(n,128)z = self.leakyrelu(z)  # leakyrelu激活函数z = self.linear2(z)  # linear映射,(n,128)-->(n,256)z = self.bn2(z)  # 一维bn操作z = self.leakyrelu(z)  # leakyrelu激活函数z = self.linear3(z)  # linear映射,(n,256)-->(n,512)z = self.bn3(z)  # 一维bn操作z = self.leakyrelu(z)  # leakyrelu激活函数z = self.linear4(z)  # linear映射,(n,512)-->(n,1024)z = self.bn4(z)  # 一维bn操作z = self.leakyrelu(z)  # leakyrelu激活函数z = self.linear5(z)  # linear映射,(n,1024)-->(n,784)z = self.tanh(z)  # tanh激活函数z = z.view(-1, self.img_shape[0], self.img_shape[1], self.img_shape[2])  # 从一维特征向量扩展至三维图片,(n,784)-->(n,1,28,28)return z  # 返回生成的图片def cal_gp(D, real_imgs, fake_imgs, cuda):  # 定义函数,计算梯度惩罚项gpr = torch.rand(size=(real_imgs.shape[0], 1, 1, 1))  # 真假样本的采样比例r,batch size个随机数,服从区间[0,1)的均匀分布if cuda:  # 如果使用cudar = r.cuda()  # r加载到GPUx = (r * real_imgs + (1 - r) * fake_imgs).requires_grad_(True)  # 输入样本x,由真假样本按照比例产生,需要计算梯度d = D(x)  # 判别网络D对输入样本x的判别结果D(x)fake = torch.ones_like(d)  # 定义与d形状相同的张量,代表梯度计算时每一个元素的权重if cuda:  # 如果使用cudafake = fake.cuda()  # fake加载到GPUg = torch.autograd.grad(  # 进行梯度计算outputs=d,  # 计算梯度的函数d,即D(x)inputs=x,  # 计算梯度的变量xgrad_outputs=fake,  # 梯度计算权重create_graph=True,  # 创建计算图retain_graph=True  # 保留计算图)[0]  # 返回元组的第一个元素为梯度计算结果gp = ((g.norm(2, dim=1) - 1) ** 2).mean()  # (||grad(D(x))||2-1)^2 的均值return gp  # 返回梯度惩罚项gpif __name__ == "__main__":# 训练参数total_epochs = 100  # 训练轮次batch_size = 64  # 批大小lr_D = 4e-3  # 判别网络D学习率lr_G = 1e-3  # 生成网络G学习率num_workers = 8  # 数据加载线程数latent_dim = 100  # 噪声z长度image_size = 28  # 图片尺寸channel = 1  # 图片通道a = 10  # 梯度惩罚项系数clip_value = 0.01  # 判别器参数限定范围dataset_dir = "dataset/mnist"  # 训练数据集路径gen_images_dir = "gen_images"  # 生成样例图片路径cuda = True if torch.cuda.is_available() else False  # 设置是否使用cudaos.makedirs(dataset_dir, exist_ok=True)  # 创建训练数据集路径os.makedirs(gen_images_dir, exist_ok=True)  # 创建样例图片路径image_shape = (channel, image_size, image_size)  # 图片形状# 模型D = Discriminator(image_shape)  # 实例化判别器G = Generator(image_shape, latent_dim)  # 实例化生成器if cuda:  # 如果使用cudaD = D.cuda()  # 模型加载到GPUG = G.cuda()  # 模型加载到GPU# 数据集transform = transforms.Compose(  # 数据预处理方法[transforms.Resize(image_size),  # resizetransforms.ToTensor(),  # 转为tensortransforms.Normalize([0.5], [0.5])]  # 标准化)dataloader = DataLoader(  # dataloaderdataset=datasets.MNIST(  # 数据集选取MNIST手写体数据集root=dataset_dir,  # 数据集存放路径train=True,  # 使用训练集download=True,  # 自动下载transform=transform  # 应用数据预处理方法),batch_size=batch_size,  # 设置batch sizenum_workers=num_workers,  # 设置读取数据线程数shuffle=True  # 设置打乱数据)# 优化器optimizer_D = torch.optim.Adam(D.parameters(), lr=lr_D)  # 定义判别网络Adam优化器,传入学习率lr_Doptimizer_G = torch.optim.Adam(G.parameters(), lr=lr_G)  # 定义生成网络Adam优化器,传入学习率lr_G# 训练循环for epoch in range(total_epochs):  # 循环epochpbar = tqdm(total=len(dataloader), desc=f'Epoch {epoch + 1}/{total_epochs}', postfix=dict,mininterval=0.3)  # 设置当前epoch显示进度LD = 0LG = 0for i, (real_imgs, _) in enumerate(dataloader):  # 循环iterif cuda:  # 如果使用cudareal_imgs = real_imgs.cuda()  # 数据加载到GPUbs = real_imgs.shape[0]  # batchsize# 开始训练判别网络Doptimizer_D.zero_grad()  # 判别网络D清零梯度z = torch.randn((bs, latent_dim))  # 生成输入噪声z,服从标准正态分布,长度为latent_dimif cuda:  # 如果使用cudaz = z.cuda()  # 噪声z加载到GPUfake_imgs = G(z).detach()  # 噪声z输入生成网络G,得到生成图片,并阻止其反向梯度传播gp = cal_gp(D, real_imgs, fake_imgs, cuda)loss_D = -torch.mean(D(real_imgs)) + torch.mean(D(fake_imgs)) + a * gp  # 判别网络D的损失函数,相较于WGAN,增加了梯度惩罚项a*gploss_D.backward()  # 反向传播,计算当前梯度optimizer_D.step()  # 根据梯度,更新网络参数LD += loss_D.item()  # 累计判别网络D的loss# 开始训练生成网络Goptimizer_G.zero_grad()  # 生成网络G清零梯度gen_imgs = G(z)  # 噪声z输入生成网络G,得到生成图片loss_G = -torch.mean(D(gen_imgs))  # 生成网络G的损失函数loss_G.backward()  # 反向传播,计算当前梯度optimizer_G.step()  # 根据梯度,更新网络参数LG += loss_G.item()  # 累计生成网络G的losspbar.set_postfix(**{'D_loss': loss_D.item(), 'G_loss': loss_G.item()})  # 显示判别网络D和生成网络G的损失pbar.update(1)  # 步进长度pbar.close()  # 关闭当前epoch显示进度print("total_D_loss:%.4f,total_G_loss:%.4f" % (LD / len(dataloader), LG / len(dataloader)))  # 显示当前epoch训练完成后,判别网络D和生成网络G的总损失save_image(gen_imgs.data[:25], "%s/ep%d.png" % (gen_images_dir, (epoch + 1)), nrow=5,normalize=True)  # 保存生成图片样例(5x5)

WGAN-gp模型——pytorch实现相关推荐

  1. ArcGIS API for javascript开发笔记(五)——GP服务调用之GP模型的发布及使用详解...

    感谢一路走来默默陪伴和支持的你~~~ ----------------欢迎来访,拒绝转载---------------- 关于GP模型的制作请点我! 一.GP发布 ArcGIS Desktop可以作为 ...

  2. ArcGIS API for Silverlight 调用GP服务准备---GP模型建立、发布、测试

    ArcGIS API for Silverlight 调用GP服务准备---GP模型建立.发布.测试 原文:ArcGIS API for Silverlight 调用GP服务准备---GP模型建立.发 ...

  3. DIN模型pytorch代码逐行细讲

    DIN模型pytorch代码逐行细讲 文章目录 DIN模型pytorch代码逐行细讲 一.DIN模型的结构 二.代码介绍 三.导入包 四.导入数据 五.数据处理 六.模型定义 七.封装训练集,测试集 ...

  4. WGAN模型——pytorch实现

    论文传送门:https://arxiv.org/pdf/1701.07875.pdf 参考文章:令人拍案叫绝的Wasserstein GAN - 知乎​​​​​​ WGAN的目的:解决GAN的梯度不稳 ...

  5. python人脸识别训练模型生产_深度学习-人脸识别DFACE模型pytorch训练(二)

    首先介绍一下MTCNN的网络结构,MTCNN有三种网络,训练网络的时候需要通过三部分分别进行,每一层网络都依赖前一层网络产生训练数据供当前训练网络,这样也推动了两个网络之间的最小损耗. Pnet Rn ...

  6. ArcGIS GP应用-GP模型服务发布

    1.双击模型名称打开运行窗体 2.在图上交互和窗体中输入数据后,点击确定运行模型,查看运行结果 3.在当前模型"缓冲区分析"的目录树上,右击含有图形(点.线.面)的节点,点击添加至 ...

  7. 轴承故障诊断经典模型pytorch复现(一)——WDCNN

    论文地址:<A New Deep Learning Model for Fault Diagnosis with Good Anti-Noise and Domain Adaptation Ab ...

  8. 车牌识别 远距离监控视角 自创简化模型 Pytorch

    甲方一拍脑门,让我去实现车牌识别,还是远距离监控视角的,真开心. 数据?呵~ 不会有人期待甲方提供数据吧?? 先逛逛某宝,一万张车辆图片,0.4元/张. 甲方:阿巴阿巴- 嗯,那没事了. 再逛逛全球同 ...

  9. Seq2Seq模型PyTorch版本

    Seq2Seq模型介绍以及Pytorch版本代码详解 一.Seq2Seq模型的概述 Seq2Seq是一种循环神经网络的变种,是一种端到端的模型,包括 Encoder编码器和 Decoder解码器部分, ...

最新文章

  1. Android 弱网测试(小米手机切换3g和2g)
  2. OSChina 周一乱弹 —— 嫂子我帮你们照顾放心吧
  3. 动手推导Self-Attention
  4. linux AS 5 DNS 配置中的小错误
  5. 在应用环境中如何构造最优的数据库模式
  6. C# WPF MVVM项目实战(进阶①)
  7. ubuntu解压缩zip/tar/tar.gz/tar.bz2
  8. 软件工程概论课后作业01
  9. Python验证和可视化冰雹猜想、角谷猜想、考拉兹猜想
  10. 2 环境设置_用友U8V10.1安装(Windows 7环境)
  11. 盖瑞特金属探测门受多个严重漏洞影响,可遭篡改
  12. PTA:图的理论习题集
  13. 第八章第二层交换和生成树协议(STP)
  14. bootstrap创建响应式网站
  15. 【Android 第三方SDK】breakpad在linux下编译
  16. 日志:每个软件工程师应该知道的实时数据的统一抽象概念
  17. bmob php,文档-Bmob移动后端云服务平台
  18. 字节(byte)、位(bit)、KB、B、字符之间关系以及编码占用位数
  19. 长期换衣行人重识别(Long-Term Clothes-Changing Person Reid)数据集汇总
  20. 【指标】GMV和销售额、SPU、SKU、商品、单品

热门文章

  1. 家用计算机有辐射吗,哪些家用电器有辐射
  2. Led方阵和串口通信COM2(读取字膜的数据并用LED显示)
  3. Qt QString转lpctstr
  4. element dialog的z-index与element-select组件下拉菜单的z-index同值,导致第一次点击时下拉菜单不可见
  5. Android开发——项目实例(三)迷你背单词软件(第三版)单词录入、背诵、联网查词、单词库
  6. java web 上传附件_JAVA WEB文件上传步骤
  7. 行为者网络理论(ANT,Actor Network Theory):一切皆是映射
  8. JAVA实验接口,内部类,抽象类的声明及使用
  9. 每日笔记-2017/03/30
  10. 计算机系统结构结构相关名词解释,体系结构复习题