Skip to content

Commit

Permalink
Add stopping callback, should work now
Browse files Browse the repository at this point in the history
  • Loading branch information
tomaarsen committed Nov 28, 2024
1 parent 4be4fba commit 286affd
Showing 1 changed file with 40 additions and 14 deletions.
54 changes: 40 additions & 14 deletions examples/training/data_augmentation/train_sts_seed_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@

import logging
import math
import pprint
import random
import sys

import numpy as np
import torch
from datasets import load_dataset
from transformers import TrainerCallback, TrainerControl, TrainerState

from sentence_transformers import LoggingHandler, SentenceTransformer, losses, models
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
Expand All @@ -43,14 +45,15 @@
format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()]
)


# You can specify any huggingface/transformers pre-trained model here, for example, bert-base-uncased, roberta-base, xlm-roberta-base
model_name = sys.argv[1] if len(sys.argv) > 1 else "bert-base-uncased"
seed_count = int(sys.argv[2]) if len(sys.argv) > 2 else 10
stop_after = float(sys.argv[3]) if len(sys.argv) > 3 else 0.3

logging.info(f"Train and Evaluate: {seed_count} Random Seeds")

scores_per_seed = {}

for seed in range(seed_count):
# Setting seed for all random initializations
logging.info(f"##### Seed {seed} #####")
Expand All @@ -60,7 +63,6 @@

# Read the dataset
train_batch_size = 16
num_epochs = 1
model_save_path = "output/bi-encoder/training_stsbenchmark_" + model_name + "/seed-" + str(seed)

# Use Hugging Face/transformers model (like BERT, RoBERTa, XLNet, XLM-R) for mapping tokens to embeddings
Expand Down Expand Up @@ -91,42 +93,66 @@
scores=eval_dataset["score"],
main_similarity=SimilarityFunction.COSINE,
name="sts-dev",
show_progress_bar=True,
)

# Stopping and Evaluating after 30% of training data (less than 1 epoch)
# We find from (Dodge et al.) that 20-30% is often ideal for convergence of random seed
steps_per_epoch = math.ceil(len(train_dataset) * stop_after)
num_steps_until_stop = math.ceil(len(train_dataset) / train_batch_size * stop_after)

logging.info(f"Early-stopping: {stop_after:.2%} ({num_steps_until_stop} steps) of the training-data")

# 5. Create a Training Callback that stops training after a certain number of steps
class SeedTestingEarlyStoppingCallback(TrainerCallback):
def __init__(self, num_steps_until_stop: int):
self.num_steps_until_stop = num_steps_until_stop

logging.info(f"Early-stopping: {int(stop_after * 100)}% of the training-data")
def on_step_end(
self, args: SentenceTransformerTrainingArguments, state: TrainerState, control: TrainerControl, **kwargs
):
if state.global_step >= self.num_steps_until_stop:
control.should_training_stop = True

# 5. Define the training arguments
seed_testing_early_stopping_callback = SeedTestingEarlyStoppingCallback(num_steps_until_stop)

# 6. Define the training arguments
args = SentenceTransformerTrainingArguments(
# Required parameter:
output_dir=model_save_path,
# Optional training parameters:
num_train_epochs=num_epochs,
num_train_epochs=1,
per_device_train_batch_size=train_batch_size,
per_device_eval_batch_size=train_batch_size,
warmup_ratio=0.1,
max_steps=steps_per_epoch,
fp16=True, # Set to False if you get an error that your GPU can't run on FP16
bf16=False, # Set to True if you have a GPU that supports BF16
# Optional tracking/debugging parameters:
evaluation_strategy="steps",
eval_steps=1000,
save_strategy="steps",
save_steps=1000,
logging_steps=1000,
run_name="sts", # Will be used in W&B if `wandb` is installed
logging_steps=num_steps_until_stop // 10, # Log every 10% of the steps
seed=seed,
run_name=f"sts-{seed}", # Will be used in W&B if `wandb` is installed
)

# 6. Create the trainer & start training
# 7. Create the trainer & start training
trainer = SentenceTransformerTrainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss=train_loss,
evaluator=dev_evaluator,
callbacks=[seed_testing_early_stopping_callback],
)
trainer.train()

# 8. With the partial train, evaluate this seed on the dev set
dev_score = dev_evaluator(model)
logging.info(f"Evaluator Scores for Seed {seed} after early stopping: {dev_score}")
primary_dev_score = dev_score[dev_evaluator.primary_metric]
scores_per_seed[seed] = primary_dev_score
scores_per_seed = dict(sorted(scores_per_seed.items(), key=lambda item: item[1], reverse=True))
logging.info(
f"Current {dev_evaluator.primary_metric} Scores per Seed:\n{pprint.pformat(scores_per_seed, sort_dicts=False)}"
)

# 9. Save the model for this seed
model.save_pretrained(model_save_path)

0 comments on commit 286affd

Please sign in to comment.