Skip to content

Commit

Permalink
Remove task arg in load_dataset in image-classification example (h…
Browse files Browse the repository at this point in the history
…uggingface#28408)

* Remove `task` arg in `load_dataset` in image-classification example

* Manage case where "train" is not in dataset

* Add new args to manage image and label column names

* Similar to audio-classification example

* Fix README

* Update tests
  • Loading branch information
regisss authored and AjayP13 committed Jan 22, 2024
1 parent 1447aa2 commit dd17bb9
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 18 deletions.
3 changes: 2 additions & 1 deletion examples/pytorch/image-classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ python run_image_classification.py \
--dataset_name beans \
--output_dir ./beans_outputs/ \
--remove_unused_columns False \
--label_column_name labels \
--do_train \
--do_eval \
--push_to_hub \
Expand Down Expand Up @@ -197,7 +198,7 @@ accelerate test
that will check everything is ready for training. Finally, you can launch training with

```bash
accelerate launch run_image_classification_trainer.py
accelerate launch run_image_classification_no_trainer.py --image_column_name img
```

This command is the same and will work for:
Expand Down
43 changes: 32 additions & 11 deletions examples/pytorch/image-classification/run_image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,14 @@ class DataTrainingArguments:
)
},
)
image_column_name: str = field(
default="image",
metadata={"help": "The name of the dataset column containing the image data. Defaults to 'image'."},
)
label_column_name: str = field(
default="label",
metadata={"help": "The name of the dataset column containing the labels. Defaults to 'label'."},
)

def __post_init__(self):
if self.dataset_name is None and (self.train_dir is None and self.validation_dir is None):
Expand Down Expand Up @@ -175,12 +183,6 @@ class ModelArguments:
)


def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"] for example in examples])
labels = torch.tensor([example["labels"] for example in examples])
return {"pixel_values": pixel_values, "labels": labels}


def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
Expand Down Expand Up @@ -255,7 +257,6 @@ def main():
data_args.dataset_name,
data_args.dataset_config_name,
cache_dir=model_args.cache_dir,
task="image-classification",
token=model_args.token,
)
else:
Expand All @@ -268,9 +269,27 @@ def main():
"imagefolder",
data_files=data_files,
cache_dir=model_args.cache_dir,
task="image-classification",
)

dataset_column_names = dataset["train"].column_names if "train" in dataset else dataset["validation"].column_names
if data_args.image_column_name not in dataset_column_names:
raise ValueError(
f"--image_column_name {data_args.image_column_name} not found in dataset '{data_args.dataset_name}'. "
"Make sure to set `--image_column_name` to the correct audio column - one of "
f"{', '.join(dataset_column_names)}."
)
if data_args.label_column_name not in dataset_column_names:
raise ValueError(
f"--label_column_name {data_args.label_column_name} not found in dataset '{data_args.dataset_name}'. "
"Make sure to set `--label_column_name` to the correct text column - one of "
f"{', '.join(dataset_column_names)}."
)

def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"] for example in examples])
labels = torch.tensor([example[data_args.label_column_name] for example in examples])
return {"pixel_values": pixel_values, "labels": labels}

# If we don't have a validation split, split off a percentage of train as validation.
data_args.train_val_split = None if "validation" in dataset.keys() else data_args.train_val_split
if isinstance(data_args.train_val_split, float) and data_args.train_val_split > 0.0:
Expand All @@ -280,7 +299,7 @@ def main():

# Prepare label mappings.
# We'll include these in the model's config to get human readable labels in the Inference API.
labels = dataset["train"].features["labels"].names
labels = dataset["train"].features[data_args.label_column_name].names
label2id, id2label = {}, {}
for i, label in enumerate(labels):
label2id[label] = str(i)
Expand Down Expand Up @@ -354,13 +373,15 @@ def compute_metrics(p):
def train_transforms(example_batch):
"""Apply _train_transforms across a batch."""
example_batch["pixel_values"] = [
_train_transforms(pil_img.convert("RGB")) for pil_img in example_batch["image"]
_train_transforms(pil_img.convert("RGB")) for pil_img in example_batch[data_args.image_column_name]
]
return example_batch

def val_transforms(example_batch):
"""Apply _val_transforms across a batch."""
example_batch["pixel_values"] = [_val_transforms(pil_img.convert("RGB")) for pil_img in example_batch["image"]]
example_batch["pixel_values"] = [
_val_transforms(pil_img.convert("RGB")) for pil_img in example_batch[data_args.image_column_name]
]
return example_batch

if training_args.do_train:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,18 @@ def parse_args():
action="store_true",
help="Whether or not to enable to load a pretrained model whose head dimensions are different.",
)
parser.add_argument(
"--image_column_name",
type=str,
default="image",
help="The name of the dataset column containing the image data. Defaults to 'image'.",
)
parser.add_argument(
"--label_column_name",
type=str,
default="label",
help="The name of the dataset column containing the labels. Defaults to 'label'.",
)
args = parser.parse_args()

# Sanity checks
Expand Down Expand Up @@ -272,7 +284,7 @@ def main():
# download the dataset.
if args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
dataset = load_dataset(args.dataset_name, task="image-classification")
dataset = load_dataset(args.dataset_name)
else:
data_files = {}
if args.train_dir is not None:
Expand All @@ -282,11 +294,24 @@ def main():
dataset = load_dataset(
"imagefolder",
data_files=data_files,
task="image-classification",
)
# See more about loading custom images at
# https://huggingface.co/docs/datasets/v2.0.0/en/image_process#imagefolder.

dataset_column_names = dataset["train"].column_names if "train" in dataset else dataset["validation"].column_names
if args.image_column_name not in dataset_column_names:
raise ValueError(
f"--image_column_name {args.image_column_name} not found in dataset '{args.dataset_name}'. "
"Make sure to set `--image_column_name` to the correct audio column - one of "
f"{', '.join(dataset_column_names)}."
)
if args.label_column_name not in dataset_column_names:
raise ValueError(
f"--label_column_name {args.label_column_name} not found in dataset '{args.dataset_name}'. "
"Make sure to set `--label_column_name` to the correct text column - one of "
f"{', '.join(dataset_column_names)}."
)

# If we don't have a validation split, split off a percentage of train as validation.
args.train_val_split = None if "validation" in dataset.keys() else args.train_val_split
if isinstance(args.train_val_split, float) and args.train_val_split > 0.0:
Expand All @@ -296,7 +321,7 @@ def main():

# Prepare label mappings.
# We'll include these in the model's config to get human readable labels in the Inference API.
labels = dataset["train"].features["labels"].names
labels = dataset["train"].features[args.label_column_name].names
label2id = {label: str(i) for i, label in enumerate(labels)}
id2label = {str(i): label for i, label in enumerate(labels)}

Expand Down Expand Up @@ -355,12 +380,16 @@ def main():

def preprocess_train(example_batch):
"""Apply _train_transforms across a batch."""
example_batch["pixel_values"] = [train_transforms(image.convert("RGB")) for image in example_batch["image"]]
example_batch["pixel_values"] = [
train_transforms(image.convert("RGB")) for image in example_batch[args.image_column_name]
]
return example_batch

def preprocess_val(example_batch):
"""Apply _val_transforms across a batch."""
example_batch["pixel_values"] = [val_transforms(image.convert("RGB")) for image in example_batch["image"]]
example_batch["pixel_values"] = [
val_transforms(image.convert("RGB")) for image in example_batch[args.image_column_name]
]
return example_batch

with accelerator.main_process_first():
Expand All @@ -376,7 +405,7 @@ def preprocess_val(example_batch):
# DataLoaders creation:
def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"] for example in examples])
labels = torch.tensor([example["labels"] for example in examples])
labels = torch.tensor([example[args.label_column_name] for example in examples])
return {"pixel_values": pixel_values, "labels": labels}

train_dataloader = DataLoader(
Expand Down
1 change: 1 addition & 0 deletions examples/pytorch/test_accelerate_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ def test_run_image_classification_no_trainer(self):
--output_dir {tmp_dir}
--with_tracking
--checkpointing_steps 1
--label_column_name labels
""".split()

run_command(self._launch_args + testargs)
Expand Down
1 change: 1 addition & 0 deletions examples/pytorch/test_pytorch_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,7 @@ def test_run_image_classification(self):
--max_steps 10
--train_val_split 0.1
--seed 42
--label_column_name labels
""".split()

if is_torch_fp16_available_on_device(torch_device):
Expand Down

0 comments on commit dd17bb9

Please sign in to comment.