-
Notifications
You must be signed in to change notification settings - Fork 0
/
main_zslt.py
130 lines (110 loc) · 3.81 KB
/
main_zslt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import logging
import json
import torch
from torch.utils.data import DataLoader
import hydra
from omegaconf import DictConfig, OmegaConf
from meffi_prompt.utils import resolve_relative_path, seed_everything, aggregate_batch
from meffi_prompt.data import SmilerDataset
from meffi_prompt.prompt import SmilerPrompt, get_max_decode_length
from meffi_prompt.model import T5Model
from meffi_prompt.tokenizer import BatchTokenizer
from meffi_prompt.eval import train_and_eval
logger = logging.getLogger(__name__)
# in-order + w/o reversed
template = {
"input": ["x", "[vN]", "eh", "[vN]", "<extra_id_0>", "[vN]", "et"],
"target": ["<extra_id_0>", "r", "<extra_id_1>"],
}
# # post-order
# template = {
# "input": ["x", "[vN]", "eh", "[vN]", "et", "[vN]", "<extra_id_0>"],
# "target": ["<extra_id_0>", "r", "<extra_id_1>"],
# }
@hydra.main(config_name="config_zslt", config_path="configs")
def main(cfg: DictConfig) -> None:
"""
Conducts evaluation given the configuration.
Args
----
`cfg`: Hydra-format configuration given in a dict.
"""
resolve_relative_path(cfg)
print(OmegaConf.to_yaml(cfg))
seed_everything(cfg.seed)
device = (
torch.device("cuda", cfg.cuda_device)
if cfg.cuda_device > -1
else torch.device("cpu")
)
# get raw dataset and do simple pre-processing such as convert special tokens
train_dataset = SmilerDataset(cfg.train_file)
eval_dataset = SmilerDataset(cfg.eval_file)
if cfg.tag_set_transfer:
eval_tag_set = set(eval_dataset.label_to_id.keys())
eval_tag_set.discard("no_relation")
train_dataset.dataset = train_dataset.dataset.filter(
lambda example: example["relation"] not in eval_dataset.label_to_id.keys()
)
# transform to prompted dataset, with appended inputs and verbalized labels
prompt = SmilerPrompt(
template=template,
model_name=cfg.model,
soft_token_length=0,
)
train_dataset = prompt(train_dataset, translate=False)
eval_dataset, verbalizer = prompt(
eval_dataset, translate=False, return_verbalizer=True
)
# set dataloader
train_loader = DataLoader(
train_dataset,
batch_size=cfg.batch_size,
shuffle=True,
pin_memory=True,
collate_fn=aggregate_batch,
)
eval_loader = DataLoader(
eval_dataset,
batch_size=cfg.batch_size,
shuffle=False,
pin_memory=True,
collate_fn=aggregate_batch,
)
# instantiate tokenizer and model
batch_processor = BatchTokenizer(
tokenizer_name_or_path=cfg.model,
max_length=cfg.max_length,
num_soft_tokens=0,
)
tokenized_verbalizer = {
k: batch_processor.tokenizer(v, add_special_tokens=False)["input_ids"]
for k, v in verbalizer.items()
}
max_relation_length = max([len(v) for v in tokenized_verbalizer.values()])
max_decode_length = get_max_decode_length(template, max_relation_length)
logger.info("Max decode length: {}.".format(max_decode_length))
model = T5Model(
cfg.model,
max_decode_length=max_decode_length,
tokenizer=batch_processor.tokenizer,
)
micro_f1, macro_f1 = train_and_eval(
model=model,
train_loader=train_loader,
eval_loader=eval_loader,
batch_processor=batch_processor,
num_epochs=cfg.num_epochs,
lr=cfg.lr,
device=device,
label_column_name=train_dataset.label_column_name,
tokenized_verbalizer=tokenized_verbalizer,
)
logger.info(
"Evaluation micro-F1: {:.4f}, macro_f1: {:.4f}.".format(micro_f1, macro_f1)
)
# save evaluation results to json
with open("./results.json", "w") as f:
json.dump({"micro_f1": micro_f1, "macro_f1": macro_f1}, f, indent=4)
if __name__ == "__main__":
main()