From 5a08db83d711aa7f924866fe6b882c61f828be2c Mon Sep 17 00:00:00 2001 From: JINO-ROHIT Date: Sun, 15 Dec 2024 18:18:06 +0530 Subject: [PATCH] update tsdae example with SentenceTransformerTrainer --- .../TSDAE/train_tsdae_from_file.py | 156 +++++++++--------- 1 file changed, 80 insertions(+), 76 deletions(-) diff --git a/examples/unsupervised_learning/TSDAE/train_tsdae_from_file.py b/examples/unsupervised_learning/TSDAE/train_tsdae_from_file.py index 13cc5eaef..8621019a3 100644 --- a/examples/unsupervised_learning/TSDAE/train_tsdae_from_file.py +++ b/examples/unsupervised_learning/TSDAE/train_tsdae_from_file.py @@ -1,83 +1,87 @@ -""" -This file loads sentences from a provided text file. It is expected, that the there is one sentence per line in that text file. - -TSDAE will be training using these sentences. Checkpoints are stored every 500 steps to the output folder. - -Usage: -python train_tsdae_from_file.py path/to/sentences.txt - -""" - -import gzip -import logging -import sys -from datetime import datetime - -import tqdm -from torch.utils.data import DataLoader - -from sentence_transformers import LoggingHandler, SentenceTransformer, datasets, losses, models - -#### Just some code to print debug information to stdout -logging.basicConfig( - format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()] +import numpy as np +from datasets import load_dataset +from nltk import word_tokenize +from nltk.tokenize.treebank import TreebankWordDetokenizer + +from sentence_transformers import ( +SentenceTransformer, +SentenceTransformerTrainer, +SentenceTransformerTrainingArguments, ) -#### /print debug information to stdout - -# Train Parameters -model_name = "bert-base-uncased" -batch_size = 8 - -# Input file path (a text file, each line a sentence) -if len(sys.argv) < 2: - print(f"Run this script with: python {sys.argv[0]} path/to/sentences.txt") - exit() - -filepath = sys.argv[1] +from sentence_transformers.losses import DenoisingAutoEncoderLoss -# Save path to store our model -output_name = "" -if len(sys.argv) >= 3: - output_name = "-" + sys.argv[2].replace(" ", "_").replace("/", "_").replace("\\", "_") - -model_output_path = "output/train_tsdae{}-{}".format(output_name, datetime.now().strftime("%Y-%m-%d_%H-%M-%S")) - - -################# Read the train corpus ################# -train_sentences = [] -with ( - gzip.open(filepath, "rt", encoding="utf8") if filepath.endswith(".gz") else open(filepath, encoding="utf8") as fIn -): - for line in tqdm.tqdm(fIn, desc="Read file"): - line = line.strip() - if len(line) >= 10: - train_sentences.append(line) - - -logging.info(f"{len(train_sentences)} train sentences") +# 1. Load a model to finetune with 2. (Optional) model card data +model = SentenceTransformer( + "bert-base-cased", +) -################# Initialize an SBERT model ################# +# 3. Load a dataset to finetune on +dataset = load_dataset("sentence-transformers/all-nli", "triplet") +train_dataset = dataset["train"].select_columns(["anchor"]).select(range(100_000)) +eval_dataset = dataset["dev"].select_columns(["anchor"]) +test_dataset = dataset["test"].select_columns(["anchor"]) +# Now we have 3 datasets, each with one column of text (called "anchor", but the name doesn't matter) +# Now we need to convert the dataset into 2 columns: (damaged_sentence, original_sentence), see https://sbert.net/docs/sentence_transformer/loss_overview.html + +def noise_fn(text, del_ratio=0.6): + words = word_tokenize(text) + n = len(words) + if n == 0: + return text + + keep_or_not = np.random.rand(n) > del_ratio + if sum(keep_or_not) == 0: + keep_or_not[np.random.choice(n)] = True # guarantee that at least one word remains + words_processed = TreebankWordDetokenizer().detokenize(np.array(words)[keep_or_not]) + return { + "damaged": words_processed, + "original": text, + } + +train_dataset = train_dataset.map(noise_fn, input_columns="anchor", remove_columns="anchor") +eval_dataset = eval_dataset.map(noise_fn, input_columns="anchor", remove_columns="anchor") +test_dataset = test_dataset.map(noise_fn, input_columns="anchor", remove_columns="anchor") +# Now we have datasets with 2 columns, damaged & original (in that order). The "anchor" column is removed + +# 4. Define a loss function +loss = DenoisingAutoEncoderLoss(model, decoder_name_or_path="bert-base-cased", tie_encoder_decoder=True) + +# 5. (Optional) Specify training arguments +args = SentenceTransformerTrainingArguments( + # Required parameter: + output_dir="models/bert-base-cased-nli-tsdae", + # Optional training parameters: + num_train_epochs=1, + per_device_train_batch_size=16, + per_device_eval_batch_size=16, + learning_rate=2e-5, + warmup_ratio=0.1, + fp16=True, # Set to False if you get an error that your GPU can't run on FP16 + bf16=False, # Set to True if you have a GPU that supports BF16 + # Optional tracking/debugging parameters: + eval_strategy="steps", + eval_steps=100, + save_strategy="steps", + save_steps=100, + save_total_limit=2, + logging_steps=100, + run_name="bert-base-cased-nli-tsdae", # Will be used in W&B if `wandb` is installed +) -word_embedding_model = models.Transformer(model_name) -# Apply **cls** pooling to get one fixed sized sentence vector -pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), "cls") -model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) +# 6. (Optional) Make an evaluator to evaluate before, during, and after training -################# Train and evaluate the model (it needs about 1 hour for one epoch of AskUbuntu) ################# -# We wrap our training sentences in the DenoisingAutoEncoderDataset to add deletion noise on the fly -train_dataset = datasets.DenoisingAutoEncoderDataset(train_sentences) -train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True) -train_loss = losses.DenoisingAutoEncoderLoss(model, decoder_name_or_path=model_name, tie_encoder_decoder=True) +# 7. Create a trainer & train +trainer = SentenceTransformerTrainer( + model=model, + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + loss=loss, +) +trainer.train() +# 8. Save the trained model +model.save_pretrained("models/bert-base-cased-nli-tsdae/final") -logging.info("Start training") -model.fit( - train_objectives=[(train_dataloader, train_loss)], - epochs=1, - weight_decay=0, - scheduler="constantlr", - optimizer_params={"lr": 3e-5}, - show_progress_bar=True, - checkpoint_path=model_output_path, - use_amp=False, # Set to True, if your GPU supports FP16 cores -) +# 9. (Optional) Push it to the Hugging Face Hub +model.push_to_hub("bert-base-cased-nli-tsdae")