1.介绍

在原始gan(GAN 简介与代码实战)中,生成数据的来源一般是一个固定分布噪声z,z可以生成不同的图片,z代表着很多意思,我们无法知道z的那个维度代表什么(比如在生成数字手写图片的时候,0维度是否代表笔画的风格,我们不得而知),z是不可解释的。为了解决这个问题,InfoGAN就横空出世了,更加详细的内容参见论文:Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets

2.模型结构

网络是基于DC-GAN(Deep Convolutional GAN)的,G和D都由CNN构成。在此基础上,Q和D共享卷积网络,然后分别通过各自的全连接层输出不同的内容:Q输出对应于生成图片的c'(与之对应的c是可以控制图片的生成,比如生成什么数字),D则仍然判别真伪。

3.模型特点

相对于原始gan,作者将Z分成z(固定分布噪声)和c(一些隐变量信息,比如笔画风格,字体大小等),损失函数里面用到互信息,使得隐变量c与生成的变量G(z,c)拥有尽可能多的共同信息。

4.代码实现keras

class INFOGAN():def __init__(self):self.img_rows = 28self.img_cols = 28self.channels = 1self.num_classes = 10self.img_shape = (self.img_rows, self.img_cols, self.channels)self.latent_dim = 72optimizer = Adam(0.0002, 0.5)losses = ['binary_crossentropy', self.mutual_info_loss]# Build and the discriminator and recognition networkself.discriminator, self.auxilliary = self.build_disk_and_q_net()self.discriminator.compile(loss=['binary_crossentropy'],optimizer=optimizer,metrics=['accuracy'])# Build and compile the recognition network Qself.auxilliary.compile(loss=[self.mutual_info_loss],optimizer=optimizer,metrics=['accuracy'])# Build the generatorself.generator = self.build_generator()# The generator takes noise and the target label as input# and generates the corresponding digit of that labelgen_input = Input(shape=(self.latent_dim,))img = self.generator(gen_input)# For the combined model we will only train the generatorself.discriminator.trainable = False# The discriminator takes generated image as input and determines validityvalid = self.discriminator(img)# The recognition network produces the labeltarget_label = self.auxilliary(img)# The combined model  (stacked generator and discriminator)self.combined = Model(gen_input, [valid, target_label])self.combined.compile(loss=losses,optimizer=optimizer)def build_generator(self):model = Sequential()model.add(Dense(128 * 7 * 7, activation="relu", input_dim=self.latent_dim))model.add(Reshape((7, 7, 128)))model.add(BatchNormalization(momentum=0.8))model.add(UpSampling2D())model.add(Conv2D(128, kernel_size=3, padding="same"))model.add(Activation("relu"))model.add(BatchNormalization(momentum=0.8))model.add(UpSampling2D())model.add(Conv2D(64, kernel_size=3, padding="same"))model.add(Activation("relu"))model.add(BatchNormalization(momentum=0.8))model.add(Conv2D(self.channels, kernel_size=3, padding='same'))model.add(Activation("tanh"))gen_input = Input(shape=(self.latent_dim,))img = model(gen_input)model.summary()return Model(gen_input, img)def build_disk_and_q_net(self):img = Input(shape=self.img_shape)# Shared layers between discriminator and recognition networkmodel = Sequential()model.add(Conv2D(64, kernel_size=3, strides=2, input_shape=self.img_shape, padding="same"))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))model.add(ZeroPadding2D(padding=((0,1),(0,1))))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(BatchNormalization(momentum=0.8))model.add(Conv2D(256, kernel_size=3, strides=2, padding="same"))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(BatchNormalization(momentum=0.8))model.add(Conv2D(512, kernel_size=3, strides=2, padding="same"))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(BatchNormalization(momentum=0.8))model.add(Flatten())img_embedding = model(img)# Discriminatorvalidity = Dense(1, activation='sigmoid')(img_embedding)# Recognitionq_net = Dense(128, activation='relu')(img_embedding)label = Dense(self.num_classes, activation='softmax')(q_net)# Return discriminator and recognition networkreturn Model(img, validity), Model(img, label)def mutual_info_loss(self, c, c_given_x):"""The mutual information metric we aim to minimize"""eps = 1e-8conditional_entropy = K.mean(- K.sum(K.log(c_given_x + eps) * c, axis=1))entropy = K.mean(- K.sum(K.log(c + eps) * c, axis=1))return conditional_entropy + entropydef sample_generator_input(self, batch_size):# Generator inputssampled_noise = np.random.normal(0, 1, (batch_size, 62))sampled_labels = np.random.randint(0, self.num_classes, batch_size).reshape(-1, 1)sampled_labels = to_categorical(sampled_labels, num_classes=self.num_classes)return sampled_noise, sampled_labelsdef train(self, epochs, batch_size=128, sample_interval=50):# Load the dataset(X_train, y_train), (_, _) = mnist.load_data()# Rescale -1 to 1X_train = (X_train.astype(np.float32) - 127.5) / 127.5X_train = np.expand_dims(X_train, axis=3)y_train = y_train.reshape(-1, 1)# Adversarial ground truthsvalid = np.ones((batch_size, 1))fake = np.zeros((batch_size, 1))for epoch in range(epochs):# ---------------------#  Train Discriminator# ---------------------# Select a random half batch of imagesidx = np.random.randint(0, X_train.shape[0], batch_size)imgs = X_train[idx]# Sample noise and categorical labelssampled_noise, sampled_labels = self.sample_generator_input(batch_size)gen_input = np.concatenate((sampled_noise, sampled_labels), axis=1)# Generate a half batch of new imagesgen_imgs = self.generator.predict(gen_input)# Train on real and generated datad_loss_real = self.discriminator.train_on_batch(imgs, valid)d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)# Avg. lossd_loss = 0.5 * np.add(d_loss_real, d_loss_fake)# ---------------------#  Train Generator and Q-network# ---------------------g_loss = self.combined.train_on_batch(gen_input, [valid, sampled_labels])# Plot the progressprint ("%d [D loss: %.2f, acc.: %.2f%%] [Q loss: %.2f] [G loss: %.2f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss[1], g_loss[2]))# If at save interval => save generated image samplesif epoch % sample_interval == 0:self.sample_images(epoch)def sample_images(self, epoch):r, c = 10, 10fig, axs = plt.subplots(r, c)for i in range(c):sampled_noise, _ = self.sample_generator_input(c)label = to_categorical(np.full(fill_value=i, shape=(r,1)), num_classes=self.num_classes)gen_input = np.concatenate((sampled_noise, label), axis=1)gen_imgs = self.generator.predict(gen_input)gen_imgs = 0.5 * gen_imgs + 0.5for j in range(r):axs[j,i].imshow(gen_imgs[j,:,:,0], cmap='gray')axs[j,i].axis('off')fig.savefig("images/%d.png" % epoch)plt.close()

InfoGAN 简介与代码实战相关推荐

  1. BART原理简介与代码实战

    写在前面 最近huggingface的transformer库,增加了BART模型,Bart是该库中最早的Seq2Seq模型之一,在文本生成任务,例如摘要抽取方面达到了SOTA的结果. 本次放出了三组 ...

  2. Flume NG 简介及配置实战

    2019独角兽企业重金招聘Python工程师标准>>> Flume NG 简介及配置实战 博客分类: 分布式计算 1.Flume 的一些核心概念: 1.1 数据流模型 1.2 高可靠 ...

  3. 【强化学习】Sarsa算法求解悬崖行走问题 + Python代码实战

    文章目录 一.Sarsa算法简介 1.1 更新公式 1.2 预测策略 1.3 详细资料 二.Python代码实战 2.1 运行前配置 2.2 主要代码 2.3 运行结果展示 2.4 关于可视化寻路过程 ...

  4. 【强化学习】优势演员-评论员算法(Advantage Actor-Critic , A2C)求解倒立摆问题 + Pytorch代码实战

    文章目录 一.倒立摆问题介绍 二.优势演员-评论员算法简介 三.详细资料 四.Python代码实战 4.1 运行前配置 4.2 主要代码 4.2.1 网络参数不共享版本 4.2.2 网络参数共享版本 ...

  5. 【强化学习】Q-Learning算法求解悬崖行走问题 + Python代码实战

    文章目录 一.Q-Learning算法简介 1.1 更新公式 1.2 预测策略 1.3 详细资料 二.Python代码实战 2.1 运行前配置 2.2 主要代码 2.3 运行结果展示 2.4 关于可视 ...

  6. 【强化学习】双深度Q网络(DDQN)求解倒立摆问题 + Pytorch代码实战

    文章目录 一.倒立摆问题介绍 二.双深度Q网络简介 三.详细资料 四.Python代码实战 4.1 运行前配置 4.2 主要代码 4.3 运行结果展示 4.4 关于可视化的设置 一.倒立摆问题介绍 A ...

  7. 【强化学习】竞争深度Q网络(Dueling DQN)求解倒立摆问题 + Pytorch代码实战

    文章目录 一.倒立摆问题介绍 二.竞争深度Q网络简介 三.详细资料 四.Python代码实战 4.1 运行前配置 4.2 主要代码 4.3 运行结果展示 4.4 关于可视化的设置 一.倒立摆问题介绍 ...

  8. 【强化学习】PPO算法求解倒立摆问题 + Pytorch代码实战

    文章目录 一.倒立摆问题介绍 二.PPO算法简介 三.详细资料 四.Python代码实战 4.1 运行前配置 4.2 主要代码 4.3 运行结果展示 4.4 关于可视化的设置 一.倒立摆问题介绍 Ag ...

  9. Neural Collaborative Filtering(NCF) 代码实战(Keras)

    博客主要分为两部分.第一部分为论文简介,第二部分为代码实战. 论文简介: 1. 通用框架 下图是作者提出的用神经网络解决推荐系统问题的通用框架.论文先将用户与物品分别进行one-hot编码,然后通过一 ...

最新文章

  1. 荣耀3OS怎么升级鸿蒙系统,华为鸿蒙OS正式发布!教你如何升级
  2. xpath提取html属性,xpath提取 html标签的文字内容
  3. 011_AOP注解开发
  4. 3.1 cat:合并文件或查看文件内容
  5. ai画面怎么调大小_ai如何调整对象大小
  6. 银屑病缺乏的营养汇总(持续更新中)
  7. python中print的用法_Python中print函数简单使用总结
  8. word公式编辑器_毕业论文里面的各种公式该如何编辑
  9. android 应用切换滑动,Android应用中利用ViewPager实现多页面滑动切换效果示例
  10. 在asp.net web api 2 (ioc autofac) 使用 Serilog 记录日志
  11. js中的instanceof运算符
  12. php架构师生涯一个月总结
  13. zigbee设备类型
  14. 给斐讯K1刷机并拨号e信(湖北地区测试无问题)
  15. Packet(信息包)
  16. OpenCV实现目标跟踪
  17. Win10系统设置炫酷下拉关机(其实很简单啦!)
  18. Dubbo系列之框架概括(一)
  19. Android CPU 双核,双核到底强在哪?四大手机处理器终极横评
  20. 瑞芯微RV1126/1109开发流程之驱动升级

热门文章

  1. sql模糊查询判断字符串包含一些字符串
  2. crypt密码加密函数的基本用法
  3. 在线点餐系统设计文档
  4. oracle 10046
  5. 声母-字母查询工具-词语缩写查询在线工具
  6. 贪吃蛇智能版(专家)
  7. 硬件电路设计纯纯小白-2-Altium Designer软件汉化、再变成英文
  8. 【前端知识之Vue】对插槽(slot)的理解
  9. 基于stm32的非接触式红外测温系统
  10. 三星s10android10功能,安卓老大王者归来 三星S10曝光黑科技一览