DeeplabV3算法

  • 参数配置
    • 关于数据集的配置
    • 训练集参数
  • 数据预处理模块
    • DataSet构建模块
    • 测试一下数据集
    • 去正则化
    • 模型加载模块
    • DeepLABV3+

参数配置

关于数据集的配置

    parser = argparse.ArgumentParser()# Datset Optionsparser.add_argument("--data_root", type=str, default=r'D:/',help="path to Dataset")parser.add_argument("--dataset", type=str, default='voc',choices=['voc', 'cityscapes'], help='Name of dataset')parser.add_argument("--num_classes", type=int, default=None,help="num classes (default: None)")# Deeplab Options# 选择模型的架构,特征提取模块分为moiblienet或者resnet50parser.add_argument("--model", type=str, default='deeplabv3plus_resnet50',choices=['deeplabv3_resnet50',  'deeplabv3plus_resnet50','deeplabv3_resnet101', 'deeplabv3plus_resnet101','deeplabv3_mobilenet', 'deeplabv3plus_mobilenet'], help='model name')parser.add_argument("--separable_conv", action='store_true', default=False,help="apply separable conv to decoder and aspp")parser.add_argument("--output_stride", type=int, default=16, choices=[8, 16])

训练集参数

  # Train Options# 制作测试parser.add_argument("--test_only", action='store_true', default=False)parser.add_argument("--save_val_results", action='store_true', default=False,help="save segmentation results to \"./results\"")parser.add_argument("--total_itrs", type=int, default=60e3,help="epoch number (default: 30k)")# 学习率parser.add_argument("--lr", type=float, default=0.01,help="learning rate (default: 0.01)")parser.add_argument("--lr_policy", type=str, default='poly', choices=['poly', 'step'],help="learning rate scheduler policy")parser.add_argument("--step_size", type=int, default=10000)parser.add_argument("--crop_val", action='store_true', default=False,help='crop validation (default: False)')parser.add_argument("--batch_size", type=int, default=8,help='batch size (default: 16)')parser.add_argument("--val_batch_size", type=int, default=4,help='batch size for validation (default: 4)')parser.add_argument("--crop_size", type=int, default=513)# 预训练权重路径parser.add_argument("--ckpt", default="./checkpoint/best_deeplabv3_resnet50_voc_os16.pth", type=str,help="restore from checkpoint")parser.add_argument("--continue_training", action='store_true', default=True)parser.add_argument("--loss_type", type=str, default='cross_entropy',choices=['cross_entropy', 'focal_loss'], help="loss type (default: False)")parser.add_argument("--gpu_id", type=str, default='0',help="GPU ID")# 正则化参数parser.add_argument("--weight_decay", type=float, default=1e-4,help='weight decay (default: 1e-4)')parser.add_argument("--random_seed", type=int, default=1,help="random seed (default: 1)")parser.add_argument("--print_interval", type=int, default=10,help="print interval of loss (default: 10)")parser.add_argument("--val_interval", type=int, default=100,help="epoch interval for eval (default: 100)")parser.add_argument("--download", action='store_true', default=False,help="download datasets")

数据预处理模块

分别针对训练集、验证集、测试集做三种数据增强变换

def get_dataset(opts):""" Dataset And Augmentation"""if opts.dataset == 'voc':train_transform = et.ExtCompose([#et.ExtResize(size=opts.crop_size),et.ExtRandomScale((0.5, 2.0)),et.ExtRandomCrop(size=(opts.crop_size, opts.crop_size), pad_if_needed=True),et.ExtRandomHorizontalFlip(),et.ExtToTensor(),et.ExtNormalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),])if opts.crop_val:val_transform = et.ExtCompose([et.ExtResize(opts.crop_size),et.ExtCenterCrop(opts.crop_size),et.ExtToTensor(),et.ExtNormalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),])else:val_transform = et.ExtCompose([et.ExtToTensor(),et.ExtNormalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),])train_dst = VOCSegmentation(root=opts.data_root, year=opts.year,image_set='train', download=opts.download, transform=train_transform)val_dst = VOCSegmentation(root=opts.data_root, year=opts.year,image_set='val', download=False, transform=val_transform)if opts.dataset == 'cityscapes':train_transform = et.ExtCompose([#et.ExtResize( 512 ),et.ExtRandomCrop(size=(opts.crop_size, opts.crop_size)),et.ExtColorJitter( brightness=0.5, contrast=0.5, saturation=0.5 ),et.ExtRandomHorizontalFlip(),et.ExtToTensor(),et.ExtNormalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),])val_transform = et.ExtCompose([#et.ExtResize( 512 ),et.ExtToTensor(),et.ExtNormalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),])train_dst = Cityscapes(root=opts.data_root,split='train', transform=train_transform)val_dst = Cityscapes(root=opts.data_root,split='val', transform=val_transform)return train_dst, val_dst

DataSet构建模块


def voc_cmap(N=256, normalized=False):def bitget(byteval, idx):return ((byteval & (1 << idx)) != 0)dtype = 'float32' if normalized else 'uint8'cmap = np.zeros((N, 3), dtype=dtype)for i in range(N):r = g = b = 0c = ifor j in range(8):r = r | (bitget(c, 0) << 7-j)g = g | (bitget(c, 1) << 7-j)b = b | (bitget(c, 2) << 7-j)c = c >> 3cmap[i] = np.array([r, g, b])cmap = cmap/255 if normalized else cmapreturn cmap
class VOCSegmentation(data.Dataset):"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset.Args:root (string): Root directory of the VOC Dataset.year (string, optional): The dataset year, supports years 2007 to 2012.image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``download (bool, optional): If true, downloads the dataset from the internet andputs it in root directory. If dataset is already downloaded, it is notdownloaded again.transform (callable, optional): A function/transform that  takes in an PIL imageand returns a transformed version. E.g, ``transforms.RandomCrop``"""cmap = voc_cmap()def __init__(self,root,year='2012',image_set='train',download=False,transform=None):is_aug=Falseif year=='2012_aug':is_aug = Trueyear = '2012'self.root = os.path.expanduser(root)self.year = yearself.url = DATASET_YEAR_DICT[year]['url']self.filename = DATASET_YEAR_DICT[year]['filename']self.md5 = DATASET_YEAR_DICT[year]['md5']self.transform = transformself.image_set = image_setbase_dir = DATASET_YEAR_DICT[year]['base_dir']voc_root = os.path.join(self.root, base_dir)image_dir = os.path.join(voc_root, 'JPEGImages')if download:download_extract(self.url, self.root, self.filename, self.md5)if not os.path.isdir(voc_root):raise RuntimeError('Dataset not found or corrupted.' +' You can use download=True to download it')if is_aug and image_set=='train':mask_dir = os.path.join(voc_root, 'SegmentationClassAug')assert os.path.exists(mask_dir), "SegmentationClassAug not found, please refer to README.md and prepare it manually"split_f = os.path.join( self.root, 'train_aug.txt')#'./datasets/data/train_aug.txt'else:mask_dir = os.path.join(voc_root, 'SegmentationClass')splits_dir = os.path.join(voc_root, 'ImageSets/Segmentation')split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt')if not os.path.exists(split_f):raise ValueError('Wrong image_set entered! Please use image_set="train" ''or image_set="trainval" or image_set="val"')with open(os.path.join(split_f), "r") as f:file_names = [x.strip() for x in f.readlines()]self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]assert (len(self.images) == len(self.masks))def __getitem__(self, index):"""Args:index (int): IndexReturns:tuple: (image, target) where target is the image segmentation."""img = Image.open(self.images[index]).convert('RGB')target = Image.open(self.masks[index])if self.transform is not None:img, target = self.transform(img, target)return img, targetdef __len__(self):return len(self.images)@classmethoddef decode_target(cls, mask):"""decode semantic mask to RGB image"""return cls.cmap[mask]

测试一下数据集

import numpy as npfrom datasets import VOCSegmentation
from utils import ext_transforms as et
import cv2
train_transform = et.ExtCompose([# et.ExtResize(size=opts.crop_size),et.ExtRandomScale((0.5, 2.0)),et.ExtRandomCrop(size=(224, 224), pad_if_needed=True),et.ExtRandomHorizontalFlip(),et.ExtToTensor(),et.ExtNormalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),
])
data = VOCSegmentation(root=r"D:/", year="2012", image_set='train', download=False, transform=train_transform)if __name__ == '__main__':print(data[0][0].shape)print(data[0][1].shape)res = data.decode_target(data[0][1])cv2.imshow("Res",np.array(res))cv2.waitKey(0)

去正则化

class Denormalize(object):def __init__(self, mean, std):mean = np.array(mean)std = np.array(std)self._mean = -mean/stdself._std = 1/stddef __call__(self, tensor):if isinstance(tensor, np.ndarray):return (tensor - self._mean.reshape(-1,1,1)) / self._std.reshape(-1,1,1)return normalize(tensor, self._mean, self._std)

模型加载模块

    # Set up modelmodel_map = {'deeplabv3_resnet50': network.deeplabv3_resnet50,'deeplabv3plus_resnet50': network.deeplabv3plus_resnet50,'deeplabv3_resnet101': network.deeplabv3_resnet101,'deeplabv3plus_resnet101': network.deeplabv3plus_resnet101,'deeplabv3_mobilenet': network.deeplabv3_mobilenet,'deeplabv3plus_mobilenet': network.deeplabv3plus_mobilenet}model = model_map[opts.model](num_classes=opts.num_classes, output_stride=opts.output_stride)def deeplabv3_resnet50(num_classes=21, output_stride=8, pretrained_backbone=True):"""Constructs a DeepLabV3 model with a ResNet-50 backbone.Args:num_classes (int): number of classes.output_stride (int): output stride for deeplab.pretrained_backbone (bool): If True, use the pretrained backbone."""return _load_model('deeplabv3', 'resnet50', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone)def deeplabv3_resnet101(num_classes=21, output_stride=8, pretrained_backbone=True):"""Constructs a DeepLabV3 model with a ResNet-101 backbone.Args:num_classes (int): number of classes.output_stride (int): output stride for deeplab.pretrained_backbone (bool): If True, use the pretrained backbone."""return _load_model('deeplabv3', 'resnet101', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone)def deeplabv3_mobilenet(num_classes=21, output_stride=8, pretrained_backbone=True, **kwargs):"""Constructs a DeepLabV3 model with a MobileNetv2 backbone.Args:num_classes (int): number of classes.output_stride (int): output stride for deeplab.pretrained_backbone (bool): If True, use the pretrained backbone."""return _load_model('deeplabv3', 'mobilenetv2', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone)# Deeplab v3+def deeplabv3plus_resnet50(num_classes=21, output_stride=8, pretrained_backbone=True):"""Constructs a DeepLabV3 model with a ResNet-50 backbone.Args:num_classes (int): number of classes.output_stride (int): output stride for deeplab.pretrained_backbone (bool): If True, use the pretrained backbone."""return _load_model('deeplabv3plus', 'resnet50', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone)

加载模块

def _load_model(arch_type, backbone, num_classes, output_stride, pretrained_backbone):if backbone=='mobilenetv2':model = _segm_mobilenet(arch_type, backbone, num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone)elif backbone.startswith('resnet'):model = _segm_resnet(arch_type, backbone, num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone)else:raise NotImplementedErrorreturn modeldef _segm_resnet(name, backbone_name, num_classes, output_stride, pretrained_backbone):if output_stride==8:replace_stride_with_dilation=[False, True, True]aspp_dilate = [12, 24, 36]else:replace_stride_with_dilation=[False, False, True]aspp_dilate = [6, 12, 18]backbone = resnet.__dict__[backbone_name](pretrained=pretrained_backbone,replace_stride_with_dilation=replace_stride_with_dilation)inplanes = 2048low_level_planes = 256if name=='deeplabv3plus':return_layers = {'layer4': 'out', 'layer1': 'low_level'}#classifier = DeepLabHeadV3Plus(inplanes, low_level_planes, num_classes, aspp_dilate)elif name=='deeplabv3':return_layers = {'layer4': 'out'}classifier = DeepLabHead(inplanes , num_classes, aspp_dilate)#提取网络的第几层输出结果并给一个别名backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)model = DeepLabV3(backbone, classifier)return model

DeepLABV3+

class DeepLabHeadV3Plus(nn.Module):def __init__(self, in_channels, low_level_channels, num_classes, aspp_dilate=[12, 24, 36]):super(DeepLabHeadV3Plus, self).__init__()self.project = nn.Sequential( nn.Conv2d(low_level_channels, 48, 1, bias=False),nn.BatchNorm2d(48),nn.ReLU(inplace=True),)self.aspp = ASPP(in_channels, aspp_dilate)self.classifier = nn.Sequential(nn.Conv2d(304, 256, 3, padding=(1,1), bias=False),nn.BatchNorm2d(256),nn.ReLU(inplace=True),nn.Conv2d(256, num_classes, 1))self._init_weight()def forward(self, feature):#print(feature.shape)low_level_feature = self.project( feature['low_level'] )#return_layers = {'layer4': 'out', 'layer1': 'low_level'}#print(low_level_feature.shape)output_feature = self.aspp(feature['out'])#print(output_feature.shape)output_feature = F.interpolate(output_feature, size=low_level_feature.shape[2:], mode='bilinear', align_corners=False)#print(output_feature.shape)return self.classifier( torch.cat( [ low_level_feature, output_feature ], dim=1 ) )def _init_weight(self):for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight)elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)

其中,空洞融合ASPP模块

class ASPP(nn.Module):def __init__(self, in_channels, atrous_rates):super(ASPP, self).__init__()out_channels = 256modules = []modules.append(nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True)))rate1, rate2, rate3 = tuple(atrous_rates)modules.append(ASPPConv(in_channels, out_channels, rate1))modules.append(ASPPConv(in_channels, out_channels, rate2))modules.append(ASPPConv(in_channels, out_channels, rate3))modules.append(ASPPPooling(in_channels, out_channels))self.convs = nn.ModuleList(modules)self.project = nn.Sequential(nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Dropout(0.1),)def forward(self, x):res = []for conv in self.convs:#print(conv(x).shape)res.append(conv(x))res = torch.cat(res, dim=1)return self.project(res)

卷积转深度可分离卷积

def convert_to_separable_conv(module):new_module = moduleif isinstance(module, nn.Conv2d) and module.kernel_size[0]>1:new_module = AtrousSeparableConvolution(module.in_channels,module.out_channels, module.kernel_size,module.stride,module.padding,module.dilation,module.bias)for name, child in module.named_children():new_module.add_module(name, convert_to_separable_conv(child))return new_module
class AtrousSeparableConvolution(nn.Module):""" Atrous Separable Convolution"""def __init__(self, in_channels, out_channels, kernel_size,stride=1, padding=0, dilation=1, bias=True):super(AtrousSeparableConvolution, self).__init__()self.body = nn.Sequential(# Separable Convnn.Conv2d( in_channels, in_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, groups=in_channels ),# PointWise Convnn.Conv2d( in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias),)self._init_weight()def forward(self, x):return self.body(x)def _init_weight(self):for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight)elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)

深度学习从入门到精通——图像分割实战DeeplabV3相关推荐

  1. 深度学习从入门到精通——图像分割之DeepLab系列算法

    DeepLab系列算法 图像分割传统做法 解决方案 参数计算 图像金字塔 SPP-Layer 常用的多尺度提取方法 ASPP(atrous convolution SPP) deepLabv3+ 图像 ...

  2. 深度学习从入门到精通——人工智能、机器学习与深度学习绪论

    人工智能.机器学习与深度学习 人工智能 定义 人工智能 历史 机器学习 分类,按照监督方式 深度学习 主要应用 数学基础 张量基本知识 矩阵的秩: 矩阵的逆 矩阵的广义逆矩阵 矩阵分解 矩阵特征分解 ...

  3. 深度学习从入门到精通

    很多同学还处于迷茫阶段,那就听小编一言,在b站直接搜索并关注"跟李沐学AI",可以跟李沐老师学习机器学习.深度学习,还可以通过AI论文精讲对相关论文有更加深入的了解 深度学习视频: ...

  4. 深度学习福利入门到精通第四讲——GoogleNet模型

    2014年ILSVRC大赛中分类第一名就是GoogleNet模型,网络深度22层,而且在网络中加入Inception单元, 证明通过使用Inception单元构造的深层卷积神经网络能进一步提升模型整体 ...

  5. 深度学习福利入门到精通第三讲——VGGNet模型

    VGG由牛津大学视觉几何组提出,并在2014年ILSVRC取得了定位第一名和分类第二名好成绩. 相对于AlexNet,统一了卷积中使用的参数,卷积核统一为3*3,步长1,padding为1等等.而且增 ...

  6. 深度学习从入门到精通——MTCNN人脸侦测算法

    这里写目录标题 先看效果 MTCNN 主体思想 级联网络 图像金字塔 IOU算法 iou 公式 nms 算法 数据生成celeba 数据代码 训练代码 侦测代码 总结 先看效果 MTCNN 从2016 ...

  7. 深度学习福利入门到精通第五讲——ResNet模型

    ResNet是更深的网络模型,在2015年的ILSVRC大赛中获得分类任务第一名. ResNet引入了一种残差网络结构,使用这种结构可以避免出现模型性能退化问题. 如图残差模块的输出结果等于输入数据X ...

  8. 深度学习福利入门到精通第二讲——AlexNet模型

    Hinton课题组在2012年的ILSVRC比赛中使用AlexNet搭建卷积神经网络模型,并一举夺得冠军,在识别准确率上比第二名支持向量机(SVM)有一定优势. 其架构如图所示 因为前面的卷积和最大池 ...

  9. 深度学习福利入门到精通第一讲——LeNet模型

    LeNet是由LeCun在1989年提出的历史上第一个真正意义上的卷积神经网络.用的最多的是1998年出现的LeNet的改进版本LeNet-5. 如图是LeNet-5卷积神经网络的网络架构 1)  I ...

最新文章

  1. Laravel5.6 实现后台管理登录(自定义用户表登录)
  2. MySQL—创建数据表
  3. cocurrent包semaphore信号量
  4. boost::describe模块实现overloaded的测试程序
  5. Flink 时间语义与水位线(Watermarks)
  6. 124angular1实现无限表单(仅供自己看)
  7. JavaScript --- this
  8. php pdo使用事务,PHP内PDO事务使用步骤详解
  9. java 时间处理经典案例
  10. 硅谷华人创业公司Trifo获1100万美元融资,将发布智能扫地机器人
  11. BZOJ_1011_[HNOI2008]_遥远的行星_(近似)
  12. 电脑维修之——启动错误时的解决办法
  13. RTP/RTCP/RTSP/SIP/SDP 关系
  14. 《Redis开发与运维》读书笔记三
  15. Axure8.0AxureRP8实战手册
  16. 思维导图怎么画简单又漂亮?思维导图制作方法分享
  17. 如何判断一对一、一对多和多对多的关系
  18. 网易云课堂-缓存介绍
  19. 线性独立成分分析(ICA)与鸡尾酒会问题
  20. Java后台获取USB二维码扫描枪内容(Java监听系统键盘操作)

热门文章

  1. 全球与中国自主抛光机市场深度研究分析报告
  2. Flann在python3中的实现
  3. lammps模拟输出单个原子的能量
  4. 以色列摩萨德针对伊朗核设施进行破坏性网络攻击导致断电
  5. 攻防世界-杂项-新手-掀桌子
  6. 网页设计中颜色的搭配
  7. 华为鸿蒙系统手机魅族,魅族宣布接入鸿蒙系统 魅族鸿蒙系统手机是哪些?
  8. 2019车联网产业发展报告
  9. 等保2.0基本安全要求之技术要求分类
  10. 都2023年了,还有人自学黑客?