Skip to content

Commit

Permalink
convert tensor to list
Browse files Browse the repository at this point in the history
  • Loading branch information
nkaenzig committed Oct 30, 2024
1 parent 38b0707 commit b50d68d
Showing 1 changed file with 2 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def _forward_features(self, features: torch.Tensor | List[torch.Tensor]) -> torc
A tensor of shape (batch_size, hidden_size, n_patches_height,
n_patches_width) which is feature map of the decoder head.
"""
if isinstance(features, torch.Tensor):
features = [features]
if not isinstance(features, list) or features[0].ndim != 4:
raise ValueError(
"Input features should be a list of four (4) dimensional inputs of "
Expand Down

0 comments on commit b50d68d

Please sign in to comment.