From c205f7ec617b90c92872628db62ab8661bcfd079 Mon Sep 17 00:00:00 2001 From: Florian Vahl <7vahl@informatik.uni-hamburg.de> Date: Sat, 23 Nov 2024 16:00:30 +0100 Subject: [PATCH] Add deconvolution --- yoeo/models.py | 37 ++++++++++++++++++++++++++++++++++--- 1 file changed, 34 insertions(+), 3 deletions(-) diff --git a/yoeo/models.py b/yoeo/models.py index 1b69cd8..cd964c0 100644 --- a/yoeo/models.py +++ b/yoeo/models.py @@ -60,8 +60,40 @@ def create_modules(module_defs): nn.BatchNorm2d(filters, momentum=0.1, eps=1e-5)) if module_def["activation"] == "leaky": modules.add_module(f"leaky_{module_i}", nn.LeakyReLU(0.1)) - if module_def["activation"] == "mish": + elif module_def["activation"] == "mish": modules.add_module(f"mish_{module_i}", Mish()) + elif module_def["activation"] == "linear": + pass + else: + raise ValueError(f"Unknown activation: {module_def['activation']}") + + elif module_def["type"] == "deconvolutional": + bn = int(module_def["batch_normalize"]) + filters = int(module_def["filters"]) + kernel_size = int(module_def["size"]) + pad = int(module_def["pad"]) + modules.add_module( + f"deconv_{module_i}", + nn.ConvTranspose2d( + in_channels=output_filters[-1], + out_channels=filters, + kernel_size=kernel_size, + stride=int(module_def["stride"]), + padding=pad, + bias=not bn, + ), + ) + if bn: + modules.add_module(f"batch_norm_{module_i}", + nn.BatchNorm2d(filters, momentum=0.1, eps=1e-5)) + if module_def["activation"] == "leaky": + modules.add_module(f"leaky_{module_i}", nn.LeakyReLU(0.1)) + elif module_def["activation"] == "mish": + modules.add_module(f"mish_{module_i}", Mish()) + elif module_def["activation"] == "linear": + pass + else: + raise ValueError(f"Unknown activation: {module_def['activation']}") elif module_def["type"] == "maxpool": kernel_size = int(module_def["size"]) @@ -197,10 +229,9 @@ def __init__(self, config_path): def forward(self, x, bb_targets=None, mask_targets=None): img_size = x.size(2) - loss = 0 layer_outputs, yolo_outputs, segmentation_outputs = [], [], [] for i, (module_def, module) in enumerate(zip(self.module_defs, self.module_list)): - if module_def["type"] in ["convolutional", "upsample", "maxpool"]: + if module_def["type"] in ["convolutional", "deconvolutional", "upsample", "maxpool"]: x = module(x) elif module_def["type"] == "route": combined_outputs = torch.cat([layer_outputs[int(layer_i)] for layer_i in module_def["layers"].split(",")], 1)