文章目录

  • 引入
  • 1 生成器
  • 2 鉴别器
  • 3 模型训练:生成器与鉴别器的交互
  • 4 参数设置
  • 5 数据载入
  • 6 完整代码
  • 7 部分输出图像示意
    • 7.1 真实图像
    • 7.2 训练200个批次
    • 7.2 训练400个批次
    • 7.2 训练600个批次

引入

  论文详解Unsupervised representation learning with deep convolutional generative adversarial networks
  对抗生成网络的核心在于生成器鉴别器以及两者之间的交互,本文将详细对这几个部分进行介绍。

1 生成器

  DCGAN生成器的本质是多个卷积层、批量归一化、激活函数的堆叠,具体结构如下表:

结构 输入通道 输出通道 卷积核大小 步幅 填充 后续
ConvTranspose2d nz ngf×\times× 8 4 1 0 BatchNorm2d+ReLU
ConvTranspose2d ngf×\times× 8 ngf×\times× 4 4 2 1 BatchNorm2d+ReLU
ConvTranspose2d ngf×\times× 4 ngf×\times× 2 4 2 1 BatchNorm2d+ReLU
ConvTranspose2d ngf×\times× 2 ngf 4 2 1 BatchNorm2d+ReLU
ConvTranspose2d ngf nc 4 2 1 Tanh

其中nz为输入通道数,ngf为给定结点数,nc是输出类别数。例如对于MNIST数据集,可设置nz=100,ngf=64,nc=1。
  对应代码如下:

class Generator(nn.Module):"""生成器"""def __init__(self):super(Generator, self).__init__()# 使用的GPU数量self.ngpu = ngpu# 生成器结构,与表格中一致self.main = nn.Sequential(# 输入大小:nznn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),nn.BatchNorm2d(ngf * 8),nn.ReLU(True),# 大小:(ngf*8) x 4 x 4nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf * 4),nn.ReLU(True),# 大小:(ngf*4) x 8 x 8nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf * 2),nn.ReLU(True),# 大小:(ngf*2) x 16 x 16nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf),nn.ReLU(True),# 大小:(ngf) x 32 x 32nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),nn.Tanh()# 大小:(nc) x 64 x 64)def forward(self, input):if input.is_cuda and self.ngpu > 1:output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))else:output = self.main(input)return output

  生成器的结构输出如下 (接下来都以mnist为例):

Generator((main): Sequential((0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)(3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(5): ReLU(inplace=True)(6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(8): ReLU(inplace=True)(9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(11): ReLU(inplace=True)(12): ConvTranspose2d(64, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(13): Tanh())
)

2 鉴别器

  鉴别器与生成器的不同之处在于,其卷积层、激活函数的设置不同,输入通道数也是逐渐增加的:

结构 输入通道 输出通道 卷积核大小 步幅 填充 后续
Conv2d nc ndf 4 2 1 LeakyReLU(0.2)
Conv2d ndf ndf×\times× 2 4 2 1 BatchNorm2d+LeakyReLU(0.2)
Conv2d ndf×\times× 2 ndf×\times× 4 4 2 1 BatchNorm2d+LeakyReLU(0.2)
Conv2d ndf×\times× 4 ndf×\times× 8 4 2 1 BatchNorm2d+LeakyReLU(0.2)
Conv2d ndf×\times× 8 1 4 1 0 Sigmid()

其中nc为输入通道数,ndf为给定结点数。注意输出通道变为1了哟。
  对应代码如下:

class Discriminator(nn.Module):"""鉴别器"""def __init__(self):super(Discriminator, self).__init__()# 使用的GPU数量self.ngpu = ngpu# 鉴别器的结构self.main = nn.Sequential(# 输入大小: (nc) x 64 x 64nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),nn.LeakyReLU(0.2, inplace=True),# state size. (ndf) x 32 x 32nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),nn.BatchNorm2d(ndf * 2),nn.LeakyReLU(0.2, inplace=True),# state size. (ndf*2) x 16 x 16nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),nn.BatchNorm2d(ndf * 4),nn.LeakyReLU(0.2, inplace=True),# state size. (ndf*4) x 8 x 8nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),nn.BatchNorm2d(ndf * 8),nn.LeakyReLU(0.2, inplace=True),# state size. (ndf*8) x 4 x 4nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),nn.Sigmoid())def forward(self, input):# 与生成器类似哟if input.is_cuda and self.ngpu > 1:output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))else:output = self.main(input)# 注意输出已经延展成一列的张量了return output.view(-1, 1).squeeze(1)

  鉴别器的结构输出如下:

Discriminator((main): Sequential((0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(1): LeakyReLU(negative_slope=0.2, inplace=True)(2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(4): LeakyReLU(negative_slope=0.2, inplace=True)(5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(7): LeakyReLU(negative_slope=0.2, inplace=True)(8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(10): LeakyReLU(negative_slope=0.2, inplace=True)(11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)(12): Sigmoid())
)

3 模型训练:生成器与鉴别器的交互

  下图绘制的是DCGAN生成器与鉴别器的交互过程,数字代表该步骤在程序中的运行过程:

  训练过程的代码如下:

def DCGAN():"""DCGAN主函数"""# 每一轮训练for epoch in range(opt.nepoch):# 每一个批次的训练for i, data in enumerate(dataset, 0):"""步骤1:训练鉴别器,即最大化log(D(x)) + log(1 - D(G(z)))"""# 首先基于真实图像进行训练# 鉴别器的梯度清零netD.zero_grad()# 格式化当前批次,这里data = data[0]的原因是因为所有批次的图像是放在一个列表里面的data = data[0].to(device)# 获取当前批次的图像的数量batch_size = data.size(0)# 将当前批次所有图像的标签设置为指定的真实标签,如1label = torch.full((batch_size, ), real_label, dtype=data.dtype, device=device)# 先鉴别器输出一下output = netD(data)# 计算鉴别器上基于真实图像计算的损失errorD_real = loss(output, label)errorD_real.backward()D_x = output.mean().item()# 训练虚假图像# 随机生成一个虚假图像noise = torch.randn(batch_size, nz, 1, 1, device=device)# 生成器开始造假fake = netG(noise)# 标签设置为假的标签label.fill_(fake_label)# 鉴别器来判断output = netD(fake.detach())# 假图片的损失errorD_fake = loss(output, label)errorD_fake.backward()# 假图片的梯度D_G_z1 = output.mean().item()# 鉴别器的总损失errorD = errorD_real + errorD_fake# 鉴别器优化一下optimD.step()"""训练生成器"""# 生成器梯度清零netG.zero_grad()# 生成器填真实标签,毕竟想造假label.fill_(real_label)# 得到假图片的输出output = netD(fake)# 计算生成器的损失errorG = loss(output, label)# 生成器梯度清零errorG.backward()# 再一次假图片的梯度D_G_z2 = output.mean().item()optimG.step()# 输出一些关键信息print("[%d/%d][%d/%d] lossD: %.4f lossG: %.4f ""D(x): %.4f D(G(z)): %.4f/%.4f" % (epoch, opt.nepoch, i, len(dataset),errorD.item(), errorG.item(),D_x, D_G_z1, D_G_z2))# 存储图像,可以设置想要的存储时间结点哈if i % 100 == 0:vutils.save_image(data, "real_image.png", normalize=True)fake = netG(fixed_noise)vutils.save_image(fake.detach(),'fake_image_%03d_%03d.png' % (epoch, i), normalize=True)

4 参数设置

  用的如下:

from __future__ import print_function
import argparse
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils

  相关参数设置,变量的信息于help中给出:

def get_parser():"""获取参数设置器"""parser = argparse.ArgumentParser()# 设置实验用数据集的类型,help中为所支持的数据集类型parser.add_argument("--dataset", required=False, default="mnist",help="数据集类型:cifar10 | lsun | mnist |imagenet | folder | lfw | fake")parser.add_argument("--data_root", required=False, help="数据集的存储路径",default=r"D:\Data\OneDrive\Code\MIL1\Data")parser.add_argument("--workers", type=int, default=2, help="数据集下载并行数")parser.add_argument("--batch_size", type=int, default=64, help="数据集的输入批次大小")parser.add_argument("--image_size", type=int, default=64, help="输入图像的高/宽")parser.add_argument("--nz", type=int, default=100, help="隐含向量z的大小")parser.add_argument("--ngf", type=int, default=64, help="生成器隐藏层结点数")parser.add_argument("--ndf", type=int, default=64, help="鉴别器隐藏层结点数")parser.add_argument("--nepoch", type=int, default=5, help="训练轮次数")parser.add_argument("--lr", type=float, default=0.0002, help="学习率")parser.add_argument("--beta1", type=float, default=0.5, help="Adam的beta1")parser.add_argument("--cuda", action="store_true", help="CUDA是否可用")parser.add_argument("--ngpu", type=int, default=1, help="GPU使用数量")parser.add_argument("--manual_seed", type=int, help="随机种子")parser.add_argument("--classes", default="bedroom", help="LSUN卧室数据集的列表划分分隔符")return parser.parse_args()

  管理随机种子

def get_seed():"""管理随机种子"""if opt.manual_seed is None:opt.manual_seed = random.randint(1, 10000)random.seed(opt.manual_seed)torch.manual_seed(opt.manual_seed)

  网络的权重等设置:

def init_weight(m):"""初始化权重"""classname = m.__class__.__name__if classname.find("Conv") != -1:torch.nn.init.normal_(m.weight, 0.0, 0.02)elif classname.find("BatchNorm") != -1:torch.nn.init.normal_(m.weight, 1.0, 0.02)torch.nn.init.zeros_(m.bias)

  主函数

if __name__ == '__main__':# 参数管理器opt = get_parser()# 设备device = torch.device("cuda:0" if opt.cuda else "cpu")# GPU数量、生成器输入通道数、生成器结点数设置、鉴别器结点数设置ngpu, nz, ngf, ndf = int(opt.ngpu), int(opt.nz), int(opt.ngf), int(opt.ndf)# 数据集、输出通道数dataset, nc = get_data()# 启动生成器netG = Generator().to(device)netG.apply(init_weight)# 启动鉴别器netD = Discriminator().to(device)netD.apply(init_weight)# 损失函数loss = nn.BCELoss()# 设置噪声及标签fixed_noise, real_label, fake_label = torch.randn(opt.batch_size, nz, 1, 1, device=device), 1, 0# 启动优化器optimG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))optimD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))DCGAN()

5 数据载入

  所集成的数据集如下:

def get_data():"""获取数据集"""# 输出通道数nc = 3if opt.dataset in ["imagenet", "folder", "lfw"]:dataset = dset.ImageFolder(root=opt.data_root,transform=transforms.Compose([transforms.Resize(opt.image_size),transforms.CenterCrop(opt.image_size),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),]))elif opt.dataset == "lsun":classes = [c + "_train" for c in opt.classes.split(',')]dataset = dset.LSUN(root=opt.data_root, classes=classes,transform=transforms.Compose([transforms.Resize(opt.image_size),transforms.CenterCrop(opt.image_size),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),]))elif opt.dataset == "cifar10":dataset = dset.CIFAR10(root=opt.data_root, download=True,transform=transforms.Compose([transforms.Resize(opt.image_size),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),]))elif opt.dataset == "mnist":dataset = dset.MNIST(root=opt.data_root, download=True,transform=transforms.Compose([transforms.Resize(opt.image_size),transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,)),]))nc = 1else:dataset = dset.FakeData(image_size=(3, opt.image_size, opt.image_size),transform=transforms.ToTensor())return (torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=int(opt.workers)),nc)

6 完整代码

from __future__ import print_function
import argparse
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutilsdef get_parser():"""获取参数设置器"""parser = argparse.ArgumentParser()# 设置实验用数据集的类型,help中为所支持的数据集类型parser.add_argument("--dataset", required=False, default="mnist",help="数据集类型:cifar10 | lsun | mnist |imagenet | folder | lfw | fake")parser.add_argument("--data_root", required=False, help="数据集的存储路径",default=r"D:\Data\OneDrive\Code\MIL1\Data")parser.add_argument("--workers", type=int, default=2, help="数据集下载并行数")parser.add_argument("--batch_size", type=int, default=64, help="数据集的输入批次大小")parser.add_argument("--image_size", type=int, default=64, help="输入图像的高/宽")parser.add_argument("--nz", type=int, default=100, help="隐含向量z的大小")parser.add_argument("--ngf", type=int, default=64, help="生成器隐藏层结点数")parser.add_argument("--ndf", type=int, default=64, help="鉴别器隐藏层结点数")parser.add_argument("--nepoch", type=int, default=5, help="训练轮次数")parser.add_argument("--lr", type=float, default=0.0002, help="学习率")parser.add_argument("--beta1", type=float, default=0.5, help="Adam的beta1")parser.add_argument("--cuda", action="store_true", help="CUDA是否可用")parser.add_argument("--ngpu", type=int, default=1, help="GPU使用数量")parser.add_argument("--manual_seed", type=int, help="随机种子")parser.add_argument("--classes", default="bedroom", help="LSUN卧室数据集的列表划分分隔符")return parser.parse_args()def get_seed():"""管理随机种子"""if opt.manual_seed is None:opt.manual_seed = random.randint(1, 10000)random.seed(opt.manual_seed)torch.manual_seed(opt.manual_seed)def get_data():"""获取数据集"""# 输出通道数nc = 3if opt.dataset in ["imagenet", "folder", "lfw"]:dataset = dset.ImageFolder(root=opt.data_root,transform=transforms.Compose([transforms.Resize(opt.image_size),transforms.CenterCrop(opt.image_size),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),]))elif opt.dataset == "lsun":classes = [c + "_train" for c in opt.classes.split(',')]dataset = dset.LSUN(root=opt.data_root, classes=classes,transform=transforms.Compose([transforms.Resize(opt.image_size),transforms.CenterCrop(opt.image_size),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),]))elif opt.dataset == "cifar10":dataset = dset.CIFAR10(root=opt.data_root, download=True,transform=transforms.Compose([transforms.Resize(opt.image_size),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),]))elif opt.dataset == "mnist":dataset = dset.MNIST(root=opt.data_root, download=True,transform=transforms.Compose([transforms.Resize(opt.image_size),transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,)),]))nc = 1else:dataset = dset.FakeData(image_size=(3, opt.image_size, opt.image_size),transform=transforms.ToTensor())return (torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=int(opt.workers)),nc)def init_weight(m):"""初始化权重"""classname = m.__class__.__name__if classname.find("Conv") != -1:torch.nn.init.normal_(m.weight, 0.0, 0.02)elif classname.find("BatchNorm") != -1:torch.nn.init.normal_(m.weight, 1.0, 0.02)torch.nn.init.zeros_(m.bias)class Generator(nn.Module):"""生成器"""def __init__(self):super(Generator, self).__init__()# 使用的GPU数量self.ngpu = ngpu# 生成器结构,与表格中一致self.main = nn.Sequential(# 输入大小:nznn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),nn.BatchNorm2d(ngf * 8),nn.ReLU(True),# 大小:(ngf*8) x 4 x 4nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf * 4),nn.ReLU(True),# 大小:(ngf*4) x 8 x 8nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf * 2),nn.ReLU(True),# 大小:(ngf*2) x 16 x 16nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf),nn.ReLU(True),# 大小:(ngf) x 32 x 32nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),nn.Tanh()# 大小:(nc) x 64 x 64)def forward(self, input):if input.is_cuda and self.ngpu > 1:output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))else:output = self.main(input)return outputclass Discriminator(nn.Module):"""鉴别器"""def __init__(self):super(Discriminator, self).__init__()# 使用的GPU数量self.ngpu = ngpu# 鉴别器的结构self.main = nn.Sequential(# 输入大小: (nc) x 64 x 64nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),nn.LeakyReLU(0.2, inplace=True),# state size. (ndf) x 32 x 32nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),nn.BatchNorm2d(ndf * 2),nn.LeakyReLU(0.2, inplace=True),# state size. (ndf*2) x 16 x 16nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),nn.BatchNorm2d(ndf * 4),nn.LeakyReLU(0.2, inplace=True),# state size. (ndf*4) x 8 x 8nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),nn.BatchNorm2d(ndf * 8),nn.LeakyReLU(0.2, inplace=True),# state size. (ndf*8) x 4 x 4nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),nn.Sigmoid())def forward(self, input):# 与生成器类似哟if input.is_cuda and self.ngpu > 1:output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))else:output = self.main(input)# 注意输出已经延展成一列的张量了return output.view(-1, 1).squeeze(1)def DCGAN():"""DCGAN主函数"""# 每一轮训练for epoch in range(opt.nepoch):# 每一个批次的训练for i, data in enumerate(dataset, 0):"""步骤1:训练鉴别器,即最大化log(D(x)) + log(1 - D(G(z)))"""# 首先基于真实图像进行训练# 鉴别器的梯度清零netD.zero_grad()# 格式化当前批次,这里data = data[0]的原因是因为所有批次的图像是放在一个列表里面的data = data[0].to(device)# 获取当前批次的图像的数量batch_size = data.size(0)# 将当前批次所有图像的标签设置为指定的真实标签,如1label = torch.full((batch_size, ), real_label, dtype=data.dtype, device=device)# 先鉴别器输出一下output = netD(data)# 计算鉴别器上基于真实图像计算的损失errorD_real = loss(output, label)errorD_real.backward()D_x = output.mean().item()# 训练虚假图像# 随机生成一个虚假图像noise = torch.randn(batch_size, nz, 1, 1, device=device)# 生成器开始造假fake = netG(noise)# 标签设置为假的标签label.fill_(fake_label)# 鉴别器来判断output = netD(fake.detach())# 假图片的损失errorD_fake = loss(output, label)errorD_fake.backward()# 假图片的梯度D_G_z1 = output.mean().item()# 鉴别器的总损失errorD = errorD_real + errorD_fake# 鉴别器优化一下optimD.step()"""训练生成器"""# 生成器梯度清零netG.zero_grad()# 生成器填真实标签,毕竟想造假label.fill_(real_label)# 得到假图片的输出output = netD(fake)# 计算生成器的损失errorG = loss(output, label)# 生成器梯度清零errorG.backward()# 再一次假图片的梯度D_G_z2 = output.mean().item()optimG.step()# 输出一些关键信息print("[%d/%d][%d/%d] lossD: %.4f lossG: %.4f ""D(x): %.4f D(G(z)): %.4f/%.4f" % (epoch, opt.nepoch, i, len(dataset),errorD.item(), errorG.item(),D_x, D_G_z1, D_G_z2))# 存储图像,可以设置想要的存储时间结点哈if i % 100 == 0:vutils.save_image(data, "real_image.png", normalize=True)fake = netG(fixed_noise)vutils.save_image(fake.detach(),'fake_image_%03d_%03d.png' % (epoch, i), normalize=True)if __name__ == '__main__':# 参数管理器opt = get_parser()# 设备device = torch.device("cuda:0" if opt.cuda else "cpu")# GPU数量、生成器输入通道数、生成器结点数设置、鉴别器结点数设置ngpu, nz, ngf, ndf = int(opt.ngpu), int(opt.nz), int(opt.ngf), int(opt.ndf)# 数据集、输出通道数dataset, nc = get_data()# 启动生成器netG = Generator().to(device)netG.apply(init_weight)# 启动鉴别器netD = Discriminator().to(device)netD.apply(init_weight)# 损失函数loss = nn.BCELoss()# 设置噪声及标签fixed_noise, real_label, fake_label = torch.randn(opt.batch_size, nz, 1, 1, device=device), 1, 0# 启动优化器optimG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))optimD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))DCGAN()

7 部分输出图像示意

7.1 真实图像

7.2 训练200个批次

7.2 训练400个批次

7.2 训练600个批次

  设备有限,就这么多了:

torch学习 (三十七):DCGAN详解相关推荐

  1. Mybatis源码学习(三)SqlSession详解

    前言 上一章节我们学习了SqlSessionFactory的源码,SqlSessionFactory中的方法都是围绕着SqlSession来的.,那么SqlSession又是什么东东呢?这一章节我们就 ...

  2. 深度学习网络模型——Vision Transformer详解 VIT详解

    深度学习网络模型--Vision Transformer详解 VIT详解 通用深度学习网络效果改进调参训练公司自己的数据集,训练步骤记录: 代码实现version-Transformer网络各个流程, ...

  3. python 命令-python解析命令行参数的三种方法详解

    这篇文章主要介绍了python解析命令行参数的三种方法详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下 python解析命令行参数主要有三种方法: ...

  4. java多线程学习-java.util.concurrent详解

    http://janeky.iteye.com/category/124727 java多线程学习-java.util.concurrent详解(一) Latch/Barrier 博客分类: java ...

  5. ELK学习笔记之Logstash详解

    0x00 Logstash概述 官方介绍:Logstash is an open source data collection engine with real-time pipelining cap ...

  6. 【正点原子FPGA连载】 第三章 硬件资源详解 摘自【正点原子】DFZU2EG/4EV MPSoC 之FPGA开发指南V1.0

    1)实验平台:正点原子MPSoC开发板 2)平台购买地址:https://detail.tmall.com/item.htm?id=692450874670 3)全套实验源码+手册+视频下载地址: h ...

  7. Python基础学习之 os 模块详解

    Python基础学习之 os 模块详解 文章目录 Python基础学习之 os 模块详解 1. 路径操作 1.1 os.chdir(),切换当前工作目录: 1.2 os.getcwd(),返回工作目录 ...

  8. 大咖说*图书分享-Node布道师狼叔|三卷书详解Node.js

    狼书系列图书以Node.js为主,讲解了Node.js的基础知识.开发调试方法.源码原理和应用场景等内容,旨在向读者展示如何通过新的Node.js和npm编写出更具前端特色.更具工程化优势的代码. 嘉 ...

  9. Quartz学习之Cron表达式详解

    Quartz学习之Cron表达式详解 一.cron表达式结构 **二.各字段的含义** 解释: 注意要点: 三.示例 一.cron表达式结构 cron表达式从左到右(用空格隔开):**秒 分 小时 月 ...

最新文章

  1. C/C++中“#”和“##”的作用和用法
  2. 单机运行环境搭建之 --CentOS-6.4安装MySQL 5.6.10并修改MySQL的root用户密码
  3. JESD204B的AXI4-Lite时序分析(对比SRIO的AXI4-Lite时序分析)
  4. 使用 Eclipse C/C++ Development Toolkit 开发应用程序
  5. python万年历时钟_python实现万年历类calendar
  6. 使用持久内存开发工具包 (PMDK) 创建持久内存感知队列
  7. 【muduo源码分析】Buffer类的设计
  8. mysql 主从复制日志_mysql主从复制基于日志复制
  9. linux对硬盘进行分区吗,linux对4T硬盘进行分区
  10. F2FS文件系统一 设计背景及框架结构
  11. 国外一些DICOM资源下载网址
  12. Mybatis插件动态数据库链接
  13. 【网络存储】存储区域网络SAN
  14. 爱德泰科普 | 了解单模光纤跳线和多模光纤跳线,看着一篇就够了
  15. 测试:如何测试微信朋友圈的点赞功能
  16. 数组之entries
  17. 5G技术即将到来,5G网络的基本特点和应用你了解了多少
  18. 26.什么是梯度爆炸
  19. win10-yolov5环境搭建
  20. 用vant做一个登陆页面

热门文章

  1. torch.mean()
  2. 网页骨架屏自动生成方案(dps
  3. php调用nusoap 实现soap出现Premature end of data in tag html异常
  4. python word自动排版_用Python实现Word文档的自动比较
  5. 头插法建立单链表(详解版)
  6. G - dfs POJ - 2386
  7. 程序员的十年之功(经典文章翻译)
  8. erp5开源制造业erp外协加工设置
  9. ffmpeg视频处理教程
  10. html学生选课系统源码,学生选课系统