Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove task arg in load_dataset in image-classification example #28408

Merged
3 changes: 2 additions & 1 deletion examples/pytorch/image-classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ python run_image_classification.py \
--dataset_name beans \
--output_dir ./beans_outputs/ \
--remove_unused_columns False \
--label_column_name labels \
--do_train \
--do_eval \
--push_to_hub \
Expand Down Expand Up @@ -197,7 +198,7 @@ accelerate test
that will check everything is ready for training. Finally, you can launch training with

```bash
accelerate launch run_image_classification_trainer.py
accelerate launch run_image_classification_no_trainer.py --image_column_name img
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the default should not need this no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I set the default to image because that's the name of the image column for most image datasets. But Cifar10, which is used in this example, has it named img for some reason.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it thanks

```

This command is the same and will work for:
Expand Down
43 changes: 32 additions & 11 deletions examples/pytorch/image-classification/run_image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,14 @@ class DataTrainingArguments:
)
},
)
image_column_name: str = field(
default="image",
metadata={"help": "The name of the dataset column containing the image data. Defaults to 'image'."},
)
label_column_name: str = field(
default="label",
metadata={"help": "The name of the dataset column containing the labels. Defaults to 'label'."},
)

def __post_init__(self):
if self.dataset_name is None and (self.train_dir is None and self.validation_dir is None):
Expand Down Expand Up @@ -175,12 +183,6 @@ class ModelArguments:
)


def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"] for example in examples])
labels = torch.tensor([example["labels"] for example in examples])
return {"pixel_values": pixel_values, "labels": labels}


def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
Expand Down Expand Up @@ -255,7 +257,6 @@ def main():
data_args.dataset_name,
data_args.dataset_config_name,
cache_dir=model_args.cache_dir,
task="image-classification",
token=model_args.token,
)
else:
Expand All @@ -268,9 +269,27 @@ def main():
"imagefolder",
data_files=data_files,
cache_dir=model_args.cache_dir,
task="image-classification",
)

dataset_column_names = dataset["train"].column_names if "train" in dataset else dataset["validation"].column_names
if data_args.image_column_name not in dataset_column_names:
raise ValueError(
f"--image_column_name {data_args.image_column_name} not found in dataset '{data_args.dataset_name}'. "
"Make sure to set `--image_column_name` to the correct audio column - one of "
f"{', '.join(dataset_column_names)}."
)
if data_args.label_column_name not in dataset_column_names:
raise ValueError(
f"--label_column_name {data_args.label_column_name} not found in dataset '{data_args.dataset_name}'. "
"Make sure to set `--label_column_name` to the correct text column - one of "
f"{', '.join(dataset_column_names)}."
)

def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"] for example in examples])
labels = torch.tensor([example[data_args.label_column_name] for example in examples])
return {"pixel_values": pixel_values, "labels": labels}

# If we don't have a validation split, split off a percentage of train as validation.
data_args.train_val_split = None if "validation" in dataset.keys() else data_args.train_val_split
if isinstance(data_args.train_val_split, float) and data_args.train_val_split > 0.0:
Expand All @@ -280,7 +299,7 @@ def main():

# Prepare label mappings.
# We'll include these in the model's config to get human readable labels in the Inference API.
labels = dataset["train"].features["labels"].names
labels = dataset["train"].features[data_args.label_column_name].names
label2id, id2label = {}, {}
for i, label in enumerate(labels):
label2id[label] = str(i)
Expand Down Expand Up @@ -354,13 +373,15 @@ def compute_metrics(p):
def train_transforms(example_batch):
"""Apply _train_transforms across a batch."""
example_batch["pixel_values"] = [
_train_transforms(pil_img.convert("RGB")) for pil_img in example_batch["image"]
_train_transforms(pil_img.convert("RGB")) for pil_img in example_batch[data_args.image_column_name]
]
return example_batch

def val_transforms(example_batch):
"""Apply _val_transforms across a batch."""
example_batch["pixel_values"] = [_val_transforms(pil_img.convert("RGB")) for pil_img in example_batch["image"]]
example_batch["pixel_values"] = [
_val_transforms(pil_img.convert("RGB")) for pil_img in example_batch[data_args.image_column_name]
]
return example_batch

if training_args.do_train:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,18 @@ def parse_args():
action="store_true",
help="Whether or not to enable to load a pretrained model whose head dimensions are different.",
)
parser.add_argument(
"--image_column_name",
type=str,
default="image",
help="The name of the dataset column containing the image data. Defaults to 'image'.",
)
parser.add_argument(
"--label_column_name",
type=str,
default="label",
help="The name of the dataset column containing the labels. Defaults to 'label'.",
)
args = parser.parse_args()

# Sanity checks
Expand Down Expand Up @@ -272,7 +284,7 @@ def main():
# download the dataset.
if args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
dataset = load_dataset(args.dataset_name, task="image-classification")
dataset = load_dataset(args.dataset_name)
else:
data_files = {}
if args.train_dir is not None:
Expand All @@ -282,11 +294,24 @@ def main():
dataset = load_dataset(
"imagefolder",
data_files=data_files,
task="image-classification",
)
# See more about loading custom images at
# https://huggingface.co/docs/datasets/v2.0.0/en/image_process#imagefolder.

dataset_column_names = dataset["train"].column_names if "train" in dataset else dataset["validation"].column_names
if args.image_column_name not in dataset_column_names:
raise ValueError(
f"--image_column_name {args.image_column_name} not found in dataset '{args.dataset_name}'. "
"Make sure to set `--image_column_name` to the correct audio column - one of "
f"{', '.join(dataset_column_names)}."
)
if args.label_column_name not in dataset_column_names:
raise ValueError(
f"--label_column_name {args.label_column_name} not found in dataset '{args.dataset_name}'. "
"Make sure to set `--label_column_name` to the correct text column - one of "
f"{', '.join(dataset_column_names)}."
)

# If we don't have a validation split, split off a percentage of train as validation.
args.train_val_split = None if "validation" in dataset.keys() else args.train_val_split
if isinstance(args.train_val_split, float) and args.train_val_split > 0.0:
Expand All @@ -296,7 +321,7 @@ def main():

# Prepare label mappings.
# We'll include these in the model's config to get human readable labels in the Inference API.
labels = dataset["train"].features["labels"].names
labels = dataset["train"].features[args.label_column_name].names
label2id = {label: str(i) for i, label in enumerate(labels)}
id2label = {str(i): label for i, label in enumerate(labels)}

Expand Down Expand Up @@ -355,12 +380,16 @@ def main():

def preprocess_train(example_batch):
"""Apply _train_transforms across a batch."""
example_batch["pixel_values"] = [train_transforms(image.convert("RGB")) for image in example_batch["image"]]
example_batch["pixel_values"] = [
train_transforms(image.convert("RGB")) for image in example_batch[args.image_column_name]
]
return example_batch

def preprocess_val(example_batch):
"""Apply _val_transforms across a batch."""
example_batch["pixel_values"] = [val_transforms(image.convert("RGB")) for image in example_batch["image"]]
example_batch["pixel_values"] = [
val_transforms(image.convert("RGB")) for image in example_batch[args.image_column_name]
]
return example_batch

with accelerator.main_process_first():
Expand All @@ -376,7 +405,7 @@ def preprocess_val(example_batch):
# DataLoaders creation:
def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"] for example in examples])
labels = torch.tensor([example["labels"] for example in examples])
labels = torch.tensor([example[args.label_column_name] for example in examples])
return {"pixel_values": pixel_values, "labels": labels}

train_dataloader = DataLoader(
Expand Down
1 change: 1 addition & 0 deletions examples/pytorch/test_accelerate_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ def test_run_image_classification_no_trainer(self):
--output_dir {tmp_dir}
--with_tracking
--checkpointing_steps 1
--label_column_name labels
""".split()

run_command(self._launch_args + testargs)
Expand Down
1 change: 1 addition & 0 deletions examples/pytorch/test_pytorch_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,7 @@ def test_run_image_classification(self):
--max_steps 10
--train_val_split 0.1
--seed 42
--label_column_name labels
""".split()

if is_torch_fp16_available_on_device(torch_device):
Expand Down
Loading