Skip to content

Commit

Permalink
multiloss pytorch model changes
Browse files Browse the repository at this point in the history
  • Loading branch information
mdeleeuw1 authored Mar 27, 2024
1 parent 1788659 commit d7523b5
Showing 1 changed file with 20 additions and 12 deletions.
32 changes: 20 additions & 12 deletions hooknet/models/torchmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import torch.nn as nn
from torchvision.transforms.functional import center_crop


class HookNet(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -45,13 +44,18 @@ def __init__(
hook_channels=mid_channels,
hook_to_index=3,
)
self.last_conv = nn.Conv2d(self.high_mag_branch.decoder._out_channels[0], n_classes, 1)

self.high_last_conv = nn.Conv2d(self.high_mag_branch.decoder._out_channels[0], n_classes, 1)
self.mid_last_conv = nn.Conv2d(self.mid_mag_branch.decoder._out_channels[0], n_classes, 1)
self.low_last_conv = nn.Conv2d(self.low_mag_branch.decoder._out_channels[0], n_classes, 1)

def forward(self, high_input, mid_input, low_input):
low_out = self.low_mag_branch(low_input)
mid_out = self.mid_mag_branch(mid_input, low_out)
high_out = self.high_mag_branch(high_input, mid_out)
return {'out': self.last_conv(high_out)}
low_out, low_hook_out = self.low_mag_branch(low_input)
mid_out, mid_hook_out = self.mid_mag_branch(mid_input, low_hook_out)
high_out, high_hook_out = self.high_mag_branch(high_input, mid_hook_out)
return {'high_out': self.high_last_conv(high_out),
'mid_out': self.mid_last_conv(mid_out),
'low_out': self.low_last_conv(low_out)}


class Branch(nn.Module):
Expand Down Expand Up @@ -91,8 +95,8 @@ def __init__(
def forward(self, x, hook_in=None):
out, residuals = self.encoder(x)
out = self.mid_conv_block(out)
out = self.decoder(out, residuals, hook_in)
return out
out, hook_out = self.decoder(out, residuals, hook_in)
return out, hook_out


class Encoder(nn.Module):
Expand Down Expand Up @@ -156,18 +160,22 @@ def __init__(

def forward(self, x, residuals, hook_in=None):
out = x
hook_true = False
for d in reversed(range(self._depth)):
if hook_in is not None and d == self._hook_to_index:
out = concatenator(out, hook_in)

out = self._decode_path[f"upsample{d}"](out)
out = concatenator(out, residuals[d])
out = self._decode_path[f"convblock{d}"](out)

if self._hook_from_index is not None and d == self._hook_from_index:
return out
hook_out = out
hook_true = True

return out
if hook_true == False:
hook_out = out
return out, hook_out


class ConvBlock(nn.Module):
Expand Down Expand Up @@ -211,4 +219,4 @@ def forward(self, x):
def concatenator(x, x2):
x2_cropped = center_crop(x2, x.shape[-1])
conc = torch.cat([x, x2_cropped], dim=1)
return conc
return conc

0 comments on commit d7523b5

Please sign in to comment.