动漫头像数据集下载地址:动漫头像数据集_百度云连接,DCGAN论文下载地址: https://arxiv.org/abs/1511.06434

数据集里面的图片是这个样子的:

这是DCGAN的主要改进地方:

下面是所有代码:

第一个模块:

import torch
import torch.nn as nn
import numpy as np
import torch.nn.init as init
import data_helper
from torchvision import transformstrans = transforms.Compose([transforms.ToTensor(),transforms.Normalize((.5, .5, .5), (.5, .5, .5))]
)
G_LR = 0.0002
D_LR = 0.0002
BATCHSIZE = 50
EPOCHES = 3000def init_ws_bs(m):if isinstance(m, nn.ConvTranspose2d):init.normal_(m.weight.data, std=0.2)init.normal_(m.bias.data, std=0.2)class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.deconv1 = nn.Sequential(nn.ConvTranspose2d(  # stride(input_w-1)+k-2*Paddingin_channels=100,out_channels=64 * 8,kernel_size=4,stride=1,padding=0,bias=False,),nn.BatchNorm2d(64 * 8),nn.ReLU(inplace=True),)self.deconv2 = nn.Sequential(nn.ConvTranspose2d(  # stride(input_w-1)+k-2*Paddingin_channels=64 * 8,out_channels=64 * 4,kernel_size=4,stride=2,padding=1,bias=False,),nn.BatchNorm2d(64 * 4),nn.ReLU(inplace=True),)self.deconv3 = nn.Sequential(nn.ConvTranspose2d(  # stride(input_w-1)+k-2*Paddingin_channels=64 * 4,out_channels=64 * 2,kernel_size=4,stride=2,padding=1,bias=False,),nn.BatchNorm2d(64 * 2),nn.ReLU(inplace=True),)self.deconv4 = nn.Sequential(nn.ConvTranspose2d(  # stride(input_w-1)+k-2*Paddingin_channels=64 * 2,out_channels=64 * 1,kernel_size=4,stride=2,padding=1,bias=False,),nn.BatchNorm2d(64),nn.ReLU(inplace=True),)self.deconv5 = nn.Sequential(nn.ConvTranspose2d(64, 3, 5, 3, 1, bias=False),nn.Tanh(),)def forward(self, x):x = self.deconv1(x)x = self.deconv2(x)x = self.deconv3(x)x = self.deconv4(x)x = self.deconv5(x)return xclass Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(  # batchsize,3,96,96in_channels=3,out_channels=64,kernel_size=5,padding=1,stride=3,bias=False,),nn.BatchNorm2d(64),nn.LeakyReLU(.2, inplace=True),)self.conv2 = nn.Sequential(nn.Conv2d(64, 64 * 2, 4, 2, 1, bias=False, ),  # batchsize,16,32,32nn.BatchNorm2d(64 * 2),nn.LeakyReLU(.2, inplace=True),)self.conv3 = nn.Sequential(nn.Conv2d(64 * 2, 64 * 4, 4, 2, 1, bias=False),nn.BatchNorm2d(64 * 4),nn.LeakyReLU(.2, inplace=True),)self.conv4 = nn.Sequential(nn.Conv2d(64 * 4, 64 * 8, 4, 2, 1, bias=False),nn.BatchNorm2d(64 * 8),nn.LeakyReLU(.2, inplace=True),)self.output = nn.Sequential(nn.Conv2d(64 * 8, 1, 4, 1, 0, bias=False),nn.Sigmoid()  #)def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = self.conv4(x)x = self.output(x)return xg = Generator().cuda()
d = Discriminator().cuda()init_ws_bs(g), init_ws_bs(d)g_optimizer = torch.optim.Adam(g.parameters(), betas=(.5, 0.999), lr=G_LR)
d_optimizer = torch.optim.Adam(d.parameters(), betas=(.5, 0.999), lr=D_LR)g_loss_func = nn.BCELoss()
d_loss_func = nn.BCELoss()label_real = torch.ones(BATCHSIZE).cuda()
label_fake = torch.zeros(BATCHSIZE).cuda()real_img = data_helper.get_imgs()for epoch in range(EPOCHES):np.random.shuffle(real_img)count = 0batch_imgs = []for i in range(len(real_img)):count = count + 1batch_imgs.append(trans(real_img[i]).numpy())  # tensor类型#这里经过trans操作通道维度从第四个到第二个了if count == BATCHSIZE:count = 0batch_real = torch.Tensor(batch_imgs).cuda()batch_imgs.clear()d_optimizer.zero_grad()pre_real = d(batch_real).squeeze()d_real_loss = d_loss_func(pre_real, label_real)d_real_loss.backward()batch_fake = torch.randn(BATCHSIZE, 100, 1, 1).cuda()img_fake = g(batch_fake).detach()pre_fake = d(img_fake).squeeze()d_fake_loss = d_loss_func(pre_fake, label_fake)d_fake_loss.backward()d_optimizer.step()g_optimizer.zero_grad()batch_fake = torch.randn(BATCHSIZE, 100, 1, 1).cuda()img_fake = g(batch_fake)pre_fake = d(img_fake).squeeze()g_loss = g_loss_func(pre_fake, label_real)g_loss.backward()g_optimizer.step()print(i,(d_real_loss + d_fake_loss).detach().cpu().numpy(), g_loss.detach().cpu().numpy())torch.save(g, "pkl/" + str(epoch) + "g.pkl")

以上网络结构和参数,是从另一个博客找来的Pytorch实战3:DCGAN深度卷积对抗生成网络生成动漫头像

其中调用了一个同目录下的data_helper模块,用来从本地数据文件夹中获取图片(list of arrray),其中需要把文件夹改成自己的文件夹:

import cv2
import os
MAIN_PATH="E:/DataSets/faces/"
def get_imgs():files = os.listdir(MAIN_PATH)imgs = []for file in files:imgs.append(cv2.imread(MAIN_PATH + file))print("get_imgs")return imgs

由于是对每个epoch都保存了网络结构,所以可以在训练完成后选择需要加载的本地网络文件,然后测试效果:

用下面的代码来测试:

import torch.nn as nn
import torch
import cv2
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.deconv1=nn.Sequential(nn.ConvTranspose2d(#stride(input_w-1)+k-2*Paddingin_channels=100,out_channels=64*8,kernel_size=4,stride=1,padding=0,bias=False,),nn.BatchNorm2d(64*8),nn.ReLU(inplace=True),)#14self.deconv2=nn.Sequential(nn.ConvTranspose2d(#stride(input_w-1)+k-2*Paddingin_channels=64*8,out_channels=64*4,kernel_size=4,stride=2,padding=1,bias=False,),nn.BatchNorm2d(64*4),nn.ReLU(inplace=True),)#24self.deconv3 = nn.Sequential(nn.ConvTranspose2d(  # stride(input_w-1)+k-2*Paddingin_channels=64*4,out_channels=64*2,kernel_size=4,stride=2,padding=1,bias=False,),nn.BatchNorm2d(64*2),nn.ReLU(inplace=True),)#48self.deconv4 = nn.Sequential(nn.ConvTranspose2d(  # stride(input_w-1)+k-2*Paddingin_channels=64*2,out_channels=64*1,kernel_size=4,stride=2,padding=1,bias=False,),nn.BatchNorm2d(64),nn.ReLU(inplace=True),)self.deconv5=nn.Sequential(nn.ConvTranspose2d(64,3,5,3,1,bias=False),nn.Tanh(),)def forward(self, x):x=self.deconv1(x)x=self.deconv2(x)x=self.deconv3(x)x=self.deconv4(x)x=self.deconv5(x)return  xg=torch.load("pkl/15g.pkl")
imgs=g(torch.randn(100,100,1,1).cuda())
for i in range(len(imgs)):img=imgs[i].permute(1,2,0).cpu().detach().numpy()*255cv2.imwrite("bitmaps/"+str(i)+".jpg",img,)#这里需要在同目录下建立一个bitmap文件夹
print("done")

接下来是运行的效果:

第一个epoch结束之后:

第15个epoch之后:

一共训练了67个epoch,最终的结果:

可以看到效果和15个迭代器的差别不大,说明网络的收敛速度还是可以的,但是效果也不算太好,没有原先数据集中的美观。

pytorch:DCGAN生成动漫头像相关推荐

  1. vs2019 利用Pytorch和TensorFlow分别实现DCGAN生成动漫头像

    这是针对于博客vs2019安装和使用教程(详细)的DCGAN生成动漫头像项目新建示例 目录 一.DCGAN架构及原理 二.项目结构 1.TensorFlow 2.Pytorch 三.数据集下载(两种方 ...

  2. DCGAN生成动漫头像【学习】

    DCGAN生成动漫头像 在假期看了李宏毅老师的GAN的介绍,看到了课后题DCGAN生成动漫头像的作业,实现一下.记录学习过程. 参考的文章: [Keras] 基于GAN自动生成动漫头像 因为使用的是t ...

  3. 通过PyTorch用DCGAN生成动漫头像

    数据集 数据集我们用AnimeFaces数据集,共5万多张动漫头像. 链接:https://pan.baidu.com/s/1cp-A8ZV74YBelkSuKxuM6A 提取码:face 要把所有的 ...

  4. 使用TensorFlow2.0搭建DCGAN生成动漫头像(内含生成过程GIF图)

    文章目录 生成对抗网络介绍 一.造假 二.训练判别器 三.训练生成器 DCGAN介绍 搭建DCGAN 数据来源 必要工作 读取数据 构建生成器 构建判别器 连接模型 连接图片 生成函数 训练 生成对抗 ...

  5. 基于Tensorflow和DCGAN生成动漫头像实践(二)

    本篇内容为动漫头像生成的主要代码部分,第一次写这种代码,从读取数据到生成走了一个完整的流程.创建TFrecord过程可以看上一篇内容. 代码内容: #!/usr/bin/env python2 # - ...

  6. DCGAN生成动漫头像(附代码)

    DCGAN.顾名思义,就是深度卷积生成对抗神经网络,也就是引入了卷积的,但是它用的是反卷积,就是卷积的反操作. 我们看看DCGAN的图: 生成器开始输入的是噪声数据,然后经过一个全连接层,再把全连接层 ...

  7. 有趣的图像生成——使用DCGAN与pytorch生成动漫头像

    有趣的图像生成--使用DCGAN与pytorch生成动漫头像 文章目录 有趣的图像生成--使用DCGAN与pytorch生成动漫头像 一.源码下载 二.什么是DCGAN 三.DCGAN的实现 1.** ...

  8. pytorch实现DCGAN生成动漫人物头像

    pytorch实现DCGAN生成动漫人物头像 DCGAN原理 参考这一系列文章 数据集 21551张64*64动漫人物头像 生成效果 训练1个epoch(emm-) 训练10个epoch(起码有颜色了 ...

  9. 【PyTorch】12 生成对抗网络实战——用GAN生成动漫头像

    GAN 生成动漫头像 1. 获取数据 2. 用GAN生成 2.1 Generator 2.2 Discriminator 2.3 其它细节 2.4 训练思路 3. 全部代码 4. 结果展示与分析 小结 ...

最新文章

  1. 安卓Socket连接实现连接实现发送接收数据,openwrt wifi转串口连接单片机实现控制...
  2. MVC在基控制器中实现处理Session的逻辑
  3. mongodb安装_MongoDB索引策略和索引类型
  4. gitblit.cmd运行自动关闭
  5. 3 地理位置定位_IP地理定位API的十大用途和应用
  6. centos 使vim支持+python和+python3
  7. 【VMCloud云平台】拥抱Docker(六)关于DockerFile(1)
  8. python 抽象语法树_用python演示一个简单的AST(抽象语法树)
  9. 【产品评测】华为开源镜像站体验:美好终将不期而遇
  10. java stub_Java Stub 研究学习(2)
  11. 为什么一个程序中变量只能定义一次_#带你学Python# 从简单程序出发理解Python基本语法
  12. 读写执行Druapl7 Note-5: 利用FTP安装module或theme时出错(FIXED)
  13. 让你的软件支持繁体中文
  14. dlna 斐讯r1怎么用_挽救智障——斐讯R1:固件升级、安装DLNA和Soundwire
  15. 使用WPF设计类似Visio的简单绘图软件
  16. 实习周记----第三周
  17. 网络链接错误,请检查配置后重试!
  18. 【生信】全基因组测序(WGS)
  19. Unity -- 用EasyAR制作出AR红包
  20. dB单位与放大倍数关系

热门文章

  1. 因为你在,才有我最好的年华
  2. JavaWeb学生信息管理系统_查询V1.0
  3. wubi装双系统,可能导致无线网卡无法工作
  4. python中的字符、编码、转换
  5. C语言指针的初步了解
  6. 8A计算机游戏问题课文翻译,2013-新译林牛津英语8a课文翻译.pdf
  7. PDF分割文件该怎么分割
  8. 联想微型计算机c21r3,联想c21r3一体机拆机暗器,必须先拆开这个。
  9. NetSuite 中国现金流量表(直接法)功能包
  10. DPDK原理探索: igb_uio