Skip to content

Commit

Permalink
it's Monday let's go
Browse files Browse the repository at this point in the history
  • Loading branch information
ydshieh committed Dec 16, 2024
1 parent 395a636 commit 8a058d9
Showing 1 changed file with 53 additions and 0 deletions.
53 changes: 53 additions & 0 deletions tests/models/kosmos2_5/test_processor_kosmos2_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,59 @@ def test_model_input_names(self):
with pytest.raises(ValueError):
processor()

@require_torch
@require_vision
def test_unstructured_kwargs(self):
# Rewrite as KOSMOS-2.5 processor doesn't use `rescale_factor`
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")

processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)

input_str = self.prepare_text_inputs()
image_input = self.prepare_image_inputs()
inputs = processor(
text=input_str,
images=image_input,
return_tensors="pt",
max_patches=1024,
padding="max_length",
max_length=76,
)

self.assertEqual(inputs["flattened_patches"].shape[1], 1024)
self.assertEqual(len(inputs["input_ids"][0]), 76)

@require_torch
@require_vision
def test_unstructured_kwargs_batched(self):
# Rewrite as KOSMOS-2.5 processor doesn't use `rescale_factor`
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")

processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)

input_str = self.prepare_text_inputs(batch_size=2)
image_input = self.prepare_image_inputs(batch_size=2)
inputs = processor(
text=input_str,
images=image_input,
return_tensors="pt",
max_patches=1024,
padding="longest",
max_length=76,
)

self.assertEqual(inputs["flattened_patches"].shape[1], 1024)

self.assertEqual(len(inputs["input_ids"][0]), 76)

@require_torch
@require_vision
def test_structured_kwargs_nested(self):
Expand Down

0 comments on commit 8a058d9

Please sign in to comment.