主要使用pytorch进行训练
将图片裁剪为四个
只需要修改训练数据路径,数据里面只保留dng文件

import torch
from  torch.utils.data import Dataset
import torch.utils.data as Data
import numpy as np
import os
import rawpy
from unetTorch import Unet
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)#加载数据集
class facadeData1(Dataset):def __init__(self,path='./data/train',data_file='noisy',label_file='ground'):self.path=pathself.data_file=data_fileself.label_file=label_fileself.file_list=os.listdir(path+'/'+data_file)def __len__(self):return len(self.file_list)def read_image(self,input_path):raw = rawpy.imread(input_path)raw_data = raw.raw_image_visibleheight = raw_data.shape[0]width = raw_data.shape[1]raw_data_expand = np.expand_dims(raw_data, axis=2)raw_data_expand_c = np.concatenate((raw_data_expand[0:height:2, 0:width:2, :],raw_data_expand[0:height:2, 1:width:2, :],raw_data_expand[1:height:2, 0:width:2, :],raw_data_expand[1:height:2, 1:width:2, :]), axis=2)return raw_data_expand_c, height, widthdef normalization(self,input_data, black_level=1024, white_level=16383):output_data = (input_data.astype(float) - black_level) / (white_level - black_level)return output_datadef __getitem__(self, index):image_name=self.file_list[index]label_name=image_name.replace('noise','gt')image_path=self.path+'/'+self.data_file+'/'+image_namelabel_path=self.path+'/'+self.label_file+'/'+label_nameimage,height,width=self.read_image(image_path)label,height2,width2=self.read_image(label_path)#标准化image_norm=self.normalization(image)label_norm=self.normalization(label)#扩展维度image=torch.from_numpy(np.transpose(image_norm.reshape(-1, height//2, width//2, 4), (0, 3, 1, 2))).float()image=image[:,:,0:868,0:1156]label=torch.from_numpy(np.transpose(label_norm.reshape(-1, height // 2, width//2, 4), (0, 3, 1, 2))).float()label=label[:, :, 0:868, 0:1156]return image,label
class facadeData2(Dataset):def __init__(self,path='./data/train',data_file='noisy',label_file='ground'):self.path=pathself.data_file=data_fileself.label_file=label_fileself.file_list=os.listdir(path+'/'+data_file)def __len__(self):return len(self.file_list)def read_image(self,input_path):raw = rawpy.imread(input_path)raw_data = raw.raw_image_visibleheight = raw_data.shape[0]width = raw_data.shape[1]raw_data_expand = np.expand_dims(raw_data, axis=2)raw_data_expand_c = np.concatenate((raw_data_expand[0:height:2, 0:width:2, :],raw_data_expand[0:height:2, 1:width:2, :],raw_data_expand[1:height:2, 0:width:2, :],raw_data_expand[1:height:2, 1:width:2, :]), axis=2)return raw_data_expand_c, height, widthdef normalization(self,input_data, black_level=1024, white_level=16383):output_data = (input_data.astype(float) - black_level) / (white_level - black_level)return output_datadef __getitem__(self, index):image_name=self.file_list[index]label_name=image_name.replace('noise','gt')image_path=self.path+'/'+self.data_file+'/'+image_namelabel_path=self.path+'/'+self.label_file+'/'+label_nameimage,height,width=self.read_image(image_path)label,height2,width2=self.read_image(label_path)#标准化image_norm=self.normalization(image)label_norm=self.normalization(label)#扩展维度image=torch.from_numpy(np.transpose(image_norm.reshape(-1, height//2, width//2, 4), (0, 3, 1, 2))).float()image=image[:,:,868:1736,0:1156]label=torch.from_numpy(np.transpose(label_norm.reshape(-1, height // 2, width//2, 4), (0, 3, 1, 2))).float()label=label[:, :, 868:1736, 0:1156]return image,label
class facadeData3(Dataset):def __init__(self,path='./data/train',data_file='noisy',label_file='ground'):self.path=pathself.data_file=data_fileself.label_file=label_fileself.file_list=os.listdir(path+'/'+data_file)def __len__(self):return len(self.file_list)def read_image(self,input_path):raw = rawpy.imread(input_path)raw_data = raw.raw_image_visibleheight = raw_data.shape[0]width = raw_data.shape[1]raw_data_expand = np.expand_dims(raw_data, axis=2)raw_data_expand_c = np.concatenate((raw_data_expand[0:height:2, 0:width:2, :],raw_data_expand[0:height:2, 1:width:2, :],raw_data_expand[1:height:2, 0:width:2, :],raw_data_expand[1:height:2, 1:width:2, :]), axis=2)return raw_data_expand_c, height, widthdef normalization(self,input_data, black_level=1024, white_level=16383):output_data = (input_data.astype(float) - black_level) / (white_level - black_level)return output_datadef __getitem__(self, index):image_name=self.file_list[index]label_name=image_name.replace('noise','gt')image_path=self.path+'/'+self.data_file+'/'+image_namelabel_path=self.path+'/'+self.label_file+'/'+label_nameimage,height,width=self.read_image(image_path)label,height2,width2=self.read_image(label_path)#标准化image_norm=self.normalization(image)label_norm=self.normalization(label)#扩展维度image=torch.from_numpy(np.transpose(image_norm.reshape(-1, height//2, width//2, 4), (0, 3, 1, 2))).float()image=image[:,:,0:868,1156:2312]label=torch.from_numpy(np.transpose(label_norm.reshape(-1, height // 2, width//2, 4), (0, 3, 1, 2))).float()label=label[:, :, 0:868, 1156:2312]return image,label
class facadeData4(Dataset):def __init__(self,path='./data/train',data_file='noisy',label_file='ground'):self.path=pathself.data_file=data_fileself.label_file=label_fileself.file_list=os.listdir(path+'/'+data_file)def __len__(self):return len(self.file_list)def read_image(self,input_path):raw = rawpy.imread(input_path)raw_data = raw.raw_image_visibleheight = raw_data.shape[0]width = raw_data.shape[1]raw_data_expand = np.expand_dims(raw_data, axis=2)raw_data_expand_c = np.concatenate((raw_data_expand[0:height:2, 0:width:2, :],raw_data_expand[0:height:2, 1:width:2, :],raw_data_expand[1:height:2, 0:width:2, :],raw_data_expand[1:height:2, 1:width:2, :]), axis=2)return raw_data_expand_c, height, widthdef normalization(self,input_data, black_level=1024, white_level=16383):output_data = (input_data.astype(float) - black_level) / (white_level - black_level)return output_datadef __getitem__(self, index):image_name=self.file_list[index]label_name=image_name.replace('noise','gt')image_path=self.path+'/'+self.data_file+'/'+image_namelabel_path=self.path+'/'+self.label_file+'/'+label_nameimage,height,width=self.read_image(image_path)label,height2,width2=self.read_image(label_path)#标准化image_norm=self.normalization(image)label_norm=self.normalization(label)#扩展维度image=torch.from_numpy(np.transpose(image_norm.reshape(-1, height//2, width//2, 4), (0, 3, 1, 2))).float()image=image[:,:,868:1736,1156:2312]label=torch.from_numpy(np.transpose(label_norm.reshape(-1, height // 2, width//2, 4), (0, 3, 1, 2))).float()label=label[:, :, 868:1736, 1156:2312]return image,label
class Trainer(object):def __init__(self,net,lr=1e-4,batch_size=10,num_epoch=1000,train_data=None):self.net=netself.net=self.net.to(device)self.batch_size=batch_sizeself.lr=lrself.num_epoch=num_epochself.train_data=train_dataself.data_loader=Data.DataLoader(dataset=train_data,batch_size=batch_size,shuffle=True)self.loss=torch.nn.MSELoss(reduce=True, size_average=True)self.loss = self.loss.to(device)def train(self):optim=torch.optim.Adam(self.net.parameters(),lr=self.lr)#self.net.load_state_dict(torch.load('./models2/model30.pth'))for epoch in range(self.num_epoch):for i,(bx,by) in enumerate(self.data_loader):bx=bx.squeeze(dim = 1)bx =bx.to(device)by=by.squeeze(dim = 1)by = by.to(device)pre=self.net(bx)loss =self.loss(pre, by)  # 求损失  # patch 123#bugoptim.zero_grad()loss.backward()optim.step()print('i',i,'epoch',epoch,'loss',loss)if epoch%10==0:torch.save(self.net.state_dict(), f'./models3/model{epoch}.pth')return Noneif __name__ == '__main__':train_data1=facadeData1()train_data2=facadeData2()train_data3=facadeData3()train_data4=facadeData4()train_data=torch.utils.data.ConcatDataset([train_data1]+[train_data2]+[train_data3]+[train_data3])# # print(train_data)# image,label=data.__getitem__(100)# print(train_data)net=Unet()trainer=Trainer(net=net,train_data=train_data)trainer.train()

中兴捧月RAW图像去噪训练代码相关推荐

  1. 2020年中兴捧月算法大赛---埃德加考特派赛题解析及代码

    写在前面 三月份疫情期间在家闲来无事, 各大公司举办了很多的算法比赛, 但是大多是人工智能相关, 而我这个菜鸡又不会这方面的, 这时发现了中兴捧月的埃德加考特派赛道, 也就是数据库相关本科课设, 恰好 ...

  2. 谈谈中兴捧月大赛决赛以及总结

    前言 四月份,在师兄的推荐下,报名参加了中兴捧月大赛.一开始只是为了混一个面笔试的资格(因为提交有效成绩即可免笔试),然后为了找一个简单的赛道,注册了几个号看了两三个赛道的题目.发现自己每个都不熟悉, ...

  3. Deepsort_V2 2020中兴捧月阿尔法赛道多目标检测和跟踪初赛第一名

    2020中兴捧月阿尔法赛道多目标检测和跟踪初赛第一名方案 初赛:多目标跟踪:指标MOTA和MOTP, 后期的大量实验证明检测算法相对于跟踪更重要. 数据集分析: 1.人群密集稀疏场景: 2.场景(白天 ...

  4. “中兴捧月”报文监视器的实现

    前一段时间和几个好友组队参加了中兴举办的"中兴捧月杯"程序设计大赛,跌跌撞撞竟然进了复赛,不过最终还是没能入围区域决赛,还是感觉很遗憾.这里把当时复赛的题目以及我们被Out的代码拿 ...

  5. 2023第十三届“中兴捧月”全球精英挑战赛今日正式启动

    3月31日,第十三届"中兴捧月"全球精英挑战赛正式启动! 由中兴通讯主办的"中兴捧月"大赛,自2009年首次举办至今,已走过13个年头,是广大高校师生的重点关注 ...

  6. 2020中兴捧月算法大赛-阿尔法(MOT)赛道--赛后总结

    比赛结束了,很荣幸拿到了中兴捧月算法大赛 MOT赛道 全国总决赛第二名的亚军奖杯,这估计也是我找到工作前最后一个比赛了,收获满满,下面算是自己给自己写的一个简单的赛后总结,做的比较粗糙,细节也就不多赘 ...

  7. 2020中兴捧月算法赛道傅里叶派赛题菜鸡回顾

    最近抱着试水的心理参加了2020中兴捧月算法大赛傅里叶派赛题.从4.19号由旁观者转变为参赛者,到5.8号提交完成最后的文档和代码,前后算起来也有20天了.虽然自己比较菜,但毕竟是第一次参加这种比较正 ...

  8. 中兴捧月大赛之方案探讨

    昨天参加完中兴捧月的决赛,感觉特别的糟糕.说实话,感觉中兴这次比赛搞得真的很奇葩!一是比赛搞得让我觉得公司对整个比赛的态度有点随意,有点不正式.二是比赛的赛题要求每天都在变,感觉不天天关注活动交流区, ...

  9. 使用restormer网络做2022年中兴捧月图像去噪

    目录 1.预处理 2.神经网络 3.数据增强 4.训练细节 5.loss选择 6.小trick 7.没来得及用的手段 1.预处理 由于官方提供的图像的分辨率过大,所以为了计算高效,先将其裁剪为200* ...

最新文章

  1. CISCO HSRP
  2. response 流和写能一起吗_2133和2400能一起用吗
  3. 计算机视觉-自定义对象检测器
  4. 磁盘阵列服务器Intel C610系列,超微6048R-E1CR36N 36盘位存储服务器 磁盘阵列
  5. 列表推导(list comprehension)--Python
  6. [Flink]Flink的window介绍
  7. 全国计算机一级考试难度高吗,计算机等级考试分几级 考试难度大不大
  8. window - 安装 tomcat
  9. Java探索之旅(18)——多线程(2)
  10. zen3 服务器芯片,AMD EPYC霄龙服务器处理器亮相,Zen3架构性能飙升
  11. PPT封面怎样设计才更赏心悦目
  12. word排版案例报告_看完这4个文章排版要点,你就会排版啦!
  13. 后端开发面试自我介绍_java工程师面试自我介绍范文
  14. mvn help:system下载包失败解决
  15. 【数据结构与算法】数据结构有哪些?算法有哪些?
  16. [SWPUCTF 2021 新生赛]第一波放题(nssctf刷题)
  17. 服务器通过笔记本共享网络连接外网
  18. 当当网读书排行榜爬虫
  19. python必备源代码-Python 自用代码(某方标准类网页源代码清洗)
  20. 全国计算机等级模拟考试软件答案,全国计算机等级考试模拟题答案

热门文章

  1. 聚类之 FCM 算法原理及应用(Java可视化实现)
  2. 网新恒天2011.9.21招聘会笔试题
  3. 中兴ZXA10-F460 v3.0获取超级管理员密码
  4. 高低温对Nand Flash原始误码率(RBER)及Operation time的影响
  5. 2017第十五届中国国际科学仪器及实验室装备展览会会刊(参展商名录)
  6. Guitar Pro里的渐强渐弱符号
  7. 2FA的完整形式是什么?
  8. SAP ABAP 数据填充进EXCEL模板或代码生成EXCEL并维护数据下载到本地
  9. 三菱plc,x的n次方程序教程
  10. 什么是BS?BS和CS模式的区别是什么?