Skip to content

Commit

Permalink
fix: ensure that restatements in prod also trigger restatements in dev
Browse files Browse the repository at this point in the history
  • Loading branch information
erindru committed Dec 18, 2024
1 parent 54f49ee commit ae9a565
Show file tree
Hide file tree
Showing 5 changed files with 684 additions and 3 deletions.
47 changes: 47 additions & 0 deletions docs/concepts/plans.md
Original file line number Diff line number Diff line change
Expand Up @@ -350,3 +350,50 @@ See examples below for how to restate both based on model names and model tags.
```bash
sqlmesh plan --restate-model "+db*" --restate-model "tag:+exp*"
```
=== "Specific Date Range"
```bash
sqlmesh plan --restate-model "db.model_a" --start "2024-01-01" --end "2024-01-10"
```
### Restating production vs development
Restatement plans behave differently depending on if you're targeting the `prod` environment or a [development environment](./environments.md#how-to-use-environments).
If you target a development environment like so:
```bash
sqlmesh plan dev --restate-model "db.model_a" --start "2024-01-01" --end "2024-01-10"
```
the restatement plan will restate the requested intervals for the specified model in the `dev` environment. Versions of the model in other environments will be unaffected.
However, if you target the `prod` environment:
```bash
sqlmesh plan --restate-model "db.model_a" --start "2024-01-01" --end "2024-01-10"
```
the restatement plan will restate the intervals in the `prod` table *and clear the intervals from state for every other version of that model*.
This means that next time you do a run in `dev`, those intervals will be restated in the development environment as well.
The reason for this is to prevent old data from getting promoted to `prod`. One of the benefits of SQLMesh is being able to [reuse tables](#virtual-update) from development environments to ensure that production deployments consist of quick, painless pointer swaps.
!!! info
If restating data in `prod` did not also trigger a restatement in `dev`, when `sqlmesh plan` is run against `prod` to deploy changes, a table containing old data may be promoted.
That this behaviour also triggers downstream tables that only exist in development environments to have the affected intervals cleared as well. Consider the following example:
- Table `A` exists in `prod`
- A virtual environment `dev` is created with new tables `B` and `C` downstream of `A`
- the DAG in `prod` looks like `A`
- the DAG in `dev` looks like `A <- B <- C`
- A restatement plan is created against table `A` in `prod`
- SQLMesh will ensure that the affected intervals are also cleared for `B` and `C` in `dev` even though those tables do not exist in `prod`
!!! info
If a restatement plan against `prod` cleared intervals from state for tables in development environments, you need to `sqlmesh run <env>` to trigger the reprocessing of that data.
This is because SQLMesh limits the work done in the `prod` restatement plan to just the `prod` environment to ensure the restatement can be applied as quickly as possible and to prevent potentially unnecessary work being done. We do not assume that all snapshots in a development environment will be eventually deployed to `prod`.
16 changes: 15 additions & 1 deletion sqlmesh/core/plan/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from sqlmesh.core.snapshot.definition import Interval, SnapshotId
from sqlmesh.utils import columns_to_types_all_known, random_id
from sqlmesh.utils.dag import DAG
from sqlmesh.utils.date import TimeLike, now, to_datetime, yesterday_ds
from sqlmesh.utils.date import TimeLike, now, to_datetime, yesterday_ds, to_timestamp
from sqlmesh.utils.errors import NoChangesPlanError, PlanError, SQLMeshError

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -341,6 +341,7 @@ def is_restateable_snapshot(snapshot: Snapshot) -> bool:
for downstream_s_id in dag.downstream(snapshot.snapshot_id):
if is_restateable_snapshot(self._context_diff.snapshots[downstream_s_id]):
restatements[downstream_s_id] = dummy_interval

# Get restatement intervals for all restated snapshots and make sure that if a snapshot expands it's
# restatement range that it's downstream dependencies all expand their restatement ranges as well.
for s_id in dag:
Expand All @@ -362,7 +363,20 @@ def is_restateable_snapshot(snapshot: Snapshot) -> bool:
] + [interval]
snapshot_start = min(i[0] for i in possible_intervals)
snapshot_end = max(i[1] for i in possible_intervals)

# We may be tasked with restating a time range smaller than the target snapshot interval unit
# For example, restating an hour of Hourly Model A, which has a downstream dependency of Daily Model B
# we need to ensure the whole affected day in Model B is restated
floored_snapshot_start = snapshot.node.interval_unit.cron_floor(snapshot_start)
floored_snapshot_end = snapshot.node.interval_unit.cron_floor(snapshot_end)
if floored_snapshot_end <= floored_snapshot_start:
snapshot_start = to_timestamp(floored_snapshot_start)
snapshot_end = to_timestamp(
snapshot.node.interval_unit.cron_next(floored_snapshot_end)
)

restatements[s_id] = (snapshot_start, snapshot_end)

return restatements

def _build_directly_and_indirectly_modified(
Expand Down
4 changes: 4 additions & 0 deletions sqlmesh/core/plan/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,10 @@ def is_selected_for_backfill(self, model_fqn: str) -> bool:
def plan_id(self) -> str:
return self.environment.plan_id

@property
def is_prod(self) -> bool:
return not self.is_dev


class PlanStatus(str, Enum):
STARTED = "started"
Expand Down
54 changes: 52 additions & 2 deletions sqlmesh/core/plan/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from sqlmesh.core.notification_target import (
NotificationTarget,
)
from sqlmesh.core.snapshot.definition import Interval
from sqlmesh.core.plan.definition import EvaluatablePlan
from sqlmesh.core.scheduler import Scheduler
from sqlmesh.core.snapshot import (
Expand All @@ -43,6 +44,7 @@
from sqlmesh.schedulers.airflow.client import AirflowClient, BaseAirflowClient
from sqlmesh.schedulers.airflow.mwaa_client import MWAAClient
from sqlmesh.utils.errors import PlanError, SQLMeshError
from sqlmesh.utils.dag import DAG

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -358,11 +360,59 @@ def _restate(self, plan: EvaluatablePlan, snapshots_by_name: t.Dict[str, Snapsho
if not plan.restatements:
return

snapshot_intervals_to_restate = {
(snapshots_by_name[name].table_info, intervals)
for name, intervals in plan.restatements.items()
}

if plan.is_prod:
# Restating intervals on prod plans should mean that the intervals are cleared across
# all environments, not just the version currently in prod
# This ensures that work done in dev environments can still be promoted to prod
# by forcing dev environments to re-run intervals that changed in prod
#
# Without this rule, its possible that promoting a dev table to prod will introduce old data to prod
snapshot_intervals_to_restate.update(
self._restatement_intervals_across_all_environments(plan.restatements)
)

self.state_sync.remove_intervals(
[(snapshots_by_name[name], interval) for name, interval in plan.restatements.items()],
remove_shared_versions=not plan.is_dev,
snapshot_intervals=list(snapshot_intervals_to_restate),
remove_shared_versions=plan.is_prod,
)

def _restatement_intervals_across_all_environments(
self, prod_restatements: t.Dict[str, Interval]
) -> t.Set[t.Tuple[SnapshotTableInfo, Interval]]:
"""
Given a map of snapshot names + intervals to restate in prod:
- Look up matching snapshots across all environments (match based on name - regardless of version)
- For each match, also match downstream snapshots
- Return all matches mapped to the intervals of the prod snapshot being restated
The goal here is to produce a list of intervals to invalidate across all environments so that a cadence
run in those environments causes the intervals to be repopulated
"""
if not prod_restatements:
return set()

snapshots_to_restate: t.Set[t.Tuple[SnapshotTableInfo, Interval]] = set()

for env in self.state_sync.get_environments():
keyed_snapshots = {s.name: s.table_info for s in env.snapshots}

# We dont just restate matching snapshots, we also have to restate anything downstream of them
# so that if A gets restated in prod and dev has A <- B <- C, B and C get restated in dev
env_dag = DAG({s.name: {p.name for p in s.parents} for s in env.snapshots})

for restatement, intervals in prod_restatements.items():
affected_snapshot_names = [restatement] + env_dag.downstream(restatement)
snapshots_to_restate.update(
{(keyed_snapshots[a], intervals) for a in affected_snapshot_names}
)

return snapshots_to_restate


class BaseAirflowPlanEvaluator(PlanEvaluator):
def __init__(
Expand Down
Loading

0 comments on commit ae9a565

Please sign in to comment.