stevengrove
initial commit
186701e
|
raw
history blame
No virus
10.6 kB

Distill RTM Detectors Based on MMRazor

Description

To further improve the model accuracy while not introducing much additional computation cost, we apply the feature-based distillation to the training phase of these RTM detectors. In summary, our distillation strategy are threefold:

(1) Inspired by PKD, we first normalize the intermediate feature maps to have zero mean and unit variances before calculating the distillation loss.

(2) Inspired by CWD, we adopt the channel-wise distillation paradigm, which can pay more attention to the most salient regions of each channel.

(3) Inspired by DAMO-YOLO, the distillation process is split into two stages. 1) The teacher distills the student at the first stage (280 epochs) on strong mosaic domain. 2) The student finetunes itself on no masaic domain at the second stage (20 epochs).

Results and Models

Location Dataset Teacher Student mAP mAP(T) mAP(S) Config Download
FPN COCO RTMDet-s RTMDet-tiny 41.8 (+0.8) 44.6 41.0 config teacher |model | log
FPN COCO RTMDet-m RTMDet-s 45.7 (+1.1) 49.3 44.6 config teacher |model | log
FPN COCO RTMDet-l RTMDet-m 50.2 (+0.9) 51.4 49.3 config teacher |model | log
FPN COCO RTMDet-x RTMDet-l 52.3 (+0.9) 52.8 51.4 config teacher |model | log

Usage

Prerequisites

Install MMRazor from source

git clone -b dev-1.x https://github.com/open-mmlab/mmrazor.git
cd mmrazor
# Install MMRazor
mim install -v -e .

Training commands

In MMYOLO's root directory, run the following command to train the RTMDet-tiny with 8 GPUs, using RTMDet-s as the teacher:

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 PORT=29500 ./tools/dist_train.sh configs/rtmdet/distillation/kd_tiny_rtmdet_s_neck_300e_coco.py

Testing commands

In MMYOLO's root directory, run the following command to test the model:

CUDA_VISIBLE_DEVICES=0 PORT=29500 ./tools/dist_test.sh configs/rtmdet/distillation/kd_tiny_rtmdet_s_neck_300e_coco.py ${CHECKPOINT_PATH}

Getting student-only checkpoint

After training, the checkpoint contains parameters for both student and teacher models. Run the following command to convert it to student-only checkpoint:

python ./tools/model_converters/convert_kd_ckpt_to_student.py ${CHECKPOINT_PATH} --out-path ${OUTPUT_CHECKPOINT_PATH}

Configs

Here we provide detection configs and models for MMRazor in MMYOLO. For clarify, we take ./kd_tiny_rtmdet_s_neck_300e_coco.py as an example to show how to distill a RTM detector based on MMRazor.

Here is the main part of ./kd_tiny_rtmdet_s_neck_300e_coco.py.

norm_cfg = dict(type='BN', affine=False, track_running_stats=False)

distiller=dict(
    type='ConfigurableDistiller',
    student_recorders=dict(
        fpn0=dict(type='ModuleOutputs', source='neck.out_layers.0.conv'),
        fpn1=dict(type='ModuleOutputs', source='neck.out_layers.1.conv'),
        fpn2=dict(type='ModuleOutputs', source='neck.out_layers.2.conv'),
    ),
    teacher_recorders=dict(
        fpn0=dict(type='ModuleOutputs', source='neck.out_layers.0.conv'),
        fpn1=dict(type='ModuleOutputs', source='neck.out_layers.1.conv'),
        fpn2=dict(type='ModuleOutputs', source='neck.out_layers.2.conv')),
    connectors=dict(
        fpn0_s=dict(type='ConvModuleConnector', in_channel=96,
            out_channel=128, bias=False, norm_cfg=norm_cfg,
            act_cfg=None),
        fpn0_t=dict(
            type='NormConnector', in_channels=128, norm_cfg=norm_cfg),
        fpn1_s=dict(
            type='ConvModuleConnector', in_channel=96,
            out_channel=128, bias=False, norm_cfg=norm_cfg,
            act_cfg=None),
        fpn1_t=dict(
            type='NormConnector', in_channels=128, norm_cfg=norm_cfg),
        fpn2_s=dict(
            type='ConvModuleConnector', in_channel=96,
            out_channel=128, bias=False, norm_cfg=norm_cfg,
            act_cfg=None),
        fpn2_t=dict(
            type='NormConnector', in_channels=128, norm_cfg=norm_cfg)),
    distill_losses=dict(
        loss_fpn0=dict(type='ChannelWiseDivergence', loss_weight=1),
        loss_fpn1=dict(type='ChannelWiseDivergence', loss_weight=1),
        loss_fpn2=dict(type='ChannelWiseDivergence', loss_weight=1)),
    loss_forward_mappings=dict(
        loss_fpn0=dict(
            preds_S=dict(from_student=True, recorder='fpn0', connector='fpn0_s'),
            preds_T=dict(from_student=False, recorder='fpn0', connector='fpn0_t')),
        loss_fpn1=dict(
            preds_S=dict(from_student=True, recorder='fpn1', connector='fpn1_s'),
            preds_T=dict(from_student=False, recorder='fpn1', connector='fpn1_t')),
        loss_fpn2=dict(
            preds_S=dict(from_student=True, recorder='fpn2', connector='fpn2_s'),
            preds_T=dict(from_student=False, recorder='fpn2', connector='fpn2_t'))))

recorders are used to record various intermediate results during the model forward. In this example, they can help record the output of 3 nn.Module of the teacher and the student. Details are list in Recorder and MMRazor Distillation (if users can read Chinese).

connectors are adaptive layers which usually map teacher's and students features to the same dimension.

distill_losses are configs for multiple distill losses.

loss_forward_mappings are mappings between distill loss forward arguments and records.

In addition, the student finetunes itself on no masaic domain at the last 20 epochs, so we add a new hook named StopDistillHook to stop distillation on time. We need to add this hook to the custom_hooks list like this:

custom_hooks = [..., dict(type='mmrazor.StopDistillHook', detach_epoch=280)]