一:介绍
CGAN全程是Conditional Generative Adversarial Network,回想一下,传统的GAN或者其他的GAN都是通过一堆的训练数据,最后训练出了G网络,随机输入噪声最后产生的数据是这些训练数据类别中之一,我们提前无法预测是那哪一个?

因此,我们有的时候需要定向指定生成某些数据,比如我们想让G生成飞机,数字9,等等的图片数据。

怎么做呢:
1:就是给网络的输入噪声数据增加一些类别上的信息,就是说给定某些类别条件下,生成指定的数据,所以输入数据会有一些变化;

2:然后在损失函数那里,我们目标不再是输出1/0,也就是不再是简单的输出真实和构造。当判定是真实数据的时候,还需要判定出是哪一类别的图片。一般使用one-hot表示。

上图表示,改变输入噪声数据,给z增加类别y信息,怎么增加呢,就是简单的维度拼接,y可以是一个one-hot向量,或者其他表达形式。对于真实数据x不做变化,只用y来获取D的输出结果。

判别器D最后也应该输出是哪个类别,并且按照类别最小化来训练,也就是希望D(X)尽可能接近y。

二:实例操作
拿MNIST数据练手

网络的结构什么的都没有改变,唯一变化的就是,生成的噪声z拼接上了数据的类别标签,D的输出是数据的类别的one-hot向量,而不仅仅是0/1.
详细代码如下:

import torch
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
from torch import optim
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
from torch.autograd import Variable
import pickle
import copyimport matplotlib.gridspec as gridspec
from torchvision.utils import save_image
import os# 定义展示图片的函数
def show_images(images):  # 定义画图工具print('images: ', images.shape)images = np.reshape(images, [images.shape[0], -1])sqrtn = int(np.ceil(np.sqrt(images.shape[0])))sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))fig = plt.figure(figsize=(sqrtn, sqrtn))gs = gridspec.GridSpec(sqrtn, sqrtn)gs.update(wspace=0.05, hspace=0.05)for i, img in enumerate(images):ax = plt.subplot(gs[i])plt.axis('off')ax.set_xticklabels([])ax.set_yticklabels([])ax.set_aspect('equal')plt.imshow(img.reshape([sqrtimg, sqrtimg]))returndef deprocess_img(img):out = 0.5 * (img + 1)out = out.clamp(0, 1)out = out.view(-1, 1, 28, 28)return out# step 1: ===========================================加载数据
batch_size = 128
noise_dim = 100  # 噪声维度,还是选择100维度
label_dim = 10  # 标签维度,10个数字,10个维度
z_dimension = noise_dim + label_dim  # z dimension = 100 noise dim + 10 one-hot dimtransform_img = transforms.Compose([transforms.ToTensor()])
trainset = MNIST('./data', train=True, transform=transform_img, download=True)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)# step 2: ===========================================定义模型
class discriminator(nn.Module):def __init__(self):super(discriminator, self).__init__()self.dis = nn.Sequential(nn.Conv2d(1, 32, 5, stride=1, padding=2),nn.LeakyReLU(0.2, True),nn.MaxPool2d((2, 2)),nn.Conv2d(32, 64, 5, stride=1, padding=2),nn.LeakyReLU(0.2, True),nn.MaxPool2d((2, 2)))self.fc = nn.Sequential(nn.Linear(7 * 7 * 64, 1024),nn.LeakyReLU(0.2, True),nn.Linear(1024, 10),nn.Sigmoid())def forward(self, x):  # x: [batch_size, 1, 28, 28]x = self.dis(x)x = x.view(x.size(0), -1)x = self.fc(x)return x  # [batch_size, 10]class generator(nn.Module):def __init__(self, input_size, num_feature):super(generator, self).__init__()self.fc = nn.Linear(input_size, num_feature)  # 1*56*56self.gen = nn.Sequential(nn.BatchNorm2d(1),nn.ReLU(True),nn.Conv2d(1, 50, 3, stride=1, padding=1),nn.BatchNorm2d(50),nn.ReLU(True),nn.Conv2d(50, 25, 3, stride=1, padding=1),nn.BatchNorm2d(25),nn.ReLU(True),nn.Conv2d(25, 1, 2, stride=2),nn.Tanh())def forward(self, x):  # x: [batch_size, 110]x = self.fc(x)x = x.view(x.size(0), 1, 56, 56)x = self.gen(x)return x  # [batch_size, 1, 28, 28]# 实例化模型
D_Net = discriminator()
G_Net = generator(z_dimension, 3136)  # 1*56*56# step 3: ===========================================定义优化器和损失函数
criterion = nn.BCELoss()
d_optimizer = optim.Adam(D_Net.parameters(), lr=0.0003)
g_optimizer = optim.Adam(G_Net.parameters(), lr=0.0003)# step 4: ===========================================开始训练
if __name__ == "__main__":iter_count = 0show_every = 100epoch = 100gepoch = 1for i in range(epoch):for (img, label) in trainloader:img = Variable(img)print(img.shape)# 生成 lable 的 one-hot 向量,且设置对应类别位置是 1labels_onehot = np.zeros((img.shape[0], label_dim))labels_onehot[np.arange(img.shape[0]), label.numpy()] = 1# 生成随机向量,也就是噪声z,带有标签信息z = Variable(torch.randn(img.shape[0], noise_dim))z = np.concatenate((z.numpy(), labels_onehot), axis=1)z = Variable(torch.from_numpy(z).float())# 真实数据标签和虚假数据标签,real_label = Variable(torch.from_numpy(labels_onehot).float())  # 真实label对应类别是为1fake_label = Variable(torch.zeros(img.shape[0], label_dim))  # 假的label全是为0# compute loss of real_imgreal_out = D_Net(img)  # 真实图片送入判别器D输出0~1d_loss_real = criterion(real_out, real_label)  # 得到loss# compute loss of fake_imgfake_img = G_Net(z)  # 将向量放入生成网络G生成一张图片fake_out = D_Net(fake_img)  # 判别器判断假的图片d_loss_fake = criterion(fake_out, fake_label)  # 假的图片的loss# D bp and optimized_loss = d_loss_real + d_loss_faked_optimizer.zero_grad()  # 判别器D的梯度归零d_loss.backward()  # 反向传播d_optimizer.step()  # 更新判别器D参数# 生成器G的训练compute loss of fake_imgfor j in range(gepoch):fake_img = G_Net(z)  # 将向量放入生成网络G生成一张图片output = D_Net(fake_img)  # 经过判别器得到结果g_loss = criterion(output, real_label)  # 得到假的图片与真实标签的loss# bp and optimizeg_optimizer.zero_grad()  # 生成器G的梯度归零g_loss.backward()  # 反向传播g_optimizer.step()  # 更新生成器G参数print("G")# 利用模型进行测试,指定按照顺序生成0~9的数字if (iter_count % show_every == 0):test_batch_size = 10test_label = torch.from_numpy(np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))labels_onehot = np.zeros((test_batch_size, label_dim))labels_onehot[np.arange(test_batch_size), test_label.numpy()] = 1# 生成随机向量,也就是噪声z,带有标签信息test_z = Variable(torch.randn(test_batch_size, noise_dim))test_z = np.concatenate((test_z.numpy(), labels_onehot), axis=1)test_z = Variable(torch.from_numpy(test_z).float())fake_img = G_Net(test_z)  # 将向量放入生成网络G生成一张图片# imgs_numpy = deprocess_img(fake_img.data.cpu().numpy())# show_images(imgs_numpy)# plt.show()real_images = deprocess_img(fake_img.data)save_image(real_images, 'D:/software/Anaconda3/doc/3D_Img/cgan/test_%d.png' % (iter_count))iter_count += 1print('iter_count: ', iter_count)

最后按照顺序生成0~9的图像效果还是很不错的。





深度学习《CGAN模型》相关推荐

  1. 深度学习之自编码器(4)变分自编码器

    深度学习之自编码器(4)变分自编码器 1. VAE原理  基本的自编码器本质上是学习输入 x\boldsymbol xx和隐藏变量 z\boldsymbol zz之间映射关系,它是一个 判别模型(Di ...

  2. 深度学习之自编码器(5)VAE图片生成实战

    深度学习之自编码器(5)VAE图片生成实战 1. VAE模型 2. Reparameterization技巧 3. 网络训练 4. 图片生成 VAE图片生成实战完整代码  本节我们基于VAE模型实战F ...

  3. 深度学习之自编码器AutoEncoder

    深度学习之自编码器AutoEncoder 原文:http://blog.csdn.net/marsjhao/article/details/73480859 一.什么是自编码器(Autoencoder ...

  4. 深度学习之自编码器(3)自编码器变种

    深度学习之自编码器(3)自编码器变种 1. Denoising Auto-Encoder 2. Dropout Auto-Encoder 3. Adversarial Auto-Encoder  一般 ...

  5. 深度学习之自编码器(2)Fashion MNIST图片重建实战

    深度学习之自编码器(2)Fashion MNIST图片重建实战 1. Fashion MNIST数据集 2. 编码器 3. 解码器 4. 自编码器 5. 网络训练 6. 图片重建 完整代码  自编码器 ...

  6. 深度学习之自编码器(1)自编码器原理

    深度学习之自编码器(1)自编码器原理 自编码器原理  前面我们介绍了在给出样本及其标签的情况下,神经网络如何学习的算法,这类算法需要学习的是在给定样本 x\boldsymbol xx下的条件概率 P( ...

  7. 【深度学习】 自编码器(AutoEncoder)

    目录 RDAE稳健深度自编码 自编码器(Auto-Encoder) DAE 深度自编码器 RDAE稳健深度自编码 自编码器(Auto-Encoder) AE算法的原理 Auto-Encoder,中文称 ...

  8. 深入理解深度学习——Transformer:编码器(Encoder)部分

    分类目录:<深入理解深度学习>总目录 相关文章: ·注意力机制(AttentionMechanism):基础知识 ·注意力机制(AttentionMechanism):注意力汇聚与Nada ...

  9. 深度学习之自编码器实现——实现图像去噪

    大家好,我是带我去滑雪! 自编码器是一种无监督学习的神经网络,是一种数据压缩算法,主要用于数据降维和特征提取.它的基本思想是将输入数据经过一个编码器映射到隐藏层,再通过一个解码器映射到输出层,使得输出 ...

  10. 深度学习:自编码器、深度信念网络和深度玻尔兹曼机

    最近自己会把自己个人博客中的文章陆陆续续的复制到CSDN上来,欢迎大家关注我的 个人博客,以及我的github. 本文主要讲解有关自编码器.深度信念网络和深度玻尔兹曼机的相关知识. 一.自编码器 1. ...

最新文章

  1. (转自PHPer)成长的选择
  2. 定义员工类,职工类,管理类
  3. 让SAP云平台上的Web应用使用destination服务
  4. Delphi 的字符及字符串[4] - 字符串、字符指针与字符数组
  5. Spring MVC重定向和转发
  6. 使用QRCode生成二维码
  7. bert之我见 - positional encoding
  8. 【排列组合】只上代码不解释
  9. python将网页保存为pdf,python-网页保存为pdf
  10. 计算机兼容性测试怎么做,如何进行兼容性测试
  11. python教程,python小甲鱼
  12. 流媒体直播系统由哪几部分组成?
  13. 如何解决Mac电脑在启动时出现空白屏幕的情况?
  14. linux中数据库的4种状态,数据库的数据持久有几种方案_数据库_数据管理_数据结构_课课家...
  15. 《灰色预测(GM)的MATLAB实现》
  16. Python猫眼电影最近上映的电影票房信息
  17. SpringCloud—笔记(三)高级篇
  18. HR人力资源系统管理源码
  19. 爬取网站某网站所有通知
  20. CANopen从站伺服配置报文及使用

热门文章

  1. AOP的实现方式比较,cglib vs jdk
  2. Android Calender
  3. lintcode 中等题:Single number III 落单的数III
  4. 剖析 SurfaceView Callback以及SurfaceHolder
  5. 工作与生活 -- 平衡是必须的
  6. 几个重要的Linux系统内核文件介绍
  7. TypeScript Symbol
  8. 容器编排技术 -- Kubernetes 设计理念
  9. Tomcat 比 nio 、aio性能更好的apr介绍
  10. 2021 Axios 各种请求方式传递参数格式整理