pytorch:DCGAN生成动漫头像
动漫头像数据集下载地址:动漫头像数据集_百度云连接,DCGAN论文下载地址: https://arxiv.org/abs/1511.06434
数据集里面的图片是这个样子的:
这是DCGAN的主要改进地方:
下面是所有代码:
第一个模块:
import torch
import torch.nn as nn
import numpy as np
import torch.nn.init as init
import data_helper
from torchvision import transformstrans = transforms.Compose([transforms.ToTensor(),transforms.Normalize((.5, .5, .5), (.5, .5, .5))]
)
G_LR = 0.0002
D_LR = 0.0002
BATCHSIZE = 50
EPOCHES = 3000def init_ws_bs(m):if isinstance(m, nn.ConvTranspose2d):init.normal_(m.weight.data, std=0.2)init.normal_(m.bias.data, std=0.2)class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.deconv1 = nn.Sequential(nn.ConvTranspose2d( # stride(input_w-1)+k-2*Paddingin_channels=100,out_channels=64 * 8,kernel_size=4,stride=1,padding=0,bias=False,),nn.BatchNorm2d(64 * 8),nn.ReLU(inplace=True),)self.deconv2 = nn.Sequential(nn.ConvTranspose2d( # stride(input_w-1)+k-2*Paddingin_channels=64 * 8,out_channels=64 * 4,kernel_size=4,stride=2,padding=1,bias=False,),nn.BatchNorm2d(64 * 4),nn.ReLU(inplace=True),)self.deconv3 = nn.Sequential(nn.ConvTranspose2d( # stride(input_w-1)+k-2*Paddingin_channels=64 * 4,out_channels=64 * 2,kernel_size=4,stride=2,padding=1,bias=False,),nn.BatchNorm2d(64 * 2),nn.ReLU(inplace=True),)self.deconv4 = nn.Sequential(nn.ConvTranspose2d( # stride(input_w-1)+k-2*Paddingin_channels=64 * 2,out_channels=64 * 1,kernel_size=4,stride=2,padding=1,bias=False,),nn.BatchNorm2d(64),nn.ReLU(inplace=True),)self.deconv5 = nn.Sequential(nn.ConvTranspose2d(64, 3, 5, 3, 1, bias=False),nn.Tanh(),)def forward(self, x):x = self.deconv1(x)x = self.deconv2(x)x = self.deconv3(x)x = self.deconv4(x)x = self.deconv5(x)return xclass Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d( # batchsize,3,96,96in_channels=3,out_channels=64,kernel_size=5,padding=1,stride=3,bias=False,),nn.BatchNorm2d(64),nn.LeakyReLU(.2, inplace=True),)self.conv2 = nn.Sequential(nn.Conv2d(64, 64 * 2, 4, 2, 1, bias=False, ), # batchsize,16,32,32nn.BatchNorm2d(64 * 2),nn.LeakyReLU(.2, inplace=True),)self.conv3 = nn.Sequential(nn.Conv2d(64 * 2, 64 * 4, 4, 2, 1, bias=False),nn.BatchNorm2d(64 * 4),nn.LeakyReLU(.2, inplace=True),)self.conv4 = nn.Sequential(nn.Conv2d(64 * 4, 64 * 8, 4, 2, 1, bias=False),nn.BatchNorm2d(64 * 8),nn.LeakyReLU(.2, inplace=True),)self.output = nn.Sequential(nn.Conv2d(64 * 8, 1, 4, 1, 0, bias=False),nn.Sigmoid() #)def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = self.conv4(x)x = self.output(x)return xg = Generator().cuda()
d = Discriminator().cuda()init_ws_bs(g), init_ws_bs(d)g_optimizer = torch.optim.Adam(g.parameters(), betas=(.5, 0.999), lr=G_LR)
d_optimizer = torch.optim.Adam(d.parameters(), betas=(.5, 0.999), lr=D_LR)g_loss_func = nn.BCELoss()
d_loss_func = nn.BCELoss()label_real = torch.ones(BATCHSIZE).cuda()
label_fake = torch.zeros(BATCHSIZE).cuda()real_img = data_helper.get_imgs()for epoch in range(EPOCHES):np.random.shuffle(real_img)count = 0batch_imgs = []for i in range(len(real_img)):count = count + 1batch_imgs.append(trans(real_img[i]).numpy()) # tensor类型#这里经过trans操作通道维度从第四个到第二个了if count == BATCHSIZE:count = 0batch_real = torch.Tensor(batch_imgs).cuda()batch_imgs.clear()d_optimizer.zero_grad()pre_real = d(batch_real).squeeze()d_real_loss = d_loss_func(pre_real, label_real)d_real_loss.backward()batch_fake = torch.randn(BATCHSIZE, 100, 1, 1).cuda()img_fake = g(batch_fake).detach()pre_fake = d(img_fake).squeeze()d_fake_loss = d_loss_func(pre_fake, label_fake)d_fake_loss.backward()d_optimizer.step()g_optimizer.zero_grad()batch_fake = torch.randn(BATCHSIZE, 100, 1, 1).cuda()img_fake = g(batch_fake)pre_fake = d(img_fake).squeeze()g_loss = g_loss_func(pre_fake, label_real)g_loss.backward()g_optimizer.step()print(i,(d_real_loss + d_fake_loss).detach().cpu().numpy(), g_loss.detach().cpu().numpy())torch.save(g, "pkl/" + str(epoch) + "g.pkl")
以上网络结构和参数,是从另一个博客找来的Pytorch实战3:DCGAN深度卷积对抗生成网络生成动漫头像
其中调用了一个同目录下的data_helper模块,用来从本地数据文件夹中获取图片(list of arrray),其中需要把文件夹改成自己的文件夹:
import cv2
import os
MAIN_PATH="E:/DataSets/faces/"
def get_imgs():files = os.listdir(MAIN_PATH)imgs = []for file in files:imgs.append(cv2.imread(MAIN_PATH + file))print("get_imgs")return imgs
由于是对每个epoch都保存了网络结构,所以可以在训练完成后选择需要加载的本地网络文件,然后测试效果:
用下面的代码来测试:
import torch.nn as nn import torch import cv2 class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.deconv1=nn.Sequential(nn.ConvTranspose2d(#stride(input_w-1)+k-2*Paddingin_channels=100,out_channels=64*8,kernel_size=4,stride=1,padding=0,bias=False,),nn.BatchNorm2d(64*8),nn.ReLU(inplace=True),)#14self.deconv2=nn.Sequential(nn.ConvTranspose2d(#stride(input_w-1)+k-2*Paddingin_channels=64*8,out_channels=64*4,kernel_size=4,stride=2,padding=1,bias=False,),nn.BatchNorm2d(64*4),nn.ReLU(inplace=True),)#24self.deconv3 = nn.Sequential(nn.ConvTranspose2d( # stride(input_w-1)+k-2*Paddingin_channels=64*4,out_channels=64*2,kernel_size=4,stride=2,padding=1,bias=False,),nn.BatchNorm2d(64*2),nn.ReLU(inplace=True),)#48self.deconv4 = nn.Sequential(nn.ConvTranspose2d( # stride(input_w-1)+k-2*Paddingin_channels=64*2,out_channels=64*1,kernel_size=4,stride=2,padding=1,bias=False,),nn.BatchNorm2d(64),nn.ReLU(inplace=True),)self.deconv5=nn.Sequential(nn.ConvTranspose2d(64,3,5,3,1,bias=False),nn.Tanh(),)def forward(self, x):x=self.deconv1(x)x=self.deconv2(x)x=self.deconv3(x)x=self.deconv4(x)x=self.deconv5(x)return xg=torch.load("pkl/15g.pkl") imgs=g(torch.randn(100,100,1,1).cuda()) for i in range(len(imgs)):img=imgs[i].permute(1,2,0).cpu().detach().numpy()*255cv2.imwrite("bitmaps/"+str(i)+".jpg",img,)#这里需要在同目录下建立一个bitmap文件夹 print("done")
接下来是运行的效果:
第一个epoch结束之后:
第15个epoch之后:
一共训练了67个epoch,最终的结果:
可以看到效果和15个迭代器的差别不大,说明网络的收敛速度还是可以的,但是效果也不算太好,没有原先数据集中的美观。
pytorch:DCGAN生成动漫头像相关推荐
- vs2019 利用Pytorch和TensorFlow分别实现DCGAN生成动漫头像
这是针对于博客vs2019安装和使用教程(详细)的DCGAN生成动漫头像项目新建示例 目录 一.DCGAN架构及原理 二.项目结构 1.TensorFlow 2.Pytorch 三.数据集下载(两种方 ...
- DCGAN生成动漫头像【学习】
DCGAN生成动漫头像 在假期看了李宏毅老师的GAN的介绍,看到了课后题DCGAN生成动漫头像的作业,实现一下.记录学习过程. 参考的文章: [Keras] 基于GAN自动生成动漫头像 因为使用的是t ...
- 通过PyTorch用DCGAN生成动漫头像
数据集 数据集我们用AnimeFaces数据集,共5万多张动漫头像. 链接:https://pan.baidu.com/s/1cp-A8ZV74YBelkSuKxuM6A 提取码:face 要把所有的 ...
- 使用TensorFlow2.0搭建DCGAN生成动漫头像(内含生成过程GIF图)
文章目录 生成对抗网络介绍 一.造假 二.训练判别器 三.训练生成器 DCGAN介绍 搭建DCGAN 数据来源 必要工作 读取数据 构建生成器 构建判别器 连接模型 连接图片 生成函数 训练 生成对抗 ...
- 基于Tensorflow和DCGAN生成动漫头像实践(二)
本篇内容为动漫头像生成的主要代码部分,第一次写这种代码,从读取数据到生成走了一个完整的流程.创建TFrecord过程可以看上一篇内容. 代码内容: #!/usr/bin/env python2 # - ...
- DCGAN生成动漫头像(附代码)
DCGAN.顾名思义,就是深度卷积生成对抗神经网络,也就是引入了卷积的,但是它用的是反卷积,就是卷积的反操作. 我们看看DCGAN的图: 生成器开始输入的是噪声数据,然后经过一个全连接层,再把全连接层 ...
- 有趣的图像生成——使用DCGAN与pytorch生成动漫头像
有趣的图像生成--使用DCGAN与pytorch生成动漫头像 文章目录 有趣的图像生成--使用DCGAN与pytorch生成动漫头像 一.源码下载 二.什么是DCGAN 三.DCGAN的实现 1.** ...
- pytorch实现DCGAN生成动漫人物头像
pytorch实现DCGAN生成动漫人物头像 DCGAN原理 参考这一系列文章 数据集 21551张64*64动漫人物头像 生成效果 训练1个epoch(emm-) 训练10个epoch(起码有颜色了 ...
- 【PyTorch】12 生成对抗网络实战——用GAN生成动漫头像
GAN 生成动漫头像 1. 获取数据 2. 用GAN生成 2.1 Generator 2.2 Discriminator 2.3 其它细节 2.4 训练思路 3. 全部代码 4. 结果展示与分析 小结 ...
最新文章
- 安卓Socket连接实现连接实现发送接收数据,openwrt wifi转串口连接单片机实现控制...
- MVC在基控制器中实现处理Session的逻辑
- mongodb安装_MongoDB索引策略和索引类型
- gitblit.cmd运行自动关闭
- 3 地理位置定位_IP地理定位API的十大用途和应用
- centos 使vim支持+python和+python3
- 【VMCloud云平台】拥抱Docker(六)关于DockerFile(1)
- python 抽象语法树_用python演示一个简单的AST(抽象语法树)
- 【产品评测】华为开源镜像站体验:美好终将不期而遇
- java stub_Java Stub 研究学习(2)
- 为什么一个程序中变量只能定义一次_#带你学Python# 从简单程序出发理解Python基本语法
- 读写执行Druapl7 Note-5: 利用FTP安装module或theme时出错(FIXED)
- 让你的软件支持繁体中文
- dlna 斐讯r1怎么用_挽救智障——斐讯R1:固件升级、安装DLNA和Soundwire
- 使用WPF设计类似Visio的简单绘图软件
- 实习周记----第三周
- 网络链接错误,请检查配置后重试!
- 【生信】全基因组测序(WGS)
- Unity -- 用EasyAR制作出AR红包
- dB单位与放大倍数关系