自编码器一般功能是压缩数据。但是变分自编码器(variational autoencoder,VAE)的主要作用却是可以生成数据。变分自编码器的这种功能可以应用到很多领域,从生成手写文字到合成音乐。
变分自编码器的实现在标准自编码器上增加了一个网络层,该网络层有2个子网络组成:均值网络和方差网络。如下图所示:

变分自编码器的训练过程与标准自编码器不同。首先是在编码器中引入了2个子网络均值和方差。其次是在编码器和解码器之间加入了一个采样层。采样层的输入为均值、方差以及高斯噪声,算法如下:

最后采样层的结果输入到解码器。在进行误差反传的时候,变分自编码器并不是简单的使用mse等损失函数,而是损失函数的基础上增加了Kullback-Leibler散度(KL算法)。增加KL算法的原因是要确定均值和方差子网络是符合正态分布的。
变分自编码器的使用是用正态分布的随机数作为解码器的输入,在输出端就可以得到与输入类似但又不同的结果(比如图像等)。这就是变分自编码器最具吸引力的地方。变分自编码器是目前比较流行的生成模型之一。
变分自编码器之所以可以生成结果,是因为它提取了输入的特征。比如输入的是人脸图片,变分自编码器相当于保留了人脸的基本特征的均值和方差信息。通过调整这些信息(输入正态分布随数),就可以得到不同的人脸图片。
变分自编码器代码:

import os
import tensorflow as tf
from tensorflow import keras
from PIL import Image
from matplotlib import pyplot as plt
from tensorflow.keras import Sequential, layers
import numpy as nptf.random.set_seed(2322)
np.random.seed(23422)os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
assert  tf.__version__.startswith('2.')# 把num张图片保存到一张
def save_images(img, name,num):new_im = Image.new('L', (28*num, 28*num))index = 0for i in range(0, 28*num, 28):for j in range(0, 28*num, 28):im = img[index]im = Image.fromarray(im, mode='L')new_im.paste(im, (i, j))index += 1new_im.save(name)# 定义超参数
batchsz = 256
lr = 1e-4# 数据集加载,自编码器不需要标签因为是无监督学习
(x_train, _), (x_test, _) = keras.datasets.fashion_mnist.load_data()
x_train, x_test = x_train.astype(np.float32) / 255., x_test.astype(np.float32) / 255.
train_db = tf.data.Dataset.from_tensor_slices(x_train)
train_db = train_db.shuffle(batchsz * 5).batch(batchsz)
test_db = tf.data.Dataset.from_tensor_slices(x_test)
test_db = test_db.batch(batchsz)# 搭建模型
z_dim = 10
class VAE(keras.Model):def __init__(self,z_dim,units=256):super(VAE, self).__init__()self.z_dim = z_dimself.units = units# 编码网络self.vae_encoder = layers.Dense(self.units)# 均值网络self.vae_mean = layers.Dense(self.z_dim)      # get mean prediction# 方差网络(均值和方差是一一对应的,所以维度相同)self.vae_variance = layers.Dense(self.z_dim)      # get variance prediction# 解码网络self.vae_decoder = layers.Dense(self.units)# 输出网络self.vae_out = layers.Dense(784)# encoder传播的过程def encoder(self, x):h = tf.nn.relu(self.vae_encoder(x))#计算均值mu = self.vae_mean(h)#计算方差log_var = self.vae_variance(h)return  mu, log_var# decoder传播的过程def decoder(self, z):out = tf.nn.relu(self.vae_decoder(z))out = self.vae_out(out)return outdef reparameterize(self, mu, log_var):eps = tf.random.normal(log_var.shape)std = tf.exp(log_var)         # 去掉log, 得到方差;std = std**0.5                # 开根号,得到标准差;z = mu + std * epsreturn zdef call(self, inputs):mu, log_var = self.encoder(inputs)# reparameterizaion trick:最核心的部分z = self.reparameterize(mu, log_var)# decoder 进行还原x_hat = self.decoder(z)# Variational auto-encoder除了前向传播不同之外,还有一个额外的约束;# 这个约束使得你的mu, var更接近正太分布;所以我们把mu, log_var返回;return x_hat, mu, log_varmodel = VAE(z_dim,units=128)
model.build(input_shape=(128, 784))
optimizer = keras.optimizers.Adam(lr=lr)epochs = 30
for epoch in range(epochs):for step, x in enumerate(train_db):x = tf.reshape(x, [-1, 784])with tf.GradientTape() as tape:# shapex_hat, mu, log_var = model(x)# 把每个像素点当成一个二分类的问题;rec_loss = tf.losses.binary_crossentropy(x, x_hat, from_logits=True)rec_loss = tf.reduce_mean(rec_loss)# compute kl divergence (mu, var) ~ N(0, 1): 我们得到的均值方差和正太分布的;# 链接参考: https://stats.stackexchange.com/questions/7440/kl-divergence-between-two-univariate-gaussianskl_div = -0.5 * (log_var + 1 -mu**2 - tf.exp(log_var))kl_div = tf.reduce_mean(kl_div) / batchszloss = rec_loss + 1. * kl_divgrads = tape.gradient(loss, model.trainable_variables)optimizer.apply_gradients(zip(grads, model.trainable_variables))if step % 100 ==0:print('\repoch: %3d, step:%4d, kl_div: %5f, rec_loss:%9f' %(epoch, step, float(kl_div), float(rec_loss)),end="")num_pic = 9# evaluation 1: 从正太分布直接sample;z = tf.random.normal((batchsz, z_dim))                              # 从正太分布中sample这个尺寸的logits = model.decoder(z)                                           # 通过这个得到decoderx_hat = tf.sigmoid(logits)x_hat = tf.reshape(x_hat, [-1, 28, 28]).numpy() * 255.logits = x_hat.astype(np.uint8)                                     # 标准的图片格式;save_images(logits, 'd:\\vae_images\\sampled_epoch%d.png' %epoch,num_pic)         # 直接sample出的正太分布;# evaluation 2: 正常的传播过程;x = next(iter(test_db))x = tf.reshape(x, [-1, 784])x_hat_logits, _, _  = model(x)                       # 前向传播返回的还有mu, log_varx_hat = tf.sigmoid(x_hat_logits)x_hat = tf.reshape(x_hat, [-1, 28, 28]).numpy() * 255.x_hat = x_hat.astype(np.uint8)                       # 标准的图片格式;# print(x_hat.shape)save_images(x_hat, 'd:\\vae_images\\rec_epoch%d.png' %epoch,num_pic)

变分自编码器输出结果:

手写数字的生成效果就差一些:

Tensorflow实现变分自编码器相关推荐

  1. VAE变分自编码器实现

    变分自编码器(VAE)组合了神经网络和贝叶斯推理这两种最好的方法,是最酷的神经网络,已经成为无监督学习的流行方法之一. 变分自编码器是一个扭曲的自编码器.同自编码器的传统编码器和解码器网络一起,具有附 ...

  2. 【阿里云课程】深度生成模型基础,自编码器与变分自编码器

    大家好,继续更新有三AI与阿里天池联合推出的深度学习系列课程,本次更新内容为第11课中两节,介绍如下: 第1节:生成模型基础 本次课程是阿里天池联合有三AI推出的深度学习系列课程第11期,深度生成模型 ...

  3. 变分自编码器VAE:一步到位的聚类方案

    作者丨苏剑林 单位丨广州火焰信息科技有限公司 研究方向丨NLP,神经网络 个人主页丨kexue.fm 由于 VAE 中既有编码器又有解码器(生成器),同时隐变量分布又被近似编码为标准正态分布,因此 V ...

  4. VAE【变分自编码器】

    使用通用自编码器的时候,首先将输入encoder压缩为一个小的 form,然后将其decoder转换成输出的一个估计.如果目标是简单的重现输入效果很好,但是若想生成新的对象就不太可行了,因为其实我们根 ...

  5. 变分自编码器(VAE)详解与实现(tensorflow2.x)

    变分自编码器(VAE)详解与实现(tensorflow2.x) VAE介绍 VAE原理 变分推理 VAE核心方程 优化方式 重参数化技巧(Reparameterization trick) VAE实现 ...

  6. 深入理解自编码器(用变分自编码器生成图像)

    文章目录 自编码器 欠完备自编码器 正则自编码器 稀疏自编码器 去噪自编码器 收缩自编码器 变分自编码器 References 内容总结自花书<Deep Learning>以及<Py ...

  7. 理解变分自编码器,GAN的近亲

    转自 专知 原文: https://www.jeremyjordan.me/variational-autoencoders/ [导读]自编码器是一种非常直观的无监督神经网络方法,由编码器和解码器两部 ...

  8. 【深度学习】用变分自编码器生成图像和生成式对抗网络

    目录 问题描述: 代码展示: VAE代码段 GAN部分(仅供参考) 运行截图: 参考: 问题描述: 从图像的潜在空间中采样,并创建全新图像或编辑现有图像,这是目前最流行也是最成 功的创造性人工智能应用 ...

  9. 【自然语言处理系列】自编码器AE、变分自编码器VAE和条件变分自编码器CVAE

    作者:CHEONG 公众号:AI机器学习与知识图谱 研究方向:自然语言处理与知识图谱 本文主要分享自编码器.变分自编码器和条件变分自编码器的相关知识以及在实际实践中的应用技巧,原创不易转载请注明出处, ...

最新文章

  1. C语言程序设计最佳分组,求助把一些数值按指定的和进行分组
  2. 编译linux内核的错误,linux内核编译错误
  3. Python Day23 stark组件1
  4. 如何使用MATLAB绘制不同类型的二维图形
  5. JS实现各种复制到剪贴板
  6. View、Text、Button的drawableLeft左侧图片自定义大小
  7. 支付宝借呗利息万3和万2.5的,都是些什么大神级的人物?
  8. 在主线程执行_深入理解JavaScript执行机制
  9. LaTeX的历史:图灵奖得主1977年开启的计划,引发学术圈重大变革
  10. 一文讲解ARM、STM32之间的关系以及STM单片机介绍
  11. html网页设计课程的思维导图,html思维导图
  12. 【wordpress】wordpress自己制作主题看这一篇就够了/常用函数/注意事项
  13. 计算机维修培训教材,计算机芯片级维修中心芯片级维修培训教材.pdf
  14. 苹果笔记本电脑怎么编辑html,苹果笔记本电脑怎么操作 苹果笔记本电脑操作方法【详解】...
  15. 双矩阵对策MATLAB,带有模糊收益的双矩阵对策研究
  16. 创建输入控件(input控件、文本框、密码框、单项选择、多项选择、重置与提交按钮的设置)
  17. Windows解决运行slmgr.vbs -xpr 找不到应用程序问题
  18. SVN在Eclipse中的安装步骤以及使用方法和建立分支
  19. 1-1HTML笔记总结
  20. 基于C#的机器学习--c# .NET中直观的深度学习

热门文章

  1. Flink1.11 intervalJoin watermark生成,状态清理机制源码理解Demo分析
  2. logic pro x 下载
  3. 邯郸学院计算机怎样百度,邯郸学院怎么样?为你深度解读邯郸学院。
  4. Security Storage Management using Tivoli – Wrap –up! Part 1 2
  5. Anolis OS8.6QU1通过cephadm部署ceph17.2.0分布式块存储(三)添加其它主机和添加mgr节点
  6. 爬虫,百度搜索热点排行
  7. php账单明细功能怎么实现,php 账单生成
  8. 极路由 1s HC5661 玩转 openwrt
  9. S7-1200PLC与MCGS触摸屏实现以太网通信的具体方法示例
  10. cisco ap 上线不成功