-
Notifications
You must be signed in to change notification settings - Fork 4
/
sweep.py
88 lines (70 loc) · 2.89 KB
/
sweep.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from args import parse_arguments
from typing import Any
import torch
import pytorch_lightning as pl
from dataloader.dataloader import STSDataModule
import wandb
from pytorch_lightning.loggers import WandbLogger
def main(config: Any) -> None:
"""
WandB sweep 설정을 통한 하이퍼파라미터 튜닝 수행, config.json 파일로 설정 변경 가능
Args:
config: 사용자 정의 설정파일, sweep 조절 인자와 그렇지 않은 인자가 모두 포함됨
"""
# Sweep 통해 실행될 학습 코드 생성
def sweep_train(config: Any = config) -> None:
wandb.init(entity=config.wandb['entity'], # 기본값: 'salmons'
project=config.wandb['sweep_project_name'])
sweep_config = wandb.config
# 베이스라인 모델 혹은 GRU가 부가된 모델을 설정할 것인지 결정
model_class = (
"gru_model.GRUModel"
if config.arch['args']['gru_enabled']
else "model.Model"
)
module_name, class_name = model_class.split('.')
model_module = __import__('models.' + module_name,
fromlist=[class_name])
ModelClass = getattr(model_module, class_name)
# dataloader와 model을 정의합니다.
dataloader = STSDataModule(
model_name=config.arch['type'],
batch_size=sweep_config['batch_size'],
shuffle=config.dataloader['args']['shuffle'],
dataset_commit_hash=config.dataloader['args']
['dataset_commit_hash'],
num_workers=config.dataloader['args']['num_workers'],
)
model = ModelClass(config.arch['type'],
sweep_config.optimizer,
sweep_config.lr,
sweep_config.loss_function,
config.loss['args']['beta'],
config.loss['args']['bce'],
config.lr_scheduler['is_schedule'])
wandb_logger = WandbLogger()
# 가속기 설정
accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'
# Trainer 정의
trainer = pl.Trainer(accelerator=accelerator,
devices=1,
max_epochs=sweep_config.epochs,
log_every_n_steps=1,
logger=wandb_logger,
precision=16)
# 학습
trainer.fit(model=model, datamodule=dataloader)
# 평가
trainer.test(model=model, datamodule=dataloader)
# Sweep 생성
sweep_id = wandb.sweep(
sweep=config.sweep_config
)
wandb.agent(
sweep_id=sweep_id,
function=sweep_train,
count=config.wandb['sweep_count']
)
if __name__ == '__main__':
config = parse_arguments()
main(config)