From 6be7b7c8adcfd5d4ca5a6f3802eebe0f5875dff2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christoph=20M=2E=20L=C3=BCscher?= Date: Mon, 9 Dec 2024 10:23:53 +0100 Subject: [PATCH] add returnn ppl job --- returnn/perplexity.py | 93 +++++++++++++++++++++++++++++++++++++++++++ returnn/training.py | 10 ++--- 2 files changed, 98 insertions(+), 5 deletions(-) create mode 100644 returnn/perplexity.py diff --git a/returnn/perplexity.py b/returnn/perplexity.py new file mode 100644 index 00000000..67987a76 --- /dev/null +++ b/returnn/perplexity.py @@ -0,0 +1,93 @@ +__all__ = ["ReturnnCalculatePerplexityJob"] + +import shutil +import subprocess as sp +from typing import Union + +from sisyphus import Job, Task, setup_path, tk + +import i6_core.util as util + +from .config import ReturnnConfig +from .training import PtCheckpoint, Checkpoint + +Path = setup_path(__package__) + + +class ReturnnCalculatePerplexityJob(Job): + """ + Calculates the perplexity of a language model trained in RETURNN + on an evaluation data set + """ + + def __init__( + self, + returnn_config: ReturnnConfig, + returnn_model: Union[PtCheckpoint, Checkpoint], + eval_dataset: tk.Path, + *, + log_verbosity: int = 3, + returnn_root: tk.Path, + returnn_python_exe: tk.Path, + ): + returnn_config.config.pop("train") + returnn_config.config.pop("dev") + returnn_config.config["eval_datasets"] = {"eval": eval_dataset} + + # TODO verify paths + if isinstance(returnn_model, PtCheckpoint): + model_path = returnn_model.path + self.add_input(returnn_model.path) + elif isinstance(returnn_model, Checkpoint): + model_path = returnn_model.index_path + self.add_input(returnn_model.index_path) + else: + raise NotImplementedError(f"returnn model has unknown type: {type(returnn_model)}") + + returnn_config.config["model"] = model_path + + returnn_config.post_config["log_verbosity"] = log_verbosity + + self.returnn_config = returnn_config + + self.returnn_python_exe = returnn_python_exe + self.returnn_root = returnn_root + + self.out_returnn_config_file = self.output_path("returnn.config") + self.out_returnn_log = self.output_path("returnn.log") + self.out_perplexities = self.output_var("ppl_score") + + self.rqmt = {"gpu": 0, "cpu": 2, "mem": 4, "time": 4} + + def tasks(self): + yield Task("create_files", mini_task=True) + yield Task("run", resume="run", rqmt=self.rqmt) + yield Task("gather", mini_task=True) + + def _get_run_cmd(self): + run_cmd = [ + self.returnn_python_exe.get_path(), + self.returnn_root.join_right("rnn.py").get_path(), + self.out_returnn_config_file.get_path(), + "++task eval", + ] + return run_cmd + + def create_files(self): + self.returnn_config.write(self.out_returnn_config_file.get_path()) + + util.create_executable("rnn.sh", self._get_run_cmd()) + + def run(self): + sp.check_call(self._get_run_cmd()) + + shutil.move("returnn_log", self.out_returnn_log.get_path()) + + def gather(self): + for data_key in self.out_perplexities.keys(): + print(data_key) + + @classmethod + def hash(cls, parsed_args): + del parsed_args["log_verbosity"] + return super().hash(parsed_args) diff --git a/returnn/training.py b/returnn/training.py index 03e17127..6e6f565f 100644 --- a/returnn/training.py +++ b/returnn/training.py @@ -35,12 +35,12 @@ class ReturnnModel: This is deprecated, use :class:`Checkpoint` instead. """ - def __init__(self, returnn_config_file, model, epoch): + def __init__(self, returnn_config_file: Path, model: Path, epoch: int): """ - :param Path returnn_config_file: Path to a returnn config file - :param Path model: Path to a RETURNN checkpoint (only the .meta for Tensorflow) - :param int epoch: + :param returnn_config_file: Path to a returnn config file + :param model: Path to a RETURNN checkpoint (only the .meta for Tensorflow) + :param epoch: """ self.returnn_config_file = returnn_config_file self.model = model @@ -52,7 +52,7 @@ class Checkpoint: Checkpoint object which holds the (Tensorflow) index file path as tk.Path, and will return the checkpoint path as common prefix of the .index/.meta/.data[...] - A checkpoint object should directly assigned to a RasrConfig entry (do not call `.ckpt_path`) + A checkpoint object should directly be assigned to a RasrConfig entry (do not call `.ckpt_path`) so that the hash will resolve correctly """