pytorch基于Unet的铁轨缺陷语义分割

  • 分割效果
  • Unet网络
  • 数据读取器
  • 训练及保存模型
  • 遇到问题

分割效果

Unet网络

""" Full assembly of the parts to form the complete network """
"""Refer https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_model.py"""import torch.nn.functional as Ffrom unet_parts import *class UNet(nn.Module):def __init__(self, n_channels, n_classes, bilinear=False):super(UNet, self).__init__()self.n_channels = n_channelsself.n_classes = n_classesself.bilinear = bilinearself.inc = DoubleConv(n_channels, 64)self.down1 = Down(64, 128)self.down2 = Down(128, 256)self.down3 = Down(256, 512)self.down4 = Down(512, 1024)self.up1 = Up(1024, 512, bilinear)self.up2 = Up(512, 256, bilinear)self.up3 = Up(256, 128, bilinear)self.up4 = Up(128, 64, bilinear)self.outc = OutConv(64, n_classes)def forward(self, x):x1 = self.inc(x)x2 = self.down1(x1)x3 = self.down2(x2)x4 = self.down3(x3)x5 = self.down4(x4)x = self.up1(x5, x4)x = self.up2(x, x3)x = self.up3(x, x2)x = self.up4(x, x1)logits = self.outc(x)return logits# if __name__ == '__main__':
#     net = UNet(n_channels=3, n_classes=1)
#     print(net)
""" Parts of the U-Net model """
"""https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_parts.py"""import torch
import torch.nn as nn
import torch.nn.functional as Fclass DoubleConv(nn.Module):"""(convolution => [BN] => ReLU) * 2"""def __init__(self, in_channels, out_channels):super().__init__()self.double_conv = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))def forward(self, x):return self.double_conv(x)class Down(nn.Module):"""Downscaling with maxpool then double conv"""def __init__(self, in_channels, out_channels):super().__init__()self.maxpool_conv = nn.Sequential(nn.MaxPool2d(2),DoubleConv(in_channels, out_channels))def forward(self, x):return self.maxpool_conv(x)class Up(nn.Module):"""Upscaling then double conv"""def __init__(self, in_channels, out_channels, bilinear=True):super().__init__()# if bilinear, use the normal convolutions to reduce the number of channelsif bilinear:self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)else:self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)self.conv = DoubleConv(in_channels, out_channels)def forward(self, x1, x2):x1 = self.up(x1)# input is CHWdiffY = torch.tensor([x2.size()[2] - x1.size()[2]])diffX = torch.tensor([x2.size()[3] - x1.size()[3]])x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,diffY // 2, diffY - diffY // 2])# if you have padding issues, see# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bdx = torch.cat([x2, x1], dim=1)return self.conv(x)class OutConv(nn.Module):def __init__(self, in_channels, out_channels):super(OutConv, self).__init__()self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)def forward(self, x):return self.conv(x)

数据读取器

import torch
import cv2
import os
import glob
from torch.utils.data import Dataset
from PIL import Image
import random
import numpy as npclass ISBI_Loader(Dataset):def __init__(self, data_path, label_path):# 初始化函数,读取所有data_path下的图片self.images_path = glob.glob(os.path.join(data_path, '*.jpg'))self.labels_path = glob.glob(os.path.join(label_path, '*.jpg'))print(self.images_path)print(self.labels_path)def augment(self, image, flipCode):# 使用cv2.flip进行数据增强,filpCode为1水平翻转,0垂直翻转,-1水平+垂直翻转flip = cv2.flip(image, flipCode)return flipdef __getitem__(self, index):# 根据index读取图片image_path = self.images_path[index]# 根据image_path生成label_pathlabel_path = self.labels_path[index]# 读取训练图片和标签图片image = cv2.imread(image_path)label = cv2.imread(label_path)# image = cv2.resize(image, dsize=(1000, 160))# label = cv2.resize(label, dsize=(1000, 160))# 将数据转为单通道的图片image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)label = cv2.cvtColor(label, cv2.COLOR_BGR2GRAY)image = image.reshape(1, image.shape[0], image.shape[1])label = label.reshape(1, label.shape[0], label.shape[1])# 处理标签,将像素值为255的改为1if label.max() > 1:label = label / 255# 随机进行数据增强,为2时不做处理flipCode = random.choice([-1, 0, 1, 2])if flipCode != 2:image = self.augment(image, flipCode)label = self.augment(label, flipCode)return image, labeldef __len__(self):# 返回训练集大小return len(self.images_path)if __name__ == "__main__":isbi_dataset = ISBI_Loader("Railsurfaceimages/", "GroundTruth/")print("数据个数:", len(isbi_dataset))train_loader = torch.utils.data.DataLoader(dataset=isbi_dataset,batch_size=2,shuffle=False)for image, label in train_loader:print(image.shape)

训练及保存模型

import torch.nn as nn
import torch
# from model import Unet
from data import ISBI_Loader
import numpy as np
from unet_model import UNetdef train_net(net, device, train_loader, epochs=40, batch_size=1, lr=0.00001):# 加载训练集train_loader = train_loader# 定义RMSprop算法optimizer = torch.optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)# 定义Loss算法criterion = nn.BCEWithLogitsLoss()# best_loss统计,初始化为正无穷best_loss = float('inf')# 训练epochs次for epoch in range(epochs):# 训练模式net.train()# 按照batch_size开始训练for image, label in train_loader:optimizer.zero_grad()# 将数据拷贝到device中image = image.to(device=device, dtype=torch.float32)label = label.to(device=device, dtype=torch.float32)# 使用网络参数,输出预测结果pred = net(image)# 计算lossloss = criterion(pred, label)print('Loss/train', loss.item())# 保存loss值最小的网络参数if loss < best_loss:best_loss = losstorch.save(net.state_dict(), 'best_model.pth')# 更新参数loss.backward()optimizer.step()if __name__ == '__main__':# 加载训练数据isbi_dataset = ISBI_Loader("Railsurfaceimages/", "GroundTruth/")train_loader = torch.utils.data.DataLoader(dataset=isbi_dataset,batch_size=2,shuffle=True)device = 'cuda'net = UNet(n_channels=1, n_classes=1)net.to(device=device)train_net(net, device, train_loader, epochs=40, batch_size=1, lr=0.00001)

遇到问题


padding值设为1,卷积时自动填充,不然会因尺寸为问题导致错误。

pytorch基于Unet的铁轨缺陷语义分割相关推荐

  1. 基于U-Net+残差网络的语义分割缺陷检测

    一.介绍 基于深度学习的缺陷检测主要集中在场景识别.object detection等方法,近年来,推出了一系列优秀的语义分割模型,比如SegNet.FCN.U-Net等.语义分割模型被广泛的应用到场 ...

  2. 【PytorchLearning】基于 UNet 的肺部影像语义分割案例保姆教程

    基于 UNet 的肺部影像分割 一般而言,计算机视觉领域包含三大主流任务:分类.检测.分割.其中,分类任务对模型的要求较为简单,在之前的Pytorch入门教程中已进行了较为详尽的介绍,有兴趣的小伙伴可 ...

  3. 【Keras】基于SegNet和U-Net的遥感图像语义分割

    from:[Keras]基于SegNet和U-Net的遥感图像语义分割 上两个月参加了个比赛,做的是对遥感高清图像做语义分割,美其名曰"天空之眼".这两周数据挖掘课期末projec ...

  4. Keras】基于SegNet和U-Net的遥感图像语义分割

    from:[Keras]基于SegNet和U-Net的遥感图像语义分割 上两个月参加了个比赛,做的是对遥感高清图像做语义分割,美其名曰"天空之眼".这两周数据挖掘课期末projec ...

  5. 【论文阅读】Swin Transformer Embedding UNet用于遥感图像语义分割

    [论文阅读]Swin Transformer Embedding UNet用于遥感图像语义分割 文章目录 [论文阅读]Swin Transformer Embedding UNet用于遥感图像语义分割 ...

  6. 深度学习(二十五)基于Mutil-Scale CNN的图片语义分割、法向量估计-ICCV 2015

    基于Mutil-Scale CNN的图片语义分割.法向量估计 原文地址:http://blog.csdn.net/hjimce/article/details/50443995 作者:hjimce 一 ...

  7. 深度学习(二十五)基于Mutil-Scale CNN的图片语义分割、法向量估计

    基于Mutil-Scale CNN的图片语义分割.法向量估计 原文地址:http://blog.csdn.net/hjimce/article/details/50443995 作者:hjimce 一 ...

  8. CV | Feature Space Optimization for Semantic Video Segmentation - 基于特征空间优化的视频语义分割

    前言:今天分享的这一篇文章是CVPR2016有关视频语义分割方向的,最近才开始学习语义分割相关的文献,有理解偏差的希望大家可以指正. 语义分割 在维基百科上面没有直接定义,但从字面上就可以理解,就是将 ...

  9. SAFNet 基于相似性感知的三维语义分割融合网络

    SAFNet 基于相似性感知的三维语义分割融合网络 论文 Similarity-Aware Fusion Network for 3D Semantic Segmentation IROS 2021 ...

  10. 语义分割源代码_综述 | 基于深度学习的实时语义分割方法:全面调研

    34页综述,共计119篇参考文献.本文对图像分割中的最新深度学习体系结构进行了全面分析,更重要的是,它提供了广泛的技术列表以实现快速推理和计算效率. A Survey on Deep Learning ...

最新文章

  1. 快快: 一点即玩的游戏客户端平台
  2. mysql mgr故障恢复实现_MGR实现分析 - 成员管理与故障恢复实现
  3. 基于RBAC模型的通用权限管理系统的设计(数据模型)的扩展
  4. ant接口自动化 junit_ant 学习(3)--结合junit形成自动化测试小框架
  5. C#连接4种类型数据库(Access、SQL Server、Oracle、MySQL)
  6. 错误: 找不到或无法加载主类 org.apache.flume.tools.GetJavaProperty
  7. Team Foundation 和 Visual SourceSafe 之间的区别
  8. 如何用3DsMax制作笔记本电脑
  9. FileNet 开发资料 官方红皮书
  10. 大数据与传统数仓的区别?
  11. 【仪器常用操作方法】33500B函数发生器常用操作方法
  12. 电商物流仓储流程图模板分享
  13. HCIA网络基础7-VRP和命令行基础
  14. APView500电能质量在线监测装置 谐波分析 电压不平衡
  15. 轻轻松松背单词软件测试,Englishfield词汇记忆与测试
  16. 100层楼2个鸡蛋,测试其最低破碎楼层问题
  17. Android 环信 消息免打扰 实现
  18. SQL零基础入门学习(八)
  19. Android 圆形 ImageView
  20. Gamma校正与线性空间

热门文章

  1. This.$Set的用法和作用
  2. 思修期末考试复习题库
  3. microsoft office professional plus 2007安装的最后一根稻草
  4. 【python教程入门学习】Pandas库下载、安装和更新
  5. python 生成图片人脸识别器(调用百度AI)
  6. 长沙课工场中秋节五仁来袭
  7. java高级 --- 各集合存null值问题
  8. 【Js13kGames】基于JavaScript 创造仅有13kb大小的游戏世界
  9. 一个北漂程序员的抉择:我要离开帝都了
  10. 明枪易躲暗箭难防,规避引入RPA时的规则、文化与管理难题