Skip to content

Commit

Permalink
Fix TVPModelTest (#27695)
Browse files Browse the repository at this point in the history
* fix

* fix

* fix

* fix

* fix

---------

Co-authored-by: ydshieh <[email protected]>
  • Loading branch information
ydshieh and ydshieh authored Nov 24, 2023
1 parent 29c9480 commit 35551f9
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
13 changes: 9 additions & 4 deletions src/transformers/models/tvp/modeling_tvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,12 +606,15 @@ def __init__(self, config):

def forward(self, pixel_values):
if self.visual_prompter_apply != "add":
visual_prompt_mask = torch.ones([self.max_img_size, self.max_img_size], dtype=pixel_values.dtype)
visual_prompt_mask = torch.ones(
[self.max_img_size, self.max_img_size], dtype=pixel_values.dtype, device=pixel_values.device
)
visual_prompt_mask[self.max_img_size - self.visual_prompt_size : self.max_img_size, :] = 0.0
pixel_values *= visual_prompt_mask
if self.visual_prompter_apply != "remove":
prompt = torch.zeros(
[pixel_values.shape[0], pixel_values.shape[1], 3, self.max_img_size, self.max_img_size]
[pixel_values.shape[0], pixel_values.shape[1], 3, self.max_img_size, self.max_img_size],
device=pixel_values.device,
)
start_point = self.max_img_size - self.visual_prompt_size
prompt[:, :, :, start_point : self.max_img_size, :] = self.pad_down
Expand Down Expand Up @@ -667,10 +670,12 @@ def forward(self, pixel_values):
if self.visual_prompter_apply not in ("add", "remove", "replace"):
raise ValueError(f"Invalid visual_prompter_apply value {self.visual_prompter_apply}")
if self.visual_prompter_apply in ("replace", "remove"):
visual_prompt_mask = torch.ones([self.max_img_size, self.max_img_size], dtype=pixel_values.dtype)
visual_prompt_mask = torch.ones(
[self.max_img_size, self.max_img_size], dtype=pixel_values.dtype, device=pixel_values.device
)
pixel_values *= visual_prompt_mask
if self.visual_prompter_apply in ("replace", "add"):
base = torch.zeros(1, self.num_frames, 3, self.base_size, self.base_size)
base = torch.zeros(1, self.num_frames, 3, self.base_size, self.base_size, device=pixel_values.device)
prompt = torch.cat([self.pad_left, base, self.pad_right], dim=4)
prompt = torch.cat([self.pad_up, prompt, self.pad_down], dim=3)
prompt = torch.cat(pixel_values.size(0) * [prompt])
Expand Down
3 changes: 3 additions & 0 deletions tests/models/tvp/test_modeling_tvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,9 @@ class TVPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
else {}
)

# TODO: Enable this once this model gets more usage
test_torchscript = False

def setUp(self):
self.model_tester = TVPModelTester(self)

Expand Down

0 comments on commit 35551f9

Please sign in to comment.