有一段时间没有写博客了,前段时间一直在看一篇关于SSAH算法的行人重识别,也是争对跨模态的一篇文章,只不过它跨的是图片和文本的模态,导师想让我把这篇文章中的想法应用到RGB-IR行人重识别当中来,但是研究了一段时间后发现这篇文章的算法并不适合RGB-IR的SYSU-MM01数据集,因此只能先放一边,以后能作个参考。之后看了一篇RGB-IR跨模态行人再识别的文章,叫《Cross-ModalityPersonRe-IdentificationwithGenerativeAdversarialTraining_IJCAI2018》是厦门大学的一个团队做的,里面用到了GAN,看了之后挺感兴趣,就去了解了一下它的原理,并且用GAN拿手写数字做了一下实验,做完之后才发现,那篇论文用的GAN跟生成图像用的GAN完全是两码事。

那篇论文其实是用了GAN的对坑思想,但是GAN最大的用处其实是生成图像,行人重识别说到底是一个分类问题,GAN基本上不会用在分类领域,但是对坑思想可以用上,既然做都做了,那还是写一篇博客记录一下吧

代码如下:

import torch
import torch.nn as nn
from torch import optim
from torch.autograd import variable
from torch.utils.data.dataloader import DataLoader
from torchvision import datasets,transforms
from torchvision.utils import save_imageG_in_dim=100#模型的参数参考别人的网络设置
D_in_dim=784
hidden1_dim=256
hidden2_dim=256
G_out_dim=784
D_out_dim=1epoch=50
batch_num=60
lr_rate=0.0003def to_img(x):#这个函数参考自别人的网络,是将生成的假图像经过一系列操作能更清晰的显示出来,具体为什么这样设置没研究过out = 0.5 * (x + 1)out = out.clamp(0, 1)out = out.view(-1, 1, 28, 28)return outclass G_Net(nn.Module):#生成网络,或者叫生成器,负责生成假数据def __init__(self):super().__init__()self.layer=nn.Sequential(nn.Linear(G_in_dim,hidden1_dim),nn.ReLU(),nn.Linear(hidden1_dim,hidden2_dim),nn.LeakyReLU(),nn.Linear(hidden2_dim,G_out_dim),nn.Tanh())def forward(self,x):x=self.layer(x)return xclass D_Net(nn.Module):#判别网络,或者叫判别器,用来判别数据真假def __init__(self):super().__init__()self.layer=nn.Sequential(nn.Linear(D_in_dim,hidden1_dim),nn.LeakyReLU(0.2),nn.Linear(hidden1_dim,hidden2_dim),nn.LeakyReLU(0.2),nn.Linear(hidden2_dim,D_out_dim),nn.Sigmoid())def forward(self,x):x=self.layer(x)return xdata_tf = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5], [0.5])])
train_set=datasets.MNIST(root='data',train=True,transform=data_tf,download=True)
train_loader=DataLoader(train_set,batch_size=batch_num,shuffle=True)g_net=G_Net()
d_net=D_Net()
if torch.cuda.is_available():g_net = g_net.cuda()d_net = d_net.cuda()criterion = nn.BCELoss()
G_optimizer = optim.Adam(g_net.parameters(), lr=lr_rate)
D_optimizer = optim.Adam(d_net.parameters(), lr=lr_rate)for e in range(epoch):for data in train_loader:img,l=dataimg = img.view(img.size(0), -1)if torch.cuda.is_available():img=variable(img).cuda()r_label = variable(torch.ones(batch_num)).cuda()f_label = variable(torch.zeros(batch_num)).cuda()g_input = variable(torch.randn(batch_num,G_in_dim)).cuda()r_output=d_net(img)r_loss=criterion(r_output,r_label)f_output=g_net(g_input)d_f_output=d_net(f_output)f_loss=criterion(d_f_output,f_label)sum_loss=r_loss+f_lossD_optimizer.zero_grad()sum_loss.backward()D_optimizer.step()if torch.cuda.is_available():g_input1 = variable(torch.randn(batch_num,G_in_dim)).cuda()g_output=g_net(g_input1)d_output=d_net(g_output)d_loss=criterion(d_output,r_label)G_optimizer.zero_grad()d_loss.backward()G_optimizer.step()g_img=g_net(variable(torch.randn(batch_num,G_in_dim)).cuda())images = to_img(g_img)save_image(images, './img/fake_images-{}.png'.format(e))

整体的实现思路可以总结如下:首先先训练判别器,用生成器生成假的图像和mnist中的真图像去训练判别器,因此判别器的输出就只有两种情况,真(1)或者假(0),从代码中可以看到,生成器是一个输入为100,输出为784的全连接网络,输入就用pytorch随机生成,也可以理解为输入一组随机噪声吧,在这个过程中,我们只更新判别器的参数,而不更新生成器的参数,损失就为真图片和假图片各自的损失和。之后保持判别器参数不变,我们去训练生成器,将100维数据输入到生成器中得到784维的数据,再将得到的这个784维数据输入到判别器中,将判别器的目标设置为1,计算损失并更新生成器的参数,这样生成器就会生成越来越像真图片的假图片了

我是看懂了原理顺着思路借鉴别人模型的参数实现了一遍,用的只有全连接层的神经网络,效果没有那么好,如果有卷积层应该会更好。训练了50个epoch,得到了如下一些假图像:

总的来说还可以,有几个生成的还是挺像的,这个算法其实很有意思,可惜以后应该不会研究它了

GAN生成手写数字假图片相关推荐

  1. 生成式对抗网络GAN生成手写数字

    GAN(Generative Adversarial Networks)是较为火热的一种神经网络,具有较多的优势和特点. 一.GAN 1. 原理 源自于零和博弈(zero-sum game),包括生成 ...

  2. 《Gans in Action》第三章 用GAN生成手写数字

    此为<Gans in Action>(对抗神经网络实战)第三章读书笔记 Chapter 3. Your first GAN: Generating handwritten digits 用 ...

  3. 深度学习之基于GAN实现手写数字生成

    在弄毕设的时候,室友的毕设是基于DCGAN实现音乐的自动生成.那是第一次接触对抗神经网络,当时听室友的描述就是两个CNN,一个生成一个监测,在互相博弈. 最近我关注的一个大神在弄有关于GAN的东西,所 ...

  4. 生成对抗网络(GAN)——MNIST手写数字生成

    前言 正文 一.什么是GAN 二.GAN的应用 三.GAN的网络模型 对抗生成手写数字 一.引入必要的库 一.引入必要的库 二.进行准备工作 三.定义生成器和判别器模型 四.设置损失函数和优化器,以及 ...

  5. GAN变种ACGAN利用手写数字识别mnist生成手写数字

    1.摘要 本文主要讲解:GAN变种ACGAN利用手写数字识别mnist数据集进行训练,最终生成手写数字图片 主要思路: Initialize generator and discriminator I ...

  6. 【万物皆可 GAN】生成对抗网络生成手写数字 Part 1

    [万物皆可 GAN]生成对抗网络生成手写数字 Part 1 概述 GAN 网络结构 GAN 训练流程 模型详解 生成器 判别器 概述 GAN (Generative Adversarial Netwo ...

  7. 深度卷积生成对抗网络DCGAN——生成手写数字图片

    前言 本文使用深度卷积生成对抗网络(DCGAN)生成手写数字图片,代码使用Keras API与tf.GradientTape 编写的,其中tf.GradientTrape是训练模型时用到的. 本文用到 ...

  8. 深度学习100例-生成对抗网络(GAN)手写数字生成 | 第18天

    文章目录 一.前期工作 1. 设置GPU 2. 定义训练参数 二.什么是生成对抗网络 1. 简单介绍 2. 应用领域 三.网络结构 四.构建生成器 五.构建鉴别器 六.训练模型 1. 保存样例图片 2 ...

  9. 利用GAN原始框架生成手写数字

    这一篇GAN文章只是让产生的结果尽量真实,还不能分类. 本次手写数字GAN的思想: 对于辨别器,利用真实的手写数字(真样本,对应的标签为真标签)和随机噪声经过生成器产生的样本(假样本,对应的标签为假标 ...

最新文章

  1. SAP RETAIL 为门店维护多个存储地点
  2. 《Oracle系列》:oracle job详解
  3. ASP.NET Core [1]:Hosting(笔记)
  4. Oracle查询数据库编码
  5. css 定位兼容性,CSS基础:定位与浏览器兼容性
  6. [转载]对 Linux 新手非常有用的20个命令
  7. 鸟哥的Linux私房菜(服务器)- 簡易 APT/YUM 伺服器設定
  8. Python语言学习之字符串那些事:python和字符串的使用方法之详细攻略
  9. 打包工具的配置教程见的多了,但它们的运行原理你知道吗?
  10. 统一的.NET文档体验发布
  11. Spring-framework应用程序启动loadtime源码分析笔记(二)——@Transactional
  12. Python format功能
  13. 论文学习15-Table Filling Multi-Task Recurrent Neural Network(联合实体关系抽取模型)
  14. IDEA中Spring Boot项目报错:There was an unexpected error (type=Not Found, status=404)
  15. 基于JAVA+SpringMVC+Mybatis+MYSQL的疫情防控物业管理系统
  16. python能做什么工作-学完Python我们可以做什么工作?
  17. Linux 内核软中断(softirq)执行分析
  18. 红帽linux开启vnc服务器,红帽Linux上使用VNC
  19. 3ds max 旋转及角度
  20. 3.3v稳压芯片有哪些

热门文章

  1. ubuntu重启后显卡驱动失效的问题
  2. 企业数字化转型,财务应该做些什么?
  3. 多孔材料导热模型篇—复现论文结果
  4. JavaScript调用OCX控件,运行时报错:对象不支持“XXX”属性或方法【已解决】
  5. [附源码]Java计算机毕业设计SSM防疫期社区人员信息动态管理系统
  6. 基于人脸图像识别学生宿舍系统的设计与实现(论文+源码)_kaic
  7. 人气漫画《蜡笔小新》作者被确定已死亡
  8. OCR之R^2AM(Recursive Recurrent Nets with Attention Modeling for OCR in the Wild)论文笔记
  9. 2023-04-23 学习记录--C/C++-函数
  10. Java编程基础30——SE经典案例