-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
uniformize kwargs for OneFormer #34547
base: main
Are you sure you want to change the base?
Conversation
cd15539
to
fad4111
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the initiative! Left a couple initial comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool! But we should also check that the previous signature works as intended as to not break backwards compatibility - in this case the previous behaviour where arguments were passed positionally should still be supported
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added *args
. I had to add it after images
, so the signature will only match after another iteration of deprecation. WDYT?
if isinstance(task_inputs, str): | ||
task_inputs = [task_inputs] | ||
|
||
if not isinstance(task_inputs, List) or not task_inputs: | ||
raise TypeError("task_inputs should be a string or a list of strings.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since this does more than checking the types, I suggest moving the conversion to list out of this function, and rename it explicitly _validate_input_types
for instance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Implemented.
|
||
def _preprocess_text(self, text_list: PreTokenizedInput, max_length: int = 77): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so this tokenizes which should be done with self.tokenizer(..., output_kwargs['text_kwargs'])
, and it seems that it puts pad tokens where the attention mask is 0 - this should be covered by the tokenizer already, would be a nice refactor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will look into this...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Atm I don't see how this can be simplified because the function is called with different arguments inside __call__
, and it is also called in encode_inputs
. Any suggestions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed that it would be best to use self.tokenizer directly, and pass in the other text kwargs as well. For max_length, you can do the following:
- if the max_length kwarg is not set, set it to
max_seq_length
ortask_seq_length
in the two different calls. - if it is set, use its value for both task and seq.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried this (commit 3c7724d), but I don't see yet how it is simpler/more readable :)
Couple of observations:
- the function
_preprocess_text
is called both withtask_seq_length
andmax_seq_length
, so if I add"max_length"
to thetext_kwargs
dict, I would need to add and then replace in__call__
. - in order to do something similar to what the attention-mask is doing (padding with 0s), I would need to change
self.tokenizer.pad_token_id
(and then change it back to what was before, e.g. 49407 (endoftext), because I can't pass inpad_token_id
to the tokenizer. I am most probably missing here something, so suggestions welcome.
max_seq_length: int = 77, | ||
task_seq_length: int = 77, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so the way it should be is that in OneFormerProcessorKwargs
, these defaults should be passed to the relevant dictionary, which you can then get back. say
_defaults = {"max_seq_length" = 77,
"task_seq_length" = 77}
and you can inform the types with a Kwargs typed class
class OneFormerTextKwargs(TextKwargs, total=False):
max_seq_length: Optional[int]
task_seq_length: Optional[int]
then your types are correctly informed as well as your defaults and you can use both.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was a bigger refactor. Let me know what you think...
Hey! @molbap is off as he needs to rest, cc @yonigozlan can you review this? 🤗 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this!
I mainly have comments about the handling of positional kwargs (like in the SAM PR), and about avoiding the wrapping of text processing inside _preprocess_text
|
||
def _preprocess_text(self, text_list: PreTokenizedInput, max_length: int = 77): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed that it would be best to use self.tokenizer directly, and pass in the other text kwargs as well. For max_length, you can do the following:
- if the max_length kwarg is not set, set it to
max_seq_length
ortask_seq_length
in the two different calls. - if it is set, use its value for both task and seq.
def encode_inputs( | ||
self, | ||
images=None, | ||
task_inputs=None, | ||
segmentation_maps=None, | ||
max_seq_length: int = 77, | ||
task_seq_length: int = 77, | ||
**kwargs, | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With this refactoring, it looks to me like encode_inputs could be made the same as call while preserving backward compatibility (if the args are handled correctly), so maybe we could just use call here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! Done.
I tried this in commit 0270034, took some time until I realized that there is a big difference in the two calls:
encode_inputs
is callingOneFormerImageProcessor.encode_inputs
__call__
is callingOneFormerImageProcessor.__call__
, which will eventually call alsoOneFormerImageProcessor.encode_inputs
, but only after quite some checks and massaging.
So atm I don't think there is a simple and elegant way to replace and would just leave as is. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I see, never mind then
def __call__( | ||
self, | ||
images: Optional[ImageInput] = None, | ||
*args, # to be deprecated |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, see SAM PR comment #34578 (comment)
@staticmethod | ||
def _add_args_for_backward_compatibility(args): | ||
""" | ||
Remove this function once support for args is dropped in __call__ | ||
""" | ||
if len(args) > 2: | ||
raise ValueError("Too many positional arguments") | ||
return dict(zip(("task_inputs", "segmentation_maps"), args)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not needed #34578 (comment)
@@ -76,13 +100,46 @@ def _preprocess_text(self, text_list=None, max_length=77): | |||
token_inputs = torch.cat(token_inputs, dim=0) | |||
return token_inputs | |||
|
|||
def __call__(self, images=None, task_inputs=None, segmentation_maps=None, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These should be added to the optional_call_args attribute (see udop processor)
40cf888
to
cbddaf4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a breaking change in the init that needs to be fixed. Otherwise it is looking good!
Just like for SAM, you'll also have to add the ProcessorTesterMixin to the processor test class, make sure all the tests pass and override the, if needed, and rebase on main.
def encode_inputs( | ||
self, | ||
images=None, | ||
task_inputs=None, | ||
segmentation_maps=None, | ||
max_seq_length: int = 77, | ||
task_seq_length: int = 77, | ||
**kwargs, | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I see, never mind then
def __init__( | ||
self, image_processor=None, tokenizer=None, max_seq_length: int = 77, task_seq_length: int = 77, **kwargs | ||
self, | ||
image_processor=None, | ||
tokenizer=None, | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I hadn't caught that before, but removing max_seq_length and task_seq_length from the init looks like a breaking change. We should add them back, and use self.max_seq_length and self.task_seq_length in the processing when the "max_length" kwarg is not defined. No need to add max_seq_length
and task_seq_length
to OneFormerTextKwargs
as they weren't accepted before.
text_kwargs = {} | ||
tokens = self.tokenizer( | ||
text_list, | ||
max_length=max_length if max_length is not None else text_kwargs.get("max_length", 77), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would say the opposite:
max_length=max_length if max_length is not None else text_kwargs.get("max_length", 77), | |
max_length=text_kwargs.get("max_length") if text_kwargs.get("max_length") is not None else max_length, |
So that the max_length
kwarg overrides the default kwarg when it is specified
class OneFormerTextKwargs(TextKwargs): | ||
max_seq_length: int | ||
task_seq_length: int |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need for that (see comments after)
padding=text_kwargs.get("padding", "max_length"), | ||
truncation=text_kwargs.get("truncation", True), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to redefine the defaults kwargs here in the get, as if they are not specified, they will necessarily be "max_length" and True respectively, you can just have:
padding=text_kwargs.get("padding", "max_length"), | |
truncation=text_kwargs.get("truncation", True), | |
padding=text_kwargs.get("padding"), | |
truncation=text_kwargs.get("truncation"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this is needed because encode_inputs
does not pass in any dict, so text_kwargs
will become {}
, without default.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yes I see
task_token_inputs.append(task_input) | ||
encoded_inputs["task_inputs"] = self._preprocess_text( | ||
task_token_inputs, | ||
max_length=output_kwargs["text_kwargs"]["task_seq_length"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
max_length=output_kwargs["text_kwargs"]["task_seq_length"], | |
max_length=self.task_seq_length, |
text_inputs = [ | ||
self._preprocess_text( | ||
texts, | ||
max_length=output_kwargs["text_kwargs"]["max_seq_length"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
max_length=output_kwargs["text_kwargs"]["max_seq_length"], | |
max_length=self.max_seq_length, |
3c7724d
to
9ccaa40
Compare
9ccaa40
to
b579834
Compare
@yonigozlan I had to repeat some of the tests, e.g. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM after putting the kwargs back in the init! Thanks for iterating on this :)
image_processor=None, | ||
tokenizer=None, | ||
max_seq_length: int = 77, | ||
task_seq_length: int = 77, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit removing the **kwargs here would also be a breaking change, even if it's not used anywhere.
Yes I think that's the best way to do it :) |
Adds uniformized processors for OneFormer following #31911.
Small changes to simplify code.
Additional check via
test_processor_oneformer.py
that theinputs
are the same before and after the changes (with fixed random seeds).@qubvel @molbap