Skip to content

Commit

Permalink
fix!: make signals serializable
Browse files Browse the repository at this point in the history
  • Loading branch information
tobymao committed Dec 9, 2024
1 parent c43266a commit e3ac8db
Show file tree
Hide file tree
Showing 23 changed files with 671 additions and 494 deletions.
6 changes: 5 additions & 1 deletion examples/sushi/models/waiter_as_customer_by_day.sql
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@ MODEL (
audits (
not_null(columns := (waiter_id)),
forall(criteria := (LENGTH(waiter_name) > 0))
)
),
signals (
test_signal(arg := 1)
),

);

JINJA_QUERY_BEGIN;
Expand Down
9 changes: 9 additions & 0 deletions examples/sushi/signals/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import typing as t

from sqlmesh import signal, DatetimeRanges


@signal()
def test_signal(batch: DatetimeRanges, arg: int = 0) -> t.Union[bool, DatetimeRanges]:
assert arg == 1
return True
6 changes: 4 additions & 2 deletions sqlmesh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from sqlmesh.core.engine_adapter import EngineAdapter as EngineAdapter
from sqlmesh.core.macros import SQL as SQL, macro as macro
from sqlmesh.core.model import Model as Model, model as model
from sqlmesh.core.signal import signal as signal
from sqlmesh.core.snapshot import Snapshot as Snapshot
from sqlmesh.core.snapshot.evaluator import (
CustomMaterialization as CustomMaterialization,
Expand All @@ -32,6 +33,7 @@
debug_mode_enabled as debug_mode_enabled,
enable_debug_mode as enable_debug_mode,
)
from sqlmesh.utils.date import DatetimeRanges as DatetimeRanges

try:
from sqlmesh._version import __version__ as __version__, __version_tuple__ as __version_tuple__
Expand Down Expand Up @@ -171,13 +173,13 @@ def configure_logging(
os.remove(path)

if debug:
from signal import SIGUSR1
import faulthandler
import signal

enable_debug_mode()

# Enable threadumps.
faulthandler.enable()
# Windows doesn't support register so we check for it here
if hasattr(faulthandler, "register"):
faulthandler.register(signal.SIGUSR1.value)
faulthandler.register(SIGUSR1.value)
8 changes: 4 additions & 4 deletions sqlmesh/core/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import typing as t

from sqlmesh.core.dialect import parse_one, extract_audit
from sqlmesh.core.dialect import parse_one, extract_func_call
from sqlmesh.core.config.base import BaseConfig
from sqlmesh.core.model.kind import (
ModelKind,
Expand All @@ -11,7 +11,7 @@
on_destructive_change_validator,
)
from sqlmesh.utils.date import TimeLike
from sqlmesh.core.model.meta import AuditReference
from sqlmesh.core.model.meta import FunctionCall
from sqlmesh.utils.pydantic import field_validator


Expand Down Expand Up @@ -44,14 +44,14 @@ class ModelDefaultsConfig(BaseConfig):
storage_format: t.Optional[str] = None
on_destructive_change: t.Optional[OnDestructiveChange] = None
session_properties: t.Optional[t.Dict[str, t.Any]] = None
audits: t.Optional[t.List[AuditReference]] = None
audits: t.Optional[t.List[FunctionCall]] = None

_model_kind_validator = model_kind_validator
_on_destructive_change_validator = on_destructive_change_validator

@field_validator("audits", mode="before")
def _audits_validator(cls, v: t.Any) -> t.Any:
if isinstance(v, list):
return [extract_audit(parse_one(audit)) for audit in v]
return [extract_func_call(parse_one(audit)) for audit in v]

return v
1 change: 0 additions & 1 deletion sqlmesh/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2179,7 +2179,6 @@ def _load_materializations_and_signals(self) -> None:
if not self._loaded:
for context_loader in self._loaders.values():
with sys_path(*context_loader.configs):
context_loader.loader.load_signals(self)
context_loader.loader.load_materializations(self)

def _select_models_for_run(
Expand Down
10 changes: 9 additions & 1 deletion sqlmesh/core/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -1205,7 +1205,9 @@ def interpret_key_value_pairs(
return {i.this.name: interpret_expression(i.expression) for i in e.expressions}


def extract_audit(v: exp.Expression) -> t.Tuple[str, t.Dict[str, exp.Expression]]:
def extract_func_call(
v: exp.Expression, allow_tuples: bool = False
) -> t.Tuple[str, t.Dict[str, exp.Expression]]:
kwargs = {}

if isinstance(v, exp.Anonymous):
Expand All @@ -1214,6 +1216,12 @@ def extract_audit(v: exp.Expression) -> t.Tuple[str, t.Dict[str, exp.Expression]
elif isinstance(v, exp.Func):
func = v.sql_name()
args = list(v.args.values())
elif isinstance(v, exp.Tuple): # airflow only
if not allow_tuples:
raise ConfigError("Audit name is missing (eg. MY_AUDIT())")

func = ""
args = v.expressions
else:
return v.name.lower(), {}

Expand Down
57 changes: 48 additions & 9 deletions sqlmesh/core/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
)
from sqlmesh.core.model.cache import load_optimized_query_and_mapping, optimized_query_cache_pool
from sqlmesh.core.model import model as model_registry
from sqlmesh.core.signal import signal
from sqlmesh.utils import UniqueKeyDict
from sqlmesh.utils.dag import DAG
from sqlmesh.utils.errors import ConfigError
Expand Down Expand Up @@ -79,7 +80,7 @@ def load(self, context: GenericContext, update_schemas: bool = True) -> LoadedPr
self._dag = DAG()

self._load_materializations()
self._load_signals()
signals = self._load_signals()

config_mtimes: t.Dict[Path, t.List[float]] = defaultdict(list)
for context_path, config in self._context.configs.items():
Expand All @@ -103,7 +104,13 @@ def load(self, context: GenericContext, update_schemas: bool = True) -> LoadedPr
else:
standalone_audits[name] = audit

models = self._load_models(macros, jinja_macros, context.selected_gateway, audits)
models = self._load_models(
macros,
jinja_macros,
context.selected_gateway,
audits,
signals,
)

for model in models.values():
self._add_model_to_dag(model)
Expand Down Expand Up @@ -131,10 +138,10 @@ def load(self, context: GenericContext, update_schemas: bool = True) -> LoadedPr
)
return project

def load_signals(self, context: GenericContext) -> None:
def load_signals(self, context: GenericContext) -> UniqueKeyDict[str, signal]:
"""Loads signals for the built-in scheduler."""
self._context = context
self._load_signals()
return self._load_signals()

def load_materializations(self, context: GenericContext) -> None:
"""Loads materializations for the built-in scheduler."""
Expand Down Expand Up @@ -165,6 +172,7 @@ def _load_models(
jinja_macros: JinjaMacroRegistry,
gateway: t.Optional[str],
audits: UniqueKeyDict[str, ModelAudit],
signals: UniqueKeyDict[str, signal],
) -> UniqueKeyDict[str, Model]:
"""Loads all models."""

Expand All @@ -177,8 +185,8 @@ def _load_audits(
def _load_materializations(self) -> None:
"""Loads custom materializations."""

def _load_signals(self) -> None:
"""Loads signals for the built-in scheduler."""
def _load_signals(self) -> UniqueKeyDict[str, signal]:
return UniqueKeyDict("signals")

def _load_metrics(self) -> UniqueKeyDict[str, MetricMeta]:
return UniqueKeyDict("metrics")
Expand Down Expand Up @@ -320,14 +328,15 @@ def _load_models(
jinja_macros: JinjaMacroRegistry,
gateway: t.Optional[str],
audits: UniqueKeyDict[str, ModelAudit],
signals: UniqueKeyDict[str, signal],
) -> UniqueKeyDict[str, Model]:
"""
Loads all of the models within the model directory with their associated
audits into a Dict and creates the dag
"""
models = self._load_sql_models(macros, jinja_macros, audits)
models = self._load_sql_models(macros, jinja_macros, audits, signals)
models.update(self._load_external_models(audits, gateway))
models.update(self._load_python_models(macros, jinja_macros, audits))
models.update(self._load_python_models(macros, jinja_macros, audits, signals))

return models

Expand All @@ -336,6 +345,7 @@ def _load_sql_models(
macros: MacroRegistry,
jinja_macros: JinjaMacroRegistry,
audits: UniqueKeyDict[str, ModelAudit],
signals: UniqueKeyDict[str, signal],
) -> UniqueKeyDict[str, Model]:
"""Loads the sql models into a Dict"""
models: UniqueKeyDict[str, Model] = UniqueKeyDict("models")
Expand Down Expand Up @@ -380,6 +390,7 @@ def _load() -> Model:
default_catalog=self._context.default_catalog,
variables=variables,
infer_names=config.model_naming.infer_names,
signal_definitions=signals,
)

model = cache.get_or_load_model(path, _load)
Expand All @@ -397,6 +408,7 @@ def _load_python_models(
macros: MacroRegistry,
jinja_macros: JinjaMacroRegistry,
audits: UniqueKeyDict[str, ModelAudit],
signals: UniqueKeyDict[str, signal],
) -> UniqueKeyDict[str, Model]:
"""Loads the python models into a Dict"""
models: UniqueKeyDict[str, Model] = UniqueKeyDict("models")
Expand Down Expand Up @@ -452,29 +464,51 @@ def _load_materializations(self) -> None:
if os.path.getsize(path):
import_python_file(path, context_path)

def _load_signals(self) -> None:
def _load_signals(self) -> UniqueKeyDict[str, signal]:
"""Loads signals for the built-in scheduler."""

signals_max_mtime: t.Optional[float] = None

for context_path, config in self._context.configs.items():
for path in self._glob_paths(
context_path / c.SIGNALS,
ignore_patterns=config.ignore_patterns,
extension=".py",
):
if os.path.getsize(path):
self._track_file(path)
signal_file_mtime = self._path_mtimes[path]
signals_max_mtime = (
max(signals_max_mtime, signal_file_mtime)
if signals_max_mtime
else signal_file_mtime
)
import_python_file(path, context_path)

self._signals_max_mtime = signals_max_mtime

return signal.get_registry()

def _load_audits(
self, macros: MacroRegistry, jinja_macros: JinjaMacroRegistry
) -> UniqueKeyDict[str, Audit]:
"""Loads all the model audits."""
audits_by_name: UniqueKeyDict[str, Audit] = UniqueKeyDict("audits")
audits_max_mtime: t.Optional[float] = None

for context_path, config in self._context.configs.items():
variables = self._variables(config)
for path in self._glob_paths(
context_path / c.AUDITS, ignore_patterns=config.ignore_patterns, extension=".sql"
):
self._track_file(path)
with open(path, "r", encoding="utf-8") as file:
audits_file_mtime = self._path_mtimes[path]
audits_max_mtime = (
max(audits_max_mtime, audits_file_mtime)
if audits_max_mtime
else audits_file_mtime
)
expressions = parse(file.read(), default_dialect=config.model_defaults.dialect)
audits = load_multiple_audits(
expressions=expressions,
Expand All @@ -488,6 +522,9 @@ def _load_audits(
)
for audit in audits:
audits_by_name[audit.name] = audit

self._audits_max_mtime = audits_max_mtime

return audits_by_name

def _load_metrics(self) -> UniqueKeyDict[str, MetricMeta]:
Expand Down Expand Up @@ -550,6 +587,8 @@ def _model_cache_entry_id(self, model_path: Path) -> str:
mtimes = [
self._loader._path_mtimes[model_path],
self._loader._macros_max_mtime,
self._loader._signals_max_mtime,
self._loader._audits_max_mtime,
self._loader._config_mtimes.get(self._context_path),
self._loader._config_mtimes.get(c.SQLMESH_PATH),
]
Expand Down
Loading

0 comments on commit e3ac8db

Please sign in to comment.