Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RETURNN compute perplexity job #563

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions returnn/perplexity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
__all__ = ["ReturnnCalculatePerplexityJob"]

import shutil
import subprocess as sp
from typing import Union

from sisyphus import Job, Task, setup_path, tk

import i6_core.util as util

from .config import ReturnnConfig
from .training import PtCheckpoint, Checkpoint

Path = setup_path(__package__)


class ReturnnCalculatePerplexityJob(Job):
"""
Calculates the perplexity of a language model trained in RETURNN
on an evaluation data set
"""

def __init__(
self,
returnn_config: ReturnnConfig,
returnn_model: Union[PtCheckpoint, Checkpoint],
eval_dataset: tk.Path,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
eval_dataset: tk.Path,
eval_dataset: Union[tk.Path, Dict[str, Any]],

in principle any dataset is valid? Actually, does eval_datasets = {"eval": "/path/to/file"} constitute a correct dataset definition?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have not tested if any dataset is valid. I am assuming that any dataset with text should be valid.. but this should be used with LmDataset.

No that is not the correct dataset definition. It needs to be something like: {"eval": "class": "LmDatset", ...}

*,
log_verbosity: int = 3,
returnn_root: tk.Path,
returnn_python_exe: tk.Path,
):
returnn_config.config.pop("train")
returnn_config.config.pop("dev")
returnn_config.config["eval_datasets"] = {"eval": eval_dataset}

# TODO verify paths
if isinstance(returnn_model, PtCheckpoint):
model_path = returnn_model.path
self.add_input(returnn_model.path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should already be input to the Job as we pass the PT/Checkpoint object to the Job.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure??
Pt/Checkpoint are normal Python classes which I think are not covered by the extract_paths from sisyphus https://github.com/rwth-i6/sisyphus/blob/master/sisyphus/tools.py#L74
or am I missing something?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure??

never without actually testing ;)

but very confident because:

  • in recognition we also just give the checkpoint object and it works

are not covered by the extract_paths from sisyphus

  • extract_paths should arrive in the last else and then call get_object_state which should then via get_members_ descend into the dict of the Checkpoint object and return the underlying tk.Path object.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yes. thanks for the explanation :)

elif isinstance(returnn_model, Checkpoint):
model_path = returnn_model.index_path
self.add_input(returnn_model.index_path)
else:
raise NotImplementedError(f"returnn model has unknown type: {type(returnn_model)}")

returnn_config.config["model"] = model_path

returnn_config.post_config["log_verbosity"] = log_verbosity

self.returnn_config = returnn_config

self.returnn_python_exe = returnn_python_exe
self.returnn_root = returnn_root

self.out_returnn_config_file = self.output_path("returnn.config")
self.out_returnn_log = self.output_path("returnn.log")
self.out_perplexities = self.output_var("ppl_score")

self.rqmt = {"gpu": 0, "cpu": 2, "mem": 4, "time": 4}

def tasks(self):
yield Task("create_files", mini_task=True)
yield Task("run", resume="run", rqmt=self.rqmt)
yield Task("gather", mini_task=True)

def _get_run_cmd(self):
run_cmd = [
self.returnn_python_exe.get_path(),
self.returnn_root.join_right("rnn.py").get_path(),
self.out_returnn_config_file.get_path(),
"++task eval",
]
return run_cmd

def create_files(self):
self.returnn_config.write(self.out_returnn_config_file.get_path())

util.create_executable("rnn.sh", self._get_run_cmd())

def run(self):
sp.check_call(self._get_run_cmd())

shutil.move("returnn_log", self.out_returnn_log.get_path())

def gather(self):
for data_key in self.out_perplexities.keys():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tk.Variable has no keys()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks I think that is from an earlier version...

print(data_key)

@classmethod
def hash(cls, parsed_args):
del parsed_args["log_verbosity"]
return super().hash(parsed_args)
10 changes: 5 additions & 5 deletions returnn/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ class ReturnnModel:
This is deprecated, use :class:`Checkpoint` instead.
"""

def __init__(self, returnn_config_file, model, epoch):
def __init__(self, returnn_config_file: Path, model: Path, epoch: int):
"""

:param Path returnn_config_file: Path to a returnn config file
:param Path model: Path to a RETURNN checkpoint (only the .meta for Tensorflow)
:param int epoch:
:param returnn_config_file: Path to a returnn config file
:param model: Path to a RETURNN checkpoint (only the .meta for Tensorflow)
:param epoch:
"""
self.returnn_config_file = returnn_config_file
self.model = model
Expand All @@ -52,7 +52,7 @@ class Checkpoint:
Checkpoint object which holds the (Tensorflow) index file path as tk.Path,
and will return the checkpoint path as common prefix of the .index/.meta/.data[...]

A checkpoint object should directly assigned to a RasrConfig entry (do not call `.ckpt_path`)
A checkpoint object should be directly assigned to a RasrConfig entry (do not call `.ckpt_path`)
so that the hash will resolve correctly
"""

Expand Down
Loading