From 9eb3b88c5f3b6009560990a55c4aca7d62e8af60 Mon Sep 17 00:00:00 2001 From: Zhijian Liu <5782437+zhijian-liu@users.noreply.github.com> Date: Wed, 31 Jul 2024 11:15:10 -0400 Subject: [PATCH] Fix the dimension index (#321) --- torchsparse/nn/modules/bev.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchsparse/nn/modules/bev.py b/torchsparse/nn/modules/bev.py index dac3a33..10f270b 100644 --- a/torchsparse/nn/modules/bev.py +++ b/torchsparse/nn/modules/bev.py @@ -93,7 +93,7 @@ def forward(self, input: SparseTensor) -> torch.Tensor: self.kernel, 0, torch.div(coords[:, self.dim], stride).trunc().long() ) feats = (feats.unsqueeze(dim=-1) * kernel).sum(1) + self.bias - coords = (coords - self.offset).t()[[3] + self.bev_dims].long() + coords = (coords - self.offset).t()[[0] + self.bev_dims].long() coords[1:] = torch.div(coords[1:], stride).trunc().long() indices = ( coords[0] * int(self.bev_shape.prod()) @@ -197,7 +197,7 @@ def forward(self, input: SparseTensor) -> torch.Tensor: assert isinstance(stride, torch.Tensor), type(stride) # [b, x, y, z] - coords = (coords - self.offset).t()[[3] + self.bev_dims + [self.dim]].long() + coords = (coords - self.offset).t()[[0] + self.bev_dims + [self.dim]].long() shape = self.shape[self.bev_dims + [self.dim]] # now stride must be torch.Tensor since input.s is tuple.