From d0cf85ee3137959a1667bc4d84b74e7fce3911a8 Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Tue, 3 Sep 2024 10:23:14 +0200 Subject: [PATCH] Parse `TextDict`s chunkwise to avoid `OverflowError` --- returnn/search.py | 26 ++++++++------------------ text/convert.py | 10 ++++------ util.py | 21 +++++++++++++++++++++ 3 files changed, 33 insertions(+), 24 deletions(-) diff --git a/returnn/search.py b/returnn/search.py index 472460a5..a4bff091 100644 --- a/returnn/search.py +++ b/returnn/search.py @@ -348,8 +348,7 @@ def tasks(self): yield Task("run", mini_task=True) def run(self): - d = eval(util.uopen(self.search_py_output, "rt").read(), {"nan": float("nan"), "inf": float("inf")}) - assert isinstance(d, dict) # seq_tag -> bpe string + d = util.parse_text_dict(self.search_py_output) assert not os.path.exists(self.out_word_search_results.get_path()) with util.uopen(self.out_word_search_results, "wt") as out: out.write("{\n") @@ -400,8 +399,7 @@ def tasks(self): yield Task("run", mini_task=True) def run(self): - d = eval(util.uopen(self.search_py_output, "rt").read(), {"nan": float("nan"), "inf": float("inf")}) - assert isinstance(d, dict) # seq_tag -> bpe string + d = util.parse_text_dict(self.search_py_output) assert not os.path.exists(self.out_search_results.get_path()) def _transform_text(s: str): @@ -446,8 +444,7 @@ def tasks(self): def run(self): corpus = Corpus() corpus.load(self.bliss_corpus.get_path()) - d = eval(util.uopen(self.recog_words_file.get_path(), "rt").read()) - assert isinstance(d, dict), "only search output file with dict format is supported" + d = util.parse_text_dict(self.recog_words_file) with util.uopen(self.out_ctm_file.get_path(), "wt") as out: out.write(";; []\n") for seg in corpus.segments(): @@ -531,10 +528,7 @@ def tasks(self): yield Task("run", mini_task=True) def run(self): - # nan/inf should not be needed, but avoids errors at this point and will print an error below, - # that we don't expect an N-best list here. - d = eval(util.uopen(self.recog_words_file, "rt").read(), {"nan": float("nan"), "inf": float("inf")}) - assert isinstance(d, dict), "only search output file with dict format is supported" + d = util.parse_text_dict(self.recog_words_file) if self.seq_order_file is not None: seq_order = eval(util.uopen(self.seq_order_file, "rt").read(), {"nan": float("nan"), "inf": float("inf")}) assert isinstance(seq_order, (dict, list, tuple)) @@ -647,8 +641,7 @@ def tasks(self): def run(self): """run""" - d = eval(util.uopen(self.search_py_output, "rt").read(), {"nan": float("nan"), "inf": float("inf")}) - assert isinstance(d, dict) # seq_tag -> bpe string + d = util.parse_text_dict(self.search_py_output) assert not os.path.exists(self.out_best_search_results.get_path()) with util.uopen(self.out_best_search_results, "wt") as out: out.write("{\n") @@ -686,8 +679,7 @@ def tasks(self): def run(self): """run""" - d = eval(util.uopen(self.search_py_output, "rt").read(), {"nan": float("nan"), "inf": float("inf")}) - assert isinstance(d, dict) # seq_tag -> bpe string + d = util.parse_text_dict(self.search_py_output) assert not os.path.exists(self.out_search_results.get_path()) with util.uopen(self.out_search_results, "wt") as out: out.write("{\n") @@ -727,8 +719,7 @@ def tasks(self): def run(self): """run""" - d = eval(util.uopen(self.search_py_output, "rt").read(), {"nan": float("nan"), "inf": float("inf")}) - assert isinstance(d, dict) # seq_tag -> bpe string + d = util.parse_text_dict(self.search_py_output) assert not os.path.exists(self.out_search_results.get_path()) with util.uopen(self.out_search_results, "wt") as out: out.write("{\n") @@ -786,8 +777,7 @@ def logsumexp(*args): lsp = numpy.log(sum(numpy.exp(a - a_max) for a in args)) return a_max + lsp - d = eval(util.uopen(self.search_py_output, "rt").read(), {"nan": float("nan"), "inf": float("inf")}) - assert isinstance(d, dict) # seq_tag -> bpe string + d = util.parse_text_dict(self.search_py_output) assert not os.path.exists(self.out_search_results.get_path()) with util.uopen(self.out_search_results, "wt") as out: out.write("{\n") diff --git a/text/convert.py b/text/convert.py index f9079feb..a832aa80 100644 --- a/text/convert.py +++ b/text/convert.py @@ -3,10 +3,10 @@ "TextDictToStmJob", ] -from typing import Optional, Union, Sequence, Dict, List, Tuple +from typing import Union, Sequence, Dict, Tuple import re from sisyphus import Job, Path, Task -from i6_core.util import uopen +from i6_core.util import parse_text_dict, uopen class TextDictToTextLinesJob(Job): @@ -30,8 +30,7 @@ def tasks(self): def run(self): # nan/inf should not be needed, but avoids errors at this point and will print an error below, # that we don't expect an N-best list here. - d = eval(uopen(self.text_dict, "rt").read(), {"nan": float("nan"), "inf": float("inf")}) - assert isinstance(d, dict) # seq_tag -> text + d = parse_text_dict(self.text_dict) with uopen(self.out_text_lines, "wt") as out: for seq_tag, entry in d.items(): @@ -83,8 +82,7 @@ def tasks(self): def run(self): # nan/inf should not be needed, but avoids errors at this point and will print an error below, # that we don't expect an N-best list here. - c = eval(uopen(self.text_dict, "rt").read(), {"nan": float("nan"), "inf": float("inf")}) - assert isinstance(c, dict) + c = parse_text_dict(self.text_dict) all_tags = [ ("d%d" % i, "default%d" % i, "all other segments of category %d" % i) diff --git a/util.py b/util.py index 4bc8fe56..617096a2 100644 --- a/util.py +++ b/util.py @@ -383,3 +383,24 @@ def update_nested_dict(dict1: Dict[str, Any], dict2: Dict[str, Any]): else: dict1[k] = v return dict1 + + +def parse_text_dict(path: Union[str, tk.Path]) -> Dict[str, str]: + """ + Loads the text dict at :param:`path` making sure not to trigger line counter overflow. + """ + + with uopen(path, "rt") as text_dict_file: + txt = text_dict_file.read() + + # remove leading and trailing dict brackets + txt = txt.strip().strip("{}").strip() + + lines = txt.splitlines() + result = { + k: v + # parse chunkwise to avoid line counter overflow when the text dict is very large + for chunk in chunks(lines, max(1, len(lines) // 1000)) + for k, v in eval("\n".join(["{", *chunk, "}"]), {"nan": float("nan"), "inf": float("inf")}).items() + } + return result