Skip to content

Commit

Permalink
currently obejct_detecion.py has OOM error
Browse files Browse the repository at this point in the history
  • Loading branch information
SangbumChoi committed Aug 8, 2024
1 parent 1b156b5 commit ffce43c
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 11 deletions.
4 changes: 2 additions & 2 deletions examples/pytorch/zero-shot/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,14 @@ python run_zero_shot_object_detection.py \
--do_train true \
--do_eval true \
--output_dir grounding-dino-tiny-finetuned-cppe-5-10k-steps \
--num_train_epochs 100 \
--num_train_epochs 10 \
--image_square_size 600 \
--fp16 true \
--learning_rate 5e-5 \
--weight_decay 1e-4 \
--dataloader_num_workers 4 \
--dataloader_prefetch_factor 2 \
--per_device_train_batch_size 8 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 1 \
--remove_unused_columns false \
--eval_do_concat_batches false \
Expand Down
78 changes: 69 additions & 9 deletions examples/pytorch/zero-shot/run_zero_shot_object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,37 @@ def convert_bbox_yolo_to_pascal(boxes: torch.Tensor, image_size: Tuple[int, int]
return boxes


def convert_zero_shot_to_coco_format(predictions, label2id):
"""
Convert zershot format output to typical object detection format in order to calculate mAP.
Args:
predictions (Dict): Output of zero-shot object detection
e.g. {'scores': tensor([0.4786, 0.4379, 0.4760], device='cuda:0'), 'labels': ['a cat', 'a cat', 'a remote control'], 'boxes': tensor([[344.6973, 23.1085, 637.1817, 374.2748],[ 12.2690, 51.9104, 316.8564, 472.4341],[ 38.5870, 70.0092, 176.7755, 118.1748]], device='cuda:0')}
label2id (Dict): Dictionary of label to id mapping
Returns:
Dict: Output of zero-shot object detection
e.g. {'scores': tensor([0.4786, 0.4379, 0.4760], device='cuda:0'), 'labels': [1, 1, 2], 'boxes': tensor([[344.6973, 23.1085, 637.1817, 374.2748],[ 12.2690, 51.9104, 316.8564, 472.4341],[ 38.5870, 70.0092, 176.7755, 118.1748]], device='cuda:0')}
"""
# convert center to corners format
torch_label = []
for prediction in predictions:
scores = prediction["scores"]
device = scores.device
labels = prediction["labels"]
for label in labels:
if label in label2id:
torch_label.append(label)
else:
# Give background class
torch_label.append(0)
prediction["labels"] = torch.Tensor(torch_label).to(dtype=torch.int32).to(device)

return predictions


def to_label_list(id2label):
return list(id2label.values())

Expand Down Expand Up @@ -213,16 +244,19 @@ def collate_fn(batch: List[BatchFeature]) -> Mapping[str, Union[torch.Tensor, Li
@torch.no_grad()
def compute_metrics(
evaluation_results: EvalPrediction,
image_processor: AutoProcessor,
threshold: float = 0.0,
processor: AutoProcessor,
box_threshold: float = 0.15,
text_threshold: float = 0.1,
id2label: Optional[Mapping[int, str]] = None,
label2id: Optional[Mapping[str, int]] = None,
) -> Mapping[str, float]:
"""
Compute mean average mAP, mAR and their variants for the object detection task.
Args:
evaluation_results (EvalPrediction): Predictions and targets from evaluation.
threshold (float, optional): Threshold to filter predicted boxes by confidence. Defaults to 0.0.
box_threshold (float, optional): Threshold to filter predicted boxes by confidence. Defaults to 0.15.
text_threshold (float, optional): Threshold to filter predicted text by confidence. Defaults to 0.1.
id2label (Optional[dict], optional): Mapping from class id to class name. Defaults to None.
Returns:
Expand Down Expand Up @@ -254,13 +288,14 @@ def compute_metrics(
post_processed_targets.append({"boxes": boxes, "labels": labels})

# Collect predictions in the required format for metric computation,
# model produce boxes in YOLO format, then image_processor convert them to Pascal VOC format
# model produce boxes in YOLO format, then processor convert them to Pascal VOC format
for batch, target_sizes in zip(predictions, image_sizes):
batch_logits, batch_boxes = batch[1], batch[2]
output = ModelOutput(logits=torch.tensor(batch_logits), pred_boxes=torch.tensor(batch_boxes))
post_processed_output = image_processor.post_process_object_detection(
output, threshold=threshold, target_sizes=target_sizes
post_processed_output = processor.post_process_grounded_object_detection(
output, box_threshold=box_threshold, text_threshold=text_threshold, target_sizes=target_sizes
)
post_processed_output = convert_zero_shot_to_coco_format(post_processed_output, label2id)
post_processed_predictions.extend(post_processed_output)

# Compute metrics
Expand Down Expand Up @@ -372,6 +407,14 @@ class ModelArguments:
)
},
)
freeze_backbone: bool = field(
default=True,
metadata={"help": ("Whether freeze the image backbone.")},
)
freeze_text_backbone: bool = field(
default=True,
metadata={"help": ("Whether freeze the text encoder.")},
)


def main():
Expand Down Expand Up @@ -478,6 +521,13 @@ def main():
model_args.image_processor_name or model_args.model_name_or_path,
)

# Freeze both text_backbone
if model_args.freeze_backbone:
model.model.freeze_backbone()
if model_args.freeze_text_backbone:
for name, param in model.model.text_backbone.named_parameters():
param.requires_grad_(False)

# ------------------------------------------------------------------------------------------------
# Define image augmentations and dataset transforms
# ------------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -513,10 +563,20 @@ def main():

# Make transform functions for batch and apply for dataset splits
train_transform_batch = partial(
augment_and_transform_batch, transform=train_augment_and_transform, image_processor=processor
augment_and_transform_batch,
transform=train_augment_and_transform,
processor=processor,
id2label=id2label,
label2id=label2id,
random_text_prompt=False,
)
validation_transform_batch = partial(
augment_and_transform_batch, transform=validation_transform, image_processor=processor
augment_and_transform_batch,
transform=validation_transform,
processor=processor,
id2label=id2label,
label2id=label2id,
random_text_prompt=False,
)

dataset["train"] = dataset["train"].with_transform(train_transform_batch)
Expand All @@ -527,7 +587,7 @@ def main():
# Model training and evaluation with Trainer API
# ------------------------------------------------------------------------------------------------

eval_compute_metrics_fn = partial(compute_metrics, image_processor=processor, id2label=id2label, threshold=0.0)
eval_compute_metrics_fn = partial(compute_metrics, processor=processor, id2label=id2label, label2id=label2id)

trainer = Trainer(
model=model,
Expand Down

0 comments on commit ffce43c

Please sign in to comment.