Skip to content

Commit

Permalink
simplify the template syntax
Browse files Browse the repository at this point in the history
  • Loading branch information
mccalluc committed Sep 27, 2024
1 parent bb4aca4 commit 772e84c
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 43 deletions.
66 changes: 30 additions & 36 deletions dp_creator_ii/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,18 @@ def __init__(self, path):
template_path = Path(__file__).parent / "templates" / path
self._template = template_path.read_text()

def fill_expressions(self, map):
for k, v in map.items():
def fill_expressions(self, **kwargs):
for k, v in kwargs.items():
self._template = self._template.replace(k, v)
return self

def fill_values(self, map):
for k, v in map.items():
def fill_values(self, **kwargs):
for k, v in kwargs.items():
self._template = self._template.replace(k, repr(v))
return self

def fill_blocks(self, map):
for k, v in map.items():
def fill_blocks(self, **kwargs):
for k, v in kwargs.items():
k_re = re.escape(k)
self._template = re.sub(
rf"^(\s*){k_re}$",
Expand All @@ -53,60 +53,54 @@ def __str__(self):
def _make_context_for_notebook(csv_path, unit, loss, weights):
return str(
_Template("context.py").fill_values(
{
"CSV_PATH": csv_path,
"UNIT": unit,
"LOSS": loss,
"WEIGHTS": weights,
}
CSV_PATH=csv_path,
UNIT=unit,
LOSS=loss,
WEIGHTS=weights,
)
)


def _make_context_for_script(unit, loss, weights):
return str(
_Template("context.py")
.fill_expressions({"CSV_PATH": "csv_path"})
.fill_expressions(
CSV_PATH="csv_path",
)
.fill_values(
{
"UNIT": unit,
"LOSS": loss,
"WEIGHTS": weights,
}
UNIT=unit,
LOSS=loss,
WEIGHTS=weights,
)
)


def _make_imports():
return str(_Template("imports.py").fill_values({}))
return str(_Template("imports.py").fill_values())


def make_notebook_py(csv_path, unit, loss, weights):
return str(
_Template("notebook.py").fill_blocks(
{
"IMPORTS_BLOCK": _make_imports(),
"CONTEXT_BLOCK": _make_context_for_notebook(
csv_path=csv_path,
unit=unit,
loss=loss,
weights=weights,
),
}
IMPORTS_BLOCK=_make_imports(),
CONTEXT_BLOCK=_make_context_for_notebook(
csv_path=csv_path,
unit=unit,
loss=loss,
weights=weights,
),
)
)


def make_script_py(unit, loss, weights):
return str(
_Template("script.py").fill_blocks(
{
"IMPORTS_BLOCK": _make_imports(),
"CONTEXT_BLOCK": _make_context_for_script(
unit=unit,
loss=loss,
weights=weights,
),
}
IMPORTS_BLOCK=_make_imports(),
CONTEXT_BLOCK=_make_context_for_script(
unit=unit,
loss=loss,
weights=weights,
),
)
)
12 changes: 5 additions & 7 deletions dp_creator_ii/tests/test_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,10 @@ def test_fill_template():
context_template = _Template("context.py")
context_block = str(
context_template.fill_values(
{
"CSV_PATH": fake_csv,
"UNIT": 1,
"LOSS": 1,
"WEIGHTS": [1],
}
CSV_PATH=fake_csv,
UNIT=1,
LOSS=1,
WEIGHTS=[1],
)
)
assert f"data=pl.scan_csv('{fake_csv}')" in context_block
Expand All @@ -30,7 +28,7 @@ def test_fill_template_unfilled_slots():
Exception,
match=re.escape("context.py has unfilled slots: CSV_PATH, LOSS, UNIT, WEIGHTS"),
):
str(context_template.fill_values({}))
str(context_template.fill_values())


def test_make_notebook():
Expand Down

0 comments on commit 772e84c

Please sign in to comment.