轻松学Pytorch – 构建生成对抗网络
点击上方“小白学视觉”,选择加"星标"或“置顶”
重磅干货,第一时间送达
又好久没有继续写了,这个是我写的第21篇文章,我还在继续坚持写下去,虽然经常各种拖延症,但是我还记得,一直没有敢忘记!今天给大家分享一下Pytorch生成对抗网络代码实现。
01.什么是生成对抗网络
Ian J. Goodfellow在2014年提出生成对抗网络,从此打开了深度学习中另外一个重要分支,让生成对抗网络(GAN)成为与卷积神经网络(CNN)、循环神经网络(RNN/LSTM)可以并驾齐驱的分支领域。今天GAN仍然是计算机视觉领域研究热点之一,每年还有大量相关的论文产生,GAN已经被用在视觉任务的很多方面,主要包括:
图像合成与数据增广
图像翻译与变换
缺陷检测
图像去噪与重建
图像分割
但是GAN最基本的核心思想还是2014年Ian J. Goodfellow在论文中提到的两个基本的模型分别是:生成器与判别器
生成器(G):
根据输入噪声Z生成输出样本G(z)
目标:通过生成样本与目标样本分布一致,成功欺骗鉴别器
判别器(D):
根据输入样本数据来分辨真实样本概率
从数据中学习样本数据的差异性
从a到d,可以看到输入噪声的生成分布越来越接近真实分布X,最终达到一种平衡状态,这种稳定的平衡状态叫纳什均衡,还有一部电影跟这个有关系叫《美丽心灵》。
02.GAN代码实现
下面的代码实现了基于Mnist数据集实现判别器与生成器,最终通过生成器可以自动生成手写数字识别的图像,输入的z=100是随机噪声,输出的是784个数据表示28x28大小的手写数字样本,损失主要来自两个部分,生成器生成损失,判别器分别判别真实与虚构样本概率,基于反向传播训练两个网络,设置epoch=100,得到最终的生成器生成结果如下:
生成器与判别器代码实现如下
判别器与生成器代码:(后面文字忽略)2004论文中提出,其主要思想可以通过下面一张图像解释:
1transform = tv.transforms.Compose([tv.transforms.ToTensor(),2 tv.transforms.Normalize((0.5,), (0.5,))])3train_ts = tv.datasets.MNIST(root='./data', train=True, download=True, transform=transform)4test_ts = tv.datasets.MNIST(root='./data', train=False, download=True, transform=transform)5train_dl = DataLoader(train_ts, batch_size=128, shuffle=True, drop_last=False)6test_dl = DataLoader(test_ts, batch_size=128, shuffle=True, drop_last=False)789class Generator(t.nn.Module):
10 def __init__(self, g_input_dim, g_output_dim):
11 super(Generator, self).__init__()
12 self.fc1 = t.nn.Linear(g_input_dim, 256)
13 self.fc2 = t.nn.Linear(self.fc1.out_features, self.fc1.out_features * 2)
14 self.fc3 = t.nn.Linear(self.fc2.out_features, self.fc2.out_features * 2)
15 self.fc4 = t.nn.Linear(self.fc3.out_features, g_output_dim)
16
17 # forward method
18 def forward(self, x):
19 x = F.leaky_relu(self.fc1(x), 0.2)
20 x = F.leaky_relu(self.fc2(x), 0.2)
21 x = F.leaky_relu(self.fc3(x), 0.2)
22 return t.tanh(self.fc4(x))
23
24
25class Discriminator(t.nn.Module):
26 def __init__(self, d_input_dim):
27 super(Discriminator, self).__init__()
28 self.fc1 = t.nn.Linear(d_input_dim, 1024)
29 self.fc2 = t.nn.Linear(self.fc1.out_features, self.fc1.out_features // 2)
30 self.fc3 = t.nn.Linear(self.fc2.out_features, self.fc2.out_features // 2)
31 self.fc4 = t.nn.Linear(self.fc3.out_features, 1)
32
33 # forward method
34 def forward(self, x):
35 x = F.leaky_relu(self.fc1(x), 0.2)
36 x = F.dropout(x, 0.3)
37 x = F.leaky_relu(self.fc2(x), 0.2)
38 x = F.dropout(x, 0.3)
39 x = F.leaky_relu(self.fc3(x), 0.2)
40 x = F.dropout(x, 0.3)
41 return t.sigmoid(self.fc4(x))
损失与训练代码如下
分别定义生成网络训练与鉴别网络的训练方法,然后开始训练即可,代码实现如下:
1# 生成者与判别者2bs = 1283z_dim = 1004mnist_dim = 7845# loss6criterion = t.nn.BCELoss()78# optimizer9device = "cuda"
10gnet = Generator(g_input_dim = z_dim, g_output_dim = mnist_dim).to(device)
11dnet = Discriminator(mnist_dim).to(device)
12lr = 0.0002
13G_optimizer = t.optim.Adam(gnet.parameters(), lr=lr)
14D_optimizer = t.optim.Adam(dnet.parameters(), lr=lr)
15
16
17def D_train(x):
18 # =======================Train the discriminator=======================#
19 dnet.zero_grad()
20
21 # train discriminator on real
22 x_real, y_real = x.view(-1, mnist_dim), t.ones(bs, 1)
23 x_real, y_real = Variable(x_real.to(device)), Variable(y_real.to(device))
24
25 D_output = dnet(x_real)
26 D_real_loss = criterion(D_output, y_real)
27
28 # train discriminator on facke
29 z = Variable(t.randn(bs, z_dim).to(device))
30 x_fake, y_fake = gnet(z), Variable(t.zeros(bs, 1).to(device))
31
32 D_output = dnet(x_fake)
33 D_fake_loss = criterion(D_output, y_fake)
34
35 # gradient backprop & optimize ONLY D's parameters
36 D_loss = D_real_loss + D_fake_loss
37 D_loss.backward()
38 D_optimizer.step()
39
40 return D_loss.data.item()
41
42
43def G_train(x):
44 # =======================Train the generator=======================#
45 gnet.zero_grad()
46
47 z = Variable(t.randn(bs, z_dim).to(device))
48 y = Variable(t.ones(bs, 1).to(device))
49
50 G_output = gnet(z)
51 D_output = dnet(G_output)
52 G_loss = criterion(D_output, y)
53
54 # gradient backprop & optimize ONLY G's parameters
55 G_loss.backward()
56 G_optimizer.step()
57
58 return G_loss.data.item()
59
60
61n_epoch = 100
62for epoch in range(1, n_epoch+1):
63 D_losses, G_losses = [], []
64 for batch_idx, (x, _) in enumerate(train_dl):
65 bs_, _,_,_ = x.size()
66 bs = bs_
67 D_losses.append(D_train(x))
68 G_losses.append(G_train(x))
69
70 print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % (
71 (epoch), n_epoch, t.mean(t.FloatTensor(D_losses)), t.mean(t.FloatTensor(G_losses))))
下载1:OpenCV-Contrib扩展模块中文版教程在「小白学视觉」公众号后台回复:扩展模块中文教程,即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。下载2:Python视觉实战项目52讲
在「小白学视觉」公众号后台回复:Python视觉实战项目,即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。下载3:OpenCV实战项目20讲
在「小白学视觉」公众号后台回复:OpenCV实战项目20讲,即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。交流群欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器、自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN、算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~
轻松学Pytorch – 构建生成对抗网络相关推荐
- Pytorch Note46 生成对抗网络的数学原理
Pytorch Note46 生成对抗网络的数学原理 文章目录 Pytorch Note46 生成对抗网络的数学原理 全部笔记的汇总贴: Pytorch Note 快乐星球 之前介绍了什么是生成对抗, ...
- 基于PyTorch的生成对抗网络入门(3)——利用PyTorch搭建生成对抗网络(GAN)生成彩色图像超详解
目录 一.案例描述 二.代码详解 2.1 获取数据 2.2 数据集类 2.3 构建判别器 2.3.1 构造函数 2.3.2 测试判别器 2.4 构建生成器 2.4.1 构造函数 2.4.2 测试生成器 ...
- 『一起学AI』生成对抗网络(GAN)原理学习及实战开发
参考并翻译教程:https://d2l.ai/chapter_generative-adversarial-networks/gan.html,加入笔者的理解和心得 1.生成对抗网络原理 在Col ...
- 利用Tensorflow构建生成对抗网络GAN以生成数据
使用生成对抗网络(GAN)生成数据 本文主要内容 介绍了自动编码器的基本原理 比较了生成模型与自动编码器的区别 描述了GAN模型的网络结构 分析了GAN模型的目标核函数以及训练过程 介绍了利用Goog ...
- Pytorch:GAN生成对抗网络实现二次元人脸的生成
github:https://github.com/SPECTRELWF/pytorch-GAN-study 网络结构 最近在疯狂补深度学习一些基本架构的基础,看了一下大佬的GAN的原始论文,说实话一 ...
- Pytorch:GAN生成对抗网络实现MNIST手写数字的生成
github:https://github.com/SPECTRELWF/pytorch-GAN-study 个人主页:liuweifeng.top:8090 网络结构 最近在疯狂补深度学习一些基本架 ...
- 从零开始学keras之生成对抗网络GAN
生成对抗网络主要分为生成器网络和判别器网络. 生成器网络:他以一个随机向量(潜在空间的一个随机点)作为输入,并将其解码成一张合成图像. 判别器网络:以一张图像(真实的或合成的均可)作为输入,并预测该图 ...
- pytorch生成对抗网络GAN的基础教学简单实例(附代码数据集)
1.简介 这篇文章主要是介绍了使用pytorch框架构建生成对抗网络GAN来生成虚假图像的原理与简单实例代码.数据集使用的是开源人脸图像数据集img_align_celeba,共1.34G.生成器与判 ...
- 数据集制作_轻松学Pytorch自定义数据集制作与使用
点击上方蓝字关注我们 微信公众号:OpenCV学堂 关注获取更多计算机视觉与深度学习知识 大家好,这是轻松学Pytorch系列的第六篇分享,本篇你将学会如何从头开始制作自己的数据集,并通过DataLo ...
- pytorch argmax_轻松学Pytorch使用ResNet50实现图像分类
点击上方蓝字关注我们 微信公众号:OpenCV学堂 关注获取更多计算机视觉与深度学习知识 Hello大家好,这篇文章给大家详细介绍一下pytorch中最重要的组件torchvision,它包含了常见的 ...
最新文章
- POJ 3260 多重背包+完全背包
- 报名丨24小时创新挑战:数字时代的人类健康与福祉
- MySQL中varchar类型在5.0.3后的变化
- JAVA8 十大新特性
- mysql generic安装_MySQL 5.6 Generic Binary安装与配置
- 单核工作法15:循序渐进
- IE、FF的基本注意事项
- 电商产品评论数据情感分析代码详解
- Angular应用i18n - internationalization翻译的实现单步调试
- 产品运行所需的信息检索失败_为服务业注入新活力,华北工控推出服务机器人专用计算机产品方案...
- 那些一眼就被看出包装过的简历
- (36)FPGA面试题D触发器实现4进制计数器
- springboot+前端实现文件(图片)上传到指定目录
- idle点开没反应_翟天临、靳东,一个人越是没文化越是喜欢装
- 开源API网关系统:Kong简介
- win10默认系统字体更改
- 利用JS模拟排队系统
- 基于JAVA公立医院绩效考核系统计算机毕业设计源码+数据库+lw文档+系统+部署
- Mac下载vscode 缓慢?以下解决方法起飞下载
- pvid与vid详解
热门文章
- 用pytorch实现深度学习;60分钟闪电战
- 携创教育:东莞市2022年成人高考报名啦
- 12星座七月桃花运势,赶紧看看吧!
- 王干娘和西门庆-UMLChina建模知识竞赛第4赛季第18轮
- 我们整理了一份《程序员健康指南》!
- canvas学习笔记04
- 黑苹果安装,usb 不能用,键盘不能用
- [导读]整合Spring MVC由于用到jstl,所以假如jstl便签用的jar包,启动tomcat时控制台出现了如下的输出:2014-3-25 23:54:49 org.apache.catal
- TCP/UDP 数据传输的链路解析
- Mac illustrator 输入特殊字符(如希腊字符)