本篇内容为动漫头像生成的主要代码部分,第一次写这种代码,从读取数据到生成走了一个完整的流程。创建TFrecord过程可以看上一篇内容。

代码内容:

#!/usr/bin/env python2
# -*- coding: utf-8 -*-import tensorflow as  tf
import numpy as np
import importlib,sys
import matplotlib.pyplot as plt stdi,stdo,stde=sys.stdin,sys.stdout,sys.stderr  #编码问题,重载sys,不然读取图片会报错
importlib.reload(sys)
#sys.setdefaultencoding('utf-8')
sys.stdin,sys.stdout,sys.stderr=stdi,stdo,stdenoises_size = 128  #定义噪声的维度大小def gen_deconv(batch_input,out_channels):return tf.layers.conv2d_transpose(batch_input,out_channels,4,2,padding='same') #解卷积操作,这里的参数可以理解为把图像长宽x2#tensorflow的解卷积方式有layer和nn下两种方式,#nn下的可以指定输出的维度大小,但是使用时老是#报错,输出维度可以可根据解卷积核和步长推理出def batchnorm(inputs):               #batch normalization                                        return tf.layers.batch_normalization(inputs, axis=3, epsilon=1e-5, momentum=0.1, training=True, gamma_initializer=tf.random_normal_initializer(1.0, 0.02))def lrelu(x, a):           #lrelu激活函数,看不懂的可以将x正负时代入试试with tf.name_scope("lrelu"):x = tf.identity(x)return (0.5 * (1 + a)) * x + (0.5 * (1 - a)) * tf.abs(x)def generator(noises,base=128,output_channels=3):layers =[] # Linear_layer: [batch, 1, 1, base*8]=> [batch, 3, 3, base*8]with tf.variable_scope("Linear_layer"):W = tf.get_variable("w", [noises_size, 3*3*8*base], tf.float32,tf.random_normal_initializer(stddev=0.02))b = tf.get_variable("b", [3*3*8*base])output = tf.matmul(noises,W) + b         #一开始是一个全连接层output = tf.reshape(output,[-1,3,3,base*8])    #[batch,3*3*base*8] -> [batch,3,3,base*8]output = batchnorm(output)              print("gen_layer_%d" % (len(layers)+1))print(output.shape)layers.append(output)layers_specs =[(base * 8,0.5), # deconv_2: [batch, 3, 3, base*8] => [batch, 6, 6, base*8](base * 4,0.5), # deconv_3: [batch, 6, 6, base*8] => [batch, 12, 12, base*4](base * 2,0.0), # deconv_4: [batch, 12, 12, base*4] => [batch, 24, 24, base*2](base * 1,0.0) # deconv_5: [batch, 24, 24, base*2] => [batch, 48, 48, base*1]     ]for (out_channels,dropout) in layers_specs:with tf.variable_scope("deconv_%d" % (len(layers)+1)):temp = tf.nn.relu(layers[-1])           #relu激活函数output = gen_deconv(temp,out_channels)     #解卷积output = batchnorm(output)                 #batch normalization#if dropout> 0.0:# output = tf.nn.dropout(output,keep_prob=1-dropout)     #这个drop操作我忽略掉了print("gen_layer_%d" % (len(layers)+1))print(output.shape)layers.append(output)# deconv_6: [batch, 48, 48, base*2] => [batch, 96, 96, output_channels]with tf.variable_scope("deconv_6"):temp = layers[-1]output = tf.nn.relu(temp)output = gen_deconv(output,output_channels)output = tf.tanh(output)                      #tanh函数将输出转化为(-1,1)print("gen_layer_%d" % (len(layers)+1))print(output.shape)layers.append(output)return layers[-1]def dis_conv(batch_input,out_channels):in_channels = int(batch_input.shape[3])kernel = tf.get_variable(initializer=tf.random_normal(shape=[4,4,in_channels,out_channels]),name='kernel')  #卷积操作,把图像长宽缩小为1/2,输出通道数为out_channelsreturn tf.nn.conv2d(batch_input,kernel,strides=[1,2,2,1],padding='SAME')def discriminator(dis_input,base=128):layers = []# layer_1: [batch, 96, 96, 3] => [batch, 48, 48, base]with tf.variable_scope("layer_1"):output = dis_conv(dis_input,base)output = lrelu(output,0.2)print("layer_1")print(output.shape)layers.append(output)layers_spec = [base*2,  # layer_2: [batch, 48, 48, base] => [batch, 24, 24, base*2]base*4,  # layer_3: [batch, 24, 24, base*2] => [batch, 12, 12, base*4]base*8,  # layer_4: [batch, 12, 12, base*4] => [batch, 6, 6, base*8]base*8   # layer_5: [batch, 6, 6, base*8] => [batch, 3, 3, base*8]                     ] for out_channels in layers_spec:     with tf.variable_scope("layer_%d" % (len(layers)+1)):output = dis_conv(layers[-1],out_channels)  #进行卷积output = batchnorm(output)                  #进行batch normalizationoutput = lrelu(output,0.2)                  #激活函数为lreluprint("layer_%d" % (len(layers)+1))       print(output.shape)layers.append(output)# layer_5: [batch, 3, 3, base*8] => [batch, 1]    with tf.variable_scope("layer_%d" % (len(layers)+1)):output = tf.reshape(layers[-1],[-1,3*3*base*8])   #[batch,3,3,base*8] -> [batch,3*3*8]W = tf.get_variable("w", [3*3*8*base,1], tf.float32,tf.random_normal_initializer(stddev=0.02))b = tf.get_variable("b", [1])output = tf.matmul(output,W) + b              #这里是一个全连接层output = tf.sigmoid(output)                   #sigmoid函数转化输出为0-1print("layer_%d" % (len(layers)+1))print(output.shape)layers.append(output)return layers[-1]def create_model(gen_inputs,dis_inputs,learning_rate):EPS =  1e-12with tf.variable_scope("generator"):gen_outputs = generator(gen_inputs)with tf.variable_scope("discriminator"):predict_real = discriminator(dis_inputs)with tf.variable_scope("discriminator",reuse=True):predict_fake = discriminator(gen_outputs)with tf.name_scope("discriminator_loss"):dis_loss = tf.reduce_mean(-tf.log(predict_real+EPS)-tf.log(1-predict_fake+EPS)) #这里加上EPS防止出现log(0),否则loss会变成nanwith tf.name_scope("generator_loss"):gen_loss = tf.reduce_mean(-tf.log(predict_fake+EPS))all_var = tf.trainable_variables() with tf.name_scope("discriminator_train"):dis_var = [var for var in all_var if var.name.startswith("discriminator")] dis_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(dis_loss, var_list=dis_var)  #定义D的优化with tf.name_scope("generator_train"):gen_var = [var for var in all_var if var.name.startswith("generator")]gen_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(gen_loss, var_list=gen_var) #定义G的优化return dis_optimizer,gen_optimizer,dis_loss,gen_loss,gen_outputsdef read():files = tf.train.match_filenames_once("./TFrecord/data-tfrecords-*")filename_queue = tf.train.string_input_producer(files,shuffle=True)  #将files输入到一个队列reader =tf.TFRecordReader()_,serialize_example = reader.read(filename_queue)    #从队列中读出数据features = tf.parse_single_example(serialize_example,features={ 'height':tf.FixedLenFeature([],tf.int64),'width':tf.FixedLenFeature([],tf.int64),'channel':tf.FixedLenFeature([],tf.int64),'image_raw':tf.FixedLenFeature([],tf.string)    })image_raw = features['image_raw']                      #读入图片数据decoded_image = tf.decode_raw(image_raw,tf.uint8)      #将字符串形式的数据解码images  = tf.reshape(decoded_image,[96,96,3])          #重新定义shapereturn images  def main():batch_size = 64    #定义一个batch的大小gen_inputs = tf.placeholder(tf.float32,[None,noises_size])  #定义Generator的输入dis_inputs = tf.placeholder(tf.float32,[None,96,96,3])      #定义Discriminator的输入dis_optimizer,gen_optimizer,dis_loss,gen_loss,gen_output= create_model(gen_inputs,dis_inputs,0.0002)  #创建模型gen_images = (gen_output + 1) * 127.5    #将Generator的输出转化为可以显示的图像gen_images = tf.cast(gen_images,tf.int32)images = read()images = tf.cast(images,tf.float32)images_input = images / 127.5 - 1   #将图像数据范围变成-1-1之间images_batch = tf.train.batch([images_input],batch_size=batch_size,capacity=5000)  #把image按batch输出,这里会按多线程加快读取速度#注意这里的[]不可省略saver = tf.train.Saver()gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.8)            #定义占用GPU的内存比例with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:tf.local_variables_initializer().run()                                  tf.global_variables_initializer().run()                       #变量初始化coord =tf.train.Coordinator()  #Coordinator类用来帮助多个线程协同工作,多个线程同步终止threads = tf.train.start_queue_runners(coord=coord)steps = 0while True:cur_batch = sess.run(images_batch)   #产生一个batchnoises = np.random.uniform(-1,1,size=(batch_size,noises_size)).astype(np.float32)  #产生一个噪声for i in range(2):     _,discriminator_loss = sess.run([dis_optimizer,dis_loss],feed_dict={gen_inputs:noises,dis_inputs:cur_batch})  #训练Discriminator_,generator_loss = sess.run([gen_optimizer, gen_loss],feed_dict={gen_inputs:noises})                 #训练Generator#训练过程会出现D过弱或过强的现象,可以通过加大D的训练次数#或者调整learning rate 来达到平衡if steps % 20 == 0:print("%d steps:  gen_loss is %f; dis_loss is %f" % (steps,float(generator_loss),float(discriminator_loss)))   #每训练20个batch输出一遍lossif steps % 100 == 0:                    #每训练100个batch保存一张图片now_image =sess.run(gen_images,{gen_inputs:noises})          plt.imshow(now_image[0].astype(np.uint32))#plt.show()plt.savefig("./result/R_%d.png" % steps)saver.save(sess,"./Models1/model.ckpt")steps += 1coord.request_stop()coord.join(threads)     #终止线程main()

生成的结果:

我选取了几张比较成功的:

可以看到训练效果还算可以,但还是比不上训练的数据好。实际上,一开始我的网络并不是DCGAN,是从pix2pix抄取了一部分,也能生成类似的结果,但是后来我把网络改造了一下比之前参数更多更强,希望生成更好的结果,但是却再也没训练出来头像,所以网络还是不能乱改的,这大概是一门玄学,我还需学习。

然后列几个心得吧:

1.学习写tensorflow时候,首先要搞清楚基本概念,计算图和variable,在程序开始训练的时候计算图是已经构建好的,写的generator和discriminator这些函数只是在构建计算图的时候运行了一次,这之后再也不会运行,因为计算图已经保存好了,写成函数的目的主要是更好地分隔不同的功能以及不必再写相同重复使用的步骤,比如卷积解卷积。模型在训练的时候只会训练variable类型的tensor,variable_scope的作用是把变量放入文件夹里,比如我把G的变量全部放进generator命名的文件夹里,训练的时候我把generator文件夹里的所有变量拿出来训练即可。

2.然后是一开始写代码的时候,如果没有自己的套路最好先参考别人的写法,我是借鉴了两个代码的,虽然都没完全看懂,但是写法还是学习到了的,只要形成了自己的写法,剩下的就是改造网络的事情,另外tensorflow的api有点多,特别是解卷积的比较复杂,我尝试的几个api都会报一些错误,所以我就不停的换,所以很多时间是花在这些细节上面的。

最后,为我的第一篇博客撒花 ^ - ^

基于Tensorflow和DCGAN生成动漫头像实践(二)相关推荐

  1. vs2019 利用Pytorch和TensorFlow分别实现DCGAN生成动漫头像

    这是针对于博客vs2019安装和使用教程(详细)的DCGAN生成动漫头像项目新建示例 目录 一.DCGAN架构及原理 二.项目结构 1.TensorFlow 2.Pytorch 三.数据集下载(两种方 ...

  2. DCGAN生成动漫头像【学习】

    DCGAN生成动漫头像 在假期看了李宏毅老师的GAN的介绍,看到了课后题DCGAN生成动漫头像的作业,实现一下.记录学习过程. 参考的文章: [Keras] 基于GAN自动生成动漫头像 因为使用的是t ...

  3. pytorch:DCGAN生成动漫头像

    动漫头像数据集下载地址:动漫头像数据集_百度云连接,DCGAN论文下载地址: https://arxiv.org/abs/1511.06434 数据集里面的图片是这个样子的: 这是DCGAN的主要改进 ...

  4. 使用TensorFlow2.0搭建DCGAN生成动漫头像(内含生成过程GIF图)

    文章目录 生成对抗网络介绍 一.造假 二.训练判别器 三.训练生成器 DCGAN介绍 搭建DCGAN 数据来源 必要工作 读取数据 构建生成器 构建判别器 连接模型 连接图片 生成函数 训练 生成对抗 ...

  5. 通过PyTorch用DCGAN生成动漫头像

    数据集 数据集我们用AnimeFaces数据集,共5万多张动漫头像. 链接:https://pan.baidu.com/s/1cp-A8ZV74YBelkSuKxuM6A 提取码:face 要把所有的 ...

  6. python 动漫卡通人物图片大全,『TensorFlow』DCGAN生成动漫人物头像_下

    一.计算图效果以及实际代码实现 计算图效果 实际模型实现 相关介绍移步我的github项目. 二.生成器与判别器设计 生成器 相关参量, 噪声向量z维度:100 标签向量y维度:10(如果有的话) 生 ...

  7. 『TensorFlow』DCGAN生成动漫人物头像_下

    『TensorFlow』以GAN为例的神经网络类范式 『cs231n』通过代码理解gan网络&tensorflow共享变量机制_上 『TensorFlow』通过代码理解gan网络_中 一.计算 ...

  8. DCGAN生成动漫头像(附代码)

    DCGAN.顾名思义,就是深度卷积生成对抗神经网络,也就是引入了卷积的,但是它用的是反卷积,就是卷积的反操作. 我们看看DCGAN的图: 生成器开始输入的是噪声数据,然后经过一个全连接层,再把全连接层 ...

  9. 基于DCGAN的动漫头像生成

    基于DCGAN的动漫头像生成 数据 数据集:动漫图库爬虫获得,经过数据清洗,裁剪得到动漫头像.分辨率为3 * 96 * 96,共5万多张动漫头像的图片,从知乎用户何之源处下载. 生成器:输入为随机噪声 ...

最新文章

  1. [ python ] 基础技巧
  2. leetcode - 统计封闭岛屿的数目
  3. 所有controller interceptor_filter、interceptor、aspect不知如何选择
  4. 如何自建appender扩展Log4j框架
  5. python提高——进程、线程、协程对比及代码实现
  6. OKHTTP好文推荐
  7. Linux acpi off学习
  8. 使用WUCDCreator将SCSI、RAID、SATA、SAS驱动程序集成到光盘中
  9. python 顺序遍历文件夹下的文件
  10. python程序中1—10的乘积_[求助]1个数1到10的乘积
  11. 华为usg6320服务器映射,华为防火墙USG6320配置(简单)
  12. 快速排序与冒泡排序的效率对比
  13. 服务器怎么解绑网站域名,宝塔面板如何解绑域名
  14. cuem模拟器安装及使用
  15. 【学习资料】中国开放大学-电大-《教育学》形考作业答案(2018).docx
  16. 2021MySQL面试题
  17. HTML+CSS+JS网页设计期末课程大作业(家居网)
  18. cvpr2020 matlab_新zwpython 完胜 老matlab 篇二
  19. c语言实验求最小值,最小值c语言流程(C语言求最小值程序)
  20. 这个视频「橡皮擦」让你瞬间消失,头发丝都不留 | ECCV 2020

热门文章

  1. 假如你能给「微信」增加一个小功能
  2. Mysql数据库版本升级5.5---5.7
  3. 餐饮服务质量调查打分
  4. 解决IntelliJ IDEA控制台输出中文乱码问题(图文详解)
  5. 预测算法:具身智能如何应对不确定性[Reviews of Daniel Williams]
  6. 数据脱敏显示-用户名和手机号
  7. java——OOAD
  8. 天视通增加海康威视摄像头2023
  9. 电商系统设计艺术——秒杀业务设计
  10. 数据分析 | SVM模型