Skip to content

Commit

Permalink
[xxxTrainer] multi-tags support for tagging (#1133)
Browse files Browse the repository at this point in the history
* multi-tags support for tagging

* oops
  • Loading branch information
younesbelkada authored Dec 22, 2023
1 parent 17ec68d commit 0c4edb7
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 12 deletions.
4 changes: 2 additions & 2 deletions trl/trainer/ddpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class DDPOTrainer(BaseTrainer):
**image_samples_hook** (Optional[Callable[[Any, Any, Any], Any]]) -- Hook to be called to log images
"""

_tag_name = "trl-ddpo"
_tag_names = ["trl", "ddpo"]

def __init__(
self,
Expand Down Expand Up @@ -585,6 +585,6 @@ def push_to_hub(self, commit_message: Optional[str] = "End of training", blockin
Overwrite the `push_to_hub` method in order to force-add the tag "sft" when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = trl_sanitze_kwargs_for_tagging(tag_name=self._tag_name, kwargs=kwargs)
kwargs = trl_sanitze_kwargs_for_tagging(tag_names=self._tag_names, kwargs=kwargs)

return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs)
4 changes: 2 additions & 2 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class DPOTrainer(Trainer):
Dict of Optional kwargs to pass when instantiating the ref model from a string
"""

_tag_name = "trl-dpo"
_tag_names = ["trl", "dpo"]

def __init__(
self,
Expand Down Expand Up @@ -1144,6 +1144,6 @@ def push_to_hub(self, commit_message: Optional[str] = "End of training", blockin
Overwrite the `push_to_hub` method in order to force-add the tag "sft" when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = trl_sanitze_kwargs_for_tagging(tag_name=self._tag_name, kwargs=kwargs)
kwargs = trl_sanitze_kwargs_for_tagging(tag_names=self._tag_names, kwargs=kwargs)

return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs)
4 changes: 2 additions & 2 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ class PPOTrainer(BaseTrainer):
**lr_scheduler** (`torch.optim.lr_scheduler`, *optional*) -- Learning rate scheduler to be used for training.
"""

_tag_name = "trl-ppo"
_tag_names = ["trl", "ppo"]

def __init__(
self,
Expand Down Expand Up @@ -1452,6 +1452,6 @@ def push_to_hub(self, commit_message: Optional[str] = "End of training", blockin
Overwrite the `push_to_hub` method in order to force-add the tag "sft" when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = trl_sanitze_kwargs_for_tagging(tag_name=self._tag_name, kwargs=kwargs)
kwargs = trl_sanitze_kwargs_for_tagging(tag_names=self._tag_names, kwargs=kwargs)

return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs)
4 changes: 2 additions & 2 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class SFTTrainer(Trainer):
dataset_kwargs: (`Optional[Dict]`, *optional*):
Dict of Optional kwargs to pass when creating packed or non-packed datasets
"""
_tag_name = "trl-sft"
_tag_names = ["trl", "sft"]

def __init__(
self,
Expand Down Expand Up @@ -334,7 +334,7 @@ def push_to_hub(self, commit_message: Optional[str] = "End of training", blockin
Overwrite the `push_to_hub` method in order to force-add the tag "sft" when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = trl_sanitze_kwargs_for_tagging(tag_name=self._tag_name, kwargs=kwargs)
kwargs = trl_sanitze_kwargs_for_tagging(tag_names=self._tag_names, kwargs=kwargs)

return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs)

Expand Down
12 changes: 8 additions & 4 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,12 +639,16 @@ def peft_module_casting_to_bf16(model):
module = module.to(torch.bfloat16)


def trl_sanitze_kwargs_for_tagging(tag_name, kwargs=None):
def trl_sanitze_kwargs_for_tagging(tag_names, kwargs=None):
if isinstance(tag_names, str):
tag_names = [tag_names]

if kwargs is not None:
if "tags" not in kwargs:
kwargs["tags"] = [tag_name]
kwargs["tags"] = tag_names
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
kwargs["tags"].append(tag_name)
kwargs["tags"].extend(tag_names)
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
kwargs["tags"] = [kwargs["tags"], tag_name]
tag_names.append(kwargs["tags"])
kwargs["tags"] = tag_names
return kwargs

0 comments on commit 0c4edb7

Please sign in to comment.