Skip to content

Commit

Permalink
add changes
Browse files Browse the repository at this point in the history
  • Loading branch information
jadechoghari committed Dec 21, 2024
1 parent e8a5b97 commit 29d969b
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 49 deletions.
8 changes: 2 additions & 6 deletions src/transformers/models/textnet/configuration_textnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,13 @@

logger = logging.get_logger(__name__)

TEXTNET_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"textnet-base": ("https://huggingface.co/Raghavan/textnet-base/blob/main/config.json"),
}


class TextNetConfig(BackboneConfigMixin, PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`TextNextModel`]. It is used to instantiate a
TextNext model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the
[Raghavan/textnet-base](https://huggingface.co/Raghavan/textnet-base). Configuration objects inherit from
[jadechoghari/textnet-base](https://huggingface.co/jadechoghari/textnet-base). Configuration objects inherit from
[`PretrainedConfig`] and can be used to control the model outputs.Read the documentation from [`PretrainedConfig`]
for more information.
Expand Down Expand Up @@ -85,7 +81,7 @@ class TextNetConfig(BackboneConfigMixin, PretrainedConfig):
```"""

r"""
[Raghavan/textnet-base](https://huggingface.co/Raghavan/textnet-base)
[jadechoghari](https://huggingface.co/jadechoghari/textnet-base)
"""
model_type = "textnet"

Expand Down
24 changes: 17 additions & 7 deletions src/transformers/models/textnet/image_processing_textnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class TextNetImageProcessor(BaseImageProcessor):
Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with
the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess`
method.
size_divisor (`int`, *optional*, defaults to 32):
Ensures height and width are rounded to a multiple of this value after resizing.
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
do_center_crop (`bool`, *optional*, defaults to `False`):
Expand Down Expand Up @@ -93,6 +95,7 @@ def __init__(
self,
do_resize: bool = True,
size: Dict[str, int] = None,
size_divisor: int = 32,
resample: PILImageResampling = PILImageResampling.BILINEAR,
do_center_crop: bool = False,
crop_size: Dict[str, int] = None,
Expand All @@ -112,6 +115,7 @@ def __init__(

self.do_resize = do_resize
self.size = size
self.size_divisor = size_divisor
self.resample = resample
self.do_center_crop = do_center_crop
self.crop_size = crop_size
Expand All @@ -126,6 +130,7 @@ def __init__(
"images",
"do_resize",
"size",
"size_divisor",
"resample",
"do_center_crop",
"crop_size",
Expand Down Expand Up @@ -158,6 +163,8 @@ def resize(
Image to resize.
size (`Dict[str, int]`):
Size of the output image.
size_divisor (`int`, *optional*, defaults to `32`):
Ensures height and width are rounded to a multiple of this value after resizing.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
Resampling filter to use when resiizing the image.
data_format (`str` or `ChannelDimension`, *optional*):
Expand All @@ -176,29 +183,29 @@ def resize(
else:
raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.")

height, weight = get_resize_output_image_size(
height, width = get_resize_output_image_size(
image, size=size, input_data_format=input_data_format, default_to_square=False
)
if height % 32 != 0:
height = height + (32 - height % 32)
if weight % 32 != 0:
weight = weight + (32 - weight % 32)
if height % self.size_divisor != 0:
height += self.size_divisor - (height % self.size_divisor)
if width % self.size_divisor != 0:
width += self.size_divisor - (width % self.size_divisor)

return resize(
image,
size=(height, weight),
size=(height, width),
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)

# Copied from transformers.models.clip.image_processing_clip.CLIPImageProcessor.preprocess
def preprocess(
self,
images: ImageInput,
do_resize: bool = None,
size: Dict[str, int] = None,
size_divisor: int = None,
resample: PILImageResampling = None,
do_center_crop: bool = None,
crop_size: int = None,
Expand All @@ -225,6 +232,8 @@ def preprocess(
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
the longest edge resized to keep the input aspect ratio.
size_divisor (`int`, *optional*, defaults to `32`):
Ensures height and width are rounded to a multiple of this value after resizing.
resample (`int`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
has an effect if `do_resize` is set to `True`.
Expand Down Expand Up @@ -267,6 +276,7 @@ def preprocess(
do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size
size = get_size_dict(size, param_name="size", default_to_square=False)
size_divisor = size_divisor if size_divisor is not None else self.size_divisor
resample = resample if resample is not None else self.resample
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
crop_size = crop_size if crop_size is not None else self.crop_size
Expand Down
63 changes: 32 additions & 31 deletions src/transformers/models/textnet/modeling_textnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,33 +43,9 @@

# General docstring
_CONFIG_FOR_DOC = "TextNetConfig"
_CHECKPOINT_FOR_DOC = "Raghavan/textnet-base"
_CHECKPOINT_FOR_DOC = "jadechoghari/textnet-base"
_EXPECTED_OUTPUT_SHAPE = [1, 512, 20, 27]

TEXTNET_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
behavior.
Parameters:
config ([`TextNetConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""

TEXTNET_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
[`TextNetImageProcessor.__call__`] for details.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""


class TextNetConvLayer(nn.Module):
def __init__(self, config: TextNetConfig):
Expand Down Expand Up @@ -116,7 +92,7 @@ def __init__(self, config: TextNetConfig, in_channels: int, out_channels: int, k

padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2)

self.nonlinearity = nn.ReLU()
self.activation_function = nn.ReLU()

self.main_conv = nn.Conv2d(
in_channels=in_channels,
Expand Down Expand Up @@ -181,7 +157,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
id_out = self.rbr_identity(hidden_states)
main_outputs = main_outputs + id_out

return self.nonlinearity(main_outputs)
return self.activation_function(main_outputs)


class TextNetStage(nn.Module):
Expand Down Expand Up @@ -237,6 +213,31 @@ def forward(
return BaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=hidden_states)


TEXTNET_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
behavior.
Parameters:
config ([`TextNetConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""

TEXTNET_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
[`TextNetImageProcessor.__call__`] for details.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""


class TextNetPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
Expand Down Expand Up @@ -354,8 +355,8 @@ def forward(
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> processor = TextNetImageProcessor.from_pretrained("Raghavan/textnet-base")
>>> model = TextNetForImageClassification.from_pretrained("Raghavan/textnet-base")
>>> processor = TextNetImageProcessor.from_pretrained("jadechoghari/textnet-base")
>>> model = TextNetForImageClassification.from_pretrained("jadechoghari/textnet-base")
>>> inputs = processor(images=image, return_tensors="pt", size={"height": 640, "width": 640})
>>> with torch.no_grad():
Expand Down Expand Up @@ -434,8 +435,8 @@ def forward(
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> processor = AutoImageProcessor.from_pretrained("Raghavan/textnet-base")
>>> model = AutoBackbone.from_pretrained("Raghavan/textnet-base")
>>> processor = AutoImageProcessor.from_pretrained("jadechoghari/textnet-base")
>>> model = AutoBackbone.from_pretrained("jadechoghari/textnet-base")
>>> inputs = processor(image, return_tensors="pt")
>>> with torch.no_grad():
Expand Down
4 changes: 4 additions & 0 deletions tests/models/textnet/test_image_processing_textnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
max_resolution=400,
do_resize=True,
size=None,
size_divisor=32,
do_center_crop=True,
crop_size=None,
do_normalize=True,
Expand All @@ -54,6 +55,7 @@ def __init__(
self.max_resolution = max_resolution
self.do_resize = do_resize
self.size = size
self.size_divisor = size_divisor
self.do_center_crop = do_center_crop
self.crop_size = crop_size
self.do_normalize = do_normalize
Expand All @@ -65,6 +67,7 @@ def prepare_image_processor_dict(self):
return {
"do_resize": self.do_resize,
"size": self.size,
"size_divisor": self.size_divisor,
"do_center_crop": self.do_center_crop,
"crop_size": self.crop_size,
"do_normalize": self.do_normalize,
Expand Down Expand Up @@ -105,6 +108,7 @@ def test_image_processor_properties(self):
image_processing = self.image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "size"))
self.assertTrue(hasattr(image_processing, "size_divisor"))
self.assertTrue(hasattr(image_processing, "do_center_crop"))
self.assertTrue(hasattr(image_processing, "center_crop"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
Expand Down
14 changes: 9 additions & 5 deletions tests/models/textnet/test_modeling_textnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def test_for_image_classification(self):

@slow
def test_model_from_pretrained(self):
model_name = "Raghavan/textnet-base"
model_name = "jadechoghari/textnet-base"
model = TextNetModel.from_pretrained(model_name)
self.assertIsNotNone(model)

Expand All @@ -314,9 +314,9 @@ def test_model_from_pretrained(self):
@require_vision
class TextNetModelIntegrationTest(unittest.TestCase):
@slow
def test_inference_textnet_image_classification(self):
processor = TextNetImageProcessor.from_pretrained("Raghavan/textnet-base")
model = TextNetForImageClassification.from_pretrained("Raghavan/textnet-base").to(torch_device)
def test_inference_no_head(self):
processor = TextNetImageProcessor.from_pretrained("jadechoghari/textnet-base")
model = TextNetModel.from_pretrained("jadechoghari/textnet-base").to(torch_device)

# prepare image
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
Expand All @@ -329,7 +329,11 @@ def test_inference_textnet_image_classification(self):

# verify logits
self.assertEqual(output.logits.shape, torch.Size([1, 2]))
self.assertTrue(torch.allclose(output.logits[:3, :3]), torch.zeros_like((output.logits[:3, :3])), atol=1e-3)
expected_slice_backbone = torch.tensor(
[0.9210, 0.6099, 0.0000, 0.0000, 0.0000, 0.0000, 3.2207, 2.6602, 1.8925, 0.0000],
device=torch_device,
)
self.assertTrue(torch.allclose(output.feature_maps[-1][0][10][12][:10], expected_slice_backbone, atol=1e-3))


@require_torch
Expand Down

0 comments on commit 29d969b

Please sign in to comment.