From 29d969ba6ce43859875123f27ceafb141477a56d Mon Sep 17 00:00:00 2001 From: jadechoghari Date: Sat, 21 Dec 2024 15:52:59 +0100 Subject: [PATCH] add changes --- .../models/textnet/configuration_textnet.py | 8 +-- .../textnet/image_processing_textnet.py | 24 ++++--- .../models/textnet/modeling_textnet.py | 63 ++++++++++--------- .../textnet/test_image_processing_textnet.py | 4 ++ tests/models/textnet/test_modeling_textnet.py | 14 +++-- 5 files changed, 64 insertions(+), 49 deletions(-) diff --git a/src/transformers/models/textnet/configuration_textnet.py b/src/transformers/models/textnet/configuration_textnet.py index 235e946e78bbda..67182ac79330ac 100644 --- a/src/transformers/models/textnet/configuration_textnet.py +++ b/src/transformers/models/textnet/configuration_textnet.py @@ -21,17 +21,13 @@ logger = logging.get_logger(__name__) -TEXTNET_PRETRAINED_CONFIG_ARCHIVE_MAP = { - "textnet-base": ("https://huggingface.co/Raghavan/textnet-base/blob/main/config.json"), -} - class TextNetConfig(BackboneConfigMixin, PretrainedConfig): r""" This is the configuration class to store the configuration of a [`TextNextModel`]. It is used to instantiate a TextNext model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the - [Raghavan/textnet-base](https://huggingface.co/Raghavan/textnet-base). Configuration objects inherit from + [jadechoghari/textnet-base](https://huggingface.co/jadechoghari/textnet-base). Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs.Read the documentation from [`PretrainedConfig`] for more information. @@ -85,7 +81,7 @@ class TextNetConfig(BackboneConfigMixin, PretrainedConfig): ```""" r""" - [Raghavan/textnet-base](https://huggingface.co/Raghavan/textnet-base) + [jadechoghari](https://huggingface.co/jadechoghari/textnet-base) """ model_type = "textnet" diff --git a/src/transformers/models/textnet/image_processing_textnet.py b/src/transformers/models/textnet/image_processing_textnet.py index 3a5d71e32f1683..cf2362762defc7 100644 --- a/src/transformers/models/textnet/image_processing_textnet.py +++ b/src/transformers/models/textnet/image_processing_textnet.py @@ -60,6 +60,8 @@ class TextNetImageProcessor(BaseImageProcessor): Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess` method. + size_divisor (`int`, *optional*, defaults to 32): + Ensures height and width are rounded to a multiple of this value after resizing. resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`): Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. do_center_crop (`bool`, *optional*, defaults to `False`): @@ -93,6 +95,7 @@ def __init__( self, do_resize: bool = True, size: Dict[str, int] = None, + size_divisor: int = 32, resample: PILImageResampling = PILImageResampling.BILINEAR, do_center_crop: bool = False, crop_size: Dict[str, int] = None, @@ -112,6 +115,7 @@ def __init__( self.do_resize = do_resize self.size = size + self.size_divisor = size_divisor self.resample = resample self.do_center_crop = do_center_crop self.crop_size = crop_size @@ -126,6 +130,7 @@ def __init__( "images", "do_resize", "size", + "size_divisor", "resample", "do_center_crop", "crop_size", @@ -158,6 +163,8 @@ def resize( Image to resize. size (`Dict[str, int]`): Size of the output image. + size_divisor (`int`, *optional*, defaults to `32`): + Ensures height and width are rounded to a multiple of this value after resizing. resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`): Resampling filter to use when resiizing the image. data_format (`str` or `ChannelDimension`, *optional*): @@ -176,29 +183,29 @@ def resize( else: raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.") - height, weight = get_resize_output_image_size( + height, width = get_resize_output_image_size( image, size=size, input_data_format=input_data_format, default_to_square=False ) - if height % 32 != 0: - height = height + (32 - height % 32) - if weight % 32 != 0: - weight = weight + (32 - weight % 32) + if height % self.size_divisor != 0: + height += self.size_divisor - (height % self.size_divisor) + if width % self.size_divisor != 0: + width += self.size_divisor - (width % self.size_divisor) return resize( image, - size=(height, weight), + size=(height, width), resample=resample, data_format=data_format, input_data_format=input_data_format, **kwargs, ) - # Copied from transformers.models.clip.image_processing_clip.CLIPImageProcessor.preprocess def preprocess( self, images: ImageInput, do_resize: bool = None, size: Dict[str, int] = None, + size_divisor: int = None, resample: PILImageResampling = None, do_center_crop: bool = None, crop_size: int = None, @@ -225,6 +232,8 @@ def preprocess( size (`Dict[str, int]`, *optional*, defaults to `self.size`): Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with the longest edge resized to keep the input aspect ratio. + size_divisor (`int`, *optional*, defaults to `32`): + Ensures height and width are rounded to a multiple of this value after resizing. resample (`int`, *optional*, defaults to `self.resample`): Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only has an effect if `do_resize` is set to `True`. @@ -267,6 +276,7 @@ def preprocess( do_resize = do_resize if do_resize is not None else self.do_resize size = size if size is not None else self.size size = get_size_dict(size, param_name="size", default_to_square=False) + size_divisor = size_divisor if size_divisor is not None else self.size_divisor resample = resample if resample is not None else self.resample do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop crop_size = crop_size if crop_size is not None else self.crop_size diff --git a/src/transformers/models/textnet/modeling_textnet.py b/src/transformers/models/textnet/modeling_textnet.py index d0ed0f6a12577e..f5d9cab61b811c 100644 --- a/src/transformers/models/textnet/modeling_textnet.py +++ b/src/transformers/models/textnet/modeling_textnet.py @@ -43,33 +43,9 @@ # General docstring _CONFIG_FOR_DOC = "TextNetConfig" -_CHECKPOINT_FOR_DOC = "Raghavan/textnet-base" +_CHECKPOINT_FOR_DOC = "jadechoghari/textnet-base" _EXPECTED_OUTPUT_SHAPE = [1, 512, 20, 27] -TEXTNET_START_DOCSTRING = r""" - This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it - as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and - behavior. - - Parameters: - config ([`TextNetConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -TEXTNET_INPUTS_DOCSTRING = r""" - Args: - pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See - [`TextNetImageProcessor.__call__`] for details. - - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - class TextNetConvLayer(nn.Module): def __init__(self, config: TextNetConfig): @@ -116,7 +92,7 @@ def __init__(self, config: TextNetConfig, in_channels: int, out_channels: int, k padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2) - self.nonlinearity = nn.ReLU() + self.activation_function = nn.ReLU() self.main_conv = nn.Conv2d( in_channels=in_channels, @@ -181,7 +157,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: id_out = self.rbr_identity(hidden_states) main_outputs = main_outputs + id_out - return self.nonlinearity(main_outputs) + return self.activation_function(main_outputs) class TextNetStage(nn.Module): @@ -237,6 +213,31 @@ def forward( return BaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=hidden_states) +TEXTNET_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`TextNetConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +TEXTNET_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`TextNetImageProcessor.__call__`] for details. + + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + class TextNetPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained @@ -354,8 +355,8 @@ def forward( >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) - >>> processor = TextNetImageProcessor.from_pretrained("Raghavan/textnet-base") - >>> model = TextNetForImageClassification.from_pretrained("Raghavan/textnet-base") + >>> processor = TextNetImageProcessor.from_pretrained("jadechoghari/textnet-base") + >>> model = TextNetForImageClassification.from_pretrained("jadechoghari/textnet-base") >>> inputs = processor(images=image, return_tensors="pt", size={"height": 640, "width": 640}) >>> with torch.no_grad(): @@ -434,8 +435,8 @@ def forward( >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) - >>> processor = AutoImageProcessor.from_pretrained("Raghavan/textnet-base") - >>> model = AutoBackbone.from_pretrained("Raghavan/textnet-base") + >>> processor = AutoImageProcessor.from_pretrained("jadechoghari/textnet-base") + >>> model = AutoBackbone.from_pretrained("jadechoghari/textnet-base") >>> inputs = processor(image, return_tensors="pt") >>> with torch.no_grad(): diff --git a/tests/models/textnet/test_image_processing_textnet.py b/tests/models/textnet/test_image_processing_textnet.py index 781f71460ec240..4fcd93e872fcdf 100644 --- a/tests/models/textnet/test_image_processing_textnet.py +++ b/tests/models/textnet/test_image_processing_textnet.py @@ -37,6 +37,7 @@ def __init__( max_resolution=400, do_resize=True, size=None, + size_divisor=32, do_center_crop=True, crop_size=None, do_normalize=True, @@ -54,6 +55,7 @@ def __init__( self.max_resolution = max_resolution self.do_resize = do_resize self.size = size + self.size_divisor = size_divisor self.do_center_crop = do_center_crop self.crop_size = crop_size self.do_normalize = do_normalize @@ -65,6 +67,7 @@ def prepare_image_processor_dict(self): return { "do_resize": self.do_resize, "size": self.size, + "size_divisor": self.size_divisor, "do_center_crop": self.do_center_crop, "crop_size": self.crop_size, "do_normalize": self.do_normalize, @@ -105,6 +108,7 @@ def test_image_processor_properties(self): image_processing = self.image_processing_class(**self.image_processor_dict) self.assertTrue(hasattr(image_processing, "do_resize")) self.assertTrue(hasattr(image_processing, "size")) + self.assertTrue(hasattr(image_processing, "size_divisor")) self.assertTrue(hasattr(image_processing, "do_center_crop")) self.assertTrue(hasattr(image_processing, "center_crop")) self.assertTrue(hasattr(image_processing, "do_normalize")) diff --git a/tests/models/textnet/test_modeling_textnet.py b/tests/models/textnet/test_modeling_textnet.py index ffcf6c2e13f4e5..95ec249f46f7a4 100644 --- a/tests/models/textnet/test_modeling_textnet.py +++ b/tests/models/textnet/test_modeling_textnet.py @@ -305,7 +305,7 @@ def test_for_image_classification(self): @slow def test_model_from_pretrained(self): - model_name = "Raghavan/textnet-base" + model_name = "jadechoghari/textnet-base" model = TextNetModel.from_pretrained(model_name) self.assertIsNotNone(model) @@ -314,9 +314,9 @@ def test_model_from_pretrained(self): @require_vision class TextNetModelIntegrationTest(unittest.TestCase): @slow - def test_inference_textnet_image_classification(self): - processor = TextNetImageProcessor.from_pretrained("Raghavan/textnet-base") - model = TextNetForImageClassification.from_pretrained("Raghavan/textnet-base").to(torch_device) + def test_inference_no_head(self): + processor = TextNetImageProcessor.from_pretrained("jadechoghari/textnet-base") + model = TextNetModel.from_pretrained("jadechoghari/textnet-base").to(torch_device) # prepare image url = "http://images.cocodataset.org/val2017/000000039769.jpg" @@ -329,7 +329,11 @@ def test_inference_textnet_image_classification(self): # verify logits self.assertEqual(output.logits.shape, torch.Size([1, 2])) - self.assertTrue(torch.allclose(output.logits[:3, :3]), torch.zeros_like((output.logits[:3, :3])), atol=1e-3) + expected_slice_backbone = torch.tensor( + [0.9210, 0.6099, 0.0000, 0.0000, 0.0000, 0.0000, 3.2207, 2.6602, 1.8925, 0.0000], + device=torch_device, + ) + self.assertTrue(torch.allclose(output.feature_maps[-1][0][10][12][:10], expected_slice_backbone, atol=1e-3)) @require_torch