上周我们尝试使用StyleGAN的pretrained_example.py来生成人像和动漫少女头像,效果很不错,请参考:

https://blog.csdn.net/weixin_41943311/article/details/100539707

但是,StyleGAN还是显得很神秘,网上可以查到的中文资料也非常少,于是决定由浅入深,探一探StyleGAN功能和代码的究竟。

先看一下StyleGAN的网络模型,如下图所示:

在StyleGAN的网络模型中,先定义一个随机张量latent,归一化后经过8层全连接网络(Mapping network,也称:函数f),映射到向量w;向量w作为输入A,同时引入噪声B,再经过合成网络(Synthesis network,也称:函数g)生成图像。

在这里,latent总是一个大小为512的张量;对于1024x1024的目标图像,合成网络为18层;对于512x512的目标图像,合成网络为16层。

今天我们研究的是StyleGAN中比较简单的generate_figure.py的代码,说明如下:

(1)这里用已下载到本地的网络模型,代替需要通过Google网盘访问的链接,因此修改了Load_Gs()函数。

已训练好的网络模型的百度网盘下载地址,可以参考:

https://blog.csdn.net/weixin_41943311/article/details/100539707

(2)我的笔记本电脑是Windows 10,自带一块NVIDIA GeForce GTX 1060显卡,由于显卡内存有限,所以无法运行面向1024x1024人脸图像的全部功能测试(主要是运行draw_style_mixing_figure()时,会报告显卡内存分配失败错误),因此我改为面向512x512动漫头像的功能测试,在代码的许多地方将1024修改为512,将18层修改为16层,将剪切区域从1024x1024的范围缩小到512x512的范围。

(3)中文注释了generate_figure()中的全部函数,包括:

draw_uncurated_result_figure()

draw_style_mixing_figure()

draw_noise_detail_figure()

draw_noise_components_figure()

draw_truncation_trick_figure()

(4)值得专门说明的一点是,有三种方法可以使用StyleGAN预先训的生成器:

(4.1)使用Gs.run(),输入和输出均为numpy数组,这是一种最为简便的使用方法:

# 选择特征向量
rnd = np.random.RandomState(5)
latents = rnd.randn(1, Gs.input_shape[1])# 生成图像
fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
images = Gs.run(latents, None, truncation_psi=0.7, randomize_noise=True, output_transform=fmt)

第一个参数是一批形状为[num, 512]的特征向量,第二个参数预留给类别标签(StyleGan并没有使用,所以参数为None)。其余的关键字参数是可选的。输出是一批图像,其格式由output_transform参数决定。

(4.2)使用Gs.get_output_for()将生成器合并为一个更大的TensorFlow表达式的一部分:

latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:])
images = Gs_clone.get_output_for(latents, None, is_validation=True, randomize_noise=True)
images = tflib.convert_images_to_uint8(images)
result_expr.append(inception_clone.get_output_for(images))

它生成一批随机图像,并将它们直接提供给Inception-v3网络,而无需在中间将数据转换为numpy数组。

(4.3)查找Gs.components.mapping和Gs.components.synthesis以访问生成器的各个子网络,即下面的src_dlatents:

src_latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in src_seeds)
src_dlatents = Gs.components.mapping.run(src_latents, None) # [seed, layer, component]
src_images = Gs.components.synthesis.run(src_dlatents, randomize_noise=False, **synthesis_kwargs)

首先利用映射网络将一批特征向量转化为中间的W空间,然后利用合成网络将这些向量转化为一批图像。src_dlatents数组为合成网络的每一层存储同一w向量的单独副本,在使用合成网络之前,其中某些副本可以用其他的数据(如:draw_style_mixing_figure()函数中用到的dst_dlatents)的某些副本来替换,然后再通过合成网络来实现样式混合。

下面是在generate_figure.py基础上修改得到的generate_figure002.py的源代码及中文注释:

# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the Creative Commons Attribution-NonCommercial
# 4.0 International License. To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA."""Minimal script for reproducing the figures of the StyleGAN paper using pre-trained generators."""import os
import pickle
import numpy as np
import PIL.Image
import dnnlib
import dnnlib.tflib as tflib
import config
import glob#----------------------------------------------------------------------------
# Helpers for loading and using pre-trained generators.# pre-trained network.
#Model = './cache/karras2019stylegan-ffhq-1024x1024.pkl'
Model = './cache/2019-03-08-stylegan-animefaces-network-02051-021980.pkl'synthesis_kwargs = dict(output_transform=dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True), minibatch_size=8)
_Gs_cache = dict()# 加载StyleGAN已训练好的网络模型
def load_Gs(model):if model not in _Gs_cache:model_file = glob.glob(Model)if len(model_file) == 1:model_file = open(model_file[0], "rb")else:raise Exception('Failed to find the model')_G, _D, Gs = pickle.load(model_file)# _G = Instantaneous snapshot of the generator. Mainly useful for resuming a previous training run.# _D = Instantaneous snapshot of the discriminator. Mainly useful for resuming a previous training run.# Gs = Long-term average of the generator. Yields higher-quality results than the instantaneous snapshot.# Print network details.Gs.print_layers()_Gs_cache[model] = Gsreturn _Gs_cache[model]#----------------------------------------------------------------------------
# Figures 2, 3, 10, 11, 12: Multi-resolution grid of uncurated result images.
# lod, level of detail
# 多层次细节处理,在相机距离不同的情况下,使得物体显示不同的模型,从而节省性能的开销。
# 这里的实际处理是把图片按比例缩小,按照不同大小的网格显示随机生成的图片,使得效果更炫一些。def draw_uncurated_result_figure(png, Gs, cx, cy, cw, ch, rows, lods, seed):print(png)# 规划在网格中显示的图片的数量,按此数量定义latents的数量latents = np.random.RandomState(seed).randn(sum(rows * 2**lod for lod in lods), Gs.input_shape[1])# 使用Gs.run()直接生成输出为numpy数组的图像images = Gs.run(latents, None, **synthesis_kwargs) # [seed, y, x, rgb]# 绘制空白画布canvas = PIL.Image.new('RGB', (sum(cw // 2**lod for lod in lods), ch * rows), 'white')# iteration()产生一个迭代器,使用next()方法获取下一个项image_iter = iter(list(images))# 在画布的网格中逐一绘制生成的图像for col, lod in enumerate(lods):for row in range(rows * 2**lod):image = PIL.Image.fromarray(next(image_iter), 'RGB')image = image.crop((cx, cy, cx + cw, cy + ch))image = image.resize((cw // 2**lod, ch // 2**lod), PIL.Image.ANTIALIAS)canvas.paste(image, (sum(cw // 2**lod for lod in lods[:col]), row * ch // 2**lod))canvas.save(png)#----------------------------------------------------------------------------
# Figure 3: Style mixing.
# 分别用不同的种子生成源图像和目标图像,然后用源图像的src_dlatents的一部分替换目标图像的dst_dlatents的对应部分,
# 然后用Gs.components.synthesis.run()函数生成风格混合后的图像def draw_style_mixing_figure(png, Gs, w, h, src_seeds, dst_seeds, style_ranges):print(png)# 生成随机的latents,包括:src_latents和dst_latents,都是Gs.input_shape[1]大小的张量,对512x512的图片来说,就是512# 重复生成,src_latents共有len(src_seeds)个张量(主程序中设定为5),dst_latents共有len(dst_seeds)个张量(主程序中设定为6)src_latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in src_seeds)dst_latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in dst_seeds)# 按照StyleGAN的网络架构,从z变换到w,对于512x512的图片来说,src_dlatents的shape是5x16x512,dst_dlatents的shape是6x16x512src_dlatents = Gs.components.mapping.run(src_latents, None) # [seed, layer, component]dst_dlatents = Gs.components.mapping.run(dst_latents, None) # [seed, layer, component]# 从w生成图像src_images = Gs.components.synthesis.run(src_dlatents, randomize_noise=False, **synthesis_kwargs)dst_images = Gs.components.synthesis.run(dst_dlatents, randomize_noise=False, **synthesis_kwargs)# 画空白画布canvas = PIL.Image.new('RGB', (w * (len(src_seeds) + 1), h * (len(dst_seeds) + 1)), 'white')# 在画布的第一行画源图像,第一格空白for col, src_image in enumerate(list(src_images)):canvas.paste(PIL.Image.fromarray(src_image, 'RGB'), ((col + 1) * w, 0))# 在画布逐行绘制图像for row, dst_image in enumerate(list(dst_images)):# 首列绘制目标图像canvas.paste(PIL.Image.fromarray(dst_image, 'RGB'), (0, (row + 1) * h))# 将目标图像复制src_seeds份(主程序中设定为5),构成新数组row_dlatents = np.stack([dst_dlatents[row]] * len(src_seeds))# 用src_dlatents的指定列替换row_dlatents的指定列,数据混合row_dlatents[:, style_ranges[row]] = src_dlatents[:, style_ranges[row]]# 调用用Gs.components.synthesis.run()函数生成风格混合后的图像row_images = Gs.components.synthesis.run(row_dlatents, randomize_noise=False, **synthesis_kwargs)# 在画布上逐列绘制风格混合后的图像for col, image in enumerate(list(row_images)):canvas.paste(PIL.Image.fromarray(image, 'RGB'), ((col + 1) * w, (row + 1) * h))canvas.save(png)#----------------------------------------------------------------------------
# Figure 4: Noise detail.
# 以计算统计均方差的方式展现图片生成过程中的噪声的影响def draw_noise_detail_figure(png, Gs, w, h, num_samples, seeds):print(png)# 画布为3列,一共len(seeds)行canvas = PIL.Image.new('RGB', (w * 3, h * len(seeds)), 'white')# 逐行画图for row, seed in enumerate(seeds):# latents是大小为len(Gs.input_shape[1]的张量,默认为512,相同的种子,一次生成num_samples个张量(主程序设定为100)latents = np.stack([np.random.RandomState(seed).randn(Gs.input_shape[1])] * num_samples)# 允许生成器使用“截断技巧”(truncation trick),通过通过设置阈值的方式来截断 z 的采样,提高图片的生成质量(但同时可能会降低生成图片的差异性)# 使用Gs.run()直接生成输出为numpy数组的图像images = Gs.run(latents, None, truncation_psi=1, **synthesis_kwargs)# 画图,第一列canvas.paste(PIL.Image.fromarray(images[0], 'RGB'), (0, row * h))# 剪裁,放大image1、2、3、4,并画图,让你看看图片细节上的某些差异(即:噪声)for i in range(4):crop = PIL.Image.fromarray(images[i + 1], 'RGB')# 我们的图是512x512,所以截取的区域都除以2crop = crop.crop((650/2, 180/2, 906/2, 436/2))crop = crop.resize((w//2, h//2), PIL.Image.NEAREST)canvas.paste(crop, (w + (i%2) * w//2, row * h + (i//2) * h//2))# 对所有图像的同一个像素(x,y)的数值先计算均值,然后在每一行上计算标准差,最后乘以4diff = np.std(np.mean(images, axis=3), axis=0) * 4# 0-255内取值,超出取值范围的按边界值取值,四舍五入diff = np.clip(diff + 0.5, 0, 255).astype(np.uint8)# 在第三列画图,即多图计算方差后得出的“噪声”canvas.paste(PIL.Image.fromarray(diff, 'L'), (w * 2, row * h))canvas.save(png)#----------------------------------------------------------------------------
# Figure 5: Noise components.
# 显示不同的噪声区间对生成图片的影响def draw_noise_components_figure(png, Gs, w, h, seeds, noise_ranges, flips):print(png)# 创建Gs网络的一个克隆体,包括所有的变量Gsc = Gs.clone()# vars() 函数返回对象的属性和属性值的字典对象,这里返回的是随机生成的noise_input[]。noise_vars = [var for name, var in Gsc.components.synthesis.vars.items() if name.startswith('noise')]# tflib.run()运行网络,为noise_vars赋值noise_pairs = list(zip(noise_vars, tflib.run(noise_vars))) # [(var, val), ...]# 为Gsc创建初始张量latents,主程序给定2个种子latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in seeds)all_images = []# 添加不同噪声后生成图片for noise_range in noise_ranges:# 赋值,在噪声区间内的值保留,否则置0tflib.set_vars({var: val * (1 if i in noise_range else 0) for i, (var, val) in enumerate(noise_pairs)})# 给定噪声区间,用Gsc生成不同的图片range_images = Gsc.run(latents, None, truncation_psi=1, randomize_noise=False, **synthesis_kwargs)range_images[flips, :, :] = range_images[flips, :, ::-1]all_images.append(list(range_images))# 绘制空白画布canvas = PIL.Image.new('RGB', (w * 2, h * 2), 'white')for col, col_images in enumerate(zip(*all_images)): # col = 2,两个种子,生成两组图片# 画图,第一组图片放在第一列,第二组图片放在第二列# image[0]左半边和image[1]右半边画在同一列的第一行,image[2]的左半边和image[3]的右半边画在同一列的第二行canvas.paste(PIL.Image.fromarray(col_images[0], 'RGB').crop((0, 0, w//2, h)), (col * w, 0))canvas.paste(PIL.Image.fromarray(col_images[1], 'RGB').crop((w//2, 0, w, h)), (col * w + w//2, 0))canvas.paste(PIL.Image.fromarray(col_images[2], 'RGB').crop((0, 0, w//2, h)), (col * w, h))canvas.paste(PIL.Image.fromarray(col_images[3], 'RGB').crop((w//2, 0, w, h)), (col * w + w//2, h))canvas.save(png)#----------------------------------------------------------------------------
# Figure 8: Truncation trick.
# 展现不同截断阈值对生成图片的影响def draw_truncation_trick_figure(png, Gs, w, h, seeds, psis):print(png)# 随机生成初始张量latentslatents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in seeds)# z映射到wdlatents = Gs.components.mapping.run(latents, None) # [seed, layer, component]# 获取w的均值dlatent_avgdlatent_avg = Gs.get_var('dlatent_avg') # [component]# 绘制空白画布,共2行,每行展示同一种子不同截断阈值下生成的图片canvas = PIL.Image.new('RGB', (w * len(psis), h * len(seeds)), 'white')for row, dlatent in enumerate(list(dlatents)):# 将dlatent张量增加维度,依照主程序设定的截断阈值的个数复制了N份,用截断阈值的比例大小控制row_dlatents# row_dlatents满足下面synthesis.run()的输入数组的维度要求row_dlatents = (dlatent[np.newaxis] - dlatent_avg) * np.reshape(psis, [-1, 1, 1]) + dlatent_avg# 将被截断阈值调控后的row_dlatents用于StyleGAN网络模型row_images = Gs.components.synthesis.run(row_dlatents, randomize_noise=False, **synthesis_kwargs)# 在画布上画图for col, image in enumerate(list(row_images)):canvas.paste(PIL.Image.fromarray(image, 'RGB'), (col * w, row * h))canvas.save(png)#----------------------------------------------------------------------------
# Main program.def main():tflib.init_tf()os.makedirs(config.result_dir, exist_ok=True)# 将许多未经筛选的图片集中绘制在一起,按不同的缩小比例布置在画布上draw_uncurated_result_figure(os.path.join(config.result_dir, 'figure02-uncurated-animation.png'), load_Gs(Model), cx=0, cy=0, cw=512, ch=512, rows=3, lods=[0, 1, 2, 2, 3, 3], seed=5)# 这里除了要把w、h从1024修改为512以外,还得把range(8, 18)修改为range(8,16),因为StyleGAN在生成1024x1024的图片时的合成网络g是18层的,生成512x512图片时的合成网络g是16层的# 精心选择的种子(src_seeds, dst_seeds),应该与生成图片的效果有关draw_style_mixing_figure(os.path.join(config.result_dir, 'figure03-style-mixing.png'), load_Gs(Model), w=512, h=512, src_seeds=[639, 701, 687, 615, 2268], dst_seeds=[888, 829, 1898, 1733, 1614, 845], style_ranges=[range(0, 4)] * 3 + [range(4, 8)] * 2 + [range(8, 16)])# 自动生成的图片,展现在种子相同时不同图片之间的噪声draw_noise_detail_figure(os.path.join(config.result_dir, 'figure04-noise-detail.png'), load_Gs(Model), w=512, h=512, num_samples=100, seeds=[1157, 1012])# 给定两个种子,设置4个噪声区间,显示不同噪声水平下的图片内容draw_noise_components_figure(os.path.join(config.result_dir, 'figure05-noise-components.png'), load_Gs(Model), w=512, h=512, seeds=[1967, 1555], noise_ranges=[range(0, 16), range(0, 0), range(8, 16), range(0, 8)], flips=[1])# 给定两个种子,6个truncation_psi,用于比较不同“截断技巧”(truncation trick)水平下的图片生成质量和差异度draw_truncation_trick_figure(os.path.join(config.result_dir, 'figure08-truncation-trick.png'), load_Gs(Model), w=512, h=512, seeds=[91, 388], psis=[1, 0.7, 0.5, 0, -0.5, -1])#----------------------------------------------------------------------------if __name__ == "__main__":main()#----------------------------------------------------------------------------

(完)

下一篇:

轻轻松松使用StyleGAN(三):基于ResNet50构造StyleGAN的逆向网络,从目标图像提取特征码

【实战】轻轻松松使用StyleGAN(二):源代码初探+中文注释,generate_figure.py相关推荐

  1. 【实战】轻轻松松使用StyleGAN(一):创建令人惊讶的黄种人脸和专属于自己的老婆动漫头像

    NVIDIA(英伟达)开源了StyleGAN,用它可以生成令人惊讶的逼真人脸:也可以像某些人所说的,生成专属于自己的老婆动漫头像.这些生成的人脸或者动漫头像都是此前这个世界上从来没有过的,完全是被&q ...

  2. 轻轻松松使用StyleGAN2(六):StyleGAN2 Encoder是怎样加载训练数据的?源代码+中文注释,dataset_tool.py和dataset.py

    上一篇文章里,我们简单介绍了StyleGAN2 Encoder的一部分源代码,即:projector.py和project_images.py,内容请参考: 轻轻松松使用StyleGAN2(五):St ...

  3. 【实战】轻轻松松使用StyleGAN(六):StyleGAN Encoder找到真实人脸对应的特征码,核心源代码+中文注释

    在上一篇文章中,我们用了四种方法来寻找真实人脸对应的特征码,都没有成功,内容请参考: https://blog.csdn.net/weixin_41943311/article/details/102 ...

  4. 【实战】轻轻松松使用StyleGAN(七):用StyleGAN Encoder为女朋友制作美丽头像

    在上一篇文章里,我们下载了StyleGAN Encoder的源代码和相关资源,可以实现对真实人脸特征码的提取,内容请参考: https://blog.csdn.net/weixin_41943311/ ...

  5. 轻轻松松使用StyleGAN(五):提取真实人脸特征码的一些探索

    我们在上一篇文章中提到:能不能给出一个目标图像,使用神经网络自动提取出它的特征码呢? 如果可以,那么我们就可以方便地对这些图像进行编辑,创造出各种各样"酷炫"的风格人像. 这个工作 ...

  6. “物联网开发实战”学习笔记-(二)手机控制智能电灯

    "物联网开发实战"学习笔记-(二)手机控制智能电灯 如果搭建好硬件平台后,这一次我们的任务主要是调试好智能电灯,并且连接到腾讯云的物联网平台. 腾讯云物联网平台 腾讯物联网平台的优 ...

  7. 在APPLE从创建ID到申请发布AppStore账户(二)初探苹果开发者中心

    上篇说到:在APPLE从创建ID到申请发布AppStore账户(一)Apple ID注册自动登录 https://blog.csdn.net/AITop_Leader/article/details/ ...

  8. 《自然语言处理实战入门》 ---- 第4课 :中文分词原理及相关组件简介 之 汉语分词领域主要分词算法、组件、服务(上)...

    目录 0.内容梗概 1. 基于传统统计算法的分词组件 1.1 hanlp : Han Language Processing 1.2 语言技术平台(Language Technology Platfo ...

  9. SLAM导航机器人零基础实战系列:(二)ROS入门——2.ROS系统整体架构

    SLAM导航机器人零基础实战系列:(二)ROS入门--2.ROS系统整体架构 摘要 ROS机器人操作系统在机器人应用领域很流行,依托代码开源和模块间协作等特性,给机器人开发者带来了很大的方便.我们的机 ...

最新文章

  1. Java NIO系列教程(二) Channel
  2. 深入浅出.NET泛型编程(1)
  3. Date对象 IOS踩坑
  4. Springboot--Ehcache-Jpa (1)
  5. 2019.03.30 图解HTTP
  6. 使用Flash,HTML5和Unity开发网页游戏的对比
  7. ios实现读写锁,AFN的实现
  8. SiamFC论文理解及代码理解
  9. CCF CSP 数据中心 c++ python csp201812_4 100分
  10. [c#]删除PDF权限密码
  11. 不会讲故事,怎么带团队(用故事简化沟通,提升团队效率)--读后感
  12. FCC ES6篇中的解构赋值
  13. Python练手小项目(6)随机取红黑球并计算概率
  14. 273222-06-3,(2S,4R)-Boc-4-amino-1-Fmoc-pyrrolidine-2-carboxylic acid,(2S,4R)-Fmoc-4-叔丁氧羰基氨基吡咯烷-2-甲酸
  15. Python 真的好学吗?
  16. 计算机网络——网络层——思维导图
  17. 最大子矩阵(C语言)
  18. 超级玛丽3号MAX 达尔文3号,谁才是真正的重疾险王炸?
  19. (转)什么时候要抛出异常?
  20. 【Java开发岗:SpringCould篇】

热门文章

  1. 对话式AI系列:任务型多轮对话的实践与探索
  2. 从0开始教你编写Makefile文件
  3. < 数据结构 > 堆的应用 --- 堆排序和Topk问题
  4. 一根苹果数据线背后的“血战”
  5. 大一暑假实习day5_3
  6. 用python写一个程序
  7. 混合App应用实现本地头像剪切,压缩上传功能(支持任何H5框架)
  8. 12N65-ASEMI高压MOS管12N65
  9. 01-为什么要学爬虫-python小白爬虫入门教程
  10. 苹果备忘录丢失如何恢复,小编支招教你轻松搞定