Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

uniformize kwargs for OneFormer #34547

Open
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

tibor-reiss
Copy link
Contributor

@tibor-reiss tibor-reiss commented Oct 31, 2024

Adds uniformized processors for OneFormer following #31911.

Small changes to simplify code.

Additional check via test_processor_oneformer.py that the inputs are the same before and after the changes (with fixed random seeds).

@qubvel @molbap

@tibor-reiss tibor-reiss force-pushed the fix-31811-oneformer branch 2 times, most recently from cd15539 to fad4111 Compare October 31, 2024 21:12
Copy link
Contributor

@molbap molbap left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the initiative! Left a couple initial comments

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! But we should also check that the previous signature works as intended as to not break backwards compatibility - in this case the previous behaviour where arguments were passed positionally should still be supported

Copy link
Contributor Author

@tibor-reiss tibor-reiss Nov 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added *args. I had to add it after images, so the signature will only match after another iteration of deprecation. WDYT?

Comment on lines 111 to 130
if isinstance(task_inputs, str):
task_inputs = [task_inputs]

if not isinstance(task_inputs, List) or not task_inputs:
raise TypeError("task_inputs should be a string or a list of strings.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since this does more than checking the types, I suggest moving the conversion to list out of this function, and rename it explicitly _validate_input_types for instance.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implemented.


def _preprocess_text(self, text_list: PreTokenizedInput, max_length: int = 77):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so this tokenizes which should be done with self.tokenizer(..., output_kwargs['text_kwargs']), and it seems that it puts pad tokens where the attention mask is 0 - this should be covered by the tokenizer already, would be a nice refactor

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will look into this...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Atm I don't see how this can be simplified because the function is called with different arguments inside __call__, and it is also called in encode_inputs. Any suggestions?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed that it would be best to use self.tokenizer directly, and pass in the other text kwargs as well. For max_length, you can do the following:

  • if the max_length kwarg is not set, set it to max_seq_length or task_seq_length in the two different calls.
  • if it is set, use its value for both task and seq.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried this (commit 3c7724d), but I don't see yet how it is simpler/more readable :)
Couple of observations:

  • the function _preprocess_text is called both with task_seq_length and max_seq_length, so if I add "max_length" to the text_kwargs dict, I would need to add and then replace in __call__.
  • in order to do something similar to what the attention-mask is doing (padding with 0s), I would need to change self.tokenizer.pad_token_id (and then change it back to what was before, e.g. 49407 (endoftext), because I can't pass in pad_token_id to the tokenizer. I am most probably missing here something, so suggestions welcome.

Comment on lines +75 to +82
max_seq_length: int = 77,
task_seq_length: int = 77,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so the way it should be is that in OneFormerProcessorKwargs, these defaults should be passed to the relevant dictionary, which you can then get back. say

_defaults = {"max_seq_length" = 77,
        "task_seq_length" = 77}

and you can inform the types with a Kwargs typed class

class OneFormerTextKwargs(TextKwargs, total=False):
    max_seq_length: Optional[int]
    task_seq_length: Optional[int]

then your types are correctly informed as well as your defaults and you can use both.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a bigger refactor. Let me know what you think...

@ArthurZucker
Copy link
Collaborator

Hey! @molbap is off as he needs to rest, cc @yonigozlan can you review this? 🤗

Copy link
Member

@yonigozlan yonigozlan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this!
I mainly have comments about the handling of positional kwargs (like in the SAM PR), and about avoiding the wrapping of text processing inside _preprocess_text


def _preprocess_text(self, text_list: PreTokenizedInput, max_length: int = 77):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed that it would be best to use self.tokenizer directly, and pass in the other text kwargs as well. For max_length, you can do the following:

  • if the max_length kwarg is not set, set it to max_seq_length or task_seq_length in the two different calls.
  • if it is set, use its value for both task and seq.

Comment on lines 203 to 224
def encode_inputs(
self,
images=None,
task_inputs=None,
segmentation_maps=None,
max_seq_length: int = 77,
task_seq_length: int = 77,
**kwargs,
):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this refactoring, it looks to me like encode_inputs could be made the same as call while preserving backward compatibility (if the args are handled correctly), so maybe we could just use call here?

Copy link
Contributor Author

@tibor-reiss tibor-reiss Nov 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! Done.

I tried this in commit 0270034, took some time until I realized that there is a big difference in the two calls:

  • encode_inputs is calling OneFormerImageProcessor.encode_inputs
  • __call__ is calling OneFormerImageProcessor.__call__, which will eventually call also OneFormerImageProcessor.encode_inputs, but only after quite some checks and massaging.

So atm I don't think there is a simple and elegant way to replace and would just leave as is. WDYT?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see, never mind then

def __call__(
self,
images: Optional[ImageInput] = None,
*args, # to be deprecated
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, see SAM PR comment #34578 (comment)

Comment on lines 119 to 126
@staticmethod
def _add_args_for_backward_compatibility(args):
"""
Remove this function once support for args is dropped in __call__
"""
if len(args) > 2:
raise ValueError("Too many positional arguments")
return dict(zip(("task_inputs", "segmentation_maps"), args))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not needed #34578 (comment)

@@ -76,13 +100,46 @@ def _preprocess_text(self, text_list=None, max_length=77):
token_inputs = torch.cat(token_inputs, dim=0)
return token_inputs

def __call__(self, images=None, task_inputs=None, segmentation_maps=None, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These should be added to the optional_call_args attribute (see udop processor)

@tibor-reiss tibor-reiss force-pushed the fix-31811-oneformer branch 2 times, most recently from 40cf888 to cbddaf4 Compare November 29, 2024 09:58
Copy link
Member

@yonigozlan yonigozlan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a breaking change in the init that needs to be fixed. Otherwise it is looking good!
Just like for SAM, you'll also have to add the ProcessorTesterMixin to the processor test class, make sure all the tests pass and override the, if needed, and rebase on main.

Comment on lines 203 to 224
def encode_inputs(
self,
images=None,
task_inputs=None,
segmentation_maps=None,
max_seq_length: int = 77,
task_seq_length: int = 77,
**kwargs,
):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see, never mind then

Comment on lines 84 to 83
def __init__(
self, image_processor=None, tokenizer=None, max_seq_length: int = 77, task_seq_length: int = 77, **kwargs
self,
image_processor=None,
tokenizer=None,
):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I hadn't caught that before, but removing max_seq_length and task_seq_length from the init looks like a breaking change. We should add them back, and use self.max_seq_length and self.task_seq_length in the processing when the "max_length" kwarg is not defined. No need to add max_seq_length and task_seq_length to OneFormerTextKwargs as they weren't accepted before.

text_kwargs = {}
tokens = self.tokenizer(
text_list,
max_length=max_length if max_length is not None else text_kwargs.get("max_length", 77),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say the opposite:

Suggested change
max_length=max_length if max_length is not None else text_kwargs.get("max_length", 77),
max_length=text_kwargs.get("max_length") if text_kwargs.get("max_length") is not None else max_length,

So that the max_length kwarg overrides the default kwarg when it is specified

Comment on lines 32 to 34
class OneFormerTextKwargs(TextKwargs):
max_seq_length: int
task_seq_length: int
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for that (see comments after)

Comment on lines +107 to +106
padding=text_kwargs.get("padding", "max_length"),
truncation=text_kwargs.get("truncation", True),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to redefine the defaults kwargs here in the get, as if they are not specified, they will necessarily be "max_length" and True respectively, you can just have:

Suggested change
padding=text_kwargs.get("padding", "max_length"),
truncation=text_kwargs.get("truncation", True),
padding=text_kwargs.get("padding"),
truncation=text_kwargs.get("truncation"),

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is needed because encode_inputs does not pass in any dict, so text_kwargs will become {}, without default.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes I see

task_token_inputs.append(task_input)
encoded_inputs["task_inputs"] = self._preprocess_text(
task_token_inputs,
max_length=output_kwargs["text_kwargs"]["task_seq_length"],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
max_length=output_kwargs["text_kwargs"]["task_seq_length"],
max_length=self.task_seq_length,

text_inputs = [
self._preprocess_text(
texts,
max_length=output_kwargs["text_kwargs"]["max_seq_length"],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
max_length=output_kwargs["text_kwargs"]["max_seq_length"],
max_length=self.max_seq_length,

@tibor-reiss
Copy link
Contributor Author

@yonigozlan I had to repeat some of the tests, e.g. test_structured_kwargs_nested, due to the mandatory task_inputs. It was either this, or making a default for task_inputs, e.g. 'semantic'. I went with the former, because it might be surprising for the users that the 'semantic' is picked as default - currently it raises a ValueError if not specified.

Copy link
Member

@yonigozlan yonigozlan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM after putting the kwargs back in the init! Thanks for iterating on this :)

image_processor=None,
tokenizer=None,
max_seq_length: int = 77,
task_seq_length: int = 77,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit removing the **kwargs here would also be a breaking change, even if it's not used anywhere.

@yonigozlan
Copy link
Member

@yonigozlan I had to repeat some of the tests, e.g. test_structured_kwargs_nested, due to the mandatory task_inputs. It was either this, or making a default for task_inputs, e.g. 'semantic'. I went with the former, because it might be surprising for the users that the 'semantic' is picked as default - currently it raises a ValueError if not specified.

Yes I think that's the best way to do it :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants