From dd17bb907a57df114bb92a7a32c74ec9e9d47228 Mon Sep 17 00:00:00 2001 From: regisss <15324346+regisss@users.noreply.github.com> Date: Tue, 16 Jan 2024 08:04:08 +0100 Subject: [PATCH] Remove `task` arg in `load_dataset` in image-classification example (#28408) * Remove `task` arg in `load_dataset` in image-classification example * Manage case where "train" is not in dataset * Add new args to manage image and label column names * Similar to audio-classification example * Fix README * Update tests --- .../pytorch/image-classification/README.md | 3 +- .../run_image_classification.py | 43 ++++++++++++++----- .../run_image_classification_no_trainer.py | 41 +++++++++++++++--- examples/pytorch/test_accelerate_examples.py | 1 + examples/pytorch/test_pytorch_examples.py | 1 + 5 files changed, 71 insertions(+), 18 deletions(-) diff --git a/examples/pytorch/image-classification/README.md b/examples/pytorch/image-classification/README.md index 04b4748774ddf7..c95f180d4502cb 100644 --- a/examples/pytorch/image-classification/README.md +++ b/examples/pytorch/image-classification/README.md @@ -41,6 +41,7 @@ python run_image_classification.py \ --dataset_name beans \ --output_dir ./beans_outputs/ \ --remove_unused_columns False \ + --label_column_name labels \ --do_train \ --do_eval \ --push_to_hub \ @@ -197,7 +198,7 @@ accelerate test that will check everything is ready for training. Finally, you can launch training with ```bash -accelerate launch run_image_classification_trainer.py +accelerate launch run_image_classification_no_trainer.py --image_column_name img ``` This command is the same and will work for: diff --git a/examples/pytorch/image-classification/run_image_classification.py b/examples/pytorch/image-classification/run_image_classification.py index 07942aa7e242e8..db13dc988ed591 100755 --- a/examples/pytorch/image-classification/run_image_classification.py +++ b/examples/pytorch/image-classification/run_image_classification.py @@ -111,6 +111,14 @@ class DataTrainingArguments: ) }, ) + image_column_name: str = field( + default="image", + metadata={"help": "The name of the dataset column containing the image data. Defaults to 'image'."}, + ) + label_column_name: str = field( + default="label", + metadata={"help": "The name of the dataset column containing the labels. Defaults to 'label'."}, + ) def __post_init__(self): if self.dataset_name is None and (self.train_dir is None and self.validation_dir is None): @@ -175,12 +183,6 @@ class ModelArguments: ) -def collate_fn(examples): - pixel_values = torch.stack([example["pixel_values"] for example in examples]) - labels = torch.tensor([example["labels"] for example in examples]) - return {"pixel_values": pixel_values, "labels": labels} - - def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. @@ -255,7 +257,6 @@ def main(): data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, - task="image-classification", token=model_args.token, ) else: @@ -268,9 +269,27 @@ def main(): "imagefolder", data_files=data_files, cache_dir=model_args.cache_dir, - task="image-classification", ) + dataset_column_names = dataset["train"].column_names if "train" in dataset else dataset["validation"].column_names + if data_args.image_column_name not in dataset_column_names: + raise ValueError( + f"--image_column_name {data_args.image_column_name} not found in dataset '{data_args.dataset_name}'. " + "Make sure to set `--image_column_name` to the correct audio column - one of " + f"{', '.join(dataset_column_names)}." + ) + if data_args.label_column_name not in dataset_column_names: + raise ValueError( + f"--label_column_name {data_args.label_column_name} not found in dataset '{data_args.dataset_name}'. " + "Make sure to set `--label_column_name` to the correct text column - one of " + f"{', '.join(dataset_column_names)}." + ) + + def collate_fn(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + labels = torch.tensor([example[data_args.label_column_name] for example in examples]) + return {"pixel_values": pixel_values, "labels": labels} + # If we don't have a validation split, split off a percentage of train as validation. data_args.train_val_split = None if "validation" in dataset.keys() else data_args.train_val_split if isinstance(data_args.train_val_split, float) and data_args.train_val_split > 0.0: @@ -280,7 +299,7 @@ def main(): # Prepare label mappings. # We'll include these in the model's config to get human readable labels in the Inference API. - labels = dataset["train"].features["labels"].names + labels = dataset["train"].features[data_args.label_column_name].names label2id, id2label = {}, {} for i, label in enumerate(labels): label2id[label] = str(i) @@ -354,13 +373,15 @@ def compute_metrics(p): def train_transforms(example_batch): """Apply _train_transforms across a batch.""" example_batch["pixel_values"] = [ - _train_transforms(pil_img.convert("RGB")) for pil_img in example_batch["image"] + _train_transforms(pil_img.convert("RGB")) for pil_img in example_batch[data_args.image_column_name] ] return example_batch def val_transforms(example_batch): """Apply _val_transforms across a batch.""" - example_batch["pixel_values"] = [_val_transforms(pil_img.convert("RGB")) for pil_img in example_batch["image"]] + example_batch["pixel_values"] = [ + _val_transforms(pil_img.convert("RGB")) for pil_img in example_batch[data_args.image_column_name] + ] return example_batch if training_args.do_train: diff --git a/examples/pytorch/image-classification/run_image_classification_no_trainer.py b/examples/pytorch/image-classification/run_image_classification_no_trainer.py index 186bbfd507540d..963a01b77cf7fc 100644 --- a/examples/pytorch/image-classification/run_image_classification_no_trainer.py +++ b/examples/pytorch/image-classification/run_image_classification_no_trainer.py @@ -189,6 +189,18 @@ def parse_args(): action="store_true", help="Whether or not to enable to load a pretrained model whose head dimensions are different.", ) + parser.add_argument( + "--image_column_name", + type=str, + default="image", + help="The name of the dataset column containing the image data. Defaults to 'image'.", + ) + parser.add_argument( + "--label_column_name", + type=str, + default="label", + help="The name of the dataset column containing the labels. Defaults to 'label'.", + ) args = parser.parse_args() # Sanity checks @@ -272,7 +284,7 @@ def main(): # download the dataset. if args.dataset_name is not None: # Downloading and loading a dataset from the hub. - dataset = load_dataset(args.dataset_name, task="image-classification") + dataset = load_dataset(args.dataset_name) else: data_files = {} if args.train_dir is not None: @@ -282,11 +294,24 @@ def main(): dataset = load_dataset( "imagefolder", data_files=data_files, - task="image-classification", ) # See more about loading custom images at # https://huggingface.co/docs/datasets/v2.0.0/en/image_process#imagefolder. + dataset_column_names = dataset["train"].column_names if "train" in dataset else dataset["validation"].column_names + if args.image_column_name not in dataset_column_names: + raise ValueError( + f"--image_column_name {args.image_column_name} not found in dataset '{args.dataset_name}'. " + "Make sure to set `--image_column_name` to the correct audio column - one of " + f"{', '.join(dataset_column_names)}." + ) + if args.label_column_name not in dataset_column_names: + raise ValueError( + f"--label_column_name {args.label_column_name} not found in dataset '{args.dataset_name}'. " + "Make sure to set `--label_column_name` to the correct text column - one of " + f"{', '.join(dataset_column_names)}." + ) + # If we don't have a validation split, split off a percentage of train as validation. args.train_val_split = None if "validation" in dataset.keys() else args.train_val_split if isinstance(args.train_val_split, float) and args.train_val_split > 0.0: @@ -296,7 +321,7 @@ def main(): # Prepare label mappings. # We'll include these in the model's config to get human readable labels in the Inference API. - labels = dataset["train"].features["labels"].names + labels = dataset["train"].features[args.label_column_name].names label2id = {label: str(i) for i, label in enumerate(labels)} id2label = {str(i): label for i, label in enumerate(labels)} @@ -355,12 +380,16 @@ def main(): def preprocess_train(example_batch): """Apply _train_transforms across a batch.""" - example_batch["pixel_values"] = [train_transforms(image.convert("RGB")) for image in example_batch["image"]] + example_batch["pixel_values"] = [ + train_transforms(image.convert("RGB")) for image in example_batch[args.image_column_name] + ] return example_batch def preprocess_val(example_batch): """Apply _val_transforms across a batch.""" - example_batch["pixel_values"] = [val_transforms(image.convert("RGB")) for image in example_batch["image"]] + example_batch["pixel_values"] = [ + val_transforms(image.convert("RGB")) for image in example_batch[args.image_column_name] + ] return example_batch with accelerator.main_process_first(): @@ -376,7 +405,7 @@ def preprocess_val(example_batch): # DataLoaders creation: def collate_fn(examples): pixel_values = torch.stack([example["pixel_values"] for example in examples]) - labels = torch.tensor([example["labels"] for example in examples]) + labels = torch.tensor([example[args.label_column_name] for example in examples]) return {"pixel_values": pixel_values, "labels": labels} train_dataloader = DataLoader( diff --git a/examples/pytorch/test_accelerate_examples.py b/examples/pytorch/test_accelerate_examples.py index 8749c8add77950..fc485cf59a2ebb 100644 --- a/examples/pytorch/test_accelerate_examples.py +++ b/examples/pytorch/test_accelerate_examples.py @@ -322,6 +322,7 @@ def test_run_image_classification_no_trainer(self): --output_dir {tmp_dir} --with_tracking --checkpointing_steps 1 + --label_column_name labels """.split() run_command(self._launch_args + testargs) diff --git a/examples/pytorch/test_pytorch_examples.py b/examples/pytorch/test_pytorch_examples.py index a0781b356595ba..0aabbb4bcb881c 100644 --- a/examples/pytorch/test_pytorch_examples.py +++ b/examples/pytorch/test_pytorch_examples.py @@ -398,6 +398,7 @@ def test_run_image_classification(self): --max_steps 10 --train_val_split 0.1 --seed 42 + --label_column_name labels """.split() if is_torch_fp16_available_on_device(torch_device):