Skip to content

Commit

Permalink
fix docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
yonigozlan committed Nov 25, 2024
1 parent 0ddda21 commit a1556dd
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 4 deletions.
95 changes: 92 additions & 3 deletions src/transformers/models/got_ocr2/modular_got_ocr2.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,10 @@
from ...modeling_outputs import CausalLMOutputWithPast
from ...utils import (
ModelOutput,
add_start_docstrings_to_model_forward,
is_vision_available,
logging,
replace_return_docstrings,
)


Expand All @@ -58,6 +60,8 @@

logger = logging.get_logger(__name__)

_CONFIG_FOR_DOC = "GotOcr2Config"


class GotOcr2VisionConfig(PretrainedConfig):
r"""
Expand Down Expand Up @@ -167,6 +171,8 @@ class GotOcr2ImagesKwargs(ImagesKwargs, total=False):
num_image_tokens: Optional[int]
multi_page: Optional[bool]
crop_to_patches: Optional[bool]
min_patches: Optional[int]
max_patches: Optional[int]


class GotOcr2ProcessorKwargs(ProcessingKwargs, total=False):
Expand All @@ -179,6 +185,8 @@ class GotOcr2ProcessorKwargs(ProcessingKwargs, total=False):
},
"images_kwargs": {
"num_image_tokens": 256,
"min_patches": 1,
"max_patches": 6,
},
}

Expand Down Expand Up @@ -359,6 +367,12 @@ def __call__(
If set, will enable multi-page inference. The model will return the OCR result across multiple pages.
crop_to_patches (`bool`, *optional*):
If set, will crop the image to patches. The model will return the OCR result upon the patch reference.
min_patches (`int`, *optional*):
The minimum number of patches to be cropped from the image. Only used when `crop_to_patches` is set to
`True`.
max_patches (`int`, *optional*):
The maximum number of patches to be cropped from the image. Only used when `crop_to_patches` is set to
`True`.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
Expand Down Expand Up @@ -429,7 +443,10 @@ def __call__(
for index, (image_group, box_single, color_single) in enumerate(zip(images, box, color)):
if crop_to_patches:
image_group = self.image_processor.crop_image_to_patches(
image_group, size=output_kwargs["images_kwargs"].get("size", None)
image_group,
size=output_kwargs["images_kwargs"].get("size"),
min_num=output_kwargs["images_kwargs"].get("min_patches"),
max_num=output_kwargs["images_kwargs"].get("max_patches"),
)
images[index] = image_group
num_images = len(image_group) if (multi_page or crop_to_patches) else 1
Expand Down Expand Up @@ -468,14 +485,14 @@ def __call__(

def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)

def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
Expand Down Expand Up @@ -572,6 +589,76 @@ class GotOcr2Model(Qwen2Model):
pass


GOT_OCR2_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
`past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy.
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
pixel_values (`torch.FloatTensor` of shape `(seq_length, num_channels * image_size * image_size)):
The tensors corresponding to the input images. Pixel values can be obtained using
[`AutoImageProcessor`]. See [`GotOcr2ImageProcessor.__call__`] for details. [`GotOcr2Processor`] uses
[`GotOcr2ImageProcessor`] for processing images.
"""


class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]

Expand Down Expand Up @@ -620,6 +707,8 @@ def _update_model_kwargs_for_generation(

return model_kwargs

@add_start_docstrings_to_model_forward(GOT_OCR2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
Expand Down
9 changes: 8 additions & 1 deletion src/transformers/models/got_ocr2/processing_got_ocr2.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ class GotOcr2ImagesKwargs(ImagesKwargs, total=False):
num_image_tokens: Optional[int]
multi_page: Optional[bool]
crop_to_patches: Optional[bool]
min_patches: Optional[int]
max_patches: Optional[int]


class GotOcr2ProcessorKwargs(ProcessingKwargs, total=False):
Expand All @@ -62,6 +64,8 @@ class GotOcr2ProcessorKwargs(ProcessingKwargs, total=False):
},
"images_kwargs": {
"num_image_tokens": 256,
"min_patches": 1,
"max_patches": 6,
},
}

Expand Down Expand Up @@ -213,7 +217,10 @@ def __call__(
for index, (image_group, box_single, color_single) in enumerate(zip(images, box, color)):
if crop_to_patches:
image_group = self.image_processor.crop_image_to_patches(
image_group, size=output_kwargs["images_kwargs"].get("size", None)
image_group,
size=output_kwargs["images_kwargs"].get("size"),
min_num=output_kwargs["images_kwargs"].get("min_patches"),
max_num=output_kwargs["images_kwargs"].get("max_patches"),
)
images[index] = image_group
num_images = len(image_group) if (multi_page or crop_to_patches) else 1
Expand Down

0 comments on commit a1556dd

Please sign in to comment.