-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(schema)!: add spinner schema validation
- Loading branch information
Showing
2 changed files
with
219 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,217 @@ | ||
from __future__ import annotations | ||
|
||
import ast | ||
from functools import cache, cached_property | ||
from typing import Any, Literal, Self | ||
|
||
from jinja2 import Environment, Template, meta | ||
from pydantic import ( | ||
BaseModel, | ||
Field, | ||
PositiveFloat, | ||
PositiveInt, | ||
RootModel, | ||
model_validator, | ||
) | ||
|
||
|
||
class SpinnerMetadata(BaseModel, str_strip_whitespace=True): | ||
description: str | ||
version: str = Field(pattern=r"v?\d+\.\d+(\.\d+)?") | ||
runs: int = Field(gt=0) | ||
timeout: PositiveFloat | None = Field(default=None, gt=0.0) | ||
retry: bool = Field(default=False) | ||
retry_limit: PositiveInt = Field(default=1, ge=0) | ||
|
||
|
||
class SpinnerCommand(BaseModel, str_strip_whitespace=True): | ||
template: str | ||
|
||
def __hash__(self) -> int: | ||
return hash(self.template) | ||
|
||
@cache | ||
def parse(self, env: Environment | None = None) -> Template: | ||
if not env: | ||
env = Environment() | ||
return env.parse(self.template) | ||
|
||
|
||
SpinnerOutputType = Literal["contains"] | ||
|
||
|
||
class SpinnerLambda(BaseModel, str_strip_whitespace=True): | ||
name: str | ||
func: str = Field(alias="lambda") | ||
|
||
@model_validator(mode="after") | ||
def validate_lambda(self) -> Self: | ||
try: | ||
ast.parse(source=self.func) | ||
except SyntaxError as e: | ||
raise ValueError(f"syntax error: {e}") from e | ||
return self | ||
|
||
|
||
class SpinnerOutput(BaseModel, str_strip_whitespace=True): | ||
type: SpinnerOutputType | ||
pattern: str | ||
to_float: SpinnerLambda | ||
|
||
|
||
class SpinnerPlot(BaseModel, str_strip_whitespace=True): | ||
title: str = Field(default="") | ||
x_axis: str | ||
y_axis: str | ||
group_by: str | list[str] | None = Field(default=None) | ||
|
||
|
||
class SpinnerApplication(BaseModel, str_strip_whitespace=True): | ||
command: SpinnerCommand | ||
output: list[SpinnerOutput] = Field(default_factory=list) | ||
plot: list[SpinnerPlot] = Field(default_factory=list) | ||
|
||
def _validate_plot(self, plot: SpinnerPlot) -> None: | ||
assert plot.x_axis in self.variables, f"unknown x-axis {plot.x_axis!r}" | ||
assert plot.y_axis in self.variables, f"unknonw y-axis {plot.y_axis!r}" | ||
assert plot.group_by in self.variables, f"unknown group by {plot.x_axis!r}" | ||
|
||
@model_validator(mode="after") | ||
def validate_plots(self) -> Self: | ||
for plot in self.plot: | ||
self._validate_plot(plot) | ||
return self | ||
|
||
@cached_property | ||
def placeholders(self) -> set[str]: | ||
return meta.find_undeclared_variables(self.command.parse()) | ||
|
||
@cached_property | ||
def output_variables(self) -> set[str]: | ||
return set(x.to_float.name for x in self.output) | ||
|
||
@cached_property | ||
def variables(self) -> set[str]: | ||
return self.placeholders | self.output_variables | ||
|
||
|
||
class SpinnerApplications(RootModel, str_strip_whitespace=True): | ||
root: dict[str, SpinnerApplication] = Field(default_factory=dict) | ||
|
||
def items(self): | ||
return self.root.items() | ||
|
||
def __iter__(self): | ||
return iter(self.root) | ||
|
||
def __getitem__(self, key) -> SpinnerApplication | None: | ||
return self.root.get(key) | ||
|
||
|
||
class SpinnerBenchmark(RootModel, str_strip_whitespace=True): | ||
root: dict[str, list[Any]] = Field(default_factory=dict) | ||
|
||
@cached_property | ||
def parameters(self): | ||
return set(self.root.keys()) | ||
|
||
def items(self): | ||
return self.root.items() | ||
|
||
def __iter__(self): | ||
return iter(self.root) | ||
|
||
def __getitem__(self, key) -> list[Any] | None: | ||
return self.root.get(key) | ||
|
||
|
||
class SpinnerBenchmarks(RootModel, str_strip_whitespace=True): | ||
root: dict[str, SpinnerBenchmark] = Field(default_factory=dict) | ||
|
||
def items(self): | ||
return self.root.items() | ||
|
||
def __iter__(self): | ||
return iter(self.root) | ||
|
||
def __getitem__(self, key) -> SpinnerApplication | None: | ||
return self.root.get(key) | ||
|
||
|
||
class SpinnerConfig(BaseModel, str_strip_whitespace=True): | ||
metadata: SpinnerMetadata | ||
applications: SpinnerApplications = Field(default_factory=dict) | ||
benchmarks: SpinnerBenchmarks = Field(default_factory=dict) | ||
|
||
@model_validator(mode="after") | ||
def validate_benchmark_keys(self) -> Self: | ||
for key in self.benchmarks: | ||
assert key in self.applications, f"benchmark {key!r} is not an application." | ||
return self | ||
|
||
@model_validator(mode="after") | ||
def validate_benchmark_parameters(self) -> Self: | ||
for name, parameters in self.benchmarks.items(): | ||
application = self.applications[name] | ||
for parameter in parameters: | ||
assert ( | ||
parameter in application.placeholders | ||
), f"parameter {parameter!r} is not valid." | ||
return self | ||
|
||
@model_validator(mode="after") | ||
def validate_application_placeholders(self) -> Self: | ||
for name, application in self.applications.items(): | ||
placeholders = application.placeholders | ||
benchmark = self.benchmarks[name] | ||
for placeholder in placeholders: | ||
assert ( | ||
placeholder in benchmark.parameters | ||
), f"placeholder {placeholder!r} for {name!r} is not defined in benchmark." | ||
return self | ||
|
||
|
||
if __name__ == "__main__": | ||
import yaml | ||
from pydantic import ValidationError | ||
from rich import print | ||
|
||
raw_data = """ | ||
metadata: | ||
description: Lorem ipsum | ||
version: v1.0 | ||
runs: 10 | ||
applications: | ||
example: | ||
command: | ||
template: > | ||
sleep {{sleep_duration}} {{nodes}} | ||
output: | ||
- type: contains | ||
pattern: "Runtime: " | ||
to_float: | ||
name: runtime | ||
lambda: > | ||
print("Hello") | ||
plot: | ||
- title: Lorem | ||
x_axis: runtime | ||
y_axis: runtime | ||
benchmarks: | ||
example: | ||
sleep_duration: [1, 2, 3] | ||
nodes: [1, 2, 3] | ||
""" | ||
|
||
data = yaml.safe_load(raw_data) | ||
|
||
try: | ||
model = SpinnerConfig(**data) | ||
print(model) | ||
except ValidationError as e: | ||
for error in e.errors(): | ||
kind = error["type"] | ||
location = ".".join((str(x) for x in error["loc"])) | ||
message = error["msg"] | ||
print(f"ERROR: {kind} in {location!r}:") | ||
print(f"| {message}") |