彩色星球图片生成1:使用Gan实现(pytorch版)

  • 1. 描述
  • 2. 代码
    • 2.1 模型代码model.py
    • 2.2 训练代码main.py
  • 3. 效果
  • 4. 趣图

上一集: 使用Gan实现MNIST数据集手写数字生成(pytorch版)

1. 描述

在上一次的基础上,将代码扩展到理论上可以对任意类型的彩色图片进行生成,当然实际效果就是另一回事了, 这次使用的是自己在space engine软件中截图得到的32张星球照片(因为batch size在该实验中设置为32)作为数据集,试图让对抗式生成网络能够根据已有的星球学会画出星球。
局限性:该代码产生的图片仍然是低像素的,原因可能一方面受限于GPU性能限制只有一台可怜弱小又无助的小笔记本 ,训练图片输入只有265x265的分辨率,另一方面还没有使用PatchGan等高分辨率生成的改进方式。
该网络使用反卷积生成器,对标签进行单边平滑处理。
由于训练图片很少(全靠手动截图),所以效果可能有限。
使用了apex混合精度加速来减少训练时间和显存占用。

2. 代码

代码包括模型与训练两部分的py文件。

2.1 模型代码model.py

import torch
import torch.nn as nn# 生成器,基于上采样
class G_net(nn.Module):def __init__(self):super(G_net, self).__init__()self.expand = nn.Sequential(nn.Linear(128, 2048),nn.BatchNorm1d(2048),nn.Dropout(0.5),nn.LeakyReLU(0.2),nn.Linear(2048, 8192),nn.BatchNorm1d(8192),nn.Dropout(0.5),nn.LeakyReLU(0.2),)self.gen = nn.Sequential(# 反卷积扩张尺寸,保持kernel size能够被stride整除来减少棋盘效应nn.ConvTranspose2d(128, 128, kernel_size=4, stride=1, padding=0),nn.BatchNorm2d(128),nn.LeakyReLU(0.2),nn.ConvTranspose2d(128, 256, kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(256),nn.LeakyReLU(0.2),nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(128),nn.LeakyReLU(0.2),nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(64),nn.LeakyReLU(0.2),nn.ConvTranspose2d(64, 32, kernel_size=6, stride=3, padding=1),nn.BatchNorm2d(32),nn.LeakyReLU(0.2),nn.ConvTranspose2d(32, 16, kernel_size=5, stride=1, padding=1),nn.BatchNorm2d(16),nn.LeakyReLU(0.2),# 尾部添加正卷积压缩减少棋盘效应nn.Conv2d(16, 8, kernel_size=5, stride=1, padding=1),nn.BatchNorm2d(8),nn.LeakyReLU(0.2),nn.Conv2d(8, 3, kernel_size=3, stride=1, padding=1),# 将输出约束到[-1,1]nn.Tanh())def forward(self, img_seeds):img_seeds = self.expand(img_seeds)# 将线性数据重组为二维图片img_seeds = img_seeds.view(-1, 128, 8, 8)output = self.gen(img_seeds)return output# 返回对应的生成器
def get_G_model(from_old_model, device, model_path):model = G_net()# 从磁盘加载之前保存的模型参数if from_old_model:model.load_state_dict(torch.load(model_path))# 将模型加载到用于运算的设备的内存model = model.to(device)return model# 判别器
class D_net(nn.Module):def __init__(self):super(D_net,self).__init__()self.features = nn.Sequential(nn.Conv2d(3, 64, kernel_size=5, stride=3, padding=1, bias=False),nn.LeakyReLU(0.2, True),nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(128),nn.LeakyReLU(0.2, True),nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(256),nn.LeakyReLU(0.2, True),nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(512),nn.LeakyReLU(0.2, True),nn.Conv2d(512, 16, kernel_size=4, stride=1, padding=0, bias=False),nn.BatchNorm2d(16),nn.LeakyReLU(0.2, True),)self.classifier = nn.Sequential(nn.Linear(1024, 1024),nn.ReLU(True),nn.Linear(1024, 1),nn.Sigmoid(),)def forward(self, img):features = self.features(img)features = features.view(features.shape[0], -1)output = self.classifier(features)return output# 返回判别器的模型
def get_D_model(from_old_model, device, model_path):model = D_net()# 从磁盘加载之前保存的模型参数if from_old_model:model.load_state_dict(torch.load(model_path))# 将模型加载到用于运算的设备的内存model = model.to(device)return model

2.2 训练代码main.py

from torch.utils.data import Dataset, DataLoader
import time
from torch.optim import AdamW
from model import *
from torchvision.utils import save_image
import random
from torch.autograd import Variable
import os
import cv2
from albumentations import Normalize, Compose, Resize
from albumentations.pytorch import ToTensorV2
from apex import amp# ------------------------------------config------------------------------------
class config:# 设置种子数,配置是否要固定种子数seed = 26use_seed = True# 配置是否要从磁盘加载之前保存的模型参数继续训练from_old_model = False# 使用apex加速训练use_apex = True# 运行多少个epoch之后停止epochs = 10000# 配置batch sizebatchSize = 32# 训练图片输入分辨率img_size = 265# 配置喂入生成器的随机正态分布种子数有多少维(如果改动,需要在model中修改网络对应参数)img_seed_dim = 128# 有多大概率在训练判别器D时交换正确图片的标签和伪造图片的标签D_train_label_exchange = 0.1# 保存模型参数文件的路径G_model_path = "G_model.pth"D_model_path = "D_model.pth"# 损失函数# 使用均方差损失函数criterion = nn.MSELoss()# ------------------------------------路径配置------------------------------------# 数据集来源img_path = "train_images/"# 输出图片的文件夹路径output_path = "output_images/"# 固定随机数种子
def seed_all(seed):random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed(seed)torch.cuda.manual_seed_all(seed)torch.backends.cudnn.deterministic = Trueif config.use_seed:seed_all(seed=config.seed)# -----------------------------------transforms------------------------------------
def get_transforms(img_size):# 缩放分辨率并转换到0-1之间return Compose([Resize(img_size, img_size),Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), max_pixel_value=255.0, p=1.0),ToTensorV2(p=1.0)])# ------------------------------------dataset------------------------------------
# create dataset
class image_dataset(Dataset):def __init__(self, file_list, img_path, transform):# files listself.file_list = file_listself.img_path = img_pathself.transform = transformdef __getitem__(self, index):image_path = self.img_path + self.file_list[index]img = cv2.imread(image_path)img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)img = self.transform(image=img)['image']return imgdef __len__(self):return len(self.file_list)# ------------------------------------main------------------------------------
def main():# 如果可以使用GPU运算,则使用GPU,否则使用CPUdevice = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')print("Use " + str(device))# 创建输出文件夹if not os.path.exists(config.output_path):os.mkdir(config.output_path)# 创建dataset# create datasetfile_list = Nonefor path, dirs, files in os.walk(config.img_path, topdown=False):file_list = list(files)train_dataset = image_dataset(file_list, config.img_path, transform=get_transforms(config.img_size))train_loader = DataLoader(dataset=train_dataset, batch_size=config.batchSize, shuffle=True)# 从model中获取判别器D和生成器G的网络模型G_model = get_G_model(config.from_old_model, device, config.G_model_path)D_model = get_D_model(config.from_old_model, device, config.D_model_path)# 定义G和D的优化器,此处使用AdamW优化器G_optimizer = AdamW(G_model.parameters(), lr=3e-4, weight_decay=1e-6)D_optimizer = AdamW(D_model.parameters(), lr=3e-4, weight_decay=1e-6)# 损失函数criterion = config.criterion# 混合精度加速if config.use_apex:G_model, G_optimizer = amp.initialize(G_model, G_optimizer, opt_level="O1")D_model, D_optimizer = amp.initialize(D_model, D_optimizer, opt_level="O1")# 记录训练时间train_start = time.time()# 开始训练的每一个epochfor epoch in range(config.epochs):print("start epoch "+str(epoch+1)+":")# 定义一些变量用于记录进度和损失batch_num = len(train_loader)D_loss_sum = 0G_loss_sum = 0count = 0# 从dataloader中提取数据for index, images in enumerate(train_loader):count += 1# 将图片放入运算设备的内存images = images.to(device)# 定义真标签,使用标签平滑的策略,生成0.9到1之间的随机数作为真实标签# real_labels = (1 - torch.rand(config.batchSize, 1)/10).to(device)# 定义真标签,全1# real_labels = Variable(torch.ones(config.batchSize, 1)).to(device)# 定义真标签,全0.9real_labels = Variable(torch.ones(config.batchSize, 1)-0.1).to(device)# 定义假标签,单向平滑,因此不对生成器标签进行平滑处理,全0fake_labels = Variable(torch.zeros(config.batchSize, 1)).to(device)# 将随机的初始数据喂入生成器生成假图像img_seeds = torch.randn(config.batchSize, config.img_seed_dim).to(device)fake_images = G_model(img_seeds)# 记录真假标签是否被交换过exchange_labels = False# 有一定概率在训练判别器时交换labelif random.uniform(0, 1) < config.D_train_label_exchange:real_labels, fake_labels = fake_labels, real_labelsexchange_labels = True# 训练判断器DD_optimizer.zero_grad()# 用真样本输入判别器real_output = D_model(images)# 对于数据集末尾的数据,长度不够一个batch size时需要去除过长的真实标签if len(real_labels) > len(real_output):D_loss_real = criterion(real_output, real_labels[:len(real_output)])else:D_loss_real = criterion(real_output, real_labels)# 用假样本输入判别器fake_output = D_model(fake_images)D_loss_fake = criterion(fake_output, fake_labels)# 将真样本与假样本损失相加,得到判别器的损失D_loss = D_loss_real + D_loss_fakeD_loss_sum += D_loss.item()# 重置优化器D_optimizer.zero_grad()# 用损失更新判别器Dif config.use_apex:with amp.scale_loss(D_loss, D_optimizer) as scaled_loss:scaled_loss.backward()else:D_loss.backward()D_optimizer.step()# 如果之前交换过标签,此时再换回来if exchange_labels:real_labels, fake_labels = fake_labels, real_labels# 训练生成器G# 将随机种子数喂入生成器G生成假数据img_seeds = torch.randn(config.batchSize, config.img_seed_dim).to(device)fake_images = G_model(img_seeds)# 将假数据输入判别器fake_output = D_model(fake_images)# 将假数据的判别结果与真实标签对比得到损失G_loss = criterion(fake_output, real_labels)G_loss_sum += G_loss.item()# 重置优化器G_optimizer.zero_grad()# 利用损失更新生成器Gif config.use_apex:with amp.scale_loss(G_loss, G_optimizer) as scaled_loss:scaled_loss.backward()else:G_loss.backward()G_optimizer.step()# 打印程序工作进度if (index + 1) % 200 == 0:print("Epoch: %2d, Batch: %4d / %4d" % (epoch + 1, index + 1, batch_num))if (epoch+1) % 10 == 0:# 在每N个epoch结束时保存模型参数到磁盘文件torch.save(G_model.state_dict(), config.G_model_path)torch.save(D_model.state_dict(), config.D_model_path)# 在每N个epoch结束时输出一组生成器产生的图片到输出文件夹img_seeds = torch.randn(config.batchSize, config.img_seed_dim).to(device)fake_images = G_model(img_seeds).cuda().data# 将假图像缩放到[0,1]的区间fake_images = 0.5 * (fake_images + 1)fake_images = fake_images.clamp(0, 1)# 连接所有生成的图片然后用自带的save_image()函数输出到磁盘文件fake_images = fake_images.view(-1, 3, config.img_size, config.img_size)save_image(fake_images, config.output_path+str(epoch+1)+'.png')# 打印该epoch的损失,时间等数据用于参考print("D_loss:", round(D_loss_sum / count, 3))print("G_loss:", round(G_loss_sum / count, 3))current_time = time.time()pass_time = int(current_time - train_start)time_string = str(pass_time // 3600) + " hours, " + str((pass_time % 3600) // 60) + " minutes, " + str(pass_time % 60) + " seconds."print("Time pass:", time_string)print()# 运行结束print("Done.")if __name__ == '__main__':main()

3. 效果

训练图片(32张,输入分辨率为265x265,来源space engine):

100个epoch:

1000个epoch:

2000个epoch:

3000个epoch:

4000个epoch:

5000个epoch:

6500个epoch:

可以看出,除了高像素细节外,在图像生成上已经非常理想,也出现了不同星球混合组成的新图像。受限于贫穷 显卡硬件限制,在输入像素上很难继续扩大,不过接下来会尝试用一些Gan的改进方案进一步提高清晰度。

4. 趣图

这些是从生成效果来说不理想,但看起来真的很炫酷的星球。
半透明的星壳与水晶星核:

AI委婉地表达了想要毁灭地球:

下一集:彩色星球图片生成2:同时使用传统Gan判别器和马尔可夫判别器(pytorch版)

彩色星球图片生成1:使用Gan实现(pytorch版)相关推荐

  1. 彩色星球图片生成4:转置卷积+插值缩放+卷积收缩(pytorch版)

    彩色星球图片生成4:转置卷积层+插值缩放+卷积收缩(pytorch版) 1. 改进方面 1.1 优化器与优化步长 1.2 交叉熵损失函数 1.3 Patch判别器 1.4 输入分辨率 1.5 转置卷积 ...

  2. 彩色星球图片生成3:代码改进(pytorch版)

    彩色星球图片生成3:代码改进(pytorch版) 1. 修改 1.1 预处理缩放 1.2 随机翻转 1.3 修改全局判别器 1.4 修改进度打印 2. 效果 3. 总结 上一集: 彩色星球图片生成2: ...

  3. 彩色星球图片生成5:先验条件约束与LapGAN(pytorch版)

    彩色星球图片生成5:先验条件约束与LapGAN(pytorch版) 1. 改进方面 1.1 训练集信息的人工标注 1.2 先验信息的条件约束 1.3 分类器C 1.4 LapGAN的分层残差拟合 2. ...

  4. 彩色星球科技旗下元宇宙平台“彩色世界”亚洲版即将发布;Branch宣布获得3亿美元融资 | 全球TMT...

    国内市场 彩色星球科技旗下"彩色世界"元宇宙软件亚洲版即将登陆.彩星科技将发布一个本地化的亚洲版本,以便让用户在亚洲体验并访问该平台应用软件.亚洲版将提供简体中文和繁体中文.日文和 ...

  5. Python批量导入图片生成能治疗颈椎病的HTML5版课件

    本文要点:Python文件操作,HTML5的figure元素和CSS3属性的用法. 说明:1):本文图片来自于相关阅读中Python批量导出多个PPT/PPTX文件中每个幻灯片为独立JPG图片或Pyt ...

  6. GAN (生成对抗网络) 手写数字图片生成

    GAN (生成对抗网络) 手写数字图片生成 文章目录 GAN (生成对抗网络) 手写数字图片生成 Discriminator Network Generator Network 简单版本的生成对抗网络 ...

  7. [人工智能-深度学习-63]:生成对抗网络GAN - 图片创作:普通GAN, pix2pix, CycleGAN和pix2pixHD的演变过程

    作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客 本文网址:https://blog.csdn.net/HiWangWenBing/article/detai ...

  8. 使用PyTorch构建卷积GAN源码(详细步骤讲解+注释版) 02人脸图片生成 上

    阅读提示:本篇文章的代码为在普通GAN代码上实现人脸图片生成的修改,文章内容仅包含修改内容,全部代码讲解需结合下面的文章阅读. 相关资料链接为:使用PyTorch构建GAN生成对抗 本次训练代码使用了 ...

  9. 独家 | 什么是生成模型和GAN?一文体验计算机视觉的魔力(附链接)

    作者:PULKIT SHARMA 翻译:吴金笛 校对:王婷 本文长度约为4700字,建议阅读15分钟 本文介绍了生成模型和生成对抗网络(GAN)的工作原理和训练步骤. 概况 生成模型和GAN是计算机视 ...

最新文章

  1. 投影参数_智能投影仪参数如何去看,其实很简单
  2. php让十进制输出十六进制(ascill)码
  3. Asp.net Web.Config - 配置元素customErrors
  4. 如何清除SQL数据库日志,清除后对数据库有什么影响
  5. Java学习----运算符与表达式
  6. 欧几里得算法及其扩展
  7. Python中的@property Decorator:其用例,优点和语法
  8. 树莓派 Learning 002 装机后的必要操作 --- 03 替换软件源
  9. 用C++开发与调用WebService的例子
  10. 扫盲——敏捷开发 Agile development 之 Scrum开发
  11. 百人计划(图形部分)Bump Mapping(凹凸贴图映射技术)
  12. u盘pe启动盘怎么制作?
  13. 如何给计算机d盘加密码,怎样给电脑文件夹加密
  14. 数字逻辑课上如何制作FPGA游戏?
  15. Java主要应用于哪些方面 Java就业方向有哪些
  16. 斐波那契数列由数字1 1 2 3 5 8 13 21 34等等组成,其中每一个数字(从第三个起) 都是由前两个数字的和。
  17. Mybatis xml映射文件错误,导致Tomcat无法启动,也不报异常
  18. 黑客技术论坛为什么越来越少了?
  19. 快速提高数学成绩的奇书《巧学妙解王》高中数学!
  20. ShareLatex+Overflow:PDF Rendering Error Something went wrong while rendering this PDF问题解决

热门文章

  1. Python带格式写表格实例
  2. 随心邮|你不知道的手机收发邮件的新方式
  3. Android图形验证码
  4. Java 图形验证码
  5. H3C网络设备密码不知道,如何清空?(跳过当前系统配置)
  6. Linux 卸载软件
  7. CSS学习之相对定位和绝对定位
  8. 图形学 射线相交算法_计算机图形学中的彩色阴极射线管
  9. 携程网国内酒店评论数据(超8163万条)
  10. 三种激活函数——Sigmoid,Tanh, ReLU以及卷积感受野的计算