1、官方模型转换MMSegmentation风格

如果你想自己转换关键字使用官方存储库的预训练模型,我们还提供了一个脚本swin2mmseg.py在tools directory ,将模型的关键字从官方的repo转换为MMSegmentation风格。

python tools/model_converters/swin2mmseg.py ${PRETRAIN_PATH} ${STORE_PATH}
python tools/model_converters/swin2mmseg.py https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth pretrain/swin_base_patch4_window7_224.pth

这个脚本从PRETRAIN_PATH转换模型,并将转换后的模型存储在STORE_PATH中。
在我们的默认设置中,预训练的模型及其对应的原始模型模型可以定义如下:

2、下载ADK20的模型

https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K_20210531_112542-e380ad3e.pthhttps://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_small_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K/upernet_swin_small_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K_20210526_192015-ee2fff1c.pthhttps://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K_20210526_192340-593b0e13.pthhttps://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_22K/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_22K_20210526_211650-762e2178.pthhttps://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_22K/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_22K_20210531_125459-429057bf.pthhttps://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_large_patch4_window7_512x512_pretrain_224x224_22K_160k_ade20k/upernet_swin_large_patch4_window7_512x512_pretrain_224x224_22K_160k_ade20k_20220318_015320-48d180dd.pthhttps://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_large_patch4_window12_512x512_pretrain_384x384_22K_160k_ade20k/upernet_swin_large_patch4_window12_512x512_pretrain_384x384_22K_160k_ade20k_20220318_091743-9ba68901.pth

3、下载Swin Transform预训练模型

#tinyhttps://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_tiny_patch4_window7_224_20220317-1cdeb081.pth#small
https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_small_patch4_window7_224_20220317-7ba6d6dd.pth#big
https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_base_patch4_window7_224_20220317-e9b98025.pthhttps://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_base_patch4_window12_384_20220317-55b0104a.pthhttps://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_base_patch4_window7_224_22k_20220317-4f79f7c0.pthhttps://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_base_patch4_window12_384_22k_20220317-e5c09f74.pth#large
https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_large_patch4_window7_224_22k_20220412-aeecf2aa.pthhttps://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_large_patch4_window12_384_22k_20220412-6580f57d.pth

4、构造ADK20结构的数据目录

ADE20k拥有超过25,000张图像(20ktrain,2k val,3ktest),这些图像用开放字典标签集密集注释。对于2017 Places Challenge 2,选择了覆盖89%所有像素的100个thing和50个stuff类别。
一共150个类别。

Idx  Ratio   Train   Val Name
1 0.1576    11664 1172  wall
2 0.1072    6046  612   building, edifice
3 0.0878    8265  796   sky
4 0.0621    9336  917   floor, flooring
5 0.0480    6678  641   tree
6 0.0450    6604  643   ceiling
7 0.0398    4023  408   road, route
8 0.0231    1906  199   bed
9 0.0198    4688  460   windowpane, window
10    0.0183    2423  225   grass
11    0.0181    2874  294   cabinet
12    0.0166    3068  310   sidewalk, pavement
13    0.0160    5075  526   person, individual, someone, somebody, mortal, soul
14    0.0151    1804  190   earth, ground
15    0.0118    6666  796   door, double door
16    0.0110    4269  411   table
17    0.0109    1691  160   mountain, mount
18    0.0104    3999  441   plant, flora, plant life
19    0.0104    2149  217   curtain, drape, drapery, mantle, pall
20    0.0103    3261  318   chair
21    0.0098    3164  306   car, auto, automobile, machine, motorcar
22    0.0074    709   75    water
23    0.0067    3296  315   painting, picture
24    0.0065    1191  106   sofa, couch, lounge
25    0.0061    1516  162   shelf
26    0.0060    667   69    house
27    0.0053    651   57    sea
28    0.0052    1847  224   mirror
29    0.0046    1158  128   rug, carpet, carpeting
30    0.0044    480   44    field
31    0.0044    1172  98    armchair
32    0.0044    1292  184   seat
33    0.0033    1386  138   fence, fencing
34    0.0031    698   61    desk
35    0.0030    781   73    rock, stone
36    0.0027    380   43    wardrobe, closet, press
37    0.0026    3089  302   lamp
38    0.0024    404   37    bathtub, bathing tub, bath, tub
39    0.0024    804   99    railing, rail
40    0.0023    1453  153   cushion
41    0.0023    411   37    base, pedestal, stand
42    0.0022    1440  162   box
43    0.0022    800   77    column, pillar
44    0.0020    2650  298   signboard, sign
45    0.0019    549   46    chest of drawers, chest, bureau, dresser
46    0.0019    367   36    counter
47    0.0018    311   30    sand
48    0.0018    1181  122   sink
49    0.0018    287   23    skyscraper
50    0.0018    468   38    fireplace, hearth, open fireplace
51    0.0018    402   43    refrigerator, icebox
52    0.0018    130   12    grandstand, covered stand
53    0.0018    561   64    path
54    0.0017    880   102   stairs, steps
55    0.0017    86    12    runway
56    0.0017    172   11    case, display case, showcase, vitrine
57    0.0017    198   18    pool table, billiard table, snooker table
58    0.0017    930   109   pillow
59    0.0015    139   18    screen door, screen
60    0.0015    564   52    stairway, staircase
61    0.0015    320   26    river
62    0.0015    261   29    bridge, span
63    0.0014    275   22    bookcase
64    0.0014    335   60    blind, screen
65    0.0014    792   75    coffee table, cocktail table
66    0.0014    395   49    toilet, can, commode, crapper, pot, potty, stool, throne
67    0.0014    1309  138   flower
68    0.0013    1112  113   book
69    0.0013    266   27    hill
70    0.0013    659   66    bench
71    0.0012    331   31    countertop
72    0.0012    531   56    stove, kitchen stove, range, kitchen range, cooking stove
73    0.0012    369   36    palm, palm tree
74    0.0012    144   9 kitchen island
75    0.0011    265   29    computer, computing machine, computing device, data processor, electronic computer, information processing system
76    0.0010    324   33    swivel chair
77    0.0009    304   27    boat
78    0.0009    170   20    bar
79    0.0009    68    6 arcade machine
80    0.0009    65    8 hovel, hut, hutch, shack, shanty
81    0.0009    248   25    bus, autobus, coach, charabanc, double-decker, jitney, motorbus, motorcoach, omnibus, passenger vehicle
82    0.0008    492   49    towel
83    0.0008    2510  269   light, light source
84    0.0008    440   39    truck, motortruck
85    0.0008    147   18    tower
86    0.0008    583   56    chandelier, pendant, pendent
87    0.0007    533   61    awning, sunshade, sunblind
88    0.0007    1989  239   streetlight, street lamp
89    0.0007    71    5 booth, cubicle, stall, kiosk
90    0.0007    618   53    television receiver, television, television set, tv, tv set, idiot box, boob tube, telly, goggle box
91    0.0007    135   12    airplane, aeroplane, plane
92    0.0007    83    5 dirt track
93    0.0007    178   17    apparel, wearing apparel, dress, clothes
94    0.0006    1003  104   pole
95    0.0006    182   12    land, ground, soil
96    0.0006    452   50    bannister, banister, balustrade, balusters, handrail
97    0.0006    42    6 escalator, moving staircase, moving stairway
98    0.0006    307   31    ottoman, pouf, pouffe, puff, hassock
99    0.0006    965   114   bottle
100   0.0006    117   13    buffet, counter, sideboard
101   0.0006    354   35    poster, posting, placard, notice, bill, card
102   0.0006    108   9 stage
103   0.0006    557   55    van
104   0.0006    52    4 ship
105   0.0005    99    5 fountain
106   0.0005    57    4 conveyer belt, conveyor belt, conveyer, conveyor, transporter
107   0.0005    292   31    canopy
108   0.0005    77    9 washer, automatic washer, washing machine
109   0.0005    340   38    plaything, toy
110   0.0005    66    3 swimming pool, swimming bath, natatorium
111   0.0005    465   49    stool
112   0.0005    50    4 barrel, cask
113   0.0005    622   75    basket, handbasket
114   0.0005    80    9 waterfall, falls
115   0.0005    59    3 tent, collapsible shelter
116   0.0005    531   72    bag
117   0.0005    282   30    minibike, motorbike
118   0.0005    73    7 cradle
119   0.0005    435   44    oven
120   0.0005    136   25    ball
121   0.0005    116   24    food, solid food
122   0.0004    266   31    step, stair
123   0.0004    58    12    tank, storage tank
124   0.0004    418   83    trade name, brand name, brand, marque
125   0.0004    319   43    microwave, microwave oven
126   0.0004    1193  139   pot, flowerpot
127   0.0004    97    23    animal, animate being, beast, brute, creature, fauna
128   0.0004    347   36    bicycle, bike, wheel, cycle
129   0.0004    52    5 lake
130   0.0004    246   22    dishwasher, dish washer, dishwashing machine
131   0.0004    108   13    screen, silver screen, projection screen
132   0.0004    201   30    blanket, cover
133   0.0004    285   21    sculpture
134   0.0004    268   27    hood, exhaust hood
135   0.0003    1020  108   sconce
136   0.0003    1282  122   vase
137   0.0003    528   65    traffic light, traffic signal, stoplight
138   0.0003    453   57    tray
139   0.0003    671   100   ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin
140   0.0003    397   44    fan
141   0.0003    92    8 pier, wharf, wharfage, dock
142   0.0003    228   18    crt screen
143   0.0003    570   59    plate
144   0.0003    217   22    monitor, monitoring device
145   0.0003    206   19    bulletin board, notice board
146   0.0003    130   14    shower
147   0.0003    178   28    radiator
148   0.0002    504   57    glass, drinking glass
149   0.0002    775   96    clock
150   0.0002    421   56    flag

mmsegmentation
├── mmseg
├── tools
├── configs
├── data
│ ├── cityscapes
│ │ ├── leftImg8bit
│ │ │ ├── train
│ │ │ ├── val
│ │ ├── gtFine
│ │ │ ├── train
│ │ │ ├── val
│ ├── VOCdevkit
│ │ ├── VOC2012
│ │ │ ├── JPEGImages
│ │ │ ├── SegmentationClass
│ │ │ ├── ImageSets
│ │ │ │ ├── Segmentation
│ │ ├── VOC2010
│ │ │ ├── JPEGImages
│ │ │ ├── SegmentationClassContext
│ │ │ ├── ImageSets
│ │ │ │ ├── SegmentationContext
│ │ │ │ │ ├── train.txt
│ │ │ │ │ ├── val.txt
│ │ │ ├── trainval_merged.json
│ │ ├── VOCaug
│ │ │ ├── dataset
│ │ │ │ ├── cls
│ ├── ade
│ │ ├── ADEChallengeData2016
│ │ │ ├── annotations
│ │ │ │ ├── training
│ │ │ │ ├── validation
│ │ │ ├── images
│ │ │ │ ├── training
│ │ │ │ ├── validation

5、 修改基本配置文件

本次我们选择upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K模型进行训练,对应的配置文件如下。

具体配置信息如下

_base_ = ['../_base_/models/upernet_swin.py', '../_base_/datasets/ade20k.py','../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
]
checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_tiny_patch4_window7_224_20220317-1cdeb081.pth'  # noqa
model = dict(backbone=dict(init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file),embed_dims=96,depths=[2, 2, 6, 2],num_heads=[3, 6, 12, 24],window_size=7,use_abs_pos_embed=False,drop_path_rate=0.3,patch_norm=True),decode_head=dict(in_channels=[96, 192, 384, 768], num_classes=150),auxiliary_head=dict(in_channels=384, num_classes=150))# AdamW optimizer, no weight decay for position embedding & layer norm
# in backbone
optimizer = dict(_delete_=True,type='AdamW',lr=0.00006,betas=(0.9, 0.999),weight_decay=0.01,paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),'relative_position_bias_table': dict(decay_mult=0.),'norm': dict(decay_mult=0.)}))lr_config = dict(_delete_=True,policy='poly',warmup='linear',warmup_iters=1500,warmup_ratio=1e-6,power=1.0,min_lr=0.0,by_epoch=False)# By default, models are trained on 8 GPUs with 2 images per GPU
data = dict(samples_per_gpu=2)

1、设置修改类别数​和加载预训练模型(模型架构配置文件upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py)

_base_ = ['../_base_/models/upernet_swin.py', '../_base_/datasets/ade20k.py','../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
]
checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_tiny_patch4_window7_224_20220317-1cdeb081.pth'  # noqa,这个可以下载后,加载下载后的路径
model = dict(backbone=dict(init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file),embed_dims=96,depths=[2, 2, 6, 2],num_heads=[3, 6, 12, 24],window_size=7,use_abs_pos_embed=False,drop_path_rate=0.3,patch_norm=True),decode_head=dict(in_channels=[96, 192, 384, 768], num_classes=150),auxiliary_head=dict(in_channels=384, num_classes=150))#num_classes修改为自己的数据类别数,不包括背景,背景自动为0# AdamW optimizer, no weight decay for position embedding & layer norm
# in backbone
optimizer = dict(_delete_=True,type='AdamW',lr=0.00006,betas=(0.9, 0.999),weight_decay=0.01,paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),'relative_position_bias_table': dict(decay_mult=0.),'norm': dict(decay_mult=0.)}))lr_config = dict(_delete_=True,policy='poly',warmup='linear',warmup_iters=1500,warmup_ratio=1e-6,power=1.0,min_lr=0.0,by_epoch=False)# By default, models are trained on 8 GPUs with 2 images per GPU
data = dict(samples_per_gpu=2)

2、修改数据信息(数据类型、数据主路径等和batch-size)(‘…/base/datasets/ade20k.py’)

# dataset settings
dataset_type = 'ADE20KDataset'
data_root = 'data/ade/ADEChallengeData2016' #1、修改为自己的数据路径
img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (512, 512) #2、修改为自己的数据的尺寸
train_pipeline = [dict(type='LoadImageFromFile'),dict(type='LoadAnnotations', reduce_zero_label=True),dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),#根据img_crop调整img_scaledict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),dict(type='RandomFlip', prob=0.5),dict(type='PhotoMetricDistortion'),dict(type='Normalize', **img_norm_cfg),dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),dict(type='DefaultFormatBundle'),dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [dict(type='LoadImageFromFile'),dict(type='MultiScaleFlipAug',img_scale=(2048, 512),# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],flip=False,transforms=[dict(type='Resize', keep_ratio=True),dict(type='RandomFlip'),dict(type='Normalize', **img_norm_cfg),dict(type='ImageToTensor', keys=['img']),dict(type='Collect', keys=['img']),])
]
data = dict(samples_per_gpu=4,workers_per_gpu=4,train=dict(type=dataset_type,data_root=data_root,img_dir='images/training',ann_dir='annotations/training',pipeline=train_pipeline),val=dict(type=dataset_type,data_root=data_root,img_dir='images/validation',ann_dir='annotations/validation',pipeline=test_pipeline),test=dict(type=dataset_type,data_root=data_root,img_dir='images/validation',ann_dir='annotations/validation',pipeline=test_pipeline))

3 修该类别名称CLASSES以及后缀名\在损失计算中忽略指定的标签索引(mmseg/datasets/ade.py、mmseg/datasets/custom.py)

# Copyright (c) OpenMMLab. All rights reserved.
import os.path as ospimport mmcv
import numpy as np
from PIL import Imagefrom .builder import DATASETS
from .custom import CustomDataset@DATASETS.register_module()
class ADE20KDataset(CustomDataset):"""ADE20K dataset.In segmentation map annotation for ADE20K, 0 stands for background, whichis not included in 150 categories. ``reduce_zero_label`` is fixed to True.The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed to'.png'."""CLASSES = ('wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ','windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth','door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car','water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug','field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe','lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column','signboard', 'chest of drawers', 'counter', 'sand', 'sink','skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path','stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door','stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table','toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove','palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar','arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower','chandelier', 'awning', 'streetlight', 'booth', 'television receiver','airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister','escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van','ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything','swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent','bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank','trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake','dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce','vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen','plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass','clock', 'flag')#修改为自己数据集的类别名称PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],[4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],[230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],[150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],[143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],[0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],[255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],[255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],[255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],[224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],[255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],[6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],[140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],[255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],[255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],[11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],[0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],[255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],[0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],[173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],[255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],[255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],[255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],[0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],[0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],[143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],[8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],[255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],[92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],[163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],[255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],[255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],[10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],[255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],[41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],[71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],[184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],[102, 255, 0], [92, 0, 255]] #同理可以修改颜色def __init__(self, **kwargs):super(ADE20KDataset, self).__init__(img_suffix='.jpg', #可以修改数据集的后缀格式seg_map_suffix='.png',#可以修改数据集标签的后缀格式reduce_zero_label=True,**kwargs)def results2img(self, results, imgfile_prefix, to_label_id, indices=None):"""Write the segmentation results to images.Args:results (list[ndarray]): Testing results of thedataset.imgfile_prefix (str): The filename prefix of the png files.If the prefix is "somepath/xxx",the png files will be named "somepath/xxx.png".to_label_id (bool): whether convert output to label_id forsubmission.indices (list[int], optional): Indices of input results, if notset, all the indices of the dataset will be used.Default: None.Returns:list[str: str]: result txt files which contains correspondingsemantic segmentation images."""if indices is None:indices = list(range(len(self)))mmcv.mkdir_or_exist(imgfile_prefix)result_files = []for result, idx in zip(results, indices):filename = self.img_infos[idx]['filename']basename = osp.splitext(osp.basename(filename))[0]png_filename = osp.join(imgfile_prefix, f'{basename}.png')#这里可以修改.png# The  index range of official requirement is from 0 to 150.# But the index range of output is from 0 to 149.# That is because we set reduce_zero_label=True.result = result + 1output = Image.fromarray(result.astype(np.uint8))output.save(png_filename)result_files.append(png_filename)return result_filesdef format_results(self,results,imgfile_prefix,to_label_id=True,indices=None):"""Format the results into dir (standard format for ade20k evaluation).Args:results (list): Testing results of the dataset.imgfile_prefix (str | None): The prefix of images files. Itincludes the file path and the prefix of filename, e.g.,"a/b/prefix".to_label_id (bool): whether convert output to label_id forsubmission. Default: Falseindices (list[int], optional): Indices of input results, if notset, all the indices of the dataset will be used.Default: None.Returns:tuple: (result_files, tmp_dir), result_files is a list containingthe image paths, tmp_dir is the temporal directory createdfor saving json/png files when img_prefix is not specified."""if indices is None:indices = list(range(len(self)))assert isinstance(results, list), 'results must be a list.'assert isinstance(indices, list), 'indices must be a list.'result_files = self.results2img(results, imgfile_prefix, to_label_id,indices)return result_files

有一点需要注意的是,如果你的图片是jpg合式,mask是png格式,应该没问题,要是不是这两种格式的话,需要在mmseg/datasets/custom.py中修改你的图片的格式。

# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import warnings
from collections import OrderedDictimport mmcv
import numpy as np
from mmcv.utils import print_log
from prettytable import PrettyTable
from torch.utils.data import Datasetfrom mmseg.core import eval_metrics, intersect_and_union, pre_eval_to_metrics
from mmseg.utils import get_root_logger
from .builder import DATASETS
from .pipelines import Compose, LoadAnnotations@DATASETS.register_module()
class CustomDataset(Dataset):"""Custom dataset for semantic segmentation. An example of file structureis as followed... code-block:: none├── data│   ├── my_dataset│   │   ├── img_dir│   │   │   ├── train│   │   │   │   ├── xxx{img_suffix}│   │   │   │   ├── yyy{img_suffix}│   │   │   │   ├── zzz{img_suffix}│   │   │   ├── val│   │   ├── ann_dir│   │   │   ├── train│   │   │   │   ├── xxx{seg_map_suffix}│   │   │   │   ├── yyy{seg_map_suffix}│   │   │   │   ├── zzz{seg_map_suffix}│   │   │   ├── valThe img/gt_semantic_seg pair of CustomDataset should be of the sameexcept suffix. A valid img/gt_semantic_seg filename pair should be like``xxx{img_suffix}`` and ``xxx{seg_map_suffix}`` (extension is also includedin the suffix). If split is given, then ``xxx`` is specified in txt file.Otherwise, all files in ``img_dir/``and ``ann_dir`` will be loaded.Please refer to ``docs/en/tutorials/new_dataset.md`` for more details.Args:pipeline (list[dict]): Processing pipelineimg_dir (str): Path to image directoryimg_suffix (str): Suffix of images. Default: '.jpg'ann_dir (str, optional): Path to annotation directory. Default: Noneseg_map_suffix (str): Suffix of segmentation maps. Default: '.png'split (str, optional): Split txt file. If split is specified, onlyfile with suffix in the splits will be loaded. Otherwise, allimages in img_dir/ann_dir will be loaded. Default: Nonedata_root (str, optional): Data root for img_dir/ann_dir. Default:None.test_mode (bool): If test_mode=True, gt wouldn't be loaded.ignore_index (int): The label index to be ignored. Default: 255reduce_zero_label (bool): Whether to mark label zero as ignored.Default: Falseclasses (str | Sequence[str], optional): Specify classes to load.If is None, ``cls.CLASSES`` will be used. Default: None.palette (Sequence[Sequence[int]]] | np.ndarray | None):The palette of segmentation map. If None is given, andself.PALETTE is None, random palette will be generated.Default: Nonegt_seg_map_loader_cfg (dict, optional): build LoadAnnotations toload gt for evaluation, load from disk by default. Default: None.file_client_args (dict): Arguments to instantiate a FileClient.See :class:`mmcv.fileio.FileClient` for details.Defaults to ``dict(backend='disk')``."""CLASSES = NonePALETTE = Nonedef __init__(self,pipeline,img_dir,img_suffix='.jpg',#修改ann_dir=None,seg_map_suffix='.png',修改split=None,data_root=None,test_mode=False,ignore_index=255,reduce_zero_label=False,classes=None,palette=None,gt_seg_map_loader_cfg=None,file_client_args=dict(backend='disk')):self.pipeline = Compose(pipeline)self.img_dir = img_dirself.img_suffix = img_suffixself.ann_dir = ann_dirself.seg_map_suffix = seg_map_suffixself.split = splitself.data_root = data_rootself.test_mode = test_modeself.ignore_index = ignore_indexself.reduce_zero_label = reduce_zero_labelself.label_map = Noneself.CLASSES, self.PALETTE = self.get_classes_and_palette(classes, palette)self.gt_seg_map_loader = LoadAnnotations() if gt_seg_map_loader_cfg is None else LoadAnnotations(**gt_seg_map_loader_cfg)self.file_client_args = file_client_argsself.file_client = mmcv.FileClient.infer_client(self.file_client_args)if test_mode:assert self.CLASSES is not None, \'`cls.CLASSES` or `classes` should be specified when testing'# join paths if data_root is specifiedif self.data_root is not None:if not osp.isabs(self.img_dir):self.img_dir = osp.join(self.data_root, self.img_dir)if not (self.ann_dir is None or osp.isabs(self.ann_dir)):self.ann_dir = osp.join(self.data_root, self.ann_dir)if not (self.split is None or osp.isabs(self.split)):self.split = osp.join(self.data_root, self.split)# load annotationsself.img_infos = self.load_annotations(self.img_dir, self.img_suffix,self.ann_dir,self.seg_map_suffix, self.split)def __len__(self):"""Total number of samples of data."""return len(self.img_infos)def load_annotations(self, img_dir, img_suffix, ann_dir, seg_map_suffix,split):"""Load annotation from directory.Args:img_dir (str): Path to image directoryimg_suffix (str): Suffix of images.ann_dir (str|None): Path to annotation directory.seg_map_suffix (str|None): Suffix of segmentation maps.split (str|None): Split txt file. If split is specified, only filewith suffix in the splits will be loaded. Otherwise, all imagesin img_dir/ann_dir will be loaded. Default: NoneReturns:list[dict]: All image info of dataset."""img_infos = []if split is not None:lines = mmcv.list_from_file(split, file_client_args=self.file_client_args)for line in lines:img_name = line.strip()img_info = dict(filename=img_name + img_suffix)if ann_dir is not None:seg_map = img_name + seg_map_suffiximg_info['ann'] = dict(seg_map=seg_map)img_infos.append(img_info)else:for img in self.file_client.list_dir_or_file(dir_path=img_dir,list_dir=False,suffix=img_suffix,recursive=True):img_info = dict(filename=img)if ann_dir is not None:seg_map = img.replace(img_suffix, seg_map_suffix)img_info['ann'] = dict(seg_map=seg_map)img_infos.append(img_info)img_infos = sorted(img_infos, key=lambda x: x['filename'])print_log(f'Loaded {len(img_infos)} images', logger=get_root_logger())return img_infosdef get_ann_info(self, idx):"""Get annotation by index.Args:idx (int): Index of data.Returns:dict: Annotation info of specified index."""return self.img_infos[idx]['ann']def pre_pipeline(self, results):"""Prepare results dict for pipeline."""results['seg_fields'] = []results['img_prefix'] = self.img_dirresults['seg_prefix'] = self.ann_dirif self.custom_classes:results['label_map'] = self.label_mapdef __getitem__(self, idx):"""Get training/test data after pipeline.Args:idx (int): Index of data.Returns:dict: Training/test data (with annotation if `test_mode` is setFalse)."""if self.test_mode:return self.prepare_test_img(idx)else:return self.prepare_train_img(idx)def prepare_train_img(self, idx):"""Get training data and annotations after pipeline.Args:idx (int): Index of data.Returns:dict: Training data and annotation after pipeline with new keysintroduced by pipeline."""img_info = self.img_infos[idx]ann_info = self.get_ann_info(idx)results = dict(img_info=img_info, ann_info=ann_info)self.pre_pipeline(results)return self.pipeline(results)def prepare_test_img(self, idx):"""Get testing data after pipeline.Args:idx (int): Index of data.Returns:dict: Testing data after pipeline with new keys introduced bypipeline."""img_info = self.img_infos[idx]results = dict(img_info=img_info)self.pre_pipeline(results)return self.pipeline(results)def format_results(self, results, imgfile_prefix, indices=None, **kwargs):"""Place holder to format result to dataset specific output."""raise NotImplementedErrordef get_gt_seg_map_by_idx(self, index):"""Get one ground truth segmentation map for evaluation."""ann_info = self.get_ann_info(index)results = dict(ann_info=ann_info)self.pre_pipeline(results)self.gt_seg_map_loader(results)return results['gt_semantic_seg']def get_gt_seg_maps(self, efficient_test=None):"""Get ground truth segmentation maps for evaluation."""if efficient_test is not None:warnings.warn('DeprecationWarning: ``efficient_test`` has been deprecated ''since MMSeg v0.16, the ``get_gt_seg_maps()`` is CPU memory ''friendly by default. ')for idx in range(len(self)):ann_info = self.get_ann_info(idx)results = dict(ann_info=ann_info)self.pre_pipeline(results)self.gt_seg_map_loader(results)yield results['gt_semantic_seg']def pre_eval(self, preds, indices):"""Collect eval result from each iteration.Args:preds (list[torch.Tensor] | torch.Tensor): the segmentation logitafter argmax, shape (N, H, W).indices (list[int] | int): the prediction related ground truthindices.Returns:list[torch.Tensor]: (area_intersect, area_union, area_prediction,area_ground_truth)."""# In order to compat with batch inferenceif not isinstance(indices, list):indices = [indices]if not isinstance(preds, list):preds = [preds]pre_eval_results = []for pred, index in zip(preds, indices):seg_map = self.get_gt_seg_map_by_idx(index)pre_eval_results.append(intersect_and_union(pred,seg_map,len(self.CLASSES),self.ignore_index,# as the labels has been converted when dataset initialized# in`get_palette_for_custom_classes ` this `label_map`# should be`dict()`, see# https://github.com/open-mmlab/mmsegmentation/issues/1415# for more ditailslabel_map=dict(),reduce_zero_label=self.reduce_zero_label))return pre_eval_resultsdef get_classes_and_palette(self, classes=None, palette=None):"""Get class names of current dataset.Args:classes (Sequence[str] | str | None): If classes is None, usedefault CLASSES defined by builtin dataset. If classes is astring, take it as a file name. The file contains the name ofclasses where each line contains one class name. If classes isa tuple or list, override the CLASSES defined by the dataset.palette (Sequence[Sequence[int]]] | np.ndarray | None):The palette of segmentation map. If None is given, randompalette will be generated. Default: None"""if classes is None:self.custom_classes = Falsereturn self.CLASSES, self.PALETTEself.custom_classes = Trueif isinstance(classes, str):# take it as a file pathclass_names = mmcv.list_from_file(classes)elif isinstance(classes, (tuple, list)):class_names = classeselse:raise ValueError(f'Unsupported type {type(classes)} of classes.')if self.CLASSES:if not set(class_names).issubset(self.CLASSES):raise ValueError('classes is not a subset of CLASSES.')# dictionary, its keys are the old label ids and its values# are the new label ids.# used for changing pixel labels in load_annotations.self.label_map = {}for i, c in enumerate(self.CLASSES):if c not in class_names:self.label_map[i] = -1else:self.label_map[i] = class_names.index(c)palette = self.get_palette_for_custom_classes(class_names, palette)return class_names, palettedef get_palette_for_custom_classes(self, class_names, palette=None):if self.label_map is not None:# return subset of palettepalette = []for old_id, new_id in sorted(self.label_map.items(), key=lambda x: x[1]):if new_id != -1:palette.append(self.PALETTE[old_id])palette = type(self.PALETTE)(palette)elif palette is None:if self.PALETTE is None:# Get random state before set seed, and restore# random state later.# It will prevent loss of randomness, as the palette# may be different in each iteration if not specified.# See: https://github.com/open-mmlab/mmdetection/issues/5844state = np.random.get_state()np.random.seed(42)# random palettepalette = np.random.randint(0, 255, size=(len(class_names), 3))np.random.set_state(state)else:palette = self.PALETTEreturn palettedef evaluate(self,results,metric='mIoU',logger=None,gt_seg_maps=None,**kwargs):"""Evaluate the dataset.Args:results (list[tuple[torch.Tensor]] | list[str]): per image pre_evalresults or predict segmentation map for computing evaluationmetric.metric (str | list[str]): Metrics to be evaluated. 'mIoU','mDice' and 'mFscore' are supported.logger (logging.Logger | None | str): Logger used for printingrelated information during evaluation. Default: None.gt_seg_maps (generator[ndarray]): Custom gt seg maps as input,used in ConcatDatasetReturns:dict[str, float]: Default metrics."""if isinstance(metric, str):metric = [metric]allowed_metrics = ['mIoU', 'mDice', 'mFscore']if not set(metric).issubset(set(allowed_metrics)):raise KeyError('metric {} is not supported'.format(metric))eval_results = {}# test a list of filesif mmcv.is_list_of(results, np.ndarray) or mmcv.is_list_of(results, str):if gt_seg_maps is None:gt_seg_maps = self.get_gt_seg_maps()num_classes = len(self.CLASSES)ret_metrics = eval_metrics(results,gt_seg_maps,num_classes,self.ignore_index,metric,label_map=dict(),reduce_zero_label=self.reduce_zero_label)# test a list of pre_eval_resultselse:ret_metrics = pre_eval_to_metrics(results, metric)# Because dataset.CLASSES is required for per-eval.if self.CLASSES is None:class_names = tuple(range(num_classes))else:class_names = self.CLASSES# summary tableret_metrics_summary = OrderedDict({ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2)for ret_metric, ret_metric_value in ret_metrics.items()})# each class tableret_metrics.pop('aAcc', None)ret_metrics_class = OrderedDict({ret_metric: np.round(ret_metric_value * 100, 2)for ret_metric, ret_metric_value in ret_metrics.items()})ret_metrics_class.update({'Class': class_names})ret_metrics_class.move_to_end('Class', last=False)# for loggerclass_table_data = PrettyTable()for key, val in ret_metrics_class.items():class_table_data.add_column(key, val)summary_table_data = PrettyTable()for key, val in ret_metrics_summary.items():if key == 'aAcc':summary_table_data.add_column(key, [val])else:summary_table_data.add_column('m' + key, [val])print_log('per class results:', logger)print_log('\n' + class_table_data.get_string(), logger=logger)print_log('Summary:', logger)print_log('\n' + summary_table_data.get_string(), logger=logger)# each metric dictfor key, value in ret_metrics_summary.items():if key == 'aAcc':eval_results[key] = value / 100.0else:eval_results['m' + key] = value / 100.0ret_metrics_class.pop('Class', None)for key, value in ret_metrics_class.items():eval_results.update({key + '.' + str(name): value[idx] / 100.0for idx, name in enumerate(class_names)})return eval_results

1、custom的修改后的配置文件

2、voc数据类型修改计算中忽略的指定标签索引后的模型评估后的结果

4、修改运行信息配置(加载预训练模型和断点训练)(configs/-base-/default_runtime.py)

# yapf:disable
log_config = dict(interval=50,hooks=[dict(type='TextLoggerHook', by_epoch=False),# dict(type='TensorboardLoggerHook') #开启TensorboardLoggerHook# dict(type='PaviLoggerHook') # for internal services])
# yapf:enable
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None  #从给定的路径加载模型作为预先训练的模型,这不会恢复训练。
resume_from = None  #从给定的路径加载模型作为训练后的断点的模型,恢复训练。
workflow = [('train', 1)]
cudnn_benchmark = True

5、修改运行信息配置(模型训练的最大次数、训练每个几次保留一个checkpoints、间隔多少次进行模型训练,模型训练评估的指标为、自动保留最好的模型、)(configs/-base-/schedule_40k.py、…/base/schedules/schedule_160k.py)

# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optimizer_config = dict()
# learning policy
lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
# runtime settings
runner = dict(type='IterBasedRunner', max_iters=160000)#max_iters,模型训练的最大迭代次数
checkpoint_config = dict(by_epoch=False, interval=16000)##interval,模型保存的迭代次数
evaluation = dict(interval=16000, metric='mIoU', pre_eval=True)#interval=16000模型多少间隔训练一次,评估的指标,#save_best='auto'可以保留最好的模型
log_config = dict(interval=50,hooks=[dict(type='TextLoggerHook', by_epoch=False),dict(type='TensorboardLoggerHook')])
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = '/media/lhy/Swin-Transformer-Semantic-Segmentation/checkpoints/deeplabv3plus/deeplabv3plus_r101-d8_512x512_40k_voc12aug_20200613_205333-faf03387.pth'
resume_from = '/media/lhy/mmsegmentation-0.27.0/work_dirs/runs/train/road0.5m_1_deeplabv3plus_r101_exp2/best_mIoU_iter_44000.pth'
workflow = [('train', 1)]
cudnn_benchmark = True
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
#调用FP16
optimizer_config = dict(type='Fp16OptimizerHook', loss_scale='dynamic')
fp16 = dict()
lr_config = dict(policy='poly', power=0.9, min_lr=0.0001, by_epoch=False)
runner = dict(type='IterBasedRunner', max_iters=160000)
checkpoint_config = dict(by_epoch=False, interval=4000)
evaluation = dict(interval=4000, metric=['mIoU', 'mFscore'], pre_eval=True, save_best='mIoU')#自动保存mIOU最好的模型
work_dir = 'work_dirs/runs/train/road0.5m_1_deeplabv3plus_r101_exp2'
gpu_ids = range(0, 4)
auto_resume = False

单个GPU学习率lr= LR*(batch_size/16),LR代表4GPU的学习率

6、修改模型的推理模式以及norm_cfg(…/base/models/upernet_swin.py)

# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)#这里的norm_cfg中,如果是多卡训练,采用“SyncBN”; 如果是单卡训练,将type修改为'BN'即可。
backbone_norm_cfg = dict(type='LN', requires_grad=True)
model = dict(type='EncoderDecoder',pretrained=None,backbone=dict(type='SwinTransformer',pretrain_img_size=224,embed_dims=96,patch_size=4,window_size=7,mlp_ratio=4,depths=[2, 2, 6, 2],num_heads=[3, 6, 12, 24],strides=(4, 2, 2, 2),out_indices=(0, 1, 2, 3),qkv_bias=True,qk_scale=None,patch_norm=True,drop_rate=0.,attn_drop_rate=0.,drop_path_rate=0.3,use_abs_pos_embed=False,act_cfg=dict(type='GELU'),norm_cfg=backbone_norm_cfg),decode_head=dict(type='UPerHead',in_channels=[96, 192, 384, 768],in_index=[0, 1, 2, 3],pool_scales=(1, 2, 3, 6),channels=512,dropout_ratio=0.1,num_classes=19,norm_cfg=norm_cfg,align_corners=False,loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),auxiliary_head=dict(type='FCNHead',in_channels=384,in_index=2,channels=256,num_convs=1,concat_input=False,dropout_ratio=0.1,num_classes=19,norm_cfg=norm_cfg,align_corners=False,loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),# model training and testing settingstrain_cfg=dict(),test_cfg=dict(mode='whole'))#'whole代表全图推理模式',
#滑窗重叠预测可修改为:test_cfg=dict(mode='slide', crop_size=crop_size, stride=(341, 341))


滑动窗口代码:mmsegmentation/mmseg/models/segmentors/encoder_decoder.py

    # TODO refactordef slide_inference(self, img, img_meta, rescale):"""Inference by sliding-window with overlap.If h_crop > h_img or w_crop > w_img, the small patch will be used todecode without padding."""h_stride, w_stride = self.test_cfg.strideh_crop, w_crop = self.test_cfg.crop_sizebatch_size, _, h_img, w_img = img.size()num_classes = self.num_classesh_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1preds = img.new_zeros((batch_size, num_classes, h_img, w_img))count_mat = img.new_zeros((batch_size, 1, h_img, w_img))for h_idx in range(h_grids):for w_idx in range(w_grids):y1 = h_idx * h_stridex1 = w_idx * w_stridey2 = min(y1 + h_crop, h_img)x2 = min(x1 + w_crop, w_img)y1 = max(y2 - h_crop, 0)x1 = max(x2 - w_crop, 0)crop_img = img[:, :, y1:y2, x1:x2]crop_seg_logit = self.encode_decode(crop_img, img_meta)preds += F.pad(crop_seg_logit,(int(x1), int(preds.shape[3] - x2), int(y1),int(preds.shape[2] - y2)))count_mat[:, :, y1:y2, x1:x2] += 1assert (count_mat == 0).sum() == 0if torch.onnx.is_in_onnx_export():# cast count_mat to constant while exporting to ONNXcount_mat = torch.from_numpy(count_mat.cpu().detach().numpy()).to(device=img.device)preds = preds / count_matif rescale:# remove padding arearesize_shape = img_meta[0]['img_shape'][:2]preds = preds[:, :, :resize_shape[0], :resize_shape[1]]preds = resize(preds,size=img_meta[0]['ori_shape'][:2],mode='bilinear',align_corners=self.align_corners,warning=False)return preds

6、模型优化技巧

1、学习率优化技巧

在语义分割中,一些方法使头部的 LR 大于骨干,以实现更好的性能或更快的收敛。
在 MMSegmentation 中,您可以在配置中添加以下行,以使 head 的 LR 是主干的 10 倍。通过此修改,任何具有 LR名称的参数组的 LR’head’都将乘以 10。

Different Learning Rate(LR) for Backbone and Heads
n MMSegmentation, you may add following lines to config to make the LR of heads 10 times of backbone.optimizer=dict(paramwise_cfg = dict(custom_keys={'head': dict(lr_mult=10.)}))

2、Online Hard Example Mining (OHEM)

我们在这里实现像素采样器用于训练采样。这是一个启用 OHEM 的 PSPNet 训练示例配置。
这样,只使用置信度分数低于 0.7 的像素进行训练。我们在训练期间至少保留 100000 像素。如果thresh未指定,min_kept将选择顶部丢失的像素。

Online Hard Example Mining (OHEM)
We implement pixel sampler here for training sampling. Here is an example config of training PSPNet with OHEM enabled._base_ = './pspnet_r50-d8_512x1024_40k_cityscapes.py'
model=dict(decode_head=dict(sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=100000)) )

3、类平衡损失

对于类别分布不平衡的数据集,您可以更改每个类别的损失权重。这是城市景观数据集的示例。class_weight 将作为weight参数传入CrossEntropyLoss

_base_ = './pspnet_r50-d8_512x1024_40k_cityscapes.py'
model=dict(decode_head=dict(loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0,# DeepLab used this class weight for cityscapesclass_weight=[0.8373, 0.9180, 0.8660, 1.0345, 1.0166, 0.9969, 0.9754,1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037,1.0865, 1.0955, 1.0865, 1.1529, 1.0507])))

4、多重损失

对于损失计算,我们支持同时进行多个损失训练。unet这是一个在数据集上训练的示例配置DRIVE,其损失函数是1:3和 的加权CrossEntropyLoss和DiceLoss:

_base_ = './fcn_unet_s5-d16_64x64_40k_drive.py'
model = dict(decode_head=dict(loss_decode=[dict(type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1.0),dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0)]),auxiliary_head=dict(loss_decode=[dict(type='CrossEntropyLoss', loss_name='loss_ce',loss_weight=1.0),dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0)]),)

这样,loss_weight和loss_name将分别是对应损失的训练日志中的权重和名称。
注意:如果要将此损失项包含到后向图中,loss_必须是名称的前缀。

5、在损失计算中忽略指定的标签索引

mmseg 中已经为各种公共分割数据集编写了描述文件和加载代码,对于有用过 PyTorch 的小伙伴而言,学习各种数据集的描述文件还是很自如的,只有 reduce_zero_label 对于 mmseg 的新手比较陌生,所以,在搭建自己的 mmseg 数据集时,新手最疑惑的大概就是 reduce_zero_label 到底应该是 True 还是 False。

它有什么用呢?从名字直译过来就是“减少 0 值标签”。在多类分割任务中,如果你的数据集中 0 值作为 label 文件中的背景类别,是建议忽略的。

打开加载数据的源码片段可以看到一段处理 reduce_zero_label 的代码,意思是:若开启了 reduce_zero_label,原本为 0 的所有标注设置为 255,也就是损失函数中 ignore_index 参数的默认值,该参数默认避免值为 255 的标注参与损失计算。前文按下不表的 150 类的 ADE 数据集,它不包含背景的原因就是开了 reduce zero label,原本为 0 值的背景设置为了 ignore_index。

# mmseg/datasets/pipelines/loading.py...
# reduce zero_label
if self.reduce_zero_label:# avoid using underflow conversiongt_semantic_seg[gt_semantic_seg == 0] = 255gt_semantic_seg = gt_semantic_seg - 1gt_semantic_seg[gt_semantic_seg == 254] = 255
...

reduce_zero_label 导致的常见问题描述

我们这里以 ADE 数据集源码为例,reduce_zero_label 默认设置为 True,然而,就算新手掌握了上一节的 reduce_zero_label,也可能对 ADE 了解比较肤浅,会怀疑配置文件中开启的 reduce_zero_label 是不是把 150 个实例类中的第一个给忽略掉了,毕竟 num_classes 不就是 150 吗,然后想当然把 reduce_zero_label 关掉。

错误原因分析

# configs/_base_/datasets/ade20k.pytrain_pipeline = [dict(type='LoadImageFromFile'),dict(type='LoadAnnotations', reduce_zero_label=True), # ADE中reduce_zero_label默认设置为Truedict(...),...
]

label 中实际参加训练的确实只有 150 类,定义在 CLASSES 中,但 label 文件中实际包含了 151 类,而背景类(剩下仍没有标记的,或者被意外忽略的区域都归为背景,在 label 中值为 0)不包含在 150 个 CLASSES 中,需要在训练的时候设置成 ignore_index,所以我们借助上一小节的 reduce_zero_label 将背景从 151 个类中提出来单独设置为了 ignore_index,我们倘若错误地将 reduce_zero_label 关掉了,那 num_classes 就是 151 了。

在默认设置中,avg_non_ignore=False这意味着每个像素都计入损失计算,尽管其中一些属于忽略索引标签。
对于损失计算,我们支持通过avg_non_ignore和忽略某些标签的索引ignore_index。这样,平均损失只会在非忽略标签中计算,可能会获得更好的性能,这里是参考。unet这是数据集训练的示例配置Cityscapes:在损失计算中,它将忽略作为背景的标签 0,并且仅在非忽略标签上计算损失平均值:

_base_ = './fcn_unet_s5-d16_4x4_512x1024_160k_cityscapes.py'
model = dict(decode_head=dict(ignore_index=0,loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, avg_non_ignore=True),auxiliary_head=dict(ignore_index=0,loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, avg_non_ignore=True)),))

只需添加ignore_index解码器头或辅助头并添加avg_non_ignore=True:

# model settings
...loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, avg_non_ignore=True),
...

MMSegmentation V0.27.0训练与推理自己的数据集(二)相关推荐

  1. so-vits-svc4.0 中文详细安装、训练、推理使用教程

    SO-VITS-SVC4.0详细安装.训练.推理使用步骤 本帮助文档为项目 so-vits-svc4.0 的详细中文安装.调试.推理教程,您也可以直接选择官方README文档 撰写:Sucial 点击 ...

  2. so-vits-svc3.0 中文详细安装、训练、推理使用教程

    SO-VITS-SVC3.0详细安装.训练.推理使用步骤 2023-3-12文档更新说明: 由于特殊原因,本项目文档将停止更新,详情请见原作者首页,感谢各位的支持! 本文档的Github项目地址 点击 ...

  3. CV之CNN:基于tensorflow框架采用CNN(改进的AlexNet,训练/评估/推理)卷积神经网络算法实现猫狗图像分类识别

    CV之CNN:基于tensorflow框架采用CNN(改进的AlexNet,训练/评估/推理)卷积神经网络算法实现猫狗图像分类识别 目录 基于tensorflow框架采用CNN(改进的AlexNet, ...

  4. 详谈大模型训练和推理优化技术

    详谈大模型训练和推理优化技术 作者:王嘉宁,转载请注明出处:https://wjn1996.blog.csdn.net/article/details/130764843 ChatGPT于2022年1 ...

  5. CV之IC之AlexNet:基于tensorflow框架采用CNN卷积神经网络算法(改进的AlexNet,训练/评估/推理)实现猫狗分类识别案例应用

    CV之IC之AlexNet:基于tensorflow框架采用CNN卷积神经网络算法(改进的AlexNet,训练/评估/推理)实现猫狗分类识别案例应用 目录 基于tensorflow框架采用CNN(改进 ...

  6. float32精度_混合精度对模型训练和推理的影响

    单精度/双精度/半精度/混合精度 计算机使用0/1来标识信息,每个0或每个1代表一个bit.信息一般会以下面的三种形式表示: 1 字符串 字符串的最小单元是char,每个char占8个bit,也就是1 ...

  7. NLP 训练及推理一体化工具(TurboNLPExp)

    作者:TurboNLP,腾讯 TEG 后台工程师 导语 NLP 任务(序列标注.分类.句子关系判断.生成式)训练时,通常使用机器学习框架 Pytorch 或 Tensorflow,在其之上定义模型以及 ...

  8. RetinaNet Examples:NVIDIA 一站式训练、推理及模型转换解决方案

    retinanet-examples 是英伟达提供的目标检测工程范例,针对端到端 GPU 处理进行了优化: 使用基于 Python 多进程的 apex.parallel.DistributedData ...

  9. 深度估计自监督模型monodepth2在自己数据集的实战——单卡/多卡训练、推理、Onnx转换和量化指标评估

    本文详细介绍monodepth2模型在自己数据集的实战方法,包括单卡/多卡训练.推理.Onnx转换和量化评估等,关于理论部分请参见另一篇博客:深度估计自监督模型monodepth2论文总结和源码分析 ...

最新文章

  1. centos7开放端口
  2. 难道我的事,又要落空么。。。
  3. javaweb学习总结(四十四)——监听器(Listener)学习
  4. NuGet社区使用体验调查
  5. wepy组件子父传值_【WePY小程序框架实战三】-组件传值
  6. HDU(1175),连连看,BFS
  7. 让效率“爆表”的49个数据可视化工具
  8. 运行jar包提示找不到.properties文件的问题
  9. 【动态规划】记录每步选择:牛客网:连续子数组的最大和(二)
  10. JS学习之Object
  11. 【软件操作】WinRAR 实现安装更新 操作
  12. 2倍研发费用=营销费用,小牛电动“智”在何方?
  13. 页面劫持修复方法,织梦dedeCMS被流氓网站劫持如何解决
  14. 《嵌入式 - 深入剖析STM32》详解STM32时钟系统
  15. opencv | cv2|OpenCV3.3安装Make sure that you use the correct version of ‘pip‘ installed for your Pytho
  16. 树莓派通过API向企业微信推送图文
  17. 名帖73 柳公权 楷书《玄秘塔碑》
  18. java -jar命令运行jar包时指定外部依赖jar包
  19. 实验八---理解进程调度时机跟踪分析进程调度与进程切换的过程
  20. 上海大学计算机学院2021,2020-2021学年秋季学期新生选课通知

热门文章

  1. Prefab资源和Random类 .
  2. Unity重要知识点
  3. php影视管理系统下载,GitHub - ganjmeng/SEACMS: 海洋cms 海洋影视管理系统 - 免费开源PHP...
  4. 【AI视觉】智能送药小车——1.复盘及核心代码
  5. SpringBoot使用Jedis实现zset数据类型获取过去24h的数据
  6. P1422 小玉家的电费
  7. 高通 wlan 调试总结随笔
  8. 深度学习之强化学习(1)强化学习案例
  9. 高度塌陷问题和解决高度塌陷问题
  10. 2019.5.5_我能做到,哈哈哈哈_pygame突击