From 2bc2d43f23bcbd361b41ac14a64a3810045af8b8 Mon Sep 17 00:00:00 2001 From: tobymao Date: Mon, 2 Dec 2024 17:06:24 -0800 Subject: [PATCH] fix!: make signals serializable --- .../models/waiter_as_customer_by_day.sql | 6 +- examples/sushi/signals/__init__.py | 9 + sqlmesh/__init__.py | 6 +- sqlmesh/core/config/model.py | 8 +- sqlmesh/core/context.py | 1 - sqlmesh/core/dialect.py | 10 +- sqlmesh/core/loader.py | 57 ++++- sqlmesh/core/macros.py | 226 +++++++++-------- sqlmesh/core/model/definition.py | 42 ++- sqlmesh/core/model/meta.py | 45 +--- sqlmesh/core/scheduler.py | 132 +--------- sqlmesh/core/signal.py | 35 +++ sqlmesh/core/snapshot/definition.py | 84 ++++++ sqlmesh/dbt/loader.py | 2 + .../migrations/v0059_add_physical_version.py | 5 - sqlmesh/migrations/v0062_change_signals.py | 83 ++++++ sqlmesh/utils/metaprogramming.py | 2 +- tests/core/test_model.py | 52 ++-- tests/core/test_scheduler.py | 240 +++++------------- tests/core/test_snapshot.py | 110 +++++++- .../airflow/operators/test_sensor.py | 15 +- 21 files changed, 671 insertions(+), 499 deletions(-) create mode 100644 examples/sushi/signals/__init__.py create mode 100644 sqlmesh/core/signal.py delete mode 100644 sqlmesh/migrations/v0059_add_physical_version.py create mode 100644 sqlmesh/migrations/v0062_change_signals.py diff --git a/examples/sushi/models/waiter_as_customer_by_day.sql b/examples/sushi/models/waiter_as_customer_by_day.sql index 3f7e053c4..7dc12db87 100644 --- a/examples/sushi/models/waiter_as_customer_by_day.sql +++ b/examples/sushi/models/waiter_as_customer_by_day.sql @@ -8,7 +8,11 @@ MODEL ( audits ( not_null(columns := (waiter_id)), forall(criteria := (LENGTH(waiter_name) > 0)) - ) + ), + signals ( + test_signal(arg := 1) + ), + ); JINJA_QUERY_BEGIN; diff --git a/examples/sushi/signals/__init__.py b/examples/sushi/signals/__init__.py new file mode 100644 index 000000000..bd7c839fc --- /dev/null +++ b/examples/sushi/signals/__init__.py @@ -0,0 +1,9 @@ +import typing as t + +from sqlmesh import signal, DatetimeRanges + + +@signal() +def test_signal(batch: DatetimeRanges, arg: int = 0) -> t.Union[bool, DatetimeRanges]: + assert arg == 1 + return True diff --git a/sqlmesh/__init__.py b/sqlmesh/__init__.py index bfa0094dc..18bb9e30e 100644 --- a/sqlmesh/__init__.py +++ b/sqlmesh/__init__.py @@ -24,6 +24,7 @@ from sqlmesh.core.engine_adapter import EngineAdapter as EngineAdapter from sqlmesh.core.macros import SQL as SQL, macro as macro from sqlmesh.core.model import Model as Model, model as model +from sqlmesh.core.signal import signal as signal from sqlmesh.core.snapshot import Snapshot as Snapshot from sqlmesh.core.snapshot.evaluator import ( CustomMaterialization as CustomMaterialization, @@ -32,6 +33,7 @@ debug_mode_enabled as debug_mode_enabled, enable_debug_mode as enable_debug_mode, ) +from sqlmesh.utils.date import DatetimeRanges as DatetimeRanges try: from sqlmesh._version import __version__ as __version__, __version_tuple__ as __version_tuple__ @@ -171,8 +173,8 @@ def configure_logging( os.remove(path) if debug: + from signal import SIGUSR1 import faulthandler - import signal enable_debug_mode() @@ -180,4 +182,4 @@ def configure_logging( faulthandler.enable() # Windows doesn't support register so we check for it here if hasattr(faulthandler, "register"): - faulthandler.register(signal.SIGUSR1.value) + faulthandler.register(SIGUSR1.value) diff --git a/sqlmesh/core/config/model.py b/sqlmesh/core/config/model.py index ec8d2aaaf..2fac33ff3 100644 --- a/sqlmesh/core/config/model.py +++ b/sqlmesh/core/config/model.py @@ -2,7 +2,7 @@ import typing as t -from sqlmesh.core.dialect import parse_one, extract_audit +from sqlmesh.core.dialect import parse_one, extract_func_call from sqlmesh.core.config.base import BaseConfig from sqlmesh.core.model.kind import ( ModelKind, @@ -11,7 +11,7 @@ on_destructive_change_validator, ) from sqlmesh.utils.date import TimeLike -from sqlmesh.core.model.meta import AuditReference +from sqlmesh.core.model.meta import FunctionCall from sqlmesh.utils.pydantic import field_validator @@ -44,7 +44,7 @@ class ModelDefaultsConfig(BaseConfig): storage_format: t.Optional[str] = None on_destructive_change: t.Optional[OnDestructiveChange] = None session_properties: t.Optional[t.Dict[str, t.Any]] = None - audits: t.Optional[t.List[AuditReference]] = None + audits: t.Optional[t.List[FunctionCall]] = None _model_kind_validator = model_kind_validator _on_destructive_change_validator = on_destructive_change_validator @@ -52,6 +52,6 @@ class ModelDefaultsConfig(BaseConfig): @field_validator("audits", mode="before") def _audits_validator(cls, v: t.Any) -> t.Any: if isinstance(v, list): - return [extract_audit(parse_one(audit)) for audit in v] + return [extract_func_call(parse_one(audit)) for audit in v] return v diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index 5882e9470..8c5d565a2 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -2179,7 +2179,6 @@ def _load_materializations_and_signals(self) -> None: if not self._loaded: for context_loader in self._loaders.values(): with sys_path(*context_loader.configs): - context_loader.loader.load_signals(self) context_loader.loader.load_materializations(self) def _select_models_for_run( diff --git a/sqlmesh/core/dialect.py b/sqlmesh/core/dialect.py index 1639fd88f..2d284422c 100644 --- a/sqlmesh/core/dialect.py +++ b/sqlmesh/core/dialect.py @@ -1205,7 +1205,9 @@ def interpret_key_value_pairs( return {i.this.name: interpret_expression(i.expression) for i in e.expressions} -def extract_audit(v: exp.Expression) -> t.Tuple[str, t.Dict[str, exp.Expression]]: +def extract_func_call( + v: exp.Expression, allow_tuples: bool = False +) -> t.Tuple[str, t.Dict[str, exp.Expression]]: kwargs = {} if isinstance(v, exp.Anonymous): @@ -1214,6 +1216,12 @@ def extract_audit(v: exp.Expression) -> t.Tuple[str, t.Dict[str, exp.Expression] elif isinstance(v, exp.Func): func = v.sql_name() args = list(v.args.values()) + elif isinstance(v, exp.Tuple): # airflow only + if not allow_tuples: + raise ConfigError("Audit name is missing (eg. MY_AUDIT())") + + func = "" + args = v.expressions else: return v.name.lower(), {} diff --git a/sqlmesh/core/loader.py b/sqlmesh/core/loader.py index 493c9e54e..d1f989d3d 100644 --- a/sqlmesh/core/loader.py +++ b/sqlmesh/core/loader.py @@ -29,6 +29,7 @@ ) from sqlmesh.core.model.cache import load_optimized_query_and_mapping, optimized_query_cache_pool from sqlmesh.core.model import model as model_registry +from sqlmesh.core.signal import signal from sqlmesh.utils import UniqueKeyDict from sqlmesh.utils.dag import DAG from sqlmesh.utils.errors import ConfigError @@ -79,7 +80,7 @@ def load(self, context: GenericContext, update_schemas: bool = True) -> LoadedPr self._dag = DAG() self._load_materializations() - self._load_signals() + signals = self._load_signals() config_mtimes: t.Dict[Path, t.List[float]] = defaultdict(list) for context_path, config in self._context.configs.items(): @@ -103,7 +104,13 @@ def load(self, context: GenericContext, update_schemas: bool = True) -> LoadedPr else: standalone_audits[name] = audit - models = self._load_models(macros, jinja_macros, context.selected_gateway, audits) + models = self._load_models( + macros, + jinja_macros, + context.selected_gateway, + audits, + signals, + ) for model in models.values(): self._add_model_to_dag(model) @@ -131,10 +138,10 @@ def load(self, context: GenericContext, update_schemas: bool = True) -> LoadedPr ) return project - def load_signals(self, context: GenericContext) -> None: + def load_signals(self, context: GenericContext) -> UniqueKeyDict[str, signal]: """Loads signals for the built-in scheduler.""" self._context = context - self._load_signals() + return self._load_signals() def load_materializations(self, context: GenericContext) -> None: """Loads materializations for the built-in scheduler.""" @@ -165,6 +172,7 @@ def _load_models( jinja_macros: JinjaMacroRegistry, gateway: t.Optional[str], audits: UniqueKeyDict[str, ModelAudit], + signals: UniqueKeyDict[str, signal], ) -> UniqueKeyDict[str, Model]: """Loads all models.""" @@ -177,8 +185,8 @@ def _load_audits( def _load_materializations(self) -> None: """Loads custom materializations.""" - def _load_signals(self) -> None: - """Loads signals for the built-in scheduler.""" + def _load_signals(self) -> UniqueKeyDict[str, signal]: + return UniqueKeyDict("signals") def _load_metrics(self) -> UniqueKeyDict[str, MetricMeta]: return UniqueKeyDict("metrics") @@ -320,14 +328,15 @@ def _load_models( jinja_macros: JinjaMacroRegistry, gateway: t.Optional[str], audits: UniqueKeyDict[str, ModelAudit], + signals: UniqueKeyDict[str, signal], ) -> UniqueKeyDict[str, Model]: """ Loads all of the models within the model directory with their associated audits into a Dict and creates the dag """ - models = self._load_sql_models(macros, jinja_macros, audits) + models = self._load_sql_models(macros, jinja_macros, audits, signals) models.update(self._load_external_models(audits, gateway)) - models.update(self._load_python_models(macros, jinja_macros, audits)) + models.update(self._load_python_models(macros, jinja_macros, audits, signals)) return models @@ -336,6 +345,7 @@ def _load_sql_models( macros: MacroRegistry, jinja_macros: JinjaMacroRegistry, audits: UniqueKeyDict[str, ModelAudit], + signals: UniqueKeyDict[str, signal], ) -> UniqueKeyDict[str, Model]: """Loads the sql models into a Dict""" models: UniqueKeyDict[str, Model] = UniqueKeyDict("models") @@ -380,6 +390,7 @@ def _load() -> Model: default_catalog=self._context.default_catalog, variables=variables, infer_names=config.model_naming.infer_names, + signal_definitions=signals, ) model = cache.get_or_load_model(path, _load) @@ -397,6 +408,7 @@ def _load_python_models( macros: MacroRegistry, jinja_macros: JinjaMacroRegistry, audits: UniqueKeyDict[str, ModelAudit], + signals: UniqueKeyDict[str, signal], ) -> UniqueKeyDict[str, Model]: """Loads the python models into a Dict""" models: UniqueKeyDict[str, Model] = UniqueKeyDict("models") @@ -452,8 +464,11 @@ def _load_materializations(self) -> None: if os.path.getsize(path): import_python_file(path, context_path) - def _load_signals(self) -> None: + def _load_signals(self) -> UniqueKeyDict[str, signal]: """Loads signals for the built-in scheduler.""" + + signals_max_mtime: t.Optional[float] = None + for context_path, config in self._context.configs.items(): for path in self._glob_paths( context_path / c.SIGNALS, @@ -461,13 +476,26 @@ def _load_signals(self) -> None: extension=".py", ): if os.path.getsize(path): + self._track_file(path) + signal_file_mtime = self._path_mtimes[path] + signals_max_mtime = ( + max(signals_max_mtime, signal_file_mtime) + if signals_max_mtime + else signal_file_mtime + ) import_python_file(path, context_path) + self._signals_max_mtime = signals_max_mtime + + return signal.get_registry() + def _load_audits( self, macros: MacroRegistry, jinja_macros: JinjaMacroRegistry ) -> UniqueKeyDict[str, Audit]: """Loads all the model audits.""" audits_by_name: UniqueKeyDict[str, Audit] = UniqueKeyDict("audits") + audits_max_mtime: t.Optional[float] = None + for context_path, config in self._context.configs.items(): variables = self._variables(config) for path in self._glob_paths( @@ -475,6 +503,12 @@ def _load_audits( ): self._track_file(path) with open(path, "r", encoding="utf-8") as file: + audits_file_mtime = self._path_mtimes[path] + audits_max_mtime = ( + max(audits_max_mtime, audits_file_mtime) + if audits_max_mtime + else audits_file_mtime + ) expressions = parse(file.read(), default_dialect=config.model_defaults.dialect) audits = load_multiple_audits( expressions=expressions, @@ -488,6 +522,9 @@ def _load_audits( ) for audit in audits: audits_by_name[audit.name] = audit + + self._audits_max_mtime = audits_max_mtime + return audits_by_name def _load_metrics(self) -> UniqueKeyDict[str, MetricMeta]: @@ -550,6 +587,8 @@ def _model_cache_entry_id(self, model_path: Path) -> str: mtimes = [ self._loader._path_mtimes[model_path], self._loader._macros_max_mtime, + self._loader._signals_max_mtime, + self._loader._audits_max_mtime, self._loader._config_mtimes.get(self._context_path), self._loader._config_mtimes.get(c.SQLMESH_PATH), ] diff --git a/sqlmesh/core/macros.py b/sqlmesh/core/macros.py index b7fcdf0b0..c17ddb6c1 100644 --- a/sqlmesh/core/macros.py +++ b/sqlmesh/core/macros.py @@ -39,6 +39,7 @@ columns_to_types_all_known, registry_decorator, ) +from sqlmesh.utils.date import DatetimeRanges from sqlmesh.utils.errors import MacroEvalError, SQLMeshError from sqlmesh.utils.jinja import JinjaMacroRegistry, has_jinja from sqlmesh.utils.metaprogramming import Executable, prepare_env, print_exception @@ -79,6 +80,7 @@ class MacroStrTemplate(Template): "List": t.List, "Tuple": t.Tuple, "Union": t.Union, + "DatetimeRanges": DatetimeRanges, } for klass in sqlglot.Parser.EXPRESSION_PARSERS: @@ -197,38 +199,7 @@ def send( raise SQLMeshError(f"Macro '{name}' does not exist.") try: - # Bind the macro's actual parameters to its formal parameters - sig = inspect.signature(func) - bound = sig.bind(self, *args, **kwargs) - bound.apply_defaults() - except Exception as e: - print_exception(e, self.python_env) - raise MacroEvalError("Error trying to eval macro.") from e - - try: - annotations = t.get_type_hints(func, localns=SUPPORTED_TYPES) - except NameError: # forward references aren't handled - annotations = {} - - # If the macro is annotated, we try coerce the actual parameters to the corresponding types - if annotations: - for arg, value in bound.arguments.items(): - typ = annotations.get(arg) - if not typ: - continue - - # Changes to bound.arguments will reflect in bound.args and bound.kwargs - # https://docs.python.org/3/library/inspect.html#inspect.BoundArguments.arguments - param = sig.parameters[arg] - if param.kind is inspect.Parameter.VAR_POSITIONAL: - bound.arguments[arg] = tuple(self._coerce(v, typ) for v in value) - elif param.kind is inspect.Parameter.VAR_KEYWORD: - bound.arguments[arg] = {k: self._coerce(v, typ) for k, v in value.items()} - else: - bound.arguments[arg] = self._coerce(value, typ) - - try: - return func(*bound.args, **bound.kwargs) + return call_macro(func, self.dialect, self._path, self, *args, **kwargs) # type: ignore except Exception as e: print_exception(e, self.python_env) raise MacroEvalError("Error trying to eval macro.") from e @@ -497,78 +468,7 @@ def var(self, var_name: str, default: t.Optional[t.Any] = None) -> t.Optional[t. def _coerce(self, expr: exp.Expression, typ: t.Any, strict: bool = False) -> t.Any: """Coerces the given expression to the specified type on a best-effort basis.""" - base_err_msg = f"Failed to coerce expression '{expr}' to type '{typ}'." - try: - if typ is None or typ is t.Any: - return expr - base = t.get_origin(typ) or typ - - # We need to handle Union and TypeVars first since we cannot use isinstance with it - if base in UNION_TYPES: - for branch in t.get_args(typ): - try: - return self._coerce(expr, branch, True) - except Exception: - pass - raise SQLMeshError(base_err_msg) - if base is SQL and isinstance(expr, exp.Expression): - return expr.sql(self.dialect) - - if isinstance(expr, base): - return expr - if issubclass(base, exp.Expression): - d = Dialect.get_or_raise(self.dialect) - into = base if base in d.parser_class.EXPRESSION_PARSERS else None - if into is None: - if isinstance(expr, exp.Literal): - coerced = parse_one(expr.this) - else: - raise SQLMeshError( - f"{base_err_msg} Coercion to {base} requires a literal expression." - ) - else: - coerced = parse_one( - expr.this if isinstance(expr, exp.Literal) else expr.sql(), into=into - ) - if isinstance(coerced, base): - return coerced - raise SQLMeshError(base_err_msg) - - if base in (int, float, str) and isinstance(expr, exp.Literal): - return base(expr.this) - if base is str and isinstance(expr, exp.Column) and not expr.table: - return expr.name - if base is bool and isinstance(expr, exp.Boolean): - return expr.this - # if base is str and isinstance(expr, exp.Expression): - # return expr.sql(self.dialect) - if base is tuple and isinstance(expr, (exp.Tuple, exp.Array)): - generic = t.get_args(typ) - if not generic: - return tuple(expr.expressions) - if generic[-1] is ...: - return tuple(self._coerce(expr, generic[0]) for expr in expr.expressions) - elif len(generic) == len(expr.expressions): - return tuple( - self._coerce(expr, generic[i]) for i, expr in enumerate(expr.expressions) - ) - raise SQLMeshError(f"{base_err_msg} Expected {len(generic)} items.") - if base is list and isinstance(expr, (exp.Array, exp.Tuple)): - generic = t.get_args(typ) - if not generic: - return expr.expressions - return [self._coerce(expr, generic[0]) for expr in expr.expressions] - raise SQLMeshError(base_err_msg) - except Exception: - if strict: - raise - logger.error( - "Coercion of expression '%s' to type '%s' failed. Using non coerced expression at '%s'", - expr, - typ, - self._path, - ) - return expr + return _coerce(expr, typ, self.dialect, self._path, strict) class macro(registry_decorator): @@ -1284,3 +1184,121 @@ def normalize_macro_name(name: str) -> str: for m in macro.get_registry().values(): setattr(m, c.SQLMESH_BUILTIN, True) + + +def call_macro( + func: t.Callable, + dialect: DialectType, + path: Path, + *args: t.Any, + **kwargs: t.Any, +) -> t.Any: + # Bind the macro's actual parameters to its formal parameters + sig = inspect.signature(func) + bound = sig.bind(*args, **kwargs) + bound.apply_defaults() + + try: + annotations = t.get_type_hints(func, localns=SUPPORTED_TYPES) + except (NameError, TypeError): # forward references aren't handled + annotations = {} + + # If the macro is annotated, we try coerce the actual parameters to the corresponding types + if annotations: + for arg, value in bound.arguments.items(): + typ = annotations.get(arg) + if not typ: + continue + + # Changes to bound.arguments will reflect in bound.args and bound.kwargs + # https://docs.python.org/3/library/inspect.html#inspect.BoundArguments.arguments + param = sig.parameters[arg] + if param.kind is inspect.Parameter.VAR_POSITIONAL: + bound.arguments[arg] = tuple(_coerce(v, typ, dialect, path) for v in value) + elif param.kind is inspect.Parameter.VAR_KEYWORD: + bound.arguments[arg] = {k: _coerce(v, typ, dialect, path) for k, v in value.items()} + else: + bound.arguments[arg] = _coerce(value, typ, dialect, path) + + return func(*bound.args, **bound.kwargs) + + +def _coerce( + expr: exp.Expression, + typ: t.Any, + dialect: DialectType, + path: Path, + strict: bool = False, +) -> t.Any: + """Coerces the given expression to the specified type on a best-effort basis.""" + base_err_msg = f"Failed to coerce expression '{expr}' to type '{typ}'." + try: + if typ is None or typ is t.Any: + return expr + base = t.get_origin(typ) or typ + + # We need to handle Union and TypeVars first since we cannot use isinstance with it + if base in UNION_TYPES: + for branch in t.get_args(typ): + try: + return _coerce(expr, branch, dialect, path, strict=True) + except Exception: + pass + raise SQLMeshError(base_err_msg) + if base is SQL and isinstance(expr, exp.Expression): + return expr.sql(dialect) + + if isinstance(expr, base): + return expr + if issubclass(base, exp.Expression): + d = Dialect.get_or_raise(dialect) + into = base if base in d.parser_class.EXPRESSION_PARSERS else None + if into is None: + if isinstance(expr, exp.Literal): + coerced = parse_one(expr.this) + else: + raise SQLMeshError( + f"{base_err_msg} Coercion to {base} requires a literal expression." + ) + else: + coerced = parse_one( + expr.this if isinstance(expr, exp.Literal) else expr.sql(), into=into + ) + if isinstance(coerced, base): + return coerced + raise SQLMeshError(base_err_msg) + + if base in (int, float, str) and isinstance(expr, exp.Literal): + return base(expr.this) + if base is str and isinstance(expr, exp.Column) and not expr.table: + return expr.name + if base is bool and isinstance(expr, exp.Boolean): + return expr.this + if base is tuple and isinstance(expr, (exp.Tuple, exp.Array)): + generic = t.get_args(typ) + if not generic: + return tuple(expr.expressions) + if generic[-1] is ...: + return tuple(_coerce(expr, generic[0], dialect, path) for expr in expr.expressions) + elif len(generic) == len(expr.expressions): + return tuple( + _coerce(expr, generic[i], dialect, path) + for i, expr in enumerate(expr.expressions) + ) + raise SQLMeshError(f"{base_err_msg} Expected {len(generic)} items.") + if base is list and isinstance(expr, (exp.Array, exp.Tuple)): + generic = t.get_args(typ) + if not generic: + return expr.expressions + return [_coerce(expr, generic[0], dialect, path) for expr in expr.expressions] + raise SQLMeshError(base_err_msg) + except Exception: + if strict: + raise + logger.error( + "Coercion of expression '%s' to type '%s' failed. Using non coerced expression at '%s'", + expr, + typ, + path, + ) + return expr diff --git a/sqlmesh/core/model/definition.py b/sqlmesh/core/model/definition.py index 6a76e411c..b00fc97aa 100644 --- a/sqlmesh/core/model/definition.py +++ b/sqlmesh/core/model/definition.py @@ -31,9 +31,10 @@ single_value_or_tuple, ) from sqlmesh.core.model.kind import ModelKindName, SeedKind, ModelKind, FullKind, create_model_kind -from sqlmesh.core.model.meta import ModelMeta, AuditReference +from sqlmesh.core.model.meta import ModelMeta, FunctionCall from sqlmesh.core.model.seed import CsvSeedReader, Seed, create_seed from sqlmesh.core.renderer import ExpressionRenderer, QueryRenderer +from sqlmesh.core.signal import SignalRegistry from sqlmesh.utils import columns_to_types_all_known, str_to_bool, UniqueKeyDict from sqlmesh.utils.date import TimeLike, make_inclusive, to_datetime, to_time_column from sqlmesh.utils.errors import ConfigError, SQLMeshError, raise_config_error @@ -42,8 +43,10 @@ from sqlmesh.utils.pydantic import PydanticModel, PRIVATE_FIELDS from sqlmesh.utils.metaprogramming import ( Executable, + build_env, prepare_env, print_exception, + serialize_env, ) if t.TYPE_CHECKING: @@ -556,6 +559,7 @@ def _create_renderer(expression: exp.Expression) -> ExpressionRenderer: jinja_macro_registry=self.jinja_macros, python_env=self.python_env, only_execution_time=False, + quote_identifiers=False, ) def _render(e: exp.Expression) -> str | int | float | bool: @@ -575,7 +579,10 @@ def _render(e: exp.Expression) -> str | int | float | bool: return rendered.this return rendered.sql(dialect=self.dialect) - return [{t.this.name: _render(t.expression) for t in signal} for signal in self.signals] + # airflow only + return [ + {k: _render(v) for k, v in signal.items()} for name, signal in self.signals if not name + ] def ctas_query(self, **render_kwarg: t.Any) -> exp.Query: """Return a dummy query to do a CTAS. @@ -954,7 +961,11 @@ def metadata_hash(self) -> str: metadata.append(key) metadata.append(gen(value)) - metadata.extend(gen(s) for s in self.signals) + for signal_name, args in sorted(self.signals, key=lambda x: x[0]): + metadata.append(signal_name) + for k, v in sorted(args.items()): + metadata.append(f"{k}:{gen(v)}") + metadata.extend(self._additional_metadata) self._metadata_hash = hash_data(metadata) @@ -1598,7 +1609,7 @@ def load_sql_based_model( macros: t.Optional[MacroRegistry] = None, jinja_macros: t.Optional[JinjaMacroRegistry] = None, audits: t.Optional[t.Dict[str, ModelAudit]] = None, - default_audits: t.Optional[t.List[AuditReference]] = None, + default_audits: t.Optional[t.List[FunctionCall]] = None, python_env: t.Optional[t.Dict[str, Executable]] = None, dialect: t.Optional[str] = None, physical_schema_mapping: t.Optional[t.Dict[re.Pattern, str]] = None, @@ -1918,10 +1929,11 @@ def _create_model( physical_schema_mapping: t.Optional[t.Dict[re.Pattern, str]] = None, python_env: t.Optional[t.Dict[str, Executable]] = None, audit_definitions: t.Optional[t.Dict[str, ModelAudit]] = None, - default_audits: t.Optional[t.List[AuditReference]] = None, + default_audits: t.Optional[t.List[FunctionCall]] = None, inline_audits: t.Optional[t.Dict[str, ModelAudit]] = None, module_path: Path = Path(), macros: t.Optional[MacroRegistry] = None, + signal_definitions: t.Optional[SignalRegistry] = None, variables: t.Optional[t.Dict[str, t.Any]] = None, **kwargs: t.Any, ) -> Model: @@ -2018,7 +2030,16 @@ def _create_model( python_env=python_env, ) + env: t.Dict[str, t.Any] = {} + + for signal_name, _ in model.signals: + if signal_definitions and signal_name in signal_definitions: + func = signal_definitions[signal_name].func + setattr(func, c.SQLMESH_METADATA, True) + build_env(func, env=env, name=signal_name, path=module_path) + model.python_env.update(python_env) + model.python_env.update(serialize_env(env, path=module_path)) model._path = path model.set_time_format(time_column_format) @@ -2203,7 +2224,16 @@ def _meta_renderer( "virtual_properties_": lambda value: value, "session_properties_": lambda value: value, "allow_partials": exp.convert, - "signals": lambda values: exp.Tuple(expressions=values), + "signals": lambda values: exp.tuple_( + *( + exp.func( + name, *(exp.PropertyEQ(this=exp.var(k), expression=v) for k, v in args.items()) + ) + if name + else exp.Tuple(expressions=[exp.var(k).eq(v) for k, v in args.items()]) + for name, args in values + ) + ), } diff --git a/sqlmesh/core/model/meta.py b/sqlmesh/core/model/meta.py index bab149c25..2bd4fa13c 100644 --- a/sqlmesh/core/model/meta.py +++ b/sqlmesh/core/model/meta.py @@ -10,12 +10,11 @@ from sqlglot.optimizer.normalize_identifiers import normalize_identifiers from sqlmesh.core import dialect as d -from sqlmesh.core.dialect import normalize_model_name, extract_audit +from sqlmesh.core.dialect import normalize_model_name, extract_func_call from sqlmesh.core.model.common import ( bool_validator, default_catalog_validator, depends_on_validator, - parse_properties, properties_validator, ) from sqlmesh.core.model.kind import ( @@ -45,7 +44,7 @@ if t.TYPE_CHECKING: from sqlmesh.core._typing import CustomMaterializationProperties, SessionProperties -AuditReference = t.Tuple[str, t.Dict[str, exp.Expression]] +FunctionCall = t.Tuple[str, t.Dict[str, exp.Expression]] logger = logging.getLogger(__name__) @@ -67,7 +66,7 @@ class ModelMeta(_Node): column_descriptions_: t.Optional[t.Dict[str, str]] = Field( default=None, alias="column_descriptions" ) - audits: t.List[AuditReference] = [] + audits: t.List[FunctionCall] = [] grains: t.List[exp.Expression] = [] references: t.List[exp.Expression] = [] physical_schema_override: t.Optional[str] = None @@ -75,7 +74,7 @@ class ModelMeta(_Node): virtual_properties_: t.Optional[exp.Tuple] = Field(default=None, alias="virtual_properties") session_properties_: t.Optional[exp.Tuple] = Field(default=None, alias="session_properties") allow_partials: bool = False - signals: t.List[exp.Tuple] = [] + signals: t.List[FunctionCall] = [] enabled: bool = True physical_version: t.Optional[str] = None gateway: t.Optional[str] = None @@ -86,21 +85,23 @@ class ModelMeta(_Node): _default_catalog_validator = default_catalog_validator _depends_on_validator = depends_on_validator - @field_validator("audits", mode="before") - def _audits_validator(cls, v: t.Any) -> t.Any: + @field_validator("audits", "signals", mode="before") + def _func_call_validator(cls, v: t.Any, field: t.Any) -> t.Any: + is_signal = getattr(field, "name" if hasattr(field, "name") else "field_name") == "signals" + if isinstance(v, (exp.Tuple, exp.Array)): - return [extract_audit(i) for i in v.expressions] + return [extract_func_call(i, allow_tuples=is_signal) for i in v.expressions] if isinstance(v, exp.Paren): - return [extract_audit(v.this)] + return [extract_func_call(v.this, allow_tuples=is_signal)] if isinstance(v, exp.Expression): - return [extract_audit(v)] + return [extract_func_call(v, allow_tuples=is_signal)] if isinstance(v, list): audits = [] for entry in v: if isinstance(entry, dict): args = entry - name = entry.pop("name") + name = "" if is_signal else entry.pop("name") elif isinstance(entry, (tuple, list)): name, args = entry else: @@ -286,28 +287,6 @@ def _refs_validator(cls, vs: t.Any, values: t.Dict[str, t.Any]) -> t.List[exp.Ex return refs - @field_validator("signals", mode="before") - @field_validator_v1_args - def _signals_validator(cls, v: t.Any, values: t.Dict[str, t.Any]) -> t.Any: - if v is None: - return [] - - if isinstance(v, str): - dialect = values.get("dialect") - v = d.parse_one(v, dialect=dialect) - - if isinstance(v, (exp.Array, exp.Paren, exp.Tuple)): - tuples: t.List[exp.Expression] = ( - [v.unnest()] if isinstance(v, exp.Paren) else v.expressions - ) - signals = [parse_properties(cls, t, values) for t in tuples] - elif isinstance(v, list): - signals = [parse_properties(cls, t, values) for t in v] - else: - raise ConfigError(f"Unexpected signals '{v}'") - - return signals - @model_validator(mode="before") @model_validator_v1_args def _pre_root_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: diff --git a/sqlmesh/core/scheduler.py b/sqlmesh/core/scheduler.py index 92fb9e3b1..b14b7469c 100644 --- a/sqlmesh/core/scheduler.py +++ b/sqlmesh/core/scheduler.py @@ -1,6 +1,5 @@ from __future__ import annotations -import abc import logging import typing as t @@ -32,10 +31,8 @@ TimeLike, now, now_timestamp, - to_datetime, to_timestamp, validate_date_range, - DatetimeRanges, ) from sqlmesh.utils.errors import AuditError, CircuitBreakerError, SQLMeshError @@ -46,66 +43,6 @@ SchedulingUnit = t.Tuple[str, t.Tuple[Interval, int]] -class Signal(abc.ABC): - @abc.abstractmethod - def check_intervals(self, batch: DatetimeRanges) -> t.Union[bool, DatetimeRanges]: - """Returns which intervals are ready from a list of scheduled intervals. - - When SQLMesh wishes to execute a batch of intervals, say between `a` and `d`, then - the `batch` parameter will contain each individual interval within this batch, - i.e.: `[a,b),[b,c),[c,d)`. - - This function may return `True` to indicate that the whole batch is ready, - `False` to indicate none of the batch's intervals are ready, or a list of - intervals (a batch) to indicate exactly which ones are ready. - - When returning a batch, the function is expected to return a subset of - the `batch` parameter, e.g.: `[a,b),[b,c)`. Note that it may return - gaps, e.g.: `[a,b),[c,d)`, but it may not alter the bounds of any of the - intervals. - - The interface allows an implementation to check batches of intervals without - having to actually compute individual intervals itself. - - Args: - batch: the list of intervals that are missing and scheduled to run. - - Returns: - Either `True` to indicate all intervals are ready, `False` to indicate none are - ready or a list of intervals to indicate exactly which ones are ready. - """ - - -SignalFactory = t.Callable[[t.Dict[str, t.Union[str, int, float, bool]]], Signal] -_registered_signal_factory: t.Optional[SignalFactory] = None - - -def signal_factory(f: SignalFactory) -> None: - """Specifies a function as the SignalFactory to use for building Signal instances from model signal metadata. - - Only one such function may be decorated with this decorator. - - Example: - import typing as t - from sqlmesh.core.scheduler import signal_factory, Batch, Signal - - class AlwaysReadySignal(Signal): - def check_intervals(self, batch: Batch) -> t.Union[bool, Batch]: - return True - - @signal_factory - def my_signal_factory(signal_metadata: t.Dict[str, t.Union[str, int, float, bool]]) -> Signal: - return AlwaysReadySignal() - """ - - global _registered_signal_factory - - if _registered_signal_factory is not None and _registered_signal_factory.__code__ != f.__code__: - raise SQLMeshError("Only one function may be decorated with @signal_factory") - - _registered_signal_factory = f - - class Scheduler: """Schedules and manages the evaluation of snapshots. @@ -121,7 +58,6 @@ class Scheduler: state_sync: The state sync to pull saved snapshots. max_workers: The maximum number of parallel queries to run. console: The rich instance used for printing scheduling information. - signal_factory: A factory method for building Signal instances from model signal configuration. """ def __init__( @@ -133,7 +69,6 @@ def __init__( max_workers: int = 1, console: t.Optional[Console] = None, notification_target_manager: t.Optional[NotificationTargetManager] = None, - signal_factory: t.Optional[SignalFactory] = None, ): self.state_sync = state_sync self.snapshots = {s.snapshot_id: s for s in snapshots} @@ -145,7 +80,6 @@ def __init__( self.notification_target_manager = ( notification_target_manager or NotificationTargetManager() ) - self.signal_factory = signal_factory or _registered_signal_factory def merged_missing_intervals( self, @@ -401,19 +335,9 @@ def expand_range_as_interval( snapshot_batches = {} all_unready_intervals: t.Dict[str, set[Interval]] = {} for snapshot, intervals in snapshot_intervals.items(): - if self.signal_factory and snapshot.is_model: - unready = set(intervals) - - for signal in snapshot.model.render_signals( - start=start, end=end, execution_time=execution_time - ): - intervals = _check_ready_intervals( - signal=self.signal_factory(signal), - intervals=intervals, - ) - unready -= set(intervals) - else: - unready = set() + unready = set(intervals) + intervals = snapshot.check_ready_intervals(intervals) + unready -= set(intervals) for parent in snapshot.parents: if parent.name in all_unready_intervals: @@ -687,53 +611,3 @@ def _resolve_one_snapshot_per_version( snapshot_per_version[key] = snapshot return snapshot_per_version - - -def _contiguous_intervals( - intervals: Intervals, -) -> t.List[Intervals]: - """Given a list of intervals with gaps, returns a list of sequences of contiguous intervals.""" - contiguous_intervals = [] - current_batch: t.List[Interval] = [] - for interval in intervals: - if len(current_batch) == 0 or interval[0] == current_batch[-1][-1]: - current_batch.append(interval) - else: - contiguous_intervals.append(current_batch) - current_batch = [interval] - - if len(current_batch) > 0: - contiguous_intervals.append(current_batch) - - return contiguous_intervals - - -def _check_ready_intervals( - signal: Signal, - intervals: Intervals, -) -> Intervals: - """Returns a list of intervals that are considered ready by the provided signal. - - Note that this will handle gaps in the provided intervals. The returned intervals - may introduce new gaps. - """ - checked_intervals = [] - for interval_batch in _contiguous_intervals(intervals): - batch = [(to_datetime(start), to_datetime(end)) for start, end in interval_batch] - - ready_intervals = signal.check_intervals(batch=batch) - if isinstance(ready_intervals, bool): - if not ready_intervals: - batch = [] - elif isinstance(ready_intervals, list): - for i in ready_intervals: - if i not in batch: - raise RuntimeError(f"Signal returned unknown interval {i}") - batch = ready_intervals - else: - raise ValueError( - f"unexpected return value from signal, expected bool | list, got {type(ready_intervals)}" - ) - - checked_intervals.extend([(to_timestamp(start), to_timestamp(end)) for start, end in batch]) - return checked_intervals diff --git a/sqlmesh/core/signal.py b/sqlmesh/core/signal.py new file mode 100644 index 000000000..d9ee67092 --- /dev/null +++ b/sqlmesh/core/signal.py @@ -0,0 +1,35 @@ +from __future__ import annotations + + +from sqlmesh.utils import UniqueKeyDict, registry_decorator + + +class signal(registry_decorator): + """Specifies a function which intervals are ready from a list of scheduled intervals. + + When SQLMesh wishes to execute a batch of intervals, say between `a` and `d`, then + the `batch` parameter will contain each individual interval within this batch, + i.e.: `[a,b),[b,c),[c,d)`. + + This function may return `True` to indicate that the whole batch is ready, + `False` to indicate none of the batch's intervals are ready, or a list of + intervals (a batch) to indicate exactly which ones are ready. + + When returning a batch, the function is expected to return a subset of + the `batch` parameter, e.g.: `[a,b),[b,c)`. Note that it may return + gaps, e.g.: `[a,b),[c,d)`, but it may not alter the bounds of any of the + intervals. + + The interface allows an implementation to check batches of intervals without + having to actually compute individual intervals itself. + + Args: + batch: the list of intervals that are missing and scheduled to run. + + Returns: + Either `True` to indicate all intervals are ready, `False` to indicate none are + ready or a list of intervals to indicate exactly which ones are ready. + """ + + +SignalRegistry = UniqueKeyDict[str, signal] diff --git a/sqlmesh/core/snapshot/definition.py b/sqlmesh/core/snapshot/definition.py index 3074f2a08..dae13b035 100644 --- a/sqlmesh/core/snapshot/definition.py +++ b/sqlmesh/core/snapshot/definition.py @@ -5,6 +5,7 @@ from datetime import datetime, timedelta from enum import IntEnum from functools import cached_property, lru_cache +from pathlib import Path from pydantic import Field from sqlglot import exp @@ -12,6 +13,7 @@ from sqlmesh.core import constants as c from sqlmesh.core.audit import StandaloneAudit +from sqlmesh.core.macros import call_macro from sqlmesh.core.model import Model, ModelKindMixin, ModelKindName, ViewKind, CustomKind from sqlmesh.core.model.definition import _Model from sqlmesh.core.node import IntervalUnit, NodeType @@ -34,6 +36,7 @@ yesterday, ) from sqlmesh.utils.errors import SQLMeshError +from sqlmesh.utils.metaprogramming import prepare_env, print_exception from sqlmesh.utils.hashing import hash_data from sqlmesh.utils.pydantic import PydanticModel, field_validator @@ -885,6 +888,37 @@ def missing_intervals( model_end_ts, ) + def check_ready_intervals(self, intervals: Intervals) -> Intervals: + """Returns a list of intervals that are considered ready by the provided signal. + + Note that this will handle gaps in the provided intervals. The returned intervals + may introduce new gaps. + """ + signals = self.is_model and self.model.signals + + if not signals: + return intervals + + python_env = self.model.python_env + env = prepare_env(python_env) + + for signal_name, kwargs in signals: + try: + intervals = _check_ready_intervals( + env[signal_name], + intervals, + dialect=self.model.dialect, + path=self.model._path, + kwargs=kwargs, + ) + except SQLMeshError as e: + print_exception(e, python_env) + raise SQLMeshError( + f"{e} '{signal_name}' for '{self.model.name}' at {self.model._path}" + ) + + return intervals + def categorize_as(self, category: SnapshotChangeCategory) -> None: """Assigns the given category to this snapshot. @@ -1836,3 +1870,53 @@ def snapshots_to_dag(snapshots: t.Collection[Snapshot]) -> DAG[SnapshotId]: for snapshot in snapshots: dag.add(snapshot.snapshot_id, snapshot.parents) return dag + + +def _contiguous_intervals(intervals: Intervals) -> t.List[Intervals]: + """Given a list of intervals with gaps, returns a list of sequences of contiguous intervals.""" + contiguous_intervals = [] + current_batch: t.List[Interval] = [] + for interval in intervals: + if len(current_batch) == 0 or interval[0] == current_batch[-1][-1]: + current_batch.append(interval) + else: + contiguous_intervals.append(current_batch) + current_batch = [interval] + + if len(current_batch) > 0: + contiguous_intervals.append(current_batch) + + return contiguous_intervals + + +def _check_ready_intervals( + check: t.Callable, + intervals: Intervals, + dialect: DialectType = None, + path: Path = Path(), + kwargs: t.Optional[t.Dict] = None, +) -> Intervals: + checked_intervals: Intervals = [] + + for interval_batch in _contiguous_intervals(intervals): + batch = [(to_datetime(start), to_datetime(end)) for start, end in interval_batch] + + try: + ready_intervals = call_macro(check, dialect, path, batch, **(kwargs or {})) + except Exception: + raise SQLMeshError("Error evaluating signal") + + if isinstance(ready_intervals, bool): + if not ready_intervals: + batch = [] + elif isinstance(ready_intervals, list): + for i in ready_intervals: + if i not in batch: + raise SQLMeshError(f"Unknown interval {i} for signal") + batch = ready_intervals + else: + raise SQLMeshError(f"Expected bool | list, got {type(ready_intervals)} for signal") + + checked_intervals.extend((to_timestamp(start), to_timestamp(end)) for start, end in batch) + + return checked_intervals diff --git a/sqlmesh/dbt/loader.py b/sqlmesh/dbt/loader.py index 7a0561e95..4b6caa04f 100644 --- a/sqlmesh/dbt/loader.py +++ b/sqlmesh/dbt/loader.py @@ -13,6 +13,7 @@ from sqlmesh.core.loader import LoadedProject, Loader from sqlmesh.core.macros import MacroRegistry, macro from sqlmesh.core.model import Model, ModelCache +from sqlmesh.core.signal import signal from sqlmesh.dbt.basemodel import BMC, BaseModelConfig from sqlmesh.dbt.context import DbtContext from sqlmesh.dbt.model import ModelConfig @@ -94,6 +95,7 @@ def _load_models( jinja_macros: JinjaMacroRegistry, gateway: t.Optional[str], audits: UniqueKeyDict[str, ModelAudit], + signals: UniqueKeyDict[str, signal], ) -> UniqueKeyDict[str, Model]: models: UniqueKeyDict[str, Model] = UniqueKeyDict("models") diff --git a/sqlmesh/migrations/v0059_add_physical_version.py b/sqlmesh/migrations/v0059_add_physical_version.py deleted file mode 100644 index ae24ee390..000000000 --- a/sqlmesh/migrations/v0059_add_physical_version.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Add the physical_version model attribute.""" - - -def migrate(state_sync, **kwargs): # type: ignore - pass diff --git a/sqlmesh/migrations/v0062_change_signals.py b/sqlmesh/migrations/v0062_change_signals.py new file mode 100644 index 000000000..d2060f552 --- /dev/null +++ b/sqlmesh/migrations/v0062_change_signals.py @@ -0,0 +1,83 @@ +"""Change serialization of signals to allow for function calls.""" + +import json + +import pandas as pd +from sqlglot import exp + +from sqlmesh.utils.migration import index_text_type + + +def migrate(state_sync, **kwargs): # type: ignore + engine_adapter = state_sync.engine_adapter + schema = state_sync.schema + snapshots_table = "_snapshots" + index_type = index_text_type(engine_adapter.dialect) + if schema: + snapshots_table = f"{schema}.{snapshots_table}" + + new_snapshots = [] + + for ( + name, + identifier, + version, + snapshot, + kind_name, + updated_ts, + unpaused_ts, + ttl_ms, + unrestorable, + ) in engine_adapter.fetchall( + exp.select( + "name", + "identifier", + "version", + "snapshot", + "kind_name", + "updated_ts", + "unpaused_ts", + "ttl_ms", + "unrestorable", + ).from_(snapshots_table), + quote_identifiers=True, + ): + parsed_snapshot = json.loads(snapshot) + node = parsed_snapshot["node"] + signals = node.pop("signals", None) + + if signals: + node["signals"] = [("", signal) for signal in signals] + + new_snapshots.append( + { + "name": name, + "identifier": identifier, + "version": version, + "snapshot": json.dumps(parsed_snapshot), + "kind_name": kind_name, + "updated_ts": updated_ts, + "unpaused_ts": unpaused_ts, + "ttl_ms": ttl_ms, + "unrestorable": unrestorable, + } + ) + + if new_snapshots: + engine_adapter.delete_from(snapshots_table, "TRUE") + + engine_adapter.insert_append( + snapshots_table, + pd.DataFrame(new_snapshots), + columns_to_types={ + "name": exp.DataType.build(index_type), + "identifier": exp.DataType.build(index_type), + "version": exp.DataType.build(index_type), + "snapshot": exp.DataType.build("text"), + "kind_name": exp.DataType.build(index_type), + "updated_ts": exp.DataType.build("bigint"), + "unpaused_ts": exp.DataType.build("bigint"), + "ttl_ms": exp.DataType.build("bigint"), + "unrestorable": exp.DataType.build("boolean"), + }, + ) diff --git a/sqlmesh/utils/metaprogramming.py b/sqlmesh/utils/metaprogramming.py index 4173c8eb2..2b2ab3f01 100644 --- a/sqlmesh/utils/metaprogramming.py +++ b/sqlmesh/utils/metaprogramming.py @@ -21,7 +21,7 @@ from sqlmesh.utils.errors import SQLMeshError from sqlmesh.utils.pydantic import PydanticModel -IGNORE_DECORATORS = {"macro", "model"} +IGNORE_DECORATORS = {"macro", "model", "signal"} def _is_relative_to(path: t.Optional[Path | str], other: t.Optional[Path | str]) -> bool: diff --git a/tests/core/test_model.py b/tests/core/test_model.py index d87ca1b33..f3ecb00b4 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -3820,9 +3820,10 @@ def test_signals(): MODEL ( name db.table, signals [ + my_signal(arg = 1), ( - table_name = 'table_a', - ds = @end_ds, + table_name := 'table_a', + ds := @end_ds, ), ( table_name = 'table_b', @@ -3843,26 +3844,35 @@ def test_signals(): model = load_sql_based_model(expressions) assert model.signals == [ - exp.Tuple( - expressions=[ - exp.to_column("table_name").eq("table_a"), - exp.to_column("ds").eq(d.MacroVar(this="end_ds")), - ] + ( + "my_signal", + { + "arg": exp.Literal.number(1), + }, ), - exp.Tuple( - expressions=[ - exp.to_column("table_name").eq("table_b"), - exp.to_column("ds").eq(d.MacroVar(this="end_ds")), - exp.to_column("hour").eq(d.MacroVar(this="end_hour")), - ] + ( + "", + { + "table_name": exp.Literal.string("table_a"), + "ds": d.MacroVar(this="end_ds"), + }, ), - exp.Tuple( - expressions=[ - exp.to_column("bool_key").eq(True), - exp.to_column("int_key").eq(1), - exp.to_column("float_key").eq(1.0), - exp.to_column("string_key").eq("string"), - ] + ( + "", + { + "table_name": exp.Literal.string("table_b"), + "ds": d.MacroVar(this="end_ds"), + "hour": d.MacroVar(this="end_hour"), + }, + ), + ( + "", + { + "bool_key": exp.true(), + "int_key": exp.Literal.number(1), + "float_key": exp.Literal.number(1.0), + "string_key": exp.Literal.string("string"), + }, ), ] @@ -3874,7 +3884,7 @@ def test_signals(): ] assert ( - "signals ((table_name = 'table_a', ds = @end_ds), (table_name = 'table_b', ds = @end_ds, hour = @end_hour), (bool_key = TRUE, int_key = 1, float_key = 1.0, string_key = 'string')" + "signals (MY_SIGNAL(arg := 1), (table_name = 'table_a', ds = @end_ds), (table_name = 'table_b', ds = @end_ds, hour = @end_hour), (bool_key = TRUE, int_key = 1, float_key = 1.0, string_key = 'string'))" in model.render_definition()[0].sql() ) diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index 4b7de2aa5..804348d3d 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -10,7 +10,6 @@ from sqlmesh.core.model import load_sql_based_model from sqlmesh.core.model.definition import AuditResult, SqlModel from sqlmesh.core.model.kind import ( - FullKind, IncrementalByTimeRangeKind, IncrementalByUniqueKeyKind, TimeColumn, @@ -20,10 +19,9 @@ Scheduler, interval_diff, compute_interval_params, - signal_factory, - Signal, SnapshotToIntervals, ) +from sqlmesh.core.signal import signal from sqlmesh.core.snapshot import ( Snapshot, SnapshotEvaluator, @@ -506,95 +504,6 @@ def test_external_model_audit(mocker, make_snapshot): spy.assert_called_once() -def test_contiguous_intervals(): - from sqlmesh.core.scheduler import _contiguous_intervals as ci - - assert ci([]) == [] - assert ci([(0, 1)]) == [[(0, 1)]] - assert ci([(0, 1), (1, 2), (2, 3)]) == [[(0, 1), (1, 2), (2, 3)]] - assert ci([(0, 1), (3, 4), (4, 5), (6, 7)]) == [ - [(0, 1)], - [(3, 4), (4, 5)], - [(6, 7)], - ] - - -def test_check_ready_intervals(mocker: MockerFixture): - from sqlmesh.core.scheduler import _check_ready_intervals - from sqlmesh.core.snapshot.definition import Interval - - def const_signal(const): - signal_mock = mocker.Mock() - signal_mock.check_intervals = mocker.MagicMock(return_value=const) - return signal_mock - - def assert_always_signal(intervals): - _check_ready_intervals(const_signal(True), intervals) == intervals - - assert_always_signal([]) - assert_always_signal([(0, 1)]) - assert_always_signal([(0, 1), (1, 2)]) - assert_always_signal([(0, 1), (2, 3)]) - - def assert_never_signal(intervals): - _check_ready_intervals(const_signal(False), intervals) == [] - - assert_never_signal([]) - assert_never_signal([(0, 1)]) - assert_never_signal([(0, 1), (1, 2)]) - assert_never_signal([(0, 1), (2, 3)]) - - def to_intervals(values: t.List[t.Tuple[int, int]]) -> t.List[Interval]: - return [(to_datetime(s), to_datetime(e)) for s, e in values] - - def assert_check_intervals( - intervals: t.List[t.Tuple[int, int]], - ready: t.List[t.List[t.Tuple[int, int]]], - expected: t.List[t.Tuple[int, int]], - ): - signal_mock = mocker.Mock() - signal_mock.check_intervals = mocker.MagicMock(side_effect=[to_intervals(r) for r in ready]) - _check_ready_intervals(signal_mock, intervals) == expected - - assert_check_intervals([], [], []) - assert_check_intervals([(0, 1)], [[]], []) - assert_check_intervals( - [(0, 1)], - [[(0, 1)]], - [(0, 1)], - ) - assert_check_intervals( - [(0, 1), (1, 2)], - [[(0, 1)]], - [(0, 1)], - ) - assert_check_intervals( - [(0, 1), (1, 2)], - [[(1, 2)]], - [(1, 2)], - ) - assert_check_intervals( - [(0, 1), (1, 2)], - [[(0, 1), (1, 2)]], - [(0, 1), (1, 2)], - ) - assert_check_intervals( - [(0, 1), (1, 2), (3, 4)], - [[], []], - [], - ) - assert_check_intervals( - [(0, 1), (1, 2), (3, 4)], - [[(0, 1)], []], - [(0, 1)], - ) - assert_check_intervals( - [(0, 1), (1, 2), (3, 4)], - [[(0, 1)], [(3, 4)]], - [(0, 1), (3, 4)], - ) - - def test_audit_failure_notifications( scheduler: Scheduler, waiter_names: Snapshot, mocker: MockerFixture ): @@ -661,56 +570,6 @@ def _evaluate(): assert notify_mock.call_count == 1 -def test_signal_factory(mocker: MockerFixture, make_snapshot): - class AlwaysReadySignal(Signal): - def check_intervals(self, batch: DatetimeRanges): - return True - - signal_factory_invoked = 0 - - @signal_factory - def factory(signal_metadata): - nonlocal signal_factory_invoked - signal_factory_invoked += 1 - assert signal_metadata.get("kind") == "foo" - return AlwaysReadySignal() - - start = to_datetime("2023-01-01") - end = to_datetime("2023-01-07") - snapshot: Snapshot = make_snapshot( - SqlModel( - name="name", - kind=FullKind(), - owner="owner", - dialect="", - cron="@daily", - start=start, - query=parse_one("SELECT id FROM VALUES (1), (2) AS t(id)"), - signals=[{"kind": "foo"}], - ), - ) - snapshot_evaluator = SnapshotEvaluator(adapters=mocker.MagicMock(), ddl_concurrent_tasks=1) - scheduler = Scheduler( - snapshots=[snapshot], - snapshot_evaluator=snapshot_evaluator, - state_sync=mocker.MagicMock(), - max_workers=2, - default_catalog=None, - console=mocker.MagicMock(), - ) - merged_intervals = scheduler.merged_missing_intervals(start, end, end) - assert len(merged_intervals) == 1 - scheduler.run_merged_intervals( - merged_intervals=merged_intervals, - deployability_index=DeployabilityIndex.all_deployable(), - environment_naming_info=EnvironmentNamingInfo(), - start=start, - end=end, - ) - - assert signal_factory_invoked > 0 - - def test_interval_diff(): assert interval_diff([(1, 2)], []) == [(1, 2)] assert interval_diff([(1, 2)], [(1, 2)]) == [] @@ -741,51 +600,85 @@ def test_interval_diff(): def test_signal_intervals(mocker: MockerFixture, make_snapshot, get_batched_missing_intervals): - class TestSignal(Signal): - def __init__(self, signal: t.Dict): - self.name = signal["kind"] + @signal() + def signal_a(batch: DatetimeRanges): + return [batch[0], batch[1]] + + @signal() + def signal_b(batch: DatetimeRanges): + return batch[-49:] - def check_intervals(self, batch: DatetimeRanges): - if self.name == "a": - return [batch[0], batch[1]] - if self.name == "b": - return batch[-49:] + signals = signal.get_registry() a = make_snapshot( - SqlModel( - name="a", - kind="full", - start="2023-01-01", - query=parse_one("SELECT 1 x"), - signals=[{"kind": "a"}], + load_sql_based_model( + parse( # type: ignore + """ + MODEL ( + name a, + kind FULL, + start '2023-01-01', + signals SIGNAL_A(), + ); + + SELECT 1 x; + """ + ), + signal_definitions=signals, ), ) + b = make_snapshot( - SqlModel( - name="b", - kind="full", - start="2023-01-01", - cron="@hourly", - query=parse_one("SELECT 2 x"), - signals=[{"kind": "b"}], + load_sql_based_model( + parse( # type: ignore + """ + MODEL ( + name b, + kind FULL, + cron '@hourly', + start '2023-01-01', + signals SIGNAL_B(), + ); + + SELECT 2 x; + """ + ), + signal_definitions=signals, ), nodes={a.name: a.model}, ) + c = make_snapshot( - SqlModel( - name="c", - kind="full", - start="2023-01-01", - query=parse_one("select * from a union select * from b"), + load_sql_based_model( + parse( # type: ignore + """ + MODEL ( + name c, + kind FULL, + start '2023-01-01', + ); + + SELECT * FROM a UNION SELECT * FROM b + """ + ), + signal_definitions=signals, ), nodes={a.name: a.model, b.name: b.model}, ) d = make_snapshot( - SqlModel( - name="d", - kind="full", - start="2023-01-01", - query=parse_one("select * from c union all select * from d"), + load_sql_based_model( + parse( # type: ignore + """ + MODEL ( + name d, + kind FULL, + start '2023-01-01', + ); + + SELECT * FROM c UNION SELECT * FROM d + """ + ), + signal_definitions=signals, ), nodes={a.name: a.model, b.name: b.model, c.name: c.model}, ) @@ -797,7 +690,6 @@ def check_intervals(self, batch: DatetimeRanges): state_sync=mocker.MagicMock(), max_workers=2, default_catalog=None, - signal_factory=lambda signal: TestSignal(signal), ) batches = get_batched_missing_intervals(scheduler, "2023-01-01", "2023-01-03", None) diff --git a/tests/core/test_snapshot.py b/tests/core/test_snapshot.py index 1db72f815..7872ae39a 100644 --- a/tests/core/test_snapshot.py +++ b/tests/core/test_snapshot.py @@ -47,9 +47,13 @@ ) from sqlmesh.core.snapshot.cache import SnapshotCache from sqlmesh.core.snapshot.categorizer import categorize_change -from sqlmesh.core.snapshot.definition import display_name +from sqlmesh.core.snapshot.definition import ( + display_name, + _check_ready_intervals, + _contiguous_intervals, +) from sqlmesh.utils import AttributeDict -from sqlmesh.utils.date import to_date, to_datetime, to_timestamp +from sqlmesh.utils.date import DatetimeRanges, to_date, to_datetime, to_timestamp from sqlmesh.utils.errors import SQLMeshError from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroInfo @@ -878,6 +882,29 @@ def test_fingerprint_virtual_properties(model: Model, parent_model: Model): assert updated_fingerprint.data_hash == fingerprint.data_hash +def test_fingerprint_signals(sushi_context_pre_scheduling): + model = deepcopy(sushi_context_pre_scheduling.get_model("sushi.waiter_as_customer_by_day")) + fingerprint = fingerprint_from_node(model, nodes={}) + + def assert_metadata_only(): + model._metadata_hash = None + model._data_hash = None + updated_fingerprint = fingerprint_from_node(model, nodes={}) + + assert updated_fingerprint != fingerprint + assert updated_fingerprint.metadata_hash != fingerprint.metadata_hash + assert updated_fingerprint.data_hash == fingerprint.data_hash + + executable = model.python_env["test_signal"] + model.python_env["test_signal"].payload = executable.payload.replace("arg == 1", "arg == 2") + + assert_metadata_only() + + model = deepcopy(sushi_context_pre_scheduling.get_model("sushi.waiter_as_customer_by_day")) + model.signals.clear() + assert_metadata_only() + + def test_stamp(model: Model): original_fingerprint = fingerprint_from_node(model, nodes={}) @@ -2155,3 +2182,82 @@ def test_physical_version_pin_for_new_forward_only_models(make_snapshot): assert snapshot_f.version == "1234" assert snapshot_f.fingerprint != snapshot_e.fingerprint + + +def test_contiguous_intervals(): + assert _contiguous_intervals([]) == [] + assert _contiguous_intervals([(0, 1)]) == [[(0, 1)]] + assert _contiguous_intervals([(0, 1), (1, 2), (2, 3)]) == [[(0, 1), (1, 2), (2, 3)]] + assert _contiguous_intervals([(0, 1), (3, 4), (4, 5), (6, 7)]) == [ + [(0, 1)], + [(3, 4), (4, 5)], + [(6, 7)], + ] + + +def test_check_ready_intervals(mocker: MockerFixture): + def assert_always_signal(intervals): + _check_ready_intervals(lambda _: True, intervals) == intervals + + assert_always_signal([]) + assert_always_signal([(0, 1)]) + assert_always_signal([(0, 1), (1, 2)]) + assert_always_signal([(0, 1), (2, 3)]) + + def assert_never_signal(intervals): + _check_ready_intervals(lambda _: False, intervals) == [] + + assert_never_signal([]) + assert_never_signal([(0, 1)]) + assert_never_signal([(0, 1), (1, 2)]) + assert_never_signal([(0, 1), (2, 3)]) + + def to_intervals(values: t.List[t.Tuple[int, int]]) -> DatetimeRanges: + return [(to_datetime(s), to_datetime(e)) for s, e in values] + + def assert_check_intervals( + intervals: t.List[t.Tuple[int, int]], + ready: t.List[t.List[t.Tuple[int, int]]], + expected: t.List[t.Tuple[int, int]], + ): + mock = mocker.Mock() + mock.side_effect = [to_intervals(r) for r in ready] + _check_ready_intervals(mock, intervals) == expected + + assert_check_intervals([], [], []) + assert_check_intervals([(0, 1)], [[]], []) + assert_check_intervals( + [(0, 1)], + [[(0, 1)]], + [(0, 1)], + ) + assert_check_intervals( + [(0, 1), (1, 2)], + [[(0, 1)]], + [(0, 1)], + ) + assert_check_intervals( + [(0, 1), (1, 2)], + [[(1, 2)]], + [(1, 2)], + ) + assert_check_intervals( + [(0, 1), (1, 2)], + [[(0, 1), (1, 2)]], + [(0, 1), (1, 2)], + ) + assert_check_intervals( + [(0, 1), (1, 2), (3, 4)], + [[], []], + [], + ) + assert_check_intervals( + [(0, 1), (1, 2), (3, 4)], + [[(0, 1)], []], + [(0, 1)], + ) + assert_check_intervals( + [(0, 1), (1, 2), (3, 4)], + [[(0, 1)], [(3, 4)]], + [(0, 1), (3, 4)], + ) diff --git a/tests/schedulers/airflow/operators/test_sensor.py b/tests/schedulers/airflow/operators/test_sensor.py index 7bb4a2eb6..c13795dbf 100644 --- a/tests/schedulers/airflow/operators/test_sensor.py +++ b/tests/schedulers/airflow/operators/test_sensor.py @@ -122,12 +122,15 @@ def test_external_sensor(mocker: MockerFixture, make_snapshot, set_airflow_as_li name="this", query=parse_one("select 1"), signals=[ - {"table_name": "test_table_name_a", "ds": parse_one("@end_ds")}, - { - "table_name": "test_table_name_b", - "ds": parse_one("@end_ds"), - "hour": parse_one("@end_hour"), - }, + ("", {"table_name": "test_table_name_a", "ds": parse_one("@end_ds")}), + ( + "", + { + "table_name": "test_table_name_b", + "ds": parse_one("@end_ds"), + "hour": parse_one("@end_hour"), + }, + ), ], ) )