运行环境

pytorch1.3.0

简介

生成对抗神经网络分为两部分: 生成器

鉴别器

生成器

把一个满足高斯分布的向量映射为一个784维度的向量,期望这个784维度的向量就是我们要的图片。对于生成器的监督训练主要是靠鉴别器。

鉴别器

是一个二分类模型,输入是图片向量,输出: 这张图是真图的概率。

训练数据有两种:真图、假图(生成器生成的)

真图对应的label是1、假图对应的label是0

补充

从0.4起, Variable 正式合并入Tensor类,通过Variable嵌套实现的自动微分功能已经整合进入了Tensor类中。虽然为了代码的兼容性还是可以使用Variable(tensor)这种方式进行嵌套,但是这个操作其实什么都没做。

运行方法:把程序粘贴下来保存为 xx.py ,然后最好把xx.py放在一个空的文件夹里面 ,python装好需要的包,就可以运行了

运行后程序会自动再生成两个文件夹,总体长这个样子

本程序数据集会自动下载,如果慢可以使用迅雷下载下来然后放到指定的地方(data\MNIST\raw)。

代码

import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image
from torch.autograd import Variable
import os#超参数
batch_size = 128
num_epoch = 500
z_dimention = 100
d_optimizer_lr=0.0001
g_optimizer_lr=0.0001
use_gpu=Trueclass discriminator(nn.Module):  #784长度向量 -> (0,1)def __init__(self):super(discriminator, self).__init__()self.dis = nn.Sequential(nn.Linear(784, 256),nn.LeakyReLU(0.2),nn.Linear(256, 256),nn.LeakyReLU(0.2),nn.Linear(256, 1),nn.Sigmoid())def forward(self, x):x = self.dis(x)return xclass generator(nn.Module):  # 100长度向量 -> 784长度向量def __init__(self):super(generator, self).__init__()self.gen = nn.Sequential(nn.Linear(100, 256),nn.ReLU(True),nn.Linear(256, 256),nn.ReLU(True),nn.Linear(256, 784),nn.Tanh())def forward(self, x):x = self.gen(x)return x#784向量 -> 1*28*28
def to_img(x):out = 0.5 * (x + 1)out = out.clamp(0, 1)out = out.view(-1, 1, 28, 28)return outimg_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])mnist = datasets.MNIST(root='./data/', train=True, transform=img_transform, download=True
)dataloader = torch.utils.data.DataLoader(dataset=mnist, batch_size=batch_size,shuffle=True,drop_last=True
)if __name__=='__main__':if not os.path.exists('./img'):os.mkdir('./img')D = discriminator()G = generator()criterion = nn.BCELoss()  # 二分类的交叉熵d_optimizer = torch.optim.Adam(D.parameters(), lr=d_optimizer_lr)g_optimizer = torch.optim.Adam(G.parameters(), lr=g_optimizer_lr)z=torch.FloatTensor(batch_size,z_dimention)if use_gpu:D = D.cuda()G = G.cuda()z = z.cuda()for epoch in range(num_epoch):for i, (img, _) in enumerate(dataloader):img = img.view(batch_size, -1)if use_gpu:real_img = img.cuda()real_label = torch.ones(batch_size).cuda()fake_label = torch.zeros(batch_size).cuda()# =================train discriminatorreal_out = D(real_img)d_loss_real = criterion(real_out, real_label)  # 真实数据对应输出 1real_scores = real_outz.data.normal_(0,1)fake_img1 = G(z)fake_out = D(fake_img1.detach())d_loss_fake = criterion(fake_out, fake_label)fake_scores = fake_outd_loss = d_loss_real + d_loss_faked_optimizer.zero_grad()d_loss.backward()d_optimizer.step()# ===============train generatorz.data.normal_(0,1)fake_img2 = G(z)output = D(fake_img2)g_loss = criterion(output, real_label)g_optimizer.zero_grad()g_loss.backward()g_optimizer.step()if epoch == 0:real_images = to_img(real_img.cpu().data)save_image(real_images, './img/real_images.png')print('Epoch[{}/{}],d_loss:{:.6f},g_loss:{:.6f} D real:{:.6f},D fake:{:.6f}'.format(epoch, num_epoch, d_loss.item(), g_loss.item(), real_scores.data.mean(), fake_scores.data.mean()))fake_images = to_img(fake_img2.cpu().data)save_image(fake_images, './img/fake_images-{}.png'.format(epoch + 1))

实验结果

epoch=0

epoch=10

epoch=100

epoch=200

epoch=500

epoch=2800

GAN(生成对抗神经网络)生成MNIST 基于pytorch实现相关推荐

  1. 条件生成对抗神经网络,生成对抗网络gan原理

    关于GAN生成式对抗网络中判别器的输出的问题 . ...摘要生成式对抗网络GAN(Generativeadversarialnetworks)目前已经成为人工智能学界一个热门的研究方向.GAN的基本思 ...

  2. ​清华大学提出基于生成对抗神经网络的自然图像多风格卡通化方法并开源代码

    近日,清华大学刘永进教授课题组在 IEEE Transactions on Visualization and Computer Graphics 上发表论文,提出基于生成对抗神经网络的自然图像多风格 ...

  3. 赠书 | 读懂生成对抗神经网络 GAN,看这文就够了

    生成对抗神经网络(Generative Adversarial Nets,GAN)是一种深度学习的框架,它是通过一个相互对抗的过程来完成模型训练的.典型的GAN包含两个部分,一个是生成模型(Gener ...

  4. ​清华大学提出基于生成对抗神经网络的自然图像多风格卡通化方法并开源代码...

    近日,清华大学刘永进教授课题组在 IEEE Transactions on Visualization and Computer Graphics 上发表论文,提出基于生成对抗神经网络的自然图像多风格 ...

  5. 【万物皆可 GAN】生成对抗网络生成手写数字 Part 1

    [万物皆可 GAN]生成对抗网络生成手写数字 Part 1 概述 GAN 网络结构 GAN 训练流程 模型详解 生成器 判别器 概述 GAN (Generative Adversarial Netwo ...

  6. pytorch生成对抗网络生成动漫图像

    代码地址:pytorch实战,使用生成对抗网络生成动漫图像 dataset from torchvision import transforms from torch.utils.data impor ...

  7. 利用生成对抗网络生成海洋塑料合成图像

    问题陈述 过去十年来,海洋塑料污染一直是气候问题的首要问题.海洋中的塑料不仅能够通过勒死或饥饿杀死海洋生物,而且也是通过捕获二氧化碳使海洋变暖的一个主要因素. 近年来,非营利组织海洋清洁组织(Ocea ...

  8. 生成对抗网络生成多维数据集_生成没有数据集的新颖内容

    生成对抗网络生成多维数据集 介绍(Introduction) GAN architecture has been the standard for generating content through ...

  9. 以FGSM算法为例的对抗训练的实现(基于Pytorch)

    如果可以,请点个赞,这是我写博客的动力,谢谢各位观众 1. 前言 深度学习虽然发展迅速,但是由于其线性的特性,受到了对抗样本的影响,很容易造成系统功能的失效. 以图像分类为例子,对抗样本很容易使得在测 ...

  10. Pytorch:GAN生成对抗网络实现MNIST手写数字的生成

    github:https://github.com/SPECTRELWF/pytorch-GAN-study 个人主页:liuweifeng.top:8090 网络结构 最近在疯狂补深度学习一些基本架 ...

最新文章

  1. 15.1 自定义分词器
  2. HTTP2和HTTPS来不来了解一下?
  3. Java注释是一个大错误
  4. 针对新手的Java EE7和Maven项目-第3部分-定义ejb服务和jpa实体模块
  5. Win10远程连接自己的电脑提示“登陆没有成功”的解决方案
  6. 能让中年人放下面子赚到钱的副业
  7. 《程序设计技术基础》第1-5章例程
  8. android密码用户名和密码错误,Android之输入用户名和密码验证
  9. 斐讯路由器K2弹广告-刷机过程
  10. fences卸载_fences是什么意思?WIN10专业版彻底删除fences的技巧
  11. 实验5 Spark SQL 编程初级实践
  12. 如何确定硕士毕业论文选题?
  13. 增强现实技术(AR)的103个应用场景汇总
  14. sl4a库_SL4A、QPython学习笔记(2)
  15. Android Hybrid 方案之 离线文件加载
  16. OC-NSString
  17. 5A成绩通过PMP,备考经验总结——姜飞
  18. GPU中实现反距离加权插值(IDW)
  19. Proteus 中 Virtual Terminal无法自动弹出窗口的问题的解决
  20. 基于Wav2Lip的AI主播

热门文章

  1. vs2019 无法打开包括文件:“SDKDDKVer.h”: No such file or directory的另外一种解决思路
  2. dw添加下拉菜单_dreamweaver cs6中网页制作一个带有列表下拉菜单的详细操作方法...
  3. lj245a引脚功能图_lm324工作原理_引脚图功能_特性参数_内部电路及应用电路
  4. kali自带发包工具tcpreplay
  5. 在线画板_在线画画_在线画图工具-速写板
  6. CAD中 OLE不能旋转_AutoCAD小秘密042:光栅图像和OLE图像,究竟如何选择
  7. 卷积码编码和译码c语言,卷积码编码和译码.doc
  8. 虚拟现实技术虚拟校园解决方案
  9. 利用计算机指令清理垃圾,怎么用命令来清理系统垃圾
  10. HTML+css网站设计布局模板