diff --git a/src/transformers/models/vivit/modeling_vivit.py b/src/transformers/models/vivit/modeling_vivit.py index 972040264fec31..3d543503284489 100755 --- a/src/transformers/models/vivit/modeling_vivit.py +++ b/src/transformers/models/vivit/modeling_vivit.py @@ -104,9 +104,10 @@ def __init__(self, config): torch.zeros(1, self.patch_embeddings.num_patches + 1, config.hidden_size) ) self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.patch_size = config.tubelet_size[1:] self.config = config - # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding + # Adapted from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: """ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution @@ -129,8 +130,8 @@ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: dim = embeddings.shape[-1] - new_height = height // self.patch_size - new_width = width // self.patch_size + new_height = height // self.patch_size[0] + new_width = width // self.patch_size[1] sqrt_num_positions = torch_int(num_positions**0.5) patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) diff --git a/tests/models/vivit/test_modeling_vivit.py b/tests/models/vivit/test_modeling_vivit.py index 19a179a6a3e030..7cce77e6fc0019 100644 --- a/tests/models/vivit/test_modeling_vivit.py +++ b/tests/models/vivit/test_modeling_vivit.py @@ -359,12 +359,12 @@ def test_inference_interpolate_pos_encoding(self): # allowing to interpolate the pre-trained position embeddings in order to use # the model on higher resolutions. The DINO model by Facebook AI leverages this # to visualize self-attention on higher resolution images. - model = VivitModel.from_pretrained("google/vivit-b-16x2").to(torch_device) + model = VivitModel.from_pretrained("google/vivit-b-16x2-kinetics400").to(torch_device) - image_processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2") + image_processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2-kinetics400") video = prepare_video() inputs = image_processor( - video, size={"shortest_edge": 480}, crop_size={"height": 480, "width": 480}, return_tensors="pt" + video, size={"shortest_edge": 480}, crop_size={"height": 232, "width": 232}, return_tensors="pt" ) pixel_values = inputs.pixel_values.to(torch_device)