Skip to content

Commit

Permalink
🎨 fmt
Browse files Browse the repository at this point in the history
Signed-off-by: Joe Runde <[email protected]>
  • Loading branch information
joerunde committed Aug 24, 2023
1 parent 674df4d commit 8faff38
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions caikit_nlp/modules/text_generation/peft_prompt_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,8 @@ def train(
error.value_check(
"<NLP11848053E>",
init_method in allowed_tuning_init_methods,
f"Init method [{init_method}] not in allowed init methods: [{allowed_tuning_init_methods}]",
f"Init method [{init_method}] not in allowed init methods: "
f"[{allowed_tuning_init_methods}]",
)

init_method = MultitaskPromptTuningInit(init_method)
Expand Down Expand Up @@ -442,7 +443,8 @@ def train(
error.value_check(
"<NLP30542004E>",
len(output_model_types) <= base_model.MAX_NUM_TRANSFORMERS,
f"Too many output model types. Got {len(output_model_types)}, maximum {base_model.MAX_NUM_TRANSFORMERS}",
f"Too many output model types. Got {len(output_model_types)}, "
f"maximum {base_model.MAX_NUM_TRANSFORMERS}",
)
# Ensure that our verbalizer is a string and will not render to a hardcoded string
error.value_check(
Expand All @@ -460,7 +462,8 @@ def train(
error.value_check(
"<NLP65714994E>",
tuning_type in TuningType._member_names_,
f"Invalid tuning type [{tuning_type}]. Allowed types: [{TuningType._member_names_}]",
f"Invalid tuning type [{tuning_type}]. Allowed types: "
f"[{TuningType._member_names_}]",
)
tuning_type = TuningType(tuning_type)
error.type_check("<NLP65714993E>", TuningType, tuning_type=tuning_type)
Expand Down Expand Up @@ -771,7 +774,9 @@ def get_exportable_prompt_vectors(
error.value_check(
"<NLP83444722E>",
prompt_vector.shape[0] == num_transformer_submodules * num_virtual_tokens,
f"Row mismatch: Expected num_transformer_submodules * num_virtual_tokens ({num_transformer_submodules * num_virtual_tokens} but got f{prompt_vector.shape[0]})",
f"Row mismatch: Expected num_transformer_submodules * num_virtual_tokens "
f"({num_transformer_submodules * num_virtual_tokens}) "
f"but got f{prompt_vector.shape[0]}",
)

# Otherwise it depends on the number of transformer modules. See seq2seq forward()
Expand Down

0 comments on commit 8faff38

Please sign in to comment.