Skip to content

Commit

Permalink
fix!: make signals serializable
Browse files Browse the repository at this point in the history
  • Loading branch information
tobymao committed Dec 10, 2024
1 parent bee88a4 commit a1d8e40
Show file tree
Hide file tree
Showing 24 changed files with 702 additions and 598 deletions.
135 changes: 31 additions & 104 deletions docs/guides/signals.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,98 +14,57 @@ The scheduler uses two criteria to determine whether a model should be evaluated

Signals allow you to specify additional criteria that must be met before the scheduler evaluates the model.

A signal definition has two components: a "checking" function that checks whether a criterion is met and a "factory function" that provides the checking function to SQLMesh. Before describing the checking function, we provide some background information about how the scheduler works.
A signal definition is simply a function that checks whether a criterion is met. Before describing the checking function, we provide some background information about how the scheduler works.

The scheduler doesn't actually evaluate "a model" - it evaluates a model over a specific time interval. This is clearest for incremental models, where only rows in the time interval are ingested during an evaluation. However, evaluation of non-temporal model kinds like `FULL` and `VIEW` are also based on a time interval: the model's `cron` frequency.

The scheduler's decisions are based on these time intervals. For each model, the scheduler examines a set of candidate intervals and identifies the ones that are ready for evaluation.

It then divides those into _batches_ (configured with the model's [batch_size](../concepts/models/overview.md#batch_size) parameter). For incremental models, it evaluates the model once for each batch. For non-incremental models, it evaluates the model once if any batch contains an interval.

Signal checking functions examines a batch of time intervals. The function has two inputs: signal metadata values defined your model and a batch of time intervals. It may return `True` if all intervals are ready for evaluation, `False` if no intervals are ready, or the time intervals themselves if only some are ready. A checking function is defined as a method on a `Signal` sub-class.

A project may have one signal factory function. The factory function determines which checking function should be used for a given model and signal. Its inputs are the signal metadata values defined in your model, and it returns the checking function.
Signal checking functions examines a batch of time intervals. The function is always called with a batch of time intervals (DateTimeRanges). It can also optionally be called with key word arguments. It may return `True` if all intervals are ready for evaluation, `False` if no intervals are ready, or the time intervals themselves if only some are ready. A checking function is defined with the `@signal` decorator.

## Defining a signal

To define a signal, create a `signals` directory in your project folder. Define your signal in a file named `__init__.py` in that directory.

The file must:
To define a signal, create a `signals` directory in your project folder. Define your signal in a file named `__init__.py` in that directory (you can have additional python file names as well).

- Define at least one `Signal` sub-class containing a `check_intervals` method
- Define a factory function that returns a `Signal` sub-class and decorate the function with the `@signal_factory` decorator
A signal is a function that accepts a batch (DateTimeRanges: t.List[t.Tuple[datetime, datetime]]) and returns a batch or a boolean. It needs use the @signal decorator.

We now demonstrate signals of varying complexity.

### Simple example

This example defines the `RandomSignal` class and its mandatory `check_intervals()` method.
This example defines a `RandomSignal` method.

The method returns `True` (indicating that all intervals are ready for evaluation) if a random number is greater than a threshold specified in the model definition:

```python linenums="1"
import random
import typing as t
from sqlmesh.core.scheduler import signal_factory, Signal
from sqlmesh.utils.date import DatetimeRanges

from sqlmesh import signal, DatetimeRanges

class RandomSignal(Signal):
def __init__(
self,
signal_metadata: t.Dict[str, t.Union[str, int, float, bool]]
):
self.signal_metadata = signal_metadata

def check_intervals(self, batch: DatetimeRanges) -> t.Union[bool, DatetimeRanges]:
threshold = self.signal_metadata["threshold"]
return random.random() > threshold
@signal()
def random_signal(batch: DatetimeRanges, threshold: float) -> t.Union[bool, DatetimeRanges]:
return random.random() > threshold
```

Note that the `RandomSignal` class sub-classes `Signal` and takes a `signal_metadata` argument.

The `check_intervals()` method extracts the threshold metadata and compares a random number to it.

We can now add a factory function that returns the `RandomSignal` sub-class. Note the `@signal_factory` decorator on line 8:

```python linenums="1" hl_lines="8-10"
import random
import typing as t
from sqlmesh.core.scheduler import signal_factory, Signal
from sqlmesh.utils.date import DatetimeRanges


class RandomSignal(Signal):
def __init__(
self,
signal_metadata: t.Dict[str, t.Union[str, int, float, bool]]
):
self.signal_metadata = signal_metadata

def check_intervals(self, batch: DatetimeRanges) -> t.Union[bool, DatetimeRanges]:
threshold = self.signal_metadata["threshold"]
return random.random() > threshold

Note that the `random_signal()` takes a mandatory user defined `threshold` argument.

@signal_factory
def my_signal_factory(signal_metadata: t.Dict[str, t.Union[str, int, float, bool]]) -> Signal:
return RandomSignal(signal_metadata)
```

We now have a working signal!
The `random_signal()` method extracts the threshold metadata and compares a random number to it. The type is inferred based on the same [rules as SQLMesh Macros](../concepts/macros/sqlmesh_macros.md#typed-macros).

We specify that a model should use the signal by passing metadata to the model DDL's `signals` key.
Now that we have a working signal, we need to specify that a model should use the signal by passing metadata to the model DDL's `signals` key.

The `signals` key accepts an array delimited by brackets `[]`. Each tuple in the list should contain the metadata needed for one signal evaluation.
The `signals` key accepts an array delimited by brackets `[]`. Each function in the list should contain the metadata needed for one signal evaluation.

This example specifies that the `RandomSignal` should evaluate once with a threshold of 0.5:
This example specifies that the `random_signal()` should evaluate once with a threshold of 0.5:

```sql linenums="1" hl_lines="4-6"
MODEL (
name example.signal_model,
kind FULL,
signals [
(threshold = 0.5), # specify threshold value
random_signal(threshold = 0.5), # specify threshold value
]
);

Expand All @@ -116,47 +75,31 @@ The next time this project is `sqlmesh run`, our signal will metaphorically flip

### Advanced Example

This example demonstrates more advanced use of signals, including:

- Multiple signals in one model
- A signal returning a subset of intervals from a batch (rather than a single `True`/`False` value for all intervals in the batch)

In this example, there are two signals.
This example demonstrates more advanced use of signals: a signal returning a subset of intervals from a batch (rather than a single `True`/`False` value for all intervals in the batch)

```python
import typing as t
from datetime import datetime

from sqlmesh.core.scheduler import signal_factory, Signal
from sqlmesh.utils.date import DatetimeRanges, to_datetime


class AlwaysReady(Signal):
# signal that indicates every interval is always ready
def check_intervals(self, batch: DatetimeRanges) -> t.Union[bool, DatetimeRanges]:
return True
from sqlmesh import signal, DatetimeRanges
from sqlmesh.utils.date import to_datetime


class OneweekAgo(Signal):
def __init__(self, dt: datetime):
self.dt = dt
# signal that returns only intervals that are <= 1 week ago
@signal()
def one_week_ago(batch: DatetimeRanges) -> t.Union[bool, DatetimeRanges]:
dt = to_datetime("1 week ago")

# signal that returns only intervals that are <= 1 week ago
def check_intervals(self, batch: DatetimeRanges) -> t.Union[bool, DatetimeRanges]:
return [
(start, end)
for start, end in batch
if start <= self.dt
]
return [
(start, end)
for start, end in batch
if start <= dt
]
```
Instead of returning a single `True`/`False` value for whether a batch of intervals is ready for evaluation, the `OneweekAgo` signal class returns specific intervals from the batch.

Its `check_intervals` method accepts a `dt` datetime argument, to which It compares the beginning of each interval in the batch. If the interval start is before that argument, the interval is ready for evaluation and included in the returned list.
These signals can be added to a model like so. Now that we have more than one signal, we must have a way to tell the signal factory which signal should be called.
Instead of returning a single `True`/`False` value for whether a batch of intervals is ready for evaluation, the `one_week_ago()` function returns specific intervals from the batch.

In this example, we use the `kind` key to tell the signal factory which signal class should be called. The key name is arbitrary, and you may choose any key name you want.

This example model specifies two `kind` values so both signals are called.
It generates a datetime argument, to which it compares the beginning of each interval in the batch. If the interval start is before that argument, the interval is ready for evaluation and included in the returned list.
These signals can be added to a model like so.

```sql linenums="1" hl_lines="7-10"
MODEL (
Expand All @@ -166,26 +109,10 @@ MODEL (
),
start '2 week ago',
signals [
(kind = 'a'),
(kind = 'b'),
one_week_ago(),
]
);


SELECT @start_ds AS ds
```

Our signal factory definition extracts the `kind` key value from the `signal_metadata`, then instantiates the signal class corresponding to the value:

```python linenums="1" hl_lines="3-3"
@signal_factory
def my_signal_factory(signal_metadata: t.Dict[str, t.Union[str, int, float, bool]]) -> Signal:
kind = signal_metadata["kind"]

if kind == "a":
return AlwaysReady()
if kind == "b":
return OneweekAgo(to_datetime("1 week ago"))
raise Exception(f"Unknown signal {kind}")
```

6 changes: 5 additions & 1 deletion examples/sushi/models/waiter_as_customer_by_day.sql
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@ MODEL (
audits (
not_null(columns := (waiter_id)),
forall(criteria := (LENGTH(waiter_name) > 0))
)
),
signals (
test_signal(arg := 1)
),

);

JINJA_QUERY_BEGIN;
Expand Down
9 changes: 9 additions & 0 deletions examples/sushi/signals/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import typing as t

from sqlmesh import signal, DatetimeRanges


@signal()
def test_signal(batch: DatetimeRanges, arg: int = 0) -> t.Union[bool, DatetimeRanges]:
assert arg == 1
return True
6 changes: 4 additions & 2 deletions sqlmesh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from sqlmesh.core.engine_adapter import EngineAdapter as EngineAdapter
from sqlmesh.core.macros import SQL as SQL, macro as macro
from sqlmesh.core.model import Model as Model, model as model
from sqlmesh.core.signal import signal as signal
from sqlmesh.core.snapshot import Snapshot as Snapshot
from sqlmesh.core.snapshot.evaluator import (
CustomMaterialization as CustomMaterialization,
Expand All @@ -32,6 +33,7 @@
debug_mode_enabled as debug_mode_enabled,
enable_debug_mode as enable_debug_mode,
)
from sqlmesh.utils.date import DatetimeRanges as DatetimeRanges

try:
from sqlmesh._version import __version__ as __version__, __version_tuple__ as __version_tuple__
Expand Down Expand Up @@ -171,13 +173,13 @@ def configure_logging(
os.remove(path)

if debug:
from signal import SIGUSR1
import faulthandler
import signal

enable_debug_mode()

# Enable threadumps.
faulthandler.enable()
# Windows doesn't support register so we check for it here
if hasattr(faulthandler, "register"):
faulthandler.register(signal.SIGUSR1.value)
faulthandler.register(SIGUSR1.value)
8 changes: 4 additions & 4 deletions sqlmesh/core/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import typing as t

from sqlmesh.core.dialect import parse_one, extract_audit
from sqlmesh.core.dialect import parse_one, extract_func_call
from sqlmesh.core.config.base import BaseConfig
from sqlmesh.core.model.kind import (
ModelKind,
Expand All @@ -11,7 +11,7 @@
on_destructive_change_validator,
)
from sqlmesh.utils.date import TimeLike
from sqlmesh.core.model.meta import AuditReference
from sqlmesh.core.model.meta import FunctionCall
from sqlmesh.utils.pydantic import field_validator


Expand Down Expand Up @@ -44,14 +44,14 @@ class ModelDefaultsConfig(BaseConfig):
storage_format: t.Optional[str] = None
on_destructive_change: t.Optional[OnDestructiveChange] = None
session_properties: t.Optional[t.Dict[str, t.Any]] = None
audits: t.Optional[t.List[AuditReference]] = None
audits: t.Optional[t.List[FunctionCall]] = None

_model_kind_validator = model_kind_validator
_on_destructive_change_validator = on_destructive_change_validator

@field_validator("audits", mode="before")
def _audits_validator(cls, v: t.Any) -> t.Any:
if isinstance(v, list):
return [extract_audit(parse_one(audit)) for audit in v]
return [extract_func_call(parse_one(audit)) for audit in v]

return v
1 change: 0 additions & 1 deletion sqlmesh/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2187,7 +2187,6 @@ def _load_materializations_and_signals(self) -> None:
if not self._loaded:
for context_loader in self._loaders.values():
with sys_path(*context_loader.configs):
context_loader.loader.load_signals(self)
context_loader.loader.load_materializations(self)

def _select_models_for_run(
Expand Down
10 changes: 9 additions & 1 deletion sqlmesh/core/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -1205,7 +1205,9 @@ def interpret_key_value_pairs(
return {i.this.name: interpret_expression(i.expression) for i in e.expressions}


def extract_audit(v: exp.Expression) -> t.Tuple[str, t.Dict[str, exp.Expression]]:
def extract_func_call(
v: exp.Expression, allow_tuples: bool = False
) -> t.Tuple[str, t.Dict[str, exp.Expression]]:
kwargs = {}

if isinstance(v, exp.Anonymous):
Expand All @@ -1214,6 +1216,12 @@ def extract_audit(v: exp.Expression) -> t.Tuple[str, t.Dict[str, exp.Expression]
elif isinstance(v, exp.Func):
func = v.sql_name()
args = list(v.args.values())
elif isinstance(v, exp.Tuple): # airflow only
if not allow_tuples:
raise ConfigError("Audit name is missing (eg. MY_AUDIT())")

func = ""
args = v.expressions
else:
return v.name.lower(), {}

Expand Down
Loading

0 comments on commit a1d8e40

Please sign in to comment.