Skip to content

Commit

Permalink
Merge pull request caikit#148 from joerunde/value-errors
Browse files Browse the repository at this point in the history
🔊 add value check logs
  • Loading branch information
joerunde authored Aug 24, 2023
2 parents 69c396c + 8faff38 commit d7433d8
Showing 1 changed file with 18 additions and 3 deletions.
21 changes: 18 additions & 3 deletions caikit_nlp/modules/text_generation/peft_prompt_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,10 @@ def train(
init_method = tuning_config.prompt_tuning_init_method

error.value_check(
"<NLP11848053E>", init_method in allowed_tuning_init_methods
"<NLP11848053E>",
init_method in allowed_tuning_init_methods,
f"Init method [{init_method}] not in allowed init methods: "
f"[{allowed_tuning_init_methods}]",
)

init_method = MultitaskPromptTuningInit(init_method)
Expand Down Expand Up @@ -438,7 +441,10 @@ def train(
)

error.value_check(
"<NLP30542004E>", len(output_model_types) <= base_model.MAX_NUM_TRANSFORMERS
"<NLP30542004E>",
len(output_model_types) <= base_model.MAX_NUM_TRANSFORMERS,
f"Too many output model types. Got {len(output_model_types)}, "
f"maximum {base_model.MAX_NUM_TRANSFORMERS}",
)
# Ensure that our verbalizer is a string and will not render to a hardcoded string
error.value_check(
Expand All @@ -456,6 +462,8 @@ def train(
error.value_check(
"<NLP65714994E>",
tuning_type in TuningType._member_names_,
f"Invalid tuning type [{tuning_type}]. Allowed types: "
f"[{TuningType._member_names_}]",
)
tuning_type = TuningType(tuning_type)
error.type_check("<NLP65714993E>", TuningType, tuning_type=tuning_type)
Expand Down Expand Up @@ -721,7 +729,11 @@ def get_exportable_prompt_vectors(
# Our model should only have one or two transformer modules; PEFT config lets you
# arbitrarily configure these, but the slicing assumptions for the prompt tuning
# seem to assume this...
error.value_check("<NLP83837722E>", 1 <= num_transformer_submodules <= 2)
error.value_check(
"<NLP83837722E>",
1 <= num_transformer_submodules <= 2,
f"Only 1 or 2 transformer submodules allowed. {num_transformer_submodules} detected.",
)
# Get the prompt vectors.
if tuning_type == TuningType.PROMPT_TUNING: # Should also be done for prefix
# NOTE; If this is done for MPT, we get the SHARED prompt vector.
Expand Down Expand Up @@ -762,6 +774,9 @@ def get_exportable_prompt_vectors(
error.value_check(
"<NLP83444722E>",
prompt_vector.shape[0] == num_transformer_submodules * num_virtual_tokens,
f"Row mismatch: Expected num_transformer_submodules * num_virtual_tokens "
f"({num_transformer_submodules * num_virtual_tokens}) "
f"but got f{prompt_vector.shape[0]}",
)

# Otherwise it depends on the number of transformer modules. See seq2seq forward()
Expand Down

0 comments on commit d7433d8

Please sign in to comment.