-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_sbert.py
101 lines (75 loc) · 3.83 KB
/
train_sbert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
'''
This examples show how to train a basic Bi-Encoder for any BEIR dataset without any mined hard negatives or triplets.
The queries and passages are passed independently to the transformer network to produce fixed sized embeddings.
These embeddings can then be compared using cosine-similarity to find matching passages for a given query.
For training, we use MultipleNegativesRankingLoss. There, we pass pairs in the format:
(query, positive_passage). Other positive passages within a single batch becomes negatives given the pos passage.
We do not mine hard negatives or train triplets in this example.
Running this script:
python train_sbert.py
'''
import logging
import os
import pathlib
import torch
from sentence_transformers import losses, SentenceTransformer
from beir.beir import LoggingHandler
from beir.beir import util
from beir.beir.datasets.data_loader import GenericDataLoader
from beir.beir.retrieval.train import TrainRetriever
#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO,
handlers=[LoggingHandler()])
device = "cpu"
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
logging.info("Using Device: {}".format(device))
#### Download dataset and unzip the dataset
dataset = "scifact"
out_dir = "datasets"
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), out_dir)
data_path = util.download_and_unzip(url, out_dir)
#### Provide the data_path where nfcorpus has been downloaded and unzipped
corpus, queries, qrels = GenericDataLoader(data_path).load(split="train")
#### Please Note not all datasets contain a dev split, comment out the line if such the case
# dev_corpus, dev_queries, dev_qrels = GenericDataLoader(data_path).load(split="dev")
#### Provide any sentence-transformers or HF model
model_name = "sentence-transformers/all-distilroberta-v1"
model = SentenceTransformer(model_name_or_path=model_name, device=device)
# sparse_model = SparseSearch(models.SPARTA(model_name), batch_size=128)
#### Or provide pretrained sentence-transformer model
# model = SentenceTransformer("msmarco-distilbert-base-v3")
retriever = TrainRetriever(model=model, batch_size=16)
#### Prepare training samples
train_samples = retriever.load_train(corpus, queries, qrels)
train_dataloader = retriever.prepare_train(train_samples, shuffle=True)
#### Training SBERT with cosine-product
train_loss = losses.MultipleNegativesRankingLoss(model=retriever.model)
#### training SBERT with dot-product
# train_loss = losses.MultipleNegativesRankingLoss(model=retriever.model, similarity_fct=util.dot_score)
#### Prepare dev evaluator
#ir_evaluator = retriever.load_ir_evaluator(dev_corpus, dev_queries, dev_qrels)
#### If no dev set is present from above use dummy evaluator
ir_evaluator = retriever.load_dummy_evaluator()
#### Provide model save path
model_save_path = os.path.join(pathlib.Path(__file__).parent.absolute(), "output",
"{}-v1-{}".format(model_name, dataset))
os.makedirs(model_save_path, exist_ok=True)
#### Configure Train params
num_epochs = 5
evaluation_steps = 5000
warmup_steps = int(len(train_samples) * num_epochs / retriever.batch_size * 0.1)
logging.info("Warmup-steps: {}".format(warmup_steps))
logging.info("Warmup-steps: {}".format(int(len(train_samples) * num_epochs / (retriever.batch_size * 0.1))))
retriever.fit(train_objectives=[(train_dataloader, train_loss)],
evaluator=ir_evaluator,
epochs=num_epochs,
output_path=model_save_path,
warmup_steps=warmup_steps,
evaluation_steps=evaluation_steps,
use_amp=True)