GAN原理

相关数学推导可参考 李宏毅https://www.bilibili.com/video/av36779967/?p=4

通俗的比喻:制造假钞(G)和警察(D)对抗的过程。假钞制造者制造假钞,警察识别假钞

假钞制造者制造出当前警察无法识别的假钞
警察提高识别能力识别出假钞

假钞制造者再次制造出当前警察无法识别的假钞
警察再次识别出假钞


以上两个过程循环,生成器就能生成接近真实的数据,实际上是使生成数据与真实数据的有相同的分布。

数据准备

用GAN生成服从一下分布的点两个区域分别为
x∈(0,5),y∈(5,10)
x∈(10,15),y∈(10,15)
在两个区域内均服从均匀分布,每个区域个500个点

`

def data_gen():# 区域1x1 = np.random.uniform(0, 5, 500)y1 = np.random.uniform(5, 10, 500)# 区域2 x2 = np.random.uniform(10, 15, 500)y2 = np.random.uniform(10, 15, 500)# 拼接data_x = np.concatenate((x1,x2))data_y = np.concatenate((y1,y2))# 拼接x,y以作为网络输入data = np.transpose(np.vstack((data_x,data_y)))print(data.shape)return data
# 测试代码
if __name__ == '__main__':data = data_gen()plt.scatter(data[:,0],data[:,1])plt.show()

生成对抗网络

生成器

生成器D接受一个二维向量(采用服从正态分布的向量),生成坐标(x,y)

import keras
import tensorflow as tf
from keras import layers
import numpy as np
import matplotlib.pyplot as plt
from data_gen import data_gen as dgG_input = keras.Input(shape=(2,)) # 输入一个二维vector
x = layers.Dense(5,activation='relu')(G_input)
x = layers.Dense(5,activation='relu')(x)
x = layers.Dense(2,activation='tanh')(x) # 输出为一堆二维坐标
G = keras.models.Model(G_input,x)
G.summary()
判别器

一个普通的全连网络

# 判别器
D_input = keras.Input(shape=(2,))
x = layers.Dense(10,activation='relu')(D_input)
x = layers.Dense(10,activation='relu')(x)
x = layers.Dropout(0, 4)(x)
x = layers.Dense(1,activation='sigmoid')(x)
D = keras.models.Model(D_input,x)
D.compile(loss='binary_crossentropy',optimizer='rmsprop')
D.summary()
GAN

生成器和判别器的连接,D参数设置为不可训练

D.trainable = False
gan_input = keras.Input(shape=(2,))
gan_output = D(G(gan_input))
gan = keras.models.Model(gan_input, gan_output)
gan.compile(loss='binary_crossentropy',optimizer='rmsprop')
gan.summary()

训练

准备数据
real_point = dg()/15.0 # 简单地归一化数据
fig = plt.figure()
# 限定坐标轴位置
plt.xlim(0,1)
plt.ylim(0,1)
ax = fig.add_subplot(1,1,1)
ax.scatter(real_point[:, 0], real_point[:, 1])
plt.ion()
训练
epochs = 1000 # 训练次数
for step in range(epochs):random_input = np.random.normal(size=(1000,2)) # 输入接受的随机向量gen_point = G.predict(random_input) # 生成fake points# 拼接数据加入标签用于训练判别器combined_point = np.concatenate([real_point,gen_point])labels = np.concatenate([np.ones((1000,1)),np.zeros((1000,1))])labels += 0.05*np.random.random(labels.shape)  # 标签加噪声,似乎绝对0和绝对1都对gan训练不利d_loss = D.train_on_batch(combined_point, labels) # 训练判别器random_input = np.random.normal(size=(1000, 2)) # 随机向量,用于对抗网络训练输入mis_targets = np.ones((1000,1)) # 加标签a_loss = gan.train_on_batch(random_input, mis_targets)# 可视化过程if step%10==0:print('discriminator loss:', d_loss)print('adversarial loss:', a_loss)try:ax.lines.remove(points[0])except Exception:passgen_point = G.predict(random_input)points = ax.plot(gen_point[:, 0], gen_point[:, 1],'ro')plt.pause(0.01)plt.pause(10)
训练过程结果

Reference

主要参考书籍:python深度学习 p260 - p263

用惯了matlab,matplotlib画图也很让人头疼。动态刷新plot可参考参考文章
matplotlib动态刷新指定曲线 https://blog.csdn.net/omodao1/article/details/81223240

Keras实现生成对抗网络(GAN)(生成二维平面上服从某一分布的点)相关推荐

  1. (五)使用生成对抗网络 (GAN)生成新的时装设计

    目录 介绍 预测新时尚形象的力量 构建GAN 初始化GAN参数和加载数据 从头开始构建生成器 从头开始构建鉴别器 初始化GAN的损失和优化器 下一步 下载源 - 120.7 MB 介绍 DeepFas ...

  2. 利用Tensorflow构建生成对抗网络GAN以生成数据

    使用生成对抗网络(GAN)生成数据 本文主要内容 介绍了自动编码器的基本原理 比较了生成模型与自动编码器的区别 描述了GAN模型的网络结构 分析了GAN模型的目标核函数以及训练过程 介绍了利用Goog ...

  3. [Python人工智能] 二十九.什么是生成对抗网络GAN?基础原理和代码普及(1)

    从本专栏开始,作者正式研究Python深度学习.神经网络及人工智能相关知识.前一篇文章分享了Keras实现经典的深度学习文本分类算法,包括LSTM.BiLSTM.BiLSTM+Attention和CN ...

  4. [论文阅读] (06) 万字详解什么是生成对抗网络GAN?经典论文及案例普及

    <娜璋带你读论文>系列主要是督促自己阅读优秀论文及听取学术讲座,并分享给大家,希望您喜欢.由于作者的英文水平和学术能力不高,需要不断提升,所以还请大家批评指正,非常欢迎大家给我留言评论,学 ...

  5. 万字详解什么是生成对抗网络GAN

    摘要:这篇文章将详细介绍生成对抗网络GAN的基础知识,包括什么是GAN.常用算法(CGAN.DCGAN.infoGAN.WGAN).发展历程.预备知识,并通过Keras搭建最简答的手写数字图片生成案. ...

  6. 生成对抗网络——GAN(一)

    Generative adversarial network 据有关媒体统计:CVPR2018的论文里,有三分之一的论文与GAN有关 由此可见,GAN在视觉领域的未来多年内,将是一片沃土(CVer们是 ...

  7. 生成对抗网络gan原理_中国首个“芯片大学”即将落地;生成对抗网络(GAN)的数学原理全解...

    开发者社区技术周刊又和大家见面了,萌妹子主播为您带来第三期"开发者技术联播".让我们一起听听,过去一周有哪些值得我们开发者关注的重要新闻吧. 中国首个芯片大学,南京集成电路大学即将 ...

  8. 权重对生成对抗网络GAN性能的影响

    本文制作了一个生成对抗网络GAN网络,并通过调节权重的初始化方法来观察权重对网络性能的影响. 生成网络的结构是784*300*784,对抗网络的结构是784*300*1.生成网络的输入是一个28*28 ...

  9. 【GAN优化】长文综述解读如何定量评价生成对抗网络(GAN)

    欢迎大家来到<GAN优化>专栏,这里将讨论GAN优化相关的内容,本次将和大家一起讨论GAN的评价指标. 作者&编辑 | 小米粥 编辑 | 言有三 在判别模型中,训练完成的模型要在测 ...

最新文章

  1. Python数字类型及操作汇总(入门级)
  2. php cache缓存 购物车,Yii2使用Redis缓存购物车等数据
  3. UBUNTU手动安装JDK的详细步骤
  4. 现代前端开发路线图:从零开始,一步步成为前端工程师
  5. 关于Oracle undostat中的2012和ORA-01555问题的自我解答
  6. UE4学习-游戏退出、游戏打包
  7. 三年半Java后端面试经历
  8. mysql主要的收获_MySQL性能测试大总结
  9. 第一二三范式的简单理解
  10. oracle联合运算,Oracle UNION运算符
  11. 【Mac】Mac 使用 zsh 后, mvn 命令无效
  12. 联想服务器远程管理模块,联想慧眼远程管理模块-Lenovo服务网站.PDF
  13. catia今天突然打不开了_catia打不开的解答
  14. 夏令时国家时间java代码_程序里的国际时区和夏令时
  15. 前端推荐的书籍学习(必看)
  16. 如何在wince下添加和删除驱动(作者:wogoyixikexie@gliet)
  17. 网页三栏布局常用方法
  18. 对finalize的理解
  19. 微型计算机故障处理基本原则,微机故障处理的一般性原则和方法
  20. 抖音怎么申请企业蓝V认证?(含认证教程)

热门文章

  1. 蘑菇游戏_熊碰撞边界处理
  2. ASP.NET典型三层架构企业级医药行业ERP系统实战(8大模块22个子系统,价值3000万)...
  3. Win10配置jdk1.8
  4. Altium Designer v22.7.1.60 PCB板、电路原理图设计工具
  5. 隆重给大家拜早年了,并顺道推荐几部影片
  6. android开发里跳过的坑-电源锁WakeLock不起作用
  7. 在手机端查看CAD图纸有什么技巧呢?
  8. c++运算符重载与输入输出流重载
  9. 金融风险管理师FRM培训班多少费用?贵吗?
  10. CAD文件解析(DWG to SVG)