文章目录

  • 前言
  • 一、SRGAN
    • 1.训练步骤
    • 2.生成器
    • 3.判别器
  • 二、其他准备
    • 1.数据读取
    • 2.VGG19提取特征
    • 4.训练过程代码
    • 5. 预测过程
  • 参考链接

前言

SRGAN 网络是用GAN网络来实现图像超分辨率重建的网络。训练完网络后。只用生成器来重建低分辨率图像。网络结构主要使用生成器(Generator)和判别器(Discriminator)。训练过程不太稳定。一般用于卫星图像,遥感图像的图像重建,人脸图像超分重建。
这里我们使用的高分辨率的数据集 (DIV2K)
数据集下载链接:链接:https://pan.baidu.com/s/1UBle5Cu74TRifcAVz14cDg 提取码:luly
github代码地址:https://github.com/jiantenggei/srgan
重制版代码仓库:https://github.com/jiantenggei/Srgan_

一、SRGAN

1.训练步骤

SRGAN 网络的训练思路如下图所示:

训练步骤如下:
(1) 将低分辨率输入到生成网络,生成高分辨率图像。
(2) 将高分辨率图像输入的判别网络判别真假,与0和1进行对比
(3) 将原始高分辨率图像和生成的高分辨率图像分别用VGG19 的前9层提取特征,将提取的特征计算loss。
(4). 将loss返回给生成器继续训练。
这就是SRGAN 的训练流程了。
接下来我们一一去实现上述步骤。

2.生成器

生成器网络结构如下图所示:

生成器主要有两部分构成,第一部分是residual block 残差块(图中红色方块),第二部分是上采样部分(图中蓝色方块)用来上采样特征图。
残差块:包含一个两个3x3的卷积 BN,PReLu
上采样:使用UpSampling2D,这里可能与原模型不同实现
生成器代码如下所示:

# 生成器中的残差块
def res_block_gen(x, kernal_size, filters, strides):gen = xx = Conv2D(filters = filters, kernel_size = kernal_size, strides = strides, padding = "same")(x)x = BatchNormalization(momentum = 0.5)(x)# Using Parametric ReLUx = PReLU(alpha_initializer='zeros', alpha_regularizer=None, alpha_constraint=None, shared_axes=[1,2])(x)x = Conv2D(filters = filters, kernel_size = kernal_size, strides = strides, padding = "same")(x)x = BatchNormalization(momentum = 0.5)(x)x = add([gen, x])return x#上采样样块
def up_sampling_block(x, kernal_size, filters, strides):x = Conv2D(filters = filters, kernel_size = kernal_size, strides = strides, padding = "same")(x)x = UpSampling2D(size = 2)(x)x = LeakyReLU(alpha = 0.2)(x)return x
#--------------------------------------
# 亚像素卷积上采样块
# 生成器 还是用的 UpSampling2D
# 如果有需要可以自己更改
# -------------------------------------
def SubpixelConv2D(input_shape, scale=4):def subpixel_shape(input_shape):dims = [input_shape[0],input_shape[1] * scale,input_shape[2] * scale,int(input_shape[3] / (scale ** 2))]output_shape = tuple(dims)return output_shapedef subpixel(x):return tf.compat.v1.depth_to_space(x, scale)return Lambda(subpixel, output_shape=subpixel_shape)def Generator(input_shape=[128,128,3]):gen_input = Input(input_shape)x = Conv2D(filters = 64, kernel_size = 9, strides = 1, padding = "same")(gen_input)x = PReLU(alpha_initializer='zeros', alpha_regularizer=None, alpha_constraint=None, shared_axes=[1,2])(x)gen_x = x# 16 个残差快for index in range(16):x = res_block_gen(x, 3, 64, 1)x = Conv2D(filters = 64, kernel_size = 3, strides = 1, padding = "same")(x)x = BatchNormalization(momentum = 0.5)(x)x = add([gen_x, x])#两个上采样 -> 放大四倍for index in range(2):x = up_sampling_block(x, 3, 256, 1)x = Conv2D(filters = 3, kernel_size = 9, strides = 1, padding = "same")(x)x = Activation('tanh')(x)generator_x = Model(inputs = gen_input, outputs = x)return generator_x

3.判别器

判别器主要用于判断生成图片的真假。与0和1比较,1代表真图片,0代表假图片。这里的0和1 是与判别器输出大小想用的向量,而不是单纯的0,1,判别器网络结果如下所示:

判别网络由一个个包含卷积、BN、和LeakyRelu 激活函数的块组成,最后输出1或0, 实际上就相当于是一个二分类的分类网络,代码如下所示:

#判别器中的卷积块
def discriminator_block(x, filters, kernel_size, strides):x = Conv2D(filters = filters, kernel_size = kernel_size, strides = strides, padding = "same")(x)x = BatchNormalization(momentum = 0.5)(x)x = LeakyReLU(alpha = 0.2)(x)return xdef Discriminator(image_shape=[512,512,3]):dis_input = Input(image_shape)x = Conv2D(filters = 64, kernel_size = 3, strides = 1, padding = "same")(dis_input)x = LeakyReLU(alpha = 0.2)(x)x = discriminator_block(x, 64, 3, 2)x = discriminator_block(x, 128, 3, 1)x = discriminator_block(x, 128, 3, 2)x = discriminator_block(x, 256, 3, 1)x = discriminator_block(x, 256, 3, 2)x = discriminator_block(x, 512, 3, 1)x = discriminator_block(x, 512, 3, 2)#x = Flatten()(x) # 这里采用Flatten 太浪费现存了 改为 全局池化x = GlobalAveragePooling2D()(x)x = Dense(1024)(x)x = LeakyReLU(alpha = 0.2)(x)x = Dense(1)(x)x = Activation('sigmoid')(x) discriminator_x = Model(inputs = dis_input, outputs = x)return discriminator_x

网络主要分为生成器和判别器,训练时相互对抗,以达到一个很好的平衡为目的。

二、其他准备

1.数据读取

训练时,输入的高分辨率图像一般为很大的图片。需要将其随机裁剪为预设的大小。再将裁剪的图像,下采样作为低分辨率图像。代码过长,不全部贴出来了。

class SRganDataset(keras.utils.Sequence):def __init__(self, train_lines, lr_shape, hr_shape, batch_size):super(SRganDataset, self).__init__()self.train_lines    = train_linesself.train_batches  = len(train_lines)self.lr_shape       = lr_shapeself.hr_shape       = hr_shapeself.batch_size     = batch_sizedef __len__(self):return math.ceil(self.train_batches / float(self.batch_size))def __getitem__(self, index):if index == 0:self.on_epoch_begin()images_l = []images_h = []for i in range(index * self.batch_size, (index + 1) * self.batch_size):  i = i % self.train_batchesimage_origin = Image.open(self.train_lines[i].split()[0])if self.rand()<.5:img_h = self.get_random_data(image_origin, self.hr_shape)else:img_h = self.random_crop(image_origin, self.hr_shape[1], self.hr_shape[0])img_l = img_h.resize((self.lr_shape[1], self.lr_shape[0]), Image.BICUBIC)img_h = preprocess_input(np.array(img_h, dtype=np.float32), [0.5,0.5,0.5], [0.5,0.5,0.5])img_l = preprocess_input(np.array(img_l, dtype=np.float32), [0.5,0.5,0.5], [0.5,0.5,0.5])images_h.append(img_h)images_l.append(img_l)return np.array(images_l), np.array(images_h)def on_epoch_begin(self):shuffle(self.train_lines)def rand(self, a=0, b=1):return np.random.rand()*(b-a) + a

2.VGG19提取特征

VGG19提取生成高分辨率的图像特征与真实高分辨率图像特征进行比较。计算LOSS。

class VGG_LOSS(object):def __init__(self, image_shape):self.image_shape = image_shape# 用VGG19 计算 高清图和生成的高清图之间的差别def vgg_loss(self, y_true, y_pred):vgg19 = VGG19(include_top=False, weights='imagenet', input_shape=self.image_shape)vgg19.trainable = False# Make trainable as Falsefor l in vgg19.layers:l.trainable = Falsemodel = Model(inputs=vgg19.input, outputs=vgg19.get_layer('block5_conv4').output)model.trainable = Falsereturn K.mean(K.square(model(y_true) - model(y_pred)))

4.训练过程代码

def train(epochs, batch_size, model_save_dir):train_annotation_path = 'dataset.txt'#下采样倍数downscale_factor = 4#输入图片形状hr_shape = (384,384,3)#加载数据集with open(train_annotation_path, encoding='utf-8') as f:train_lines = f.readlines()#计算 生成图片 和 原高清图 之间的lossloss = VGG_LOSS(hr_shape) #打乱 random.shuffle(train_lines)batch_count = int(len(train_lines)/ batch_size)lr_shape = (hr_shape[0]//downscale_factor, hr_shape[1]//downscale_factor, hr_shape[2])generator = Generator(lr_shape)discriminator = Discriminator(hr_shape)optimizer =tf.optimizers.Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)generator.compile(loss=loss.vgg_loss, optimizer=optimizer)discriminator.compile(loss="binary_crossentropy", optimizer=optimizer)gen                 = SRganDataset(train_lines, lr_shape[:2], hr_shape[:2], batch_size)gan = get_gan(discriminator, lr_shape, generator, optimizer,loss.vgg_loss)loss_file = open(model_save_dir + 'losses.txt' , 'w+')loss_file.close()for epoch in range(0,epochs):print ('-'*15, 'Epoch %d' % epoch, '-'*15)with tqdm(total=batch_count,desc=f'Epoch {epoch + 1}/{epochs}',postfix=dict,mininterval=0.3) as pbar:for iteration, batch in enumerate(gen):if iteration >= batch_count:breakimgs_lr, imgs_hr        = batch#生成器生成图片gen_img = generator.predict(imgs_lr)real_data_Y = np.ones(batch_size) - np.random.random_sample(batch_size)*0.2fake_data_Y = np.random.random_sample(batch_size)*0.2discriminator.trainable = Trued_loss_real = discriminator.train_on_batch(imgs_hr, real_data_Y)d_loss_fake = discriminator.train_on_batch(gen_img, fake_data_Y)discriminator_loss = 0.5 * np.add(d_loss_fake, d_loss_real)gan_Y = np.ones(batch_size) - np.random.random_sample(batch_size)*0.2discriminator.trainable = Falsegan_loss = gan.train_on_batch(imgs_lr, [imgs_hr,gan_Y])pbar.set_postfix(**{'G_loss'        : gan_loss[0] , 'D_loss'        : discriminator_loss,'PSNR'          : gan_loss[4]},)pbar.update(1)  print("discriminator_loss : %f" % discriminator_loss)print("gan_loss :", gan_loss)gan_loss = str(gan_loss)loss_file = open(model_save_dir + 'losses.txt' , 'a')loss_file.write('epoch%d : gan_loss = %s ; discriminator_loss = %f\n' %(epoch, gan_loss, discriminator_loss) )loss_file.close()show_result(epoch,generator,imgs_lr,imgs_hr)generator.save(model_save_dir + 'gen_model%d.h5' % epoch)discriminator.save(model_save_dir + 'dis_model%d.h5' % epoch)

训练时,在目录result 目录下会出现这样的图片。

lr_images : 低分辨率图
Fake_Hr_Images:生成的高分辨率图像
True_Hr_Images:远高分图像

5. 预测过程

预测部分代码:

from pickle import NONE
from PIL import Image
import cv2
import numpy as np
import matplotlib.pyplot as plt
from nets.nets import Generator
before_image = Image.open(r"0.jpg")before_image = before_image.convert("RGB")
gen_model = Generator([None,None,3])
gen_model.load_weights('loss\gen_model99.h5')
# gen_model.summary()
new_img = Image.new('RGB', before_image.size, (128, 128, 128))
new_img.paste(before_image)
# plt.imshow(new_img)
# plt.show()new_image = np.array(new_img)/127.5 - 1
# 三维变4维  因为神经网络的输入是四维的
new_image = np.expand_dims(new_image, axis=0)  # [batch_size,w,h,c]
fake = (gen_model.predict(new_image)*0.5 + 0.5)*255
#将np array 形式的图片转换为unit8  把数据转换为图
fake = Image.fromarray(np.uint8(fake[0]))fake.save("out.png")
titles = ['Generated', 'Original']
plt.subplot(1, 2, 1)
plt.imshow(before_image)
plt.subplot(1, 2, 2)
plt.imshow(fake)
plt.show()

重建效果:

参考链接

https://github.com/bubbliiiing/srgan-keras
https://github.com/deepak112/Keras-SRGAN
https://github.com/JustinhoCHN/SRGAN_Wasserstein

SRGAN 图像超分辨率重建(Keras)相关推荐

  1. 基于SRGAN实现图像超分辨率重建或复原

    基于SRGAN实现图像超分辨率重建或复原 2018年04月20日 11:50:26 山水之间2018 阅读数 14064 文章标签: SRGAN图像超分辨率重建 更多 分类专栏: 深度学习 图像处理  ...

  2. 超分辨率——基于SRGAN的图像超分辨率重建(Pytorch实现)

    基于SRGAN的图像超分辨率重建 本文偏新手项,因此只是作为定性学习使用,因此不涉及最后的定量评估环节 目录 基于SRGAN的图像超分辨率重建 1 简要介绍 2 代码实现 2.1 开发环境 2.2 主 ...

  3. 学习笔记之——基于深度学习的图像超分辨率重建

    最近开展图像超分辨率( Image Super Resolution)方面的研究,做了一些列的调研,并结合本人的理解总结成本博文~(本博文仅用于本人的学习笔记,不做商业用途) 本博文涉及的paper已 ...

  4. 基于深度学习的图像超分辨率重建

    最近开展图像超分辨率( Image Super Resolution)方面的研究,做了一些列的调研,并结合本人的理解总结成本博文~(本博文仅用于本人的学习笔记,不做商业用途) 本博文涉及的paper已 ...

  5. GAN的应用-SRGAN图像超分辨率重构、U-net结构和字“姐”跳动学习心得

    GAN的应用 -- SRGAN图像超分辨率重构 项目地址:https://aistudio.baidu.com/aistudio/projectdetail/843989 文章来源:2017 IEEE ...

  6. 图像超分辨率重建算法,让模糊图像变清晰(附数据和代码)

    一.  图像超分辨率重建概述 1. 概念 图像分辨率是一组用于评估图像中蕴含细节信息丰富程度的性能参数,包括时间分辨率.空间分辨率及色阶分辨率等,体现了成像系统实际所能反映物体细节信息的能力.相较于低 ...

  7. 基于深度学习的图像超分辨率重建技术的研究

    1 超分辨率重建技术的研究背景与意义 图像分辨率是一组用于评估图像中蕴含细节信息丰富程度的性能参数,包括时间分辨率.空间分辨率及色阶分辨率等,体现了成像系统实际所能反映物体细节信息的能力.相较于低分辨 ...

  8. PANet:基于金字塔注意力网络的图像超分辨率重建(Pytorch实现)

    PANet:基于金字塔注意力网络的图像超分辨率重建 [!] 为了提高代码的可读性,本文模型的具体实现与原文具有一定区别,因此会造成性能上的差异 文章目录 PANet:基于金字塔注意力网络的图像超分辨率 ...

  9. 一文掌握图像超分辨率重建(算法原理、Pytorch实现)——含完整代码和数据

    目录 一.  图像超分辨率重建概述 1. 概念 2. 应用领域 3. 研究进展 3.1 传统超分辨率重建算法 3.2 基于深度学习的超分辨率重建算法 二.  SRResNet算法原理和Pytorch实 ...

最新文章

  1. zk如何实现watch
  2. JavaScript为什么使用原型模式而不是类模式
  3. JZ2440学习总结5
  4. AI入门:无门槛可以玩的神经网络
  5. 斯特林数第一类数的应用 hdu3625
  6. matlab生成exe独立运行文件已破解(好用)
  7. torch tensor去掉1维_浑身是刺的“维c之王”,有人管它叫“菠萝”,有人管它叫“梨”...
  8. postman发送HTTP请求自动生成MD5签名
  9. arc科学计算机在线应用使用,ArcMap字段计算器(Field Calculator)的妙用
  10. Redis常见配置介绍
  11. 关于Mac电脑更新IP地址的解决方法!
  12. foreach进不去报错java_为什么阿里巴巴Java开发手册中强制要求不要在foreach循环里进行元素的remove和add操作...
  13. HTML学习基础教程
  14. 永中集成Office:既是创新产品,也是战略储备
  15. 小爱同学服务器响应,小爱同学反应慢
  16. clickhouse总结
  17. bigdecimal不等于0怎么写_我不知道三打白骨精读后感怎么写才算正能量
  18. 英伟达发布528.02 WHQL 新驱动,支持4070 Ti
  19. 传统支付方式不能满足线下支付的需求
  20. Java项目:公寓房屋出租系统(java+SSM+JSP+EasyUI+Echarts+Mysql)

热门文章

  1. matlab循环读取csv文件
  2. C#学习笔记之——委托、线程、事件、lamda表达式
  3. 全球及中国信息存储行业发展机遇与供需前景预测报告2022版
  4. C#PrintDocument打印尺寸调整
  5. 数学建模——lingo实现多目标规划
  6. tablayout下划线长度
  7. 烤仔TVの尚书房 | 对话 Chainlink 条子哥 共话 DeFi 狂热下的冷静
  8. HDU1009 FatMouse' Trade
  9. 寻宝 oj_寻宝体验正在流行
  10. java 正则判断身份证号