Skip to content

Commit

Permalink
Set cache_dir for evaluate.load() in example scripts (huggingface…
Browse files Browse the repository at this point in the history
…#28422)

While using `run_clm.py`,[^1] I noticed that some files were being added
to my global cache, not the local cache. I set the `cache_dir` parameter
for the one call to `evaluate.load()`, which partially solved the
problem. I figured that while I was fixing the one script upstream, I
might as well fix the problem in all other example scripts that I could.

There are still some files being added to my global cache, but this
appears to be a bug in `evaluate` itself. This commit at least moves
some of the files into the local cache, which is better than before.

To create this PR, I made the following regex-based transformation:
`evaluate\.load\((.*?)\)` -> `evaluate\.load\($1,
cache_dir=model_args.cache_dir\)`. After using that, I manually fixed
all modified files with `ruff` serving as useful guidance. During the
process, I removed one existing usage of the `cache_dir` parameter in a
script that did not have a corresponding `--cache-dir` argument
declared.

[^1]: I specifically used `pytorch/language-modeling/run_clm.py` from
v4.34.1 of the library. For the original code, see the following URL:
https://github.com/huggingface/transformers/tree/acc394c4f5e1283c19783581790b3dc3105a3697/examples/pytorch/language-modeling/run_clm.py.
  • Loading branch information
aphedges authored and wgifford committed Jan 21, 2024
1 parent 980e2f7 commit aeb4576
Show file tree
Hide file tree
Showing 31 changed files with 47 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -853,7 +853,7 @@ def blockwise_data_loader(
yield batch

# Metric
metric = evaluate.load("rouge")
metric = evaluate.load("rouge", cache_dir=model_args.cache_dir)

def postprocess_text(preds, labels):
preds = [pred.strip() for pred in preds]
Expand Down
4 changes: 3 additions & 1 deletion examples/flax/question-answering/run_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,7 +807,9 @@ def post_processing_function(examples, features, predictions, stage="eval"):
references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples]
return EvalPrediction(predictions=formatted_predictions, label_ids=references)

metric = evaluate.load("squad_v2" if data_args.version_2_with_negative else "squad")
metric = evaluate.load(
"squad_v2" if data_args.version_2_with_negative else "squad", cache_dir=model_args.cache_dir
)

def compute_metrics(p: EvalPrediction):
return metric.compute(predictions=p.predictions, references=p.label_ids)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ def is_audio_in_length_range(length):
return

# 8. Load Metric
metric = evaluate.load("wer")
metric = evaluate.load("wer", cache_dir=model_args.cache_dir)

def compute_metrics(preds, labels):
# replace padded labels by the padding token
Expand Down
2 changes: 1 addition & 1 deletion examples/flax/summarization/run_summarization_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,7 @@ def preprocess_function(examples):
)

# Metric
metric = evaluate.load("rouge")
metric = evaluate.load("rouge", cache_dir=model_args.cache_dir)

def postprocess_text(preds, labels):
preds = [pred.strip() for pred in preds]
Expand Down
4 changes: 2 additions & 2 deletions examples/flax/text-classification/run_flax_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,9 +599,9 @@ def eval_step(state, batch):
p_eval_step = jax.pmap(eval_step, axis_name="batch")

if data_args.task_name is not None:
metric = evaluate.load("glue", data_args.task_name)
metric = evaluate.load("glue", data_args.task_name, cache_dir=model_args.cache_dir)
else:
metric = evaluate.load("accuracy")
metric = evaluate.load("accuracy", cache_dir=model_args.cache_dir)

logger.info(f"===== Starting training ({num_epochs} epochs) =====")
train_time = 0
Expand Down
2 changes: 1 addition & 1 deletion examples/flax/token-classification/run_flax_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,7 @@ def eval_step(state, batch):

p_eval_step = jax.pmap(eval_step, axis_name="batch")

metric = evaluate.load("seqeval")
metric = evaluate.load("seqeval", cache_dir=model_args.cache_dir)

def get_labels(y_pred, y_true):
# Transform predictions and references tensos to numpy arrays
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def val_transforms(batch):
id2label[str(i)] = label

# Load the accuracy metric from the datasets package
metric = evaluate.load("accuracy")
metric = evaluate.load("accuracy", cache_dir=model_args.cache_dir)

# Define our compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with
# `predictions` and `label_ids` fields) and has to return a dictionary string to float.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def main():
id2label[str(i)] = label

# Load the accuracy metric from the datasets package
metric = evaluate.load("accuracy")
metric = evaluate.load("accuracy", cache_dir=model_args.cache_dir)

# Define our compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
# predictions and label_ids field) and has to return a dictionary string to float.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,6 @@ def main():
dataset = load_dataset(
"imagefolder",
data_files=data_files,
cache_dir=args.cache_dir,
task="image-classification",
)
# See more about loading custom images at
Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch/language-modeling/run_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ def preprocess_logits_for_metrics(logits, labels):
logits = logits[0]
return logits.argmax(dim=-1)

metric = evaluate.load("accuracy")
metric = evaluate.load("accuracy", cache_dir=model_args.cache_dir)

def compute_metrics(eval_preds):
preds, labels = eval_preds
Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch/language-modeling/run_mlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ def preprocess_logits_for_metrics(logits, labels):
logits = logits[0]
return logits.argmax(dim=-1)

metric = evaluate.load("accuracy")
metric = evaluate.load("accuracy", cache_dir=model_args.cache_dir)

def compute_metrics(eval_preds):
preds, labels = eval_preds
Expand Down
4 changes: 3 additions & 1 deletion examples/pytorch/question-answering/run_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,9 @@ def post_processing_function(examples, features, predictions, stage="eval"):
references = [{"id": str(ex["id"]), "answers": ex[answer_column_name]} for ex in examples]
return EvalPrediction(predictions=formatted_predictions, label_ids=references)

metric = evaluate.load("squad_v2" if data_args.version_2_with_negative else "squad")
metric = evaluate.load(
"squad_v2" if data_args.version_2_with_negative else "squad", cache_dir=model_args.cache_dir
)

def compute_metrics(p: EvalPrediction):
return metric.compute(predictions=p.predictions, references=p.label_ids)
Expand Down
4 changes: 3 additions & 1 deletion examples/pytorch/question-answering/run_qa_beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,9 @@ def post_processing_function(examples, features, predictions, stage="eval"):
references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples]
return EvalPrediction(predictions=formatted_predictions, label_ids=references)

metric = evaluate.load("squad_v2" if data_args.version_2_with_negative else "squad")
metric = evaluate.load(
"squad_v2" if data_args.version_2_with_negative else "squad", cache_dir=model_args.cache_dir
)

def compute_metrics(p: EvalPrediction):
return metric.compute(predictions=p.predictions, references=p.label_ids)
Expand Down
4 changes: 3 additions & 1 deletion examples/pytorch/question-answering/run_seq2seq_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,9 @@ def preprocess_validation_function(examples):
pad_to_multiple_of=8 if training_args.fp16 else None,
)

metric = evaluate.load("squad_v2" if data_args.version_2_with_negative else "squad")
metric = evaluate.load(
"squad_v2" if data_args.version_2_with_negative else "squad", cache_dir=model_args.cache_dir
)

def compute_metrics(p: EvalPrediction):
return metric.compute(predictions=p.predictions, references=p.label_ids)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def main():
label2id = {v: str(k) for k, v in id2label.items()}

# Load the mean IoU metric from the datasets package
metric = evaluate.load("mean_iou")
metric = evaluate.load("mean_iou", cache_dir=model_args.cache_dir)

# Define our compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
# predictions and label_ids field) and has to return a dictionary string to float.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ def preprocess_val(example_batch):
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

# Instantiate metric
metric = evaluate.load("mean_iou")
metric = evaluate.load("mean_iou", cache_dir=args.cache_dir)

# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,7 @@ def is_audio_in_length_range(length):
# instantiate a data collator and the trainer

# Define evaluation metrics during training, *i.e.* word error rate, character error rate
eval_metrics = {metric: evaluate.load(metric) for metric in data_args.eval_metrics}
eval_metrics = {metric: evaluate.load(metric, cache_dir=model_args.cache_dir) for metric in data_args.eval_metrics}

# for large datasets it is advised to run the preprocessing on a
# single machine first with ``args.preprocessing_only`` since there will mostly likely
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,7 @@ def is_audio_in_length_range(length):
# instantiate a data collator and the trainer

# Define evaluation metrics during training, *i.e.* word error rate, character error rate
eval_metrics = {metric: evaluate.load(metric) for metric in data_args.eval_metrics}
eval_metrics = {metric: evaluate.load(metric, cache_dir=model_args.cache_dir) for metric in data_args.eval_metrics}

# for large datasets it is advised to run the preprocessing on a
# single machine first with ``args.preprocessing_only`` since there will mostly likely
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ def is_audio_in_length_range(length):
return

# 8. Load Metric
metric = evaluate.load("wer")
metric = evaluate.load("wer", cache_dir=model_args.cache_dir)

def compute_metrics(pred):
pred_ids = pred.predictions
Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch/summarization/run_summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,7 @@ def preprocess_function(examples):
)

# Metric
metric = evaluate.load("rouge")
metric = evaluate.load("rouge", cache_dir=model_args.cache_dir)

def postprocess_text(preds, labels):
preds = [pred.strip() for pred in preds]
Expand Down
10 changes: 5 additions & 5 deletions examples/pytorch/text-classification/run_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,23 +633,23 @@ def preprocess_function(examples):

if data_args.metric_name is not None:
metric = (
evaluate.load(data_args.metric_name, config_name="multilabel")
evaluate.load(data_args.metric_name, config_name="multilabel", cache_dir=model_args.cache_dir)
if is_multi_label
else evaluate.load(data_args.metric_name)
else evaluate.load(data_args.metric_name, cache_dir=model_args.cache_dir)
)
logger.info(f"Using metric {data_args.metric_name} for evaluation.")
else:
if is_regression:
metric = evaluate.load("mse")
metric = evaluate.load("mse", cache_dir=model_args.cache_dir)
logger.info("Using mean squared error (mse) as regression score, you can use --metric_name to overwrite.")
else:
if is_multi_label:
metric = evaluate.load("f1", config_name="multilabel")
metric = evaluate.load("f1", config_name="multilabel", cache_dir=model_args.cache_dir)
logger.info(
"Using multilabel F1 for multi-label classification task, you can use --metric_name to overwrite."
)
else:
metric = evaluate.load("accuracy")
metric = evaluate.load("accuracy", cache_dir=model_args.cache_dir)
logger.info("Using accuracy as classification score, you can use --metric_name to overwrite.")

def compute_metrics(p: EvalPrediction):
Expand Down
6 changes: 3 additions & 3 deletions examples/pytorch/text-classification/run_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,11 +514,11 @@ def preprocess_function(examples):

# Get the metric function
if data_args.task_name is not None:
metric = evaluate.load("glue", data_args.task_name)
metric = evaluate.load("glue", data_args.task_name, cache_dir=model_args.cache_dir)
elif is_regression:
metric = evaluate.load("mse")
metric = evaluate.load("mse", cache_dir=model_args.cache_dir)
else:
metric = evaluate.load("accuracy")
metric = evaluate.load("accuracy", cache_dir=model_args.cache_dir)

# You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
# predictions and label_ids field) and has to return a dictionary string to float.
Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch/text-classification/run_xnli.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def preprocess_function(examples):
)

# Get the metric function
metric = evaluate.load("xnli")
metric = evaluate.load("xnli", cache_dir=model_args.cache_dir)

# You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
# predictions and label_ids field) and has to return a dictionary string to float.
Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch/token-classification/run_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ def tokenize_and_align_labels(examples):
data_collator = DataCollatorForTokenClassification(tokenizer, pad_to_multiple_of=8 if training_args.fp16 else None)

# Metrics
metric = evaluate.load("seqeval")
metric = evaluate.load("seqeval", cache_dir=model_args.cache_dir)

def compute_metrics(p):
predictions, labels = p
Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch/translation/run_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,7 @@ def preprocess_function(examples):
)

# Metric
metric = evaluate.load("sacrebleu")
metric = evaluate.load("sacrebleu", cache_dir=model_args.cache_dir)

def postprocess_text(preds, labels):
preds = [pred.strip() for pred in preds]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def val_transforms(example_batch):
collate_fn = DefaultDataCollator(return_tensors="np")

# Load the accuracy metric from the datasets package
metric = evaluate.load("accuracy")
metric = evaluate.load("accuracy", cache_dir=model_args.cache_dir)

# Define our compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
# predictions and label_ids field) and has to return a dictionary string to float.
Expand Down
4 changes: 3 additions & 1 deletion examples/tensorflow/question-answering/run_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,9 @@ def post_processing_function(examples, features, predictions, stage="eval"):
references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples]
return EvalPrediction(predictions=formatted_predictions, label_ids=references)

metric = evaluate.load("squad_v2" if data_args.version_2_with_negative else "squad")
metric = evaluate.load(
"squad_v2" if data_args.version_2_with_negative else "squad", cache_dir=model_args.cache_dir
)

def compute_metrics(p: EvalPrediction):
return metric.compute(predictions=p.predictions, references=p.label_ids)
Expand Down
2 changes: 1 addition & 1 deletion examples/tensorflow/summarization/run_summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,7 @@ def postprocess_text(preds, labels):

# region Metric and KerasMetricCallback
if training_args.do_eval:
metric = evaluate.load("rouge")
metric = evaluate.load("rouge", cache_dir=model_args.cache_dir)

if data_args.val_max_target_length is None:
data_args.val_max_target_length = data_args.max_target_length
Expand Down
2 changes: 1 addition & 1 deletion examples/tensorflow/text-classification/run_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def preprocess_function(examples):
# endregion

# region Metric function
metric = evaluate.load("glue", data_args.task_name)
metric = evaluate.load("glue", data_args.task_name, cache_dir=model_args.cache_dir)

def compute_metrics(preds, label_ids):
preds = preds["logits"]
Expand Down
2 changes: 1 addition & 1 deletion examples/tensorflow/token-classification/run_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ def tokenize_and_align_labels(examples):
# endregion

# Metrics
metric = evaluate.load("seqeval")
metric = evaluate.load("seqeval", cache_dir=model_args.cache_dir)

def get_labels(y_pred, y_true):
# Transform predictions and references tensos to numpy arrays
Expand Down
2 changes: 1 addition & 1 deletion examples/tensorflow/translation/run_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ def preprocess_function(examples):

# region Metric and postprocessing
if training_args.do_eval:
metric = evaluate.load("sacrebleu")
metric = evaluate.load("sacrebleu", cache_dir=model_args.cache_dir)

if data_args.val_max_target_length is None:
data_args.val_max_target_length = data_args.max_target_length
Expand Down

0 comments on commit aeb4576

Please sign in to comment.