-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
39 lines (35 loc) · 1.21 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import argparse
import json
from ngpt.train import NGPTTrainer
from bart.train import run_train
from pegasus.pipeline import SummarizationPipeline
parser = argparse.ArgumentParser(description="Train a model with specified parameters.")
parser.add_argument(
"--MODEL", type=str, required=True, help="Model name (e.g., ngpt, bARt, pegasus)"
)
parser.add_argument(
"--PARAMS",
type=str,
required=False,
help="Path to the JSON file containing model parameters. Examples under params in root of the project",
)
args = parser.parse_args()
param_file = args.PARAMS
params = {}
if param_file:
try:
with open(param_file, "r") as file:
params.update(json.load(file))
print("Loaded parameters from JSON file:")
except FileNotFoundError:
print(f"Error: File {param_file} not found. Using default parameters.")
except json.JSONDecodeError:
print(f"Error: Invalid JSON in {param_file}. Using default parameters.")
if str(args.MODEL).lower() == "ngpt":
nGPT = NGPTTrainer(params)
nGPT.train()
elif str(args.MODEL).lower() == "bart":
run_train(params)
elif str(args.MODEL).lower() == "pegasus":
summ = SummarizationPipeline(params)
summ.train_model()