转载自:https://blog.csdn.net/weixin_46773169/article/details/105462644,本文只做个人记录学习使用,版权归原作者所有。

github链接:https://github.com/zhilin007/FFA-Net

代码注释:

data_utils.py

import torch.utils.data as data
import torchvision.transforms as tfs
from torchvision.transforms import functional as FF
import os, syssys.path.append('.')
sys.path.append('..')
import numpy as np
import torch
import random
from PIL import Image
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
from torchvision.utils import make_grid
from net.metrics import *                # metrics.py
from net.option import opt               # option.pyBS = opt.bs
print('BS:',BS)
crop_size = 'whole_img'   # 裁剪图片的大小
if opt.crop:crop_size = opt.crop_sizedef tensorShow(tensors, titles=None):'''t:BCWH'''fig = plt.figure()for tensor, tit, i in zip(tensors, titles, range(len(tensors))):img = make_grid(tensor)npimg = img.numpy()ax = fig.add_subplot(211 + i)ax.imshow(np.transpose(npimg, (1, 2, 0)))ax.set_title(tit)plt.show()class RESIDE_Dataset(data.Dataset):def __init__(self, path, train, size=crop_size, format='.png'):super(RESIDE_Dataset, self).__init__()self.size = size# print('crop size:', size) # ---本人测试命令self.train = trainself.format = formatself.haze_imgs_dir = os.listdir(os.path.join(path, 'hazy'))# 返回指定路径下所有文件和文件夹的名字,并存放于一个列表中# print('self_haze_imgs_dir :', self.haze_imgs_dir) # 本人测试命令self.haze_imgs = [os.path.join(path, 'hazy', img) for img in self.haze_imgs_dir]# hazy图像所有的路径,并存放于一个列表中# print('self_haze_imgs:',self.haze_imgs) # ---本人测试命令self.clear_dir = os.path.join(path, 'clear')# print('self_clean:', self.clear_dir) # ---本人测试命令def __getitem__(self, index):haze = Image.open(self.haze_imgs[index])# print('haze_size:',haze.size,haze.size[::-1]) # ---本人测试命令# print('index:', index) # ---本人测试命令if isinstance(self.size, int):  # 如果size是int型,则返回True# print('这个isinstance方法被调用') # ---本人测试命令while haze.size[0] < self.size or haze.size[1] < self.size:index = random.randint(0, 20000)haze = Image.open(self.haze_imgs[index])img = self.haze_imgs[index]  # 从haze_imgs(路径名称列表)中取出对于索引值的路径# print('img:', img) # ---本人测试命令# id = img.split('/')[-1].split('_')[0] # 此命令在windows下执行会报路径错误,改为以下命令id = img.split('\\')[-1].split('_')[0]# 提取最后‘\’之后和第一个‘_’之前的内容,以hazy图像的路径找到对应clear图像的序号# print('id:',id) # ---本人测试命令clear_name = id + self.format# print('clear_name:', clear_name) # ---本人测试命令# test_dir = os.path.join(self.clear_dir, clear_name) # ---本人测试命令# print('clear_dir:',test_dir) # ---本人测试命令clear = Image.open(os.path.join(self.clear_dir, clear_name))clear = tfs.CenterCrop(haze.size[::-1])(clear)# haze.size=(W, H) -> haze.size[::-1]=(H, W),然后按(H, W)对clear进行中心裁剪if not isinstance(self.size, str): # 如果size不是str类型,则返回True# print('这个not isinstance方法被调用')i, j, h, w = tfs.RandomCrop.get_params(haze, output_size=(self.size, self.size))'''w, h = haze.sizeth, tw = output_sizei = random.randint(0, h - th)j = random.randint(0, w - tw)return i, j, th, tw'''haze = FF.crop(haze, i, j, h, w)  # 把haze随机裁剪成(i, j, h, w)的大小clear = FF.crop(clear, i, j, h, w)haze, clear = self.augData(haze.convert("RGB"), clear.convert("RGB")) # 使用数据增强后把图片转为RGB格式return haze, cleardef augData(self, data, target):  # 数据增强if self.train:rand_hor = random.randint(0, 1)  # 从[0, 1]中随机选一个数rand_rot = random.randint(0, 3)  # 从[0, 1, 2, 3]中随机选一个数data = tfs.RandomHorizontalFlip(rand_hor)(data)# 依据概率rand_hor对data(图片)进行水平翻转(这里,rand_hor=0:不翻转;=1:翻转)target = tfs.RandomHorizontalFlip(rand_hor)(target)if rand_rot:  # rand_rot>0时执行此命令data = FF.rotate(data, 90 * rand_rot)  # 将data旋转的角度为90*rand_rottarget = FF.rotate(target, 90 * rand_rot)data = tfs.ToTensor()(data)  # range [0, 255] -> [0.0, 1.0]data = tfs.Normalize(mean=[0.64, 0.6, 0.58], std=[0.14, 0.15, 0.152])(data)# 归一化操作# 输入的data(图片)大小为CxWxH(三维张量),mean为各通道的均值,std为各通道的方差# output = (input - mean) / stdtarget = tfs.ToTensor()(target)return data, targetdef __len__(self):return len(self.haze_imgs)import os
pwd = os.getcwd()
print(pwd)
# path = '/FFA-Net-master/data'  # path to your 'data' folder
path = '../data'  # path to your 'data' folderITS_train_loader = DataLoader(dataset=RESIDE_Dataset(path + '/RESIDE/ITS', train=True, size=crop_size), batch_size=BS,shuffle=True)
ITS_test_loader = DataLoader(dataset=RESIDE_Dataset(path + '/RESIDE/SOTS/indoor', train=False, size='whole img'),batch_size=1, shuffle=False)
OTS_train_loader = DataLoader(dataset=RESIDE_Dataset(path + '/RESIDE/OTS', train=True, format='.jpg'), batch_size=BS,shuffle=True)
OTS_test_loader = DataLoader(dataset=RESIDE_Dataset(path + '/RESIDE/SOTS/outdoor', train=False, size='whole img', format='.png'), batch_size=1,shuffle=False)
# 如果train_loader没有数据,即检查Dataset的__len__()函数输出为零,会报ValueError:num_samples...的错if __name__ == "__main__":pass

option.py

import torch,os,sys,torchvision,argparse
import torchvision.transforms as tfs
import time,math
import numpy as np
from torch.backends import cudnn
from torch import optim
import torch,warnings
from torch import nn
import torchvision.utils as vutils
warnings.filterwarnings('ignore')parser=argparse.ArgumentParser()  # 命令行选项、参数和子命令解析器
'''
argparse 模块可以让人轻松编写用户友好的命令行接口。
程序定义它需要的参数,然后 argparse 将弄清如何从 sys.argv 解析出那些参数。
argparse 模块还会自动生成帮助和使用手册,并在用户给程序传入无效参数时报出错误信息。
'''# 添加参数
# default - 当参数未在命令行中出现时使用的值。
# type - 命令行参数应当被转换成的类型。
# action='store_true',只要运行时该变量有传参就将该变量设为True
parser.add_argument('--steps',type=int,default=10) # 10000
parser.add_argument('--device',type=str,default='Automatic detection')
parser.add_argument('--resume',type=bool,default=True)
parser.add_argument('--eval_step',type=int,default=5)  # 5000
parser.add_argument('--lr', default=0.0001, type=float, help='learning rate')
parser.add_argument('--model_dir',type=str,default='./trained_models/')
parser.add_argument('--trainset',type=str,default='its_train')
parser.add_argument('--testset',type=str,default='its_test')
parser.add_argument('--net',type=str,default='ffa')
parser.add_argument('--gps',type=int,default=3,help='residual_groups')
parser.add_argument('--blocks',type=int,default=20,help='residual_blocks')
parser.add_argument('--bs',type=int,default=16,help='batch size')
parser.add_argument('--crop',action='store_true')
parser.add_argument('--crop_size',type=int,default=240,help='Takes effect when using --crop ')
parser.add_argument('--no_lr_sche',action='store_true',help='no lr cos schedule')
parser.add_argument('--perloss',action='store_true',help='perceptual loss')opt=parser.parse_args()  # 解析参数
opt.device='cuda' if torch.cuda.is_available() else 'cpu'
model_name=opt.trainset+'_'+opt.net.split('.')[0]+'_'+str(opt.gps)+'_'+str(opt.blocks)
# split('.')[0] , 以'.'作分隔符,输出'.'之前的内容opt.model_dir=opt.model_dir+model_name+'.pk'
log_dir='logs/'+model_name# ---以下为本人自己的测试命令---
# print('opt:', opt)
# print('model_name:', model_name)
# print('model_dir:',opt.model_dir)
# print('log_dir:', log_dir)if not os.path.exists('trained_models'):os.mkdir('trained_models')  # 创建路径
if not os.path.exists('numpy_files'):os.mkdir('numpy_files')
if not os.path.exists('logs'):os.mkdir('logs')
if not os.path.exists('samples'):os.mkdir('samples')
if not os.path.exists(f"samples/{model_name}"):os.mkdir(f'samples/{model_name}')
if not os.path.exists(log_dir):os.mkdir(log_dir)

metrics.py

from math import exp
import math
import numpy as np
import torch
import torch.nn.functional as F
from torch.autograd import Variable
from math import exp
import math
import numpy as np
import torch
import torch.nn.functional as F
from torch.autograd import Variable
from  torchvision.transforms import ToPILImagedef gaussian(window_size, sigma):gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])return gauss / gauss.sum()def create_window(window_size, channel):_1D_window = gaussian(window_size, 1.5).unsqueeze(1)  # 添加一个轴,变成二维张量_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)# torch.mul(a, b)是矩阵a和b对应位相乘,a和b的维度必须相等,比如a的维度是(1, 2),b的维度是(1, 2),返回的仍是(1, 2)的矩阵# torch.mm(a, b)是矩阵a和b矩阵相乘,比如a的维度是(1, 2),b的维度是(2, 3),返回的就是(1, 3)的矩阵# .t(), 求转置,输入tensor结构维度<=2D# 在二维张量前面添加2个轴,变成四维张量window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())# 把张量扩展成(channel, 1, window_size, window_size)的大小,以原来的值填充(其自身的值不变)# contiguous:view只能用在contiguous的variable上。contiguous一般与transpose,permute,view搭配使用# 即使用transpose或permute进行维度变换后,需要用contiguous()来返回一个contiguous copy,然后方可使用view对维度进行变形return windowdef _ssim(img1, img2, window, window_size, channel, size_average=True):mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)mu1_sq = mu1.pow(2)  # mul的2次方mu2_sq = mu2.pow(2)mu1_mu2 = mu1 * mu2sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sqsigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sqsigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2C1 = 0.01 ** 2C2 = 0.03 ** 2ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))if size_average:return ssim_map.mean()else:return ssim_map.mean(1).mean(1).mean(1)def ssim(img1, img2, window_size=11, size_average=True):img1=torch.clamp(img1,min=0,max=1)# 将输入img1张量每个元素的范围限制到区间[min, max],返回结果到一个新张量。img2=torch.clamp(img2,min=0,max=1)(_, channel, _, _) = img1.size()  # 取出img1的通道数window = create_window(window_size, channel)if img1.is_cuda:window = window.cuda(img1.get_device())window = window.type_as(img1)  # 将window张量转换为给定img1类型的张量return _ssim(img1, img2, window, window_size, channel, size_average)def psnr(pred, gt):pred=pred.clamp(0,1).cpu().numpy() # 将gpu上的数据类型转为cpu上的数据类型,然后转化为numpy()数组gt=gt.clamp(0,1).cpu().numpy()imdff = pred - gtrmse = math.sqrt(np.mean(imdff ** 2))if rmse == 0:return 100return 20 * math.log10( 1.0 / rmse)if __name__ == "__main__":pass

FFA.py

import torch.nn as nn
import torchdef default_conv(in_channels, out_channels, kernel_size, bias=True):return nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias)  # '//'整数除法,'/'浮点数除法class PALayer(nn.Module):def __init__(self, channel):super(PALayer, self).__init__()self.pa = nn.Sequential(# PA层的卷积核不应该是3x3么,为什么这里是1x1?# 这样的话PA层与CA层只差一个全局平均池化操作的区别,而且1x1是抓通道特征,并不能实现像素注意的功能# 论文中“实施细节”处写道只有CA模块的卷积核为1x1,怀疑此处代码失误nn.Conv2d(channel, channel // 8, 1, padding=0, bias=True),nn.ReLU(inplace=True),  # inplace 原位操作,即不经过复制操作,而是直接在原来的内存上改变它的值nn.Conv2d(channel // 8, 1, 1, padding=0, bias=True),# 第一个'1'表示输出的通道数为1,即实现CxHxW -> 1xHxWnn.Sigmoid())def forward(self, x):y = self.pa(x)return x * yclass CALayer(nn.Module):def __init__(self, channel):super(CALayer, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)# 自适应平均池化,输出大小为: 1 x 1,即把一张图片(HxW)的所有的值加起来取平均,大小变为1x1self.ca = nn.Sequential(# 这里,'1'表示卷积核的大小为1x1,这是实现特征注意功能的关键:# 用channel个channel//8层的conv2D 1x1滤镜作逐点卷积,抓通道相关性nn.Conv2d(channel, channel // 8, 1, padding=0, bias=True),nn.ReLU(inplace=True),nn.Conv2d(channel // 8, channel, 1, padding=0, bias=True),nn.Sigmoid())def forward(self, x):y = self.avg_pool(x)y = self.ca(y)return x * yclass Block(nn.Module):def __init__(self, conv, dim, kernel_size, ):super(Block, self).__init__()self.conv1 = conv(dim, dim, kernel_size, bias=True)self.act1 = nn.ReLU(inplace=True)self.conv2 = conv(dim, dim, kernel_size, bias=True)self.calayer = CALayer(dim)self.palayer = PALayer(dim)def forward(self, x):res = self.act1(self.conv1(x))res = res + xres = self.conv2(res)res = self.calayer(res)res = self.palayer(res)res += xreturn resclass Group(nn.Module):def __init__(self, conv, dim, kernel_size, blocks):super(Group, self).__init__()modules = [Block(conv, dim, kernel_size) for _ in range(blocks)]# moduels列表里有n(=blocks)个Block块modules.append(conv(dim, dim, kernel_size))self.gp = nn.Sequential(*modules)# modules列表前加*号,表示将列表解开成独立的参数。# 转化为Sequential模型,网络为n个Block块线性堆叠。def forward(self, x):res = self.gp(x)res += xreturn resclass FFA(nn.Module):def __init__(self, gps, blocks, conv=default_conv):super(FFA, self).__init__()self.gps = gpsself.dim = 64kernel_size = 3pre_process = [conv(3, self.dim, kernel_size)]assert self.gps == 3self.g1 = Group(conv, self.dim, kernel_size, blocks=blocks)self.g2 = Group(conv, self.dim, kernel_size, blocks=blocks)self.g3 = Group(conv, self.dim, kernel_size, blocks=blocks)self.ca = nn.Sequential(*[nn.AdaptiveAvgPool2d(1),nn.Conv2d(self.dim * self.gps, self.dim // 16, 1, padding=0),nn.ReLU(inplace=True),nn.Conv2d(self.dim // 16, self.dim * self.gps, 1, padding=0, bias=True),nn.Sigmoid()])self.palayer = PALayer(self.dim)post_precess = [conv(self.dim, self.dim, kernel_size),conv(self.dim, 3, kernel_size)]self.pre = nn.Sequential(*pre_process)self.post = nn.Sequential(*post_precess)def forward(self, x1):x = self.pre(x1)res1 = self.g1(x)res2 = self.g2(res1)res3 = self.g3(res2)w = self.ca(torch.cat([res1, res2, res3], dim=1))# 按序号为1的轴进行拼接,即按通道进行拼接,每个res大小为([1, 64, H, W]),# cat后大小为([1, 192, H, W]),w.size() = ([1, 192, 1, 1])w = w.view(-1, self.gps, self.dim)[:, :, :, None, None]  # 添加两个轴(元素是None)# w.size() = ([1, 3, 64, 1, 1])out = w[:, 0, ::] * res1 + w[:, 1, ::] * res2 + w[:, 2, ::] * res3# w的三个通道分别与res1,2,3相乘再相加,out.size()=([1, 64, H, W])out = self.palayer(out)x = self.post(out)return x + x1if __name__ == "__main__":# 当.py文件被直接运行时,if __name__ == '__main__'之下的代码块将被运行;# 当.py文件以模块形式被导入时, if __name__ == '__main__'之下的代码块不被运行net = FFA(gps=3, blocks=19)print(net)

main.py

import torch, os, sys, torchvision, argparse
import torchvision.transforms as tfsfrom net.models.FFA import FFA # FFA.py
from net.metrics import psnr, ssim # metrics.py
from net.models import *
import time, math
import numpy as np
from torch.backends import cudnn
from torch import optim
import torch, warnings
from torch import nn
# from tensorboardX import SummaryWriter
import torchvision.utils as vutilswarnings.filterwarnings('ignore')
from net.option import opt, model_name, log_dir # option.py
from net.data_utils import *  # data_utils.py
from torchvision.models import vgg16print('log_dir :', log_dir)
print('model_name:', model_name)models_ = {'ffa': FFA(gps=opt.gps, blocks=opt.blocks),
}loaders_ = {'its_train': ITS_train_loader,'its_test': ITS_test_loader,'ots_train': OTS_train_loader,'ots_test': OTS_test_loader
}start_time = time.time()  # 返回当前时间的时间戳
T = opt.steps  # default=100000def lr_schedule_cosdecay(t, T, init_lr=opt.lr):# 文章中公式(9),采用cosine annealing strategy进行学习率衰减,直到0lr = 0.5 * (1 + math.cos(t * math.pi / T)) * init_lrreturn lrdef train(net, loader_train, loader_test, optim, criterion):losses = []start_step = 0max_ssim = 0max_psnr = 0ssims = []psnrs = []if opt.resume and os.path.exists(opt.model_dir):  # 如果已有训练好的模型,返回trueprint(f'resume from {opt.model_dir}')  # 带f的print可以执行表达式ckp = torch.load(opt.model_dir)  # 将对象文件反序列化为内存losses = ckp['losses']  # 取出已训练好的模型的lossnet.load_state_dict(ckp['model'])# 使用反序列化状态字典加载model’s参数字典# state_dict是个简单的Python dictionary对象,它将每个层映射到它的参数张量start_step = ckp['step']max_ssim = ckp['max_ssim']max_psnr = ckp['max_psnr']psnrs = ckp['psnrs']ssims = ckp['ssims']print(f'start_step:{start_step} start training ---')else:print('train from scratch *** ')for step in range(start_step + 1, opt.steps + 1):  # opt.steps=10(default)net.train()  # 定义的网络进入训练模式lr = opt.lrif not opt.no_lr_sche:lr = lr_schedule_cosdecay(step, T)for param_group in optim.param_groups:  # 在训练中动态的调整学习率param_group["lr"] = lrx, y = next(iter(loader_train))# 读取一个读取一个batch的数据,batch size=16时实际对应16张图像# dataloader本质上是一个可迭代对象,使用iter()访问,不能使用next()访问;# 使用iter(dataloader)返回的是一个迭代器,然后可以使用next访问x = x.to(opt.device)  # 若opt.device=cuda,即转移到GPU运行y = y.to(opt.device)out = net(x)  # 把x输入网络训练loss = criterion[0](out, y)if opt.perloss:  # Perceptual loss为L1损失和L2损失的加权和loss2 = criterion[1](out, y)loss = loss + 0.04 * loss2loss.backward()  # 反向传播求梯度optim.step()  # 更新参数optim.zero_grad()  # 清除梯度,为下一个batch训练做准备losses.append(loss.item())  # loss是个标量,item表示取出这个标量,然后放入losses中print(f'\rtrain loss : {loss.item():.5f}| step :{step}/{opt.steps}|lr :{lr :.7f} |time_used :{(time.time() - start_time) / 60 :.1f}',end='', flush=True)# with SummaryWriter(logdir=log_dir,comment=log_dir) as writer:#  writer.add_scalar('data/loss',loss,step)if step % opt.eval_step == 0:  # default=5000with torch.no_grad():  # 切断梯度计算,不会进行反向传播,因为SSIM和PSNR的计算不需要ssim_eval, psnr_eval = test(net, loader_test, max_psnr, max_ssim, step)  # 计算SSIM,PSNRprint(f'\nstep :{step} |ssim:{ssim_eval:.4f}| psnr:{psnr_eval:.4f}')# with SummaryWriter(logdir=log_dir,comment=log_dir) as writer:#  writer.add_scalar('data/ssim',ssim_eval,step)#  writer.add_scalar('data/psnr',psnr_eval,step)#  writer.add_scalars('group',{#     'ssim':ssim_eval,#     'psnr':psnr_eval,#     'loss':loss#  },step)ssims.append(ssim_eval)psnrs.append(psnr_eval)if ssim_eval > max_ssim and psnr_eval > max_psnr:max_ssim = max(max_ssim, ssim_eval)max_psnr = max(max_psnr, psnr_eval)torch.save({'step': step,'max_psnr': max_psnr,'max_ssim': max_ssim,'ssims': ssims,'psnrs': psnrs,'losses': losses,'model': net.state_dict()}, opt.model_dir)  # 保存各项参数到model_dir中print(f'\n model saved at step :{step}| max_psnr:{max_psnr:.4f}|max_ssim:{max_ssim:.4f}')# 把参数保存为.npy文件np.save(f'./numpy_files/{model_name}_{opt.steps}_losses.npy', losses)     np.save(f'./numpy_files/{model_name}_{opt.steps}_ssims.npy', ssims)np.save(f'./numpy_files/{model_name}_{opt.steps}_psnrs.npy', psnrs)def test(net, loader_test, max_psnr, max_ssim, step):net.eval()  # 网络参数会被固定,权值不会被更新torch.cuda.empty_cache()  # 清空显存ssims = []psnrs = []# s=Truefor i, (inputs, targets) in enumerate(loader_test):inputs = inputs.to(opt.device)targets = targets.to(opt.device)pred = net(inputs)# # print(pred)# tfs.ToPILImage()(torch.squeeze(targets.cpu())).save('111.png')# vutils.save_image(targets.cpu(),'target.png')# vutils.save_image(pred.cpu(),'pred.png')ssim1 = ssim(pred, targets).item()psnr1 = psnr(pred, targets)ssims.append(ssim1)psnrs.append(psnr1)# if (psnr1>max_psnr or ssim1 > max_ssim) and s :#     ts=vutils.make_grid([torch.squeeze(inputs.cpu()),torch.squeeze(targets.cpu()),torch.squeeze(pred.clamp(0,1).cpu())])#     vutils.save_image(ts,f'samples/{model_name}/{step}_{psnr1:.4}_{ssim1:.4}.png')#     s=Falsereturn np.mean(ssims), np.mean(psnrs)if __name__ == "__main__":'''直接执行该模块(main.py),此时__name__=main.py,以下语句才会被执行;如果该模块 import 到其他模块中,此时__name__=main,以下语句不会被执行,。'''loader_train = loaders_[opt.trainset]loader_test = loaders_[opt.testset]net = models_[opt.net]net = net.to(opt.device)if opt.device == 'cuda':net = torch.nn.DataParallel(net)# 在多个GPU上并行计算,是将输入一个batch的数据均分成多份,分别送到对应的GPU进行计算,各个GPU得到的梯度累加。# cudnn.benchmark = True让内置的cuDNN 的auto-tuner自动寻找最适合当前配置的高效算法,来达到优化运行的效率criterion = []criterion.append(nn.L1Loss().to(opt.device))  # 采用L1损失,放入certerion[0]中if opt.perloss:vgg_model = vgg16(pretrained=True).features[:16]# 使用预训练的权重,只调用特征提取部分的前16层,分类部分已抛弃掉vgg_model = vgg_model.to(opt.device)for param in vgg_model.parameters():param.requires_grad = False  # vgg_model不进行梯度计算criterion.append(PerLoss(vgg_model).to(opt.device))  # 计算的Perceptual loss损失放入criterion[1]中optimizer = optim.Adam(params=filter(lambda x: x.requires_grad, net.parameters()), lr=opt.lr, betas=(0.9, 0.999),eps=1e-08)# filter函数将net模型中属性requires_grad = True的参数筛选出来,传到优化器(以Adam为例)中,只有这些参数会被求导数和更新optimizer.zero_grad()train(net, loader_train, loader_test, optimizer, criterion)
n.L1Loss().to(opt.device))  # 采用L1损失,放入certerion[0]中if opt.perloss:vgg_model = vgg16(pretrained=True).features[:16]# 使用预训练的权重,只调用特征提取部分的前16层,分类部分已抛弃掉vgg_model = vgg_model.to(opt.device)for param in vgg_model.parameters():param.requires_grad = False  # vgg_model不进行梯度计算criterion.append(PerLoss(vgg_model).to(opt.device))  # 计算的Perceptual loss损失放入criterion[1]中optimizer = optim.Adam(params=filter(lambda x: x.requires_grad, net.parameters()), lr=opt.lr, betas=(0.9, 0.999),eps=1e-08)# filter函数将net模型中属性requires_grad = True的参数筛选出来,传到优化器(以Adam为例)中,只有这些参数会被求导数和更新optimizer.zero_grad()train(net, loader_train, loader_test, optimizer, criterion)

附:为方便理解网络,将FFA.py的blocks改为1

if __name__ == "__main__":net = FFA(gps=3, blocks=1)  # blocks改为1print(net)
FFA((g1): Group((gp): Sequential((0): Block((conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(act1): ReLU(inplace=True)(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(calayer): CALayer((avg_pool): AdaptiveAvgPool2d(output_size=1)(ca): Sequential((0): Conv2d(64, 8, kernel_size=(1, 1), stride=(1, 1))(1): ReLU(inplace=True)(2): Conv2d(8, 64, kernel_size=(1, 1), stride=(1, 1))(3): Sigmoid()))(palayer): PALayer((pa): Sequential((0): Conv2d(64, 8, kernel_size=(1, 1), stride=(1, 1))(1): ReLU(inplace=True)(2): Conv2d(8, 1, kernel_size=(1, 1), stride=(1, 1))(3): Sigmoid())))(1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))))(g2): Group((gp): Sequential((0): Block((conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(act1): ReLU(inplace=True)(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(calayer): CALayer((avg_pool): AdaptiveAvgPool2d(output_size=1)(ca): Sequential((0): Conv2d(64, 8, kernel_size=(1, 1), stride=(1, 1))(1): ReLU(inplace=True)(2): Conv2d(8, 64, kernel_size=(1, 1), stride=(1, 1))(3): Sigmoid()))(palayer): PALayer((pa): Sequential((0): Conv2d(64, 8, kernel_size=(1, 1), stride=(1, 1))(1): ReLU(inplace=True)(2): Conv2d(8, 1, kernel_size=(1, 1), stride=(1, 1))(3): Sigmoid())))(1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))))(g3): Group((gp): Sequential((0): Block((conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(act1): ReLU(inplace=True)(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(calayer): CALayer((avg_pool): AdaptiveAvgPool2d(output_size=1)(ca): Sequential((0): Conv2d(64, 8, kernel_size=(1, 1), stride=(1, 1))(1): ReLU(inplace=True)(2): Conv2d(8, 64, kernel_size=(1, 1), stride=(1, 1))(3): Sigmoid()))(palayer): PALayer((pa): Sequential((0): Conv2d(64, 8, kernel_size=(1, 1), stride=(1, 1))(1): ReLU(inplace=True)(2): Conv2d(8, 1, kernel_size=(1, 1), stride=(1, 1))(3): Sigmoid())))(1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))))(ca): Sequential((0): AdaptiveAvgPool2d(output_size=1)(1): Conv2d(192, 4, kernel_size=(1, 1), stride=(1, 1))(2): ReLU(inplace=True)(3): Conv2d(4, 192, kernel_size=(1, 1), stride=(1, 1))(4): Sigmoid())(palayer): PALayer((pa): Sequential((0): Conv2d(64, 8, kernel_size=(1, 1), stride=(1, 1))(1): ReLU(inplace=True)(2): Conv2d(8, 1, kernel_size=(1, 1), stride=(1, 1))(3): Sigmoid()))(pre): Sequential((0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))(post): Sequential((0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): Conv2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
)

用summary()调出每层的输出大小和参数

pip install torchsummary

在FFA.py末添加:

from torchsummary import summary
summary(net, input_size=(3, 64, 64), batch_size=1)

结果如下:

----------------------------------------------------------------Layer (type)               Output Shape         Param #
================================================================Conv2d-1            [1, 64, 64, 64]           1,792Conv2d-2            [1, 64, 64, 64]          36,928ReLU-3            [1, 64, 64, 64]               0Conv2d-4            [1, 64, 64, 64]          36,928AdaptiveAvgPool2d-5              [1, 64, 1, 1]               0Conv2d-6               [1, 8, 1, 1]             520ReLU-7               [1, 8, 1, 1]               0Conv2d-8              [1, 64, 1, 1]             576Sigmoid-9              [1, 64, 1, 1]               0CALayer-10            [1, 64, 64, 64]               0Conv2d-11             [1, 8, 64, 64]             520ReLU-12             [1, 8, 64, 64]               0Conv2d-13             [1, 1, 64, 64]               9Sigmoid-14             [1, 1, 64, 64]               0PALayer-15            [1, 64, 64, 64]               0Block-16            [1, 64, 64, 64]               0Conv2d-17            [1, 64, 64, 64]          36,928Group-18            [1, 64, 64, 64]               0Conv2d-19            [1, 64, 64, 64]          36,928ReLU-20            [1, 64, 64, 64]               0Conv2d-21            [1, 64, 64, 64]          36,928
AdaptiveAvgPool2d-22              [1, 64, 1, 1]               0Conv2d-23               [1, 8, 1, 1]             520ReLU-24               [1, 8, 1, 1]               0Conv2d-25              [1, 64, 1, 1]             576Sigmoid-26              [1, 64, 1, 1]               0CALayer-27            [1, 64, 64, 64]               0Conv2d-28             [1, 8, 64, 64]             520ReLU-29             [1, 8, 64, 64]               0Conv2d-30             [1, 1, 64, 64]               9Sigmoid-31             [1, 1, 64, 64]               0PALayer-32            [1, 64, 64, 64]               0Block-33            [1, 64, 64, 64]               0Conv2d-34            [1, 64, 64, 64]          36,928Group-35            [1, 64, 64, 64]               0Conv2d-36            [1, 64, 64, 64]          36,928ReLU-37            [1, 64, 64, 64]               0Conv2d-38            [1, 64, 64, 64]          36,928
AdaptiveAvgPool2d-39              [1, 64, 1, 1]               0Conv2d-40               [1, 8, 1, 1]             520ReLU-41               [1, 8, 1, 1]               0Conv2d-42              [1, 64, 1, 1]             576Sigmoid-43              [1, 64, 1, 1]               0CALayer-44            [1, 64, 64, 64]               0Conv2d-45             [1, 8, 64, 64]             520ReLU-46             [1, 8, 64, 64]               0Conv2d-47             [1, 1, 64, 64]               9Sigmoid-48             [1, 1, 64, 64]               0PALayer-49            [1, 64, 64, 64]               0Block-50            [1, 64, 64, 64]               0Conv2d-51            [1, 64, 64, 64]          36,928Group-52            [1, 64, 64, 64]               0
AdaptiveAvgPool2d-53             [1, 192, 1, 1]               0Conv2d-54               [1, 4, 1, 1]             772ReLU-55               [1, 4, 1, 1]               0Conv2d-56             [1, 192, 1, 1]             960Sigmoid-57             [1, 192, 1, 1]               0Conv2d-58             [1, 8, 64, 64]             520ReLU-59             [1, 8, 64, 64]               0Conv2d-60             [1, 1, 64, 64]               9Sigmoid-61             [1, 1, 64, 64]               0PALayer-62            [1, 64, 64, 64]               0Conv2d-63            [1, 64, 64, 64]          36,928Conv2d-64             [1, 3, 64, 64]           1,731
================================================================
Total params: 379,939
Trainable params: 379,939
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.05
Forward/backward pass size (MB): 56.35
Params size (MB): 1.45
Estimated Total Size (MB): 57.85
----------------------------------------------------------------

注: CALayer-10即为残差连接部分,对应于Class CALayer 中最后一条语句 return x * y。假如summary()中不指定batch_size,那么Output Shape 的第一个轴将为-1。

总结:
整个网络由1个卷积层+3个群结构+Concatenate模块+1个CA模块+1个PA模块组成+2个卷积层组成,其中,每个群结构包含19个基础块结构,每个基础块结构又由1个卷积层+1个relu层+1个卷积层+1个CA模块+1个PA模块组成,CA和PA模块详细见“主要内容”部分,另外通过长跳和短跳残差连接绕过薄雾或低频区域等不太重要的信息,使得信息的流动更加容易。一般网络越深(如大于400层),网络训练将更加困难,使用残差连接能够让很深的网络训练更加容易。本文网络共704层,训练总参数:4455913。
疑问:作者在PA模块代码中使用1x1卷积核和论文描述不符。(见疑惑1)
疑问:在PA模块中,实现像素注意的原理。(见疑惑2)

FFA-Net:文章理解于代码注释相关推荐

  1. FFA-Net:文章理解与代码注释

    FFA-Net: Feature Fusion Attention Network for Single Image Dehazing (AAAI 2020) Pytorch代码(GitHub) 本文 ...

  2. 归并排序(代码注释超详细)

    归并排序: (复制粘贴百度百科没什么意思),简单来说,就是对数组进行分组,然后分组进行排序,排序完最后再整合起来排序! 我看了很多博客,都是写的8个数据呀什么的(2^4,分组方便),我就想着,要是10 ...

  3. 代码注释//_您应该停止编写//的五个代码注释,并且//应该开始的一个注释

    代码注释// 提供来自您最喜欢和最受欢迎的开源项目的示例-React,Angular,PHP,Pandas等! (With examples from your favorite and most p ...

  4. tensorflow笔记:流程,概念和简单代码注释

    tensorflow是google在2015年开源的深度学习框架,可以很方便的检验算法效果.这两天看了看官方的tutorial,极客学院的文档,以及综合tensorflow的源码,把自己的心得整理了一 ...

  5. yolov3网络结构图_目标检测——YOLO V3简介及代码注释(附github代码——已跑通)...

    GitHub: liuyuemaicha/PyTorch-YOLOv3​github.com 注:该代码fork自eriklindernoren/PyTorch-YOLOv3,该代码相比master分 ...

  6. Kotlin------函数和代码注释

    定义函数 Kotlin定义一个函数的风格大致如下 访问控制符 fun 方法名(参数,参数,参数) : 返回值类型{...... } 访问控制符:与Java有点差异,Kotlin的访问范围从大到小分别是 ...

  7. php代码注释处理类库,php代码注释

    代码注释在多人开发的时候非常重要,现象一下,一段代码没有任何主要你去结合运行的效果去看实现的逻辑,那是非常费劲的事. 如果让别人看懂你写的代码,代码注释启动非常重要的作用.一个不会写代码注释的不是一个 ...

  8. 竟有如此沙雕的代码注释!

    点击上方蓝色"程序猿DD",选择"设为星标" 回复"资源"获取独家整理的学习资料! 某站后端代码被"开源",同时刷遍全网 ...

  9. java的注释规范_Java代码注释规范

    1,单行(单行)-简短说明: ///... 单行注释: 代码中的单行注释. 最好在注释前有一个空行,并在其后加上与代码相同的缩进级别. 如果无法完成一行,则应使用块注释. 评论格式: 在行首注释: 在 ...

最新文章

  1. css编写要注意什么 及一些公用的样式和外部引用 转码
  2. linux统计日志,Linux一些常使用的统计日志 方法
  3. Jupyter中打印所有结果的解决办法
  4. [转]Android Studio系列教程六--Gradle多渠道打包
  5. Python中的字符串与字符编码:编码和转换问题
  6. 数据库 查询XML XQuery
  7. 论文笔记 - 《ImageNet Classification with Deep Convolutional Neural Networks》 精典
  8. gin post 数据参数_golang--gin获取post里body的参数
  9. 二维数组作数据源填充到repeater
  10. protobuf-3.0 win环境编译
  11. 24.Creating Customer Groups
  12. SqlServer中Group By高级使用--Inner Join分组统计
  13. 20155202 《Java程序设计》实验二(面向对象程序设计)实验报告
  14. Basic knowledge about python
  15. ansys workbench汉化教程_ansys16.0软件下载及安装教程
  16. 物联网技术-RFID
  17. 小布老师oracle,小布老师-oracle-1
  18. ilo看服务器信息,查询ILO信息
  19. Java实现 四舍五入取整到百位 四舍五入取整到千位 数字取整到千位 数字取值到千位 数字取整到百位 数字取值到百位
  20. Laravel :Laravel、Symfony、 Zend 对比测试

热门文章

  1. 《道德经》第四十五章
  2. 2018面试题大总结(一)
  3. CISCO交换机概览
  4. 光模块的中心波长和传输距离
  5. Java实现模拟电梯上下楼,初学者练手
  6. Python 逢7拍手小游戏
  7. MyBatis配置文件(八):mappers配置
  8. 静态方法中注入bean对象
  9. 解决【unity3d】播放视频的两种操作方式
  10. [转]永远的Beyond