Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot reproduce published performance after training #84

Open
LaCandela opened this issue Jun 20, 2024 · 2 comments
Open

Cannot reproduce published performance after training #84

LaCandela opened this issue Jun 20, 2024 · 2 comments

Comments

@LaCandela
Copy link

I am trying to reproduce the published model quality but I couldn't so far. I retrained an m0 model version and got only 0.215 NDS and 0.04 mAP instead of the published 0.411 NDS and 0.277 mAP.

I've only made small modification in the config file:

  • switched to BN instead of SyncBN because I only use one GPU
  • file_client_args = dict(backend='disk')
  • samples_per_gpu=20, workers_per_gpu=5,
  • updating the load_from variable pointing to a pretrained model

I started the training without slurm:

CUDA_VISIBLE_DEVICES=0 python tools/train.py ~/Fast-BEV/configs/fastbev/exp/paper/fastbev_m0_r18_s256x704_v200x200x4_c192_d2_f4.py --work-dir /Fast-BEV/runs/train_repro/

However, I can reproduce the NDS and mAP values on the validation set with the pre-trained models published in this repository. So this shows to me that the environment and data setup seems to be fine (at least for validation).

Do you have any idea where can I improve the performance of the training pipeline? Is there any hyperparameter I could tune?

For reference I copy here the config file that I used:

-- coding: utf-8 --

model = dict(
type='FastBEV',
style="v1",
backbone=dict(
type='ResNet',
depth=18,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
#norm_cfg=dict(type='SyncBN', requires_grad=True),
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet18'),
style='pytorch'
),
neck=dict(
type='FPN',
#norm_cfg=dict(type='SyncBN', requires_grad=True),
norm_cfg=dict(type='BN', requires_grad=True),
in_channels=[64, 128, 256, 512],
out_channels=64,
num_outs=4),
neck_fuse=dict(in_channels=[256], out_channels=[64]),
neck_3d=dict(
type='M2BevNeck',
in_channels=644,
out_channels=192,
num_layers=2,
stride=2,
is_transpose=False,
fuse=dict(in_channels=64
44, out_channels=644),
#norm_cfg=dict(type='SyncBN', requires_grad=True)),
norm_cfg=dict(type='BN', requires_grad=True)),
seg_head=None,
bbox_head=dict(
type='FreeAnchor3DHead',
is_transpose=True,
num_classes=10,
in_channels=192,
feat_channels=192,
num_convs=0,
use_direction_classifier=True,
pre_anchor_topk=25,
bbox_thr=0.5,
gamma=2.0,
alpha=0.5,
anchor_generator=dict(
type='AlignedAnchor3DRangeGenerator',
ranges=[[-50, -50, -1.8, 50, 50, -1.8]],
# scales=[1, 2, 4],
sizes=[
[0.8660, 2.5981, 1.], # 1.5/sqrt(3)
[0.5774, 1.7321, 1.], # 1/sqrt(3)
[1., 1., 1.],
[0.4, 0.4, 1],
],
custom_values=[0, 0],
rotations=[0, 1.57],
reshape_out=True),
assigner_per_size=False,
diff_rad_by_sin=True,
dir_offset=0.7854, # pi/4
dir_limit_offset=0,
bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder', code_size=9),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=0.8),
loss_dir=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.8)),
multi_scale_id=[0],
n_voxels=[[200, 200, 4]],
voxel_size=[[0.5, 0.5, 1.5]],
# model training and testing settings
train_cfg=dict(
assigner=dict(
type='MaxIoUAssigner',
iou_calculator=dict(type='BboxOverlapsNearest3D'),
pos_iou_thr=0.6,
neg_iou_thr=0.3,
min_pos_iou=0.3,
ignore_iof_thr=-1),
allowed_border=0,
code_weight=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2],
pos_weight=-1,
debug=False),
test_cfg=dict(
score_thr=0.05,
min_bbox_size=0,
nms_pre=1000,
max_num=500,
use_scale_nms=True,
use_tta=False,
# Normal-NMS
nms_across_levels=False,
use_rotate_nms=True,
nms_thr=0.2,
# Scale-NMS
nms_type_list=[
'rotate', 'rotate', 'rotate', 'rotate', 'rotate', 'rotate', 'rotate', 'rotate', 'rotate', 'circle'],
nms_thr_list=[0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.5, 0.5, 0.2],
nms_radius_thr_list=[4, 12, 10, 10, 12, 0.85, 0.85, 0.175, 0.175, 1],
nms_rescale_factor=[1.0, 0.7, 0.55, 0.4, 0.7, 1.0, 1.0, 4.5, 9.0, 1.0],
)
)

If point cloud range is changed, the models should also change their point cloud range accordingly

point_cloud_range = [-50, -50, -5, 50, 50, 3]

For nuScenes we usually do 10-class detection

class_names = [
'car', 'truck', 'trailer', 'bus', 'construction_vehicle', 'bicycle',
'motorcycle', 'pedestrian', 'traffic_cone', 'barrier'
]
dataset_type = 'NuScenesMultiView_Map_Dataset2'
data_root = './data/nuscenes/'

Input modality for nuScenes dataset, this is consistent with the submission

format which requires the information in input_modality.

input_modality = dict(
use_lidar=False,
use_camera=True,
use_radar=False,
use_map=False,
use_external=True)

img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
data_config = {
'src_size': (900, 1600),
'input_size': (256, 704),
# train-aug
'resize': (-0.06, 0.11),
'crop': (-0.05, 0.05),
'rot': (-5.4, 5.4),
'flip': True,
# test-aug
'test_input_size': (256, 704),
'test_resize': 0.0,
'test_rotate': 0.0,
'test_flip': False,
# top, right, bottom, left
'pad': (0, 0, 0, 0),
'pad_divisor': 32,
'pad_color': (0, 0, 0),
}

file_client_args = dict(backend='disk')

file_client_args = dict(

backend='petrel',

path_mapping=dict({

data_root: 'public-1424:s3://openmmlab/datasets/detection3d/nuscenes/'}))

train_pipeline = [
dict(type='MultiViewPipeline', sequential=True, n_images=6, n_times=4, transforms=[
dict(
type='LoadImageFromFile',
file_client_args=file_client_args)]),
dict(type='LoadAnnotations3D',
with_bbox=True,
with_label=True,
with_bev_seg=True),
dict(
type='LoadPointsFromFile',
dummy=True,
coord_type='LIDAR',
load_dim=5,
use_dim=5),
dict(
type='RandomFlip3D',
flip_2d=False,
sync_2d=False,
flip_ratio_bev_horizontal=0.5,
flip_ratio_bev_vertical=0.5,
update_img2lidar=True),
dict(
type='GlobalRotScaleTrans',
rot_range=[-0.3925, 0.3925],
scale_ratio_range=[0.95, 1.05],
translation_std=[0.05, 0.05, 0.05],
update_img2lidar=True),
dict(type='RandomAugImageMultiViewImage', data_config=data_config),
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
dict(type='KittiSetOrigin', point_cloud_range=point_cloud_range),
dict(type='NormalizeMultiviewImage', **img_norm_cfg),
dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(type='Collect3D', keys=['img', 'gt_bboxes', 'gt_labels',
'gt_bboxes_3d', 'gt_labels_3d',
'gt_bev_seg'])]
test_pipeline = [
dict(type='MultiViewPipeline', sequential=True, n_images=6, n_times=4, transforms=[
dict(
type='LoadImageFromFile',
file_client_args=file_client_args)]),
dict(
type='LoadPointsFromFile',
dummy=True,
coord_type='LIDAR',
load_dim=5,
use_dim=5),
dict(type='RandomAugImageMultiViewImage', data_config=data_config, is_train=False),
# dict(type='TestTimeAugImageMultiViewImage', data_config=data_config, is_train=False),
dict(type='KittiSetOrigin', point_cloud_range=point_cloud_range),
dict(type='NormalizeMultiviewImage', **img_norm_cfg),
dict(type='DefaultFormatBundle3D', class_names=class_names, with_label=False),
dict(type='Collect3D', keys=['img'])]

data = dict(
samples_per_gpu=20,
workers_per_gpu=5,
train=dict(
type='CBGSDataset',
dataset=dict(
type=dataset_type,
data_root=data_root,
pipeline=train_pipeline,
classes=class_names,
modality=input_modality,
test_mode=False,
with_box2d=True,
box_type_3d='LiDAR',
ann_file='data/nuscenes/nuscenes_infos_train_4d_interval3_max60.pkl',
load_interval=1,
sequential=True,
n_times=4,
train_adj_ids=[1, 3, 5],
speed_mode='abs_velo',
max_interval=10,
min_interval=0,
fix_direction=True,
prev_only=True,
test_adj='prev',
test_adj_ids=[1, 3, 5],
test_time_id=None,
)
),
val=dict(
type=dataset_type,
data_root=data_root,
pipeline=test_pipeline,
classes=class_names,
modality=input_modality,
test_mode=True,
with_box2d=True,
box_type_3d='LiDAR',
ann_file='data/nuscenes/nuscenes_infos_val_4d_interval3_max60.pkl',
load_interval=1,
sequential=True,
n_times=4,
train_adj_ids=[1, 3, 5],
speed_mode='abs_velo',
max_interval=10,
min_interval=0,
fix_direction=True,
test_adj='prev',
test_adj_ids=[1, 3, 5],
test_time_id=None,
),
test=dict(
type=dataset_type,
data_root=data_root,
pipeline=test_pipeline,
classes=class_names,
modality=input_modality,
test_mode=True,
with_box2d=True,
box_type_3d='LiDAR',
ann_file='data/nuscenes/nuscenes_infos_val_4d_interval3_max60.pkl',
load_interval=1,
sequential=True,
n_times=4,
train_adj_ids=[1, 3, 5],
speed_mode='abs_velo',
max_interval=10,
min_interval=0,
fix_direction=True,
test_adj='prev',
test_adj_ids=[1, 3, 5],
test_time_id=None,
)
)

optimizer = dict(
type='AdamW2',
lr=0.0004,
weight_decay=0.01,
paramwise_cfg=dict(
custom_keys={'backbone': dict(lr_mult=0.1, decay_mult=1.0)}))
optimizer_config = dict(grad_clip=dict(max_norm=35., norm_type=2))

learning policy

lr_config = dict(
policy='poly',
warmup='linear',
warmup_iters=1000,
warmup_ratio=1e-6,
power=1.0,
min_lr=0,
by_epoch=False
)

total_epochs = 20
checkpoint_config = dict(interval=1)
log_config = dict(
interval=10,
hooks=[
dict(type='TextLoggerHook'),
dict(type='TensorboardLoggerHook'),
])
evaluation = dict(interval=2)
dist_params = dict(backend='nccl')
find_unused_parameters = True # todo: fix number of FPN outputs
log_level = 'INFO'

load_from = "/Fast-BEV/checkpoints/cascade_mask_rcnn_r18_fpn_coco-mstrain_3x_20e_nuim_bbox_mAP_0.5110_segm_mAP_0.4070.pth"
resume_from = None
workflow = [('train', 1), ('val', 1)]

fp16 settings, the loss scale is specifically tuned to avoid Nan

fp16 = dict(loss_scale='dynamic')

@Seyd2
Copy link

Seyd2 commented Sep 18, 2024

Hi @LaCandela!
I think you can improve the detection score by scaling the learning rate lr according to your batch size. As a general starting point I'd try scaling it linear to the author's batch size of 32 (1 batch per gpu). As you are setting the batch size to 20 using a single gpu the lr should be lr=(0.0004/32)*20.

Hope you are still working on FastBEV and can help me figure out, how to setup the environment to get the published inference time. For the M0 model I only get 2 fps on a single RTX4090. I haven't compiled the files from "script/view_transform_cuda" folder and wonder if that is the issue or that a TensorRT implementation is necessary.

@BaophanN
Copy link

Hi @LaCandela!
Thanks for the helpful issue report, can you help me with this FileNotFoundErrors:

  File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
    data = fetcher.fetch(index)
  File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 49, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/workspace/source/Fast-BEV/mmdet3d/datasets/dataset_wrappers.py", line 71, in __getitem__
    return self.dataset[ori_idx]
  File "/workspace/source/Fast-BEV/mmdet3d/datasets/nuscenes_monocular_dataset_map_2.py", line 57, in __getitem__
    data = self.prepare_train_data(idx)
  File "/workspace/source/Fast-BEV/mmdet3d/datasets/custom_3d.py", line 156, in prepare_train_data
    example = self.pipeline(input_dict)
  File "/opt/conda/lib/python3.7/site-packages/mmdet/datasets/pipelines/compose.py", line 40, in __call__
    data = t(data)
  File "/workspace/source/Fast-BEV/mmdet3d/datasets/pipelines/multi_view.py", line 44, in __call__
    _results = self.transforms(_results)
  File "/opt/conda/lib/python3.7/site-packages/mmdet/datasets/pipelines/compose.py", line 40, in __call__
    data = t(data)
  File "/opt/conda/lib/python3.7/site-packages/mmdet/datasets/pipelines/loading.py", line 59, in __call__
    img_bytes = self.file_client.get(filename)
  File "/opt/conda/lib/python3.7/site-packages/mmcv/fileio/file_client.py", line 992, in get
    return self.client.get(filepath)
  File "/opt/conda/lib/python3.7/site-packages/mmcv/fileio/file_client.py", line 517, in get
    with open(filepath, 'rb') as f:
FileNotFoundError: [Errno 2] No such file or directory: './data/nuscenes/sweeps/CAM_FRONT_RIGHT/n015-2018-08-01-16-32-59+0800__CAM_FRONT_RIGHT__1533112540770339.jpg'

Maybe I do not download enough scenes from Nuscenes data since I cannot look up for the above program in my dataset folder. But I do not understand where is the lines that contains the above filepath in the code? That maybe from some metadata file. But I really am confused with the error trace above.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants