diff --git a/module_base/config.yaml b/module_base/config.yaml index 3d0d52c..3569725 100644 --- a/module_base/config.yaml +++ b/module_base/config.yaml @@ -23,12 +23,15 @@ max_epoch: 50 # loss loss: - name: DiceBCELoss + name: BCEWithLogitsLoss params: null # optimizer optim: Adam +# scheduler +scheduler: cosine + # validation 인자 kfold: 0 val_every: 2 diff --git a/module_base/loss.py b/module_base/loss.py index 938af30..9716b46 100644 --- a/module_base/loss.py +++ b/module_base/loss.py @@ -24,12 +24,16 @@ def forward(self, preds, targets): dice = (2. * intersection + self.eps) / (torch.sum(preds_f, -1) + torch.sum(targets_f, -1) + self.eps) loss = 1 - dice - + return loss.mean() +class IoULoss(nn.Module): + def __init__(self, **kwargs): + super(IoULoss, self).__init__() + class DiceBCELoss(nn.Module): def __init__(self, **kwargs): - super(DiceBCELoss, self).__init__(**kwargs) + super(DiceBCELoss, self).__init__() self.bceWithLogitLoss = nn.BCEWithLogitsLoss(**kwargs) self.diceLoss = DiceLoss() diff --git a/module_base/scheduler.py b/module_base/scheduler.py new file mode 100644 index 0000000..9aa167c --- /dev/null +++ b/module_base/scheduler.py @@ -0,0 +1,13 @@ +from torch.optim import lr_scheduler + +class SchedulerSelector: + def __init__(self, sched, optimizer, epoch): + self.optimizer = optimizer + self.epoch = epoch + if sched == 'step': + self.scheduler = lr_scheduler.StepLR(self.optimizer, step=5, gamma=0.6) + elif sched == 'cosine': + self.scheduler = lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=self.epoch, eta_min=1e-6) + + def get_sched(self): + return self.scheduler \ No newline at end of file diff --git a/module_base/train.py b/module_base/train.py index 0acbed3..d567e97 100644 --- a/module_base/train.py +++ b/module_base/train.py @@ -17,6 +17,7 @@ from model import ModelSelector from transform import TransformSelector from loss import LossSelector +from scheduler import SchedulerSelector import warnings warnings.filterwarnings('ignore') @@ -138,7 +139,7 @@ def train(model, train_loader, val_loader, criterion, optimizer, save_dir, rando scaler.update() # step 주기에 따라 loss를 출력합니다. - if (step + 1) % 25 == 0: + if (step + 1) % 80 == 0: print( f'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")} | ' f'Epoch [{epoch+1}/{max_epoch}], ' diff --git a/module_base/transform.py b/module_base/transform.py index cc2b6fa..37952bc 100644 --- a/module_base/transform.py +++ b/module_base/transform.py @@ -8,7 +8,7 @@ def __init__(self, is_train, resize): self.transform = A.Compose( [ A.HorizontalFlip(0.2), - A.CLAHE(clip_limit=4.0, tile_grid_size=(8, 8), always_apply=False, p=1.0) + # A.CLAHE(clip_limit=4.0, tile_grid_size=(8, 8), always_apply=False, p=1.0) ]+ common_transform) else: self.transform = A.Compose(common_transform) diff --git a/module_base/wandb_train.py b/module_base/wandb_train.py index b7346d9..0ec1cdc 100644 --- a/module_base/wandb_train.py +++ b/module_base/wandb_train.py @@ -18,6 +18,7 @@ from model import ModelSelector from transform import TransformSelector from loss import LossSelector +from scheduler import SchedulerSelector import warnings warnings.filterwarnings('ignore') @@ -111,7 +112,7 @@ def validation(epoch, model, val_loader, criterion, model_type, thr=0.5): return avg_dice, val_loss -def train(model, train_loader, val_loader, criterion, optimizer, cfg): +def train(model, train_loader, val_loader, criterion, optimizer, scheduler, cfg): print(f'Start training..') logger = WandbLogger(name=cfg.wandb_run_name) @@ -158,6 +159,7 @@ def train(model, train_loader, val_loader, criterion, optimizer, cfg): f'Step [{step+1}/{len(train_loader)}], ' f'Loss: {round(loss.item(),4)}' ) + scheduler.step() epoch_time = datetime.timedelta(seconds=time.time() - epoch_start) dataset_size = len(train_loader.dataset) epoch_loss = epoch_loss / dataset_size @@ -231,8 +233,11 @@ def main(cfg): criterion = loss.get_loss() optimizer = optim.Adam(params=model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay) + + sched = SchedulerSelector(cfg.scheduler, optimizer, cfg.max_epoch) + scheduler = sched.get_sched() - train(model, train_loader, valid_loader, criterion, optimizer, cfg) + train(model, train_loader, valid_loader, criterion, optimizer, scheduler, cfg) if __name__ == '__main__': parser = ArgumentParser()