使用mmsegmentation训练自己的数据集

Author Avatar
patrickcty 10月 14, 2020

总体流程

  • 安装
  • 注册数据集
  • 编写配置文件
  • 运行
  • 测试

安装

pip install mmcv
# 注意只是 clone 是不行的,还要 install 一下产生版本文件
pip install git+https://github.com/open-mmlab/mmsegmentation.git # install the master branch

更多安装方法参考官方文档

注册数据集

  • mmseg/datasets 目录下添加自己的数据集的 .py 文件,这里主要是让框架知道模型的类别,下面 suffix 根据自己实际情况修改
import os.path as osp

from .builder import DATASETS
from .custom import CustomDataset


@DATASETS.register_module()
class SatelliteDataset(CustomDataset):
    """Satellite dataset.

    The ``img_suffix`` is fixed to '.tif' and ``seg_map_suffix`` is
    fixed to '.png'.
    """

    CLASSES = ('ford', 'transportation', 'building', 'farmland', 'grassland',
               'woodland', 'bare_soil',  'others')

    PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
               [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255]]

    def __init__(self, **kwargs):
        super(SatelliteDataset, self).__init__(
            img_suffix='.tif',
            seg_map_suffix='.png',
            reduce_zero_label=False,
            **kwargs)
        assert osp.exists(self.img_dir)
  • mmseg/datasets/__init__.py 中导入你自定义的类,并在 __all__ 变量中添加你的类名

编写配置文件

configs/你要用的方法/ 下创建一个 .py 文件,配置文件主要由四部分组成:

  • 使用模型
  • 数据集及数据处理流程
  • 模型调度方法
  • runtime 配置

可以引入已有的配置,如果要修改配置就新建一个 dict 覆盖掉原来的配置项(不需要全部字段都有),如下所示:

model = dict(
    # 修改类别
    decode_head=dict(num_classes=8, norm_cfg=norm_cfg),
    auxiliary_head=dict(num_classes=8, norm_cfg=norm_cfg),
    # 修改预训练路径
    pretrained='open-mmlab://resnet101_v1c',
    # 修改训练 backbone
    backbone=dict(depth=101, norm_cfg=norm_cfg)
)

放一个完整的配置:

_base_ = [
    '../_base_/models/pspnet_r50-d8.py',
    '../_base_/default_runtime.py',
    '../_base_/schedules/schedule_40k.py'
]
# norm_cfg = dict(type='BN', requires_grad=True)
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
    decode_head=dict(num_classes=8, norm_cfg=norm_cfg),
    auxiliary_head=dict(num_classes=8, norm_cfg=norm_cfg),
    pretrained='open-mmlab://resnet101_v1c',
    backbone=dict(depth=101, norm_cfg=norm_cfg)
)

dataset_type = 'SatelliteDataset'
data_root = '/home/sse/data4T/common_datasets/satelite_dataset/'
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
img_scale = (256, 256)
# crop_size = (224, 224)

train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations'),
    dict(type='RandomFlip', flip_ratio=0.5),
    dict(type='PhotoMetricDistortion'),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=img_scale,
        # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
        flip=False,
        transforms=[
            dict(type='Normalize', **img_norm_cfg),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img']),
        ])
]
data = dict(
    samples_per_gpu=12,
    workers_per_gpu=0,
    train=dict(
        type=dataset_type,
        data_root=data_root,
        img_dir='img_dir/train',
        ann_dir='ann_dir/train',
        # split='ImageSets/SegmentationContext/train.txt',
        pipeline=train_pipeline),
    val=dict(  # 训练到一定轮次会自动验证
        type=dataset_type,
        data_root=data_root,
        img_dir='img_dir/test',
        ann_dir='ann_dir/test',
        # split='ImageSets/SegmentationContext/val.txt',
        pipeline=test_pipeline),
    test=dict(  # 测试的时候才用到
        type=dataset_type,
        data_root=data_root,
        img_dir='image_A/image_A_9',
        # ann_dir='ann_dir/test',
        # split='ImageSets/SegmentationContext/val.txt',
        pipeline=test_pipeline))

total_iters = 100000
checkpoint_config = dict(by_epoch=False, interval=4000)
evaluation = dict(interval=4000, metric='mIoU')

# 训练结果保存路径
work_dir = '/home/sse/mmsegmentation/run/satellite-10-12'

详细配置还是参考官方文档

运行

单卡运行:

python tools/train.py 配置文件路径名

分布式运行

./tools/dist_train.sh 配置文件路径名 GPU数 [optional arguments]

从已有参数从头开始运行

./tools/dist_train.sh 配置文件路径名 GPU数 --load-from 参数路径

从已有参数从头继续运行

./tools/dist_train.sh 配置文件路径名 GPU数 --resume-from 参数路径

更详细的依然参考官方文档

运行时候的==一个坑==

验证数据集在验证时会把所有的预测结果和 GT 保存在内存中,如果验证集太大很可能进程会挂掉,测试同

测试

测试的时候默认使用配置文件中指定的测试集

# single-gpu testing
python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [--out ${RESULT_FILE}] [--eval ${EVAL_METRICS}] [--show]

# multi-gpu testing
./tools/dist_test.sh ${CONFIG_FILE} ${CHECKPOINT_FILE} ${GPU_NUM} [--out ${RESULT_FILE}] [--eval ${EVAL_METRICS}]

可选参数

  • RESULT_FILE: pickle 文件,结果保存在其中
  • EVAL_METRICS: 评价指标,制定之后需要标签文件
  • –show: 结果会在新窗口中打开
  • –show-dir: 可视化结果保存到指定文件夹中,注意保存的不是神经网络出来的结果,而是经过 RGB 调色之后和原图叠加的结果