diff --git a/dp_creator_ii/template.py b/dp_creator_ii/template.py index 91ad658..767c03a 100644 --- a/dp_creator_ii/template.py +++ b/dp_creator_ii/template.py @@ -19,18 +19,18 @@ def __init__(self, path): template_path = Path(__file__).parent / "templates" / path self._template = template_path.read_text() - def fill_expressions(self, map): - for k, v in map.items(): + def fill_expressions(self, **kwargs): + for k, v in kwargs.items(): self._template = self._template.replace(k, v) return self - def fill_values(self, map): - for k, v in map.items(): + def fill_values(self, **kwargs): + for k, v in kwargs.items(): self._template = self._template.replace(k, repr(v)) return self - def fill_blocks(self, map): - for k, v in map.items(): + def fill_blocks(self, **kwargs): + for k, v in kwargs.items(): k_re = re.escape(k) self._template = re.sub( rf"^(\s*){k_re}$", @@ -53,12 +53,10 @@ def __str__(self): def _make_context_for_notebook(csv_path, unit, loss, weights): return str( _Template("context.py").fill_values( - { - "CSV_PATH": csv_path, - "UNIT": unit, - "LOSS": loss, - "WEIGHTS": weights, - } + CSV_PATH=csv_path, + UNIT=unit, + LOSS=loss, + WEIGHTS=weights, ) ) @@ -66,33 +64,31 @@ def _make_context_for_notebook(csv_path, unit, loss, weights): def _make_context_for_script(unit, loss, weights): return str( _Template("context.py") - .fill_expressions({"CSV_PATH": "csv_path"}) + .fill_expressions( + CSV_PATH="csv_path", + ) .fill_values( - { - "UNIT": unit, - "LOSS": loss, - "WEIGHTS": weights, - } + UNIT=unit, + LOSS=loss, + WEIGHTS=weights, ) ) def _make_imports(): - return str(_Template("imports.py").fill_values({})) + return str(_Template("imports.py").fill_values()) def make_notebook_py(csv_path, unit, loss, weights): return str( _Template("notebook.py").fill_blocks( - { - "IMPORTS_BLOCK": _make_imports(), - "CONTEXT_BLOCK": _make_context_for_notebook( - csv_path=csv_path, - unit=unit, - loss=loss, - weights=weights, - ), - } + IMPORTS_BLOCK=_make_imports(), + CONTEXT_BLOCK=_make_context_for_notebook( + csv_path=csv_path, + unit=unit, + loss=loss, + weights=weights, + ), ) ) @@ -100,13 +96,11 @@ def make_notebook_py(csv_path, unit, loss, weights): def make_script_py(unit, loss, weights): return str( _Template("script.py").fill_blocks( - { - "IMPORTS_BLOCK": _make_imports(), - "CONTEXT_BLOCK": _make_context_for_script( - unit=unit, - loss=loss, - weights=weights, - ), - } + IMPORTS_BLOCK=_make_imports(), + CONTEXT_BLOCK=_make_context_for_script( + unit=unit, + loss=loss, + weights=weights, + ), ) ) diff --git a/dp_creator_ii/tests/test_template.py b/dp_creator_ii/tests/test_template.py index 073c766..4d9e509 100644 --- a/dp_creator_ii/tests/test_template.py +++ b/dp_creator_ii/tests/test_template.py @@ -13,12 +13,10 @@ def test_fill_template(): context_template = _Template("context.py") context_block = str( context_template.fill_values( - { - "CSV_PATH": fake_csv, - "UNIT": 1, - "LOSS": 1, - "WEIGHTS": [1], - } + CSV_PATH=fake_csv, + UNIT=1, + LOSS=1, + WEIGHTS=[1], ) ) assert f"data=pl.scan_csv('{fake_csv}', encoding=\"utf8-lossy\")" in context_block @@ -30,7 +28,7 @@ def test_fill_template_unfilled_slots(): Exception, match=re.escape("context.py has unfilled slots: CSV_PATH, LOSS, UNIT, WEIGHTS"), ): - str(context_template.fill_values({})) + str(context_template.fill_values()) def test_make_notebook():