Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove task arg in load_dataset in image-classification example #28408

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,6 @@ def main():
data_args.dataset_name,
data_args.dataset_config_name,
cache_dir=model_args.cache_dir,
task="image-classification",
token=model_args.token,
)
else:
Expand All @@ -268,9 +267,14 @@ def main():
"imagefolder",
data_files=data_files,
cache_dir=model_args.cache_dir,
task="image-classification",
)

# Rename image and label columns if needed (e.g. Cifar10)
if "img" in dataset["train"].features:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that if data_args.dataset_name is None and data_args.train_dir is None, then the dataset dict will not have a "train" key, and the line above will raise a KeyError.

This could be avoided by replacing the line above with:

Suggested change
if "img" in dataset["train"].features:
if "img" in list(dataset.column_names.values())[0]:

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or:

Suggested change
if "img" in dataset["train"].features:
if "img" in next(iter(dataset.column_names.values())):

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe more readable ?

Suggested change
if "img" in dataset["train"].features:
if "img" in (dataset["train"].features if "train" in dataset else dataset["validation"].features):

and also compatible with your suggestion at huggingface/datasets#6571

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually it seems this script always do training no ? in this case you can assume "train" is always present

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed it's probably mostly used for training, but I'm going to add the suggestion anyway in case.

dataset = dataset.rename_column("img", "image")
if "label" in dataset["train"].features:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same as above.

dataset = dataset.rename_column("label", "labels")

# If we don't have a validation split, split off a percentage of train as validation.
data_args.train_val_split = None if "validation" in dataset.keys() else data_args.train_val_split
if isinstance(data_args.train_val_split, float) and data_args.train_val_split > 0.0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def main():
# download the dataset.
if args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
dataset = load_dataset(args.dataset_name, task="image-classification")
dataset = load_dataset(args.dataset_name)
else:
data_files = {}
if args.train_dir is not None:
Expand All @@ -283,11 +283,16 @@ def main():
"imagefolder",
data_files=data_files,
cache_dir=args.cache_dir,
task="image-classification",
)
# See more about loading custom images at
# https://huggingface.co/docs/datasets/v2.0.0/en/image_process#imagefolder.

# Rename image and label columns if needed (e.g. Cifar10)
if "img" in dataset["train"].features:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same as above.

dataset = dataset.rename_column("img", "image")
if "label" in dataset["train"].features:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same as above.

dataset = dataset.rename_column("label", "labels")

# If we don't have a validation split, split off a percentage of train as validation.
args.train_val_split = None if "validation" in dataset.keys() else args.train_val_split
if isinstance(args.train_val_split, float) and args.train_val_split > 0.0:
Expand Down
Loading