-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_model.py
74 lines (64 loc) · 2.45 KB
/
train_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from trainer import MatryoshkaTrainer
from utils import create_evaluation_pairs
import logging
import os
import torch
# Set up logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
def main():
# Set random seed for reproducibility
torch.manual_seed(42)
# Create output directory if it doesn't exist
output_dir = "trained-matryoshka-model"
os.makedirs(output_dir, exist_ok=True)
# Sample training data with more examples and clear semantic relationships
texts = [
# Weather-related pairs
"The weather is beautiful today",
"It's a sunny and pleasant day",
"The sky is clear and bright",
"The temperature is perfect outside",
# Animal-related pairs
"The cat is sleeping on the couch",
"A feline is resting on the sofa",
"The dog is playing in the yard",
"The puppy is running around outside",
# Programming-related pairs
"Python is a great programming language",
"Programming in Python is wonderful",
"Java is used for enterprise development",
"Software development using Java is common",
# Additional varied examples
"The book was very interesting",
"This novel is quite engaging",
"The movie was fantastic",
"The film was excellent",
"I love cooking Italian food",
"Making pasta dishes is my passion",
"Playing guitar is fun",
"Learning music is enjoyable",
]
logger.info("Creating training pairs...")
train_data = create_evaluation_pairs(texts)
logger.info("Initializing trainer...")
trainer = MatryoshkaTrainer(
base_model="tomaarsen/mpnet-base-nli-matryoshka", # Experiment with better models
matryoshka_dims=[768, 512, 256, 128, 64], # Dimensions to train
weights=[1.0, 0.8, 0.6, 0.4, 0.2], # Weights for dimensions
)
logger.info("Starting training...")
trainer.train(
train_data=train_data,
batch_size=32, # Experiment with larger batch size
epochs=10, # Increase epochs for better convergence
output_path=output_dir,
warmup_steps=200, # Increase warmup steps for smoother training
)
logger.info("Training completed!")
if __name__ == "__main__":
# Set environment variable to handle tokenizer parallelism warning
os.environ["TOKENIZERS_PARALLELISM"] = "false"
main()