diff --git a/caikit_nlp/modules/text_generation/text_generation_local.py b/caikit_nlp/modules/text_generation/text_generation_local.py index 0d2223b9..3898f16d 100644 --- a/caikit_nlp/modules/text_generation/text_generation_local.py +++ b/caikit_nlp/modules/text_generation/text_generation_local.py @@ -147,8 +147,8 @@ def train( lr: float = 2e-5, # Directory where model predictions and checkpoints will be written checkpoint_dir: str = "/tmp", - **training_arguments, - ): + **kwargs, + ) -> "TextGeneration": """ Fine-tune a CausalLM or Seq2seq text generation model. @@ -177,7 +177,7 @@ def train( Learning rate to be used while tuning model. Default: 2e-5. checkpoint_dir: str Directory where model predictions and checkpoints will be written - **training_arguments: + **kwargs: Arguments supported by HF Training Arguments. TrainingArguments: https://huggingface.co/docs/transformers/v4.30.0/en/main_classes/trainer#transformers.TrainingArguments @@ -274,7 +274,7 @@ def train( "eval_accumulation_steps": accumulate_steps, # eval_steps=1, # load_best_model_at_end - **training_arguments, + **kwargs, **dtype_based_params, }