Skip to content

Commit

Permalink
Merge pull request caikit#232 from alex-jw-brooks/eval_seq_lens
Browse files Browse the repository at this point in the history
Add support for sequence lengths in eval
  • Loading branch information
gkumbhat authored Oct 11, 2023
2 parents a1b5674 + b494dd3 commit 0fdaf9d
Showing 1 changed file with 27 additions and 3 deletions.
30 changes: 27 additions & 3 deletions examples/evaluate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,25 @@ def parse_args() -> argparse.Namespace:
help="JSON file to dump raw source / target texts to.",
default="model_preds.json",
)
parser.add_argument(
"--max_new_tokens",
help="Maximum number of new tokens to be generated",
type=int,
default=20,
)
parser.add_argument(
"--truncate_input_tokens",
help="Number of allowed input tokens (no truncation=0)",
type=int,
default=0,
)
args = parser.parse_args()
return args


def get_model_preds_and_references(model, validation_stream):
def get_model_preds_and_references(
model, validation_stream, max_new_tokens, truncate_input_tokens
):
"""Given a model & a validation stream, run the model against every example in the validation
stream and compare the outputs to the target/output sequence.
Expand All @@ -79,6 +93,10 @@ def get_model_preds_and_references(model, validation_stream):
validation_stream: DataStream[GenerationTrainRecord]
Validation stream with labeled targets that we want to compare to our model's
predictions.
max_new_tokens: int
Max number of new tokens to be generated, i.e., output limit
truncate_input_tokens: int
Number of allowed input tokens, i.e., input limit
Returns:
Tuple(List)
Expand All @@ -90,7 +108,11 @@ def get_model_preds_and_references(model, validation_stream):
for datum in tqdm(validation_stream):
# Local .run() currently prepends the input text to the generated string;
# Ensure that we're just splitting the first predicted token & beyond.
raw_model_text = model.run(datum.input).generated_text
raw_model_text = model.run(
datum.input,
max_new_tokens=max_new_tokens,
truncate_input_tokens=truncate_input_tokens,
).generated_text
parse_pred_text = raw_model_text.split(datum.input)[-1].strip()
model_preds.append(parse_pred_text)
targets.append(datum.output)
Expand Down Expand Up @@ -153,7 +175,9 @@ def export_model_preds(preds_file, predictions, validation_stream, verbalizer):

# Run the data through the model; save the predictions & references
print_colored("Getting model predictions...")
predictions, references = get_model_preds_and_references(model, validation_stream)
predictions, references = get_model_preds_and_references(
model, validation_stream, args.max_new_tokens, args.truncate_input_tokens
)
print_colored(
"Exporting model preds, source, verbalized source, and ground truth targets to {}".format(
args.preds_file
Expand Down

0 comments on commit 0fdaf9d

Please sign in to comment.