diff --git a/examples/pytorch/zero-shot/README.md b/examples/pytorch/zero-shot/README.md index 34099371549c05..40c22be8d44d5e 100644 --- a/examples/pytorch/zero-shot/README.md +++ b/examples/pytorch/zero-shot/README.md @@ -40,14 +40,14 @@ python run_zero_shot_object_detection.py \ --do_train true \ --do_eval true \ --output_dir grounding-dino-tiny-finetuned-cppe-5-10k-steps \ - --num_train_epochs 100 \ + --num_train_epochs 10 \ --image_square_size 600 \ --fp16 true \ --learning_rate 5e-5 \ --weight_decay 1e-4 \ --dataloader_num_workers 4 \ --dataloader_prefetch_factor 2 \ - --per_device_train_batch_size 8 \ + --per_device_train_batch_size 1 \ --gradient_accumulation_steps 1 \ --remove_unused_columns false \ --eval_do_concat_batches false \ diff --git a/examples/pytorch/zero-shot/run_zero_shot_object_detection.py b/examples/pytorch/zero-shot/run_zero_shot_object_detection.py index 97ec6c012ca655..eaf512700e3caf 100644 --- a/examples/pytorch/zero-shot/run_zero_shot_object_detection.py +++ b/examples/pytorch/zero-shot/run_zero_shot_object_detection.py @@ -117,6 +117,37 @@ def convert_bbox_yolo_to_pascal(boxes: torch.Tensor, image_size: Tuple[int, int] return boxes +def convert_zero_shot_to_coco_format(predictions, label2id): + """ + Convert zershot format output to typical object detection format in order to calculate mAP. + + Args: + predictions (Dict): Output of zero-shot object detection + e.g. {'scores': tensor([0.4786, 0.4379, 0.4760], device='cuda:0'), 'labels': ['a cat', 'a cat', 'a remote control'], 'boxes': tensor([[344.6973, 23.1085, 637.1817, 374.2748],[ 12.2690, 51.9104, 316.8564, 472.4341],[ 38.5870, 70.0092, 176.7755, 118.1748]], device='cuda:0')} + label2id (Dict): Dictionary of label to id mapping + + Returns: + Dict: Output of zero-shot object detection + e.g. {'scores': tensor([0.4786, 0.4379, 0.4760], device='cuda:0'), 'labels': [1, 1, 2], 'boxes': tensor([[344.6973, 23.1085, 637.1817, 374.2748],[ 12.2690, 51.9104, 316.8564, 472.4341],[ 38.5870, 70.0092, 176.7755, 118.1748]], device='cuda:0')} + + """ + # convert center to corners format + torch_label = [] + for prediction in predictions: + scores = prediction["scores"] + device = scores.device + labels = prediction["labels"] + for label in labels: + if label in label2id: + torch_label.append(label) + else: + # Give background class + torch_label.append(0) + prediction["labels"] = torch.Tensor(torch_label).to(dtype=torch.int32).to(device) + + return predictions + + def to_label_list(id2label): return list(id2label.values()) @@ -213,16 +244,19 @@ def collate_fn(batch: List[BatchFeature]) -> Mapping[str, Union[torch.Tensor, Li @torch.no_grad() def compute_metrics( evaluation_results: EvalPrediction, - image_processor: AutoProcessor, - threshold: float = 0.0, + processor: AutoProcessor, + box_threshold: float = 0.15, + text_threshold: float = 0.1, id2label: Optional[Mapping[int, str]] = None, + label2id: Optional[Mapping[str, int]] = None, ) -> Mapping[str, float]: """ Compute mean average mAP, mAR and their variants for the object detection task. Args: evaluation_results (EvalPrediction): Predictions and targets from evaluation. - threshold (float, optional): Threshold to filter predicted boxes by confidence. Defaults to 0.0. + box_threshold (float, optional): Threshold to filter predicted boxes by confidence. Defaults to 0.15. + text_threshold (float, optional): Threshold to filter predicted text by confidence. Defaults to 0.1. id2label (Optional[dict], optional): Mapping from class id to class name. Defaults to None. Returns: @@ -254,13 +288,14 @@ def compute_metrics( post_processed_targets.append({"boxes": boxes, "labels": labels}) # Collect predictions in the required format for metric computation, - # model produce boxes in YOLO format, then image_processor convert them to Pascal VOC format + # model produce boxes in YOLO format, then processor convert them to Pascal VOC format for batch, target_sizes in zip(predictions, image_sizes): batch_logits, batch_boxes = batch[1], batch[2] output = ModelOutput(logits=torch.tensor(batch_logits), pred_boxes=torch.tensor(batch_boxes)) - post_processed_output = image_processor.post_process_object_detection( - output, threshold=threshold, target_sizes=target_sizes + post_processed_output = processor.post_process_grounded_object_detection( + output, box_threshold=box_threshold, text_threshold=text_threshold, target_sizes=target_sizes ) + post_processed_output = convert_zero_shot_to_coco_format(post_processed_output, label2id) post_processed_predictions.extend(post_processed_output) # Compute metrics @@ -372,6 +407,14 @@ class ModelArguments: ) }, ) + freeze_backbone: bool = field( + default=True, + metadata={"help": ("Whether freeze the image backbone.")}, + ) + freeze_text_backbone: bool = field( + default=True, + metadata={"help": ("Whether freeze the text encoder.")}, + ) def main(): @@ -478,6 +521,13 @@ def main(): model_args.image_processor_name or model_args.model_name_or_path, ) + # Freeze both text_backbone + if model_args.freeze_backbone: + model.model.freeze_backbone() + if model_args.freeze_text_backbone: + for name, param in model.model.text_backbone.named_parameters(): + param.requires_grad_(False) + # ------------------------------------------------------------------------------------------------ # Define image augmentations and dataset transforms # ------------------------------------------------------------------------------------------------ @@ -513,10 +563,20 @@ def main(): # Make transform functions for batch and apply for dataset splits train_transform_batch = partial( - augment_and_transform_batch, transform=train_augment_and_transform, image_processor=processor + augment_and_transform_batch, + transform=train_augment_and_transform, + processor=processor, + id2label=id2label, + label2id=label2id, + random_text_prompt=False, ) validation_transform_batch = partial( - augment_and_transform_batch, transform=validation_transform, image_processor=processor + augment_and_transform_batch, + transform=validation_transform, + processor=processor, + id2label=id2label, + label2id=label2id, + random_text_prompt=False, ) dataset["train"] = dataset["train"].with_transform(train_transform_batch) @@ -527,7 +587,7 @@ def main(): # Model training and evaluation with Trainer API # ------------------------------------------------------------------------------------------------ - eval_compute_metrics_fn = partial(compute_metrics, image_processor=processor, id2label=id2label, threshold=0.0) + eval_compute_metrics_fn = partial(compute_metrics, processor=processor, id2label=id2label, label2id=label2id) trainer = Trainer( model=model,