From a1d8e400bc58fa923508255ee4d0500987e9c568 Mon Sep 17 00:00:00 2001 From: tobymao Date: Mon, 2 Dec 2024 17:06:24 -0800 Subject: [PATCH] fix!: make signals serializable --- docs/guides/signals.md | 135 +++------- .../models/waiter_as_customer_by_day.sql | 6 +- examples/sushi/signals/__init__.py | 9 + sqlmesh/__init__.py | 6 +- sqlmesh/core/config/model.py | 8 +- sqlmesh/core/context.py | 1 - sqlmesh/core/dialect.py | 10 +- sqlmesh/core/loader.py | 57 ++++- sqlmesh/core/macros.py | 226 +++++++++-------- sqlmesh/core/model/definition.py | 42 ++- sqlmesh/core/model/meta.py | 45 +--- sqlmesh/core/scheduler.py | 132 +--------- sqlmesh/core/signal.py | 35 +++ sqlmesh/core/snapshot/definition.py | 84 ++++++ sqlmesh/dbt/loader.py | 2 + ...model.py => v0060_move_audits_to_model.py} | 0 ...e.py => v0061_mysql_fix_blob_text_type.py} | 0 ..._gateway.py => v0062_add_model_gateway.py} | 0 sqlmesh/migrations/v0063_change_signals.py | 83 ++++++ sqlmesh/utils/metaprogramming.py | 2 +- tests/core/test_model.py | 52 ++-- tests/core/test_scheduler.py | 240 +++++------------- tests/core/test_snapshot.py | 110 +++++++- .../airflow/operators/test_sensor.py | 15 +- 24 files changed, 702 insertions(+), 598 deletions(-) create mode 100644 examples/sushi/signals/__init__.py create mode 100644 sqlmesh/core/signal.py rename sqlmesh/migrations/{v0059_move_audits_to_model.py => v0060_move_audits_to_model.py} (100%) rename sqlmesh/migrations/{v0060_mysql_fix_blob_text_type.py => v0061_mysql_fix_blob_text_type.py} (100%) rename sqlmesh/migrations/{v0061_add_model_gateway.py => v0062_add_model_gateway.py} (100%) create mode 100644 sqlmesh/migrations/v0063_change_signals.py diff --git a/docs/guides/signals.md b/docs/guides/signals.md index 3466d7b2a..8ce254340 100644 --- a/docs/guides/signals.md +++ b/docs/guides/signals.md @@ -14,7 +14,7 @@ The scheduler uses two criteria to determine whether a model should be evaluated Signals allow you to specify additional criteria that must be met before the scheduler evaluates the model. -A signal definition has two components: a "checking" function that checks whether a criterion is met and a "factory function" that provides the checking function to SQLMesh. Before describing the checking function, we provide some background information about how the scheduler works. +A signal definition is simply a function that checks whether a criterion is met. Before describing the checking function, we provide some background information about how the scheduler works. The scheduler doesn't actually evaluate "a model" - it evaluates a model over a specific time interval. This is clearest for incremental models, where only rows in the time interval are ingested during an evaluation. However, evaluation of non-temporal model kinds like `FULL` and `VIEW` are also based on a time interval: the model's `cron` frequency. @@ -22,90 +22,49 @@ The scheduler's decisions are based on these time intervals. For each model, the It then divides those into _batches_ (configured with the model's [batch_size](../concepts/models/overview.md#batch_size) parameter). For incremental models, it evaluates the model once for each batch. For non-incremental models, it evaluates the model once if any batch contains an interval. -Signal checking functions examines a batch of time intervals. The function has two inputs: signal metadata values defined your model and a batch of time intervals. It may return `True` if all intervals are ready for evaluation, `False` if no intervals are ready, or the time intervals themselves if only some are ready. A checking function is defined as a method on a `Signal` sub-class. - -A project may have one signal factory function. The factory function determines which checking function should be used for a given model and signal. Its inputs are the signal metadata values defined in your model, and it returns the checking function. +Signal checking functions examines a batch of time intervals. The function is always called with a batch of time intervals (DateTimeRanges). It can also optionally be called with key word arguments. It may return `True` if all intervals are ready for evaluation, `False` if no intervals are ready, or the time intervals themselves if only some are ready. A checking function is defined with the `@signal` decorator. ## Defining a signal -To define a signal, create a `signals` directory in your project folder. Define your signal in a file named `__init__.py` in that directory. - -The file must: +To define a signal, create a `signals` directory in your project folder. Define your signal in a file named `__init__.py` in that directory (you can have additional python file names as well). -- Define at least one `Signal` sub-class containing a `check_intervals` method -- Define a factory function that returns a `Signal` sub-class and decorate the function with the `@signal_factory` decorator +A signal is a function that accepts a batch (DateTimeRanges: t.List[t.Tuple[datetime, datetime]]) and returns a batch or a boolean. It needs use the @signal decorator. We now demonstrate signals of varying complexity. ### Simple example -This example defines the `RandomSignal` class and its mandatory `check_intervals()` method. +This example defines a `RandomSignal` method. The method returns `True` (indicating that all intervals are ready for evaluation) if a random number is greater than a threshold specified in the model definition: ```python linenums="1" import random import typing as t -from sqlmesh.core.scheduler import signal_factory, Signal -from sqlmesh.utils.date import DatetimeRanges - +from sqlmesh import signal, DatetimeRanges -class RandomSignal(Signal): - def __init__( - self, - signal_metadata: t.Dict[str, t.Union[str, int, float, bool]] - ): - self.signal_metadata = signal_metadata - def check_intervals(self, batch: DatetimeRanges) -> t.Union[bool, DatetimeRanges]: - threshold = self.signal_metadata["threshold"] - return random.random() > threshold +@signal() +def random_signal(batch: DatetimeRanges, threshold: float) -> t.Union[bool, DatetimeRanges]: + return random.random() > threshold ``` -Note that the `RandomSignal` class sub-classes `Signal` and takes a `signal_metadata` argument. - -The `check_intervals()` method extracts the threshold metadata and compares a random number to it. - -We can now add a factory function that returns the `RandomSignal` sub-class. Note the `@signal_factory` decorator on line 8: - -```python linenums="1" hl_lines="8-10" -import random -import typing as t -from sqlmesh.core.scheduler import signal_factory, Signal -from sqlmesh.utils.date import DatetimeRanges - - -class RandomSignal(Signal): - def __init__( - self, - signal_metadata: t.Dict[str, t.Union[str, int, float, bool]] - ): - self.signal_metadata = signal_metadata - - def check_intervals(self, batch: DatetimeRanges) -> t.Union[bool, DatetimeRanges]: - threshold = self.signal_metadata["threshold"] - return random.random() > threshold - +Note that the `random_signal()` takes a mandatory user defined `threshold` argument. -@signal_factory -def my_signal_factory(signal_metadata: t.Dict[str, t.Union[str, int, float, bool]]) -> Signal: - return RandomSignal(signal_metadata) -``` - -We now have a working signal! +The `random_signal()` method extracts the threshold metadata and compares a random number to it. The type is inferred based on the same [rules as SQLMesh Macros](../concepts/macros/sqlmesh_macros.md#typed-macros). -We specify that a model should use the signal by passing metadata to the model DDL's `signals` key. +Now that we have a working signal, we need to specify that a model should use the signal by passing metadata to the model DDL's `signals` key. -The `signals` key accepts an array delimited by brackets `[]`. Each tuple in the list should contain the metadata needed for one signal evaluation. +The `signals` key accepts an array delimited by brackets `[]`. Each function in the list should contain the metadata needed for one signal evaluation. -This example specifies that the `RandomSignal` should evaluate once with a threshold of 0.5: +This example specifies that the `random_signal()` should evaluate once with a threshold of 0.5: ```sql linenums="1" hl_lines="4-6" MODEL ( name example.signal_model, kind FULL, signals [ - (threshold = 0.5), # specify threshold value + random_signal(threshold = 0.5), # specify threshold value ] ); @@ -116,47 +75,31 @@ The next time this project is `sqlmesh run`, our signal will metaphorically flip ### Advanced Example -This example demonstrates more advanced use of signals, including: - -- Multiple signals in one model -- A signal returning a subset of intervals from a batch (rather than a single `True`/`False` value for all intervals in the batch) - -In this example, there are two signals. +This example demonstrates more advanced use of signals: a signal returning a subset of intervals from a batch (rather than a single `True`/`False` value for all intervals in the batch) ```python import typing as t -from datetime import datetime - -from sqlmesh.core.scheduler import signal_factory, Signal -from sqlmesh.utils.date import DatetimeRanges, to_datetime - -class AlwaysReady(Signal): - # signal that indicates every interval is always ready - def check_intervals(self, batch: DatetimeRanges) -> t.Union[bool, DatetimeRanges]: - return True +from sqlmesh import signal, DatetimeRanges +from sqlmesh.utils.date import to_datetime -class OneweekAgo(Signal): - def __init__(self, dt: datetime): - self.dt = dt +# signal that returns only intervals that are <= 1 week ago +@signal() +def one_week_ago(batch: DatetimeRanges) -> t.Union[bool, DatetimeRanges]: + dt = to_datetime("1 week ago") - # signal that returns only intervals that are <= 1 week ago - def check_intervals(self, batch: DatetimeRanges) -> t.Union[bool, DatetimeRanges]: - return [ - (start, end) - for start, end in batch - if start <= self.dt - ] + return [ + (start, end) + for start, end in batch + if start <= dt + ] ``` -Instead of returning a single `True`/`False` value for whether a batch of intervals is ready for evaluation, the `OneweekAgo` signal class returns specific intervals from the batch. -Its `check_intervals` method accepts a `dt` datetime argument, to which It compares the beginning of each interval in the batch. If the interval start is before that argument, the interval is ready for evaluation and included in the returned list. -These signals can be added to a model like so. Now that we have more than one signal, we must have a way to tell the signal factory which signal should be called. +Instead of returning a single `True`/`False` value for whether a batch of intervals is ready for evaluation, the `one_week_ago()` function returns specific intervals from the batch. -In this example, we use the `kind` key to tell the signal factory which signal class should be called. The key name is arbitrary, and you may choose any key name you want. - -This example model specifies two `kind` values so both signals are called. +It generates a datetime argument, to which it compares the beginning of each interval in the batch. If the interval start is before that argument, the interval is ready for evaluation and included in the returned list. +These signals can be added to a model like so. ```sql linenums="1" hl_lines="7-10" MODEL ( @@ -166,26 +109,10 @@ MODEL ( ), start '2 week ago', signals [ - (kind = 'a'), - (kind = 'b'), + one_week_ago(), ] ); SELECT @start_ds AS ds ``` - -Our signal factory definition extracts the `kind` key value from the `signal_metadata`, then instantiates the signal class corresponding to the value: - -```python linenums="1" hl_lines="3-3" -@signal_factory -def my_signal_factory(signal_metadata: t.Dict[str, t.Union[str, int, float, bool]]) -> Signal: - kind = signal_metadata["kind"] - - if kind == "a": - return AlwaysReady() - if kind == "b": - return OneweekAgo(to_datetime("1 week ago")) - raise Exception(f"Unknown signal {kind}") -``` - diff --git a/examples/sushi/models/waiter_as_customer_by_day.sql b/examples/sushi/models/waiter_as_customer_by_day.sql index 3f7e053c4..7dc12db87 100644 --- a/examples/sushi/models/waiter_as_customer_by_day.sql +++ b/examples/sushi/models/waiter_as_customer_by_day.sql @@ -8,7 +8,11 @@ MODEL ( audits ( not_null(columns := (waiter_id)), forall(criteria := (LENGTH(waiter_name) > 0)) - ) + ), + signals ( + test_signal(arg := 1) + ), + ); JINJA_QUERY_BEGIN; diff --git a/examples/sushi/signals/__init__.py b/examples/sushi/signals/__init__.py new file mode 100644 index 000000000..bd7c839fc --- /dev/null +++ b/examples/sushi/signals/__init__.py @@ -0,0 +1,9 @@ +import typing as t + +from sqlmesh import signal, DatetimeRanges + + +@signal() +def test_signal(batch: DatetimeRanges, arg: int = 0) -> t.Union[bool, DatetimeRanges]: + assert arg == 1 + return True diff --git a/sqlmesh/__init__.py b/sqlmesh/__init__.py index bfa0094dc..18bb9e30e 100644 --- a/sqlmesh/__init__.py +++ b/sqlmesh/__init__.py @@ -24,6 +24,7 @@ from sqlmesh.core.engine_adapter import EngineAdapter as EngineAdapter from sqlmesh.core.macros import SQL as SQL, macro as macro from sqlmesh.core.model import Model as Model, model as model +from sqlmesh.core.signal import signal as signal from sqlmesh.core.snapshot import Snapshot as Snapshot from sqlmesh.core.snapshot.evaluator import ( CustomMaterialization as CustomMaterialization, @@ -32,6 +33,7 @@ debug_mode_enabled as debug_mode_enabled, enable_debug_mode as enable_debug_mode, ) +from sqlmesh.utils.date import DatetimeRanges as DatetimeRanges try: from sqlmesh._version import __version__ as __version__, __version_tuple__ as __version_tuple__ @@ -171,8 +173,8 @@ def configure_logging( os.remove(path) if debug: + from signal import SIGUSR1 import faulthandler - import signal enable_debug_mode() @@ -180,4 +182,4 @@ def configure_logging( faulthandler.enable() # Windows doesn't support register so we check for it here if hasattr(faulthandler, "register"): - faulthandler.register(signal.SIGUSR1.value) + faulthandler.register(SIGUSR1.value) diff --git a/sqlmesh/core/config/model.py b/sqlmesh/core/config/model.py index ec8d2aaaf..2fac33ff3 100644 --- a/sqlmesh/core/config/model.py +++ b/sqlmesh/core/config/model.py @@ -2,7 +2,7 @@ import typing as t -from sqlmesh.core.dialect import parse_one, extract_audit +from sqlmesh.core.dialect import parse_one, extract_func_call from sqlmesh.core.config.base import BaseConfig from sqlmesh.core.model.kind import ( ModelKind, @@ -11,7 +11,7 @@ on_destructive_change_validator, ) from sqlmesh.utils.date import TimeLike -from sqlmesh.core.model.meta import AuditReference +from sqlmesh.core.model.meta import FunctionCall from sqlmesh.utils.pydantic import field_validator @@ -44,7 +44,7 @@ class ModelDefaultsConfig(BaseConfig): storage_format: t.Optional[str] = None on_destructive_change: t.Optional[OnDestructiveChange] = None session_properties: t.Optional[t.Dict[str, t.Any]] = None - audits: t.Optional[t.List[AuditReference]] = None + audits: t.Optional[t.List[FunctionCall]] = None _model_kind_validator = model_kind_validator _on_destructive_change_validator = on_destructive_change_validator @@ -52,6 +52,6 @@ class ModelDefaultsConfig(BaseConfig): @field_validator("audits", mode="before") def _audits_validator(cls, v: t.Any) -> t.Any: if isinstance(v, list): - return [extract_audit(parse_one(audit)) for audit in v] + return [extract_func_call(parse_one(audit)) for audit in v] return v diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index 7257fb58a..521f397e6 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -2187,7 +2187,6 @@ def _load_materializations_and_signals(self) -> None: if not self._loaded: for context_loader in self._loaders.values(): with sys_path(*context_loader.configs): - context_loader.loader.load_signals(self) context_loader.loader.load_materializations(self) def _select_models_for_run( diff --git a/sqlmesh/core/dialect.py b/sqlmesh/core/dialect.py index 1639fd88f..2d284422c 100644 --- a/sqlmesh/core/dialect.py +++ b/sqlmesh/core/dialect.py @@ -1205,7 +1205,9 @@ def interpret_key_value_pairs( return {i.this.name: interpret_expression(i.expression) for i in e.expressions} -def extract_audit(v: exp.Expression) -> t.Tuple[str, t.Dict[str, exp.Expression]]: +def extract_func_call( + v: exp.Expression, allow_tuples: bool = False +) -> t.Tuple[str, t.Dict[str, exp.Expression]]: kwargs = {} if isinstance(v, exp.Anonymous): @@ -1214,6 +1216,12 @@ def extract_audit(v: exp.Expression) -> t.Tuple[str, t.Dict[str, exp.Expression] elif isinstance(v, exp.Func): func = v.sql_name() args = list(v.args.values()) + elif isinstance(v, exp.Tuple): # airflow only + if not allow_tuples: + raise ConfigError("Audit name is missing (eg. MY_AUDIT())") + + func = "" + args = v.expressions else: return v.name.lower(), {} diff --git a/sqlmesh/core/loader.py b/sqlmesh/core/loader.py index 493c9e54e..d1f989d3d 100644 --- a/sqlmesh/core/loader.py +++ b/sqlmesh/core/loader.py @@ -29,6 +29,7 @@ ) from sqlmesh.core.model.cache import load_optimized_query_and_mapping, optimized_query_cache_pool from sqlmesh.core.model import model as model_registry +from sqlmesh.core.signal import signal from sqlmesh.utils import UniqueKeyDict from sqlmesh.utils.dag import DAG from sqlmesh.utils.errors import ConfigError @@ -79,7 +80,7 @@ def load(self, context: GenericContext, update_schemas: bool = True) -> LoadedPr self._dag = DAG() self._load_materializations() - self._load_signals() + signals = self._load_signals() config_mtimes: t.Dict[Path, t.List[float]] = defaultdict(list) for context_path, config in self._context.configs.items(): @@ -103,7 +104,13 @@ def load(self, context: GenericContext, update_schemas: bool = True) -> LoadedPr else: standalone_audits[name] = audit - models = self._load_models(macros, jinja_macros, context.selected_gateway, audits) + models = self._load_models( + macros, + jinja_macros, + context.selected_gateway, + audits, + signals, + ) for model in models.values(): self._add_model_to_dag(model) @@ -131,10 +138,10 @@ def load(self, context: GenericContext, update_schemas: bool = True) -> LoadedPr ) return project - def load_signals(self, context: GenericContext) -> None: + def load_signals(self, context: GenericContext) -> UniqueKeyDict[str, signal]: """Loads signals for the built-in scheduler.""" self._context = context - self._load_signals() + return self._load_signals() def load_materializations(self, context: GenericContext) -> None: """Loads materializations for the built-in scheduler.""" @@ -165,6 +172,7 @@ def _load_models( jinja_macros: JinjaMacroRegistry, gateway: t.Optional[str], audits: UniqueKeyDict[str, ModelAudit], + signals: UniqueKeyDict[str, signal], ) -> UniqueKeyDict[str, Model]: """Loads all models.""" @@ -177,8 +185,8 @@ def _load_audits( def _load_materializations(self) -> None: """Loads custom materializations.""" - def _load_signals(self) -> None: - """Loads signals for the built-in scheduler.""" + def _load_signals(self) -> UniqueKeyDict[str, signal]: + return UniqueKeyDict("signals") def _load_metrics(self) -> UniqueKeyDict[str, MetricMeta]: return UniqueKeyDict("metrics") @@ -320,14 +328,15 @@ def _load_models( jinja_macros: JinjaMacroRegistry, gateway: t.Optional[str], audits: UniqueKeyDict[str, ModelAudit], + signals: UniqueKeyDict[str, signal], ) -> UniqueKeyDict[str, Model]: """ Loads all of the models within the model directory with their associated audits into a Dict and creates the dag """ - models = self._load_sql_models(macros, jinja_macros, audits) + models = self._load_sql_models(macros, jinja_macros, audits, signals) models.update(self._load_external_models(audits, gateway)) - models.update(self._load_python_models(macros, jinja_macros, audits)) + models.update(self._load_python_models(macros, jinja_macros, audits, signals)) return models @@ -336,6 +345,7 @@ def _load_sql_models( macros: MacroRegistry, jinja_macros: JinjaMacroRegistry, audits: UniqueKeyDict[str, ModelAudit], + signals: UniqueKeyDict[str, signal], ) -> UniqueKeyDict[str, Model]: """Loads the sql models into a Dict""" models: UniqueKeyDict[str, Model] = UniqueKeyDict("models") @@ -380,6 +390,7 @@ def _load() -> Model: default_catalog=self._context.default_catalog, variables=variables, infer_names=config.model_naming.infer_names, + signal_definitions=signals, ) model = cache.get_or_load_model(path, _load) @@ -397,6 +408,7 @@ def _load_python_models( macros: MacroRegistry, jinja_macros: JinjaMacroRegistry, audits: UniqueKeyDict[str, ModelAudit], + signals: UniqueKeyDict[str, signal], ) -> UniqueKeyDict[str, Model]: """Loads the python models into a Dict""" models: UniqueKeyDict[str, Model] = UniqueKeyDict("models") @@ -452,8 +464,11 @@ def _load_materializations(self) -> None: if os.path.getsize(path): import_python_file(path, context_path) - def _load_signals(self) -> None: + def _load_signals(self) -> UniqueKeyDict[str, signal]: """Loads signals for the built-in scheduler.""" + + signals_max_mtime: t.Optional[float] = None + for context_path, config in self._context.configs.items(): for path in self._glob_paths( context_path / c.SIGNALS, @@ -461,13 +476,26 @@ def _load_signals(self) -> None: extension=".py", ): if os.path.getsize(path): + self._track_file(path) + signal_file_mtime = self._path_mtimes[path] + signals_max_mtime = ( + max(signals_max_mtime, signal_file_mtime) + if signals_max_mtime + else signal_file_mtime + ) import_python_file(path, context_path) + self._signals_max_mtime = signals_max_mtime + + return signal.get_registry() + def _load_audits( self, macros: MacroRegistry, jinja_macros: JinjaMacroRegistry ) -> UniqueKeyDict[str, Audit]: """Loads all the model audits.""" audits_by_name: UniqueKeyDict[str, Audit] = UniqueKeyDict("audits") + audits_max_mtime: t.Optional[float] = None + for context_path, config in self._context.configs.items(): variables = self._variables(config) for path in self._glob_paths( @@ -475,6 +503,12 @@ def _load_audits( ): self._track_file(path) with open(path, "r", encoding="utf-8") as file: + audits_file_mtime = self._path_mtimes[path] + audits_max_mtime = ( + max(audits_max_mtime, audits_file_mtime) + if audits_max_mtime + else audits_file_mtime + ) expressions = parse(file.read(), default_dialect=config.model_defaults.dialect) audits = load_multiple_audits( expressions=expressions, @@ -488,6 +522,9 @@ def _load_audits( ) for audit in audits: audits_by_name[audit.name] = audit + + self._audits_max_mtime = audits_max_mtime + return audits_by_name def _load_metrics(self) -> UniqueKeyDict[str, MetricMeta]: @@ -550,6 +587,8 @@ def _model_cache_entry_id(self, model_path: Path) -> str: mtimes = [ self._loader._path_mtimes[model_path], self._loader._macros_max_mtime, + self._loader._signals_max_mtime, + self._loader._audits_max_mtime, self._loader._config_mtimes.get(self._context_path), self._loader._config_mtimes.get(c.SQLMESH_PATH), ] diff --git a/sqlmesh/core/macros.py b/sqlmesh/core/macros.py index b7fcdf0b0..c17ddb6c1 100644 --- a/sqlmesh/core/macros.py +++ b/sqlmesh/core/macros.py @@ -39,6 +39,7 @@ columns_to_types_all_known, registry_decorator, ) +from sqlmesh.utils.date import DatetimeRanges from sqlmesh.utils.errors import MacroEvalError, SQLMeshError from sqlmesh.utils.jinja import JinjaMacroRegistry, has_jinja from sqlmesh.utils.metaprogramming import Executable, prepare_env, print_exception @@ -79,6 +80,7 @@ class MacroStrTemplate(Template): "List": t.List, "Tuple": t.Tuple, "Union": t.Union, + "DatetimeRanges": DatetimeRanges, } for klass in sqlglot.Parser.EXPRESSION_PARSERS: @@ -197,38 +199,7 @@ def send( raise SQLMeshError(f"Macro '{name}' does not exist.") try: - # Bind the macro's actual parameters to its formal parameters - sig = inspect.signature(func) - bound = sig.bind(self, *args, **kwargs) - bound.apply_defaults() - except Exception as e: - print_exception(e, self.python_env) - raise MacroEvalError("Error trying to eval macro.") from e - - try: - annotations = t.get_type_hints(func, localns=SUPPORTED_TYPES) - except NameError: # forward references aren't handled - annotations = {} - - # If the macro is annotated, we try coerce the actual parameters to the corresponding types - if annotations: - for arg, value in bound.arguments.items(): - typ = annotations.get(arg) - if not typ: - continue - - # Changes to bound.arguments will reflect in bound.args and bound.kwargs - # https://docs.python.org/3/library/inspect.html#inspect.BoundArguments.arguments - param = sig.parameters[arg] - if param.kind is inspect.Parameter.VAR_POSITIONAL: - bound.arguments[arg] = tuple(self._coerce(v, typ) for v in value) - elif param.kind is inspect.Parameter.VAR_KEYWORD: - bound.arguments[arg] = {k: self._coerce(v, typ) for k, v in value.items()} - else: - bound.arguments[arg] = self._coerce(value, typ) - - try: - return func(*bound.args, **bound.kwargs) + return call_macro(func, self.dialect, self._path, self, *args, **kwargs) # type: ignore except Exception as e: print_exception(e, self.python_env) raise MacroEvalError("Error trying to eval macro.") from e @@ -497,78 +468,7 @@ def var(self, var_name: str, default: t.Optional[t.Any] = None) -> t.Optional[t. def _coerce(self, expr: exp.Expression, typ: t.Any, strict: bool = False) -> t.Any: """Coerces the given expression to the specified type on a best-effort basis.""" - base_err_msg = f"Failed to coerce expression '{expr}' to type '{typ}'." - try: - if typ is None or typ is t.Any: - return expr - base = t.get_origin(typ) or typ - - # We need to handle Union and TypeVars first since we cannot use isinstance with it - if base in UNION_TYPES: - for branch in t.get_args(typ): - try: - return self._coerce(expr, branch, True) - except Exception: - pass - raise SQLMeshError(base_err_msg) - if base is SQL and isinstance(expr, exp.Expression): - return expr.sql(self.dialect) - - if isinstance(expr, base): - return expr - if issubclass(base, exp.Expression): - d = Dialect.get_or_raise(self.dialect) - into = base if base in d.parser_class.EXPRESSION_PARSERS else None - if into is None: - if isinstance(expr, exp.Literal): - coerced = parse_one(expr.this) - else: - raise SQLMeshError( - f"{base_err_msg} Coercion to {base} requires a literal expression." - ) - else: - coerced = parse_one( - expr.this if isinstance(expr, exp.Literal) else expr.sql(), into=into - ) - if isinstance(coerced, base): - return coerced - raise SQLMeshError(base_err_msg) - - if base in (int, float, str) and isinstance(expr, exp.Literal): - return base(expr.this) - if base is str and isinstance(expr, exp.Column) and not expr.table: - return expr.name - if base is bool and isinstance(expr, exp.Boolean): - return expr.this - # if base is str and isinstance(expr, exp.Expression): - # return expr.sql(self.dialect) - if base is tuple and isinstance(expr, (exp.Tuple, exp.Array)): - generic = t.get_args(typ) - if not generic: - return tuple(expr.expressions) - if generic[-1] is ...: - return tuple(self._coerce(expr, generic[0]) for expr in expr.expressions) - elif len(generic) == len(expr.expressions): - return tuple( - self._coerce(expr, generic[i]) for i, expr in enumerate(expr.expressions) - ) - raise SQLMeshError(f"{base_err_msg} Expected {len(generic)} items.") - if base is list and isinstance(expr, (exp.Array, exp.Tuple)): - generic = t.get_args(typ) - if not generic: - return expr.expressions - return [self._coerce(expr, generic[0]) for expr in expr.expressions] - raise SQLMeshError(base_err_msg) - except Exception: - if strict: - raise - logger.error( - "Coercion of expression '%s' to type '%s' failed. Using non coerced expression at '%s'", - expr, - typ, - self._path, - ) - return expr + return _coerce(expr, typ, self.dialect, self._path, strict) class macro(registry_decorator): @@ -1284,3 +1184,121 @@ def normalize_macro_name(name: str) -> str: for m in macro.get_registry().values(): setattr(m, c.SQLMESH_BUILTIN, True) + + +def call_macro( + func: t.Callable, + dialect: DialectType, + path: Path, + *args: t.Any, + **kwargs: t.Any, +) -> t.Any: + # Bind the macro's actual parameters to its formal parameters + sig = inspect.signature(func) + bound = sig.bind(*args, **kwargs) + bound.apply_defaults() + + try: + annotations = t.get_type_hints(func, localns=SUPPORTED_TYPES) + except (NameError, TypeError): # forward references aren't handled + annotations = {} + + # If the macro is annotated, we try coerce the actual parameters to the corresponding types + if annotations: + for arg, value in bound.arguments.items(): + typ = annotations.get(arg) + if not typ: + continue + + # Changes to bound.arguments will reflect in bound.args and bound.kwargs + # https://docs.python.org/3/library/inspect.html#inspect.BoundArguments.arguments + param = sig.parameters[arg] + if param.kind is inspect.Parameter.VAR_POSITIONAL: + bound.arguments[arg] = tuple(_coerce(v, typ, dialect, path) for v in value) + elif param.kind is inspect.Parameter.VAR_KEYWORD: + bound.arguments[arg] = {k: _coerce(v, typ, dialect, path) for k, v in value.items()} + else: + bound.arguments[arg] = _coerce(value, typ, dialect, path) + + return func(*bound.args, **bound.kwargs) + + +def _coerce( + expr: exp.Expression, + typ: t.Any, + dialect: DialectType, + path: Path, + strict: bool = False, +) -> t.Any: + """Coerces the given expression to the specified type on a best-effort basis.""" + base_err_msg = f"Failed to coerce expression '{expr}' to type '{typ}'." + try: + if typ is None or typ is t.Any: + return expr + base = t.get_origin(typ) or typ + + # We need to handle Union and TypeVars first since we cannot use isinstance with it + if base in UNION_TYPES: + for branch in t.get_args(typ): + try: + return _coerce(expr, branch, dialect, path, strict=True) + except Exception: + pass + raise SQLMeshError(base_err_msg) + if base is SQL and isinstance(expr, exp.Expression): + return expr.sql(dialect) + + if isinstance(expr, base): + return expr + if issubclass(base, exp.Expression): + d = Dialect.get_or_raise(dialect) + into = base if base in d.parser_class.EXPRESSION_PARSERS else None + if into is None: + if isinstance(expr, exp.Literal): + coerced = parse_one(expr.this) + else: + raise SQLMeshError( + f"{base_err_msg} Coercion to {base} requires a literal expression." + ) + else: + coerced = parse_one( + expr.this if isinstance(expr, exp.Literal) else expr.sql(), into=into + ) + if isinstance(coerced, base): + return coerced + raise SQLMeshError(base_err_msg) + + if base in (int, float, str) and isinstance(expr, exp.Literal): + return base(expr.this) + if base is str and isinstance(expr, exp.Column) and not expr.table: + return expr.name + if base is bool and isinstance(expr, exp.Boolean): + return expr.this + if base is tuple and isinstance(expr, (exp.Tuple, exp.Array)): + generic = t.get_args(typ) + if not generic: + return tuple(expr.expressions) + if generic[-1] is ...: + return tuple(_coerce(expr, generic[0], dialect, path) for expr in expr.expressions) + elif len(generic) == len(expr.expressions): + return tuple( + _coerce(expr, generic[i], dialect, path) + for i, expr in enumerate(expr.expressions) + ) + raise SQLMeshError(f"{base_err_msg} Expected {len(generic)} items.") + if base is list and isinstance(expr, (exp.Array, exp.Tuple)): + generic = t.get_args(typ) + if not generic: + return expr.expressions + return [_coerce(expr, generic[0], dialect, path) for expr in expr.expressions] + raise SQLMeshError(base_err_msg) + except Exception: + if strict: + raise + logger.error( + "Coercion of expression '%s' to type '%s' failed. Using non coerced expression at '%s'", + expr, + typ, + path, + ) + return expr diff --git a/sqlmesh/core/model/definition.py b/sqlmesh/core/model/definition.py index 6a76e411c..b00fc97aa 100644 --- a/sqlmesh/core/model/definition.py +++ b/sqlmesh/core/model/definition.py @@ -31,9 +31,10 @@ single_value_or_tuple, ) from sqlmesh.core.model.kind import ModelKindName, SeedKind, ModelKind, FullKind, create_model_kind -from sqlmesh.core.model.meta import ModelMeta, AuditReference +from sqlmesh.core.model.meta import ModelMeta, FunctionCall from sqlmesh.core.model.seed import CsvSeedReader, Seed, create_seed from sqlmesh.core.renderer import ExpressionRenderer, QueryRenderer +from sqlmesh.core.signal import SignalRegistry from sqlmesh.utils import columns_to_types_all_known, str_to_bool, UniqueKeyDict from sqlmesh.utils.date import TimeLike, make_inclusive, to_datetime, to_time_column from sqlmesh.utils.errors import ConfigError, SQLMeshError, raise_config_error @@ -42,8 +43,10 @@ from sqlmesh.utils.pydantic import PydanticModel, PRIVATE_FIELDS from sqlmesh.utils.metaprogramming import ( Executable, + build_env, prepare_env, print_exception, + serialize_env, ) if t.TYPE_CHECKING: @@ -556,6 +559,7 @@ def _create_renderer(expression: exp.Expression) -> ExpressionRenderer: jinja_macro_registry=self.jinja_macros, python_env=self.python_env, only_execution_time=False, + quote_identifiers=False, ) def _render(e: exp.Expression) -> str | int | float | bool: @@ -575,7 +579,10 @@ def _render(e: exp.Expression) -> str | int | float | bool: return rendered.this return rendered.sql(dialect=self.dialect) - return [{t.this.name: _render(t.expression) for t in signal} for signal in self.signals] + # airflow only + return [ + {k: _render(v) for k, v in signal.items()} for name, signal in self.signals if not name + ] def ctas_query(self, **render_kwarg: t.Any) -> exp.Query: """Return a dummy query to do a CTAS. @@ -954,7 +961,11 @@ def metadata_hash(self) -> str: metadata.append(key) metadata.append(gen(value)) - metadata.extend(gen(s) for s in self.signals) + for signal_name, args in sorted(self.signals, key=lambda x: x[0]): + metadata.append(signal_name) + for k, v in sorted(args.items()): + metadata.append(f"{k}:{gen(v)}") + metadata.extend(self._additional_metadata) self._metadata_hash = hash_data(metadata) @@ -1598,7 +1609,7 @@ def load_sql_based_model( macros: t.Optional[MacroRegistry] = None, jinja_macros: t.Optional[JinjaMacroRegistry] = None, audits: t.Optional[t.Dict[str, ModelAudit]] = None, - default_audits: t.Optional[t.List[AuditReference]] = None, + default_audits: t.Optional[t.List[FunctionCall]] = None, python_env: t.Optional[t.Dict[str, Executable]] = None, dialect: t.Optional[str] = None, physical_schema_mapping: t.Optional[t.Dict[re.Pattern, str]] = None, @@ -1918,10 +1929,11 @@ def _create_model( physical_schema_mapping: t.Optional[t.Dict[re.Pattern, str]] = None, python_env: t.Optional[t.Dict[str, Executable]] = None, audit_definitions: t.Optional[t.Dict[str, ModelAudit]] = None, - default_audits: t.Optional[t.List[AuditReference]] = None, + default_audits: t.Optional[t.List[FunctionCall]] = None, inline_audits: t.Optional[t.Dict[str, ModelAudit]] = None, module_path: Path = Path(), macros: t.Optional[MacroRegistry] = None, + signal_definitions: t.Optional[SignalRegistry] = None, variables: t.Optional[t.Dict[str, t.Any]] = None, **kwargs: t.Any, ) -> Model: @@ -2018,7 +2030,16 @@ def _create_model( python_env=python_env, ) + env: t.Dict[str, t.Any] = {} + + for signal_name, _ in model.signals: + if signal_definitions and signal_name in signal_definitions: + func = signal_definitions[signal_name].func + setattr(func, c.SQLMESH_METADATA, True) + build_env(func, env=env, name=signal_name, path=module_path) + model.python_env.update(python_env) + model.python_env.update(serialize_env(env, path=module_path)) model._path = path model.set_time_format(time_column_format) @@ -2203,7 +2224,16 @@ def _meta_renderer( "virtual_properties_": lambda value: value, "session_properties_": lambda value: value, "allow_partials": exp.convert, - "signals": lambda values: exp.Tuple(expressions=values), + "signals": lambda values: exp.tuple_( + *( + exp.func( + name, *(exp.PropertyEQ(this=exp.var(k), expression=v) for k, v in args.items()) + ) + if name + else exp.Tuple(expressions=[exp.var(k).eq(v) for k, v in args.items()]) + for name, args in values + ) + ), } diff --git a/sqlmesh/core/model/meta.py b/sqlmesh/core/model/meta.py index bab149c25..2bd4fa13c 100644 --- a/sqlmesh/core/model/meta.py +++ b/sqlmesh/core/model/meta.py @@ -10,12 +10,11 @@ from sqlglot.optimizer.normalize_identifiers import normalize_identifiers from sqlmesh.core import dialect as d -from sqlmesh.core.dialect import normalize_model_name, extract_audit +from sqlmesh.core.dialect import normalize_model_name, extract_func_call from sqlmesh.core.model.common import ( bool_validator, default_catalog_validator, depends_on_validator, - parse_properties, properties_validator, ) from sqlmesh.core.model.kind import ( @@ -45,7 +44,7 @@ if t.TYPE_CHECKING: from sqlmesh.core._typing import CustomMaterializationProperties, SessionProperties -AuditReference = t.Tuple[str, t.Dict[str, exp.Expression]] +FunctionCall = t.Tuple[str, t.Dict[str, exp.Expression]] logger = logging.getLogger(__name__) @@ -67,7 +66,7 @@ class ModelMeta(_Node): column_descriptions_: t.Optional[t.Dict[str, str]] = Field( default=None, alias="column_descriptions" ) - audits: t.List[AuditReference] = [] + audits: t.List[FunctionCall] = [] grains: t.List[exp.Expression] = [] references: t.List[exp.Expression] = [] physical_schema_override: t.Optional[str] = None @@ -75,7 +74,7 @@ class ModelMeta(_Node): virtual_properties_: t.Optional[exp.Tuple] = Field(default=None, alias="virtual_properties") session_properties_: t.Optional[exp.Tuple] = Field(default=None, alias="session_properties") allow_partials: bool = False - signals: t.List[exp.Tuple] = [] + signals: t.List[FunctionCall] = [] enabled: bool = True physical_version: t.Optional[str] = None gateway: t.Optional[str] = None @@ -86,21 +85,23 @@ class ModelMeta(_Node): _default_catalog_validator = default_catalog_validator _depends_on_validator = depends_on_validator - @field_validator("audits", mode="before") - def _audits_validator(cls, v: t.Any) -> t.Any: + @field_validator("audits", "signals", mode="before") + def _func_call_validator(cls, v: t.Any, field: t.Any) -> t.Any: + is_signal = getattr(field, "name" if hasattr(field, "name") else "field_name") == "signals" + if isinstance(v, (exp.Tuple, exp.Array)): - return [extract_audit(i) for i in v.expressions] + return [extract_func_call(i, allow_tuples=is_signal) for i in v.expressions] if isinstance(v, exp.Paren): - return [extract_audit(v.this)] + return [extract_func_call(v.this, allow_tuples=is_signal)] if isinstance(v, exp.Expression): - return [extract_audit(v)] + return [extract_func_call(v, allow_tuples=is_signal)] if isinstance(v, list): audits = [] for entry in v: if isinstance(entry, dict): args = entry - name = entry.pop("name") + name = "" if is_signal else entry.pop("name") elif isinstance(entry, (tuple, list)): name, args = entry else: @@ -286,28 +287,6 @@ def _refs_validator(cls, vs: t.Any, values: t.Dict[str, t.Any]) -> t.List[exp.Ex return refs - @field_validator("signals", mode="before") - @field_validator_v1_args - def _signals_validator(cls, v: t.Any, values: t.Dict[str, t.Any]) -> t.Any: - if v is None: - return [] - - if isinstance(v, str): - dialect = values.get("dialect") - v = d.parse_one(v, dialect=dialect) - - if isinstance(v, (exp.Array, exp.Paren, exp.Tuple)): - tuples: t.List[exp.Expression] = ( - [v.unnest()] if isinstance(v, exp.Paren) else v.expressions - ) - signals = [parse_properties(cls, t, values) for t in tuples] - elif isinstance(v, list): - signals = [parse_properties(cls, t, values) for t in v] - else: - raise ConfigError(f"Unexpected signals '{v}'") - - return signals - @model_validator(mode="before") @model_validator_v1_args def _pre_root_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: diff --git a/sqlmesh/core/scheduler.py b/sqlmesh/core/scheduler.py index 92fb9e3b1..b14b7469c 100644 --- a/sqlmesh/core/scheduler.py +++ b/sqlmesh/core/scheduler.py @@ -1,6 +1,5 @@ from __future__ import annotations -import abc import logging import typing as t @@ -32,10 +31,8 @@ TimeLike, now, now_timestamp, - to_datetime, to_timestamp, validate_date_range, - DatetimeRanges, ) from sqlmesh.utils.errors import AuditError, CircuitBreakerError, SQLMeshError @@ -46,66 +43,6 @@ SchedulingUnit = t.Tuple[str, t.Tuple[Interval, int]] -class Signal(abc.ABC): - @abc.abstractmethod - def check_intervals(self, batch: DatetimeRanges) -> t.Union[bool, DatetimeRanges]: - """Returns which intervals are ready from a list of scheduled intervals. - - When SQLMesh wishes to execute a batch of intervals, say between `a` and `d`, then - the `batch` parameter will contain each individual interval within this batch, - i.e.: `[a,b),[b,c),[c,d)`. - - This function may return `True` to indicate that the whole batch is ready, - `False` to indicate none of the batch's intervals are ready, or a list of - intervals (a batch) to indicate exactly which ones are ready. - - When returning a batch, the function is expected to return a subset of - the `batch` parameter, e.g.: `[a,b),[b,c)`. Note that it may return - gaps, e.g.: `[a,b),[c,d)`, but it may not alter the bounds of any of the - intervals. - - The interface allows an implementation to check batches of intervals without - having to actually compute individual intervals itself. - - Args: - batch: the list of intervals that are missing and scheduled to run. - - Returns: - Either `True` to indicate all intervals are ready, `False` to indicate none are - ready or a list of intervals to indicate exactly which ones are ready. - """ - - -SignalFactory = t.Callable[[t.Dict[str, t.Union[str, int, float, bool]]], Signal] -_registered_signal_factory: t.Optional[SignalFactory] = None - - -def signal_factory(f: SignalFactory) -> None: - """Specifies a function as the SignalFactory to use for building Signal instances from model signal metadata. - - Only one such function may be decorated with this decorator. - - Example: - import typing as t - from sqlmesh.core.scheduler import signal_factory, Batch, Signal - - class AlwaysReadySignal(Signal): - def check_intervals(self, batch: Batch) -> t.Union[bool, Batch]: - return True - - @signal_factory - def my_signal_factory(signal_metadata: t.Dict[str, t.Union[str, int, float, bool]]) -> Signal: - return AlwaysReadySignal() - """ - - global _registered_signal_factory - - if _registered_signal_factory is not None and _registered_signal_factory.__code__ != f.__code__: - raise SQLMeshError("Only one function may be decorated with @signal_factory") - - _registered_signal_factory = f - - class Scheduler: """Schedules and manages the evaluation of snapshots. @@ -121,7 +58,6 @@ class Scheduler: state_sync: The state sync to pull saved snapshots. max_workers: The maximum number of parallel queries to run. console: The rich instance used for printing scheduling information. - signal_factory: A factory method for building Signal instances from model signal configuration. """ def __init__( @@ -133,7 +69,6 @@ def __init__( max_workers: int = 1, console: t.Optional[Console] = None, notification_target_manager: t.Optional[NotificationTargetManager] = None, - signal_factory: t.Optional[SignalFactory] = None, ): self.state_sync = state_sync self.snapshots = {s.snapshot_id: s for s in snapshots} @@ -145,7 +80,6 @@ def __init__( self.notification_target_manager = ( notification_target_manager or NotificationTargetManager() ) - self.signal_factory = signal_factory or _registered_signal_factory def merged_missing_intervals( self, @@ -401,19 +335,9 @@ def expand_range_as_interval( snapshot_batches = {} all_unready_intervals: t.Dict[str, set[Interval]] = {} for snapshot, intervals in snapshot_intervals.items(): - if self.signal_factory and snapshot.is_model: - unready = set(intervals) - - for signal in snapshot.model.render_signals( - start=start, end=end, execution_time=execution_time - ): - intervals = _check_ready_intervals( - signal=self.signal_factory(signal), - intervals=intervals, - ) - unready -= set(intervals) - else: - unready = set() + unready = set(intervals) + intervals = snapshot.check_ready_intervals(intervals) + unready -= set(intervals) for parent in snapshot.parents: if parent.name in all_unready_intervals: @@ -687,53 +611,3 @@ def _resolve_one_snapshot_per_version( snapshot_per_version[key] = snapshot return snapshot_per_version - - -def _contiguous_intervals( - intervals: Intervals, -) -> t.List[Intervals]: - """Given a list of intervals with gaps, returns a list of sequences of contiguous intervals.""" - contiguous_intervals = [] - current_batch: t.List[Interval] = [] - for interval in intervals: - if len(current_batch) == 0 or interval[0] == current_batch[-1][-1]: - current_batch.append(interval) - else: - contiguous_intervals.append(current_batch) - current_batch = [interval] - - if len(current_batch) > 0: - contiguous_intervals.append(current_batch) - - return contiguous_intervals - - -def _check_ready_intervals( - signal: Signal, - intervals: Intervals, -) -> Intervals: - """Returns a list of intervals that are considered ready by the provided signal. - - Note that this will handle gaps in the provided intervals. The returned intervals - may introduce new gaps. - """ - checked_intervals = [] - for interval_batch in _contiguous_intervals(intervals): - batch = [(to_datetime(start), to_datetime(end)) for start, end in interval_batch] - - ready_intervals = signal.check_intervals(batch=batch) - if isinstance(ready_intervals, bool): - if not ready_intervals: - batch = [] - elif isinstance(ready_intervals, list): - for i in ready_intervals: - if i not in batch: - raise RuntimeError(f"Signal returned unknown interval {i}") - batch = ready_intervals - else: - raise ValueError( - f"unexpected return value from signal, expected bool | list, got {type(ready_intervals)}" - ) - - checked_intervals.extend([(to_timestamp(start), to_timestamp(end)) for start, end in batch]) - return checked_intervals diff --git a/sqlmesh/core/signal.py b/sqlmesh/core/signal.py new file mode 100644 index 000000000..d9ee67092 --- /dev/null +++ b/sqlmesh/core/signal.py @@ -0,0 +1,35 @@ +from __future__ import annotations + + +from sqlmesh.utils import UniqueKeyDict, registry_decorator + + +class signal(registry_decorator): + """Specifies a function which intervals are ready from a list of scheduled intervals. + + When SQLMesh wishes to execute a batch of intervals, say between `a` and `d`, then + the `batch` parameter will contain each individual interval within this batch, + i.e.: `[a,b),[b,c),[c,d)`. + + This function may return `True` to indicate that the whole batch is ready, + `False` to indicate none of the batch's intervals are ready, or a list of + intervals (a batch) to indicate exactly which ones are ready. + + When returning a batch, the function is expected to return a subset of + the `batch` parameter, e.g.: `[a,b),[b,c)`. Note that it may return + gaps, e.g.: `[a,b),[c,d)`, but it may not alter the bounds of any of the + intervals. + + The interface allows an implementation to check batches of intervals without + having to actually compute individual intervals itself. + + Args: + batch: the list of intervals that are missing and scheduled to run. + + Returns: + Either `True` to indicate all intervals are ready, `False` to indicate none are + ready or a list of intervals to indicate exactly which ones are ready. + """ + + +SignalRegistry = UniqueKeyDict[str, signal] diff --git a/sqlmesh/core/snapshot/definition.py b/sqlmesh/core/snapshot/definition.py index 3074f2a08..dae13b035 100644 --- a/sqlmesh/core/snapshot/definition.py +++ b/sqlmesh/core/snapshot/definition.py @@ -5,6 +5,7 @@ from datetime import datetime, timedelta from enum import IntEnum from functools import cached_property, lru_cache +from pathlib import Path from pydantic import Field from sqlglot import exp @@ -12,6 +13,7 @@ from sqlmesh.core import constants as c from sqlmesh.core.audit import StandaloneAudit +from sqlmesh.core.macros import call_macro from sqlmesh.core.model import Model, ModelKindMixin, ModelKindName, ViewKind, CustomKind from sqlmesh.core.model.definition import _Model from sqlmesh.core.node import IntervalUnit, NodeType @@ -34,6 +36,7 @@ yesterday, ) from sqlmesh.utils.errors import SQLMeshError +from sqlmesh.utils.metaprogramming import prepare_env, print_exception from sqlmesh.utils.hashing import hash_data from sqlmesh.utils.pydantic import PydanticModel, field_validator @@ -885,6 +888,37 @@ def missing_intervals( model_end_ts, ) + def check_ready_intervals(self, intervals: Intervals) -> Intervals: + """Returns a list of intervals that are considered ready by the provided signal. + + Note that this will handle gaps in the provided intervals. The returned intervals + may introduce new gaps. + """ + signals = self.is_model and self.model.signals + + if not signals: + return intervals + + python_env = self.model.python_env + env = prepare_env(python_env) + + for signal_name, kwargs in signals: + try: + intervals = _check_ready_intervals( + env[signal_name], + intervals, + dialect=self.model.dialect, + path=self.model._path, + kwargs=kwargs, + ) + except SQLMeshError as e: + print_exception(e, python_env) + raise SQLMeshError( + f"{e} '{signal_name}' for '{self.model.name}' at {self.model._path}" + ) + + return intervals + def categorize_as(self, category: SnapshotChangeCategory) -> None: """Assigns the given category to this snapshot. @@ -1836,3 +1870,53 @@ def snapshots_to_dag(snapshots: t.Collection[Snapshot]) -> DAG[SnapshotId]: for snapshot in snapshots: dag.add(snapshot.snapshot_id, snapshot.parents) return dag + + +def _contiguous_intervals(intervals: Intervals) -> t.List[Intervals]: + """Given a list of intervals with gaps, returns a list of sequences of contiguous intervals.""" + contiguous_intervals = [] + current_batch: t.List[Interval] = [] + for interval in intervals: + if len(current_batch) == 0 or interval[0] == current_batch[-1][-1]: + current_batch.append(interval) + else: + contiguous_intervals.append(current_batch) + current_batch = [interval] + + if len(current_batch) > 0: + contiguous_intervals.append(current_batch) + + return contiguous_intervals + + +def _check_ready_intervals( + check: t.Callable, + intervals: Intervals, + dialect: DialectType = None, + path: Path = Path(), + kwargs: t.Optional[t.Dict] = None, +) -> Intervals: + checked_intervals: Intervals = [] + + for interval_batch in _contiguous_intervals(intervals): + batch = [(to_datetime(start), to_datetime(end)) for start, end in interval_batch] + + try: + ready_intervals = call_macro(check, dialect, path, batch, **(kwargs or {})) + except Exception: + raise SQLMeshError("Error evaluating signal") + + if isinstance(ready_intervals, bool): + if not ready_intervals: + batch = [] + elif isinstance(ready_intervals, list): + for i in ready_intervals: + if i not in batch: + raise SQLMeshError(f"Unknown interval {i} for signal") + batch = ready_intervals + else: + raise SQLMeshError(f"Expected bool | list, got {type(ready_intervals)} for signal") + + checked_intervals.extend((to_timestamp(start), to_timestamp(end)) for start, end in batch) + + return checked_intervals diff --git a/sqlmesh/dbt/loader.py b/sqlmesh/dbt/loader.py index 7a0561e95..4b6caa04f 100644 --- a/sqlmesh/dbt/loader.py +++ b/sqlmesh/dbt/loader.py @@ -13,6 +13,7 @@ from sqlmesh.core.loader import LoadedProject, Loader from sqlmesh.core.macros import MacroRegistry, macro from sqlmesh.core.model import Model, ModelCache +from sqlmesh.core.signal import signal from sqlmesh.dbt.basemodel import BMC, BaseModelConfig from sqlmesh.dbt.context import DbtContext from sqlmesh.dbt.model import ModelConfig @@ -94,6 +95,7 @@ def _load_models( jinja_macros: JinjaMacroRegistry, gateway: t.Optional[str], audits: UniqueKeyDict[str, ModelAudit], + signals: UniqueKeyDict[str, signal], ) -> UniqueKeyDict[str, Model]: models: UniqueKeyDict[str, Model] = UniqueKeyDict("models") diff --git a/sqlmesh/migrations/v0059_move_audits_to_model.py b/sqlmesh/migrations/v0060_move_audits_to_model.py similarity index 100% rename from sqlmesh/migrations/v0059_move_audits_to_model.py rename to sqlmesh/migrations/v0060_move_audits_to_model.py diff --git a/sqlmesh/migrations/v0060_mysql_fix_blob_text_type.py b/sqlmesh/migrations/v0061_mysql_fix_blob_text_type.py similarity index 100% rename from sqlmesh/migrations/v0060_mysql_fix_blob_text_type.py rename to sqlmesh/migrations/v0061_mysql_fix_blob_text_type.py diff --git a/sqlmesh/migrations/v0061_add_model_gateway.py b/sqlmesh/migrations/v0062_add_model_gateway.py similarity index 100% rename from sqlmesh/migrations/v0061_add_model_gateway.py rename to sqlmesh/migrations/v0062_add_model_gateway.py diff --git a/sqlmesh/migrations/v0063_change_signals.py b/sqlmesh/migrations/v0063_change_signals.py new file mode 100644 index 000000000..d2060f552 --- /dev/null +++ b/sqlmesh/migrations/v0063_change_signals.py @@ -0,0 +1,83 @@ +"""Change serialization of signals to allow for function calls.""" + +import json + +import pandas as pd +from sqlglot import exp + +from sqlmesh.utils.migration import index_text_type + + +def migrate(state_sync, **kwargs): # type: ignore + engine_adapter = state_sync.engine_adapter + schema = state_sync.schema + snapshots_table = "_snapshots" + index_type = index_text_type(engine_adapter.dialect) + if schema: + snapshots_table = f"{schema}.{snapshots_table}" + + new_snapshots = [] + + for ( + name, + identifier, + version, + snapshot, + kind_name, + updated_ts, + unpaused_ts, + ttl_ms, + unrestorable, + ) in engine_adapter.fetchall( + exp.select( + "name", + "identifier", + "version", + "snapshot", + "kind_name", + "updated_ts", + "unpaused_ts", + "ttl_ms", + "unrestorable", + ).from_(snapshots_table), + quote_identifiers=True, + ): + parsed_snapshot = json.loads(snapshot) + node = parsed_snapshot["node"] + signals = node.pop("signals", None) + + if signals: + node["signals"] = [("", signal) for signal in signals] + + new_snapshots.append( + { + "name": name, + "identifier": identifier, + "version": version, + "snapshot": json.dumps(parsed_snapshot), + "kind_name": kind_name, + "updated_ts": updated_ts, + "unpaused_ts": unpaused_ts, + "ttl_ms": ttl_ms, + "unrestorable": unrestorable, + } + ) + + if new_snapshots: + engine_adapter.delete_from(snapshots_table, "TRUE") + + engine_adapter.insert_append( + snapshots_table, + pd.DataFrame(new_snapshots), + columns_to_types={ + "name": exp.DataType.build(index_type), + "identifier": exp.DataType.build(index_type), + "version": exp.DataType.build(index_type), + "snapshot": exp.DataType.build("text"), + "kind_name": exp.DataType.build(index_type), + "updated_ts": exp.DataType.build("bigint"), + "unpaused_ts": exp.DataType.build("bigint"), + "ttl_ms": exp.DataType.build("bigint"), + "unrestorable": exp.DataType.build("boolean"), + }, + ) diff --git a/sqlmesh/utils/metaprogramming.py b/sqlmesh/utils/metaprogramming.py index 4173c8eb2..2b2ab3f01 100644 --- a/sqlmesh/utils/metaprogramming.py +++ b/sqlmesh/utils/metaprogramming.py @@ -21,7 +21,7 @@ from sqlmesh.utils.errors import SQLMeshError from sqlmesh.utils.pydantic import PydanticModel -IGNORE_DECORATORS = {"macro", "model"} +IGNORE_DECORATORS = {"macro", "model", "signal"} def _is_relative_to(path: t.Optional[Path | str], other: t.Optional[Path | str]) -> bool: diff --git a/tests/core/test_model.py b/tests/core/test_model.py index d87ca1b33..f3ecb00b4 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -3820,9 +3820,10 @@ def test_signals(): MODEL ( name db.table, signals [ + my_signal(arg = 1), ( - table_name = 'table_a', - ds = @end_ds, + table_name := 'table_a', + ds := @end_ds, ), ( table_name = 'table_b', @@ -3843,26 +3844,35 @@ def test_signals(): model = load_sql_based_model(expressions) assert model.signals == [ - exp.Tuple( - expressions=[ - exp.to_column("table_name").eq("table_a"), - exp.to_column("ds").eq(d.MacroVar(this="end_ds")), - ] + ( + "my_signal", + { + "arg": exp.Literal.number(1), + }, ), - exp.Tuple( - expressions=[ - exp.to_column("table_name").eq("table_b"), - exp.to_column("ds").eq(d.MacroVar(this="end_ds")), - exp.to_column("hour").eq(d.MacroVar(this="end_hour")), - ] + ( + "", + { + "table_name": exp.Literal.string("table_a"), + "ds": d.MacroVar(this="end_ds"), + }, ), - exp.Tuple( - expressions=[ - exp.to_column("bool_key").eq(True), - exp.to_column("int_key").eq(1), - exp.to_column("float_key").eq(1.0), - exp.to_column("string_key").eq("string"), - ] + ( + "", + { + "table_name": exp.Literal.string("table_b"), + "ds": d.MacroVar(this="end_ds"), + "hour": d.MacroVar(this="end_hour"), + }, + ), + ( + "", + { + "bool_key": exp.true(), + "int_key": exp.Literal.number(1), + "float_key": exp.Literal.number(1.0), + "string_key": exp.Literal.string("string"), + }, ), ] @@ -3874,7 +3884,7 @@ def test_signals(): ] assert ( - "signals ((table_name = 'table_a', ds = @end_ds), (table_name = 'table_b', ds = @end_ds, hour = @end_hour), (bool_key = TRUE, int_key = 1, float_key = 1.0, string_key = 'string')" + "signals (MY_SIGNAL(arg := 1), (table_name = 'table_a', ds = @end_ds), (table_name = 'table_b', ds = @end_ds, hour = @end_hour), (bool_key = TRUE, int_key = 1, float_key = 1.0, string_key = 'string'))" in model.render_definition()[0].sql() ) diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index 4b7de2aa5..804348d3d 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -10,7 +10,6 @@ from sqlmesh.core.model import load_sql_based_model from sqlmesh.core.model.definition import AuditResult, SqlModel from sqlmesh.core.model.kind import ( - FullKind, IncrementalByTimeRangeKind, IncrementalByUniqueKeyKind, TimeColumn, @@ -20,10 +19,9 @@ Scheduler, interval_diff, compute_interval_params, - signal_factory, - Signal, SnapshotToIntervals, ) +from sqlmesh.core.signal import signal from sqlmesh.core.snapshot import ( Snapshot, SnapshotEvaluator, @@ -506,95 +504,6 @@ def test_external_model_audit(mocker, make_snapshot): spy.assert_called_once() -def test_contiguous_intervals(): - from sqlmesh.core.scheduler import _contiguous_intervals as ci - - assert ci([]) == [] - assert ci([(0, 1)]) == [[(0, 1)]] - assert ci([(0, 1), (1, 2), (2, 3)]) == [[(0, 1), (1, 2), (2, 3)]] - assert ci([(0, 1), (3, 4), (4, 5), (6, 7)]) == [ - [(0, 1)], - [(3, 4), (4, 5)], - [(6, 7)], - ] - - -def test_check_ready_intervals(mocker: MockerFixture): - from sqlmesh.core.scheduler import _check_ready_intervals - from sqlmesh.core.snapshot.definition import Interval - - def const_signal(const): - signal_mock = mocker.Mock() - signal_mock.check_intervals = mocker.MagicMock(return_value=const) - return signal_mock - - def assert_always_signal(intervals): - _check_ready_intervals(const_signal(True), intervals) == intervals - - assert_always_signal([]) - assert_always_signal([(0, 1)]) - assert_always_signal([(0, 1), (1, 2)]) - assert_always_signal([(0, 1), (2, 3)]) - - def assert_never_signal(intervals): - _check_ready_intervals(const_signal(False), intervals) == [] - - assert_never_signal([]) - assert_never_signal([(0, 1)]) - assert_never_signal([(0, 1), (1, 2)]) - assert_never_signal([(0, 1), (2, 3)]) - - def to_intervals(values: t.List[t.Tuple[int, int]]) -> t.List[Interval]: - return [(to_datetime(s), to_datetime(e)) for s, e in values] - - def assert_check_intervals( - intervals: t.List[t.Tuple[int, int]], - ready: t.List[t.List[t.Tuple[int, int]]], - expected: t.List[t.Tuple[int, int]], - ): - signal_mock = mocker.Mock() - signal_mock.check_intervals = mocker.MagicMock(side_effect=[to_intervals(r) for r in ready]) - _check_ready_intervals(signal_mock, intervals) == expected - - assert_check_intervals([], [], []) - assert_check_intervals([(0, 1)], [[]], []) - assert_check_intervals( - [(0, 1)], - [[(0, 1)]], - [(0, 1)], - ) - assert_check_intervals( - [(0, 1), (1, 2)], - [[(0, 1)]], - [(0, 1)], - ) - assert_check_intervals( - [(0, 1), (1, 2)], - [[(1, 2)]], - [(1, 2)], - ) - assert_check_intervals( - [(0, 1), (1, 2)], - [[(0, 1), (1, 2)]], - [(0, 1), (1, 2)], - ) - assert_check_intervals( - [(0, 1), (1, 2), (3, 4)], - [[], []], - [], - ) - assert_check_intervals( - [(0, 1), (1, 2), (3, 4)], - [[(0, 1)], []], - [(0, 1)], - ) - assert_check_intervals( - [(0, 1), (1, 2), (3, 4)], - [[(0, 1)], [(3, 4)]], - [(0, 1), (3, 4)], - ) - - def test_audit_failure_notifications( scheduler: Scheduler, waiter_names: Snapshot, mocker: MockerFixture ): @@ -661,56 +570,6 @@ def _evaluate(): assert notify_mock.call_count == 1 -def test_signal_factory(mocker: MockerFixture, make_snapshot): - class AlwaysReadySignal(Signal): - def check_intervals(self, batch: DatetimeRanges): - return True - - signal_factory_invoked = 0 - - @signal_factory - def factory(signal_metadata): - nonlocal signal_factory_invoked - signal_factory_invoked += 1 - assert signal_metadata.get("kind") == "foo" - return AlwaysReadySignal() - - start = to_datetime("2023-01-01") - end = to_datetime("2023-01-07") - snapshot: Snapshot = make_snapshot( - SqlModel( - name="name", - kind=FullKind(), - owner="owner", - dialect="", - cron="@daily", - start=start, - query=parse_one("SELECT id FROM VALUES (1), (2) AS t(id)"), - signals=[{"kind": "foo"}], - ), - ) - snapshot_evaluator = SnapshotEvaluator(adapters=mocker.MagicMock(), ddl_concurrent_tasks=1) - scheduler = Scheduler( - snapshots=[snapshot], - snapshot_evaluator=snapshot_evaluator, - state_sync=mocker.MagicMock(), - max_workers=2, - default_catalog=None, - console=mocker.MagicMock(), - ) - merged_intervals = scheduler.merged_missing_intervals(start, end, end) - assert len(merged_intervals) == 1 - scheduler.run_merged_intervals( - merged_intervals=merged_intervals, - deployability_index=DeployabilityIndex.all_deployable(), - environment_naming_info=EnvironmentNamingInfo(), - start=start, - end=end, - ) - - assert signal_factory_invoked > 0 - - def test_interval_diff(): assert interval_diff([(1, 2)], []) == [(1, 2)] assert interval_diff([(1, 2)], [(1, 2)]) == [] @@ -741,51 +600,85 @@ def test_interval_diff(): def test_signal_intervals(mocker: MockerFixture, make_snapshot, get_batched_missing_intervals): - class TestSignal(Signal): - def __init__(self, signal: t.Dict): - self.name = signal["kind"] + @signal() + def signal_a(batch: DatetimeRanges): + return [batch[0], batch[1]] + + @signal() + def signal_b(batch: DatetimeRanges): + return batch[-49:] - def check_intervals(self, batch: DatetimeRanges): - if self.name == "a": - return [batch[0], batch[1]] - if self.name == "b": - return batch[-49:] + signals = signal.get_registry() a = make_snapshot( - SqlModel( - name="a", - kind="full", - start="2023-01-01", - query=parse_one("SELECT 1 x"), - signals=[{"kind": "a"}], + load_sql_based_model( + parse( # type: ignore + """ + MODEL ( + name a, + kind FULL, + start '2023-01-01', + signals SIGNAL_A(), + ); + + SELECT 1 x; + """ + ), + signal_definitions=signals, ), ) + b = make_snapshot( - SqlModel( - name="b", - kind="full", - start="2023-01-01", - cron="@hourly", - query=parse_one("SELECT 2 x"), - signals=[{"kind": "b"}], + load_sql_based_model( + parse( # type: ignore + """ + MODEL ( + name b, + kind FULL, + cron '@hourly', + start '2023-01-01', + signals SIGNAL_B(), + ); + + SELECT 2 x; + """ + ), + signal_definitions=signals, ), nodes={a.name: a.model}, ) + c = make_snapshot( - SqlModel( - name="c", - kind="full", - start="2023-01-01", - query=parse_one("select * from a union select * from b"), + load_sql_based_model( + parse( # type: ignore + """ + MODEL ( + name c, + kind FULL, + start '2023-01-01', + ); + + SELECT * FROM a UNION SELECT * FROM b + """ + ), + signal_definitions=signals, ), nodes={a.name: a.model, b.name: b.model}, ) d = make_snapshot( - SqlModel( - name="d", - kind="full", - start="2023-01-01", - query=parse_one("select * from c union all select * from d"), + load_sql_based_model( + parse( # type: ignore + """ + MODEL ( + name d, + kind FULL, + start '2023-01-01', + ); + + SELECT * FROM c UNION SELECT * FROM d + """ + ), + signal_definitions=signals, ), nodes={a.name: a.model, b.name: b.model, c.name: c.model}, ) @@ -797,7 +690,6 @@ def check_intervals(self, batch: DatetimeRanges): state_sync=mocker.MagicMock(), max_workers=2, default_catalog=None, - signal_factory=lambda signal: TestSignal(signal), ) batches = get_batched_missing_intervals(scheduler, "2023-01-01", "2023-01-03", None) diff --git a/tests/core/test_snapshot.py b/tests/core/test_snapshot.py index 1db72f815..7872ae39a 100644 --- a/tests/core/test_snapshot.py +++ b/tests/core/test_snapshot.py @@ -47,9 +47,13 @@ ) from sqlmesh.core.snapshot.cache import SnapshotCache from sqlmesh.core.snapshot.categorizer import categorize_change -from sqlmesh.core.snapshot.definition import display_name +from sqlmesh.core.snapshot.definition import ( + display_name, + _check_ready_intervals, + _contiguous_intervals, +) from sqlmesh.utils import AttributeDict -from sqlmesh.utils.date import to_date, to_datetime, to_timestamp +from sqlmesh.utils.date import DatetimeRanges, to_date, to_datetime, to_timestamp from sqlmesh.utils.errors import SQLMeshError from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroInfo @@ -878,6 +882,29 @@ def test_fingerprint_virtual_properties(model: Model, parent_model: Model): assert updated_fingerprint.data_hash == fingerprint.data_hash +def test_fingerprint_signals(sushi_context_pre_scheduling): + model = deepcopy(sushi_context_pre_scheduling.get_model("sushi.waiter_as_customer_by_day")) + fingerprint = fingerprint_from_node(model, nodes={}) + + def assert_metadata_only(): + model._metadata_hash = None + model._data_hash = None + updated_fingerprint = fingerprint_from_node(model, nodes={}) + + assert updated_fingerprint != fingerprint + assert updated_fingerprint.metadata_hash != fingerprint.metadata_hash + assert updated_fingerprint.data_hash == fingerprint.data_hash + + executable = model.python_env["test_signal"] + model.python_env["test_signal"].payload = executable.payload.replace("arg == 1", "arg == 2") + + assert_metadata_only() + + model = deepcopy(sushi_context_pre_scheduling.get_model("sushi.waiter_as_customer_by_day")) + model.signals.clear() + assert_metadata_only() + + def test_stamp(model: Model): original_fingerprint = fingerprint_from_node(model, nodes={}) @@ -2155,3 +2182,82 @@ def test_physical_version_pin_for_new_forward_only_models(make_snapshot): assert snapshot_f.version == "1234" assert snapshot_f.fingerprint != snapshot_e.fingerprint + + +def test_contiguous_intervals(): + assert _contiguous_intervals([]) == [] + assert _contiguous_intervals([(0, 1)]) == [[(0, 1)]] + assert _contiguous_intervals([(0, 1), (1, 2), (2, 3)]) == [[(0, 1), (1, 2), (2, 3)]] + assert _contiguous_intervals([(0, 1), (3, 4), (4, 5), (6, 7)]) == [ + [(0, 1)], + [(3, 4), (4, 5)], + [(6, 7)], + ] + + +def test_check_ready_intervals(mocker: MockerFixture): + def assert_always_signal(intervals): + _check_ready_intervals(lambda _: True, intervals) == intervals + + assert_always_signal([]) + assert_always_signal([(0, 1)]) + assert_always_signal([(0, 1), (1, 2)]) + assert_always_signal([(0, 1), (2, 3)]) + + def assert_never_signal(intervals): + _check_ready_intervals(lambda _: False, intervals) == [] + + assert_never_signal([]) + assert_never_signal([(0, 1)]) + assert_never_signal([(0, 1), (1, 2)]) + assert_never_signal([(0, 1), (2, 3)]) + + def to_intervals(values: t.List[t.Tuple[int, int]]) -> DatetimeRanges: + return [(to_datetime(s), to_datetime(e)) for s, e in values] + + def assert_check_intervals( + intervals: t.List[t.Tuple[int, int]], + ready: t.List[t.List[t.Tuple[int, int]]], + expected: t.List[t.Tuple[int, int]], + ): + mock = mocker.Mock() + mock.side_effect = [to_intervals(r) for r in ready] + _check_ready_intervals(mock, intervals) == expected + + assert_check_intervals([], [], []) + assert_check_intervals([(0, 1)], [[]], []) + assert_check_intervals( + [(0, 1)], + [[(0, 1)]], + [(0, 1)], + ) + assert_check_intervals( + [(0, 1), (1, 2)], + [[(0, 1)]], + [(0, 1)], + ) + assert_check_intervals( + [(0, 1), (1, 2)], + [[(1, 2)]], + [(1, 2)], + ) + assert_check_intervals( + [(0, 1), (1, 2)], + [[(0, 1), (1, 2)]], + [(0, 1), (1, 2)], + ) + assert_check_intervals( + [(0, 1), (1, 2), (3, 4)], + [[], []], + [], + ) + assert_check_intervals( + [(0, 1), (1, 2), (3, 4)], + [[(0, 1)], []], + [(0, 1)], + ) + assert_check_intervals( + [(0, 1), (1, 2), (3, 4)], + [[(0, 1)], [(3, 4)]], + [(0, 1), (3, 4)], + ) diff --git a/tests/schedulers/airflow/operators/test_sensor.py b/tests/schedulers/airflow/operators/test_sensor.py index 7bb4a2eb6..c13795dbf 100644 --- a/tests/schedulers/airflow/operators/test_sensor.py +++ b/tests/schedulers/airflow/operators/test_sensor.py @@ -122,12 +122,15 @@ def test_external_sensor(mocker: MockerFixture, make_snapshot, set_airflow_as_li name="this", query=parse_one("select 1"), signals=[ - {"table_name": "test_table_name_a", "ds": parse_one("@end_ds")}, - { - "table_name": "test_table_name_b", - "ds": parse_one("@end_ds"), - "hour": parse_one("@end_hour"), - }, + ("", {"table_name": "test_table_name_a", "ds": parse_one("@end_ds")}), + ( + "", + { + "table_name": "test_table_name_b", + "ds": parse_one("@end_ds"), + "hour": parse_one("@end_hour"), + }, + ), ], ) )