Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update train_sts_seed_optimization with SentenceTransformerTrainer #3092

Merged
merged 4 commits into from
Dec 2, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 80 additions & 55 deletions examples/training/data_augmentation/train_sts_seed_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,35 +23,27 @@
python train_sts_seed_optimization.py bert-base-uncased 10 0.3
"""

import csv
import gzip
import logging
import math
import os
import pprint
import random
import sys

import numpy as np
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import TrainerCallback, TrainerControl, TrainerState

from sentence_transformers import LoggingHandler, SentenceTransformer, losses, models, util
from sentence_transformers import LoggingHandler, SentenceTransformer, losses, models
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
from sentence_transformers.readers import InputExample
from sentence_transformers.similarity_functions import SimilarityFunction
from sentence_transformers.trainer import SentenceTransformerTrainer
from sentence_transformers.training_args import SentenceTransformerTrainingArguments

#### Just some code to print debug information to stdout
logging.basicConfig(
format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()]
)
#### /print debug information to stdout


# Check if dataset exists. If not, download and extract it
sts_dataset_path = "datasets/stsbenchmark.tsv.gz"

if not os.path.exists(sts_dataset_path):
util.http_get("https://sbert.net/datasets/stsbenchmark.tsv.gz", sts_dataset_path)


# You can specify any huggingface/transformers pre-trained model here, for example, bert-base-uncased, roberta-base, xlm-roberta-base
model_name = sys.argv[1] if len(sys.argv) > 1 else "bert-base-uncased"
Expand All @@ -60,6 +52,8 @@

logging.info(f"Train and Evaluate: {seed_count} Random Seeds")

scores_per_seed = {}

for seed in range(seed_count):
# Setting seed for all random initializations
logging.info(f"##### Seed {seed} #####")
Expand All @@ -69,7 +63,6 @@

# Read the dataset
train_batch_size = 16
num_epochs = 1
model_save_path = "output/bi-encoder/training_stsbenchmark_" + model_name + "/seed-" + str(seed)

# Use Hugging Face/transformers model (like BERT, RoBERTa, XLNet, XLM-R) for mapping tokens to embeddings
Expand All @@ -85,49 +78,81 @@

model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

# Convert the dataset to a DataLoader ready for training
logging.info("Read STSbenchmark train dataset")

train_samples = []
dev_samples = []
test_samples = []
with gzip.open(sts_dataset_path, "rt", encoding="utf8") as fIn:
reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE)
for row in reader:
score = float(row["score"]) / 5.0 # Normalize score to range 0 ... 1
inp_example = InputExample(texts=[row["sentence1"], row["sentence2"]], label=score)

if row["split"] == "dev":
dev_samples.append(inp_example)
elif row["split"] == "test":
test_samples.append(inp_example)
else:
train_samples.append(inp_example)

train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size)
train_loss = losses.CosineSimilarityLoss(model=model)
# 2. Load the STSB dataset: https://huggingface.co/datasets/sentence-transformers/stsb
train_dataset = load_dataset("sentence-transformers/stsb", split="train")
eval_dataset = load_dataset("sentence-transformers/stsb", split="validation")
test_dataset = load_dataset("sentence-transformers/stsb", split="test")
logging.info(train_dataset)

logging.info("Read STSbenchmark dev dataset")
evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name="sts-dev")
train_loss = losses.CosineSimilarityLoss(model=model)

# Configure the training. We skip evaluation in this example
warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) # 10% of train data for warm-up
# 4. Define an evaluator for use during training.
dev_evaluator = EmbeddingSimilarityEvaluator(
sentences1=eval_dataset["sentence1"],
sentences2=eval_dataset["sentence2"],
scores=eval_dataset["score"],
main_similarity=SimilarityFunction.COSINE,
name="sts-dev",
show_progress_bar=True,
)

# Stopping and Evaluating after 30% of training data (less than 1 epoch)
# We find from (Dodge et al.) that 20-30% is often ideal for convergence of random seed
steps_per_epoch = math.ceil(len(train_dataloader) * stop_after)

logging.info(f"Warmup-steps: {warmup_steps}")

logging.info(f"Early-stopping: {int(stop_after * 100)}% of the training-data")
num_steps_until_stop = math.ceil(len(train_dataset) / train_batch_size * stop_after)

logging.info(f"Early-stopping: {stop_after:.2%} ({num_steps_until_stop} steps) of the training-data")

# 5. Create a Training Callback that stops training after a certain number of steps
class SeedTestingEarlyStoppingCallback(TrainerCallback):
def __init__(self, num_steps_until_stop: int):
self.num_steps_until_stop = num_steps_until_stop

def on_step_end(
self, args: SentenceTransformerTrainingArguments, state: TrainerState, control: TrainerControl, **kwargs
):
if state.global_step >= self.num_steps_until_stop:
control.should_training_stop = True

seed_testing_early_stopping_callback = SeedTestingEarlyStoppingCallback(num_steps_until_stop)

# 6. Define the training arguments
args = SentenceTransformerTrainingArguments(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the stop_after isn't actually making it stop after this many steps.
Normally you can use max_steps, but then I think it messes with the scheduler, ideally we want the scheduler to be "normal" but then still stop after stop_after steps, but I'm not sure if that's the old behaviour either.

# Required parameter:
output_dir=model_save_path,
# Optional training parameters:
num_train_epochs=1,
per_device_train_batch_size=train_batch_size,
per_device_eval_batch_size=train_batch_size,
warmup_ratio=0.1,
fp16=True, # Set to False if you get an error that your GPU can't run on FP16
bf16=False, # Set to True if you have a GPU that supports BF16
# Optional tracking/debugging parameters:
logging_steps=num_steps_until_stop // 10, # Log every 10% of the steps
seed=seed,
run_name=f"sts-{seed}", # Will be used in W&B if `wandb` is installed
)

# Train the model
model.fit(
train_objectives=[(train_dataloader, train_loss)],
evaluator=evaluator,
epochs=num_epochs,
steps_per_epoch=steps_per_epoch,
evaluation_steps=1000,
warmup_steps=warmup_steps,
output_path=model_save_path,
# 7. Create the trainer & start training
trainer = SentenceTransformerTrainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss=train_loss,
evaluator=dev_evaluator,
callbacks=[seed_testing_early_stopping_callback],
)
trainer.train()

# 8. With the partial train, evaluate this seed on the dev set
dev_score = dev_evaluator(model)
logging.info(f"Evaluator Scores for Seed {seed} after early stopping: {dev_score}")
primary_dev_score = dev_score[dev_evaluator.primary_metric]
scores_per_seed[seed] = primary_dev_score
scores_per_seed = dict(sorted(scores_per_seed.items(), key=lambda item: item[1], reverse=True))
logging.info(
f"Current {dev_evaluator.primary_metric} Scores per Seed:\n{pprint.pformat(scores_per_seed, sort_dicts=False)}"
)

# 9. Save the model for this seed
model.save_pretrained(model_save_path)