Skip to content

Commit

Permalink
Fix image processing in textnet
Browse files Browse the repository at this point in the history
  • Loading branch information
raghavanone committed Nov 27, 2023
1 parent 7e15fbd commit ef4671d
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 13 deletions.
11 changes: 2 additions & 9 deletions src/transformers/models/textnet/image_processing_textnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ def resize(
resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
default_to_square: bool = False,
**kwargs,
) -> np.ndarray:
"""
Expand Down Expand Up @@ -158,7 +157,7 @@ def resize(
raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.")

output_size = get_resize_output_image_size(
image, size=size, input_data_format=input_data_format, default_to_square=default_to_square
image, size=size, input_data_format=input_data_format, default_to_square=False
)
height, weight = output_size
if height % 32 != 0:
Expand Down Expand Up @@ -194,7 +193,6 @@ def preprocess(
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
default_to_square: bool = False,
**kwargs,
) -> PIL.Image.Image:
"""
Expand Down Expand Up @@ -247,14 +245,10 @@ def preprocess(
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
default_to_square (`bool`, *optional*, defaults to `False`):
The value to be passed to `get_size_dict` as `default_to_square` when computing the image size. If the
`size` argument in `get_size_dict` is an `int`, it determines whether to default to a square image or
not.Note that this attribute is not used in computing `crop_size` via calling `get_size_dict`.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size
size = get_size_dict(size, param_name="size", default_to_square=default_to_square)
size = get_size_dict(size, param_name="size", default_to_square=False)
resample = resample if resample is not None else self.resample
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
crop_size = crop_size if crop_size is not None else self.crop_size
Expand Down Expand Up @@ -310,7 +304,6 @@ def preprocess(
size=size,
resample=resample,
input_data_format=input_data_format,
default_to_square=default_to_square,
)
for image in images
]
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/textnet/modeling_textnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def forward(
>>> processor = TextNetImageProcessor.from_pretrained("Raghavan/textnet-base")
>>> model = TextNetForImageClassification.from_pretrained("Raghavan/textnet-base")
>>> inputs = processor(images=image, return_tensors="pt", size={"shortest_edge": 640}, default_to_square=True)
>>> inputs = processor(images=image, return_tensors="pt", size={"height": 640, "width": 640})
>>> outputs = model(**inputs)
>>> outputs.logits.shape
torch.Size([1, 2])
Expand Down
4 changes: 1 addition & 3 deletions tests/models/textnet/test_modeling_textnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,9 +320,7 @@ def test_inference_textnet_image_classification(self):
image = Image.open(requests.get(url, stream=True).raw)
processor = TextNetImageProcessor.from_pretrained("Raghavan/textnet-base")
text = "This is a photo of a cat"
inputs = processor(
text=text, images=image, return_tensors="pt", size={"shortest_edge": 640}, default_to_square=True
)
inputs = processor(text=text, images=image, return_tensors="pt", size={"height": 640, "width": 640})

# forward pass
output = model(pixel_values=torch.tensor(inputs["pixel_values"]))
Expand Down

0 comments on commit ef4671d

Please sign in to comment.