Skip to content

Commit

Permalink
Merge pull request caikit#122 from gkumbhat/fix_train_service_gen
Browse files Browse the repository at this point in the history
🐛 Fix training arguments for service generation to work correctly
Signed-off-by: gkumbhat <[email protected]>
  • Loading branch information
gkumbhat authored Aug 10, 2023
2 parents 60a5ebb + 394189a commit 3b006a2
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions caikit_nlp/modules/text_generation/text_generation_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ def train(
lr: float = 2e-5,
# Directory where model predictions and checkpoints will be written
checkpoint_dir: str = "/tmp",
**training_arguments,
):
**kwargs,
) -> "TextGeneration":
"""
Fine-tune a CausalLM or Seq2seq text generation model.
Expand Down Expand Up @@ -177,7 +177,7 @@ def train(
Learning rate to be used while tuning model. Default: 2e-5.
checkpoint_dir: str
Directory where model predictions and checkpoints will be written
**training_arguments:
**kwargs:
Arguments supported by HF Training Arguments.
TrainingArguments:
https://huggingface.co/docs/transformers/v4.30.0/en/main_classes/trainer#transformers.TrainingArguments
Expand Down Expand Up @@ -274,7 +274,7 @@ def train(
"eval_accumulation_steps": accumulate_steps,
# eval_steps=1,
# load_best_model_at_end
**training_arguments,
**kwargs,
**dtype_based_params,
}

Expand Down

0 comments on commit 3b006a2

Please sign in to comment.