-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
update tsdae example with SentenceTransformerTrainer
- Loading branch information
1 parent
03db339
commit 5a08db8
Showing
1 changed file
with
80 additions
and
76 deletions.
There are no files selected for viewing
156 changes: 80 additions & 76 deletions
156
examples/unsupervised_learning/TSDAE/train_tsdae_from_file.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,83 +1,87 @@ | ||
""" | ||
This file loads sentences from a provided text file. It is expected, that the there is one sentence per line in that text file. | ||
TSDAE will be training using these sentences. Checkpoints are stored every 500 steps to the output folder. | ||
Usage: | ||
python train_tsdae_from_file.py path/to/sentences.txt | ||
""" | ||
|
||
import gzip | ||
import logging | ||
import sys | ||
from datetime import datetime | ||
|
||
import tqdm | ||
from torch.utils.data import DataLoader | ||
|
||
from sentence_transformers import LoggingHandler, SentenceTransformer, datasets, losses, models | ||
|
||
#### Just some code to print debug information to stdout | ||
logging.basicConfig( | ||
format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()] | ||
import numpy as np | ||
from datasets import load_dataset | ||
from nltk import word_tokenize | ||
from nltk.tokenize.treebank import TreebankWordDetokenizer | ||
|
||
from sentence_transformers import ( | ||
SentenceTransformer, | ||
SentenceTransformerTrainer, | ||
SentenceTransformerTrainingArguments, | ||
) | ||
#### /print debug information to stdout | ||
|
||
# Train Parameters | ||
model_name = "bert-base-uncased" | ||
batch_size = 8 | ||
|
||
# Input file path (a text file, each line a sentence) | ||
if len(sys.argv) < 2: | ||
print(f"Run this script with: python {sys.argv[0]} path/to/sentences.txt") | ||
exit() | ||
|
||
filepath = sys.argv[1] | ||
from sentence_transformers.losses import DenoisingAutoEncoderLoss | ||
|
||
# Save path to store our model | ||
output_name = "" | ||
if len(sys.argv) >= 3: | ||
output_name = "-" + sys.argv[2].replace(" ", "_").replace("/", "_").replace("\\", "_") | ||
|
||
model_output_path = "output/train_tsdae{}-{}".format(output_name, datetime.now().strftime("%Y-%m-%d_%H-%M-%S")) | ||
|
||
|
||
################# Read the train corpus ################# | ||
train_sentences = [] | ||
with ( | ||
gzip.open(filepath, "rt", encoding="utf8") if filepath.endswith(".gz") else open(filepath, encoding="utf8") as fIn | ||
): | ||
for line in tqdm.tqdm(fIn, desc="Read file"): | ||
line = line.strip() | ||
if len(line) >= 10: | ||
train_sentences.append(line) | ||
|
||
|
||
logging.info(f"{len(train_sentences)} train sentences") | ||
# 1. Load a model to finetune with 2. (Optional) model card data | ||
model = SentenceTransformer( | ||
"bert-base-cased", | ||
) | ||
|
||
################# Initialize an SBERT model ################# | ||
# 3. Load a dataset to finetune on | ||
dataset = load_dataset("sentence-transformers/all-nli", "triplet") | ||
train_dataset = dataset["train"].select_columns(["anchor"]).select(range(100_000)) | ||
eval_dataset = dataset["dev"].select_columns(["anchor"]) | ||
test_dataset = dataset["test"].select_columns(["anchor"]) | ||
# Now we have 3 datasets, each with one column of text (called "anchor", but the name doesn't matter) | ||
# Now we need to convert the dataset into 2 columns: (damaged_sentence, original_sentence), see https://sbert.net/docs/sentence_transformer/loss_overview.html | ||
|
||
def noise_fn(text, del_ratio=0.6): | ||
words = word_tokenize(text) | ||
n = len(words) | ||
if n == 0: | ||
return text | ||
|
||
keep_or_not = np.random.rand(n) > del_ratio | ||
if sum(keep_or_not) == 0: | ||
keep_or_not[np.random.choice(n)] = True # guarantee that at least one word remains | ||
words_processed = TreebankWordDetokenizer().detokenize(np.array(words)[keep_or_not]) | ||
return { | ||
"damaged": words_processed, | ||
"original": text, | ||
} | ||
|
||
train_dataset = train_dataset.map(noise_fn, input_columns="anchor", remove_columns="anchor") | ||
eval_dataset = eval_dataset.map(noise_fn, input_columns="anchor", remove_columns="anchor") | ||
test_dataset = test_dataset.map(noise_fn, input_columns="anchor", remove_columns="anchor") | ||
# Now we have datasets with 2 columns, damaged & original (in that order). The "anchor" column is removed | ||
|
||
# 4. Define a loss function | ||
loss = DenoisingAutoEncoderLoss(model, decoder_name_or_path="bert-base-cased", tie_encoder_decoder=True) | ||
|
||
# 5. (Optional) Specify training arguments | ||
args = SentenceTransformerTrainingArguments( | ||
# Required parameter: | ||
output_dir="models/bert-base-cased-nli-tsdae", | ||
# Optional training parameters: | ||
num_train_epochs=1, | ||
per_device_train_batch_size=16, | ||
per_device_eval_batch_size=16, | ||
learning_rate=2e-5, | ||
warmup_ratio=0.1, | ||
fp16=True, # Set to False if you get an error that your GPU can't run on FP16 | ||
bf16=False, # Set to True if you have a GPU that supports BF16 | ||
# Optional tracking/debugging parameters: | ||
eval_strategy="steps", | ||
eval_steps=100, | ||
save_strategy="steps", | ||
save_steps=100, | ||
save_total_limit=2, | ||
logging_steps=100, | ||
run_name="bert-base-cased-nli-tsdae", # Will be used in W&B if `wandb` is installed | ||
) | ||
|
||
word_embedding_model = models.Transformer(model_name) | ||
# Apply **cls** pooling to get one fixed sized sentence vector | ||
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), "cls") | ||
model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) | ||
# 6. (Optional) Make an evaluator to evaluate before, during, and after training | ||
|
||
################# Train and evaluate the model (it needs about 1 hour for one epoch of AskUbuntu) ################# | ||
# We wrap our training sentences in the DenoisingAutoEncoderDataset to add deletion noise on the fly | ||
train_dataset = datasets.DenoisingAutoEncoderDataset(train_sentences) | ||
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True) | ||
train_loss = losses.DenoisingAutoEncoderLoss(model, decoder_name_or_path=model_name, tie_encoder_decoder=True) | ||
# 7. Create a trainer & train | ||
trainer = SentenceTransformerTrainer( | ||
model=model, | ||
args=args, | ||
train_dataset=train_dataset, | ||
eval_dataset=eval_dataset, | ||
loss=loss, | ||
) | ||
trainer.train() | ||
|
||
# 8. Save the trained model | ||
model.save_pretrained("models/bert-base-cased-nli-tsdae/final") | ||
|
||
logging.info("Start training") | ||
model.fit( | ||
train_objectives=[(train_dataloader, train_loss)], | ||
epochs=1, | ||
weight_decay=0, | ||
scheduler="constantlr", | ||
optimizer_params={"lr": 3e-5}, | ||
show_progress_bar=True, | ||
checkpoint_path=model_output_path, | ||
use_amp=False, # Set to True, if your GPU supports FP16 cores | ||
) | ||
# 9. (Optional) Push it to the Hugging Face Hub | ||
model.push_to_hub("bert-base-cased-nli-tsdae") |