1. 概念

模型通过框架中(至少)两个模块:生成模型(Generative Model)和判别模型(Discriminative Model)的互相博弈学习产生相当好的输出。原始 GAN理论中,并不要求 G 和 D 都是神经网络,只需要是能拟合相应生成和判别的函数即可。但实用中一般均使用深度神经网络作为 G 和 D 。
图例:

其目标函数:

简单来说,就是分每一轮训练分两步,首先固定G训练D:min C(D,1)+C(D(G),0)C(D,1)+C(D(G),0)C(D,1)+C(D(G),0),然后固定D训练G:min C(G,1)C(G,1)C(G,1)。其中C表示cross entrophy函数,后面的1表示判断为真实,0表示判断为虚假。

2. 简单的GAN代码分析

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision import transformstransform = transforms.Compose([transforms.ToTensor(),  transforms.Normalize(0.5, 0.5)
])
# 加载内置数据  做生成只需要图片就行,不需要标签 也不需要测试数据集
train_ds = torchvision.datasets.MNIST('data',   # 当前目录下的data文件夹train=True,  # train数据transform=transform,download=True)dataloader = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True)# 定义生成器
# 输入是长度为100的噪声
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.gen = nn.Sequential(nn.Linear(100, 256),  # 输入长度为100nn.ReLU(),nn.Linear(256, 512),nn.ReLU(),nn.Linear(512, 28 * 28),nn.Tanh())def forward(self, x):  # 定义前向传播 x表示长度为100的noise输入img = self.gen(x)img = img.view(-1, 28, 28)  return img# 定义判别器
# 输入为(1,28,28)的图片 输出为二分类的概率值,使用sigmoid激活
# BCEloss 计算交叉熵损失
# 判别器中推荐使用LeakyReLU
class Discriminator(nn.Module):def __init__(self):super(Discriminator,self).__init__()self.disc = nn.Sequential(nn.Linear(28*28, 512),nn.LeakyReLU(),nn.Linear(512, 256),nn.LeakyReLU(),nn.Linear(256, 1),nn.Sigmoid())def forward(self, x):x = x.view(-1, 28*28) x = self.disc(x)return x# 初始化模型
device = 'cuda' if torch.cuda.is_available() else 'cpu'
gen = Generator().to(device)
dis = Discriminator().to(device)# 定义优化器
d_optim = torch.optim.Adam(dis.parameters(), lr=0.0001)
g_optim = torch.optim.Adam(gen.parameters(), lr=0.0001)# 损失计算函数
loss_function = torch.nn.BCELoss()# 绘图函数
def gen_img_plot(model, test_input):prediction = np.squeeze(model(test_input).detach().cpu().numpy())fig = plt.figure(figsize=(4,4))for i in range(16):plt.subplot(4, 4, i+1)plt.imshow((prediction[i] + 1)/2)  # 由于tanh是在-1 1 之间 要恢复道0 1 之间plt.axis("off")plt.show()
test_input =torch.randn(16, 100, device=device)# 开始训练
D_loss = []
G_loss = []
# 训练循环
for epoch in range(50):d_epoch_loss = 0g_epoch_loss = 0batch_count = len(dataloader.dataset)# 对全部的数据集做一次迭代for step, (img, _) in enumerate(dataloader):img = img.to(device)  # 上传到设备上size = img.size(0)    # 返回img的第一维的大小random_noise = torch.randn(size, 100, device=device)d_optim.zero_grad()  # 将上述步骤的梯度归零real_output = dis(img)  # 对判别器输入真实的图片,real_output是对真实图片的预测结果d_real_loss = loss_function(real_output,torch.ones_like(real_output))d_real_loss.backward() #求解梯度# 得到判别器在生成图像上的损失gen_img = gen(random_noise)fake_output = dis(gen_img.detach())  # 判别器输入生成的图片,对生成图片的预测结果d_fake_loss = loss_function(fake_output,torch.zeros_like(fake_output))d_fake_loss.backward()d_loss = d_real_loss + d_fake_lossd_optim.step()  # 优化# 得到生成器的损失g_optim.zero_grad()fake_output = dis(gen_img)g_loss = loss_function(fake_output,torch.ones_like(fake_output))g_loss.backward()g_optim.step()with torch.no_grad():d_epoch_loss += d_lossg_epoch_loss += g_losswith torch.no_grad():d_epoch_loss /= batch_countg_epoch_loss /= batch_countD_loss.append(d_epoch_loss)G_loss.append(g_epoch_loss)print('Epoch:', epoch)

3. DCGAN:将全连接用卷积替代

DCGAN的生成器和鉴别器都舍弃了CNN的pooling层(池化层),鉴别器保留CNN的整体架构,生成器则是将卷积层替换成了反卷积层(ConvTranspose2d)
在鉴别器和生成器中使用了BN(Batch Normalization)层,加速模型训练,提升了训练的稳定性。但是在生成器的输出层和鉴别器的输入层不使用BN层【直接应用batchnorm到所有层会导致样本振荡和模型不稳定】
生成器网络中使用ReLU作为激活函数,最后一层使用Tanh()【使用有界激活(a bounded activation)可以让模型更快地学习,以饱和和覆盖训练分布的颜色空间】
鉴别器网络中使用LeakyReLU作为激活函数
使用Adam优化器,一阶矩估计的指数衰减率的值设置为0.5
代码变化的部分如下:

# 定义生成器
class Generator(nn.Module):def __init__(self):super(Generator,self).__init__()self.linear1 = nn.Linear(100, 256*7*7)  # 希望生成1*28*28的图片 7反卷积后14,再反卷积28 pytorch中channel在前self.bn1 = nn.BatchNorm1d(256*7*7)self.deconv1 = nn.ConvTranspose2d(256, 128,kernel_size=(3,3),stride=1,  padding=1 )   # 得到128*7*7的图像self.bn2 = nn.BatchNorm2d(128)self.deconv2 = nn.ConvTranspose2d(128, 64,kernel_size=(4,4),stride=2,padding=1  # 64*14*14)self.bn3 = nn.BatchNorm2d(64)self.deconv3 = nn.ConvTranspose2d(64, 1,kernel_size=(4, 4),stride=2,padding=1  # 1*28*28)def forward(self, x):x = F.relu(self.linear1(x))x = self.bn1(x)x = x.view(-1, 256, 7, 7)x = F.relu(self.deconv1(x))x = self.bn2(x)x = F.relu(self.deconv2(x))x = self.bn3(x)x = torch.tanh(self.deconv3(x))return x# 定义判别器
# input:1,28,28
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=2) # 第一层不适用bn  64,13,13self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2) #128,6,6self.bn = nn.BatchNorm2d(128)self.fc = nn.Linear(128*6*6, 1) # 输出一个概率值def forward(self, x):x = F.dropout2d(F.leaky_relu(self.conv1(x)))x = F.dropout2d(F.leaky_relu(self.conv2(x)))  # (batch, 128,6,6)x = self.bn(x)x = x.view(-1, 128*6*6)   # (batch, 128,6,6)--->  (batch, 128*6*6)x = torch.sigmoid(self.fc(x))return x

深度学习系列32:GAN入门:DCGAN相关推荐

  1. 【完结】给新手的12大深度学习开源框架快速入门项目

    文/编辑 | 言有三 这是一篇总结文,给大家来捋清楚12大深度学习开源框架的快速入门,这是有三AI的GitHub项目,欢迎大家star/fork. https://github.com/longpen ...

  2. [GAN学习系列2] GAN的起源

    本文大约 5000 字,阅读大约需要 10 分钟 这是 GAN 学习系列的第二篇文章,这篇文章将开始介绍 GAN 的起源之作,鼻祖,也就是 Ian Goodfellow 在 2014 年发表在 ICL ...

  3. 深度学习系列笔记——贰 (基于Tensorflow2 Keras搭建的猫狗大战模型 三)

    深度学习系列笔记--贰 (基于Tensorflow Keras搭建的猫狗大战模型 一) 深度学习系列笔记--贰 (基于Tensorflow Keras搭建的猫狗大战模型 二) 前面两篇博文已经介绍了如 ...

  4. 机器学习与深度学习系列连载(NTU-Machine Learning, cs229, cs231n, cs224n, cs294):欢迎进入机器学习的世界

    欢迎进入机器学习的世界 本教程是根据台湾大学李弘毅老师的课程机器学习课程,斯坦福大学CS229.CS231N.CS224N.CS20i.伦敦大学学院 ([UCL-Course])(http://www ...

  5. 五本必读的深度学习圣经书籍,入门 AI 从 深度学习 开始

    原标题:`五本必读的深度学习圣经书籍,入门 AI 从「深度学习」开始` (以下以 Daniel Jeffries 第一人称撰写) 多年来,由于实验室研究和现实应用效果之间的鸿沟,少有人持续研究人工智能 ...

  6. 【github干货】主流深度学习开源框架从入门到熟练

    文章首发于微信公众号<有三AI> [github干货]主流深度学习开源框架从入门到熟练 今天送上有三AI学院第一个github项目 01项目背景 目前深度学习框架呈百家争鸣之态势,光是为人 ...

  7. 【深度学习系列】卷积神经网络CNN原理详解(一)——基本原理(1)

    上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...

  8. python系列文章(基础,应用,后端,运维,自动化测试,爬虫,数据分析,可视化,机器学习,深度学习系列内容)

    python基础教程 python基础系列教程--Python的安装与测试:python解释器.PyDev编辑器.pycharm编译器 python基础系列教程--Python库的安装与卸载 pyth ...

  9. unet是残差网络吗_深度学习系列(三)卷积神经网络模型(ResNet、ResNeXt、DenseNet、DenceUnet)...

    深度学习系列(三)卷积神经网络模型(ResNet.ResNeXt.DenseNet.Dence Unet) 内容目录 1.ResNet2.ResNeXt3.DenseNet4.Dence Unet 1 ...

最新文章

  1. 4. 编程规范和编程安全指南--go语言
  2. 安卓模拟器BlueStacks 安装使用教程(图解)
  3. linux挂载windows共享的远程目录
  4. mysql中Invalid default value for 'stime'问题
  5. 消费者做出购买决策的流程
  6. 产品经理该如何全局思考和分析行业产业链?
  7. 巧妙解决:access denied (javax.management.MBeanTrust...
  8. 关于@Alias注解的几个问题
  9. ant指定servlet版本_[转载]程序开发常见错误
  10. 通过AccessKey调用阿里云CDN接口刷新CDN资源案例
  11. IDC:“互联网+流通”将进一步释放活力
  12. S5PV210体系结构与接口01:ARM体系结构概述
  13. JavaScript实现气球打字游戏
  14. mysql俩个表怎么创主外洁_单独招生面试题极其详细答案
  15. 国产蓝牙耳机哪款好?双11平价高性价比不输大牌蓝牙耳机推荐
  16. 2022年武汉市人工智能领域技术成果征集内容及申报条件
  17. 来自GPU的Hello World-基于Win10+VS2019+CUDA 11.0搭建CUDA编程环境
  18. 简单答题系统(判断题)
  19. 哪些产品需要做3C认证
  20. 35 行代码实现一个简单的 shell

热门文章

  1. 如果实现银行卡三四要素? 银行卡实名认证
  2. Android 仿饿了么点餐页面报告,RecyclerView仿饿了吗点菜页面
  3. 仓库规划方法论-六大原则
  4. E0070——不允许使用不完整的类型和E3365——不允许使用不完整的类类型解决办法
  5. css伪类元素 添加 选中下划线 自定义长度
  6. Render Flow of Divinity II (part 2 shadow map)
  7. android横竖屏切换动画,横竖屏切换时候Activity的生命周期
  8. css box-sizing:border-box
  9. python函数变量的作用域_python函数变量的作用域
  10. ESP8266-----SNTP获取网络时间