diff --git a/examples/evaluate_model.py b/examples/evaluate_model.py index e74601df..3a902527 100644 --- a/examples/evaluate_model.py +++ b/examples/evaluate_model.py @@ -65,11 +65,25 @@ def parse_args() -> argparse.Namespace: help="JSON file to dump raw source / target texts to.", default="model_preds.json", ) + parser.add_argument( + "--max_new_tokens", + help="Maximum number of new tokens to be generated", + type=int, + default=20, + ) + parser.add_argument( + "--truncate_input_tokens", + help="Number of allowed input tokens (no truncation=0)", + type=int, + default=0, + ) args = parser.parse_args() return args -def get_model_preds_and_references(model, validation_stream): +def get_model_preds_and_references( + model, validation_stream, max_new_tokens, truncate_input_tokens +): """Given a model & a validation stream, run the model against every example in the validation stream and compare the outputs to the target/output sequence. @@ -79,6 +93,10 @@ def get_model_preds_and_references(model, validation_stream): validation_stream: DataStream[GenerationTrainRecord] Validation stream with labeled targets that we want to compare to our model's predictions. + max_new_tokens: int + Max number of new tokens to be generated, i.e., output limit + truncate_input_tokens: int + Number of allowed input tokens, i.e., input limit Returns: Tuple(List) @@ -90,7 +108,11 @@ def get_model_preds_and_references(model, validation_stream): for datum in tqdm(validation_stream): # Local .run() currently prepends the input text to the generated string; # Ensure that we're just splitting the first predicted token & beyond. - raw_model_text = model.run(datum.input).generated_text + raw_model_text = model.run( + datum.input, + max_new_tokens=max_new_tokens, + truncate_input_tokens=truncate_input_tokens, + ).generated_text parse_pred_text = raw_model_text.split(datum.input)[-1].strip() model_preds.append(parse_pred_text) targets.append(datum.output) @@ -153,7 +175,9 @@ def export_model_preds(preds_file, predictions, validation_stream, verbalizer): # Run the data through the model; save the predictions & references print_colored("Getting model predictions...") - predictions, references = get_model_preds_and_references(model, validation_stream) + predictions, references = get_model_preds_and_references( + model, validation_stream, args.max_new_tokens, args.truncate_input_tokens + ) print_colored( "Exporting model preds, source, verbalized source, and ground truth targets to {}".format( args.preds_file