diff --git a/docs/source/en/model_doc/mask2former.md b/docs/source/en/model_doc/mask2former.md
index bd5ab80728eb48..4faeed50311f69 100644
--- a/docs/source/en/model_doc/mask2former.md
+++ b/docs/source/en/model_doc/mask2former.md
@@ -41,6 +41,7 @@ This model was contributed by [Shivalika Singh](https://huggingface.co/shivi) an
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with Mask2Former.
- Demo notebooks regarding inference + fine-tuning Mask2Former on custom data can be found [here](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/Mask2Former).
+- Scripts for finetuning [`Mask2Former`] with [`Trainer`] or [Accelerate](https://huggingface.co/docs/accelerate/index) can be found [here](https://github.com/huggingface/transformers/tree/main/examples/pytorch/instance-segmentation).
If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we will review it.
The resource should ideally demonstrate something new instead of duplicating an existing resource.
diff --git a/docs/source/en/model_doc/maskformer.md b/docs/source/en/model_doc/maskformer.md
index 4d31b2829d10f2..a0199f380ce647 100644
--- a/docs/source/en/model_doc/maskformer.md
+++ b/docs/source/en/model_doc/maskformer.md
@@ -51,6 +51,7 @@ This model was contributed by [francesco](https://huggingface.co/francesco). The
- All notebooks that illustrate inference as well as fine-tuning on custom data with MaskFormer can be found [here](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/MaskFormer).
+- Scripts for finetuning [`MaskFormer`] with [`Trainer`] or [Accelerate](https://huggingface.co/docs/accelerate/index) can be found [here](https://github.com/huggingface/transformers/tree/main/examples/pytorch/instance-segmentation).
## MaskFormer specific outputs
diff --git a/examples/pytorch/README.md b/examples/pytorch/README.md
index f1b3f37d44b930..178102ec092aeb 100644
--- a/examples/pytorch/README.md
+++ b/examples/pytorch/README.md
@@ -47,6 +47,7 @@ Coming soon!
| [**`image-classification`**](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-classification) | [CIFAR-10](https://huggingface.co/datasets/cifar10) | ✅ | ✅ |✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/image_classification.ipynb)
| [**`semantic-segmentation`**](https://github.com/huggingface/transformers/tree/main/examples/pytorch/semantic-segmentation) | [SCENE_PARSE_150](https://huggingface.co/datasets/scene_parse_150) | ✅ | ✅ |✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/semantic_segmentation.ipynb)
| [**`object-detection`**](https://github.com/huggingface/transformers/tree/main/examples/pytorch/object-detection) | [CPPE-5](https://huggingface.co/datasets/cppe-5) | ✅ | ✅ |✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/transformers_doc/en/pytorch/object_detection.ipynb)
+| [**`instance-segmentation`**](https://github.com/huggingface/transformers/tree/main/examples/pytorch/instance-segmentation) | [ADE20K sample](https://huggingface.co/datasets/qubvel-hf/ade20k-mini) | ✅ | ✅ |✅ |
## Running quick tests
diff --git a/examples/pytorch/instance-segmentation/README.md b/examples/pytorch/instance-segmentation/README.md
new file mode 100644
index 00000000000000..72eb5a5befb4fb
--- /dev/null
+++ b/examples/pytorch/instance-segmentation/README.md
@@ -0,0 +1,235 @@
+
+
+# Instance Segmentation Examples
+
+This directory contains two scripts that demonstrate how to fine-tune [MaskFormer](https://huggingface.co/docs/transformers/model_doc/maskformer) and [Mask2Former](https://huggingface.co/docs/transformers/model_doc/mask2former) for instance segmentation using PyTorch.
+For other instance segmentation models, such as [DETR](https://huggingface.co/docs/transformers/model_doc/detr) and [Conditional DETR](https://huggingface.co/docs/transformers/model_doc/conditional_detr), the scripts need to be adjusted to properly handle input and output data.
+
+Content:
+- [PyTorch Version with Trainer](#pytorch-version-with-trainer)
+- [PyTorch Version with Accelerate](#pytorch-version-with-accelerate)
+- [Reload and Perform Inference](#reload-and-perform-inference)
+- [Note on Custom Data](#note-on-custom-data)
+
+## PyTorch Version with Trainer
+
+This example is based on the script [`run_instance_segmentation.py`](https://github.com/huggingface/transformers/blob/main/examples/pytorch/instance-segmentation/run_instance_segmentation.py).
+
+The script uses the [🤗 Trainer API](https://huggingface.co/docs/transformers/main_classes/trainer) to manage training automatically, including distributed environments.
+
+Here, we show how to fine-tune a [Mask2Former](https://huggingface.co/docs/transformers/model_doc/mask2former) model on a subsample of the [ADE20K](https://huggingface.co/datasets/zhoubolei/scene_parse_150) dataset. We created a [small dataset](https://huggingface.co/datasets/qubvel-hf/ade20k-mini) with approximately 2,000 images containing only "person" and "car" annotations; all other pixels are marked as "background."
+
+Here is the `label2id` mapping for this dataset:
+
+```python
+label2id = {
+ "background": 0,
+ "person": 1,
+ "car": 2,
+}
+```
+
+Since the `background` label is not an instance and we don't want to predict it, we will use `do_reduce_labels` to remove it from the data.
+
+Run the training with the following command:
+
+```bash
+python run_instance_segmentation.py \
+ --model_name_or_path facebook/mask2former-swin-tiny-coco-instance \
+ --output_dir finetune-instance-segmentation-ade20k-mini-mask2former \
+ --dataset_name qubvel-hf/ade20k-mini \
+ --do_reduce_labels \
+ --image_height 256 \
+ --image_width 256 \
+ --do_train \
+ --fp16 \
+ --num_train_epochs 40 \
+ --learning_rate 1e-5 \
+ --lr_scheduler_type constant \
+ --per_device_train_batch_size 8 \
+ --gradient_accumulation_steps 2 \
+ --dataloader_num_workers 8 \
+ --dataloader_persistent_workers \
+ --dataloader_prefetch_factor 4 \
+ --do_eval \
+ --evaluation_strategy epoch \
+ --logging_strategy epoch \
+ --save_strategy epoch \
+ --save_total_limit 2 \
+ --push_to_hub
+```
+
+The resulting model can be viewed [here](https://huggingface.co/qubvel-hf/finetune-instance-segmentation-ade20k-mini-mask2former). Always refer to the original paper for details on training hyperparameters. To improve model quality, consider:
+- Changing image size parameters (`--image_height`/`--image_width`)
+- Adjusting training parameters such as learning rate, batch size, warmup, optimizer, and more (see [TrainingArguments](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments))
+- Adding more image augmentations (we created a helpful [HF Space](https://huggingface.co/spaces/qubvel-hf/albumentations-demo) to choose some)
+
+You can also replace the model [checkpoint](https://huggingface.co/models?search=maskformer).
+
+## PyTorch Version with Accelerate
+
+This example is based on the script [`run_instance_segmentation_no_trainer.py`](https://github.com/huggingface/transformers/blob/main/examples/pytorch/instance-segmentation/run_instance_segmentation_no_trainer.py).
+
+The script uses [🤗 Accelerate](https://github.com/huggingface/accelerate) to write your own training loop in PyTorch and run it on various environments, including CPU, multi-CPU, GPU, multi-GPU, and TPU, with support for mixed precision.
+
+First, configure the environment:
+
+```bash
+accelerate config
+```
+
+Answer the questions regarding your training environment. Then, run:
+
+```bash
+accelerate test
+```
+
+This command ensures everything is ready for training. Finally, launch training with:
+
+```bash
+accelerate launch run_instance_segmentation_no_trainer.py \
+ --model_name_or_path facebook/mask2former-swin-tiny-coco-instance \
+ --output_dir finetune-instance-segmentation-ade20k-mini-mask2former-no-trainer \
+ --dataset_name qubvel-hf/ade20k-mini \
+ --do_reduce_labels \
+ --image_height 256 \
+ --image_width 256 \
+ --num_train_epochs 40 \
+ --learning_rate 1e-5 \
+ --lr_scheduler_type constant \
+ --per_device_train_batch_size 8 \
+ --gradient_accumulation_steps 2 \
+ --dataloader_num_workers 8 \
+ --push_to_hub
+```
+
+With this setup, you can train on multiple GPUs, log everything to trackers (like Weights and Biases, Tensorboard), and regularly push your model to the hub (with the repo name set to `args.output_dir` under your HF username).
+With the default settings, the script fine-tunes a [Mask2Former](https://huggingface.co/docs/transformers/model_doc/mask2former) model on the sample of [ADE20K](https://huggingface.co/datasets/qubvel-hf/ade20k-mini) dataset. The resulting model can be viewed [here](https://huggingface.co/qubvel-hf/finetune-instance-segmentation-ade20k-mini-mask2former-no-trainer).
+
+## Reload and Perform Inference
+
+After training, you can easily load your trained model and perform inference as follows:
+
+```python
+import torch
+import requests
+import matplotlib.pyplot as plt
+
+from PIL import Image
+from transformers import Mask2FormerForUniversalSegmentation, Mask2FormerImageProcessor
+
+# Load image
+image = Image.open(requests.get("http://farm4.staticflickr.com/3017/3071497290_31f0393363_z.jpg", stream=True).raw)
+
+# Load model and image processor
+device = "cuda"
+checkpoint = "qubvel-hf/finetune-instance-segmentation-ade20k-mini-mask2former"
+
+model = Mask2FormerForUniversalSegmentation.from_pretrained(checkpoint, device_map=device)
+image_processor = Mask2FormerImageProcessor.from_pretrained(checkpoint)
+
+# Run inference on image
+inputs = image_processor(images=[image], return_tensors="pt").to(device)
+with torch.no_grad():
+ outputs = model(**inputs)
+
+# Post-process outputs
+outputs = image_processor.post_process_instance_segmentation(outputs, target_sizes=[image.size[::-1]])
+
+print("Mask shape: ", outputs[0]["segmentation"].shape)
+print("Mask values: ", outputs[0]["segmentation"].unique())
+for segment in outputs[0]["segments_info"]:
+ print("Segment: ", segment)
+```
+
+```
+Mask shape: torch.Size([427, 640])
+Mask values: tensor([-1., 0., 1., 2., 3., 4., 5., 6.])
+Segment: {'id': 0, 'label_id': 0, 'was_fused': False, 'score': 0.946127}
+Segment: {'id': 1, 'label_id': 1, 'was_fused': False, 'score': 0.961582}
+Segment: {'id': 2, 'label_id': 1, 'was_fused': False, 'score': 0.968367}
+Segment: {'id': 3, 'label_id': 1, 'was_fused': False, 'score': 0.819527}
+Segment: {'id': 4, 'label_id': 1, 'was_fused': False, 'score': 0.655761}
+Segment: {'id': 5, 'label_id': 1, 'was_fused': False, 'score': 0.531299}
+Segment: {'id': 6, 'label_id': 1, 'was_fused': False, 'score': 0.929477}
+```
+
+Use the following code to visualize the results:
+
+```python
+import numpy as np
+import matplotlib.pyplot as plt
+
+segmentation = outputs[0]["segmentation"].numpy()
+
+plt.figure(figsize=(10, 10))
+plt.subplot(1, 2, 1)
+plt.imshow(np.array(image))
+plt.axis("off")
+plt.subplot(1, 2, 2)
+plt.imshow(segmentation)
+plt.axis("off")
+plt.show()
+```
+
+![Result](https://i.imgur.com/rZmaRjD.png)
+
+## Note on Custom Data
+
+Here is a short script demonstrating how to create your own dataset for instance segmentation and push it to the hub:
+
+> Note: Annotations should be represented as 3-channel images (similar to the [scene_parsing_150](https://huggingface.co/datasets/zhoubolei/scene_parse_150#instance_segmentation-1) dataset). The first channel is a semantic-segmentation map with values corresponding to `label2id`, the second is an instance-segmentation map where each instance has a unique value, and the third channel should be empty (filled with zeros).
+
+```python
+from datasets import Dataset, DatasetDict
+from datasets import Image as DatasetImage
+
+label2id = {
+ "background": 0,
+ "person": 1,
+ "car": 2,
+}
+
+train_split = {
+ "image": [, , , ...],
+ "annotation": [, , , ...],
+}
+
+validation_split = {
+ "image": [, , , ...],
+ "annotation": [, , , ...],
+}
+
+def create_instance_segmentation_dataset(label2id, **splits):
+ dataset_dict = {}
+ for split_name, split in splits.items():
+ split["semantic_class_to_id"] = [label2id] * len(split["image"])
+ dataset_split = (
+ Dataset.from_dict(split)
+ .cast_column("image", DatasetImage())
+ .cast_column("annotation", DatasetImage())
+ )
+ dataset_dict[split_name] = dataset_split
+ return DatasetDict(dataset_dict)
+
+dataset = create_instance_segmentation_dataset(label2id, train=train_split, validation=validation_split)
+dataset.push_to_hub("qubvel-hf/ade20k-nano")
+```
+
+Use this dataset for fine-tuning by specifying its name with `--dataset_name `.
+
+See also: [Dataset Creation Guide](https://huggingface.co/docs/datasets/image_dataset#create-an-image-dataset)
\ No newline at end of file
diff --git a/examples/pytorch/instance-segmentation/requirements.txt b/examples/pytorch/instance-segmentation/requirements.txt
new file mode 100644
index 00000000000000..2aa0d9bcf01672
--- /dev/null
+++ b/examples/pytorch/instance-segmentation/requirements.txt
@@ -0,0 +1,5 @@
+albumentations >= 1.4.5
+timm
+datasets
+torchmetrics
+pycocotools
diff --git a/examples/pytorch/instance-segmentation/run_instance_segmentation.py b/examples/pytorch/instance-segmentation/run_instance_segmentation.py
new file mode 100644
index 00000000000000..3a5d08b250595f
--- /dev/null
+++ b/examples/pytorch/instance-segmentation/run_instance_segmentation.py
@@ -0,0 +1,469 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+"""Finetuning 🤗 Transformers model for instance segmentation leveraging the Trainer API."""
+
+import logging
+import os
+import sys
+from dataclasses import dataclass, field
+from functools import partial
+from typing import Any, Dict, List, Mapping, Optional
+
+import albumentations as A
+import numpy as np
+import torch
+from datasets import load_dataset
+from torchmetrics.detection.mean_ap import MeanAveragePrecision
+
+import transformers
+from transformers import (
+ AutoImageProcessor,
+ AutoModelForUniversalSegmentation,
+ HfArgumentParser,
+ Trainer,
+ TrainingArguments,
+)
+from transformers.image_processing_utils import BatchFeature
+from transformers.trainer import EvalPrediction
+from transformers.trainer_utils import get_last_checkpoint
+from transformers.utils import check_min_version, send_example_telemetry
+from transformers.utils.versions import require_version
+
+
+logger = logging.getLogger(__name__)
+
+# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
+check_min_version("4.42.0.dev0")
+require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")
+
+
+@dataclass
+class Arguments:
+ """
+ Arguments pertaining to what data we are going to input our model for training and eval.
+ Using `HfArgumentParser` we can turn this class into argparse arguments to be able to specify
+ them on the command line.
+ """
+
+ model_name_or_path: str = field(
+ default="facebook/mask2former-swin-tiny-coco-instance",
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"},
+ )
+ dataset_name: str = field(
+ default="qubvel-hf/ade20k-mini",
+ metadata={
+ "help": "Name of a dataset from the hub (could be your own, possibly private dataset hosted on the hub)."
+ },
+ )
+ image_height: Optional[int] = field(default=512, metadata={"help": "Image height after resizing."})
+ image_width: Optional[int] = field(default=512, metadata={"help": "Image width after resizing."})
+ token: str = field(
+ default=None,
+ metadata={
+ "help": (
+ "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
+ "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
+ )
+ },
+ )
+ do_reduce_labels: bool = field(
+ default=False,
+ metadata={
+ "help": (
+ "If background class is labeled as 0 and you want to remove it from the labels, set this flag to True."
+ )
+ },
+ )
+
+
+def augment_and_transform_batch(
+ examples: Mapping[str, Any], transform: A.Compose, image_processor: AutoImageProcessor
+) -> BatchFeature:
+ batch = {
+ "pixel_values": [],
+ "mask_labels": [],
+ "class_labels": [],
+ }
+
+ for pil_image, pil_annotation in zip(examples["image"], examples["annotation"]):
+ image = np.array(pil_image)
+ semantic_and_instance_masks = np.array(pil_annotation)[..., :2]
+
+ # Apply augmentations
+ output = transform(image=image, mask=semantic_and_instance_masks)
+
+ aug_image = output["image"]
+ aug_semantic_and_instance_masks = output["mask"]
+ aug_instance_mask = aug_semantic_and_instance_masks[..., 1]
+
+ # Create mapping from instance id to semantic id
+ unique_semantic_id_instance_id_pairs = np.unique(aug_semantic_and_instance_masks.reshape(-1, 2), axis=0)
+ instance_id_to_semantic_id = {
+ instance_id: semantic_id for semantic_id, instance_id in unique_semantic_id_instance_id_pairs
+ }
+
+ # Apply the image processor transformations: resizing, rescaling, normalization
+ model_inputs = image_processor(
+ images=[aug_image],
+ segmentation_maps=[aug_instance_mask],
+ instance_id_to_semantic_id=instance_id_to_semantic_id,
+ return_tensors="pt",
+ )
+
+ batch["pixel_values"].append(model_inputs.pixel_values[0])
+ batch["mask_labels"].append(model_inputs.mask_labels[0])
+ batch["class_labels"].append(model_inputs.class_labels[0])
+
+ return batch
+
+
+def collate_fn(examples):
+ batch = {}
+ batch["pixel_values"] = torch.stack([example["pixel_values"] for example in examples])
+ batch["class_labels"] = [example["class_labels"] for example in examples]
+ batch["mask_labels"] = [example["mask_labels"] for example in examples]
+ if "pixel_mask" in examples[0]:
+ batch["pixel_mask"] = torch.stack([example["pixel_mask"] for example in examples])
+ return batch
+
+
+@dataclass
+class ModelOutput:
+ class_queries_logits: torch.Tensor
+ masks_queries_logits: torch.Tensor
+
+
+def nested_cpu(tensors):
+ if isinstance(tensors, (list, tuple)):
+ return type(tensors)(nested_cpu(t) for t in tensors)
+ elif isinstance(tensors, Mapping):
+ return type(tensors)({k: nested_cpu(t) for k, t in tensors.items()})
+ elif isinstance(tensors, torch.Tensor):
+ return tensors.cpu().detach()
+ else:
+ return tensors
+
+
+class Evaluator:
+ """
+ Compute metrics for the instance segmentation task.
+ """
+
+ def __init__(
+ self,
+ image_processor: AutoImageProcessor,
+ id2label: Mapping[int, str],
+ threshold: float = 0.0,
+ ):
+ """
+ Initialize evaluator with image processor, id2label mapping and threshold for filtering predictions.
+
+ Args:
+ image_processor (AutoImageProcessor): Image processor for
+ `post_process_instance_segmentation` method.
+ id2label (Mapping[int, str]): Mapping from class id to class name.
+ threshold (float): Threshold to filter predicted boxes by confidence. Defaults to 0.0.
+ """
+ self.image_processor = image_processor
+ self.id2label = id2label
+ self.threshold = threshold
+ self.metric = self.get_metric()
+
+ def get_metric(self):
+ metric = MeanAveragePrecision(iou_type="segm", class_metrics=True)
+ return metric
+
+ def reset_metric(self):
+ self.metric.reset()
+
+ def postprocess_target_batch(self, target_batch) -> List[Dict[str, torch.Tensor]]:
+ """Collect targets in a form of list of dictionaries with keys "masks", "labels"."""
+ batch_masks = target_batch[0]
+ batch_labels = target_batch[1]
+ post_processed_targets = []
+ for masks, labels in zip(batch_masks, batch_labels):
+ post_processed_targets.append(
+ {
+ "masks": masks.to(dtype=torch.bool),
+ "labels": labels,
+ }
+ )
+ return post_processed_targets
+
+ def get_target_sizes(self, post_processed_targets) -> List[List[int]]:
+ target_sizes = []
+ for target in post_processed_targets:
+ target_sizes.append(target["masks"].shape[-2:])
+ return target_sizes
+
+ def postprocess_prediction_batch(self, prediction_batch, target_sizes) -> List[Dict[str, torch.Tensor]]:
+ """Collect predictions in a form of list of dictionaries with keys "masks", "labels", "scores"."""
+
+ model_output = ModelOutput(class_queries_logits=prediction_batch[0], masks_queries_logits=prediction_batch[1])
+ post_processed_output = self.image_processor.post_process_instance_segmentation(
+ model_output,
+ threshold=self.threshold,
+ target_sizes=target_sizes,
+ return_binary_maps=True,
+ )
+
+ post_processed_predictions = []
+ for image_predictions, target_size in zip(post_processed_output, target_sizes):
+ if image_predictions["segments_info"]:
+ post_processed_image_prediction = {
+ "masks": image_predictions["segmentation"].to(dtype=torch.bool),
+ "labels": torch.tensor([x["label_id"] for x in image_predictions["segments_info"]]),
+ "scores": torch.tensor([x["score"] for x in image_predictions["segments_info"]]),
+ }
+ else:
+ # for void predictions, we need to provide empty tensors
+ post_processed_image_prediction = {
+ "masks": torch.zeros([0, *target_size], dtype=torch.bool),
+ "labels": torch.tensor([]),
+ "scores": torch.tensor([]),
+ }
+ post_processed_predictions.append(post_processed_image_prediction)
+
+ return post_processed_predictions
+
+ @torch.no_grad()
+ def __call__(self, evaluation_results: EvalPrediction, compute_result: bool = False) -> Mapping[str, float]:
+ """
+ Update metrics with current evaluation results and return metrics if `compute_result` is True.
+
+ Args:
+ evaluation_results (EvalPrediction): Predictions and targets from evaluation.
+ compute_result (bool): Whether to compute and return metrics.
+
+ Returns:
+ Mapping[str, float]: Metrics in a form of dictionary {: }
+ """
+ prediction_batch = nested_cpu(evaluation_results.predictions)
+ target_batch = nested_cpu(evaluation_results.label_ids)
+
+ # For metric computation we need to provide:
+ # - targets in a form of list of dictionaries with keys "masks", "labels"
+ # - predictions in a form of list of dictionaries with keys "masks", "labels", "scores"
+ post_processed_targets = self.postprocess_target_batch(target_batch)
+ target_sizes = self.get_target_sizes(post_processed_targets)
+ post_processed_predictions = self.postprocess_prediction_batch(prediction_batch, target_sizes)
+
+ # Compute metrics
+ self.metric.update(post_processed_predictions, post_processed_targets)
+
+ if not compute_result:
+ return
+
+ metrics = self.metric.compute()
+
+ # Replace list of per class metrics with separate metric for each class
+ classes = metrics.pop("classes")
+ map_per_class = metrics.pop("map_per_class")
+ mar_100_per_class = metrics.pop("mar_100_per_class")
+ for class_id, class_map, class_mar in zip(classes, map_per_class, mar_100_per_class):
+ class_name = self.id2label[class_id.item()] if self.id2label is not None else class_id.item()
+ metrics[f"map_{class_name}"] = class_map
+ metrics[f"mar_100_{class_name}"] = class_mar
+
+ metrics = {k: round(v.item(), 4) for k, v in metrics.items()}
+
+ # Reset metric for next evaluation
+ self.reset_metric()
+
+ return metrics
+
+
+def setup_logging(training_args: TrainingArguments) -> None:
+ """Setup logging according to `training_args`."""
+
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ handlers=[logging.StreamHandler(sys.stdout)],
+ )
+
+ if training_args.should_log:
+ # The default of training_args.log_level is passive, so we set log level at info here to have that default.
+ transformers.utils.logging.set_verbosity_info()
+
+ log_level = training_args.get_process_log_level()
+ logger.setLevel(log_level)
+ transformers.utils.logging.set_verbosity(log_level)
+ transformers.utils.logging.enable_default_handler()
+ transformers.utils.logging.enable_explicit_format()
+
+
+def find_last_checkpoint(training_args: TrainingArguments) -> Optional[str]:
+ """Find the last checkpoint in the output directory according to parameters specified in `training_args`."""
+
+ checkpoint = None
+ if training_args.resume_from_checkpoint is not None:
+ checkpoint = training_args.resume_from_checkpoint
+ elif os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir:
+ checkpoint = get_last_checkpoint(training_args.output_dir)
+ if checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
+ raise ValueError(
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
+ "Use --overwrite_output_dir to overcome."
+ )
+ elif checkpoint is not None and training_args.resume_from_checkpoint is None:
+ logger.info(
+ f"Checkpoint detected, resuming training at {checkpoint}. To avoid this behavior, change "
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
+ )
+
+ return checkpoint
+
+
+def main():
+ # See all possible arguments in https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments
+ # or by passing the --help flag to this script.
+
+ parser = HfArgumentParser([Arguments, TrainingArguments])
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
+ # If we pass only one argument to the script and it's the path to a json file,
+ # let's parse it to get our arguments.
+ args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
+ else:
+ args, training_args = parser.parse_args_into_dataclasses()
+
+ # Set default training arguments for instance segmentation
+ training_args.eval_do_concat_batches = False
+ training_args.batch_eval_metrics = True
+ training_args.remove_unused_columns = False
+
+ # # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_instance_segmentation", args)
+
+ # Setup logging and log on each process the small summary:
+ setup_logging(training_args)
+ logger.warning(
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
+ + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
+ )
+ logger.info(f"Training/evaluation parameters {training_args}")
+
+ # Load last checkpoint from output_dir if it exists (and we are not overwriting it)
+ checkpoint = find_last_checkpoint(training_args)
+
+ # ------------------------------------------------------------------------------------------------
+ # Load dataset, prepare splits
+ # ------------------------------------------------------------------------------------------------
+
+ dataset = load_dataset(args.dataset_name)
+
+ # We need to specify the label2id mapping for the model
+ # it is a mapping from semantic class name to class index.
+ # In case your dataset does not provide it, you can create it manually:
+ # label2id = {"background": 0, "cat": 1, "dog": 2}
+ label2id = dataset["train"][0]["semantic_class_to_id"]
+
+ if args.do_reduce_labels:
+ label2id = {name: idx for name, idx in label2id.items() if idx != 0} # remove background class
+ label2id = {name: idx - 1 for name, idx in label2id.items()} # shift class indices by -1
+
+ id2label = {v: k for k, v in label2id.items()}
+
+ # ------------------------------------------------------------------------------------------------
+ # Load pretrained config, model and image processor
+ # ------------------------------------------------------------------------------------------------
+ model = AutoModelForUniversalSegmentation.from_pretrained(
+ args.model_name_or_path,
+ label2id=label2id,
+ id2label=id2label,
+ ignore_mismatched_sizes=True,
+ token=args.token,
+ )
+
+ image_processor = AutoImageProcessor.from_pretrained(
+ args.model_name_or_path,
+ do_resize=True,
+ size={"height": args.image_height, "width": args.image_width},
+ do_reduce_labels=args.do_reduce_labels,
+ reduce_labels=args.do_reduce_labels, # TODO: remove when mask2former support `do_reduce_labels`
+ token=args.token,
+ )
+
+ # ------------------------------------------------------------------------------------------------
+ # Define image augmentations and dataset transforms
+ # ------------------------------------------------------------------------------------------------
+ train_augment_and_transform = A.Compose(
+ [
+ A.HorizontalFlip(p=0.5),
+ A.RandomBrightnessContrast(p=0.5),
+ A.HueSaturationValue(p=0.1),
+ ],
+ )
+ validation_transform = A.Compose(
+ [A.NoOp()],
+ )
+
+ # Make transform functions for batch and apply for dataset splits
+ train_transform_batch = partial(
+ augment_and_transform_batch, transform=train_augment_and_transform, image_processor=image_processor
+ )
+ validation_transform_batch = partial(
+ augment_and_transform_batch, transform=validation_transform, image_processor=image_processor
+ )
+
+ dataset["train"] = dataset["train"].with_transform(train_transform_batch)
+ dataset["validation"] = dataset["validation"].with_transform(validation_transform_batch)
+
+ # ------------------------------------------------------------------------------------------------
+ # Model training and evaluation with Trainer API
+ # ------------------------------------------------------------------------------------------------
+
+ compute_metrics = Evaluator(image_processor=image_processor, id2label=id2label, threshold=0.0)
+
+ trainer = Trainer(
+ model=model,
+ args=training_args,
+ train_dataset=dataset["train"] if training_args.do_train else None,
+ eval_dataset=dataset["validation"] if training_args.do_eval else None,
+ tokenizer=image_processor,
+ data_collator=collate_fn,
+ compute_metrics=compute_metrics,
+ )
+
+ # Training
+ if training_args.do_train:
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
+ trainer.save_model()
+ trainer.log_metrics("train", train_result.metrics)
+ trainer.save_metrics("train", train_result.metrics)
+ trainer.save_state()
+
+ # Final evaluation
+ if training_args.do_eval:
+ metrics = trainer.evaluate(eval_dataset=dataset["validation"], metric_key_prefix="test")
+ trainer.log_metrics("test", metrics)
+ trainer.save_metrics("test", metrics)
+
+ # Write model card and (optionally) push to hub
+ kwargs = {
+ "finetuned_from": args.model_name_or_path,
+ "dataset": args.dataset_name,
+ "tags": ["image-segmentation", "instance-segmentation", "vision"],
+ }
+ if training_args.push_to_hub:
+ trainer.push_to_hub(**kwargs)
+ else:
+ trainer.create_model_card(**kwargs)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/pytorch/instance-segmentation/run_instance_segmentation_no_trainer.py b/examples/pytorch/instance-segmentation/run_instance_segmentation_no_trainer.py
new file mode 100644
index 00000000000000..f9b96389166eb0
--- /dev/null
+++ b/examples/pytorch/instance-segmentation/run_instance_segmentation_no_trainer.py
@@ -0,0 +1,734 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+"""Finetuning 🤗 Transformers model for instance segmentation with Accelerate 🚀."""
+
+import argparse
+import json
+import logging
+import math
+import os
+import sys
+from functools import partial
+from pathlib import Path
+from typing import Any, Mapping
+
+import albumentations as A
+import datasets
+import numpy as np
+import torch
+from accelerate import Accelerator
+from accelerate.utils import set_seed
+from datasets import load_dataset
+from huggingface_hub import HfApi
+from torch.utils.data import DataLoader
+from torchmetrics.detection.mean_ap import MeanAveragePrecision
+from tqdm import tqdm
+
+import transformers
+from transformers import (
+ AutoImageProcessor,
+ AutoModelForUniversalSegmentation,
+ SchedulerType,
+ get_scheduler,
+)
+from transformers.image_processing_utils import BatchFeature
+from transformers.utils import check_min_version, send_example_telemetry
+from transformers.utils.versions import require_version
+
+
+logger = logging.getLogger(__name__)
+
+# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
+check_min_version("4.42.0.dev0")
+require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Finetune a transformers model for instance segmentation task")
+
+ parser.add_argument(
+ "--model_name_or_path",
+ type=str,
+ help="Path to a pretrained model or model identifier from huggingface.co/models.",
+ default="facebook/mask2former-swin-tiny-coco-instance",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ help="Name of the dataset on the hub.",
+ default="qubvel-hf/ade20k-mini",
+ )
+ parser.add_argument(
+ "--image_height",
+ type=int,
+ default=384,
+ help="The height of the images to feed the model.",
+ )
+ parser.add_argument(
+ "--image_width",
+ type=int,
+ default=384,
+ help="The width of the images to feed the model.",
+ )
+ parser.add_argument(
+ "--do_reduce_labels",
+ action="store_true",
+ help="Whether to reduce the number of labels by removing the background class.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ help="Path to a folder in which the model and dataset will be cached.",
+ )
+ parser.add_argument(
+ "--per_device_train_batch_size",
+ type=int,
+ default=8,
+ help="Batch size (per device) for the training dataloader.",
+ )
+ parser.add_argument(
+ "--per_device_eval_batch_size",
+ type=int,
+ default=8,
+ help="Batch size (per device) for the evaluation dataloader.",
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=4,
+ help="Number of workers to use for the dataloaders.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=5e-5,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--adam_beta1",
+ type=float,
+ default=0.9,
+ help="Beta1 for AdamW optimizer",
+ )
+ parser.add_argument(
+ "--adam_beta2",
+ type=float,
+ default=0.999,
+ help="Beta2 for AdamW optimizer",
+ )
+ parser.add_argument(
+ "--adam_epsilon",
+ type=float,
+ default=1e-8,
+ help="Epsilon for AdamW optimizer",
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.")
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--lr_scheduler_type",
+ type=SchedulerType,
+ default="linear",
+ help="The scheduler type to use.",
+ choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
+ )
+ parser.add_argument(
+ "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument(
+ "--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`."
+ )
+ parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=str,
+ default=None,
+ help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help="If the training should continue from a checkpoint folder.",
+ )
+ parser.add_argument(
+ "--with_tracking",
+ required=False,
+ action="store_true",
+ help="Whether to enable experiment trackers for logging.",
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="all",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
+ ' `"wandb"`, `"comet_ml"` and `"clearml"`. Use `"all"` (default) to report to all integrations. '
+ "Only applicable when `--with_tracking` is passed."
+ ),
+ )
+ args = parser.parse_args()
+
+ # Sanity checks
+ if args.push_to_hub or args.with_tracking:
+ if args.output_dir is None:
+ raise ValueError(
+ "Need an `output_dir` to create a repo when `--push_to_hub` or `with_tracking` is specified."
+ )
+
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ return args
+
+
+def augment_and_transform_batch(
+ examples: Mapping[str, Any], transform: A.Compose, image_processor: AutoImageProcessor
+) -> BatchFeature:
+ batch = {
+ "pixel_values": [],
+ "mask_labels": [],
+ "class_labels": [],
+ }
+
+ for pil_image, pil_annotation in zip(examples["image"], examples["annotation"]):
+ image = np.array(pil_image)
+ semantic_and_instance_masks = np.array(pil_annotation)[..., :2]
+
+ # Apply augmentations
+ output = transform(image=image, mask=semantic_and_instance_masks)
+
+ aug_image = output["image"]
+ aug_semantic_and_instance_masks = output["mask"]
+ aug_instance_mask = aug_semantic_and_instance_masks[..., 1]
+
+ # Create mapping from instance id to semantic id
+ unique_semantic_id_instance_id_pairs = np.unique(aug_semantic_and_instance_masks.reshape(-1, 2), axis=0)
+ instance_id_to_semantic_id = {
+ instance_id: semantic_id for semantic_id, instance_id in unique_semantic_id_instance_id_pairs
+ }
+
+ # Apply the image processor transformations: resizing, rescaling, normalization
+ model_inputs = image_processor(
+ images=[aug_image],
+ segmentation_maps=[aug_instance_mask],
+ instance_id_to_semantic_id=instance_id_to_semantic_id,
+ return_tensors="pt",
+ )
+
+ batch["pixel_values"].append(model_inputs.pixel_values[0])
+ batch["mask_labels"].append(model_inputs.mask_labels[0])
+ batch["class_labels"].append(model_inputs.class_labels[0])
+
+ return batch
+
+
+def collate_fn(examples):
+ batch = {}
+ batch["pixel_values"] = torch.stack([example["pixel_values"] for example in examples])
+ batch["class_labels"] = [example["class_labels"] for example in examples]
+ batch["mask_labels"] = [example["mask_labels"] for example in examples]
+ if "pixel_mask" in examples[0]:
+ batch["pixel_mask"] = torch.stack([example["pixel_mask"] for example in examples])
+ return batch
+
+
+def nested_cpu(tensors):
+ if isinstance(tensors, (list, tuple)):
+ return type(tensors)(nested_cpu(t) for t in tensors)
+ elif isinstance(tensors, Mapping):
+ return type(tensors)({k: nested_cpu(t) for k, t in tensors.items()})
+ elif isinstance(tensors, torch.Tensor):
+ return tensors.cpu().detach()
+ else:
+ return tensors
+
+
+def evaluation_loop(model, image_processor, accelerator: Accelerator, dataloader, id2label):
+ metric = MeanAveragePrecision(iou_type="segm", class_metrics=True)
+
+ for inputs in tqdm(dataloader, total=len(dataloader), disable=not accelerator.is_local_main_process):
+ with torch.no_grad():
+ outputs = model(**inputs)
+
+ inputs = accelerator.gather_for_metrics(inputs)
+ inputs = nested_cpu(inputs)
+
+ outputs = accelerator.gather_for_metrics(outputs)
+ outputs = nested_cpu(outputs)
+
+ # For metric computation we need to provide:
+ # - targets in a form of list of dictionaries with keys "masks", "labels"
+ # - predictions in a form of list of dictionaries with keys "masks", "labels", "scores"
+
+ post_processed_targets = []
+ post_processed_predictions = []
+ target_sizes = []
+
+ # Collect targets
+ for masks, labels in zip(inputs["mask_labels"], inputs["class_labels"]):
+ post_processed_targets.append(
+ {
+ "masks": masks.to(dtype=torch.bool),
+ "labels": labels,
+ }
+ )
+ target_sizes.append(masks.shape[-2:])
+
+ # Collect predictions
+ post_processed_output = image_processor.post_process_instance_segmentation(
+ outputs,
+ threshold=0.0,
+ target_sizes=target_sizes,
+ return_binary_maps=True,
+ )
+
+ for image_predictions, target_size in zip(post_processed_output, target_sizes):
+ if image_predictions["segments_info"]:
+ post_processed_image_prediction = {
+ "masks": image_predictions["segmentation"].to(dtype=torch.bool),
+ "labels": torch.tensor([x["label_id"] for x in image_predictions["segments_info"]]),
+ "scores": torch.tensor([x["score"] for x in image_predictions["segments_info"]]),
+ }
+ else:
+ # for void predictions, we need to provide empty tensors
+ post_processed_image_prediction = {
+ "masks": torch.zeros([0, *target_size], dtype=torch.bool),
+ "labels": torch.tensor([]),
+ "scores": torch.tensor([]),
+ }
+ post_processed_predictions.append(post_processed_image_prediction)
+
+ # Update metric for batch targets and predictions
+ metric.update(post_processed_predictions, post_processed_targets)
+
+ # Compute metrics
+ metrics = metric.compute()
+
+ # Replace list of per class metrics with separate metric for each class
+ classes = metrics.pop("classes")
+ map_per_class = metrics.pop("map_per_class")
+ mar_100_per_class = metrics.pop("mar_100_per_class")
+ for class_id, class_map, class_mar in zip(classes, map_per_class, mar_100_per_class):
+ class_name = id2label[class_id.item()] if id2label is not None else class_id.item()
+ metrics[f"map_{class_name}"] = class_map
+ metrics[f"mar_100_{class_name}"] = class_mar
+
+ metrics = {k: round(v.item(), 4) for k, v in metrics.items()}
+
+ return metrics
+
+
+def setup_logging(accelerator: Accelerator) -> None:
+ """Setup logging according to `training_args`."""
+
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ handlers=[logging.StreamHandler(sys.stdout)],
+ )
+
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_info()
+ logger.setLevel(logging.INFO)
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+
+
+def handle_repository_creation(accelerator: Accelerator, args: argparse.Namespace):
+ """Create a repository for the model and dataset if `args.push_to_hub` is set."""
+
+ repo_id = None
+ if accelerator.is_main_process:
+ if args.push_to_hub:
+ # Retrieve of infer repo_name
+ repo_name = args.hub_model_id
+ if repo_name is None:
+ repo_name = Path(args.output_dir).absolute().name
+ # Create repo and retrieve repo_id
+ api = HfApi()
+ repo_id = api.create_repo(repo_name, exist_ok=True, token=args.hub_token).repo_id
+
+ with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
+ if "step_*" not in gitignore:
+ gitignore.write("step_*\n")
+ if "epoch_*" not in gitignore:
+ gitignore.write("epoch_*\n")
+ elif args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+ accelerator.wait_for_everyone()
+
+ return repo_id
+
+
+def main():
+ args = parse_args()
+
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_instance_segmentation_no_trainer", args)
+
+ # Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
+ # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
+ # in the environment
+ accelerator_log_kwargs = {}
+
+ if args.with_tracking:
+ accelerator_log_kwargs["log_with"] = args.report_to
+ accelerator_log_kwargs["project_dir"] = args.output_dir
+
+ accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, **accelerator_log_kwargs)
+ setup_logging(accelerator)
+
+ # If passed along, set the training seed now.
+ # We set device_specific to True as we want different data augmentation per device.
+ if args.seed is not None:
+ set_seed(args.seed, device_specific=True)
+
+ # Create repository if push ot hub is specified
+ repo_id = handle_repository_creation(accelerator, args)
+
+ if args.push_to_hub:
+ api = HfApi()
+
+ # ------------------------------------------------------------------------------------------------
+ # Load dataset, prepare splits
+ # ------------------------------------------------------------------------------------------------
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ dataset = load_dataset(args.dataset_name, cache_dir=args.cache_dir)
+
+ # We need to specify the label2id mapping for the model
+ # it is a mapping from semantic class name to class index.
+ # In case your dataset does not provide it, you can create it manually:
+ # label2id = {"background": 0, "cat": 1, "dog": 2}
+ label2id = dataset["train"][0]["semantic_class_to_id"]
+
+ if args.do_reduce_labels:
+ label2id = {name: idx for name, idx in label2id.items() if idx != 0} # remove background class
+ label2id = {name: idx - 1 for name, idx in label2id.items()} # shift class indices by -1
+
+ id2label = {v: k for k, v in label2id.items()}
+
+ # ------------------------------------------------------------------------------------------------
+ # Load pretrained model and image processor
+ # ------------------------------------------------------------------------------------------------
+ model = AutoModelForUniversalSegmentation.from_pretrained(
+ args.model_name_or_path,
+ label2id=label2id,
+ id2label=id2label,
+ ignore_mismatched_sizes=True,
+ token=args.hub_token,
+ )
+
+ image_processor = AutoImageProcessor.from_pretrained(
+ args.model_name_or_path,
+ do_resize=True,
+ size={"height": args.image_height, "width": args.image_width},
+ do_reduce_labels=args.do_reduce_labels,
+ reduce_labels=args.do_reduce_labels, # TODO: remove when mask2former support `do_reduce_labels`
+ token=args.hub_token,
+ )
+
+ # ------------------------------------------------------------------------------------------------
+ # Define image augmentations and dataset transforms
+ # ------------------------------------------------------------------------------------------------
+ train_augment_and_transform = A.Compose(
+ [
+ A.HorizontalFlip(p=0.5),
+ A.RandomBrightnessContrast(p=0.5),
+ A.HueSaturationValue(p=0.1),
+ ],
+ )
+ validation_transform = A.Compose(
+ [A.NoOp()],
+ )
+
+ # Make transform functions for batch and apply for dataset splits
+ train_transform_batch = partial(
+ augment_and_transform_batch, transform=train_augment_and_transform, image_processor=image_processor
+ )
+ validation_transform_batch = partial(
+ augment_and_transform_batch, transform=validation_transform, image_processor=image_processor
+ )
+
+ with accelerator.main_process_first():
+ dataset["train"] = dataset["train"].with_transform(train_transform_batch)
+ dataset["validation"] = dataset["validation"].with_transform(validation_transform_batch)
+
+ dataloader_common_args = {
+ "num_workers": args.dataloader_num_workers,
+ "persistent_workers": True,
+ "collate_fn": collate_fn,
+ }
+ train_dataloader = DataLoader(
+ dataset["train"], shuffle=True, batch_size=args.per_device_train_batch_size, **dataloader_common_args
+ )
+ valid_dataloader = DataLoader(
+ dataset["validation"], shuffle=False, batch_size=args.per_device_eval_batch_size, **dataloader_common_args
+ )
+
+ # ------------------------------------------------------------------------------------------------
+ # Define optimizer, scheduler and prepare everything with the accelerator
+ # ------------------------------------------------------------------------------------------------
+
+ # Optimizer
+ optimizer = torch.optim.AdamW(
+ list(model.parameters()),
+ lr=args.learning_rate,
+ betas=[args.adam_beta1, args.adam_beta2],
+ eps=args.adam_epsilon,
+ )
+
+ # Figure out how many steps we should save the Accelerator states
+ checkpointing_steps = args.checkpointing_steps
+ if checkpointing_steps is not None and checkpointing_steps.isdigit():
+ checkpointing_steps = int(checkpointing_steps)
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ name=args.lr_scheduler_type,
+ optimizer=optimizer,
+ num_warmup_steps=args.num_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps
+ if overrode_max_train_steps
+ else args.max_train_steps * accelerator.num_processes,
+ )
+
+ # Prepare everything with our `accelerator`.
+ model, optimizer, train_dataloader, valid_dataloader, lr_scheduler = accelerator.prepare(
+ model, optimizer, train_dataloader, valid_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if args.with_tracking:
+ experiment_config = vars(args)
+ # TensorBoard cannot log Enums, need the raw value
+ experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
+ accelerator.init_trackers("instance_segmentation_no_trainer", experiment_config)
+
+ # ------------------------------------------------------------------------------------------------
+ # Run training with evaluation on each epoch
+ # ------------------------------------------------------------------------------------------------
+
+ total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(dataset['train'])}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+
+ # Only show the progress bar once on each machine.
+ progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
+ completed_steps = 0
+ starting_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
+ checkpoint_path = args.resume_from_checkpoint
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
+ dirs.sort(key=os.path.getctime)
+ path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
+ checkpoint_path = path
+ path = os.path.basename(checkpoint_path)
+
+ accelerator.print(f"Resumed from checkpoint: {checkpoint_path}")
+ accelerator.load_state(checkpoint_path)
+ # Extract `epoch_{i}` or `step_{i}`
+ training_difference = os.path.splitext(path)[0]
+
+ if "epoch" in training_difference:
+ starting_epoch = int(training_difference.replace("epoch_", "")) + 1
+ resume_step = None
+ completed_steps = starting_epoch * num_update_steps_per_epoch
+ else:
+ # need to multiply `gradient_accumulation_steps` to reflect real steps
+ resume_step = int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps
+ starting_epoch = resume_step // len(train_dataloader)
+ completed_steps = resume_step // args.gradient_accumulation_steps
+ resume_step -= starting_epoch * len(train_dataloader)
+
+ # update the progress_bar if load from checkpoint
+ progress_bar.update(completed_steps)
+
+ for epoch in range(starting_epoch, args.num_train_epochs):
+ model.train()
+ if args.with_tracking:
+ total_loss = 0
+ if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None:
+ # We skip the first `n` batches in the dataloader when resuming from a checkpoint
+ active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
+ else:
+ active_dataloader = train_dataloader
+
+ for step, batch in enumerate(active_dataloader):
+ with accelerator.accumulate(model):
+ outputs = model(**batch)
+ loss = outputs.loss
+ # We keep track of the loss at each epoch
+ if args.with_tracking:
+ total_loss += loss.detach().float()
+ accelerator.backward(loss)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ completed_steps += 1
+
+ if isinstance(checkpointing_steps, int):
+ if completed_steps % checkpointing_steps == 0:
+ output_dir = f"step_{completed_steps}"
+ if args.output_dir is not None:
+ output_dir = os.path.join(args.output_dir, output_dir)
+ accelerator.save_state(output_dir)
+
+ if args.push_to_hub and epoch < args.num_train_epochs - 1:
+ accelerator.wait_for_everyone()
+ unwrapped_model = accelerator.unwrap_model(model)
+ unwrapped_model.save_pretrained(
+ args.output_dir,
+ is_main_process=accelerator.is_main_process,
+ save_function=accelerator.save,
+ )
+ if accelerator.is_main_process:
+ image_processor.save_pretrained(args.output_dir)
+ api.upload_folder(
+ repo_id=repo_id,
+ commit_message=f"Training in progress epoch {epoch}",
+ folder_path=args.output_dir,
+ repo_type="model",
+ token=args.hub_token,
+ )
+
+ if completed_steps >= args.max_train_steps:
+ break
+
+ logger.info("***** Running evaluation *****")
+ metrics = evaluation_loop(model, image_processor, accelerator, valid_dataloader, id2label)
+
+ logger.info(f"epoch {epoch}: {metrics}")
+
+ if args.with_tracking:
+ accelerator.log(
+ {
+ "train_loss": total_loss.item() / len(train_dataloader),
+ **metrics,
+ "epoch": epoch,
+ "step": completed_steps,
+ },
+ step=completed_steps,
+ )
+
+ if args.push_to_hub and epoch < args.num_train_epochs - 1:
+ accelerator.wait_for_everyone()
+ unwrapped_model = accelerator.unwrap_model(model)
+ unwrapped_model.save_pretrained(
+ args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
+ )
+ if accelerator.is_main_process:
+ image_processor.save_pretrained(args.output_dir)
+ api.upload_folder(
+ commit_message=f"Training in progress epoch {epoch}",
+ folder_path=args.output_dir,
+ repo_id=repo_id,
+ repo_type="model",
+ token=args.hub_token,
+ )
+
+ if args.checkpointing_steps == "epoch":
+ output_dir = f"epoch_{epoch}"
+ if args.output_dir is not None:
+ output_dir = os.path.join(args.output_dir, output_dir)
+ accelerator.save_state(output_dir)
+
+ # ------------------------------------------------------------------------------------------------
+ # Run evaluation on test dataset and save the model
+ # ------------------------------------------------------------------------------------------------
+
+ logger.info("***** Running evaluation on test dataset *****")
+ metrics = evaluation_loop(model, image_processor, accelerator, valid_dataloader, id2label)
+ metrics = {f"test_{k}": v for k, v in metrics.items()}
+
+ logger.info(f"Test metrics: {metrics}")
+
+ if args.with_tracking:
+ accelerator.end_training()
+
+ if args.output_dir is not None:
+ accelerator.wait_for_everyone()
+ unwrapped_model = accelerator.unwrap_model(model)
+ unwrapped_model.save_pretrained(
+ args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
+ )
+ if accelerator.is_main_process:
+ with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
+ json.dump(metrics, f, indent=2)
+
+ image_processor.save_pretrained(args.output_dir)
+
+ if args.push_to_hub:
+ api.upload_folder(
+ commit_message="End of training",
+ folder_path=args.output_dir,
+ repo_id=repo_id,
+ repo_type="model",
+ token=args.hub_token,
+ ignore_patterns=["epoch_*"],
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/pytorch/test_accelerate_examples.py b/examples/pytorch/test_accelerate_examples.py
index 346b5cda63bf6a..f3695b4ad1178d 100644
--- a/examples/pytorch/test_accelerate_examples.py
+++ b/examples/pytorch/test_accelerate_examples.py
@@ -355,3 +355,28 @@ def test_run_object_detection_no_trainer(self):
run_command(self._launch_args + testargs)
result = get_results(tmp_dir)
self.assertGreaterEqual(result["test_map"], 0.10)
+
+ @slow
+ @mock.patch.dict(os.environ, {"WANDB_MODE": "offline", "DVCLIVE_TEST": "true"})
+ def test_run_instance_segmentation_no_trainer(self):
+ stream_handler = logging.StreamHandler(sys.stdout)
+ logger.addHandler(stream_handler)
+
+ tmp_dir = self.get_auto_remove_tmp_dir()
+ testargs = f"""
+ {self.examples_dir}/pytorch/instance-segmentation/run_instance_segmentation_no_trainer.py
+ --model_name_or_path qubvel-hf/finetune-instance-segmentation-ade20k-mini-mask2former
+ --output_dir {tmp_dir}
+ --dataset_name qubvel-hf/ade20k-nano
+ --do_reduce_labels
+ --image_height 256
+ --image_width 256
+ --num_train_epochs 1
+ --per_device_train_batch_size 2
+ --per_device_eval_batch_size 1
+ --seed 1234
+ """.split()
+
+ run_command(self._launch_args + testargs)
+ result = get_results(tmp_dir)
+ self.assertGreaterEqual(result["test_map"], 0.1)
diff --git a/examples/pytorch/test_pytorch_examples.py b/examples/pytorch/test_pytorch_examples.py
index e7cc2d51c0065f..dab2148a728908 100644
--- a/examples/pytorch/test_pytorch_examples.py
+++ b/examples/pytorch/test_pytorch_examples.py
@@ -49,6 +49,7 @@
"image-pretraining",
"semantic-segmentation",
"object-detection",
+ "instance-segmentation",
]
]
sys.path.extend(SRC_DIRS)
@@ -60,6 +61,7 @@
import run_generation
import run_glue
import run_image_classification
+ import run_instance_segmentation
import run_mae
import run_mlm
import run_ner
@@ -639,3 +641,33 @@ def test_run_object_detection(self):
run_object_detection.main()
result = get_results(tmp_dir)
self.assertGreaterEqual(result["test_map"], 0.1)
+
+ @patch.dict(os.environ, {"WANDB_DISABLED": "true"})
+ def test_run_instance_segmentation(self):
+ tmp_dir = self.get_auto_remove_tmp_dir()
+ testargs = f"""
+ run_instance_segmentation.py
+ --model_name_or_path qubvel-hf/finetune-instance-segmentation-ade20k-mini-mask2former
+ --output_dir {tmp_dir}
+ --dataset_name qubvel-hf/ade20k-nano
+ --do_reduce_labels
+ --image_height 256
+ --image_width 256
+ --do_train
+ --num_train_epochs 1
+ --learning_rate 1e-5
+ --lr_scheduler_type constant
+ --per_device_train_batch_size 2
+ --per_device_eval_batch_size 1
+ --do_eval
+ --evaluation_strategy epoch
+ --seed 32
+ """.split()
+
+ if is_torch_fp16_available_on_device(torch_device):
+ testargs.append("--fp16")
+
+ with patch.object(sys, "argv", testargs):
+ run_instance_segmentation.main()
+ result = get_results(tmp_dir)
+ self.assertGreaterEqual(result["test_map"], 0.1)
diff --git a/src/transformers/models/mask2former/image_processing_mask2former.py b/src/transformers/models/mask2former/image_processing_mask2former.py
index 5440584d25f28f..0daa4b88a0c76b 100644
--- a/src/transformers/models/mask2former/image_processing_mask2former.py
+++ b/src/transformers/models/mask2former/image_processing_mask2former.py
@@ -283,7 +283,12 @@ def convert_segmentation_map_to_binary_masks(
# Generate a binary mask for each object instance
binary_masks = [(segmentation_map == i) for i in all_labels]
- binary_masks = np.stack(binary_masks, axis=0) # (num_labels, height, width)
+
+ # Stack the binary masks
+ if binary_masks:
+ binary_masks = np.stack(binary_masks, axis=0)
+ else:
+ binary_masks = np.zeros((0, *segmentation_map.shape))
# Convert instance ids to class ids
if instance_id_to_semantic_id is not None:
@@ -969,11 +974,15 @@ def encode_inputs(
)
# We add an axis to make them compatible with the transformations library
# this will be removed in the future
- masks = [mask[None, ...] for mask in masks]
- masks = [
- self._pad_image(image=mask, output_size=pad_size, constant_values=ignore_index) for mask in masks
- ]
- masks = np.concatenate(masks, axis=0)
+ if masks.shape[0] > 0:
+ masks = [mask[None, ...] for mask in masks]
+ masks = [
+ self._pad_image(image=mask, output_size=pad_size, constant_values=ignore_index)
+ for mask in masks
+ ]
+ masks = np.concatenate(masks, axis=0)
+ else:
+ masks = np.zeros((0, *pad_size), dtype=np.float32)
mask_labels.append(torch.from_numpy(masks))
class_labels.append(torch.from_numpy(classes))
diff --git a/src/transformers/models/maskformer/image_processing_maskformer.py b/src/transformers/models/maskformer/image_processing_maskformer.py
index 3c854b35c76edb..530e11269d8425 100644
--- a/src/transformers/models/maskformer/image_processing_maskformer.py
+++ b/src/transformers/models/maskformer/image_processing_maskformer.py
@@ -286,7 +286,12 @@ def convert_segmentation_map_to_binary_masks(
# Generate a binary mask for each object instance
binary_masks = [(segmentation_map == i) for i in all_labels]
- binary_masks = np.stack(binary_masks, axis=0) # (num_labels, height, width)
+
+ # Stack the binary masks
+ if binary_masks:
+ binary_masks = np.stack(binary_masks, axis=0)
+ else:
+ binary_masks = np.zeros((0, *segmentation_map.shape))
# Convert instance ids to class ids
if instance_id_to_semantic_id is not None:
@@ -982,17 +987,20 @@ def encode_inputs(
)
# We add an axis to make them compatible with the transformations library
# this will be removed in the future
- masks = [mask[None, ...] for mask in masks]
- masks = [
- self._pad_image(
- image=mask,
- output_size=pad_size,
- constant_values=ignore_index,
- input_data_format=ChannelDimension.FIRST,
- )
- for mask in masks
- ]
- masks = np.concatenate(masks, axis=0)
+ if masks.shape[0] > 0:
+ masks = [mask[None, ...] for mask in masks]
+ masks = [
+ self._pad_image(
+ image=mask,
+ output_size=pad_size,
+ constant_values=ignore_index,
+ input_data_format=ChannelDimension.FIRST,
+ )
+ for mask in masks
+ ]
+ masks = np.concatenate(masks, axis=0)
+ else:
+ masks = np.zeros((0, *pad_size), dtype=np.float32)
mask_labels.append(torch.from_numpy(masks))
class_labels.append(torch.from_numpy(classes))
diff --git a/src/transformers/models/oneformer/image_processing_oneformer.py b/src/transformers/models/oneformer/image_processing_oneformer.py
index 9f865f8efd9b94..df83e1cc1702e3 100644
--- a/src/transformers/models/oneformer/image_processing_oneformer.py
+++ b/src/transformers/models/oneformer/image_processing_oneformer.py
@@ -285,7 +285,12 @@ def convert_segmentation_map_to_binary_masks(
# Generate a binary mask for each object instance
binary_masks = [(segmentation_map == i) for i in all_labels]
- binary_masks = np.stack(binary_masks, axis=0) # (num_labels, height, width)
+
+ # Stack the binary masks
+ if binary_masks:
+ binary_masks = np.stack(binary_masks, axis=0)
+ else:
+ binary_masks = np.zeros((0, *segmentation_map.shape))
# Convert instance ids to class ids
if instance_id_to_semantic_id is not None: