forked from AndriyMulyar/bert_document_classification
-
Notifications
You must be signed in to change notification settings - Fork 0
/
training.py
109 lines (75 loc) · 4.2 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""
Unfortunately, one cannot include the exact data utilized to the train both the clinical models due to HIPPA constraints.
The data can be found here if you fill out the appropriate online forms:
https://portal.dbmi.hms.harvard.edu/data-challenges/
For training, simply alter the config.ini present in /examples file for your purposes. Relevant variables are:
model_storage_directory: directory to store logging information, tensorboard checkpoints, model checkpoints
bert_model_path: the file path to a pretrained bert mode. can be the pytorch-transformers alias.
labels: an ordered list of labels you are training against. this should match the order given in a .fit() instance.
"""
from sklearn.metrics import f1_score
from sklearn.preprocessing import LabelEncoder
from bert_document_classification.document_bert import BertForDocumentClassification
from pprint import pformat
import time, logging, torch, configargparse, os, socket
log = logging.getLogger()
def _initialize_arguments(p: configargparse.ArgParser):
p.add('--model_storage_directory', help='The directory caching all model runs')
p.add('--bert_model_path', help='Model path to BERT')
p.add('--labels', help='Numbers of labels to predict over', type=str)
p.add('--architecture', help='Training architecture', type=str)
p.add('--freeze_bert', help='Whether to freeze bert', type=bool)
p.add('--batch_size', help='Batch size for training multi-label document classifier', type=int)
p.add('--bert_batch_size', help='Batch size for feeding 510 token subsets of documents through BERT', type=int)
p.add('--epochs', help='Epochs to train', type=int)
#Optimizer arguments
p.add('--learning_rate', help='Optimizer step size', type=float)
p.add('--weight_decay', help='Adam regularization', type=float)
p.add('--evaluation_interval', help='Evaluate model on test set every evaluation_interval epochs', type=int)
p.add('--checkpoint_interval', help='Save a model checkpoint to disk every checkpoint_interval epochs', type=int)
#Non-config arguments
p.add('--cuda', action='store_true', help='Utilize GPU for training or prediction')
p.add('--device')
p.add('--timestamp', help='Run specific signature')
p.add('--model_directory', help='The directory storing this model run, a sub-directory of model_storage_directory')
p.add('--use_tensorboard', help='Use tensorboard logging', type=bool)
args = p.parse_args()
args.labels = [x for x in args.labels.split(', ')]
#Set run specific envirorment configurations
args.timestamp = time.strftime("run_%Y_%m_%d_%H_%M_%S") + "_{machine}".format(machine=socket.gethostname())
args.model_directory = os.path.join(args.model_storage_directory, args.timestamp) #directory
os.makedirs(args.model_directory, exist_ok=True)
#Handle logging configurations
log.handlers.clear()
formatter = logging.Formatter('%(message)s')
fh = logging.FileHandler(os.path.join(args.model_directory, "log.txt"))
fh.setLevel(logging.INFO)
fh.setFormatter(formatter)
log.addHandler(fh)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
ch.setFormatter(formatter)
log.setLevel(logging.INFO)
log.addHandler(ch)
log.info(p.format_values())
#Set global GPU state
if torch.cuda.is_available() and args.cuda:
if torch.cuda.device_count() > 1:
log.info("Using %i CUDA devices" % torch.cuda.device_count() )
else:
log.info("Using CUDA device:{0}".format(torch.cuda.current_device()))
args.device = 'cuda'
else:
log.info("Not using CUDA :(")
args.dev = 'cpu'
return args
if __name__ == "__main__":
p = configargparse.ArgParser(default_config_files=["config.ini"])
args = _initialize_arguments(p)
torch.cuda.empty_cache()
#documents and labels for training
train_documents, train_labels = ["first train document", "second train document"],[[0,1,0,0], [0,0,0,1]]
#documents and labels for development
dev_documents, dev_labels = ["first development document", "second development document"],[[0,0,1,0], [0,0,0,1]]
model = BertForDocumentClassification(args=args)
model.fit((train_documents, train_labels), (dev_documents,dev_labels))