Skip to content

Commit

Permalink
Merge pull request caikit#217 from gkumbhat/add_pt_random_seed
Browse files Browse the repository at this point in the history
🐛 Add support for setting random seed for prompt tuning training
  • Loading branch information
gkumbhat authored Oct 2, 2023
2 parents 46af26e + e6b69a2 commit f10f415
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions caikit_nlp/modules/text_generation/peft_prompt_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from transformers.optimization import get_linear_schedule_with_warmup
import numpy as np
import torch
import transformers

# First Party
from caikit.core.data_model import DataStream
Expand Down Expand Up @@ -107,6 +108,7 @@ class PeftPromptTuning(ModuleBase):
# TuningType.LORA: PeftType.LORA,
}

RANDOM_SEED = 73
supported_resources = [HFAutoCausalLM, HFAutoSeq2SeqLM]

################################ Constructor / Destructor #####################################
Expand Down Expand Up @@ -298,6 +300,7 @@ def train(
accumulate_steps: Optional[int] = 32,
torch_dtype: Optional[str] = None, # TODO: Optional[Union[torch.dtype, str]]
silence_progress_bars: Optional[bool] = True,
random_seed: int = RANDOM_SEED,
**kwargs,
) -> "PeftPromptTuning":
"""Run prompt tuning (vanilla or MPT) through PEFT on a CausalLM or Seq2seq model
Expand Down Expand Up @@ -343,11 +346,19 @@ def train(
underpinning the resource will be converted in place to the correct torch dtype.
silence_progress_bars: bool
Silences TQDM progress bars at train time. Default: True.
random_seed: int
Integer to be used as random seed for training.
Returns:
PeftPromptTuning
Instance of this class with tuned prompt vectors.
"""

# Configure random seed
transformers.set_seed(random_seed)
# NOTE: Following can be uncommented to allow full determinism
# but it can have impact on performance.
# transformers.enable_full_determinism(random_seed)

# HACK - These things can't be passed through the train API currently

metric = kwargs.get("metric")
Expand Down

0 comments on commit f10f415

Please sign in to comment.