Skip to content

Commit

Permalink
feat(schema)!: add spinner schema validation
Browse files Browse the repository at this point in the history
  • Loading branch information
leiteg committed Sep 5, 2024
1 parent 8205d17 commit 2a53f5e
Show file tree
Hide file tree
Showing 2 changed files with 219 additions and 1 deletion.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ dependencies = [
"ipykernel==6.29.4",
"tokenize-rt==5.2.0",
"seaborn==0.13.2",
"scipy==1.13.1"
"scipy==1.13.1",
"pydantic~=2.8",
]

[project.optional-dependencies]
Expand Down
217 changes: 217 additions & 0 deletions spinner/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
from __future__ import annotations

import ast
from functools import cache, cached_property
from typing import Any, Literal, Self

from jinja2 import Environment, Template, meta
from pydantic import (
BaseModel,
Field,
PositiveFloat,
PositiveInt,
RootModel,
model_validator,
)


class SpinnerMetadata(BaseModel, str_strip_whitespace=True):
description: str
version: str = Field(pattern=r"v?\d+\.\d+(\.\d+)?")
runs: int = Field(gt=0)
timeout: PositiveFloat | None = Field(default=None, gt=0.0)
retry: bool = Field(default=False)
retry_limit: PositiveInt = Field(default=1, ge=0)


class SpinnerCommand(BaseModel, str_strip_whitespace=True):
template: str

def __hash__(self) -> int:
return hash(self.template)

@cache
def parse(self, env: Environment | None = None) -> Template:
if not env:
env = Environment()
return env.parse(self.template)


SpinnerOutputType = Literal["contains"]


class SpinnerLambda(BaseModel, str_strip_whitespace=True):
name: str
func: str = Field(alias="lambda")

@model_validator(mode="after")
def validate_lambda(self) -> Self:
try:
ast.parse(source=self.func)
except SyntaxError as e:
raise ValueError(f"syntax error: {e}") from e
return self


class SpinnerOutput(BaseModel, str_strip_whitespace=True):
type: SpinnerOutputType
pattern: str
to_float: SpinnerLambda


class SpinnerPlot(BaseModel, str_strip_whitespace=True):
title: str = Field(default="")
x_axis: str
y_axis: str
group_by: str | list[str] | None = Field(default=None)


class SpinnerApplication(BaseModel, str_strip_whitespace=True):
command: SpinnerCommand
output: list[SpinnerOutput] = Field(default_factory=list)
plot: list[SpinnerPlot] = Field(default_factory=list)

def _validate_plot(self, plot: SpinnerPlot) -> None:
assert plot.x_axis in self.variables, f"unknown x-axis {plot.x_axis!r}"
assert plot.y_axis in self.variables, f"unknonw y-axis {plot.y_axis!r}"
assert plot.group_by in self.variables, f"unknown group by {plot.x_axis!r}"

@model_validator(mode="after")
def validate_plots(self) -> Self:
for plot in self.plot:
self._validate_plot(plot)
return self

@cached_property
def placeholders(self) -> set[str]:
return meta.find_undeclared_variables(self.command.parse())

@cached_property
def output_variables(self) -> set[str]:
return set(x.to_float.name for x in self.output)

@cached_property
def variables(self) -> set[str]:
return self.placeholders | self.output_variables


class SpinnerApplications(RootModel, str_strip_whitespace=True):
root: dict[str, SpinnerApplication] = Field(default_factory=dict)

def items(self):
return self.root.items()

def __iter__(self):
return iter(self.root)

def __getitem__(self, key) -> SpinnerApplication | None:
return self.root.get(key)


class SpinnerBenchmark(RootModel, str_strip_whitespace=True):
root: dict[str, list[Any]] = Field(default_factory=dict)

@cached_property
def parameters(self):
return set(self.root.keys())

def items(self):
return self.root.items()

def __iter__(self):
return iter(self.root)

def __getitem__(self, key) -> list[Any] | None:
return self.root.get(key)


class SpinnerBenchmarks(RootModel, str_strip_whitespace=True):
root: dict[str, SpinnerBenchmark] = Field(default_factory=dict)

def items(self):
return self.root.items()

def __iter__(self):
return iter(self.root)

def __getitem__(self, key) -> SpinnerApplication | None:
return self.root.get(key)


class SpinnerConfig(BaseModel, str_strip_whitespace=True):
metadata: SpinnerMetadata
applications: SpinnerApplications = Field(default_factory=dict)
benchmarks: SpinnerBenchmarks = Field(default_factory=dict)

@model_validator(mode="after")
def validate_benchmark_keys(self) -> Self:
for key in self.benchmarks:
assert key in self.applications, f"benchmark {key!r} is not an application."
return self

@model_validator(mode="after")
def validate_benchmark_parameters(self) -> Self:
for name, parameters in self.benchmarks.items():
application = self.applications[name]
for parameter in parameters:
assert (
parameter in application.placeholders
), f"parameter {parameter!r} is not valid."
return self

@model_validator(mode="after")
def validate_application_placeholders(self) -> Self:
for name, application in self.applications.items():
placeholders = application.placeholders
benchmark = self.benchmarks[name]
for placeholder in placeholders:
assert (
placeholder in benchmark.parameters
), f"placeholder {placeholder!r} for {name!r} is not defined in benchmark."
return self


if __name__ == "__main__":
import yaml
from pydantic import ValidationError
from rich import print

raw_data = """
metadata:
description: Lorem ipsum
version: v1.0
runs: 10
applications:
example:
command:
template: >
sleep {{sleep_duration}} {{nodes}}
output:
- type: contains
pattern: "Runtime: "
to_float:
name: runtime
lambda: >
print("Hello")
plot:
- title: Lorem
x_axis: runtime
y_axis: runtime
benchmarks:
example:
sleep_duration: [1, 2, 3]
nodes: [1, 2, 3]
"""

data = yaml.safe_load(raw_data)

try:
model = SpinnerConfig(**data)
print(model)
except ValidationError as e:
for error in e.errors():
kind = error["type"]
location = ".".join((str(x) for x in error["loc"]))
message = error["msg"]
print(f"ERROR: {kind} in {location!r}:")
print(f"| {message}")

0 comments on commit 2a53f5e

Please sign in to comment.