Skip to content

Commit

Permalink
[#21, #39] Feature : add scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
Yoon0717 committed Nov 21, 2024
1 parent f068bef commit 1ee5c6b
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 7 deletions.
5 changes: 4 additions & 1 deletion module_base/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,15 @@ max_epoch: 50

# loss
loss:
name: DiceBCELoss
name: BCEWithLogitsLoss
params: null

# optimizer
optim: Adam

# scheduler
scheduler: cosine

# validation 인자
kfold: 0
val_every: 2
Expand Down
8 changes: 6 additions & 2 deletions module_base/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,16 @@ def forward(self, preds, targets):

dice = (2. * intersection + self.eps) / (torch.sum(preds_f, -1) + torch.sum(targets_f, -1) + self.eps)
loss = 1 - dice

return loss.mean()

class IoULoss(nn.Module):
def __init__(self, **kwargs):
super(IoULoss, self).__init__()

class DiceBCELoss(nn.Module):
def __init__(self, **kwargs):
super(DiceBCELoss, self).__init__(**kwargs)
super(DiceBCELoss, self).__init__()
self.bceWithLogitLoss = nn.BCEWithLogitsLoss(**kwargs)
self.diceLoss = DiceLoss()

Expand Down
13 changes: 13 additions & 0 deletions module_base/scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from torch.optim import lr_scheduler

class SchedulerSelector:
def __init__(self, sched, optimizer, epoch):
self.optimizer = optimizer
self.epoch = epoch
if sched == 'step':
self.scheduler = lr_scheduler.StepLR(self.optimizer, step=5, gamma=0.6)
elif sched == 'cosine':
self.scheduler = lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=self.epoch, eta_min=1e-6)

def get_sched(self):
return self.scheduler
3 changes: 2 additions & 1 deletion module_base/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from model import ModelSelector
from transform import TransformSelector
from loss import LossSelector
from scheduler import SchedulerSelector

import warnings
warnings.filterwarnings('ignore')
Expand Down Expand Up @@ -138,7 +139,7 @@ def train(model, train_loader, val_loader, criterion, optimizer, save_dir, rando
scaler.update()

# step 주기에 따라 loss를 출력합니다.
if (step + 1) % 25 == 0:
if (step + 1) % 80 == 0:
print(
f'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")} | '
f'Epoch [{epoch+1}/{max_epoch}], '
Expand Down
2 changes: 1 addition & 1 deletion module_base/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def __init__(self, is_train, resize):
self.transform = A.Compose(
[
A.HorizontalFlip(0.2),
A.CLAHE(clip_limit=4.0, tile_grid_size=(8, 8), always_apply=False, p=1.0)
# A.CLAHE(clip_limit=4.0, tile_grid_size=(8, 8), always_apply=False, p=1.0)
]+ common_transform)
else:
self.transform = A.Compose(common_transform)
Expand Down
9 changes: 7 additions & 2 deletions module_base/wandb_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from model import ModelSelector
from transform import TransformSelector
from loss import LossSelector
from scheduler import SchedulerSelector

import warnings
warnings.filterwarnings('ignore')
Expand Down Expand Up @@ -111,7 +112,7 @@ def validation(epoch, model, val_loader, criterion, model_type, thr=0.5):

return avg_dice, val_loss

def train(model, train_loader, val_loader, criterion, optimizer, cfg):
def train(model, train_loader, val_loader, criterion, optimizer, scheduler, cfg):
print(f'Start training..')
logger = WandbLogger(name=cfg.wandb_run_name)

Expand Down Expand Up @@ -158,6 +159,7 @@ def train(model, train_loader, val_loader, criterion, optimizer, cfg):
f'Step [{step+1}/{len(train_loader)}], '
f'Loss: {round(loss.item(),4)}'
)
scheduler.step()
epoch_time = datetime.timedelta(seconds=time.time() - epoch_start)
dataset_size = len(train_loader.dataset)
epoch_loss = epoch_loss / dataset_size
Expand Down Expand Up @@ -231,8 +233,11 @@ def main(cfg):
criterion = loss.get_loss()

optimizer = optim.Adam(params=model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)

sched = SchedulerSelector(cfg.scheduler, optimizer, cfg.max_epoch)
scheduler = sched.get_sched()

train(model, train_loader, valid_loader, criterion, optimizer, cfg)
train(model, train_loader, valid_loader, criterion, optimizer, scheduler, cfg)

if __name__ == '__main__':
parser = ArgumentParser()
Expand Down

0 comments on commit 1ee5c6b

Please sign in to comment.