diff --git a/configs/_base_/datasets/semi_dotav15_detection.py b/configs/_base_/datasets/semi_dotav15_detection.py new file mode 100644 index 000000000..a0ada84eb --- /dev/null +++ b/configs/_base_/datasets/semi_dotav15_detection.py @@ -0,0 +1,166 @@ +custom_imports = dict( + imports=['mmpretrain.datasets.transforms'], allow_failed_imports=False) + +# dataset settings +dataset_type = 'DOTAv15Dataset' +data_root = 'data/split_ss_dota1_5/' +backend_args = None + +branch_field = ['sup', 'unsup_teacher', 'unsup_student'] +# pipeline used to augment labeled data, +# which will be sent to student model for supervised training. +sup_pipeline = [ + dict(type='mmdet.LoadImageFromFile', backend_args=backend_args), + dict(type='mmdet.LoadAnnotations', with_bbox=True, box_type='qbox'), + dict(type='ConvertBoxType', box_type_mapping=dict(gt_bboxes='rbox')), + dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True), + dict( + type='mmdet.RandomFlip', + prob=0.75, + direction=['horizontal', 'vertical', 'diagonal']), + # dict(type='mmdet.FilterAnnotations', min_gt_bbox_wh=(1e-2, 1e-2)), + dict(type='mmdet.Pad', size_divisor=32, pad_val=dict(img=(114, 114, 114))), + dict( + type='mmdet.MultiBranch', + branch_field=branch_field, + sup=dict(type='mmdet.PackDetInputs')) +] + +# pipeline used to augment unlabeled data weakly, +# which will be sent to teacher model for predicting pseudo instances. +weak_pipeline = [ + dict(type='mmdet.Pad', size_divisor=32, pad_val=dict(img=(114, 114, 114))), + dict( + type='mmdet.PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor', 'flip', 'flip_direction', + 'homography_matrix')), +] + +# pipeline used to augment unlabeled data strongly, +# which will be sent to student model for unsupervised training. +strong_pipeline = [ + dict( + type='RandomApply', + transforms=dict( + type='mmpretrain.ColorJitter', + brightness=0.4, + contrast=0.4, + saturation=0.4, + hue=0.1), + prob=0.8), + dict(type='mmpretrain.RandomGrayscale', prob=0.2, keep_channels=True), + dict( + type='mmpretrain.GaussianBlur', + radius=None, + prob=0.5, + magnitude_level=1.9, + magnitude_range=[0.1, 2.0], + magnitude_std='inf', + total_level=1.9), + dict(type='mmdet.Pad', size_divisor=32, pad_val=dict(img=(114, 114, 114))), + dict( + type='mmdet.PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor', 'flip', 'flip_direction', + 'homography_matrix')), +] + +# pipeline used to augment unlabeled data into different views +unsup_pipeline = [ + dict(type='mmdet.LoadImageFromFile', backend_args=backend_args), + dict(type='mmdet.LoadEmptyAnnotations'), + dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True), + dict( + type='mmdet.RandomFlip', + prob=0.75, + direction=['horizontal', 'vertical', 'diagonal']), + dict( + type='mmdet.MultiBranch', + branch_field=branch_field, + unsup_teacher=weak_pipeline, + unsup_student=strong_pipeline, + ) +] + +val_pipeline = [ + dict(type='mmdet.LoadImageFromFile', backend_args=backend_args), + dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True), + dict(type='mmdet.LoadAnnotations', with_bbox=True, box_type='qbox'), + dict(type='ConvertBoxType', box_type_mapping=dict(gt_bboxes='rbox')), + dict( + type='mmdet.PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor')) +] + +test_pipeline = [ + dict(type='mmdet.LoadImageFromFile', backend_args=backend_args), + dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True), + dict( + type='mmdet.PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor')) +] + +batch_size = 3 +num_workers = 6 +# There are two common semi-supervised learning settings on the coco dataset: +# (1) Divide the train2017 into labeled and unlabeled datasets +# by a fixed percentage, such as 1%, 2%, 5% and 10%. +# The format of labeled_ann_file and unlabeled_ann_file are +# instances_train2017.{fold}@{percent}.json, and +# instances_train2017.{fold}@{percent}-unlabeled.json +# `fold` is used for cross-validation, and `percent` represents +# the proportion of labeled data in the train2017. +# (2) Choose the train2017 as the labeled dataset +# and unlabeled2017 as the unlabeled dataset. +# The labeled_ann_file and unlabeled_ann_file are +# instances_train2017.json and image_info_unlabeled2017.json +# We use this configuration by default. +labeled_dataset = dict( + type=dataset_type, + data_root=data_root, + ann_file='train_10_labeled/annfiles', + data_prefix=dict(img_path='train_10_labeled/images/'), + filter_cfg=dict(filter_empty_gt=True), + pipeline=sup_pipeline) + +unlabeled_dataset = dict( + type=dataset_type, + data_root=data_root, + ann_file='train_10_unlabeled/empty_annfiles/', + data_prefix=dict(img_path='train_10_unlabeled/images/'), + filter_cfg=dict(filter_empty_gt=False), + pipeline=unsup_pipeline) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=num_workers, + persistent_workers=True, + sampler=dict( + type='mmdet.MultiSourceSampler', + batch_size=batch_size, + source_ratio=[2, 1]), + dataset=dict( + type='ConcatDataset', datasets=[labeled_dataset, unlabeled_dataset])) + +val_dataloader = dict( + batch_size=1, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='val/annfiles/', + data_prefix=dict(img_path='val/images/'), + test_mode=True, + pipeline=val_pipeline)) + +test_dataloader = val_dataloader + +val_evaluator = dict(type='DOTAMetric', metric='mAP') + +test_evaluator = val_evaluator diff --git a/configs/sood/README.md b/configs/sood/README.md new file mode 100644 index 000000000..abce37523 --- /dev/null +++ b/configs/sood/README.md @@ -0,0 +1,96 @@ +# SOOD + +> [SOOD: Towards Semi-Supervised Oriented Object Detection](https://arxiv.org/abs/2304.04515) + + + +## Abstract + +Semi-Supervised Object Detection (SSOD), aiming to explore unlabeled data for boosting object detectors, has become an active task in recent years. However, existing SSOD approaches mainly focus on horizontal objects, leaving multi-oriented objects that are common in aerial images unexplored. This paper proposes a novel Semi-supervised Oriented Object Detection model, termed SOOD, built upon the mainstream pseudo-labeling framework. Towards oriented objects in aerial scenes, we design two loss functions to provide better supervision. Focusing on the orientations of objects, the first loss regularizes the consistency between each pseudo-label-prediction pair (includes a prediction and its corresponding pseudo label) with adaptive weights based on their orientation gap. Focusing on the layout of an image, the second loss regularizes the similarity and explicitly builds the many-to-many relation between the sets of pseudo-labels and predictions. Such a global consistency constraint can further boost semi-supervised learning. Our experiments show that when trained with the two proposed losses, SOOD surpasses the state-of-the-art SSOD methods under various settings on the DOTA-v1.5 benchmark. + +## Requirements + +- `mmpretrain>=1.0.0` + please refer to [mmpretrain](https://mmpretrain.readthedocs.io/en/latest/get_started.html) for installation. + +## Data Preparation + +Please refer to [data_preparation.md](tools/data/dota/README.md) to prepare the original data. After that, the data folder should be organized as follows: + +``` +├── data +│ ├── split_ss_dota1_5 +│ │ ├── train +│ │ │ ├── images +│ │ │ ├── annfiles +│ │ ├── val +│ │ │ ├── images +│ │ │ ├── annfiles +│ │ ├── test +│ │ │ ├── images +│ │ │ ├── annfiles +``` + +For partial labeled setting, we split the DOTA-v1.5's train set via the author released [split data list](tools/misc/split_dota1.5_lists) and [split tool](tools/misc/split_dota1.5_via_lists.py) + +```angular2html +python tools/misc/split_dota1.5_via_lists.py +``` + +For fully labeled setting, we use DOTA-V1.5 train as labeled set and DOTA-V1.5 test as unlabeled set. + +After that, the data folder should be organized as follows: + +``` +├── data +│ ├── split_ss_dota1_5 +│ │ ├── train +│ │ │ ├── images +│ │ │ ├── annfiles +│ │ ├── train_10_labeled +│ │ │ ├── images +│ │ │ ├── annfiles +│ │ ├── train_10_unlabeled +│ │ │ ├── images +│ │ │ ├── annfiles +│ │ ├── train_20_labeled +│ │ │ ├── images +│ │ │ ├── annfiles +│ │ ├── train_20_unlabeled +│ │ │ ├── images +│ │ │ ├── annfiles +│ │ ├── train_30_labeled +│ │ │ ├── images +│ │ │ ├── annfiles +│ │ ├── train_30_unlabeled +│ │ │ ├── images +│ │ │ ├── annfiles +│ │ ├── val +│ │ │ ├── images +│ │ │ ├── annfiles +│ │ ├── test +│ │ │ ├── images +│ │ │ ├── annfiles +``` + +## Results + +DOTA1.5 + +| Backbone | Setting | mAP50 | Mem (GB) | Config | +| :----------------------: | :-----: | :---: | :------: | :-------------------------------------------------------------: | +| ResNet50 (1024,1024,200) | 10% | 47.93 | 8.45 | [config](./sood_fcos_r50_fpn_2xb3-180000k_semi-0.1-dotav1.5.py) | +| ResNet50 (1024,1024,200) | 20% | | | [config](./sood_fcos_r50_fpn_2xb3-180000k_semi-0.2-dotav1.5.py) | +| ResNet50 (1024,1024,200) | 30% | | | [config](./sood_fcos_r50_fpn_2xb3-180000k_semi-0.3-dotav1.5.py) | + +## Citation + +``` +@inproceedings{hua2023sood, + title={SOOD: Towards Semi-Supervised Oriented Object Detection}, + author={Hua, Wei and Liang, Dingkang and Li, Jingyu and Liu, Xiaolong and Zou, Zhikang and Ye, Xiaoqing and Bai, Xiang}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, + pages={15558--15567}, + year={2023} +} +``` diff --git a/configs/sood/rotated-fcos-le90_r50_fpn_dotav15.py b/configs/sood/rotated-fcos-le90_r50_fpn_dotav15.py new file mode 100644 index 000000000..9269cda01 --- /dev/null +++ b/configs/sood/rotated-fcos-le90_r50_fpn_dotav15.py @@ -0,0 +1,65 @@ +"""copy from rotated fcos.""" +angle_version = 'le90' + +# model settings +model = dict( + type='mmdet.FCOS', + data_preprocessor=dict( + type='mmdet.DetDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=32, + boxtype2tensor=False), + backbone=dict( + type='mmdet.ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch', + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), + neck=dict( + type='mmdet.FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs='on_output', + num_outs=5, + relu_before_extra_convs=True), + bbox_head=dict( + type='RotatedFCOSHead', + num_classes=16, + in_channels=256, + stacked_convs=4, + feat_channels=256, + strides=[8, 16, 32, 64, 128], + center_sampling=True, + center_sample_radius=1.5, + norm_on_bbox=True, + centerness_on_reg=True, + use_hbbox_loss=False, + scale_angle=True, + bbox_coder=dict( + type='DistanceAnglePointCoder', angle_version=angle_version), + loss_cls=dict( + type='mmdet.FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='RotatedIoULoss', loss_weight=1.0), + loss_angle=None, + loss_centerness=dict( + type='mmdet.CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + ), + # training and testing settings + train_cfg=None, + test_cfg=dict( + nms_pre=2000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms_rotated', iou_threshold=0.1), + max_per_img=2000)) diff --git a/configs/sood/sood_fcos_r50_fpn_2xb3-180000k_semi-0.1-dotav1.5.py b/configs/sood/sood_fcos_r50_fpn_2xb3-180000k_semi-0.1-dotav1.5.py new file mode 100644 index 000000000..25bfa3ee9 --- /dev/null +++ b/configs/sood/sood_fcos_r50_fpn_2xb3-180000k_semi-0.1-dotav1.5.py @@ -0,0 +1,99 @@ +_base_ = [ + 'rotated-fcos-le90_r50_fpn_dotav15.py', '../_base_/default_runtime.py', + '../_base_/datasets/semi_dotav15_detection.py' +] +# todo: fix this import issue +custom_imports = dict( + imports=['mmrotate.engine.hooks.mean_teacher_hook'], + allow_failed_imports=False) + +detector = _base_.model +model = dict( + _delete_=True, + type='SOOD', + detector=detector, + data_preprocessor=dict( + type='mmdet.MultiBranchDataPreprocessor', + data_preprocessor=detector.data_preprocessor), + semi_train_cfg=dict( + freeze_teacher=True, + iter_count=0, + burn_in_steps=6400, + sup_weight=1.0, + unsup_weight=1.0, + weight_suppress='linear', + logit_specific_weights=dict(), + symmertry_aware=False, + ), + semi_loss_cfg=dict( + pseudo_label_type='pr_origin_p5', + cls_pseudo_thr=0.5, + cls_loss_type='BCE', + bbox_loss_type='l1', + aux_loss='ot_loss_norm', + aux_loss_cfg=dict(loss_weight=1.0, cost_type='all', clamp_ot=True), + rbox_pts_ratio=0.25, + dynamic_weight='50ang', + ), + semi_test_cfg=dict(predict_on='teacher')) + +# 10% dotav1.5 train is set as labeled dataset +labeled_dataset = _base_.labeled_dataset +unlabeled_dataset = _base_.unlabeled_dataset + +batch_size = 3 +num_workers = 6 +train_dataloader = dict( + batch_size=batch_size, + num_workers=num_workers, + sampler=dict( + type='mmdet.MultiSourceSampler', + batch_size=batch_size, + source_ratio=[2, 1]), + dataset=dict(datasets=[labeled_dataset, unlabeled_dataset])) + +# training schedule for 180k +train_cfg = dict( + type='IterBasedTrainLoop', max_iters=180000, val_interval=3200) +val_cfg = dict(type='mmdet.TeacherStudentValLoop') +test_cfg = dict(type='TestLoop') + +# learning rate policy +param_scheduler = [ + dict( + type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, + end=1000), + dict( + type='MultiStepLR', + begin=1000, + end=180000, + by_epoch=False, + milestones=[120000, 160000], + gamma=0.1) +] + +# optimizer +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='SGD', lr=0.0025, momentum=0.9, weight_decay=0.0001), + clip_grad=dict(max_norm=35, norm_type=2)) + +default_hooks = dict( + checkpoint=dict( + by_epoch=False, interval=3200, max_keep_ckpts=1000, save_best='auto'), + logger=dict(type='LoggerHook', interval=50), +) + +log_processor = dict(type='LogProcessor', window_size=50, by_epoch=False) + +custom_hooks = [ + dict(type='MeanTeacherHook', start_iter=3200, momentum=0.0004), +] + +vis_backends = [dict(type='TensorboardVisBackend')] + +visualizer = dict( + type='RotLocalVisualizer', + vis_backends=vis_backends, + name='visualizer', + save_dir='work_dirs/sood_fcos_r50_fpn_2xb3-180000k_semi-0.1-dotav1.5') diff --git a/configs/sood/sood_fcos_r50_fpn_2xb3-180000k_semi-0.2-dotav1.5.py b/configs/sood/sood_fcos_r50_fpn_2xb3-180000k_semi-0.2-dotav1.5.py new file mode 100644 index 000000000..ae2e7ac35 --- /dev/null +++ b/configs/sood/sood_fcos_r50_fpn_2xb3-180000k_semi-0.2-dotav1.5.py @@ -0,0 +1,104 @@ +_base_ = [ + 'rotated-fcos-le90_r50_fpn_dotav15.py', '../_base_/default_runtime.py', + '../_base_/datasets/semi_dotav15_detection.py' +] +# todo: fix this import issue +custom_imports = dict( + imports=['mmrotate.engine.hooks.mean_teacher_hook'], + allow_failed_imports=False) + +detector = _base_.model +model = dict( + _delete_=True, + type='SOOD', + detector=detector, + data_preprocessor=dict( + type='mmdet.MultiBranchDataPreprocessor', + data_preprocessor=detector.data_preprocessor), + semi_train_cfg=dict( + freeze_teacher=True, + iter_count=0, + burn_in_steps=12800, + sup_weight=1.0, + unsup_weight=1.0, + weight_suppress='linear', + logit_specific_weights=dict(), + symmertry_aware=False, + ), + semi_loss_cfg=dict( + pseudo_label_type='pr_origin_p5', + cls_pseudo_thr=0.5, + cls_loss_type='BCE', + bbox_loss_type='l1', + aux_loss='ot_loss_norm', + aux_loss_cfg=dict(loss_weight=1.0, cost_type='all', clamp_ot=True), + rbox_pts_ratio=0.25, + dynamic_weight='50ang', + ), + semi_test_cfg=dict(predict_on='teacher')) + +# 20% dotav1.5 train is set as labeled dataset +labeled_dataset = _base_.labeled_dataset +labeled_dataset.ann_file = 'train_20_labeled/annfiles' +labeled_dataset.data_prefix = dict(img_path='train_20_labeled/images/') + +unlabeled_dataset = _base_.unlabeled_dataset +unlabeled_dataset.ann_file = 'train_20_unlabeled/empty_annfiles/' +unlabeled_dataset.data_prefix = dict(img_path='train_20_unlabeled/images/') + +batch_size = 3 +num_workers = 6 +train_dataloader = dict( + batch_size=batch_size, + num_workers=num_workers, + sampler=dict( + type='mmdet.MultiSourceSampler', + batch_size=batch_size, + source_ratio=[2, 1]), + dataset=dict(datasets=[labeled_dataset, unlabeled_dataset])) + +# training schedule for 180k +train_cfg = dict( + type='IterBasedTrainLoop', max_iters=180000, val_interval=3200) +val_cfg = dict(type='mmdet.TeacherStudentValLoop') +test_cfg = dict(type='TestLoop') + +# learning rate policy +param_scheduler = [ + dict( + type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, + end=1000), + dict( + type='MultiStepLR', + begin=1000, + end=180000, + by_epoch=False, + milestones=[120000, 160000], + gamma=0.1) +] + +# optimizer +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='SGD', lr=0.0025, momentum=0.9, weight_decay=0.0001), + clip_grad=dict(max_norm=35, norm_type=2)) + +default_hooks = dict( + checkpoint=dict( + by_epoch=False, interval=3200, max_keep_ckpts=1000, save_best='auto'), + logger=dict(type='LoggerHook', interval=50), +) + +log_processor = dict(type='LogProcessor', window_size=50, by_epoch=False) + +custom_hooks = [ + dict(type='MeanTeacherHook', start_iter=3200, momentum=0.0004), +] + +vis_backends = [dict(type='TensorboardVisBackend')] + +visualizer = dict( + type='RotLocalVisualizer', + vis_backends=vis_backends, + name='visualizer', + save_dir='work_dirs/sood_fcos_r50_fpn_2xb3-180000k_semi-0.2-dotav1.5') diff --git a/configs/sood/sood_fcos_r50_fpn_2xb3-180000k_semi-0.3-dotav1.5.py b/configs/sood/sood_fcos_r50_fpn_2xb3-180000k_semi-0.3-dotav1.5.py new file mode 100644 index 000000000..84a692e8d --- /dev/null +++ b/configs/sood/sood_fcos_r50_fpn_2xb3-180000k_semi-0.3-dotav1.5.py @@ -0,0 +1,104 @@ +_base_ = [ + 'rotated-fcos-le90_r50_fpn_dotav15.py', '../_base_/default_runtime.py', + '../_base_/datasets/semi_dotav15_detection.py' +] +# todo: fix this import issue +custom_imports = dict( + imports=['mmrotate.engine.hooks.mean_teacher_hook'], + allow_failed_imports=False) + +detector = _base_.model +model = dict( + _delete_=True, + type='SOOD', + detector=detector, + data_preprocessor=dict( + type='mmdet.MultiBranchDataPreprocessor', + data_preprocessor=detector.data_preprocessor), + semi_train_cfg=dict( + freeze_teacher=True, + iter_count=0, + burn_in_steps=12800, + sup_weight=1.0, + unsup_weight=1.0, + weight_suppress='linear', + logit_specific_weights=dict(), + symmertry_aware=False, + ), + semi_loss_cfg=dict( + pseudo_label_type='pr_origin_p5', + cls_pseudo_thr=0.5, + cls_loss_type='BCE', + bbox_loss_type='l1', + aux_loss='ot_loss_norm', + aux_loss_cfg=dict(loss_weight=1.0, cost_type='all', clamp_ot=True), + rbox_pts_ratio=0.25, + dynamic_weight='50ang', + ), + semi_test_cfg=dict(predict_on='teacher')) + +# 30% dotav1.5 train is set as labeled dataset +labeled_dataset = _base_.labeled_dataset +labeled_dataset.ann_file = 'train_30_labeled/annfiles' +labeled_dataset.data_prefix = dict(img_path='train_30_labeled/images/') + +unlabeled_dataset = _base_.unlabeled_dataset +unlabeled_dataset.ann_file = 'train_30_unlabeled/empty_annfiles/' +unlabeled_dataset.data_prefix = dict(img_path='train_30_unlabeled/images/') + +batch_size = 3 +num_workers = 6 +train_dataloader = dict( + batch_size=batch_size, + num_workers=num_workers, + sampler=dict( + type='mmdet.MultiSourceSampler', + batch_size=batch_size, + source_ratio=[2, 1]), + dataset=dict(datasets=[labeled_dataset, unlabeled_dataset])) + +# training schedule for 180k +train_cfg = dict( + type='IterBasedTrainLoop', max_iters=180000, val_interval=3200) +val_cfg = dict(type='mmdet.TeacherStudentValLoop') +test_cfg = dict(type='TestLoop') + +# learning rate policy +param_scheduler = [ + dict( + type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, + end=1000), + dict( + type='MultiStepLR', + begin=1000, + end=180000, + by_epoch=False, + milestones=[120000, 160000], + gamma=0.1) +] + +# optimizer +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='SGD', lr=0.0025, momentum=0.9, weight_decay=0.0001), + clip_grad=dict(max_norm=35, norm_type=2)) + +default_hooks = dict( + checkpoint=dict( + by_epoch=False, interval=3200, max_keep_ckpts=1000, save_best='auto'), + logger=dict(type='LoggerHook', interval=50), +) + +log_processor = dict(type='LogProcessor', window_size=50, by_epoch=False) + +custom_hooks = [ + dict(type='MeanTeacherHook', start_iter=3200, momentum=0.0004), +] + +vis_backends = [dict(type='TensorboardVisBackend')] + +visualizer = dict( + type='RotLocalVisualizer', + vis_backends=vis_backends, + name='visualizer', + save_dir='work_dirs/sood_fcos_r50_fpn_2xb3-180000k_semi-0.3-dotav1.5') diff --git a/configs/sood/sood_fcos_r50_fpn_2xb3-180000k_semi-full-dotav1.5.py b/configs/sood/sood_fcos_r50_fpn_2xb3-180000k_semi-full-dotav1.5.py new file mode 100644 index 000000000..4b4740150 --- /dev/null +++ b/configs/sood/sood_fcos_r50_fpn_2xb3-180000k_semi-full-dotav1.5.py @@ -0,0 +1,104 @@ +_base_ = [ + 'rotated-fcos-le90_r50_fpn_dotav15.py', '../_base_/default_runtime.py', + '../_base_/datasets/semi_dotav15_detection.py' +] +# todo: fix this import issue +custom_imports = dict( + imports=['mmrotate.engine.hooks.mean_teacher_hook'], + allow_failed_imports=False) + +detector = _base_.model +model = dict( + _delete_=True, + type='SOOD', + detector=detector, + data_preprocessor=dict( + type='mmdet.MultiBranchDataPreprocessor', + data_preprocessor=detector.data_preprocessor), + semi_train_cfg=dict( + freeze_teacher=True, + iter_count=0, + burn_in_steps=12800, + sup_weight=1.0, + unsup_weight=1.0, + weight_suppress='linear', + logit_specific_weights=dict(), + symmertry_aware=False, + ), + semi_loss_cfg=dict( + pseudo_label_type='pr_origin_p5', + cls_pseudo_thr=0.5, + cls_loss_type='BCE', + bbox_loss_type='l1', + aux_loss='ot_loss_norm', + aux_loss_cfg=dict(loss_weight=1.0, cost_type='all', clamp_ot=True), + rbox_pts_ratio=0.25, + dynamic_weight='50ang', + ), + semi_test_cfg=dict(predict_on='teacher')) + +# 100% dotav1.5 train is set as labeled dataset +labeled_dataset = _base_.labeled_dataset +labeled_dataset.ann_file = 'train/annfiles' +labeled_dataset.data_prefix = dict(img_path='train/images/') + +unlabeled_dataset = _base_.unlabeled_dataset +unlabeled_dataset.ann_file = 'test/annfiles/' +unlabeled_dataset.data_prefix = dict(img_path='test/images/') + +batch_size = 3 +num_workers = 6 +train_dataloader = dict( + batch_size=batch_size, + num_workers=num_workers, + sampler=dict( + type='mmdet.MultiSourceSampler', + batch_size=batch_size, + source_ratio=[2, 1]), + dataset=dict(datasets=[labeled_dataset, unlabeled_dataset])) + +# training schedule for 180k +train_cfg = dict( + type='IterBasedTrainLoop', max_iters=180000, val_interval=3200) +val_cfg = dict(type='mmdet.TeacherStudentValLoop') +test_cfg = dict(type='TestLoop') + +# learning rate policy +param_scheduler = [ + dict( + type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, + end=1000), + dict( + type='MultiStepLR', + begin=1000, + end=180000, + by_epoch=False, + milestones=[120000, 160000], + gamma=0.1) +] + +# optimizer +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='SGD', lr=0.0025, momentum=0.9, weight_decay=0.0001), + clip_grad=dict(max_norm=35, norm_type=2)) + +default_hooks = dict( + checkpoint=dict( + by_epoch=False, interval=3200, max_keep_ckpts=1000, save_best='auto'), + logger=dict(type='LoggerHook', interval=50), +) + +log_processor = dict(type='LogProcessor', window_size=50, by_epoch=False) + +custom_hooks = [ + dict(type='MeanTeacherHook', start_iter=3200, momentum=0.0004), +] + +vis_backends = [dict(type='TensorboardVisBackend')] + +visualizer = dict( + type='RotLocalVisualizer', + vis_backends=vis_backends, + name='visualizer', + save_dir='work_dirs/sood_fcos_r50_fpn_2xb3-180000k_semi-full-dotav1.5') diff --git a/mmrotate/__init__.py b/mmrotate/__init__.py index 41bd11ff4..4a57a4624 100644 --- a/mmrotate/__init__.py +++ b/mmrotate/__init__.py @@ -27,7 +27,7 @@ f'<{mmengine_maximum_version}.' mmdet_minimum_version = '3.0.0rc6' -mmdet_maximum_version = '3.2.0' +mmdet_maximum_version = '3.3.0' mmdet_version = digit_version(mmdet.__version__) assert (mmdet_version >= digit_version(mmdet_minimum_version) diff --git a/mmrotate/engine/__init__.py b/mmrotate/engine/__init__.py new file mode 100644 index 000000000..31785364b --- /dev/null +++ b/mmrotate/engine/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .hooks import * # noqa: F401, F403 diff --git a/mmrotate/engine/hooks/__init__.py b/mmrotate/engine/hooks/__init__.py new file mode 100644 index 000000000..886979788 --- /dev/null +++ b/mmrotate/engine/hooks/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .mean_teacher_hook import MeanTeacherHook + +__all__ = ['MeanTeacherHook'] diff --git a/mmrotate/engine/hooks/mean_teacher_hook.py b/mmrotate/engine/hooks/mean_teacher_hook.py new file mode 100644 index 000000000..1c5a29c5f --- /dev/null +++ b/mmrotate/engine/hooks/mean_teacher_hook.py @@ -0,0 +1,97 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import torch.nn as nn +from mmengine.hooks import Hook +from mmengine.model import is_model_wrapper +from mmengine.runner import Runner + +from mmrotate.registry import HOOKS + + +@HOOKS.register_module() +class MeanTeacherHook(Hook): + """Mean Teacher Hook. + + Mean Teacher is an efficient semi-supervised learning method in + `Mean Teacher `_. + This method requires two models with exactly the same structure, + as the student model and the teacher model, respectively. + The student model updates the parameters through gradient descent, + and the teacher model updates the parameters through + exponential moving average of the student model. + Compared with the student model, the teacher model + is smoother and accumulates more knowledge. + + Args: + momentum (float): The momentum used for updating teacher's parameter. + Teacher's parameter are updated with the formula: + `teacher = (1-momentum) * teacher + momentum * student`. + Defaults to 0.001. + interval (int): Update teacher's parameter every interval iteration. + Defaults to 1. + skip_buffers (bool): Whether to skip the model buffers, such as + batchnorm running stats (running_mean, running_var), it does not + perform the ema operation. Default to True. + """ + + def __init__(self, + momentum: float = 0.001, + interval: int = 1, + start_iter: int = 0, + skip_buffer: bool = True) -> None: + assert 0 < momentum < 1 + self.momentum = momentum + self.interval = interval + + self.start_iter = start_iter + self.skip_buffers = skip_buffer + + def before_train(self, runner: Runner) -> None: + """To check that teacher model and student model exist.""" + model = runner.model + if is_model_wrapper(model): + model = model.module + assert hasattr(model, 'teacher') + assert hasattr(model, 'student') + # only do it at initial stage + if runner.iter == 0: + self.momentum_update(model, 1) + + def after_train_iter(self, + runner: Runner, + batch_idx: int, + data_batch: Optional[dict] = None, + outputs: Optional[dict] = None) -> None: + """Update teacher's parameter every self.interval iterations.""" + if (runner.iter + 1) % self.interval != 0: + return + model = runner.model + if is_model_wrapper(model): + model = model.module + + if runner.iter < self.start_iter: + return + if runner.iter == self.start_iter: + print('start EMA update at', runner.iter) + self.momentum_update(model, 1) + else: + self.momentum_update(model, self.momentum) + + def momentum_update(self, model: nn.Module, momentum: float) -> None: + """Compute the moving average of the parameters using exponential + moving average.""" + if self.skip_buffers: + for (src_name, src_parm), (dst_name, dst_parm) in zip( + model.student.named_parameters(), + model.teacher.named_parameters()): + dst_parm.data.mul_(1 - momentum).add_( + src_parm.data, alpha=momentum) + else: + for (src_parm, + dst_parm) in zip(model.student.state_dict().values(), + model.teacher.state_dict().values()): + # exclude num_tracking + if dst_parm.dtype.is_floating_point: + dst_parm.data.mul_(1 - momentum).add_( + src_parm.data, alpha=momentum) diff --git a/mmrotate/models/detectors/__init__.py b/mmrotate/models/detectors/__init__.py index cbd025b5c..e34255b41 100644 --- a/mmrotate/models/detectors/__init__.py +++ b/mmrotate/models/detectors/__init__.py @@ -2,5 +2,10 @@ from .h2rbox import H2RBoxDetector from .h2rbox_v2 import H2RBoxV2Detector from .refine_single_stage import RefineSingleStageDetector +from .semi_base import RotatedSemiBaseDetector +from .sood import SOOD -__all__ = ['RefineSingleStageDetector', 'H2RBoxDetector', 'H2RBoxV2Detector'] +__all__ = [ + 'RefineSingleStageDetector', 'H2RBoxDetector', 'H2RBoxV2Detector', + 'RotatedSemiBaseDetector', 'SOOD' +] diff --git a/mmrotate/models/detectors/semi_base.py b/mmrotate/models/detectors/semi_base.py new file mode 100644 index 000000000..6cb27908f --- /dev/null +++ b/mmrotate/models/detectors/semi_base.py @@ -0,0 +1,284 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from mmdet.models.detectors.base import BaseDetector +from mmdet.models.utils import (filter_gt_instances, rename_loss_dict, + reweight_loss_dict) +from mmdet.structures import SampleList +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from torch import Tensor + +from mmrotate.registry import MODELS +from mmrotate.structures import RotatedBoxes +from mmrotate.structures.bbox import rbox_project + + +@MODELS.register_module() +class RotatedSemiBaseDetector(BaseDetector): + """Base class for semi-supervised detectors. almost the same as the + original semi_base.py in mmdet, but with some modifications to support + rotated bounding boxes and burn-in steps. + + Semi-supervised detectors typically consisting of a teacher model + updated by exponential moving average and a student model updated + by gradient descent. + + Args: + detector (:obj:`ConfigDict` or dict): The detector config. + semi_train_cfg (:obj:`ConfigDict` or dict, optional): + The semi-supervised training config. + semi_test_cfg (:obj:`ConfigDict` or dict, optional): + The semi-supervised testing config. + data_preprocessor (:obj:`ConfigDict` or dict, optional): Config of + :class:`DetDataPreprocessor` to process the input data. + Defaults to None. + init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or + list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + detector: ConfigType, + semi_train_cfg: OptConfigType = None, + semi_test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None) -> None: + super().__init__( + data_preprocessor=data_preprocessor, init_cfg=init_cfg) + self.student = MODELS.build(detector) + self.teacher = MODELS.build(detector) + self.semi_train_cfg = semi_train_cfg + self.semi_test_cfg = semi_test_cfg + if self.semi_train_cfg.get('freeze_teacher', True) is True: + self.freeze(self.teacher) + self.iter_count = self.semi_train_cfg.get('iter_count', 0) + self.burn_in_steps = self.semi_train_cfg.get('burn_in_steps', 6400) + self.visual = self.semi_train_cfg.get('visual', False) + self.visual_interval = self.semi_train_cfg.get('visual_interval', 800) + + @staticmethod + def freeze(model: nn.Module): + """Freeze the model.""" + model.eval() + for param in model.parameters(): + param.requires_grad = False + + def loss(self, multi_batch_inputs: Dict[str, Tensor], + multi_batch_data_samples: Dict[str, SampleList]) -> dict: + """Calculate losses from multi-branch inputs and data samples. + + Args: + multi_batch_inputs (Dict[str, Tensor]): The dict of multi-branch + input images, each value with shape (N, C, H, W). + Each value should usually be mean centered and std scaled. + multi_batch_data_samples (Dict[str, List[:obj:`DetDataSample`]]): + The dict of multi-branch data samples. + + Returns: + dict: A dictionary of loss components + """ + losses = dict() + losses.update(**self.loss_by_gt_instances( + multi_batch_inputs['sup'], multi_batch_data_samples['sup'])) + if self.iter_count > self.burn_in_steps: + origin_pseudo_data_samples, batch_info = self.get_pseudo_instances( + multi_batch_inputs['unsup_teacher'], + multi_batch_data_samples['unsup_teacher']) + multi_batch_data_samples[ + 'unsup_student'] = self.project_pseudo_instances( + origin_pseudo_data_samples, + multi_batch_data_samples['unsup_student']) + losses.update(**self.loss_by_pseudo_instances( + multi_batch_inputs['unsup_student'], + multi_batch_data_samples['unsup_student'], batch_info)) + + self.iter_count += 1 + return losses + + def loss_by_gt_instances(self, batch_inputs: Tensor, + batch_data_samples: SampleList) -> dict: + """Calculate losses from a batch of inputs and ground-truth data + samples. + + Args: + batch_inputs (Tensor): Input images of shape (N, C, H, W). + These should usually be mean centered and std scaled. + batch_data_samples (List[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + + Returns: + dict: A dictionary of loss components + """ + + losses = self.student.loss(batch_inputs, batch_data_samples) + + sup_weight = self.semi_train_cfg.get('sup_weight', 1.) + return rename_loss_dict('sup_', reweight_loss_dict(losses, sup_weight)) + + def loss_by_pseudo_instances(self, + batch_inputs: Tensor, + batch_data_samples: SampleList, + batch_info: Optional[dict] = None) -> dict: + """Calculate losses from a batch of inputs and pseudo data samples. + + Args: + batch_inputs (Tensor): Input images of shape (N, C, H, W). + These should usually be mean centered and std scaled. + batch_data_samples (List[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`, + which are `pseudo_instance` or `pseudo_panoptic_seg` + or `pseudo_sem_seg` in fact. + batch_info (dict): Batch information of teacher model + forward propagation process. Defaults to None. + + Returns: + dict: A dictionary of loss components + """ + batch_data_samples = filter_gt_instances( + batch_data_samples, score_thr=self.semi_train_cfg.cls_pseudo_thr) + + # convert tensor pseudo labels to RotatedBoxes pseudo labels + for data_samples in batch_data_samples: + data_samples.gt_instances.bboxes = RotatedBoxes( + data_samples.gt_instances.bboxes) + + losses = self.student.loss(batch_inputs, batch_data_samples) + pseudo_instances_num = sum([ + len(data_samples.gt_instances) + for data_samples in batch_data_samples + ]) + unsup_weight = self.semi_train_cfg.get( + 'unsup_weight', 1.) if pseudo_instances_num > 0 else 0. + burn_in_steps = self.semi_train_cfg.get('burn_in_steps', 5000) + target = burn_in_steps * 2 + if self.iter_count <= target: + unsup_weight *= (self.iter_count - burn_in_steps) / burn_in_steps + return rename_loss_dict('unsup_', + reweight_loss_dict(losses, unsup_weight)) + + @torch.no_grad() + def get_pseudo_instances( + self, batch_inputs: Tensor, batch_data_samples: SampleList + ) -> Tuple[SampleList, Optional[dict]]: + """Get pseudo instances from teacher model.""" + self.teacher.eval() + results_list = self.teacher.predict( + batch_inputs, batch_data_samples, rescale=False) + batch_info = {} + for data_samples, results in zip(batch_data_samples, results_list): + data_samples.gt_instances = copy.deepcopy(results.pred_instances) + data_samples.gt_instances.bboxes = rbox_project( + data_samples.gt_instances.bboxes, + torch.from_numpy(data_samples.homography_matrix).inverse().to( + self.data_preprocessor.device), data_samples.ori_shape) + return batch_data_samples, batch_info + + def project_pseudo_instances(self, batch_pseudo_instances: SampleList, + batch_data_samples: SampleList) -> SampleList: + """Project pseudo instances.""" + for pseudo_instances, data_samples in zip(batch_pseudo_instances, + batch_data_samples): + data_samples.gt_instances = pseudo_instances.gt_instances + data_samples.gt_instances.bboxes = rbox_project( + data_samples.gt_instances.bboxes, + torch.tensor(data_samples.homography_matrix).to( + self.data_preprocessor.device), data_samples.img_shape) + return batch_data_samples + + def predict(self, batch_inputs: Tensor, + batch_data_samples: SampleList) -> SampleList: + """Predict results from a batch of inputs and data samples with post- + processing. + + Args: + batch_inputs (Tensor): Inputs with shape (N, C, H, W). + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + rescale (bool): Whether to rescale the results. + Defaults to True. + + Returns: + list[:obj:`DetDataSample`]: Return the detection results of the + input images. The returns value is DetDataSample, + which usually contain 'pred_instances'. And the + ``pred_instances`` usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - masks (Tensor): Has a shape (num_instances, H, W). + """ + if self.semi_test_cfg.get('predict_on', 'teacher') == 'teacher': + return self.teacher( + batch_inputs, batch_data_samples, mode='predict') + else: + return self.student( + batch_inputs, batch_data_samples, mode='predict') + + def _forward(self, batch_inputs: Tensor, + batch_data_samples: SampleList) -> SampleList: + """Network forward process. Usually includes backbone, neck and head + forward without any post-processing. + + Args: + batch_inputs (Tensor): Inputs with shape (N, C, H, W). + + Returns: + tuple: A tuple of features from ``rpn_head`` and ``roi_head`` + forward. + """ + if self.semi_test_cfg.get('forward_on', 'teacher') == 'teacher': + return self.teacher( + batch_inputs, batch_data_samples, mode='tensor') + else: + return self.student( + batch_inputs, batch_data_samples, mode='tensor') + + def extract_feat(self, batch_inputs: Tensor) -> Tuple[Tensor]: + """Extract features. + + Args: + batch_inputs (Tensor): Image tensor with shape (N, C, H ,W). + + Returns: + tuple[Tensor]: Multi-level features that may have + different resolutions. + """ + if self.semi_test_cfg.get('extract_feat_on', 'teacher') == 'teacher': + return self.teacher.extract_feat(batch_inputs) + else: + return self.student.extract_feat(batch_inputs) + + def _load_from_state_dict(self, state_dict: dict, prefix: str, + local_metadata: dict, strict: bool, + missing_keys: Union[List[str], str], + unexpected_keys: Union[List[str], str], + error_msgs: Union[List[str], str]) -> None: + """Add teacher and student prefixes to model parameter names.""" + if not any([ + 'student' in key or 'teacher' in key + for key in state_dict.keys() + ]): + keys = list(state_dict.keys()) + state_dict.update({'teacher.' + k: state_dict[k] for k in keys}) + state_dict.update({'student.' + k: state_dict[k] for k in keys}) + for k in keys: + state_dict.pop(k) + return super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) diff --git a/mmrotate/models/detectors/sood.py b/mmrotate/models/detectors/sood.py new file mode 100644 index 000000000..688daffa1 --- /dev/null +++ b/mmrotate/models/detectors/sood.py @@ -0,0 +1,361 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import random +from typing import Optional, Tuple + +import cv2 +import numpy as np +import torch +import torch.futures +import torch.nn as nn +import torch.nn.functional as F +from mmdet.models.utils import (filter_gt_instances, rename_loss_dict, + reweight_loss_dict) +from mmdet.structures import SampleList +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from torch import Tensor + +from mmrotate.models.detectors.semi_base import RotatedSemiBaseDetector +from mmrotate.models.losses import OT_Loss +from mmrotate.registry import MODELS +from mmrotate.structures.bbox import rbox2qbox, rbox_project + + +@MODELS.register_module() +class SOOD(RotatedSemiBaseDetector): + """Implementation of `SOOD: Towards Semi-Supervised Oriented Object + Detection ` + + Args: + detector (:obj:`ConfigDict` or dict): The detector config. + semi_train_cfg (:obj:`ConfigDict` or dict, optional): + The semi-supervised training config. + semi_loss_cfg (:obj:`ConfigDict` or dict, optional): + The semi-supervised loss config. + semi_test_cfg (:obj:`ConfigDict` or dict, optional): + The semi-supervised testing config. + data_preprocessor (:obj:`ConfigDict` or dict, optional): Config of + :class:`DetDataPreprocessor` to process the input data. + Defaults to None. + init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or + list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + detector: ConfigType, + semi_train_cfg: OptConfigType = None, + semi_loss_cfg: OptConfigType = None, + semi_test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None) -> None: + super().__init__( + detector=detector, + semi_train_cfg=semi_train_cfg, + semi_test_cfg=semi_test_cfg, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) + + # loss settings + self.semi_loss_cfg = semi_loss_cfg + self.cls_pseudo_thr = self.semi_loss_cfg.get('cls_pseudo_thr', 0.5) + self.cls_loss_type = self.semi_loss_cfg.get('cls_loss_type', 'BCE') + + self.reg_loss_type = self.semi_loss_cfg.get('reg_loss_type', + 'SmoothL1Loss') + assert self.reg_loss_type in ['SmoothL1Loss'] + self.loss_bbox = nn.SmoothL1Loss(reduction='none') + + # aux loss settings + self.aux_loss = self.semi_loss_cfg.get('aux_loss', None) + if self.aux_loss is not None: + assert self.aux_loss in ['ot_loss_norm', 'ot_ang_loss_norm'] + self.aux_loss_cfg = self.semi_loss_cfg.get('aux_loss_cfg', None) + assert self.aux_loss_cfg is not None, \ + 'aux_loss_cfg should be provided when aux_loss is not None.' + self.ot_weight = self.aux_loss_cfg.pop('loss_weight', 1.) + self.cost_type = self.aux_loss_cfg.pop('cost_type', 'all') + assert self.cost_type in ['all', 'dist', 'score'] + self.clamp_ot = self.aux_loss_cfg.pop('clamp_ot', False) + self.gc_loss = OT_Loss(**self.aux_loss_cfg) + + self.rbox_pts_ratio = self.semi_loss_cfg.get('rbox_pts_ratio', 0.25) + self.dynamic_weight = self.semi_loss_cfg.get('dynamic_weight', '50ang') + assert self.dynamic_weight in [ + 'None', 'ang', '10ang', '50ang', '100ang' + ] + + @torch.no_grad() + def get_pseudo_instances( + self, batch_inputs: Tensor, batch_data_samples: SampleList + ) -> Tuple[SampleList, Optional[dict]]: + """Get pseudo instances from teacher model.""" + dense_predicts = self.teacher(batch_inputs) + batch_info = {} + batch_info['dense_predicts'] = dense_predicts + + self.teacher.eval() + results_list = self.teacher.predict(batch_inputs, batch_data_samples) + + for data_samples, results in zip(batch_data_samples, results_list): + data_samples.gt_instances = copy.deepcopy(results.pred_instances) + data_samples.gt_instances.bboxes = rbox_project( + data_samples.gt_instances.bboxes, + torch.from_numpy(data_samples.homography_matrix).inverse().to( + self.data_preprocessor.device), data_samples.ori_shape) + return batch_data_samples, batch_info + + def loss_by_pseudo_instances(self, + batch_inputs: Tensor, + batch_data_samples: SampleList, + batch_info: Optional[dict] = None) -> dict: + """Calculate losses from a batch of inputs and pseudo data samples. + + Args: + batch_inputs (Tensor): Input images of shape (N, C, H, W). + These should usually be mean centered and std scaled. + batch_data_samples (List[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`, + which are `pseudo_instance` or `pseudo_panoptic_seg` + or `pseudo_sem_seg` in fact. + batch_info (dict): Batch information of teacher model + forward propagation process. Defaults to None. + Returns: + dict: A dictionary of loss components + """ + + gpu_device = batch_inputs.device + # first filter the pseudo instances with cls scores + batch_data_samples = filter_gt_instances( + batch_data_samples, score_thr=self.cls_pseudo_thr) + + # decide the dense pseudo label area + # according to the teacher predictions + masks = torch.zeros( + len(batch_data_samples), + batch_inputs.shape[-2], + batch_inputs.shape[-1], + device=gpu_device) + for img_idx, data_samples in enumerate(batch_data_samples): + if len(data_samples.gt_instances) > 0: + qboxes = rbox2qbox(data_samples.gt_instances.bboxes) + # different with the original implementation, + # we use round not just int() + # pts = qboxes.cpu().numpy().astype(int) + pts = np.round(qboxes.cpu().numpy()).astype(int) + a = [] + for i in range(len(pts)): + a.append(np.split(pts[i], 4)) + pts = np.array(a) + mask = np.zeros( + (batch_inputs.shape[-2], batch_inputs.shape[-1]), + dtype=np.uint8) + cv2.fillPoly(mask, pts, 255) + valid_mask = np.zeros_like(mask) + valid_pts = np.transpose(mask.nonzero()) + select_list = list(range(len(valid_pts))) + for _ in range(3): + random.shuffle(select_list) + select_num = int(self.rbox_pts_ratio * len(valid_pts)) + valid_pts = valid_pts[select_list[:select_num]] + valid_mask[valid_pts[:, 0], valid_pts[:, 1]] = 255 + masks[img_idx] = torch.from_numpy(valid_mask) > 0 + + masks = masks.view(-1, 1, batch_inputs.shape[-2], + batch_inputs.shape[-1]) + + # interpolate the dense pseudo label area to FPN P5 + size_fpn_p5 = (int(batch_inputs.shape[-2] / 8), + int(batch_inputs.shape[-1] / 8)) + + masks = F.interpolate( + masks.float(), size=size_fpn_p5).bool().squeeze(1) + + num_valid = sum([_.sum() for _ in masks]) if isinstance( + masks, list) else masks.sum() + + if num_valid == 0: + loss_cls = torch.tensor(0., device=gpu_device) + loss_bbox = torch.tensor(0., device=gpu_device) + loss_centerness = torch.tensor(0., device=gpu_device) + if self.aux_loss is not None: + loss_gc = torch.tensor(0., device=gpu_device) + losses = { + 'loss_cls': loss_cls, + 'loss_bbox': loss_bbox, + 'loss_centerness': loss_centerness, + 'loss_gc': loss_gc + } + else: + losses = { + 'loss_cls': loss_cls, + 'loss_bbox': loss_bbox, + 'loss_centerness': loss_centerness, + } + return rename_loss_dict('unsup_', losses) + else: + teacher_logit = batch_info['dense_predicts'] + teacher_cls_scores_logits, teacher_bbox_preds, \ + teacher_angle_pred, teacher_centernesses = teacher_logit + + student_logit = self.student(batch_inputs) + student_cls_scores_logits, student_bbox_preds, \ + student_angle_pred, student_centernesses = student_logit + + loss_cls = torch.tensor(0., device=gpu_device) + loss_bbox = torch.tensor(0., device=gpu_device) + loss_centerness = torch.tensor(0., device=gpu_device) + for i in range(len(masks)): + if masks[i].sum() == 0: + continue + teacher_cls_scores_logits_ = ( + teacher_cls_scores_logits[0][i]).permute(1, 2, 0)[masks[i]] + teacher_bbox_preds_ = (teacher_bbox_preds[0][i]).permute( + 1, 2, 0)[masks[i]] + teacher_angle_pred_ = (teacher_angle_pred[0][i]).permute( + 1, 2, 0)[masks[i]] + teacher_centernesses_ = (teacher_centernesses[0][i]).permute( + 1, 2, 0)[masks[i]] + + student_cls_scores_logits_ = ( + student_cls_scores_logits[0][i]).permute(1, 2, 0)[masks[i]] + student_bbox_preds_ = (student_bbox_preds[0][i]).permute( + 1, 2, 0)[masks[i]] + student_angle_pred_ = (student_angle_pred[0][i]).permute( + 1, 2, 0)[masks[i]] + student_centernesses_ = (student_centernesses[0][i]).permute( + 1, 2, 0)[masks[i]] + + teacher_bbox_preds_ = torch.cat( + [teacher_bbox_preds_, teacher_angle_pred_], dim=-1) + student_bbox_preds_ = torch.cat( + [student_bbox_preds_, student_angle_pred_], dim=-1) + + with torch.no_grad(): + if self.dynamic_weight in [ + 'None', 'ang', '10ang', '50ang', '100ang' + ]: + loss_weight = torch.abs( + teacher_bbox_preds_[:, -1] - + student_bbox_preds_[:, -1]) / np.pi + if self.dynamic_weight == 'None': + loss_weight = torch.ones_like( + loss_weight.unsqueeze(-1)) + elif self.dynamic_weight == 'ang': + loss_weight = torch.clamp( + loss_weight.unsqueeze(-1), 0, 1) + 1 + elif self.dynamic_weight == '10ang': + loss_weight = torch.clamp( + 10 * loss_weight.unsqueeze(-1), 0, 1) + 1 + elif self.dynamic_weight == '50ang': + loss_weight = torch.clamp( + 50 * loss_weight.unsqueeze(-1), 0, 1) + 1 + elif self.dynamic_weight == '100ang': + loss_weight = torch.clamp( + 100 * loss_weight.unsqueeze(-1), 0, 1) + 1 + else: + loss_weight = loss_weight.unsqueeze(-1) + 1 + else: + raise NotImplementedError + + # cls loss + if self.cls_loss_type == 'BCE': + loss_cls_ = F.binary_cross_entropy( + student_cls_scores_logits_.sigmoid(), + teacher_cls_scores_logits_.sigmoid(), + reduction='none') + else: + raise NotImplementedError + loss_cls_ = (loss_cls_ * loss_weight).mean() + # bbox loss + loss_bbox_ = self.loss_bbox( + student_bbox_preds_, + teacher_bbox_preds_) * teacher_centernesses_.sigmoid() + loss_bbox_ = (loss_bbox_ * loss_weight).mean() + + # centerness loss + loss_centerness_ = F.binary_cross_entropy( + student_centernesses_.sigmoid(), + teacher_centernesses_.sigmoid(), + reduction='none') + loss_centerness_ = (loss_centerness_ * loss_weight).mean() + + loss_cls += loss_cls_ + loss_bbox += loss_bbox_ + loss_centerness += loss_centerness_ + + loss_cls = loss_cls / len(masks) + loss_bbox = loss_bbox / len(masks) + loss_centerness = loss_centerness / len(masks) + + if self.aux_loss is not None: + loss_gc = torch.zeros(1, device=gpu_device) + if self.aux_loss == 'ot_ang_loss_norm': + teacher_score_map = teacher_logit[2][0] + student_score_map = student_logit[2][0] + elif self.aux_loss == 'ot_loss_norm': + teacher_score_map = teacher_logit[0][0] + student_score_map = student_logit[0][0] + + teacher_score_map = teacher_score_map.permute(0, 2, 3, 1) + student_score_map = student_score_map.permute(0, 2, 3, 1) + + if self.aux_loss == 'ot_ang_loss_norm': + teacher_score_map = torch.abs(teacher_score_map) / np.pi + student_score_map = torch.abs(student_score_map) / np.pi + elif self.aux_loss == 'ot_loss_norm': + teacher_score_map = torch.softmax(teacher_score_map, dim=-1) + student_score_map = torch.softmax(student_score_map, dim=-1) + + for i in range(teacher_score_map.shape[0]): + teacher_score, score_cls = torch.max( + teacher_score_map[i][masks[i]], dim=-1) + student_score = student_score_map[i][masks[i]][ + torch.arange(teacher_score.shape[0]), score_cls] + pts = masks[i].nonzero() + if len(pts) <= 1: + continue + loss_gc += self.gc_loss( + teacher_score, + student_score, + pts, + cost_type=self.cost_type, + clamp_ot=self.clamp_ot) + loss_gc = self.ot_weight * loss_gc / len(teacher_score_map) + + if self.aux_loss is not None: + losses = { + 'loss_cls': loss_cls, + 'loss_bbox': loss_bbox, + 'loss_centerness': loss_centerness, + 'loss_gc': loss_gc + } + else: + losses = { + 'loss_cls': loss_cls, + 'loss_bbox': loss_bbox, + 'loss_centerness': loss_centerness, + } + + # apply burnin strategy to reweight the unsupervised weights + burn_in_steps = self.semi_train_cfg.get('burn_in_steps', 5000) + unsup_weight = self.semi_train_cfg.get('unsup_weight', 1.) + self.weight_suppress = self.semi_train_cfg.get('weight_suppress', + 'linear') + if self.weight_suppress == 'exp': + target = burn_in_steps + 2000 + if self.iter_count <= target: + scale = np.exp((self.iter_count - target) / 1000) + unsup_weight *= scale + elif self.weight_suppress == 'step': + target = burn_in_steps * 2 + if self.iter_count <= target: + unsup_weight *= 0.25 + elif self.weight_suppress == 'linear': + target = burn_in_steps * 2 + if self.iter_count <= target: + unsup_weight *= (self.iter_count - + burn_in_steps) / burn_in_steps + return rename_loss_dict('unsup_', + reweight_loss_dict(losses, unsup_weight)) diff --git a/mmrotate/models/losses/__init__.py b/mmrotate/models/losses/__init__.py index 12e529e3e..f1502ebd7 100644 --- a/mmrotate/models/losses/__init__.py +++ b/mmrotate/models/losses/__init__.py @@ -5,6 +5,7 @@ from .h2rbox_consistency_loss import H2RBoxConsistencyLoss from .h2rbox_v2_consistency_loss import H2RBoxV2ConsistencyLoss from .kf_iou_loss import KFLoss +from .ot_loss import OT_Loss from .rotated_iou_loss import RotatedIoULoss from .smooth_focal_loss import SmoothFocalLoss from .spatial_border_loss import SpatialBorderLoss @@ -12,5 +13,5 @@ __all__ = [ 'GDLoss', 'GDLoss_v1', 'KFLoss', 'ConvexGIoULoss', 'BCConvexGIoULoss', 'SmoothFocalLoss', 'RotatedIoULoss', 'SpatialBorderLoss', - 'H2RBoxConsistencyLoss', 'H2RBoxV2ConsistencyLoss' + 'H2RBoxConsistencyLoss', 'H2RBoxV2ConsistencyLoss', 'OT_Loss' ] diff --git a/mmrotate/models/losses/ot_loss.py b/mmrotate/models/losses/ot_loss.py new file mode 100644 index 000000000..9ab14b02c --- /dev/null +++ b/mmrotate/models/losses/ot_loss.py @@ -0,0 +1,701 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch.nn import Module + +M_EPS = 1e-16 + + +# This code was copied from +# "https://github.com/cvlab-stonybrook/DM-Count/blob/master/losses/ot_loss.py" +class OT_Loss(Module): + + def __init__(self, num_of_iter_in_ot=100, reg=10.0, method='sinkhorn'): + super(OT_Loss, self).__init__() + self.num_of_iter_in_ot = num_of_iter_in_ot + self.reg = reg + self.method = method + + def forward(self, + t_scores, + s_scores, + pts, + cost_type='all', + clamp_ot=False, + aux_cost=None): + r""" + Calculating Optimal Transport loss between teacher and + student's distribution. + Cost map is defined as: cost = dist(p_t, p_s) + dist(score_t, score_s). + All dist are l2 distance. + Args: + t_scores: Tensor with shape (N, ) + s_scores: Tensor with shape (N, ) + + Returns: + + """ + assert cost_type in ['all', 'dist', 'score'] + with torch.no_grad(): + t_scores_prob = torch.softmax(t_scores, dim=0) + s_scores_prob = torch.softmax(s_scores, dim=0) + score_cost = (t_scores.detach().unsqueeze(1) - + s_scores.detach().unsqueeze(0))**2 + score_cost = score_cost / score_cost.max() + if cost_type in ['all', 'dist']: + coord_x = pts[:, 0] + coord_y = pts[:, 1] + dist_x = (coord_x.reshape(1, -1) - coord_x.reshape(-1, 1))**2 + dist_y = (coord_y.reshape(1, -1) - coord_y.reshape(-1, 1))**2 + dist_cost = (dist_x + dist_y).to(t_scores_prob.device) + dist_cost = dist_cost / dist_cost.max() + if cost_type == 'all': + cost_map = dist_cost + score_cost + else: + cost_map = dist_cost + else: + cost_map = score_cost + if not isinstance(aux_cost, type(None)): + cost_map = cost_map + aux_cost + # cost_map = (dist_cost + score_cost) / 2 + source_prob = s_scores_prob.detach().view(-1) + target_prob = t_scores_prob.detach().view(-1) + if t_scores.shape[0] < 2000: # 2500 + _, log = self.sinkhorn( + target_prob, + source_prob, + cost_map, + self.reg, + maxIter=self.num_of_iter_in_ot, + log=True, + method=self.method) + beta = log[ + 'beta'] # size is the same as source_prob: [#cood * #cood] + else: + _, log = self.sinkhorn( + target_prob.cpu(), + source_prob.cpu(), + cost_map.cpu(), + self.reg, + maxIter=self.num_of_iter_in_ot, + log=True, + method=self.method) + beta = log['beta'].to( + target_prob.device + ) # size is the same as source_prob: [#cood * #cood] + # compute the gradient of Optimal Transport loss + # to predicted density (unnormed_density). + # im_grad = beta / source_count - + # < beta, source_density> / (source_count)^2 + source_density = s_scores.detach().view(-1) + source_count = source_density.sum() + im_grad_1 = (source_count) / (source_count * source_count + + 1e-8) * beta # size of [#cood * #cood] + im_grad_2 = (source_density * beta).sum() / ( + source_count * source_count + 1e-8) # size of 1 + im_grad = im_grad_1 - im_grad_2 + im_grad = im_grad.detach() + # Define loss = . + # The gradient of loss w.r.t prediced density is im_grad. + if clamp_ot: + return torch.clamp_min(torch.sum(s_scores * im_grad), 0) + return torch.sum(s_scores * im_grad) + + # The code below was copied from SOOD + # (https://github.com/HamPerdredes/SOOD/blob/main/ + # ssad/models/losses/utils/bregman_pytorch.py) + def sinkhorn(self, + a, + b, + C, + reg=1e-1, + method='sinkhorn', + maxIter=1000, + tau=1e3, + stopThr=1e-9, + verbose=False, + log=True, + warm_start=None, + eval_freq=10, + print_freq=200, + **kwargs): + r"""Solve the entropic regularization optimal transport The input + should be PyTorch tensors The function solves the following + optimization problem: + + .. math:: + \gamma = arg\min_\gamma <\gamma,C>_F + reg\cdot\Omega(\gamma) + s.t. \gamma 1 = a + \gamma^T 1= b + \gamma\geq 0 + where : + - C is the (ns,nt) metric cost matrix + - :math:`\Omega` is the entropic regularization term : + math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - a and b are target and source measures (sum to 1) + The algorithm used for solving the problem is the + Sinkhorn-Knopp matrix scaling algorithm as proposed in [1]. + + Parameters + ---------- + a : torch.tensor (na,) + samples measure in the target domain + b : torch.tensor (nb,) + samples in the source domain + C : torch.tensor (na,nb) + loss matrix + reg : float + Regularization term > 0 + method : str + method used for the solver either 'sinkhorn', 'greenkhorn', + 'sinkhorn_stabilized' or 'sinkhorn_epsilon_scaling', + see those function for specific parameters + maxIter : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error ( > 0 ) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + Returns + ------- + gamma : (na x nb) torch.tensor + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + References + ---------- + [1] M. Cuturi, Sinkhorn Distances : + Lightspeed Computation of Optimal Transport, + Advances in Neural Information Processing Systems (NIPS) 26, 2013 + See Also + -------- + """ + + if method.lower() == 'sinkhorn': + return self.sinkhorn_knopp( + a, + b, + C, + reg, + maxIter=maxIter, + stopThr=stopThr, + verbose=verbose, + log=log, + warm_start=warm_start, + eval_freq=eval_freq, + print_freq=print_freq, + **kwargs) + elif method.lower() == 'sinkhorn_stabilized': + return self.sinkhorn_stabilized( + a, + b, + C, + reg, + maxIter=maxIter, + tau=tau, + stopThr=stopThr, + verbose=verbose, + log=log, + warm_start=warm_start, + eval_freq=eval_freq, + print_freq=print_freq, + **kwargs) + elif method.lower() == 'sinkhorn_epsilon_scaling': + return self.sinkhorn_epsilon_scaling( + a, + b, + C, + reg, + maxIter=maxIter, + maxInnerIter=100, + tau=tau, + scaling_base=0.75, + scaling_coef=None, + stopThr=stopThr, + verbose=False, + log=log, + warm_start=warm_start, + eval_freq=eval_freq, + print_freq=print_freq, + **kwargs) + else: + raise ValueError("Unknown method '%s'." % method) + + def sinkhorn_knopp(self, + a, + b, + C, + reg=1e-1, + maxIter=1000, + stopThr=1e-9, + verbose=False, + log=False, + warm_start=None, + eval_freq=10, + print_freq=200, + **kwargs): + r"""Solve the entropic regularization optimal transport The input + should be PyTorch tensors The function solves the following + optimization problem: + + .. math:: + \gamma = arg\min_\gamma <\gamma,C>_F + reg\cdot\Omega(\gamma) + s.t. \gamma 1 = a + \gamma^T 1= b + \gamma\geq 0 + where : + - C is the (ns,nt) metric cost matrix + - :math:`\Omega` is the entropic regularization term : + math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - a and b are target and source measures (sum to 1) + The algorithm used for solving the problem is the + Sinkhorn-Knopp matrix scaling algorithm as proposed in [1]. + + Parameters + ---------- + a : torch.tensor (na,) + samples measure in the target domain + b : torch.tensor (nb,) + samples in the source domain + C : torch.tensor (na,nb) + loss matrix + reg : float + Regularization term > 0 + maxIter : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error ( > 0 ) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + Returns + ------- + gamma : (na x nb) torch.tensor + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + References + ---------- + [1] M. Cuturi, Sinkhorn Distances : + Lightspeed Computation of Optimal Transport, + Advances in Neural Information Processing Systems (NIPS) 26, 2013 + See Also + -------- + """ + + device = a.device + na, nb = C.shape + + assert na >= 1 and nb >= 1, 'C needs to be 2d' + assert na == a.shape[0] and nb == b.shape[ + 0], "Shape of a or b doesn't match that of C" + assert reg > 0, 'reg should be greater than 0' + assert a.min() >= 0. and b.min( + ) >= 0., 'Elements in a or b less than 0' + + if log: + log = {'err': []} + + if warm_start is not None: + u = warm_start['u'] + v = warm_start['v'] + else: + u = torch.ones(na, dtype=a.dtype).to(device) / na + v = torch.ones(nb, dtype=b.dtype).to(device) / nb + + K = torch.empty(C.shape, dtype=C.dtype).to(device) + torch.div(C, -reg, out=K) + torch.exp(K, out=K) + + b_hat = torch.empty(b.shape, dtype=C.dtype).to(device) + + it = 1 + err = 1 + + # allocate memory beforehand + KTu = torch.empty(v.shape, dtype=v.dtype).to(device) + Kv = torch.empty(u.shape, dtype=u.dtype).to(device) + + while (err > stopThr and it <= maxIter): + upre, vpre = u, v + + KTu = torch.matmul(u, K).squeeze(0) + v = torch.div(b, KTu + M_EPS) + Kv = torch.matmul(K, v).squeeze(0) + u = torch.div(a, Kv + M_EPS) + + # torch.matmul(u, K, out=KTu.unsqueeze(0)) + # v = torch.div(b, KTu + M_EPS) + # torch.matmul(K, v, out=Kv) + # u = torch.div(a, Kv + M_EPS) + + if torch.any(torch.isnan(u)) or torch.any(torch.isnan(v)) or \ + torch.any(torch.isinf(u)) or torch.any(torch.isinf(v)): + print('Warning: numerical errors at iteration', it) + u, v = upre, vpre + break + + if log and it % eval_freq == 0: + # we can speed up the process + # by checking for the error only all + # the eval_freq iterations + # below is equivalent to: + # b_hat = torch.sum(u.reshape(-1, 1) * K * v.reshape(1, -1), 0) + # but with more memory efficient + b_hat = torch.matmul(u, K) * v + err = (b - b_hat).pow(2).sum().item() + # err = (b - b_hat).abs().sum().item() + log['err'].append(err) + + if verbose and it % print_freq == 0: + print('iteration {:5d}, constraint error {:5e}'.format( + it, err)) + + it += 1 + + if log: + log['u'] = u + log['v'] = v + log['alpha'] = reg * torch.log(u + M_EPS) + log['beta'] = reg * torch.log(v + M_EPS) + + # transport plan + P = u.reshape(-1, 1) * K * v.reshape(1, -1) + if log: + return P, log + else: + return P + + def sinkhorn_stabilized(self, + a, + b, + C, + reg=1e-1, + maxIter=1000, + tau=1e3, + stopThr=1e-9, + verbose=False, + log=False, + warm_start=None, + eval_freq=10, + print_freq=200, + **kwargs): + r"""Solve the entropic regularization Optimal Transport problem with + log stabilization + The function solves the following optimization problem: + + .. math:: + \gamma = arg\min_\gamma <\gamma,C>_F + reg\cdot\Omega(\gamma) + s.t. \gamma 1 = a + \gamma^T 1= b + \gamma\geq 0 + where : + - C is the (ns,nt) metric cost matrix + - :math:`\Omega` is the entropic regularization term : + math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - a and b are target and source measures (sum to 1) + + The algorithm used for solving the problem is the + Sinkhorn-Knopp matrix scaling algorithm as proposed in [1] + but with the log stabilization proposed in [3] an defined in + [2] (Algo 3.1) + + Parameters + ---------- + a : torch.tensor (na,) + samples measure in the target domain + b : torch.tensor (nb,) + samples in the source domain + C : torch.tensor (na,nb) + loss matrix + reg : float + Regularization term > 0 + tau : float + thershold for max value in u or v for log scaling + maxIter : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error ( > 0 ) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + Returns + ------- + gamma : (na x nb) torch.tensor + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + References + ---------- + [1] M. Cuturi, Sinkhorn Distances : + Lightspeed Computation of Optimal Transport, + Advances in Neural Information Processing Systems (NIPS) 26, 2013 + [2] Bernhard Schmitzer. Stabilized Sparse Scaling Algorithms for + Entropy Regularized Transport Problems. + SIAM Journal on Scientific Computing, 2019 + [3] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. + arXiv preprint arXiv:1607.05816. + + See Also + -------- + """ + + device = a.device + na, nb = C.shape + + assert na >= 1 and nb >= 1, 'C needs to be 2d' + assert na == a.shape[0] and nb == b.shape[ + 0], "Shape of a or b doesn't match that of C" + assert reg > 0, 'reg should be greater than 0' + assert a.min() >= 0. and b.min( + ) >= 0., 'Elements in a or b less than 0' + + if log: + log = {'err': []} + + if warm_start is not None: + alpha = warm_start['alpha'] + beta = warm_start['beta'] + else: + alpha = torch.zeros(na, dtype=a.dtype).to(device) + beta = torch.zeros(nb, dtype=b.dtype).to(device) + + u = torch.ones(na, dtype=a.dtype).to(device) / na + v = torch.ones(nb, dtype=b.dtype).to(device) / nb + + def update_K(alpha, beta): + """log space computation.""" + """memory efficient""" + torch.add(alpha.reshape(-1, 1), beta.reshape(1, -1), out=K) + torch.add(K, -C, out=K) + torch.div(K, reg, out=K) + torch.exp(K, out=K) + + def update_P(alpha, beta, u, v, ab_updated=False): + """log space P (gamma) computation.""" + torch.add(alpha.reshape(-1, 1), beta.reshape(1, -1), out=P) + torch.add(P, -C, out=P) + torch.div(P, reg, out=P) + if not ab_updated: + torch.add(P, torch.log(u + M_EPS).reshape(-1, 1), out=P) + torch.add(P, torch.log(v + M_EPS).reshape(1, -1), out=P) + torch.exp(P, out=P) + + K = torch.empty(C.shape, dtype=C.dtype).to(device) + update_K(alpha, beta) + + b_hat = torch.empty(b.shape, dtype=C.dtype).to(device) + + it = 1 + err = 1 + ab_updated = False + + # allocate memory beforehand + KTu = torch.empty(v.shape, dtype=v.dtype).to(device) + Kv = torch.empty(u.shape, dtype=u.dtype).to(device) + P = torch.empty(C.shape, dtype=C.dtype).to(device) + + while (err > stopThr and it <= maxIter): + torch.matmul(u, K, out=KTu) + v = torch.div(b, KTu + M_EPS) + torch.matmul(K, v, out=Kv) + u = torch.div(a, Kv + M_EPS) + + ab_updated = False + # remove numerical problems and store them in K + if u.abs().sum() > tau or v.abs().sum() > tau: + alpha += reg * torch.log(u + M_EPS) + beta += reg * torch.log(v + M_EPS) + u.fill_(1. / na) + v.fill_(1. / nb) + update_K(alpha, beta) + ab_updated = True + + if log and it % eval_freq == 0: + # we can speed up the process by checking for + # the error only all the eval_freq iterations + update_P(alpha, beta, u, v, ab_updated) + b_hat = torch.sum(P, 0) + err = (b - b_hat).pow(2).sum().item() + log['err'].append(err) + + if verbose and it % print_freq == 0: + print('iteration {:5d}, constraint error {:5e}'.format( + it, err)) + + it += 1 + + if log: + log['u'] = u + log['v'] = v + log['alpha'] = alpha + reg * torch.log(u + M_EPS) + log['beta'] = beta + reg * torch.log(v + M_EPS) + + # transport plan + update_P(alpha, beta, u, v, False) + + if log: + return P, log + else: + return P + + def sinkhorn_epsilon_scaling(self, + a, + b, + C, + reg=1e-1, + maxIter=100, + maxInnerIter=100, + tau=1e3, + scaling_base=0.75, + scaling_coef=None, + stopThr=1e-9, + verbose=False, + log=False, + warm_start=None, + eval_freq=10, + print_freq=200, + **kwargs): + r"""Solve the entropic regularization Optimal Transport problem + with log stabilization + The function solves the following optimization problem: + + .. math:: + \gamma = arg\min_\gamma <\gamma,C>_F + reg\cdot\Omega(\gamma) + s.t. \gamma 1 = a + \gamma^T 1= b + \gamma\geq 0 + where : + - C is the (ns,nt) metric cost matrix + - :math:`\Omega` is the entropic regularization term : + math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - a and b are target and source measures (sum to 1) + + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix + scaling algorithm as proposed in [1] but with the log stabilization + proposed in [3] and the log scaling proposed in [2] algorithm 3.2 + + Parameters + ---------- + a : torch.tensor (na,) + samples measure in the target domain + b : torch.tensor (nb,) + samples in the source domain + C : torch.tensor (na,nb) + loss matrix + reg : float + Regularization term > 0 + tau : float + thershold for max value in u or v for log scaling + maxIter : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error ( > 0 ) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + Returns + ------- + gamma : (na x nb) torch.tensor + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + References + ---------- + [1] M. Cuturi, Sinkhorn Distances : + Lightspeed Computation of Optimal Transport, + Advances in Neural Information Processing Systems (NIPS) 26, 2013 + [2] Bernhard Schmitzer. Stabilized Sparse Scaling Algorithms for + Entropy Regularized Transport Problems. + SIAM Journal on Scientific Computing, 2019 + [3] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. + arXiv preprint arXiv:1607.05816. + + See Also + -------- + """ + + na, nb = C.shape + + assert na >= 1 and nb >= 1, 'C needs to be 2d' + assert na == a.shape[0] and nb == b.shape[ + 0], "Shape of a or b doesn't match that of C" + assert reg > 0, 'reg should be greater than 0' + assert a.min() >= 0. and b.min( + ) >= 0., 'Elements in a or b less than 0' + + def get_reg(it, reg, pre_reg): + if it == 1: + return scaling_coef + else: + if (pre_reg - reg) * scaling_base < M_EPS: + return reg + else: + return (pre_reg - reg) * scaling_base + reg + + if scaling_coef is None: + scaling_coef = C.max() + reg + + it = 1 + err = 1 + running_reg = scaling_coef + + if log: + log = {'err': []} + + warm_start = None + + while (err > stopThr and it <= maxIter): + running_reg = get_reg(it, reg, running_reg) + P, _log = self.sinkhorn_stabilized( + a, + b, + C, + running_reg, + maxIter=maxInnerIter, + tau=tau, + stopThr=stopThr, + verbose=False, + log=True, + warm_start=warm_start, + eval_freq=eval_freq, + print_freq=print_freq, + **kwargs) + + warm_start = {} + warm_start['alpha'] = _log['alpha'] + warm_start['beta'] = _log['beta'] + + primal_val = ( + C * P).sum() + reg * (P * torch.log(P)).sum() - reg * P.sum() + dual_val = (_log['alpha'] * a).sum() + (_log['beta'] * + b).sum() - reg * P.sum() + err = primal_val - dual_val + log['err'].append(err) + + if verbose and it % print_freq == 0: + print('iteration {:5d}, constraint error {:5e}'.format( + it, err)) + + it += 1 + + if log: + log['alpha'] = _log['alpha'] + log['beta'] = _log['beta'] + return P, log + else: + return P diff --git a/mmrotate/structures/bbox/__init__.py b/mmrotate/structures/bbox/__init__.py index 895ade012..37c58ac6d 100644 --- a/mmrotate/structures/bbox/__init__.py +++ b/mmrotate/structures/bbox/__init__.py @@ -4,10 +4,12 @@ rbox2hbox, rbox2qbox) from .quadri_boxes import QuadriBoxes from .rotated_boxes import RotatedBoxes -from .transforms import distance2obb, gaussian2bbox, gt2gaussian, norm_angle +from .transforms import (distance2obb, gaussian2bbox, gt2gaussian, norm_angle, + rbox_project) __all__ = [ 'QuadriBoxes', 'RotatedBoxes', 'hbox2rbox', 'hbox2qbox', 'rbox2hbox', 'rbox2qbox', 'qbox2hbox', 'qbox2rbox', 'gaussian2bbox', 'gt2gaussian', - 'norm_angle', 'rbbox_overlaps', 'fake_rbbox_overlaps', 'distance2obb' + 'norm_angle', 'rbbox_overlaps', 'fake_rbbox_overlaps', 'distance2obb', + 'rbox_project' ] diff --git a/mmrotate/structures/bbox/transforms.py b/mmrotate/structures/bbox/transforms.py index a2edde07d..2d4ea8f23 100644 --- a/mmrotate/structures/bbox/transforms.py +++ b/mmrotate/structures/bbox/transforms.py @@ -1,7 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple, Union + import numpy as np import torch +from mmrotate.structures.bbox.box_converters import qbox2rbox, rbox2qbox + def norm_angle(angle, angle_range): """Limit the range of angles. @@ -112,3 +116,42 @@ def distance2obb(points: torch.Tensor, angle_regular = norm_angle(angle, angle_version) return torch.cat([ctr, wh, angle_regular], dim=-1) + + +def rbox_project( + bboxes: Union[torch.Tensor, np.ndarray], + homography_matrix: Union[torch.Tensor, np.ndarray], + img_shape: Optional[Tuple[int, int]] = None +) -> Union[torch.Tensor, np.ndarray]: + """Geometric transformation for rbox. modified from + mmdet/structures/bbox/transforms.py/bbox_project to support rbox. + + Args: + bboxes (Union[torch.Tensor, np.ndarray]): Shape (n, 5) for rboxes. + homography_matrix (Union[torch.Tensor, np.ndarray]): + Shape (3, 3) for geometric transformation. + img_shape (Tuple[int, int], optional): Image shape. Defaults to None. + Returns: + Union[torch.Tensor, np.ndarray]: Converted bboxes. + """ + bboxes_type = type(bboxes) + if bboxes_type is np.ndarray: + bboxes = torch.from_numpy(bboxes) + if isinstance(homography_matrix, np.ndarray): + homography_matrix = torch.from_numpy(homography_matrix) + + corners = rbox2qbox(bboxes).reshape(-1, 2) + corners = torch.cat( + [corners, corners.new_ones(corners.shape[0], 1)], dim=1) + corners = torch.matmul(homography_matrix, corners.t()).t() + # Convert to homogeneous coordinates by normalization + corners = corners[:, :2] / corners[:, 2:3] + + corners = corners.reshape(-1, 8) + corners[:, 0::2] = corners[:, 0::2].clamp(0, img_shape[1]) + corners[:, 1::2] = corners[:, 1::2].clamp(0, img_shape[0]) + bboxes = qbox2rbox(corners) + + if bboxes_type is np.ndarray: + bboxes = bboxes.numpy() + return bboxes diff --git a/tools/misc/split_dota1.5_lists/10p_list.json b/tools/misc/split_dota1.5_lists/10p_list.json new file mode 100644 index 000000000..dfb2a820f --- /dev/null +++ b/tools/misc/split_dota1.5_lists/10p_list.json @@ -0,0 +1 @@ +["P1556.png", "P2483.png", "P2055.png", "P0163.png", "P1401.png", "P2532.png", "P0909.png", "P2517.png", "P0273.png", "P0738.png", "P1130.png", "P0373.png", "P2528.png", "P1804.png", "P1600.png", "P0010.png", "P2141.png", "P1646.png", "P2585.png", "P0800.png", "P1359.png", "P0310.png", "P1161.png", "P1660.png", "P0464.png", "P1535.png", "P1641.png", "P0363.png", "P0435.png", "P0461.png", "P0487.png", "P0782.png", "P2435.png", "P2437.png", "P0067.png", "P0769.png", "P2628.png", "P2163.png", "P1307.png", "P2792.png", "P0602.png", "P0714.png", "P0209.png", "P2221.png", "P0458.png", "P2527.png", "P2409.png", "P1152.png", "P2018.png", "P1871.png", "P1158.png", "P1673.png", "P1078.png", "P2321.png", "P2682.png", "P2746.png", "P0821.png", "P1243.png", "P1688.png", "P0495.png", "P0370.png", "P2066.png", "P0187.png", "P1928.png", "P0248.png", "P0517.png", "P0301.png", "P0555.png", "P2123.png", "P1198.png", "P1686.png", "P2543.png", "P1534.png", "P2190.png", "P0981.png", "P2134.png", "P2605.png", "P0049.png", "P0562.png", "P0938.png", "P2334.png", "P0061.png", "P0731.png", "P1709.png", "P1339.png", "P1059.png", "P2747.png", "P0100.png", "P2644.png", "P0603.png", "P0144.png", "P0631.png", "P1859.png", "P2248.png", "P1962.png", "P1395.png", "P0617.png", "P2671.png", "P1317.png", "P1524.png", "P1069.png", "P1639.png", "P1414.png", "P2759.png", "P0691.png", "P2554.png", "P1427.png", "P0883.png", "P1200.png", "P2616.png", "P0505.png", "P1449.png", "P1319.png", "P2651.png", "P2415.png", "P0282.png", "P1608.png", "P0966.png", "P0627.png", "P0453.png", "P2494.png", "P1181.png", "P0220.png", "P1299.png", "P1702.png", "P1912.png", "P1705.png", "P2642.png", "P0005.png", "P2206.png", "P2565.png", "P2103.png", "P1515.png", "P0881.png", "P0109.png", "P2770.png", "P0775.png", "P0867.png", "P2049.png", "P0455.png", "P2687.png"] diff --git a/tools/misc/split_dota1.5_lists/20p_list.json b/tools/misc/split_dota1.5_lists/20p_list.json new file mode 100644 index 000000000..c0f7a4fef --- /dev/null +++ b/tools/misc/split_dota1.5_lists/20p_list.json @@ -0,0 +1 @@ +["P0535.png", "P0867.png", "P2605.png", "P0100.png", "P1344.png", "P1661.png", "P2702.png", "P0553.png", "P2505.png", "P0596.png", "P2403.png", "P1515.png", "P2133.png", "P1413.png", "P2628.png", "P1317.png", "P2122.png", "P0282.png", "P2206.png", "P1361.png", "P1080.png", "P2545.png", "P0517.png", "P1495.png", "P2471.png", "P0417.png", "P0367.png", "P1152.png", "P2622.png", "P0631.png", "P0691.png", "P1109.png", "P2016.png", "P1962.png", "P0453.png", "P0310.png", "P0067.png", "P2483.png", "P1200.png", "P2123.png", "P2682.png", "P1392.png", "P1600.png", "P0289.png", "P0489.png", "P2409.png", "P0619.png", "P1822.png", "P1912.png", "P2670.png", "P0800.png", "P0064.png", "P0870.png", "P1493.png", "P2175.png", "P2705.png", "P0363.png", "P1341.png", "P2601.png", "P1652.png", "P0554.png", "P2055.png", "P0617.png", "P1255.png", "P0209.png", "P2611.png", "P0360.png", "P0769.png", "P0052.png", "P0248.png", "P0292.png", "P0000.png", "P0538.png", "P0909.png", "P2642.png", "P0603.png", "P0597.png", "P1465.png", "P2190.png", "P0005.png", "P1299.png", "P1739.png", "P0731.png", "P1139.png", "P2260.png", "P2433.png", "P0829.png", "P0565.png", "P2759.png", "P0220.png", "P0505.png", "P1414.png", "P0187.png", "P1198.png", "P1793.png", "P0023.png", "P0555.png", "P0850.png", "P1859.png", "P1181.png", "P0627.png", "P0133.png", "P1649.png", "P0840.png", "P0288.png", "P1158.png", "P2157.png", "P0438.png", "P2472.png", "P1078.png", "P0370.png", "P0652.png", "P2747.png", "P0993.png", "P0521.png", "P2397.png", "P0750.png", "P0762.png", "P2437.png", "P2321.png", "P2103.png", "P1130.png", "P0415.png", "P1891.png", "P2687.png", "P1645.png", "P0981.png", "P1788.png", "P2517.png", "P1804.png", "P1702.png", "P0223.png", "P2696.png", "P1401.png", "P0373.png", "P0201.png", "P0925.png", "P0978.png", "P1673.png", "P2049.png", "P2066.png", "P0487.png", "P2494.png", "P1047.png", "P2729.png", "P2362.png", "P2018.png", "P1497.png", "P2792.png", "P0461.png", "P2107.png", "P1709.png", "P1685.png", "P0821.png", "P2565.png", "P2651.png", "P0562.png", "P0495.png", "P2248.png", "P2746.png", "P1893.png", "P1998.png", "P2163.png", "P0250.png", "P1871.png", "P0085.png", "P1069.png", "P1675.png", "P2616.png", "P1556.png", "P0345.png", "P2118.png", "P1714.png", "P2134.png", "P2644.png", "P1705.png", "P0972.png", "P0001.png", "P2253.png", "P0779.png", "P0227.png", "P1276.png", "P1951.png", "P2302.png", "P1646.png", "P0455.png", "P1059.png", "P0804.png", "P2334.png", "P1935.png", "P0464.png", "P2151.png", "P2716.png", "P0109.png", "P2646.png", "P2203.png", "P1174.png", "P2189.png", "P2770.png", "P1319.png", "P2415.png", "P0782.png", "P1427.png", "P2554.png", "P1243.png", "P0938.png", "P2527.png", "P1587.png", "P0714.png", "P2204.png", "P0586.png", "P1258.png", "P1913.png", "P0049.png", "P0738.png", "P0921.png", "P0913.png", "P1339.png", "P1534.png", "P2532.png", "P1688.png", "P2221.png", "P0713.png", "P0232.png", "P1502.png", "P0716.png", "P0144.png", "P0883.png", "P0218.png", "P1987.png", "P0881.png", "P0435.png", "P0163.png", "P1928.png", "P0301.png", "P0966.png", "P0010.png", "P2457.png", "P1395.png", "P2141.png", "P2672.png", "P2028.png", "P1359.png", "P1079.png", "P1660.png", "P2380.png", "P2423.png", "P2543.png", "P2278.png", "P1325.png", "P1641.png", "P0700.png", "P2656.png", "P0845.png", "P0458.png", "P2435.png", "P0042.png", "P0355.png", "P2528.png", "P2671.png", "P2585.png", "P1686.png", "P1519.png", "P0273.png", "P1524.png", "P0093.png", "P1307.png", "P0775.png", "P1438.png", "P1535.png", "P2002.png", "P1639.png", "P2692.png", "P0602.png", "P2364.png", "P0593.png", "P1608.png", "P1161.png", "P0061.png", "P2057.png", "P1449.png", "P0780.png"] diff --git a/tools/misc/split_dota1.5_lists/30p_list.json b/tools/misc/split_dota1.5_lists/30p_list.json new file mode 100644 index 000000000..1f15960ec --- /dev/null +++ b/tools/misc/split_dota1.5_lists/30p_list.json @@ -0,0 +1 @@ +["P0677.png", "P2014.png", "P0987.png", "P0099.png", "P0826.png", "P2454.png", "P2179.png", "P0116.png", "P1193.png", "P1874.png", "P1868.png", "P1387.png", "P2522.png", "P0426.png", "P2777.png", "P1469.png", "P0322.png", "P1389.png", "P2735.png", "P0303.png", "P0222.png", "P2007.png", "P2250.png", "P1214.png", "P1340.png", "P1790.png", "P1599.png", "P1199.png", "P2750.png", "P1674.png", "P0707.png", "P2615.png", "P2356.png", "P2390.png", "P1010.png", "P1353.png", "P2114.png", "P0039.png", "P0605.png", "P1638.png", "P1846.png", "P0766.png", "P2067.png", "P0944.png", "P2377.png", "P2466.png", "P0332.png", "P2162.png", "P2560.png", "P1540.png", "P1251.png", "P1707.png", "P0371.png", "P2287.png", "P2073.png", "P0263.png", "P1357.png", "P0734.png", "P1482.png", "P2490.png", "P2783.png", "P2279.png", "P0149.png", "P2382.png", "P0699.png", "P1089.png", "P1851.png", "P2237.png", "P2034.png", "P2631.png", "P1191.png", "P0430.png", "P0973.png", "P0225.png", "P0786.png", "P0744.png", "P1977.png", "P1697.png", "P0753.png", "P1173.png", "P1052.png", "P0176.png", "P2395.png", "P1211.png", "P0203.png", "P0448.png", "P2251.png", "P2693.png", "P2289.png", "P1616.png", "P0848.png", "P1757.png", "P0183.png", "P2311.png", "P0914.png", "P2017.png", "P0221.png", "P1499.png", "P2775.png", "P0885.png", "P1343.png", "P2283.png", "P0252.png", "P1399.png", "P1954.png", "P0285.png", "P2306.png", "P2319.png", "P0340.png", "P1776.png", "P1727.png", "P2001.png", "P2304.png", "P0481.png", "P2804.png", "P0204.png", "P0450.png", "P0428.png", "P1507.png", "P1054.png", "P0308.png", "P1265.png", "P2793.png", "P2089.png", "P2444.png", "P2247.png", "P0940.png", "P0070.png", "P1308.png", "P2010.png", "P1794.png", "P1155.png", "P2076.png", "P0202.png", "P0783.png", "P1297.png", "P0591.png", "P0074.png", "P1607.png", "P2659.png", "P2080.png", "P0535.png", "P0867.png", "P2605.png", "P0100.png", "P1344.png", "P1661.png", "P2702.png", "P0553.png", "P2505.png", "P0596.png", "P2403.png", "P1515.png", "P2133.png", "P1413.png", "P2628.png", "P1317.png", "P2122.png", "P0282.png", "P2206.png", "P1361.png", "P1080.png", "P2545.png", "P0517.png", "P1495.png", "P2471.png", "P0417.png", "P0367.png", "P1152.png", "P2622.png", "P0631.png", "P0691.png", "P1109.png", "P2016.png", "P1962.png", "P0453.png", "P0310.png", "P0067.png", "P2483.png", "P1200.png", "P2123.png", "P2682.png", "P1392.png", "P1600.png", "P0289.png", "P0489.png", "P2409.png", "P0619.png", "P1822.png", "P1912.png", "P2670.png", "P0800.png", "P0064.png", "P0870.png", "P1493.png", "P2175.png", "P2705.png", "P0363.png", "P1341.png", "P2601.png", "P1652.png", "P0554.png", "P2055.png", "P0617.png", "P1255.png", "P0209.png", "P2611.png", "P0360.png", "P0769.png", "P0052.png", "P0248.png", "P0292.png", "P0000.png", "P0538.png", "P0909.png", "P2642.png", "P0603.png", "P0597.png", "P1465.png", "P2190.png", "P0005.png", "P1299.png", "P1739.png", "P0731.png", "P1139.png", "P2260.png", "P2433.png", "P0829.png", "P0565.png", "P2759.png", "P0220.png", "P0505.png", "P1414.png", "P0187.png", "P1198.png", "P1793.png", "P0023.png", "P0555.png", "P0850.png", "P1859.png", "P1181.png", "P0627.png", "P0133.png", "P1649.png", "P0840.png", "P0288.png", "P1158.png", "P2157.png", "P0438.png", "P2472.png", "P1078.png", "P0370.png", "P0652.png", "P2747.png", "P0993.png", "P0521.png", "P2397.png", "P0750.png", "P0762.png", "P2437.png", "P2321.png", "P2103.png", "P1130.png", "P0415.png", "P1891.png", "P2687.png", "P1645.png", "P0981.png", "P1788.png", "P2517.png", "P1804.png", "P1702.png", "P0223.png", "P2696.png", "P1401.png", "P0373.png", "P0201.png", "P0925.png", "P0978.png", "P1673.png", "P2049.png", "P2066.png", "P0487.png", "P2494.png", "P1047.png", "P2729.png", "P2362.png", "P2018.png", "P1497.png", "P2792.png", "P0461.png", "P2107.png", "P1709.png", "P1685.png", "P0821.png", "P2565.png", "P2651.png", "P0562.png", "P0495.png", "P2248.png", "P2746.png", "P1893.png", "P1998.png", "P2163.png", "P0250.png", "P1871.png", "P0085.png", "P1069.png", "P1675.png", "P2616.png", "P1556.png", "P0345.png", "P2118.png", "P1714.png", "P2134.png", "P2644.png", "P1705.png", "P0972.png", "P0001.png", "P2253.png", "P0779.png", "P0227.png", "P1276.png", "P1951.png", "P2302.png", "P1646.png", "P0455.png", "P1059.png", "P0804.png", "P2334.png", "P1935.png", "P0464.png", "P2151.png", "P2716.png", "P0109.png", "P2646.png", "P2203.png", "P1174.png", "P2189.png", "P2770.png", "P1319.png", "P2415.png", "P0782.png", "P1427.png", "P2554.png", "P1243.png", "P0938.png", "P2527.png", "P1587.png", "P0714.png", "P2204.png", "P0586.png", "P1258.png", "P1913.png", "P0049.png", "P0738.png", "P0921.png", "P0913.png", "P1339.png", "P1534.png", "P2532.png", "P1688.png", "P2221.png", "P0713.png", "P0232.png", "P1502.png", "P0716.png", "P0144.png", "P0883.png", "P0218.png", "P1987.png", "P0881.png", "P0435.png", "P0163.png", "P1928.png", "P0301.png", "P0966.png", "P0010.png", "P2457.png", "P1395.png", "P2141.png", "P2672.png", "P2028.png", "P1359.png", "P1079.png", "P1660.png", "P2380.png", "P2423.png", "P2543.png", "P2278.png", "P1325.png", "P1641.png", "P0700.png", "P2656.png", "P0845.png", "P0458.png", "P2435.png", "P0042.png", "P0355.png", "P2528.png", "P2671.png", "P2585.png", "P1686.png", "P1519.png", "P0273.png", "P1524.png", "P0093.png", "P1307.png", "P0775.png", "P1438.png", "P1535.png", "P2002.png", "P1639.png", "P2692.png", "P0602.png", "P2364.png", "P0593.png", "P1608.png", "P1161.png", "P0061.png", "P2057.png", "P1449.png", "P0780.png"] diff --git a/tools/misc/split_dota1.5_via_lists.py b/tools/misc/split_dota1.5_via_lists.py new file mode 100644 index 000000000..c03527ac7 --- /dev/null +++ b/tools/misc/split_dota1.5_via_lists.py @@ -0,0 +1,132 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This code was modified from SOOD +# (https://github.com/HamPerdredes/SOOD/blob/main/ +# tools/data/dota/split_data_via_list.py) +import glob +import json +import os +import shutil + + +def split_img_with_list(list_dir, src_dir): + list_file = [None, None, None] + list_file[0] = os.path.join(list_dir, '10p_list.json') + list_file[1] = os.path.join(list_dir, '20p_list.json') + list_file[2] = os.path.join(list_dir, '30p_list.json') + assert all([os.path.exists(list_file_) for list_file_ in list_file]) + + file_list = [list(), list(), list(), list(), list()] + for i in range(0, len(list_file)): + with open(list_file[i], 'r', encoding='utf-8') as f: + file_list[i] = json.load(f) + + all_files = dict() + + train_dir = os.path.join(src_dir, 'train') + labeled10_out_dir = os.path.join(src_dir, 'train_10_labeled') + unlabeled10_out_dir = os.path.join(src_dir, 'train_10_unlabeled') + labeled20_out_dir = os.path.join(src_dir, 'train_20_labeled') + unlabeled20_out_dir = os.path.join(src_dir, 'train_20_unlabeled') + labeled30_out_dir = os.path.join(src_dir, 'train_30_labeled') + unlabeled30_out_dir = os.path.join(src_dir, 'train_30_unlabeled') + + train_img_dir = os.path.join(train_dir, 'images') + + for file_ in glob.glob(os.path.join(train_img_dir, '*.png')): + all_files[file_.split('/')[-1]] = file_ + print(f'Total images: {len(all_files)}') + + if os.path.exists(labeled10_out_dir): + shutil.rmtree(labeled10_out_dir) + if os.path.exists(unlabeled10_out_dir): + shutil.rmtree(unlabeled10_out_dir) + if os.path.exists(labeled20_out_dir): + shutil.rmtree(labeled20_out_dir) + if os.path.exists(unlabeled20_out_dir): + shutil.rmtree(unlabeled20_out_dir) + if os.path.exists(labeled30_out_dir): + shutil.rmtree(labeled30_out_dir) + if os.path.exists(unlabeled30_out_dir): + shutil.rmtree(unlabeled30_out_dir) + os.makedirs(labeled10_out_dir + '/images') + os.makedirs(labeled10_out_dir + '/annfiles') + os.makedirs(unlabeled10_out_dir + '/images') + os.makedirs(unlabeled10_out_dir + '/annfiles') + os.makedirs(unlabeled10_out_dir + '/empty_annfiles') + os.makedirs(labeled20_out_dir + '/images') + os.makedirs(labeled20_out_dir + '/annfiles') + os.makedirs(unlabeled20_out_dir + '/images') + os.makedirs(unlabeled20_out_dir + '/annfiles') + os.makedirs(unlabeled20_out_dir + '/empty_annfiles') + os.makedirs(labeled30_out_dir + '/images') + os.makedirs(labeled30_out_dir + '/annfiles') + os.makedirs(unlabeled30_out_dir + '/images') + os.makedirs(unlabeled30_out_dir + '/annfiles') + os.makedirs(unlabeled30_out_dir + '/empty_annfiles') + + for file_name, file_path in all_files.items(): + if (file_name.split('__')[0] + '.png') in file_list[0]: + os.symlink(file_path, + os.path.join(labeled10_out_dir, 'images', file_name)) + os.symlink( + os.path.join('/'.join(file_path.split('/')[0:-2]), 'annfiles', + file_path.split('/')[-1]).split('.')[0] + '.txt', + os.path.join(labeled10_out_dir, 'annfiles', + file_name.split('.')[0] + '.txt')) + if (file_name.split('__')[0] + '.png') in file_list[1]: + os.symlink(file_path, + os.path.join(labeled20_out_dir, 'images', file_name)) + os.symlink( + os.path.join('/'.join(file_path.split('/')[0:-2]), 'annfiles', + file_path.split('/')[-1]).split('.')[0] + '.txt', + os.path.join(labeled20_out_dir, 'annfiles', + file_name.split('.')[0] + '.txt')) + if (file_name.split('__')[0] + '.png') in file_list[2]: + os.symlink(file_path, + os.path.join(labeled30_out_dir, 'images', file_name)) + os.symlink( + os.path.join('/'.join(file_path.split('/')[0:-2]), 'annfiles', + file_path.split('/')[-1]).split('.')[0] + '.txt', + os.path.join(labeled30_out_dir, 'annfiles', + file_name.split('.')[0] + '.txt')) + if (file_name.split('__')[0] + '.png') not in file_list[0]: + os.symlink(file_path, + os.path.join(unlabeled10_out_dir, 'images', file_name)) + os.symlink( + os.path.join('/'.join(file_path.split('/')[0:-2]), 'annfiles', + file_path.split('/')[-1]).split('.')[0] + '.txt', + os.path.join(unlabeled10_out_dir, 'annfiles', + file_name.split('.')[0] + '.txt')) + open( + os.path.join(unlabeled10_out_dir, 'empty_annfiles', + file_name.split('.')[0] + '.txt'), 'w').close() + if (file_name.split('__')[0] + '.png') not in file_list[1]: + os.symlink(file_path, + os.path.join(unlabeled20_out_dir, 'images', file_name)) + os.symlink( + os.path.join('/'.join(file_path.split('/')[0:-2]), 'annfiles', + file_path.split('/')[-1]).split('.')[0] + '.txt', + os.path.join(unlabeled20_out_dir, 'annfiles', + file_name.split('.')[0] + '.txt')) + open( + os.path.join(unlabeled20_out_dir, 'empty_annfiles', + file_name.split('.')[0] + '.txt'), 'w').close() + if (file_name.split('__')[0] + '.png') not in file_list[2]: + os.symlink(file_path, + os.path.join(unlabeled30_out_dir, 'images', file_name)) + os.symlink( + os.path.join('/'.join(file_path.split('/')[0:-2]), 'annfiles', + file_path.split('/')[-1]).split('.')[0] + '.txt', + os.path.join(unlabeled30_out_dir, 'annfiles', + file_name.split('.')[0] + '.txt')) + open( + os.path.join(unlabeled30_out_dir, 'empty_annfiles', + file_name.split('.')[0] + '.txt'), 'w').close() + print('Finish symlink labeled image.') + + +if __name__ == '__main__': + list_dir = 'tools/misc/split_dota1.5_lists' + src_dir = 'data/split_ss_dota1_5' + + split_img_with_list(list_dir, src_dir)