【参考资料】
【1】https://github.com/aitorzip/PyTorch-CycleGAN
【2】《深入浅出GAN生成对抗网络》 8.2 CycleGan

尝试了下cyclegan,训练大约2小时500次迭代,效果不甚理想。估计是学习率未优化,炼丹时间太短吧;)

"""
采用数据集:http://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/horse2zebra.zip CycleGan训练的几个特点:
1、归一化用 IN 代替 BN
2、损失函数采用 LSGAN平方差
3、生成器采用残差网络
4、使用缓存历史图像训练生成器"""
from random import randint
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import os
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.utils import save_image
import shutil
import cv2
import random
from PIL import Image
import itertoolsdef to_img(x):out = 0.5 * (x + 1)out = out.clamp(0, 1)  out = out.view(-1, 3, 256, 256)  return out"""
残差网络blockflectionPad2d: 镜像填充,例如 0 1 填充至 4个数,则 0 1 0 1
InstanceNorm2d: 对单个样本的每一层特征图抽出来一层层求均值、方差然后归一化"""
class ResidualBlock(nn.Module):def __init__(self, in_features):super(ResidualBlock, self).__init__()self.block_layer = nn.Sequential(nn.ReflectionPad2d(1),nn.Conv2d(in_features, in_features, 3),nn.InstanceNorm2d(in_features),nn.ReLU(inplace=True),nn.ReflectionPad2d(1),nn.Conv2d(in_features, in_features, 3),nn.InstanceNorm2d(in_features))def forward(self, x):return x + self.block_layer(x)"""
生成器
PS: 这段代码完全copy 参考文献【1】,对于复杂网络确实这种构建方式更加清晰!!
"""
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()model = [   nn.ReflectionPad2d(3),nn.Conv2d(3, 64, 7),nn.InstanceNorm2d(64),nn.ReLU(inplace=True) ]in_features = 64out_features = in_features*2for _ in range(2):model += [  nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),nn.InstanceNorm2d(out_features),nn.ReLU(inplace=True) ]in_features  = out_featuresout_features = in_features*2for _ in range(9):model += [ResidualBlock(in_features)]out_features = in_features//2for _ in range(2):model += [  nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),nn.InstanceNorm2d(out_features),nn.ReLU(inplace=True) ]in_features = out_featuresout_features = in_features//2model += [  nn.ReflectionPad2d(3),nn.Conv2d(64, 3, 7),nn.Tanh() ]self.gen = nn.Sequential(*model)def forward(self, x):x = self.gen(x)return x """
判别器
1、这里判别器的最后一层是FCN全卷积网络avg_pool2d:以均值方式池化,以下述代码为例:
input  = torch.randn(10, 3, 128, 128)
m      = Discriminator()
output = m(input)此时在 avg_pool2d 前 x.size() 为 torch.Size([10, 1, 14, 14])
对 10个[14, 15]的tensor求均值并返回
tensor([[0.1162],[0.1298],[0.1266],[0.1229],[0.1085],[0.1121],[0.1064],[0.1044],[0.1077],[0.1139]], grad_fn=<ViewBackward>)"""
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.dis = nn.Sequential(nn.Conv2d(3, 64, 4, 2, 1, bias=False),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(64, 128, 4, 2, 1, bias=False),nn.InstanceNorm2d(128),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(128, 256, 4, 2, 1, bias=False),nn.InstanceNorm2d(256),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(256, 512, 4, padding=1),nn.InstanceNorm2d(512),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(512, 1, 4, padding=1))        def forward(self, x):x = self.dis(x)return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1)"""
数据加载
"""
data_path  = os.path.abspath("E:/dataset/cyclegan/horse2zebra")
image_size = 256
batch_size = 1transform = transforms.Compose([transforms.Resize(int(image_size * 1.12), Image.BICUBIC), transforms.RandomCrop(image_size), transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))])def _get_train_data(batch_size=1):train_a_filepath = data_path + "\\trainA\\"train_b_filepath = data_path + "\\trainB\\"train_a_list = os.listdir(train_a_filepath)train_b_list = os.listdir(train_b_filepath)train_a_result = []train_b_result = [] numlist = random.sample(range(0,len(train_a_list)), batch_size)for i in numlist:a_filename = train_a_list[i]a_img = Image.open(train_a_filepath + a_filename)res_a_img = transform(a_img)train_a_result.append(torch.unsqueeze(res_a_img, 0))b_filename = train_b_list[i]b_img = Image.open(train_b_filepath + b_filename)res_b_img = transform(b_img)train_b_result.append(torch.unsqueeze(res_b_img, 0))return torch.cat(train_a_result,dim=0), torch.cat(train_b_result,dim=0)"""
训练神经网络
存在生成网络 G_A2B 从A类图片生成B类图片
存在生成网络 G_B2A 从B类图片生成A类图片
存在判别器   D_A   对A类图片真伪进行判断
存在判别器   D_B   对B类图片真伪进行判断 存在损失函数:criterion_GAN MSELoss:均方损失函数,即 两个向量各分量差的平方
存在损失函数:criterion_cycle L1Loss:平均均对误差,即 两个向量各分量差的绝对值求和再除分量数
存在损失函数: criterion_identity L1Loss:第一步:训练生成器
1、利用 G_A2B 将 real_B 生成 same_B,获取 real_B 和 same_B 之间的损失
2、利用 G_B2A 将 real_A 生成 same_A,获取 real_A 和 same_A 之间的损失
3、利用 G_A2B 将 real_A 生成 fake_B, 通过 netD_B 判断 fake_B,获取其与判真之间的损失
4、利用 G_B2A 将 real_B 生成 faka_A,通过 netD_A 判断 fake_A,获取其与判真之间的损失
5、利用 G_B2A 将 fake_B 生成 recovered_A,获取其与 real_A 之间的损失
6、利用 G_A2B 将 fake_A 生成 recovered_B, 获取其与 real_B 之间的损失此时对所有损失求和后训练生成器,此时:对于 G_A2B:1. 如果输入是 B类图片,它生成的图片same_B总体像素层面上接近B2. 如果输入是 A类图片,它生成的图片fake_B会具有B类图片的卷积特征3. 对于生成的图片fake_B, 它经过 G_B2A 生成的recovered_A 会总体像素层面上接近A此时我们有了一张 A 经过 G_A2B 和 G_B2A 成为一张 新的A,这张新A像素总体上是A(是一批马)但细节纹理上具有 B 的特征(斑马纹理)同理有 G_B2A第二步:训练判别器
1、对判别器 netD_A 可以对 real_A 判真
2、对判别器 netD_A 可以对 fake_A 判伪
3、对判别器 netD_B 可以对 real_B 判真
4、对判别器 netD_B 可以对 fake_B 判伪
"""class ReplayBuffer():"""缓存队列,若不足则新增,否则随机替换"""def __init__(self, max_size=50):self.max_size = max_sizeself.data = []def push_and_pop(self, data):to_return = []for element in data.data:element = torch.unsqueeze(element, 0)if len(self.data) < self.max_size:self.data.append(element)to_return.append(element)else:if random.uniform(0,1) > 0.5:i = random.randint(0, self.max_size-1)to_return.append(self.data[i].clone())self.data[i] = elementelse:to_return.append(element)return Variable(torch.cat(to_return))fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()netG_A2B = Generator()
netG_B2A = Generator()
netD_A   = Discriminator()
netD_B   = Discriminator()"""
显存不够,删除
if torch.cuda.is_available():print("use cuda")netG_A2B = netG_A2B.cuda()netG_B2A = netG_B2A.cuda()netD_A   = netD_A.cuda()netD_B   = netD_B.cuda()
"""criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()d_learning_rate = 3e-4  # 3e-4
g_learning_rate = 3e-4
optim_betas     = (0.5, 0.999)g_optimizer  = optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()), lr=d_learning_rate)
da_optimizer = optim.Adam(netD_A.parameters(), lr=d_learning_rate)
db_optimizer = optim.Adam(netD_B.parameters(), lr=d_learning_rate)num_epochs   = 1000 #检查train_b是否有灰度图
"""
print(np.arange(196608).reshape(256,256,3).shape)
train_b_filepath = data_path + "\\trainB\\"
train_b_list = os.listdir(train_b_filepath)
for i in range(len(train_b_list)):b_filename = train_b_list[i]b_img = Image.open(train_b_filepath + b_filename)if np.array(b_img).shape != np.arange(196608).reshape(256,256,3).shape:print(b_filename)os.remove(train_b_filepath + b_filename)
"""for epoch in range(num_epochs): real_a, real_b = _get_train_data(batch_size)target_real = torch.full((batch_size,), 1)target_fake = torch.full((batch_size,), 0)g_optimizer.zero_grad()# 第一步:训练生成器same_B          = netG_A2B(real_b)loss_identity_B = criterion_identity(same_B, real_b) * 5.0   same_A          = netG_B2A(real_a)loss_identity_A = criterion_identity(same_A, real_a) * 5.0fake_B          = netG_A2B(real_a)pred_fake       = netD_B(fake_B)loss_GAN_A2B    = criterion_GAN(pred_fake, target_real)fake_A          = netG_B2A(real_b)pred_fake       = netD_A(fake_A)loss_GAN_B2A    = criterion_GAN(pred_fake, target_real)recovered_A     = netG_B2A(fake_B)loss_cycle_ABA  = criterion_cycle(recovered_A, real_a) * 10.0recovered_B     = netG_A2B(fake_A)loss_cycle_BAB  = criterion_cycle(recovered_B, real_b) * 10.0  loss_G          = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BABloss_G.backward()    g_optimizer.step()# 第二步:训练判别器# 训练判别器Ada_optimizer.zero_grad()pred_real = netD_A(real_a)loss_D_real = criterion_GAN(pred_real, target_real)fake_A = fake_A_buffer.push_and_pop(fake_A)pred_fake = netD_A(fake_A.detach())loss_D_fake = criterion_GAN(pred_fake, target_fake)loss_D_A = (loss_D_real + loss_D_fake)*0.5loss_D_A.backward()da_optimizer.step()# 训练判别器Bdb_optimizer.zero_grad()pred_real = netD_B(real_b)loss_D_real = criterion_GAN(pred_real, target_real)fake_B = fake_B_buffer.push_and_pop(fake_B)pred_fake = netD_B(fake_B.detach())loss_D_fake = criterion_GAN(pred_fake, target_fake)loss_D_B = (loss_D_real + loss_D_fake)*0.5loss_D_B.backward()db_optimizer.step()#损失打印,存储伪造图片print("Epoch[{}],loss_G:{:.6f} ,loss_D_A:{:.6f},loss_D_B:{:.6f}".format(epoch,loss_G.data.item(),loss_D_A.data.item(),loss_D_B.data.item()))if (epoch + 1) % 20 == 0 or epoch == 0:  b_fake = to_img(fake_B.data)a_fake = to_img(fake_A.data)a_real = to_img(real_a.data)b_real = to_img(real_b.data)save_image(a_fake, './img/cyclegan/a_fake.png') save_image(b_fake, './img/cyclegan/b_fake.png') save_image(a_real, './img/cyclegan/a_real.png') save_image(b_real, './img/cyclegan/b_real.png') 

【pytorch基础笔记六】基于CYCLEGAN的马转斑马尝试相关推荐

  1. PyTorch基础(六)迁移学习

    在实际工程中,基本没有人会从零开始(随机初始化)训练一个完整的卷积网络,因为相对于网络,很难得到一个足够大的数据集(网络很深,需要足够大数据集).通常的做法是在一个很大的数据集上进行预训练得到卷积网络 ...

  2. 论文阅读笔记(六)——基于改进深度学习方法的股骨x线骨折自动检测与定位

    Automatic detection and localization of thighbone fractures in X-ray based on improved deep learning ...

  3. Pytorch基础(六)——激活函数

    一.概念 激活函数顾名思义,就是一种可以给神经网络注入灵魂的一种方法,也可以称之为激活层.其计算就是将线性的函数转变为非线性函数的过程,只有这样,我们制作的深层神经网络才能无限逼近真实值. 自神经网络 ...

  4. 《Go语言程序设计》读书笔记(六) 基于共享变量的并发

    竞争条件 在一个线性(就是说只有一个goroutine的)的程序中,程序的执行顺序只由程序的逻辑来决定.在有两个或更多goroutine的程序中,每一个goroutine内的语句也是按照既定的顺序去执 ...

  5. C语言if和汇编jcc程序对比,逆向基础笔记六 汇编跳转和比较指令

    JCC指令 cc 代表 condition code(状态码) Jcc不是单个指令,它只是描述了跳转之前检查条件代码的跳转助记符 例如JNE,在跳转之前检查条件代码 典型的情况是进行比较(设置CC), ...

  6. PyTorch基础(六)----- torch.eq()方法

    一.torch.eq()方法详解 对两个张量Tensor进行逐元素的比较,若相同位置的两个元素相同,则返回True:若不同,返回False. torch.eq(input, other, *, out ...

  7. PyTorch学习笔记(二):PyTorch简介与基础知识

    往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本系列目录: PyTorch学习笔记(一):PyTorch环境安 ...

  8. pyTorch——基础学习笔记

    pytorch基础学习笔记博文,在整理的时候借鉴的大量的网上资料,存在和一部分图片定义的直接复制黏贴,在本博文的最后将会表明所有的参考链接.由于参考的内容众多,所以博文的更新是一个长久的过程,如果大佬 ...

  9. PyTorch学习笔记(六):PyTorch进阶训练技巧

    PyTorch实战:PyTorch进阶训练技巧 往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本系列目录: P ...

最新文章

  1. socket 服务器浏览器与服务器客户端实例
  2. BugkuCTF-WEB题file_get_contents
  3. win+mysql自动备份吗_Windows下mysql自动备份的最佳方案
  4. RabbitMQ负载均衡
  5. SCI论文全攻略:选刊\投稿\修回与退稿
  6. java 504错误怎么解决_求助java.lang.NoClassDefFoundError怎么解决,报错信息如下
  7. 清除string内容_前端面试之javascript相关内容整理一
  8. MR Shuffle流程 入门
  9. 如何不如计算机科学,第四轮学科评估结果:西交不如华中武大,你怎么看?很多网友表示不满!...
  10. linux下的p2p终结者
  11. 汇编语言程序设计的实验环境及上机步骤
  12. mockserver
  13. linux显卡驱动与opengl,NVIDIA率先发布OpenGL 3.0 Linux驱动
  14. Android应用开发——记事本
  15. c# 通过手淘分享查询淘宝优惠券
  16. Oracle 考试概要
  17. LT7911D功能概述 LT7911D是Type-C/DP1.2转双路MIPI/Lvds的一款芯片
  18. 【题解】【循环】幂级数求和
  19. php抓取网页上的指定内容
  20. 安装elasticsearch

热门文章

  1. 面试整理:关于代价函数,正则化
  2. 一台服务器能支持多少docker,一台物理机器部署多个docker
  3. java poker_Java超级高手成长之路!一个Java编写的斗地主游戏
  4. charles抓app包教程_抓包工具--charles(青花瓷)及获取AppStore数据包
  5. 安网路由器 静态IP和PPOE混用时,如果设置了路由器定时重启可能导致路由器罢工...
  6. 数据钻取,详细数据一览无遗!
  7. SAP ABAP 业务对象 BUS6038 AssetDownPayment 资产:预付款 BAPI 清单和相关 TCODE
  8. 计算机硬件与软件ppt,计算机硬件和软件.ppt
  9. 得生态者得天下,亚马逊云科技开启“无边界”合作伙伴模式
  10. Acrel-5000能耗管理系统在武清体育中心项目的应用-安科瑞耿敏花