diff --git a/src/transformers/models/tvp/modeling_tvp.py b/src/transformers/models/tvp/modeling_tvp.py index bf1fc5afbaac9e..65fac8b3a0e0ce 100644 --- a/src/transformers/models/tvp/modeling_tvp.py +++ b/src/transformers/models/tvp/modeling_tvp.py @@ -606,12 +606,15 @@ def __init__(self, config): def forward(self, pixel_values): if self.visual_prompter_apply != "add": - visual_prompt_mask = torch.ones([self.max_img_size, self.max_img_size], dtype=pixel_values.dtype) + visual_prompt_mask = torch.ones( + [self.max_img_size, self.max_img_size], dtype=pixel_values.dtype, device=pixel_values.device + ) visual_prompt_mask[self.max_img_size - self.visual_prompt_size : self.max_img_size, :] = 0.0 pixel_values *= visual_prompt_mask if self.visual_prompter_apply != "remove": prompt = torch.zeros( - [pixel_values.shape[0], pixel_values.shape[1], 3, self.max_img_size, self.max_img_size] + [pixel_values.shape[0], pixel_values.shape[1], 3, self.max_img_size, self.max_img_size], + device=pixel_values.device, ) start_point = self.max_img_size - self.visual_prompt_size prompt[:, :, :, start_point : self.max_img_size, :] = self.pad_down @@ -667,10 +670,12 @@ def forward(self, pixel_values): if self.visual_prompter_apply not in ("add", "remove", "replace"): raise ValueError(f"Invalid visual_prompter_apply value {self.visual_prompter_apply}") if self.visual_prompter_apply in ("replace", "remove"): - visual_prompt_mask = torch.ones([self.max_img_size, self.max_img_size], dtype=pixel_values.dtype) + visual_prompt_mask = torch.ones( + [self.max_img_size, self.max_img_size], dtype=pixel_values.dtype, device=pixel_values.device + ) pixel_values *= visual_prompt_mask if self.visual_prompter_apply in ("replace", "add"): - base = torch.zeros(1, self.num_frames, 3, self.base_size, self.base_size) + base = torch.zeros(1, self.num_frames, 3, self.base_size, self.base_size, device=pixel_values.device) prompt = torch.cat([self.pad_left, base, self.pad_right], dim=4) prompt = torch.cat([self.pad_up, prompt, self.pad_down], dim=3) prompt = torch.cat(pixel_values.size(0) * [prompt]) diff --git a/tests/models/tvp/test_modeling_tvp.py b/tests/models/tvp/test_modeling_tvp.py index b81635888c74d5..ebdd4fb0b7694b 100644 --- a/tests/models/tvp/test_modeling_tvp.py +++ b/tests/models/tvp/test_modeling_tvp.py @@ -176,6 +176,9 @@ class TVPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): else {} ) + # TODO: Enable this once this model gets more usage + test_torchscript = False + def setUp(self): self.model_tester = TVPModelTester(self)