forked from DFKI-NLP/cross-ling-adr
-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_zero_shot_classifier.py
101 lines (79 loc) · 2.5 KB
/
run_zero_shot_classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
"""..."""
import argparse
import colorama
import data
import json
import torch
import wandb
from datetime import datetime
from evaluate import evaluate_on_testset
from utils import training_utils as train_utils
colorama.init()
wandb.init(project="final_binary_classification", entity="lraithel")
# config = wandb.config
GREEN = colorama.Fore.GREEN
MAGENTA = colorama.Fore.MAGENTA
RED = colorama.Fore.RED
YELLOW = colorama.Fore.YELLOW
RESET = colorama.Fore.RESET
DATE = datetime.now().strftime("%d_%m_%Y_%H_%M")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("config", default=None, help="Path to config file.")
parser.add_argument(
"-sp",
"--save_probas",
action="store_true",
help="Save probas and true labels for visualization.",
)
parser.add_argument(
"-debug", "--debug", action="store_true", help="Run with test data."
)
args = parser.parse_args()
with open(args.config, "r") as read_handle:
config = json.load(read_handle)
batch_size = config["batch_size"]
max_length = config["max_length"]
min_length = config["min_length"]
model_name = config["model_name"]
model_path = config["model_path"]
test_data = config["test_data"]
wandb.config.batch_size = batch_size
wandb.config.max_length = max_length
wandb.config.min_length = min_length
wandb.config.model_name = model_name
wandb.config.model_path = model_path
wandb.config.test_data = test_data
sweep_config = wandb.config
# get test data
(
test_input_ids,
test_attention_masks,
test_labels,
langs_test,
) = data.prepare_test_data(
model_name=model_name,
test_data_file=test_data,
min_num_tokens=min_length,
max_num_tokens=max_length,
)
test_dataloader, num_test_sentences = data.get_test_data_loader(
input_ids=test_input_ids,
attention_masks=test_attention_masks,
labels=test_labels,
batch_size=batch_size,
)
# print("Reloaded model results:\n")
model_id = model_path
loaded_model = train_utils.load_fine_tuned_model(
model_id=model_id, model_name=model_name
)
# run model on test data
evaluate_on_testset(
model=loaded_model,
prediction_dataloader=test_dataloader,
num_sentences=num_test_sentences,
model_name=model_name,
languages=langs_test,
)