使用mmsegmentation训练自己的数据集

总体流程

  • 安装
  • 注册数据集
  • 编写配置文件
  • 运行
  • 测试

安装

1
2
3
pip install mmcv
# 注意只是 clone 是不行的,还要 install 一下产生版本文件
pip install git+https://github.com/open-mmlab/mmsegmentation.git # install the master branch

更多安装方法参考官方文档

注册数据集

  • mmseg/datasets 目录下添加自己的数据集的 .py 文件,这里主要是让框架知道模型的类别,下面 suffix 根据自己实际情况修改
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import os.path as osp

from .builder import DATASETS
from .custom import CustomDataset


@DATASETS.register_module()
class SatelliteDataset(CustomDataset):
"""Satellite dataset.

The ``img_suffix`` is fixed to '.tif' and ``seg_map_suffix`` is
fixed to '.png'.
"""

CLASSES = ('ford', 'transportation', 'building', 'farmland', 'grassland',
'woodland', 'bare_soil', 'others')

PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
[4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255]]

def __init__(self, **kwargs):
super(SatelliteDataset, self).__init__(
img_suffix='.tif',
seg_map_suffix='.png',
reduce_zero_label=False,
**kwargs)
assert osp.exists(self.img_dir)
  • mmseg/datasets/__init__.py 中导入你自定义的类,并在 __all__ 变量中添加你的类名

编写配置文件

configs/你要用的方法/ 下创建一个 .py 文件,配置文件主要由四部分组成:

  • 使用模型
  • 数据集及数据处理流程
  • 模型调度方法
  • runtime 配置

可以引入已有的配置,如果要修改配置就新建一个 dict 覆盖掉原来的配置项(不需要全部字段都有),如下所示:

1
2
3
4
5
6
7
8
9
model = dict(
# 修改类别
decode_head=dict(num_classes=8, norm_cfg=norm_cfg),
auxiliary_head=dict(num_classes=8, norm_cfg=norm_cfg),
# 修改预训练路径
pretrained='open-mmlab://resnet101_v1c',
# 修改训练 backbone
backbone=dict(depth=101, norm_cfg=norm_cfg)
)

放一个完整的配置:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
_base_ = [
'../_base_/models/pspnet_r50-d8.py',
'../_base_/default_runtime.py',
'../_base_/schedules/schedule_40k.py'
]
# norm_cfg = dict(type='BN', requires_grad=True)
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
decode_head=dict(num_classes=8, norm_cfg=norm_cfg),
auxiliary_head=dict(num_classes=8, norm_cfg=norm_cfg),
pretrained='open-mmlab://resnet101_v1c',
backbone=dict(depth=101, norm_cfg=norm_cfg)
)

dataset_type = 'SatelliteDataset'
data_root = '/home/sse/data4T/common_datasets/satelite_dataset/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
img_scale = (256, 256)
# crop_size = (224, 224)

train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=img_scale,
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=12,
workers_per_gpu=0,
train=dict(
type=dataset_type,
data_root=data_root,
img_dir='img_dir/train',
ann_dir='ann_dir/train',
# split='ImageSets/SegmentationContext/train.txt',
pipeline=train_pipeline),
val=dict( # 训练到一定轮次会自动验证
type=dataset_type,
data_root=data_root,
img_dir='img_dir/test',
ann_dir='ann_dir/test',
# split='ImageSets/SegmentationContext/val.txt',
pipeline=test_pipeline),
test=dict( # 测试的时候才用到
type=dataset_type,
data_root=data_root,
img_dir='image_A/image_A_9',
# ann_dir='ann_dir/test',
# split='ImageSets/SegmentationContext/val.txt',
pipeline=test_pipeline))

total_iters = 100000
checkpoint_config = dict(by_epoch=False, interval=4000)
evaluation = dict(interval=4000, metric='mIoU')

# 训练结果保存路径
work_dir = '/home/sse/mmsegmentation/run/satellite-10-12'

详细配置还是参考官方文档

运行

单卡运行:

1
python tools/train.py 配置文件路径名

分布式运行

1
./tools/dist_train.sh 配置文件路径名 GPU数 [optional arguments]

从已有参数从头开始运行

1
./tools/dist_train.sh 配置文件路径名 GPU数 --load-from 参数路径

从已有参数从头继续运行

1
./tools/dist_train.sh 配置文件路径名 GPU数 --resume-from 参数路径

更详细的依然参考官方文档

运行时候的==一个坑==

验证数据集在验证时会把所有的预测结果和 GT 保存在内存中,如果验证集太大很可能进程会挂掉,测试同

测试

测试的时候默认使用配置文件中指定的测试集

1
2
3
4
5
# single-gpu testing
python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [--out ${RESULT_FILE}] [--eval ${EVAL_METRICS}] [--show]

# multi-gpu testing
./tools/dist_test.sh ${CONFIG_FILE} ${CHECKPOINT_FILE} ${GPU_NUM} [--out ${RESULT_FILE}] [--eval ${EVAL_METRICS}]

可选参数

  • RESULT_FILE: pickle 文件,结果保存在其中
  • EVAL_METRICS: 评价指标,制定之后需要标签文件
  • –show: 结果会在新窗口中打开
  • –show-dir: 可视化结果保存到指定文件夹中,注意保存的不是神经网络出来的结果,而是经过 RGB 调色之后和原图叠加的结果