-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
94 lines (78 loc) · 2.78 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from typing import Optional
import torch
import random
import numpy as np
from dataclasses import dataclass, field
from transformers import HfArgumentParser, TrainingArguments
def set_seed(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def get_device():
return 'cuda' if torch.cuda.is_available() else 'cpu'
def get_hf_args():
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# Use quick setup to test if the entire framework works
if data_args.quick_run:
training_args.num_train_epochs = 2
data_args.data_fraction = 0.02
model_args.model_name_or_path = "albert-base-v2"
return model_args, data_args, training_args
def debug_print_hf_dataset_sample(hf_dataset, tokenizer):
for sample in hf_dataset:
print("First sample:")
print(sample)
input_ids = sample['input_ids']
print("Encoded ids:")
print(input_ids)
decoded = tokenizer.decode(input_ids)
print("Decoded ids:")
print(decoded)
break
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
hp_search: Optional[bool] = field(
default=False, metadata={"help": "Whether to perform grid-search or just train with given parameters"}
)
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
task: str = field(
metadata={
"help": "the task/dataset to train/evaluate on, e.g. SweParaphrase"
}
)
max_input_length: int = field(
default=512,
metadata={
"help": "The sequence length of samples, longer samples than this value will be truncated."
},
)
data_fraction: Optional[float] = field(
default=1.0,
metadata={
"help": "The fraction of the datasets to use, e.g. 0.2 will only use 20% of each dataset split."
}
)
quick_run: Optional[bool] = field(
default=False,
metadata={
"help": "If set to true, run with small search-space and a small subset of the data."
}
)