数据集

kaggle:https://www.kaggle.com/datasets/soumikrakshit/anime-faces

代码

import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.utils import data
from torchvision import transforms
import glob
from PIL import Image# glob获取全部图像的路径
imgs_path = glob.glob(r'anime-faces/*.png')# 画6张看看
# plt.figure(figsize=(12, 8))
# for i, img_path in enumerate(imgs_path[:6]):
#     img = np.array(Image.open(img_path))
#     plt.subplot(2, 3, i+1)
#     plt.imshow(img)
#     print(img.shape)
# plt.show()# GAN 输入-1,1 便于训练
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(0.5, 0.5),  # 数据减均值除方差
])# 创建数据集
class Face_dataset(data.Dataset):def __init__(self, imgs_path):self.imgs_path = imgs_pathdef __getitem__(self, index):img_path = self.imgs_path[index]pil_img = Image.open(img_path)pil_img = transform(pil_img)return pil_imgdef __len__(self):return len(self.imgs_path)dataset = Face_dataset(imgs_path)
dataloader = data.DataLoader(dataset,batch_size=32,shuffle=True)
imgs_batch = next(iter(dataloader))  # 32,3,64,64# 画出来看看
plt.figure(figsize=(12, 8))
for i, img in enumerate(imgs_batch[:6]):img = (img.permute(1, 2, 0).numpy() + 1) / 2  # 64,64,3plt.subplot(2, 3, i+1)plt.imshow(img)
plt.show()# 定义生成器,依然输入长度100的噪声
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.linear1 = nn.Linear(100, 256*16*16)self.bn1 = nn.BatchNorm1d(256*16*16)self.deconv1 = nn.ConvTranspose2d(256, 128, kernel_size=(3, 3), stride=1, padding=1)self.bn2 = nn.BatchNorm2d(128)self.deconv2 = nn.ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=2, padding=1)self.bn3 = nn.BatchNorm2d(64)self.deconv3 = nn.ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=2, padding=1)def forward(self, x):x = F.relu(self.linear1(x))x = self.bn1(x)x = x.view(-1, 256, 16, 16)x = F.relu(self.deconv1(x))x = self.bn2(x)x = F.relu(self.deconv2(x))x = self.bn3(x)x = torch.tanh(self.deconv3(x))return x# 判别器,输入(64, 64)图片
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2)  # (64, 31, 31)self.conv2 = nn.Conv2d(64, 128, 3, 2)self.bn = nn.BatchNorm2d(128)  # 128 * 15 * 15self.fc = nn.Linear(128*15*15, 1)def forward(self, x):x = F.dropout2d(F.leaky_relu(self.conv1(x)), p=0.3)x = F.dropout2d(F.leaky_relu(self.conv2(x)), p=0.3)x = self.bn(x)x = x.view(-1, 128*15*15)x = torch.sigmoid(self.fc(x))return xdevice = 'cuda' if torch.cuda.is_available() else 'cpu'
if device == 'cuda':print('using cuda:', torch.cuda.get_device_name(0))
else:print(device)Gen = Generator().to(device)
Dis = Discriminator().to(device)loss_fun = nn.BCELoss()
d_optimizer = torch.optim.Adam(Dis.parameters(), lr=1e-5)  # 小技巧
g_optimizer = torch.optim.Adam(Gen.parameters(), lr=1e-4)def generate_and_save_image(model, test_input):predictions = model(test_input).permute(0, 2, 3, 1).cpu().numpy()# fig = plt.figure(figsize=(40, 80))  # 画布设置太大会导致错误for i in range(predictions.shape[0]):plt.subplot(2, 4, i+1)plt.imshow((predictions[i]+1) / 2)plt.axis('off')plt.show()test_seed = torch.randn(8, 100, device=device)
D_loss = []
G_loss = []for epoch in range(500):d_epoch_loss = 0g_epoch_loss = 0count = len(dataloader)  # 批次数for step, img in enumerate(dataloader):img = img.to(device)size = img.size(0)random_noise = torch.randn(size, 100, device=device)d_optimizer.zero_grad()real_output = Dis(img)  # 判别器输入真实图片# 判别器在真实图像上的损失d_real_loss = loss_fun(real_output,torch.ones_like(real_output))d_real_loss.backward()gen_img = Gen(random_noise)fake_output = Dis(gen_img.detach())  # 判别器输入生成图片,fake_output对生成图片的预测# gen_img是由生成器得来的,但我们现在只对判别器更新,所以要截断对Gen的更新# detach()得到了没有梯度的tensor,求导到这里就停止了,backward的时候就不会求导到Gen了d_fake_loss = loss_fun(fake_output,torch.zeros_like(fake_output))d_fake_loss.backward()d_loss = d_real_loss + d_fake_lossd_optimizer.step()# 更新生成器g_optimizer.zero_grad()fake_output = Dis(gen_img)g_loss = loss_fun(fake_output,torch.ones_like(fake_output))g_loss.backward()g_optimizer.step()with torch.no_grad():d_epoch_loss += d_loss.item()g_epoch_loss += g_loss.item()with torch.no_grad():  # 之后的内容不进行梯度的计算(图的构建)d_epoch_loss /= countg_epoch_loss /= countD_loss.append(d_epoch_loss)G_loss.append(g_epoch_loss)print('Epoch:', epoch+1)generate_and_save_image(model=Gen, test_input=test_seed)plt.plot(D_loss, label='D_loss')plt.plot(G_loss, label='G_loss')plt.legend()plt.show()

效果

跑了50轮

DCGAN---生成动漫头像相关推荐

  1. vs2019 利用Pytorch和TensorFlow分别实现DCGAN生成动漫头像

    这是针对于博客vs2019安装和使用教程(详细)的DCGAN生成动漫头像项目新建示例 目录 一.DCGAN架构及原理 二.项目结构 1.TensorFlow 2.Pytorch 三.数据集下载(两种方 ...

  2. DCGAN生成动漫头像【学习】

    DCGAN生成动漫头像 在假期看了李宏毅老师的GAN的介绍,看到了课后题DCGAN生成动漫头像的作业,实现一下.记录学习过程. 参考的文章: [Keras] 基于GAN自动生成动漫头像 因为使用的是t ...

  3. pytorch:DCGAN生成动漫头像

    动漫头像数据集下载地址:动漫头像数据集_百度云连接,DCGAN论文下载地址: https://arxiv.org/abs/1511.06434 数据集里面的图片是这个样子的: 这是DCGAN的主要改进 ...

  4. 使用TensorFlow2.0搭建DCGAN生成动漫头像(内含生成过程GIF图)

    文章目录 生成对抗网络介绍 一.造假 二.训练判别器 三.训练生成器 DCGAN介绍 搭建DCGAN 数据来源 必要工作 读取数据 构建生成器 构建判别器 连接模型 连接图片 生成函数 训练 生成对抗 ...

  5. 通过PyTorch用DCGAN生成动漫头像

    数据集 数据集我们用AnimeFaces数据集,共5万多张动漫头像. 链接:https://pan.baidu.com/s/1cp-A8ZV74YBelkSuKxuM6A 提取码:face 要把所有的 ...

  6. 基于Tensorflow和DCGAN生成动漫头像实践(二)

    本篇内容为动漫头像生成的主要代码部分,第一次写这种代码,从读取数据到生成走了一个完整的流程.创建TFrecord过程可以看上一篇内容. 代码内容: #!/usr/bin/env python2 # - ...

  7. DCGAN生成动漫头像(附代码)

    DCGAN.顾名思义,就是深度卷积生成对抗神经网络,也就是引入了卷积的,但是它用的是反卷积,就是卷积的反操作. 我们看看DCGAN的图: 生成器开始输入的是噪声数据,然后经过一个全连接层,再把全连接层 ...

  8. 基于DCGAN的动漫头像生成

    基于DCGAN的动漫头像生成 数据 数据集:动漫图库爬虫获得,经过数据清洗,裁剪得到动漫头像.分辨率为3 * 96 * 96,共5万多张动漫头像的图片,从知乎用户何之源处下载. 生成器:输入为随机噪声 ...

  9. 有趣的图像生成——使用DCGAN与pytorch生成动漫头像

    有趣的图像生成--使用DCGAN与pytorch生成动漫头像 文章目录 有趣的图像生成--使用DCGAN与pytorch生成动漫头像 一.源码下载 二.什么是DCGAN 三.DCGAN的实现 1.** ...

  10. pytorch实现DCGAN生成动漫人物头像

    pytorch实现DCGAN生成动漫人物头像 DCGAN原理 参考这一系列文章 数据集 21551张64*64动漫人物头像 生成效果 训练1个epoch(emm-) 训练10个epoch(起码有颜色了 ...

最新文章

  1. 学习Hadoop时遇到的问题以及解决方法
  2. Android自定义XML属性以及遇到的命名空间的问题
  3. python简易木马(一)
  4. 送书福利 | 浙江大学陈华钧教授新作,全面梳理知识图谱技术体系
  5. 老男孩最近几年常用的免费的开源软件
  6. Xvid编码器流程(基于xvid1.1.0)
  7. 使用Spring Boot 2.0的Spring Security:保护端点
  8. ppt课堂流程图_除了直线能设计PPT,没想到曲线也实用,太赞了!
  9. .net一个函数要用另一个函数的值_VLOOKUP函数
  10. Java Web学习总结(35)——HTTP状态码汇总
  11. xForm应用开发手册
  12. Java编程解密-Dubbo负载均衡与集群容错机制
  13. MSSQL日期格式转换函数(使用CONVERT)
  14. linux下esc退不出vi
  15. MySQL指令集集合
  16. android实现弹出输入法时,顶部固定,中间部分上移的效果,使用 Dialog 制作紧贴输入法顶部的输入框...
  17. (三)MFC学习之动画
  18. 计算机安装系统说明,电脑操作系统安装方法-详细图解说明-简单安装Windows系统...
  19. 华为自带时钟天气下载_华为天气时钟农历插件,求华为自带的天气时钟
  20. 曲线运动与万有引力公式_高中物理公式大总结:曲线运动、万有引力

热门文章

  1. 有什么平价好用的蓝牙耳机?适合学生党的性价比耳机推荐
  2. 谁不是一边升学求职,一边死在路上
  3. P2085 最小函数值(优先队列 分组法运用)
  4. 学习RHCSA的第二天
  5. android 仿预订日历时间选择(如去哪儿,携程时间选择)
  6. 计算机英语sectionB,计算机专业英语教程课文翻译 Chapter Two SectionB.pdf
  7. [人体图像相关技术] -(一)概述
  8. 程序员——伤不起的三十岁
  9. MB52字段增加 显示物料库存报表
  10. ANSYS经典界面参数类型、定义及赋值