Skip to content

Commit

Permalink
Merge pull request caikit#267 from Ssukriti/add_preserve_input_text_i…
Browse files Browse the repository at this point in the history
…nference

Add preserve input text inference
  • Loading branch information
gkumbhat authored Nov 15, 2023
2 parents cc0d081 + 45fe0b0 commit b35f910
Show file tree
Hide file tree
Showing 6 changed files with 159 additions and 2 deletions.
7 changes: 7 additions & 0 deletions caikit_nlp/modules/text_generation/peft_prompt_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,11 +170,16 @@ def run(
] = None,
stop_sequences: Optional[List[str]] = None,
seed: Optional[np.uint64] = None,
preserve_input_text: bool = True,
) -> GeneratedTextResult:
f"""
Run the full text generation model.
Args:
{GENERATE_FUNCTION_ARGS}
preserve_input_text: bool
Applicable only to Causal LLMs.
Whether or not the source string should be contained in the generated output,
e.g., as a prefix. Default True. (Source string will appear as prefix)
Returns:
GeneratedTextResult
Generated text result produced by PEFT / Transformers.
Expand All @@ -201,6 +206,8 @@ def run(
max_time=max_time,
exponential_decay_length_penalty=exponential_decay_length_penalty,
stop_sequences=stop_sequences,
preserve_input_text=preserve_input_text,
task_type=self.task_type,
)

# NOTE: We need to disable wip decorator here otherwise we get issues in
Expand Down
9 changes: 8 additions & 1 deletion caikit_nlp/modules/text_generation/text_generation_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,13 +534,18 @@ def run(
temperature: Optional[float] = None,
repetition_penalty: Optional[float] = None,
max_time: Optional[float] = None,
preserve_input_text: bool = True,
**kwargs,
) -> GeneratedTextResult:

f"""
Run the full text generation model.
Args:
{GENERATE_FUNCTION_ARGS}
{GENERATE_FUNCTION_ARGS},
preserve_input_text: bool
Applicable only to Causal LLMs.
Whether or not the source string should be contained in the generated output,
e.g., as a prefix. Default True. (Source string will appear as prefix)
Returns:
GeneratedTextResult
Generated text result produced by the model.
Expand All @@ -565,6 +570,8 @@ def run(
temperature=temperature,
repetition_penalty=repetition_penalty,
max_time=max_time,
preserve_input_text=preserve_input_text,
task_type=self.model.TASK_TYPE,
**kwargs,
)

Expand Down
40 changes: 39 additions & 1 deletion caikit_nlp/toolkit/text_generation/model_run_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ def generate_text_func(
Union[Tuple[int, float], ExponentialDecayLengthPenalty]
] = None,
stop_sequences: Optional[List[str]] = None,
preserve_input_text: Optional[bool] = True,
task_type: Optional[str] = None,
**kwargs,
):
"""
Expand All @@ -164,6 +166,12 @@ def generate_text_func(
Caikit producer id associated with the module
eos_token: str
End of sequence token to be used with generation
preserve_input_text: bool
Applicable only for CAUSAL_LM task type.
Whether or not the source string should be contained in the generated output,
e.g., as a prefix. Default True. (Source string will appear as prefix)
task_type: str or None
Task type such as CAUSAL_LM, SEQ_2_SEQ_LM, SEQ_CLS or None
{}
Returns:
GeneratedTextResult
Expand Down Expand Up @@ -235,6 +243,13 @@ def generate_text_func(
for g in generate_ids
]

if preserve_input_text is not True:
generated_text = __postprocess_remove_input_text(
tokenizer, preds, inputs, task_type
)
else:
generated_text = preds[0]

if (eos_token and tokenizer.decode(generate_ids[0, -1].item()) == eos_token) or (
generate_ids[0, -1] == tokenizer.eos_token_id
):
Expand All @@ -251,14 +266,37 @@ def generate_text_func(

return GeneratedTextResult(
generated_tokens=token_count,
generated_text=preds[0],
generated_text=generated_text,
finish_reason=finish_reason,
producer_id=producer_id,
input_token_count=input_token_count,
seed=seed,
)


def __postprocess_remove_input_text(tokenizer, preds, inputs, task_type):
"""For Causal LM task types, preserve_input_text set to False will
remove the input text from generated output.
"""
if task_type == "CAUSAL_LM":
prompt_length = len(
tokenizer.decode(
inputs["input_ids"][0],
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
)
generated_text = preds[0][prompt_length:]
else:
log.warning(
"<NLP16125792W>",
f"preserve_input_text flag is not applicable for task type {task_type}. \
Returning model generated prediction",
)
generated_text = preds[0]
return generated_text


def generate_text_func_stream(
model,
tokenizer,
Expand Down
10 changes: 10 additions & 0 deletions tests/modules/text_generation/test_peft_prompt_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,16 @@ def test_run_truncate_tokens_0(causal_lm_dummy_model):
assert isinstance(pred, GeneratedTextResult)


def test_run_with_preserve_input_text(causal_lm_dummy_model):
"""Ensure preserve input text removes input
from generated output when set to False"""
input_text = "This text doesn't matter"
pred = causal_lm_dummy_model.run(input_text, preserve_input_text=True)
assert input_text in pred.generated_text
pred = causal_lm_dummy_model.run(input_text, preserve_input_text=False)
assert input_text not in pred.generated_text


def test_run_sampling_param_ignored_greedy_decoding(causal_lm_dummy_model):
"""Ensure sampling parameter gets ignored when decoding method
is set to GREEDY
Expand Down
33 changes: 33 additions & 0 deletions tests/modules/text_generation/test_text_generation_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,39 @@ def test_train_model_causallm(disable_wip, set_cpu_device):
assert isinstance(pred, GeneratedTextResult)


############################## Inferencing flags ################################


@pytest.mark.skipif(platform.processor() == "arm", reason="ARM training not supported")
def test_train_model_causallm(disable_wip, set_cpu_device):
"""Ensure that we can finetune a causal-lm model on some toy data for 1+
steps & run inference."""
train_kwargs = {
"base_model": HFAutoCausalLM.bootstrap(
model_name=CAUSAL_LM_MODEL, tokenizer_name=CAUSAL_LM_MODEL
),
"num_epochs": 1,
"train_stream": caikit.core.data_model.DataStream.from_iterable(
[
GenerationTrainRecord(
input="@foo what a cute dog!", output="no complaint"
),
]
),
"torch_dtype": torch.float32,
}
model = TextGeneration.train(**train_kwargs)
assert isinstance(model.model, HFAutoCausalLM)

# Ensure that preserve_input_text returns input in output
pred = model.run("@bar what a cute cat!", preserve_input_text=True)
assert "@bar what a cute cat!" in pred.generated_text

# Ensure that preserve_input_text set to False, removes input from output
pred = model.run("@bar what a cute cat!", preserve_input_text=False)
assert "@bar what a cute cat!" not in pred.generated_text


############################## Error Cases ################################


Expand Down
62 changes: 62 additions & 0 deletions tests/toolkit/text_generation/test_model_run_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,65 @@ def test_generate_text_func_serialization_json(

serialized = getattr(generated_text, serialization_method)()
assert isinstance(serialized, expected_type)


@pytest.mark.parametrize("causal_model_fixture", ["causal_lm_dummy_model"])
def test_generate_text_func_preserve_input_causal_lm(request, causal_model_fixture):
"""For Causal LM task types, setting preserve_inout_text to True
will result in input text in model prediction. Setting to False will
strip the input text from model prediction.
"""
input_text = "What is the boiling point of liquid Nitrogen?"
causal_model = request.getfixturevalue(causal_model_fixture)
# assert type(causal_model.model) == False
generated_text = generate_text_func(
model=causal_model.model,
tokenizer=causal_model.tokenizer,
producer_id=ProducerId("TextGeneration", "0.1.0"),
eos_token="<\n>",
text=input_text,
preserve_input_text=True,
task_type="CAUSAL_LM",
)
assert input_text in generated_text.generated_text
generated_text = generate_text_func(
model=causal_model.model,
tokenizer=causal_model.tokenizer,
producer_id=ProducerId("TextGeneration", "0.1.0"),
eos_token="<\n>",
text=input_text,
preserve_input_text=False,
task_type="CAUSAL_LM",
)
assert input_text not in generated_text.generated_text


@pytest.mark.parametrize("seq_model_fixture", ["seq2seq_lm_dummy_model"])
def test_generate_text_func_preserve_input(request, seq_model_fixture):
"""For Seq2Seq LM task types, setting preserve_inout_text to True
or False should not change predictions.
"""
input_text = "What is the boiling point of liquid Nitrogen?"
seq_model = request.getfixturevalue(seq_model_fixture)
# assert type(causal_model.model) == False
generated_text = generate_text_func(
model=seq_model.model,
tokenizer=seq_model.tokenizer,
producer_id=ProducerId("TextGeneration", "0.1.0"),
eos_token="<\n>",
text=input_text,
preserve_input_text=True,
task_type="SEQ_2_SEQ_LM",
)
before_pred = generated_text.generated_text
generated_text = generate_text_func(
model=seq_model.model,
tokenizer=seq_model.tokenizer,
producer_id=ProducerId("TextGeneration", "0.1.0"),
eos_token="<\n>",
text=input_text,
preserve_input_text=False,
task_type="SEQ_2_SEQ_LM",
)
after_pred = generated_text.generated_text
assert before_pred == after_pred

0 comments on commit b35f910

Please sign in to comment.