diff --git a/hooknet/models/torchmodel.py b/hooknet/models/torchmodel.py index 2b10148..1401933 100644 --- a/hooknet/models/torchmodel.py +++ b/hooknet/models/torchmodel.py @@ -2,7 +2,6 @@ import torch.nn as nn from torchvision.transforms.functional import center_crop - class HookNet(nn.Module): def __init__( self, @@ -45,13 +44,18 @@ def __init__( hook_channels=mid_channels, hook_to_index=3, ) - self.last_conv = nn.Conv2d(self.high_mag_branch.decoder._out_channels[0], n_classes, 1) + + self.high_last_conv = nn.Conv2d(self.high_mag_branch.decoder._out_channels[0], n_classes, 1) + self.mid_last_conv = nn.Conv2d(self.mid_mag_branch.decoder._out_channels[0], n_classes, 1) + self.low_last_conv = nn.Conv2d(self.low_mag_branch.decoder._out_channels[0], n_classes, 1) def forward(self, high_input, mid_input, low_input): - low_out = self.low_mag_branch(low_input) - mid_out = self.mid_mag_branch(mid_input, low_out) - high_out = self.high_mag_branch(high_input, mid_out) - return {'out': self.last_conv(high_out)} + low_out, low_hook_out = self.low_mag_branch(low_input) + mid_out, mid_hook_out = self.mid_mag_branch(mid_input, low_hook_out) + high_out, high_hook_out = self.high_mag_branch(high_input, mid_hook_out) + return {'high_out': self.high_last_conv(high_out), + 'mid_out': self.mid_last_conv(mid_out), + 'low_out': self.low_last_conv(low_out)} class Branch(nn.Module): @@ -91,8 +95,8 @@ def __init__( def forward(self, x, hook_in=None): out, residuals = self.encoder(x) out = self.mid_conv_block(out) - out = self.decoder(out, residuals, hook_in) - return out + out, hook_out = self.decoder(out, residuals, hook_in) + return out, hook_out class Encoder(nn.Module): @@ -156,6 +160,7 @@ def __init__( def forward(self, x, residuals, hook_in=None): out = x + hook_true = False for d in reversed(range(self._depth)): if hook_in is not None and d == self._hook_to_index: out = concatenator(out, hook_in) @@ -163,11 +168,14 @@ def forward(self, x, residuals, hook_in=None): out = self._decode_path[f"upsample{d}"](out) out = concatenator(out, residuals[d]) out = self._decode_path[f"convblock{d}"](out) - + if self._hook_from_index is not None and d == self._hook_from_index: - return out + hook_out = out + hook_true = True - return out + if hook_true == False: + hook_out = out + return out, hook_out class ConvBlock(nn.Module): @@ -211,4 +219,4 @@ def forward(self, x): def concatenator(x, x2): x2_cropped = center_crop(x2, x.shape[-1]) conc = torch.cat([x, x2_cropped], dim=1) - return conc \ No newline at end of file + return conc