-
Notifications
You must be signed in to change notification settings - Fork 0
/
run.py
81 lines (72 loc) · 2.52 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from argparse import ArgumentParser
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
# import Recall_models as Models
import AUROC_models as Models
from data_module import MimicDataModule
models = {
"Lstm": Models.Lstm,
"Star": Models.Star,
"Encoder": Models.Encoder,
"Cnn": Models.Cnn,
"Bert": Models.Bert,
"MBertLstm": Models.MBertLstm,
"MBertStar": Models.MBertStar,
"MBertEncoder": Models.MBertEncoder,
"MBertCnn": Models.MBertCnn,
"MLstmBert": Models.MLstmBert,
"MStarBert": Models.MStarBert,
"MEncoderBert": Models.MEncoderBert,
"MCnnBert": Models.MCnnBert,
"LstmBertAttn": Models.LstmBertAttn,
"BertLstmAttn": Models.BertLstmAttn,
"LstmBertOuter": Models.LstmBertOuter,
# 'EncoderBertAttn': Models.EncoderBertAttn,
# 'BertEncoderAttn': Models.BertEncoderAttn,
# 'BertEncoderOuter': Models.BertEncoderOuter,
"Line": Models.Line,
}
def parse_args(args=None):
parser = ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = MimicDataModule.add_argparse_args(parser)
parser.add_argument("--model", type=str, default="Bert")
temp_args, _ = parser.parse_known_args()
parser = models[temp_args.model].add_model_specific_args(parser)
parser.add_argument("--seed", type=int, default=7)
return parser.parse_args(args)
def main(args):
pl.seed_everything(args.seed)
dm = MimicDataModule.from_argparse_args(args)
model = models[args.model](**vars(args))
name = args.model
project = f"{args.task}_{args.duration}_{args.model}"
print(f"Run {project}: {name}")
wandb_logger = WandbLogger(
entity="abr-ehr", project=project, name=name, offline=False
)
trainer = pl.Trainer.from_argparse_args(
args,
logger=wandb_logger,
callbacks=[
EarlyStopping(monitor="score", mode="max"),
ModelCheckpoint(
monitor="score",
mode="max",
dirpath=f"/l/users/mai.kassem/datasets/",
filename="{epoch}--{step}--{name}",
),
],
gpus=-1,
strategy="dp",
gradient_clip_val=1.0,
)
print("done with the trainer step")
trainer.fit(model, datamodule=dm)
"Done with trainer.fit"
trainer.test(model, datamodule=dm)
if __name__ == "__main__":
arguments = parse_args()
main(arguments)