import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms
import timetime_start = time.time()# 生成器生成的数据在 [-1, 1]
transform = transforms.Compose([transforms.ToTensor(),  # 会做0-1归一化,也会channels, height, widthtransforms.Normalize((0.5,), (0.5,))
])train_ds = torchvision.datasets.MNIST('data', train=True, transform=transform)
dataLoader = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True)# 生成器网络定义
# 输入是长度为100的噪声(正态分布随机数)
# 输出为(1, 28, 28)的图片
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.main = nn.Sequential(nn.Linear(100, 256),nn.ReLU(),nn.Linear(256, 512),nn.ReLU(),nn.Linear(512, 28*28),nn.Tanh())def forward(self, x):img = self.main(x)img = img.view(-1, 28, 28, 1)return img# 判别器网络定义
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.main = nn.Sequential(nn.Linear(28*28, 512),nn.LeakyReLU(),nn.Linear(512, 256),nn.LeakyReLU(),nn.Linear(256, 1),nn.Sigmoid())def forward(self, x):x = x.view(-1, 28*28)x = self.main(x)return xdevice = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
gen = Generator().to(device)
dis = Discriminator().to(device)
d_optimizer = torch.optim.Adam(dis.parameters(), lr=0.0001)
g_optimizer = torch.optim.Adam(gen.parameters(), lr=0.0001)# 损失函数
loss_fn = torch.nn.BCELoss()# 绘图函数
def gen_img_plot(model, test_input):prediction = np.squeeze(model(test_input).detach().cpu().numpy())fig = plt.figure(figsize=(4, 4))for i in range(16):plt.subplot(4, 4, i+1)plt.imshow((prediction[i] + 1)/2)plt.axis('off')plt.show()test_input = torch.randn(16, 100, device=device)# GAN训练
D_loss = []
G_loss = []# 训练循环
for epoch in range(20):d_epoch_loss = 0g_epoch_loss = 0count = len(dataLoader)  # 返回批次数for step, (img, _) in enumerate(dataLoader):img = img.to(device)size = img.size(0)random_noise = torch.randn(size, 100, device=device)# 判别器的损失与优化d_optimizer.zero_grad()real_output = dis(img)  # 对判别器输入真实图片, real_output是对真实图片的判断结果d_real_loss = loss_fn(real_output, torch.ones_like(real_output))  # 判别器在真实图像上的损失d_real_loss.backward()gen_img = gen(random_noise)fake_output = dis(gen_img.detach())  # 判别器输入生成的图片,fake_output对生成图片的预测d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output))  # 判别器在生成图像上的损失d_fake_loss.backward()d_loss = d_real_loss + d_fake_lossd_optimizer.step()# 生成器的损失与优化g_optimizer.zero_grad()fake_output = dis(gen_img)g_loss = loss_fn(fake_output, torch.ones_like(fake_output))  # 生成器的损失g_loss.backward()g_optimizer.step()with torch.no_grad():d_epoch_loss += d_lossg_epoch_loss += g_losswith torch.no_grad():d_epoch_loss /= countg_epoch_loss /= countD_loss.append(d_epoch_loss)G_loss.append(g_epoch_loss)print("Epoch:", epoch)gen_img_plot(gen, test_input)time_end = time.time()
print("花费总时间为:", time_end - time_start)

GAN实战——生成手写字体相关推荐

  1. 用c语言实现knn算法要有训练集和测试集,KNN算法实战:手写字体识别

    我们已经知道手写字体数据集是一个8×8的矩阵,共有64个特征.让我们看一下K最近邻算法对手写字体数据集处理的效果. 1) 导入相关包 这里我们将用到 datasets 中的手写字体数据,使用 trai ...

  2. #21天学习挑战赛—深度学习实战100例#——生成手写字体

    ​ ​ 活动地址:CSDN21天学习挑战赛 本文为

  3. 生成式对抗网络GAN之实现手写字体的生成(基于keras Tensorflow2.0实现)详细分析训练过程和代码

  4. 使用WGAN生成手写字体

    import sys; sys.path.append("/home/hxj/anaconda3/lib/python3.6/site-packages") import nump ...

  5. Pytorch实现GAN之生成手写数字图片

    1.导入所需库 import torch import torch.optim as optim import torch.nn as nn import torch.nn.functional as ...

  6. gan网络原理(通俗)+minist手写字体实战

     gan网络原理如下: mnist手写字体实战: import torch import torchvision from torchvision import transforms from tor ...

  7. 《Gans in Action》第三章 用GAN生成手写数字

    此为<Gans in Action>(对抗神经网络实战)第三章读书笔记 Chapter 3. Your first GAN: Generating handwritten digits 用 ...

  8. GAN变种ACGAN利用手写数字识别mnist生成手写数字

    1.摘要 本文主要讲解:GAN变种ACGAN利用手写数字识别mnist数据集进行训练,最终生成手写数字图片 主要思路: Initialize generator and discriminator I ...

  9. 「zi2zi」:用AI生成自己的手写字体

    导读 如果想要自己做一套字体,无论是电脑软件FontCreator还是网站flexifont都为我们带来了极大的便利. 但是最低的国标字体数量近7000个,若采用传统的方法则需要手写相同数量的汉字,这 ...

最新文章

  1. xp 远程计算机需要远程验证登陆,xp远程桌面登陆需要身份验证问题解决
  2. IMP出现的ORA-01401错误可能和字符集有关(转载)
  3. guice 实例_使用Google Guice消除实例之间的歧义
  4. 前端学习(2286):react之无状态组件
  5. BZOJ 1878: [SDOI2009]HH的项链 | 莫队
  6. ORA-01089 数据库无法正常关闭
  7. 怎么成为日上会员直邮_放福利啦,免税店现在一件也能直邮,不用出入境、不用出入境、不用找代购...
  8. Backup Volume 操作 - 每天5分钟玩转 OpenStack(59)
  9. 逆向常用命令android常用逆向命令
  10. 织梦mysql安装教程视频教程_dedecms织梦模板安装教程视频/图文步骤(模板秀出品)...
  11. 联想计算机wifi卸载,怎么卸载联想笔记本电源管理软件
  12. 五阶魔方公式java_5阶魔方教程(五阶魔方一步一步图解)
  13. 通俗易懂地理解傅里叶变换
  14. 文件无法删除 你需要计算机管理员 提供的权限才能对此文件进行更改解决办法
  15. Linux命令:configure --prefix=/ 有什么作用
  16. CLRS 16.2贪心算法的原理
  17. live555 官方网站源码下载地址
  18. pytorch的训练测试流程总结,以及model.evel(), model.train(),torch.no_grad()作用
  19. 棒棒糖球球机器人_球球大作战刷棒棒糖
  20. linux安装SecureCRT安装教学

热门文章

  1. 【SEUSE】编译原理 - 词法分析器实验报告
  2. js`${}` 艺术字体语法
  3. 刷脸支付在支付前后商家可以做无限延展
  4. spark-shell启动出现的Error creating transactional connection factory解决办法
  5. java一键安装_新工具一键安装Java环境!微软又双叒叕造福开发者
  6. 汇编语言程序设计---1~4章习题答案(王爽第二版)
  7. 鲁迅先生文学作品合集
  8. 计算机考试反思2000字,期中考试后的反思,400字左右,谢谢,期中考试后的反思(2000字)谢谢啦!...
  9. 头歌计算机组成原理实验—运算器设计(8)第8关:乘法流水线设计
  10. 25.QAbstractionButton