Skip to content

Commit

Permalink
add dice loss
Browse files Browse the repository at this point in the history
  • Loading branch information
Yoon0717 committed Nov 21, 2024
1 parent 37078a3 commit f068bef
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 4 deletions.
8 changes: 4 additions & 4 deletions module_base/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ save_dir: /data/ephemeral/home/data/result
# 모델 라이브러리 및 사용 모델 정의
model_type: smp
model_name: UnetPlusPlus
encoder_name: efficientnet-b5
encoder_name: tu-hrnet_w64
encoder_weights: imagenet
pretrained: True

Expand All @@ -19,11 +19,11 @@ lr: 2e-4
weight_decay: 1e-6
train_num_workers: 8
valid_num_workers: 0
max_epoch: 30
max_epoch: 50

# loss
loss:
name: BCEWithLogitsLoss
name: DiceBCELoss
params: null

# optimizer
Expand All @@ -37,7 +37,7 @@ threshold: 0.5
# random seed
random_seed: 2024

# augmentation
# resize
size: 1024

# wandb
Expand Down
35 changes: 35 additions & 0 deletions module_base/loss.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

class BCE(nn.Module):
def __init__(self, **kwargs):
Expand All @@ -8,10 +10,43 @@ def __init__(self, **kwargs):
def forward(self, preds, targets):
return self.loss(preds, targets)

class DiceLoss(nn.Module):
def __init__(self, eps=1e-6):
super(DiceLoss, self).__init__()
self.eps = eps

def forward(self, preds, targets):
preds = F.sigmoid(preds)

preds_f = preds.flatten(2)
targets_f = targets.flatten(2)
intersection = torch.sum(preds_f * targets_f, -1)

dice = (2. * intersection + self.eps) / (torch.sum(preds_f, -1) + torch.sum(targets_f, -1) + self.eps)
loss = 1 - dice

return loss.mean()

class DiceBCELoss(nn.Module):
def __init__(self, **kwargs):
super(DiceBCELoss, self).__init__(**kwargs)
self.bceWithLogitLoss = nn.BCEWithLogitsLoss(**kwargs)
self.diceLoss = DiceLoss()

def forward(self, preds, targets):
bce_loss = self.bceWithLogitLoss(preds, targets)
dice_loss = self.diceLoss(preds, targets)
dice_bce_loss = bce_loss + dice_loss

return dice_bce_loss


class LossSelector:
def __init__(self, loss, **kwargs):
if loss == 'BCEWithLogitsLoss':
self.loss = BCE(**kwargs)
elif loss == 'DiceBCELoss':
self.loss = DiceBCELoss(**kwargs)

def get_loss(self):
return self.loss

0 comments on commit f068bef

Please sign in to comment.