Skip to content

Commit

Permalink
Add deconvolution
Browse files Browse the repository at this point in the history
  • Loading branch information
Flova committed Nov 23, 2024
1 parent 3d58235 commit c205f7e
Showing 1 changed file with 34 additions and 3 deletions.
37 changes: 34 additions & 3 deletions yoeo/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,40 @@ def create_modules(module_defs):
nn.BatchNorm2d(filters, momentum=0.1, eps=1e-5))
if module_def["activation"] == "leaky":
modules.add_module(f"leaky_{module_i}", nn.LeakyReLU(0.1))
if module_def["activation"] == "mish":
elif module_def["activation"] == "mish":
modules.add_module(f"mish_{module_i}", Mish())
elif module_def["activation"] == "linear":
pass
else:
raise ValueError(f"Unknown activation: {module_def['activation']}")

elif module_def["type"] == "deconvolutional":
bn = int(module_def["batch_normalize"])
filters = int(module_def["filters"])
kernel_size = int(module_def["size"])
pad = int(module_def["pad"])
modules.add_module(
f"deconv_{module_i}",
nn.ConvTranspose2d(
in_channels=output_filters[-1],
out_channels=filters,
kernel_size=kernel_size,
stride=int(module_def["stride"]),
padding=pad,
bias=not bn,
),
)
if bn:
modules.add_module(f"batch_norm_{module_i}",
nn.BatchNorm2d(filters, momentum=0.1, eps=1e-5))
if module_def["activation"] == "leaky":
modules.add_module(f"leaky_{module_i}", nn.LeakyReLU(0.1))
elif module_def["activation"] == "mish":
modules.add_module(f"mish_{module_i}", Mish())
elif module_def["activation"] == "linear":
pass
else:
raise ValueError(f"Unknown activation: {module_def['activation']}")

elif module_def["type"] == "maxpool":
kernel_size = int(module_def["size"])
Expand Down Expand Up @@ -197,10 +229,9 @@ def __init__(self, config_path):

def forward(self, x, bb_targets=None, mask_targets=None):
img_size = x.size(2)
loss = 0
layer_outputs, yolo_outputs, segmentation_outputs = [], [], []
for i, (module_def, module) in enumerate(zip(self.module_defs, self.module_list)):
if module_def["type"] in ["convolutional", "upsample", "maxpool"]:
if module_def["type"] in ["convolutional", "deconvolutional", "upsample", "maxpool"]:
x = module(x)
elif module_def["type"] == "route":
combined_outputs = torch.cat([layer_outputs[int(layer_i)] for layer_i in module_def["layers"].split(",")], 1)
Expand Down

0 comments on commit c205f7e

Please sign in to comment.