Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SandwichConv returns non-square image size #1

Open
nic-barbara opened this issue Apr 18, 2024 · 0 comments
Open

SandwichConv returns non-square image size #1

nic-barbara opened this issue Apr 18, 2024 · 0 comments

Comments

@nic-barbara
Copy link
Member

nic-barbara commented Apr 18, 2024

If the image size for any SandwichConv layer is not n x n with n/2 even, then the layer returns a non-square image. Minimal example below.

import torch
import torch.nn as nn
import torch.nn.functional as F
from layer import SandwichConv

class CNN(nn.Module):
    def __init__(self):
        super().__init__()

        self.l1 = SandwichConv(4, 32, 8, stride=2)
        self.l2 = SandwichConv(32, 64, 4, stride=2)
        self.l3 = SandwichConv(64, 64, 3, stride=1)
        
        def run_hidden(x):
            print("in: ", x.shape)
            x = self.l1(x)
            print("1: ", x.shape)
            x = self.l2(x)
            print("2: ", x.shape)
            x = self.l3(x)
            print("3: ", x.shape)
            return x

        self.model = run_hidden

    def forward(self, x):
        return self.model(x)
        
if __name__ == '__main__':
    batch, cin, n = 8, 4, 84
    model = CNN()
    x = torch.randn((batch, cin, n, n))
    y = model(x)

Running this code prints out the following image sizes between the layers.

in:  torch.Size([8, 4, 84, 84])
1:  torch.Size([8, 32, 42, 42])
2:  torch.Size([8, 64, 21, 20])
3:  torch.Size([8, 64, 21, 20])

A quick fix is to just pad the input image where required by changing the forward pass as follows:

  def forward(self, x):
      x = F.pad(x, (2, 2, 2, 2), "constant", 0)
      return self.model(x)

which returns

in:  torch.Size([8, 4, 88, 88])
1:  torch.Size([8, 32, 44, 44])
2:  torch.Size([8, 64, 22, 22])
3:  torch.Size([8, 64, 22, 22])

We should probably document this somewhere in more detail.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant