Skip to content

Commit

Permalink
Update image classification example (#1708)
Browse files Browse the repository at this point in the history
  • Loading branch information
prathikr authored Feb 23, 2024
1 parent 88f1a9c commit cd81dbd
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 27 deletions.
5 changes: 2 additions & 3 deletions examples/onnxruntime/training/image-classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@ See the License for the specific language governing permissions and
limitations under the License.
-->

# Language Modeling

## Image Classification Training
# Image Classification

By running the scripts [`run_image_classification.py`](https://github.com/huggingface/optimum/blob/main/examples/onnxruntime/training/image-classification/run_image_classification.py) we will be able to leverage the [`ONNX Runtime`](https://github.com/microsoft/onnxruntime) accelerator to train the language models from the
[HuggingFace hub](https://huggingface.co/models).
Expand All @@ -32,6 +30,7 @@ torchrun --nproc_per_node=NUM_GPUS_YOU_HAVE run_image_classification.py \
--dataset_name beans \
--output_dir ./beans_outputs/ \
--remove_unused_columns False \
--label_column_name labels \
--do_train \
--do_eval \
--learning_rate 2e-5 \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,20 @@
import evaluate
import numpy as np
import torch
import transformers
from datasets import load_dataset
from PIL import Image
from torchvision.transforms import (
CenterCrop,
Compose,
Lambda,
Normalize,
RandomHorizontalFlip,
RandomResizedCrop,
Resize,
ToTensor,
)

import transformers
from transformers import (
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
AutoConfig,
Expand All @@ -47,17 +49,16 @@
from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version

from optimum.onnxruntime import ORTTrainer, ORTTrainingArguments

from optimum import ORTTrainer, ORTTrainingArguments

""" Fine-tuning a 🤗 Transformers model for image classification"""

logger = logging.getLogger(__name__)

# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.34.0")
check_min_version("4.38.0")

require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")

MODEL_CONFIG_CLASSES = list(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
Expand Down Expand Up @@ -109,6 +110,14 @@ class DataTrainingArguments:
)
},
)
image_column_name: str = field(
default="image",
metadata={"help": "The name of the dataset column containing the image data. Defaults to 'image'."},
)
label_column_name: str = field(
default="label",
metadata={"help": "The name of the dataset column containing the labels. Defaults to 'label'."},
)

def __post_init__(self):
if self.dataset_name is None and (self.train_dir is None and self.validation_dir is None):
Expand Down Expand Up @@ -154,14 +163,14 @@ class ModelArguments:
use_auth_token: bool = field(
default=None,
metadata={
"help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token`."
"help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead."
},
)
trust_remote_code: bool = field(
default=False,
metadata={
"help": (
"Whether or not to allow for custom models defined on the Hub in their own modeling files. This option"
"Whether or not to allow for custom models defined on the Hub in their own modeling files. This option "
"should only be set to `True` for repositories you trust and in which you have read the code, as it will "
"execute code present on the Hub on your local machine."
)
Expand All @@ -173,12 +182,6 @@ class ModelArguments:
)


def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"] for example in examples])
labels = torch.tensor([example["labels"] for example in examples])
return {"pixel_values": pixel_values, "labels": labels}


def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
Expand All @@ -193,7 +196,10 @@ def main():
model_args, data_args, training_args = parser.parse_args_into_dataclasses()

if model_args.use_auth_token is not None:
warnings.warn("The `use_auth_token` argument is deprecated and will be removed in v4.34.", FutureWarning)
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead.",
FutureWarning,
)
if model_args.token is not None:
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
model_args.token = model_args.use_auth_token
Expand Down Expand Up @@ -221,7 +227,7 @@ def main():

# Log on each process the small summary:
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
+ f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Training/evaluation parameters {training_args}")
Expand Down Expand Up @@ -250,7 +256,6 @@ def main():
data_args.dataset_name,
data_args.dataset_config_name,
cache_dir=model_args.cache_dir,
task="image-classification",
token=model_args.token,
)
else:
Expand All @@ -263,9 +268,27 @@ def main():
"imagefolder",
data_files=data_files,
cache_dir=model_args.cache_dir,
task="image-classification",
)

dataset_column_names = dataset["train"].column_names if "train" in dataset else dataset["validation"].column_names
if data_args.image_column_name not in dataset_column_names:
raise ValueError(
f"--image_column_name {data_args.image_column_name} not found in dataset '{data_args.dataset_name}'. "
"Make sure to set `--image_column_name` to the correct audio column - one of "
f"{', '.join(dataset_column_names)}."
)
if data_args.label_column_name not in dataset_column_names:
raise ValueError(
f"--label_column_name {data_args.label_column_name} not found in dataset '{data_args.dataset_name}'. "
"Make sure to set `--label_column_name` to the correct text column - one of "
f"{', '.join(dataset_column_names)}."
)

def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"] for example in examples])
labels = torch.tensor([example[data_args.label_column_name] for example in examples])
return {"pixel_values": pixel_values, "labels": labels}

# If we don't have a validation split, split off a percentage of train as validation.
data_args.train_val_split = None if "validation" in dataset.keys() else data_args.train_val_split
if isinstance(data_args.train_val_split, float) and data_args.train_val_split > 0.0:
Expand All @@ -275,14 +298,14 @@ def main():

# Prepare label mappings.
# We'll include these in the model's config to get human readable labels in the Inference API.
labels = dataset["train"].features["labels"].names
labels = dataset["train"].features[data_args.label_column_name].names
label2id, id2label = {}, {}
for i, label in enumerate(labels):
label2id[label] = str(i)
id2label[str(i)] = label

# Load the accuracy metric from the datasets package
metric = evaluate.load("accuracy")
metric = evaluate.load("accuracy", cache_dir=model_args.cache_dir)

# Define our compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
# predictions and label_ids field) and has to return a dictionary string to float.
Expand Down Expand Up @@ -324,7 +347,11 @@ def compute_metrics(p):
size = image_processor.size["shortest_edge"]
else:
size = (image_processor.size["height"], image_processor.size["width"])
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
normalize = (
Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
if hasattr(image_processor, "image_mean") and hasattr(image_processor, "image_std")
else Lambda(lambda x: x)
)
_train_transforms = Compose(
[
RandomResizedCrop(size),
Expand All @@ -345,13 +372,15 @@ def compute_metrics(p):
def train_transforms(example_batch):
"""Apply _train_transforms across a batch."""
example_batch["pixel_values"] = [
_train_transforms(pil_img.convert("RGB")) for pil_img in example_batch["image"]
_train_transforms(pil_img.convert("RGB")) for pil_img in example_batch[data_args.image_column_name]
]
return example_batch

def val_transforms(example_batch):
"""Apply _val_transforms across a batch."""
example_batch["pixel_values"] = [_val_transforms(pil_img.convert("RGB")) for pil_img in example_batch["image"]]
example_batch["pixel_values"] = [
_val_transforms(pil_img.convert("RGB")) for pil_img in example_batch[data_args.image_column_name]
]
return example_batch

if training_args.do_train:
Expand All @@ -374,7 +403,7 @@ def val_transforms(example_batch):
# Set the validation transforms
dataset["validation"].set_transform(val_transforms)

# Initalize our trainer
# Initialize our trainer
trainer = ORTTrainer(
model=model,
args=training_args,
Expand Down Expand Up @@ -418,4 +447,4 @@ def val_transforms(example_batch):


if __name__ == "__main__":
main()
main()

0 comments on commit cd81dbd

Please sign in to comment.