DCGAN原理
Deep Convolution Generative Adversarial Networks(深度卷积生成对抗网络)

DCGAN是将CNN 与 GAN 结合,原理与GAN一样,只是将G和D换成两个卷积神经网络(CNN),DCGAN将CNN做了一些改变,为提高样本质量和收敛速度。

·strided convolution 替代确定性的pooling(从而可以让网络自己学习downsampling(下采样) 。G 网络中使用微步幅度卷积(fractionally strided convolution) 代替 pooling 层,D 网络中使用步幅卷积(strided convolution)代替 pooling 层。
·在 D 和 G 中均使用 batch normalization批量归一化
·去掉 FC 层,使网络变为全卷积网络
·G 网络中使用 ReLU 激活函数,最后一层使用 tanh激活函数
·D 网络中所有层都使用 LeakyReLU 作为激活函数

下图为 DCGAN 中 G 的具体网络结构:
生成器的输入是一个 100 维的噪声,中间会通过 4 层卷积层,每通过一个卷 积层通道数减半,长宽扩大一倍 ,最终产生一个 64643 大小的图片输出.

值得注意的是G中卷积层是微步幅卷积是反卷积(deconv)
上图左边是反卷积,用 33 的卷积核把 22 的矩阵反卷积成 44 的矩阵;而右边是微步 幅度卷积,用 33 的卷积核把 33 的矩阵卷积成 55 的矩阵。
反卷积是在 整个输入矩阵周围添 0,微步幅度卷积把输入矩阵拆开,在每一个像素点的周围添 0。

代码下载链接:github DCGAN Pytorch

环境需求:
torch>=0.4.0
torchvision
matplotlib
numpy
scipy
pillow
urllib3
scikit-image

DCGAN代码(语句后面附注释)

import argparse
import os
import numpy as np
import mathimport torchvision.transforms as transforms
from torchvision.utils import save_imagefrom torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable      import torch.nn as nn
import torch.nn.functional as F
import torchos.makedirs("images", exist_ok=True)parser = argparse.ArgumentParser()          #命令行选项、参数和子命令解析器
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")  #迭代次数
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")          #batch大小
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")            #学习率
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient") #动量梯度下降第一个参数
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient") #动量梯度下降第二个参数
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation") #CPU个数
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")  #噪声数据生成维度
parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension")  #输入数据的维度
parser.add_argument("--channels", type=int, default=1, help="number of image channels")      #输入数据的通道数
parser.add_argument("--sample_interval", type=int, default=400, help="interval between image sampling")  #保存图像的迭代数
opt = parser.parse_args()
print(opt)cuda = True if torch.cuda.is_available() else False        #判断GPU可用,有GPU用GPU,没有用CPUdef weights_init_normal(m):            #自定义初始化参数classname = m.__class__.__name__   #获得类名if classname.find("Conv") != -1:   #在类classname中检索到了Convtorch.nn.init.normal_(m.weight.data, 0.0, 0.02)elif classname.find("BatchNorm2d") != -1:torch.nn.init.normal_(m.weight.data, 1.0, 0.02)torch.nn.init.constant_(m.bias.data, 0.0)class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.init_size = opt.img_size // 4self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2)) #l1函数进行Linear变换。线性变换的两个参数是变换前的维度,和变换之后的维度self.conv_blocks = nn.Sequential(           #nn.sequential{}是一个组成模型的壳子,用来容纳不同的操作nn.BatchNorm2d(128),                    # BatchNorm2d的目的是使我们的一批(batch)feature map 满足均值0方差1,就是改变数据的量纲nn.Upsample(scale_factor=2),            #上采样,将图片放大两倍(这就是为啥class最先开始将图片的长宽除了4,下面还有一次放大2倍)nn.Conv2d(128, 128, 3, stride=1, padding=1), #二维卷积函数,(输入数据channel,输出的channel,步长,卷积核大小,padding的大小)nn.BatchNorm2d(128, 0.8),nn.LeakyReLU(0.2, inplace=True),        #relu激活函数nn.Upsample(scale_factor=2),            #上采样nn.Conv2d(128, 64, 3, stride=1, padding=1),#二维卷积nn.BatchNorm2d(64, 0.8),                #BNnn.LeakyReLU(0.2, inplace=True),nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),nn.Tanh(),                              #Tanh激活函数)def forward(self, z):out = self.l1(z)              #l1函数进行的是Linear变换 (第50行定义了)out = out.view(out.shape[0], 128, self.init_size, self.init_size)#view是维度变换函数,可以看到out数据变成了四维数据,第一个是batch_size(通过整个的代码,可明白),第二个是channel,第三,四是单张图片的长宽img = self.conv_blocks(out)return imgclass Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()def discriminator_block(in_filters, out_filters, bn=True):block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]#Conv卷积,Relu激活,Dropout将部分神经元失活,进而防止过拟合if bn:block.append(nn.BatchNorm2d(out_filters, 0.8))    #如果bn这个参数为True,那么就需要在block块里面添加上BatchNorm的归一化函数return blockself.model = nn.Sequential(*discriminator_block(opt.channels, 16, bn=False),*discriminator_block(16, 32),*discriminator_block(32, 64),*discriminator_block(64, 128),)# The height and width of downsampled imageds_size = opt.img_size // 2 ** 4self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid()) #先进行线性变换,再进行激活函数激活#上一句中 128是指model中最后一个判别模块的最后一个参数决定的,ds_size由model模块对单张图片的卷积效果决定的,而2次方是整个模型是选取的长宽一致的图片def forward(self, img):out = self.model(img)out = out.view(out.shape[0], -1)    #将处理之后的数据维度变成batch * N的维度形式validity = self.adv_layer(out)      #第92行定义return validity# Loss function
adversarial_loss = torch.nn.BCELoss()         #定义了一个BCE损失函数# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()if cuda:                                #初始化,将数据放在cuda上generator.cuda()discriminator.cuda()adversarial_loss.cuda()# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(     #显卡加速datasets.MNIST("../../data/mnist",                  #进行训练集下载train=True,download=True,transform=transforms.Compose([transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]),),batch_size=opt.batch_size,shuffle=True,
)# Optimizers                             定义神经网络的优化器  Adam就是一种优化器
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor# ----------
#  Training
# ----------for epoch in range(opt.n_epochs):for i, (imgs, _) in enumerate(dataloader):# Adversarial ground truthsvalid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False)fake = Variable(Tensor(imgs.shape[0], 1).fill_(0.0), requires_grad=False)# Configure inputreal_imgs = Variable(imgs.type(Tensor))     #将真实的图片转化为神经网络可以处理的变量# -----------------#  Train Generator# -----------------optimizer_G.zero_grad()   #把梯度置零  每次训练都将上一次的梯度置零,避免上一次的干扰# Sample noise as generator inputz = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))#生成的噪音 随机构00维向量 均值0方差1维度(64,100)的噪音,随机初始化一个64大小batch的向量# 输入0到1之间,形状为imgs.shape[0], opt.latent_dim的随机高斯数据。np.random.normal()正态分布# Generate a batch of imagesgen_imgs = generator(z)           #得到一个批次的图片# Loss measures generator's ability to fool the discriminatorg_loss = adversarial_loss(discriminator(gen_imgs), valid)g_loss.backward()         #反向传播和模型更新optimizer_G.step()# ---------------------#  Train Discriminator# ---------------------optimizer_D.zero_grad()# Measure discriminator's ability to classify real from generated samplesreal_loss = adversarial_loss(discriminator(real_imgs), valid)     #判别器判别真实图片是真的的损失fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)  #判别器判别假图片是假的的损失d_loss = (real_loss + fake_loss) / 2     #判别器去判别真实图片是真的和生成图片是假的的损失之和,让这个和越大,说明判别器越准确d_loss.backward()optimizer_D.step()print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item()))batches_done = epoch * len(dataloader) + iif batches_done % opt.sample_interval == 0:save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)

基于pytorch的DCGAN代码实现(DCGAN基本原理+代码讲解)相关推荐

  1. 基于pytorch搭建多特征CNN-LSTM时间序列预测代码详细解读(附完整代码)

    系列文章目录 lstm系列文章目录 1.基于pytorch搭建多特征LSTM时间序列预测代码详细解读(附完整代码) 2.基于pytorch搭建多特征CNN-LSTM时间序列预测代码详细解读(附完整代码 ...

  2. 基于Pytorch的NLP入门任务思想及代码实现:判断文本中是否出现指定字

    今天学了第一个基于Pytorch框架的NLP任务: 判断文本中是否出现指定字 思路:(注意:这是基于字的算法) 任务:判断文本中是否出现"xyz",出现其中之一即可 训练部分: 一 ...

  3. 基于PyTorch的TinyMind 汉字书法识别部分代码详解

    文章目录 0. 前言 1 遇到的问题 1.1 NameError: name 'cv2' is not defined 1.1.1 OpenCV下载 1.1.2 OpenCV安装 打开Anacoda3 ...

  4. 【项目实战课】基于Pytorch的DCGAN人脸嘴部表情图像生成实战

    欢迎大家来到我们的项目实战课,本期内容是<基于Pytorch的DCGAN人脸嘴部表情图像生成实战>. 所谓项目实战课,就是以简单的原理回顾+详细的项目实战的模式,针对具体的某一个主题,进行 ...

  5. vs2019 利用Pytorch和TensorFlow分别实现DCGAN生成动漫头像

    这是针对于博客vs2019安装和使用教程(详细)的DCGAN生成动漫头像项目新建示例 目录 一.DCGAN架构及原理 二.项目结构 1.TensorFlow 2.Pytorch 三.数据集下载(两种方 ...

  6. Deep Learning:基于pytorch搭建神经网络的花朵种类识别项目(内涵完整文件和代码)—超详细完整实战教程

    基于pytorch的深度学习花朵种类识别项目完整教程(内涵完整文件和代码) 相关链接:: 超详细--CNN卷积神经网络教程(零基础到实战) 大白话pytorch基本知识点及语法+项目实战 文章目录 基 ...

  7. 【深度学习】梯度和方向导数概念解析(代码基于Pytorch实现)

    [深度学习]梯度和方向导数概念解析(代码基于Pytorch实现) 文章目录 1 方向导数 2 梯度 3 自动求导实现 4 梯度下降4.1 概述4.2 小批量梯度下降 5 总结 1 方向导数 方向导数的 ...

  8. Pytorch之CNN:基于Pytorch框架实现经典卷积神经网络的算法(LeNet、AlexNet、VGG、NIN、GoogleNet、ResNet)——从代码认知CNN经典架构

    Pytorch之CNN:基于Pytorch框架实现经典卷积神经网络的算法(LeNet.AlexNet.VGG.NIN.GoogleNet.ResNet)--从代码认知CNN经典架构 目录 CNN经典算 ...

  9. 【图像分类】基于PyTorch搭建LSTM实现MNIST手写数字体识别(双向LSTM,附完整代码和数据集)

    写在前面: 首先感谢兄弟们的关注和订阅,让我有创作的动力,在创作过程我会尽最大能力,保证作品的质量,如果有问题,可以私信我,让我们携手共进,共创辉煌. 在https://blog.csdn.net/A ...

  10. 基于PyTorch搭建CNN实现视频动作分类任务代码详解

    数据及具体讲解来源: 基于PyTorch搭建CNN实现视频动作分类任务 import torch import torch.nn as nn import torchvision.transforms ...

最新文章

  1. 如何快速融入团队(八)
  2. ASP.NET MVC4中@model使用多个类型实例的方法
  3. 每天OnLineJudge 之 “杨辉三角 ”
  4. innodb行锁理解
  5. AAAI21最佳论文Informer:效果远超Transformer的长序列预测神器!
  6. 获取页面可见区域,屏幕区域的尺寸
  7. maven ant_如何在Maven中运行Ant目标?
  8. Hyperledger Fabric 1.0 从零开始(十二)——fabric-sdk-java应用
  9. 【Python】logging内置模块基本使用
  10. Spring Cloud学习笔记-009
  11. python绘制国际象棋规则口诀_用Python编写一个国际象棋AI程序
  12. django小站,数据3w+
  13. (一)伤不起--java调用dll
  14. 8个经典无线射频识别(RFID)优选方案
  15. 三角色:程序员、技术主管与架构师
  16. 苹果电脑修改MAC地址方法
  17. AI cs5序列号 注册机
  18. 从键盘输入一个整数,判断它是正数,负数,0
  19. 这个Kaggle三项排行榜的“顶级大师”,今年17岁
  20. 支持M1芯片Mac电脑的 Adobe Photoshop 2021 for Mac 中文版本

热门文章

  1. 一个关于银行卡号规则的问题,根据规则进行银行卡的验证
  2. SRM高维特征隐写分析原理与应用
  3. CSG:清华大学提出通过分化类特定卷积核来训练可解释的卷积网络 | ECCV 2020 Oral
  4. 球员评历史最佳阵:乔丹魔术师坚如磐石詹皇选自己
  5. 经典卷积神经网络模型盘点
  6. 实战演示k8s部署go服务,实现滚动更新、重新创建、蓝绿部署、金丝雀发布
  7. 微信和QQ内置浏览器为什么老是提示已停止访问该网页?
  8. foss测试_印度最大的针对语言技术的FOSS活动
  9. android签名图片不显示,android手写签名遇见bitmap黑屏和本地html插入签名图片
  10. 极客学院Docker45集视频教程 Docker全面解读零基础实战全套