使用8个高斯分布做对抗网络训练。
wgan_gp.py,代码:

import torch
from torch import nn, optim, autograd
import numpy as np
import visdom
from torch.nn import functional as F
from matplotlib import pyplot as plt
import randomh_dim = 400
batchsz = 512
viz = visdom.Visdom()class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.net = nn.Sequential(nn.Linear(2, h_dim),nn.ReLU(True),nn.Linear(h_dim, h_dim),nn.ReLU(True),nn.Linear(h_dim, h_dim),nn.ReLU(True),nn.Linear(h_dim, 2),)def forward(self, z):output = self.net(z)return outputclass Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.net = nn.Sequential(nn.Linear(2, h_dim),nn.ReLU(True),nn.Linear(h_dim, h_dim),nn.ReLU(True),nn.Linear(h_dim, h_dim),nn.ReLU(True),nn.Linear(h_dim, 1),nn.Sigmoid())def forward(self, x):output = self.net(x)return output.view(-1)def data_generator():scale = 2.centers = [(1, 0),(-1, 0),(0, 1),(0, -1),(1. / np.sqrt(2), 1. / np.sqrt(2)),(1. / np.sqrt(2), -1. / np.sqrt(2)),(-1. / np.sqrt(2), 1. / np.sqrt(2)),(-1. / np.sqrt(2), -1. / np.sqrt(2))]centers = [(scale * x, scale * y) for x, y in centers]while True:dataset = []for i in range(batchsz):point = np.random.randn(2) * .02center = random.choice(centers)point[0] += center[0]point[1] += center[1]dataset.append(point)dataset = np.array(dataset, dtype='float32')dataset /= 1.414  # stdevyield datasetdef generate_image(D, G, xr, epoch):"""Generates and saves a plot of the true distribution, the generator, and thecritic."""N_POINTS = 128RANGE = 3plt.clf()points = np.zeros((N_POINTS, N_POINTS, 2), dtype='float32')points[:, :, 0] = np.linspace(-RANGE, RANGE, N_POINTS)[:, None]points[:, :, 1] = np.linspace(-RANGE, RANGE, N_POINTS)[None, :]points = points.reshape((-1, 2))# (16384, 2)# print('p:', points.shape)# draw contourwith torch.no_grad():points = torch.Tensor(points).cuda()  # [16384, 2]disc_map = D(points).cpu().numpy()  # [16384]x = y = np.linspace(-RANGE, RANGE, N_POINTS)cs = plt.contour(x, y, disc_map.reshape((len(x), len(y))).transpose())plt.clabel(cs, inline=1, fontsize=10)# plt.colorbar()# draw sampleswith torch.no_grad():z = torch.randn(batchsz, 2).cuda()  # [b, 2]samples = G(z).cpu().numpy()  # [b, 2]xr = xr.cpu()plt.scatter(xr[:, 0], xr[:, 1], c='orange', marker='.')plt.scatter(samples[:, 0], samples[:, 1], c='green', marker='+')viz.matplot(plt, win='contour', opts=dict(title='p(x):%d' % epoch))def weights_init(m):if isinstance(m, nn.Linear):# m.weight.data.normal_(0.0, 0.02)nn.init.kaiming_normal_(m.weight)m.bias.data.fill_(0)def gradient_penalty(D, xr, xf):""":param D::param xr::param xf::return:"""LAMBDA = 0.3# only constrait for Discriminatorxf = xf.detach()xr = xr.detach()# [b, 1] => [b, 2]alpha = torch.rand(batchsz, 1).cuda()alpha = alpha.expand_as(xr)interpolates = alpha * xr + ((1 - alpha) * xf)interpolates.requires_grad_()disc_interpolates = D(interpolates)gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,grad_outputs=torch.ones_like(disc_interpolates),create_graph=True, retain_graph=True, only_inputs=True)[0]gp = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDAreturn gpdef main():torch.manual_seed(23)np.random.seed(23)G = Generator().cuda()D = Discriminator().cuda()G.apply(weights_init)D.apply(weights_init)optim_G = optim.Adam(G.parameters(), lr=1e-3, betas=(0.5, 0.9))optim_D = optim.Adam(D.parameters(), lr=1e-3, betas=(0.5, 0.9))data_iter = data_generator()print('batch:', next(data_iter).shape)viz.line([[0, 0]], [0], win='loss', opts=dict(title='loss',legend=['D', 'G']))for epoch in range(50000):# 1. train discriminator for k stepsfor _ in range(5):x = next(data_iter)xr = torch.from_numpy(x).cuda()# [b]predr = (D(xr))# max log(lossr)lossr = - (predr.mean())# [b, 2]z = torch.randn(batchsz, 2).cuda()# stop gradient on G# [b, 2]xf = G(z).detach()# [b]predf = (D(xf))# min predflossf = (predf.mean())# gradient penaltygp = gradient_penalty(D, xr, xf)loss_D = lossr + lossf + gpoptim_D.zero_grad()loss_D.backward()# for p in D.parameters():#     print(p.grad.norm())optim_D.step()# 2. train Generatorz = torch.randn(batchsz, 2).cuda()xf = G(z)predf = (D(xf))# max predfloss_G = - (predf.mean())optim_G.zero_grad()loss_G.backward()optim_G.step()if epoch % 100 == 0:viz.line([[loss_D.item(), loss_G.item()]], [epoch], win='loss', update='append')generate_image(D, G, xr, epoch)print(loss_D.item(), loss_G.item())if __name__ == '__main__':main()

使用命令打开visdom,代码:

python -m visdom.server


最后好的效果是所有绿色的加号都会集中在黄色的点上,此时是因为训练次数不够多,所以效果不好。

pytorch训练WGAN网络相关推荐

  1. MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(三)...

    版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com 在前两篇文章MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网 ...

  2. 使用Pytorch搭建U-Net网络并基于DRIVE数据集训练(语义分割)学习笔记

    使用Pytorch搭建U-Net网络并基于DRIVE数据集训练(语义分割)学习笔记 https://www.bilibili.com/video/BV1rq4y1w7xM?spm_id_from=33 ...

  3. pytorch建立mobilenetV3-ssd网络并进行训练与预测

    pytorch建立mobilenetV3-ssd网络并进行训练与预测 前言 Step1:搭建mobilenetV3-ssd网络框架 需要提前准备的函数和类. mobilenetV3_large 调用m ...

  4. pytorch空间变换网络

    pytorch空间变换网络 本文将学习如何使用称为空间变换器网络的视觉注意机制来扩充网络.可以在DeepMind paper 有关空间变换器网络的内容. 空间变换器网络是对任何空间变换的差异化关注的概 ...

  5. 送你9个快速使用Pytorch训练解决神经网络的技巧(附代码)

    来源:读芯术 本文约4800字,建议阅读10分钟. 本文为大家介绍9个使用Pytorch训练解决神经网络的技巧 图片来源:unsplash.com/@dulgier 事实上,你的模型可能还停留在石器时 ...

  6. pytorch训练过程中loss出现NaN的原因及可采取的方法

    在pytorch训练过程中出现loss=nan的情况 1.学习率太高. 2.loss函数 3.对于回归问题,可能出现了除0 的计算,加一个很小的余项可能可以解决 4.数据本身,是否存在Nan,可以用n ...

  7. pytorch贝叶斯网络_贝叶斯神经网络:2个在TensorFlow和Pytorch中完全连接

    pytorch贝叶斯网络 贝叶斯神经网络 (Bayesian Neural Net) This chapter continues the series on Bayesian deep learni ...

  8. pytorch训练GAN的代码(基于MNIST数据集)

    论文:Generative Adversarial Networks 作者:Ian J. Goodfellow 年份:2014年 从2020年3月多开始看网络,这是我第一篇看并且可以跑通代码的论文,简 ...

  9. 实战:使用Pytorch搭建分类网络(肺结节假阳性剔除)

    实战:使用Pytorch搭建分类网络(肺结节假阳性剔除) 阅前可看: 实战:使用yolov3完成肺结节检测(Luna16数据集)及肺实质分割 其中的脚本资源getMat.py文件是对肺结节进行切割. ...

最新文章

  1. Python:CrawlSpiders
  2. java鉴权_一个开箱即用的高效认证鉴权框架,专注于restful api的认证鉴权动态保护...
  3. vue创建二:引入本地图片
  4. 北京低利用率数据中心将有序关闭腾退
  5. duilib各种布局的作用,相对布局与绝对布局的的意义与用法
  6. React-引领未来的用户界面开发框架-读书笔记(八)
  7. Java 9和Project Jigsaw如何破坏您的代码
  8. cuSPARSE库:(十七)cusparseStatus_t 返回信息
  9. 《统计学习方法》—— 感知机对偶算法、推导以及python3代码实现(二)
  10. tcpdump命令--详解
  11. 2048小游戏后端的实现
  12. 盘点番茄汁有益于身体的N多好处
  13. Linux iptables 防火墙相关资料
  14. win10禁用驱动程序强制签名_只需一个简单命令,在Win10上启用Windows恢复环境(WinRE)...
  15. 滤波反投影重建算法(FBP)实现及应用(matlab)
  16. Html5 Egret游戏开发 成语大挑战(九)设置界面和声音管理
  17. sqlserver阻止保存要求重新创建表的更改
  18. 记一次IOS与H5-SPA页面的交互经验
  19. 选择灰盒安全测试工具,看准以下几点
  20. java版我的世界MITE怎么下_我的世界mite作弊图文教程

热门文章

  1. 经典C源程序100例
  2. php 抓取所有div,快速了解PHP抓取网页内容的技巧
  3. 用Java写的验证码程序
  4. 自然语言处理入门学习系列一
  5. 深入浅出FPGA-12-VMM(验证方法学)
  6. 2022年Java面经分享,腾讯Java面试题
  7. 关于反向传播算法的理解
  8. 光安检场景下危险品检测
  9. Nginx 配置支持PHP
  10. 绿幕特效视频的透明通道输出与拼合为图像矩阵