Skip to content

Commit

Permalink
Feat: Add support for auto restatements
Browse files Browse the repository at this point in the history
  • Loading branch information
izeigerman committed Dec 18, 2024
1 parent a3e3382 commit a321984
Show file tree
Hide file tree
Showing 23 changed files with 1,540 additions and 100 deletions.
43 changes: 43 additions & 0 deletions docs/concepts/models/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,49 @@ Some properties are only available in specific model kinds - see the [model conf
### disable_restatement
: Set this to true to indicate that [data restatement](../plans.md#restatement-plans) is disabled for this model.

### auto_restatement_cron
: A cron expression that determines when SQLMesh should automatically restate this model. Restatement means re-evaluating either a number of last intervals (controlled by [`auto_restatement_intervals`](#auto_restatement_intervals)) for model kinds that support it or the entire model for model kinds that don't. Downstream models that depend on this model will also be restated. The auto-restatement is only applied when running the `sqlmesh run` command against the production environment.

A common use case for auto-restatement is to periodically re-evaluate a model (less frequently than the model's cron) to account for late-arriving data or dimension changes. However, relying on this feature is generally not recommended, as it often indicates an underlying issue with the data model or dependency chain. Instead, users should prefer setting the [`lookback`](#lookback) property to handle late-arriving data more effectively.

Unlike the [`lookback`](#lookback) property, which only controls the time range of data scanned, auto-restatement rewrites all previously processed data for this model in the target table.

For model kinds that don't support [`auto_restatement_intervals`](#auto_restatement_intervals) the table will be re-created from scratch.

Models with [`disable_restatement`](#disable_restatement) set to `true` will not be restated automatically even if this property is set.

**NOTE**: Models with this property set can only be [previewed](../plans.md#data-preview-for-forward-only-changes) in development environments, which means that the data computed in those environments will not be reused in production.

```sql linenums="1" hl_lines="6"
MODEL (
name test_db.national_holidays,
cron '@daily',
kind INCREMENTAL_BY_UNIQUE_KEY (
unique_key key,
auto_restatement_cron '@weekly',
)
);
```

### auto_restatement_intervals
: The number of last intervals to restate automatically. This is only applied in conjunction with [`auto_restatement_cron`](#auto_restatement_cron).

If not specified, the entire model will be restated.

This property is only supported for the `INCREMENTAL_BY_TIME_RANGE` model kind.

```sql linenums="1" hl_lines="7"
MODEL (
name test_db.national_holidays,
cron '@daily',
kind INCREMENTAL_BY_TIME_RANGE (
time_column event_ts,
auto_restatement_cron '@weekly',
auto_restatement_intervals 7, -- automatically restate the last 7 days of data
)
);
```

## Macros
Macros can be used for passing in parameterized arguments such as dates, as well as for making SQL less repetitive. By default, SQLMesh provides several predefined macro variables that can be used. Macros are used by prefixing with the `@` symbol. For more information, refer to [macros](../macros/overview.md).

Expand Down
1 change: 1 addition & 0 deletions sqlmesh/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -1932,6 +1932,7 @@ def _run(
ignore_cron=ignore_cron,
circuit_breaker=circuit_breaker,
selected_snapshots=select_models,
auto_restatement_enabled=environment.lower() == c.PROD,
)

def _apply(self, plan: Plan, circuit_breaker: t.Optional[t.Callable[[], bool]]) -> None:
Expand Down
17 changes: 16 additions & 1 deletion sqlmesh/core/model/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@
parse_dependencies,
single_value_or_tuple,
)
from sqlmesh.core.model.kind import ModelKindName, SeedKind, ModelKind, FullKind, create_model_kind
from sqlmesh.core.model.meta import ModelMeta, FunctionCall
from sqlmesh.core.model.kind import ModelKindName, SeedKind, ModelKind, FullKind, create_model_kind
from sqlmesh.core.model.seed import CsvSeedReader, Seed, create_seed
from sqlmesh.core.renderer import ExpressionRenderer, QueryRenderer
from sqlmesh.core.signal import SignalRegistry
from sqlmesh.utils import columns_to_types_all_known, str_to_bool, UniqueKeyDict
from sqlmesh.utils.cron import CroniterCache
from sqlmesh.utils.date import TimeLike, make_inclusive, to_datetime, to_time_column
from sqlmesh.utils.errors import ConfigError, SQLMeshError, raise_config_error
from sqlmesh.utils.hashing import hash_data
Expand Down Expand Up @@ -770,6 +771,20 @@ def forward_only(self) -> bool:
def disable_restatement(self) -> bool:
return getattr(self.kind, "disable_restatement", False)

@property
def auto_restatement_intervals(self) -> t.Optional[int]:
return getattr(self.kind, "auto_restatement_intervals", None)

@property
def auto_restatement_cron(self) -> t.Optional[str]:
return getattr(self.kind, "auto_restatement_cron", None)

def auto_restatement_croniter(self, value: TimeLike) -> CroniterCache:
cron = self.auto_restatement_cron
if cron is None:
raise SQLMeshError("Auto restatement cron is not set.")
return CroniterCache(cron, value)

@property
def wap_supported(self) -> bool:
return self.kind.is_materialized and (self.storage_format or "").lower() == "iceberg"
Expand Down
33 changes: 31 additions & 2 deletions sqlmesh/core/model/kind.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
SQLGlotListOfFields,
SQLGlotPositiveInt,
SQLGlotString,
SQLGlotCron,
column_validator,
field_validator,
field_validator_v1_args,
Expand Down Expand Up @@ -325,6 +326,7 @@ def _kind_dialect_validator(cls: t.Type, v: t.Optional[str]) -> str:

class _Incremental(_ModelKind):
on_destructive_change: OnDestructiveChange = OnDestructiveChange.ERROR
auto_restatement_cron: t.Optional[SQLGlotCron] = None

_on_destructive_change_validator = on_destructive_change_validator

Expand All @@ -333,6 +335,7 @@ def metadata_hash_values(self) -> t.List[t.Optional[str]]:
return [
*super().metadata_hash_values,
str(self.on_destructive_change),
self.auto_restatement_cron,
]

def to_expression(
Expand All @@ -341,8 +344,11 @@ def to_expression(
return super().to_expression(
expressions=[
*(expressions or []),
_property(
"on_destructive_change", exp.Literal.string(self.on_destructive_change.value)
*_properties(
{
"on_destructive_change": self.on_destructive_change.value,
"auto_restatement_cron": self.auto_restatement_cron,
}
),
],
)
Expand Down Expand Up @@ -399,6 +405,7 @@ class IncrementalByTimeRangeKind(_IncrementalBy):
ModelKindName.INCREMENTAL_BY_TIME_RANGE
)
time_column: TimeColumn
auto_restatement_intervals: t.Optional[SQLGlotPositiveInt] = None

_time_column_validator = TimeColumn.validator()

Expand All @@ -409,13 +416,27 @@ def to_expression(
expressions=[
*(expressions or []),
self.time_column.to_property(kwargs.get("dialect") or ""),
*(
[_property("auto_restatement_intervals", self.auto_restatement_intervals)]
if self.auto_restatement_intervals is not None
else []
),
]
)

@property
def data_hash_values(self) -> t.List[t.Optional[str]]:
return [*super().data_hash_values, gen(self.time_column.column), self.time_column.format]

@property
def metadata_hash_values(self) -> t.List[t.Optional[str]]:
return [
*super().metadata_hash_values,
str(self.auto_restatement_intervals)
if self.auto_restatement_intervals is not None
else None,
]


class IncrementalByUniqueKeyKind(_IncrementalBy):
name: t.Literal[ModelKindName.INCREMENTAL_BY_UNIQUE_KEY] = (
Expand Down Expand Up @@ -807,6 +828,8 @@ class CustomKind(_ModelKind):
batch_size: t.Optional[SQLGlotPositiveInt] = None
batch_concurrency: t.Optional[SQLGlotPositiveInt] = None
lookback: t.Optional[SQLGlotPositiveInt] = None
auto_restatement_cron: t.Optional[SQLGlotCron] = None
auto_restatement_intervals: t.Optional[SQLGlotPositiveInt] = None

_properties_validator = properties_validator

Expand Down Expand Up @@ -844,6 +867,10 @@ def metadata_hash_values(self) -> t.List[t.Optional[str]]:
str(self.batch_concurrency) if self.batch_concurrency is not None else None,
str(self.forward_only),
str(self.disable_restatement),
self.auto_restatement_cron,
str(self.auto_restatement_intervals)
if self.auto_restatement_intervals is not None
else None,
]

def to_expression(
Expand All @@ -861,6 +888,8 @@ def to_expression(
"batch_size": self.batch_size,
"batch_concurrency": self.batch_concurrency,
"lookback": self.lookback,
"auto_restatement_cron": self.auto_restatement_cron,
"auto_restatement_intervals": self.auto_restatement_intervals,
}
),
],
Expand Down
16 changes: 2 additions & 14 deletions sqlmesh/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from sqlmesh.utils.errors import ConfigError
from sqlmesh.utils.pydantic import (
PydanticModel,
SQLGlotCron,
field_validator,
model_validator,
model_validator_v1_args,
Expand Down Expand Up @@ -194,7 +195,7 @@ class _Node(PydanticModel):
owner: t.Optional[str] = None
start: t.Optional[TimeLike] = None
end: t.Optional[TimeLike] = None
cron: str = "@daily"
cron: SQLGlotCron = "@daily"
interval_unit_: t.Optional[IntervalUnit] = Field(alias="interval_unit", default=None)
tags: t.List[str] = []
stamp: t.Optional[str] = None
Expand Down Expand Up @@ -240,19 +241,6 @@ def _date_validator(cls, v: t.Any) -> t.Optional[TimeLike]:
raise ConfigError(f"'{v}' needs to be time-like: https://pypi.org/project/dateparser")
return v

@field_validator("cron", mode="before")
@classmethod
def _cron_validator(cls, v: t.Any) -> t.Optional[str]:
cron = str_or_exp_to_str(v)
if cron:
from croniter import CroniterBadCronError, croniter

try:
croniter(cron)
except CroniterBadCronError:
raise ConfigError(f"Invalid cron expression '{cron}'")
return cron

@field_validator("owner", "description", "stamp", mode="before")
@classmethod
def _string_expr_validator(cls, v: t.Any) -> t.Optional[str]:
Expand Down
23 changes: 21 additions & 2 deletions sqlmesh/core/plan/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,9 @@ def __init__(
self._choices: t.Dict[SnapshotId, SnapshotChangeCategory] = {}

self._start = start
if not self._start and self._forward_only_preview_needed:
if not self._start and (
self._forward_only_preview_needed or self._auto_restatement_preview_needed
):
self._start = default_start or yesterday_ds()

self._plan_id: str = random_id()
Expand Down Expand Up @@ -673,8 +675,25 @@ def _forward_only_preview_needed(self) -> bool:
self._enable_preview
and any(
snapshot.model.forward_only
for snapshot, _ in self._context_diff.modified_snapshots.values()
for snapshot in self._modified_and_added_snapshots
if snapshot.is_model
)
)
)

@cached_property
def _auto_restatement_preview_needed(self) -> bool:
return self._is_dev and any(
snapshot.model.auto_restatement_cron is not None
for snapshot in self._modified_and_added_snapshots
if snapshot.is_model
)

@cached_property
def _modified_and_added_snapshots(self) -> t.List[Snapshot]:
return [
snapshot
for snapshot in self._context_diff.snapshots.values()
if snapshot.name in self._context_diff.modified_snapshots
or snapshot.snapshot_id in self._context_diff.added
]
21 changes: 18 additions & 3 deletions sqlmesh/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@
from sqlmesh.core.snapshot import (
DeployabilityIndex,
Snapshot,
SnapshotId,
SnapshotEvaluator,
apply_auto_restatements,
earliest_start_date,
missing_intervals,
merge_intervals,
Intervals,
)
from sqlmesh.core.snapshot.definition import Interval, expand_range
from sqlmesh.core.snapshot.definition import SnapshotId, merge_intervals
from sqlmesh.core.state_sync import StateSync
from sqlmesh.utils import format_exception
from sqlmesh.utils.concurrency import concurrent_apply_to_dag, NodeExecutionFailedError
Expand Down Expand Up @@ -116,8 +118,6 @@ def merged_missing_intervals(
validate_date_range(start, end)

snapshots: t.Collection[Snapshot] = self.snapshot_per_version.values()
self.state_sync.refresh_snapshot_intervals(snapshots)

snapshots_to_intervals = compute_interval_params(
snapshots,
start=start or earliest_start_date(snapshots),
Expand Down Expand Up @@ -156,6 +156,7 @@ def evaluate(
execution_time: The date/time time reference to use for execution time. Defaults to now.
deployability_index: Determines snapshots that are deployable in the context of this evaluation.
batch_index: If the snapshot is part of a batch of related snapshots; which index in the batch is it
auto_restatement_enabled: Whether to enable auto restatements.
kwargs: Additional kwargs to pass to the renderer.
"""
validate_date_range(start, end)
Expand Down Expand Up @@ -225,6 +226,7 @@ def run(
selected_snapshots: t.Optional[t.Set[str]] = None,
circuit_breaker: t.Optional[t.Callable[[], bool]] = None,
deployability_index: t.Optional[DeployabilityIndex] = None,
auto_restatement_enabled: bool = False,
) -> bool:
"""Concurrently runs all snapshots in topological order.
Expand All @@ -243,6 +245,7 @@ def run(
selected_snapshots: A set of snapshot names to run. If not provided, all snapshots will be run.
circuit_breaker: An optional handler which checks if the run should be aborted.
deployability_index: Determines snapshots that are deployable in the context of this render.
auto_restatement_enabled: Whether to enable auto restatements.
Returns:
True if the execution was successful and False otherwise.
Expand All @@ -266,6 +269,15 @@ def run(
else DeployabilityIndex.all_deployable()
)
execution_time = execution_time or now()

self.state_sync.refresh_snapshot_intervals(self.snapshots.values())
if auto_restatement_enabled:
auto_restated_intervals = apply_auto_restatements(self.snapshots, execution_time)
self.state_sync.add_snapshots_intervals(auto_restated_intervals)
self.state_sync.update_auto_restatements(
{s.name_version: s.next_auto_restatement_ts for s in self.snapshots.values()}
)

merged_intervals = self.merged_missing_intervals(
start,
end,
Expand All @@ -288,6 +300,7 @@ def run(
circuit_breaker=circuit_breaker,
start=start,
end=end,
auto_restatement_enabled=auto_restatement_enabled,
)

self.console.stop_evaluation_progress(success=not errors)
Expand Down Expand Up @@ -377,6 +390,7 @@ def run_merged_intervals(
circuit_breaker: t.Optional[t.Callable[[], bool]] = None,
start: t.Optional[TimeLike] = None,
end: t.Optional[TimeLike] = None,
auto_restatement_enabled: bool = False,
) -> t.Tuple[t.List[NodeExecutionFailedError[SchedulingUnit]], t.List[SchedulingUnit]]:
"""Runs precomputed batches of missing intervals.
Expand All @@ -388,6 +402,7 @@ def run_merged_intervals(
circuit_breaker: An optional handler which checks if the run should be aborted.
start: The start of the run.
end: The end of the run.
auto_restatement_enabled: Whether to enable auto restatements.
Returns:
A tuple of errors and skipped intervals.
Expand Down
1 change: 1 addition & 0 deletions sqlmesh/core/snapshot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
SnapshotNameVersionLike as SnapshotNameVersionLike,
SnapshotTableCleanupTask as SnapshotTableCleanupTask,
SnapshotTableInfo as SnapshotTableInfo,
apply_auto_restatements as apply_auto_restatements,
earliest_start_date as earliest_start_date,
fingerprint_from_node as fingerprint_from_node,
has_paused_forward_only as has_paused_forward_only,
Expand Down
Loading

0 comments on commit a321984

Please sign in to comment.