Skip to content

Commit

Permalink
Fix: Always treat forward-only models as non-deployable (#3510)
Browse files Browse the repository at this point in the history
  • Loading branch information
izeigerman authored Dec 12, 2024
1 parent c512e63 commit 434be40
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 25 deletions.
12 changes: 8 additions & 4 deletions sqlmesh/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,9 @@ def merged_missing_intervals(
validate_date_range(start, end)

snapshots: t.Collection[Snapshot] = self.snapshot_per_version.values()
if selected_snapshots is not None:
snapshots = [s for s in snapshots if s.name in selected_snapshots]

self.state_sync.refresh_snapshot_intervals(snapshots)

return compute_interval_params(
snapshots_to_intervals = compute_interval_params(
snapshots,
start=start or earliest_start_date(snapshots),
end=end or now(),
Expand All @@ -132,6 +129,13 @@ def merged_missing_intervals(
ignore_cron=ignore_cron,
end_bounded=end_bounded,
)
# Filtering snapshots after computing missing intervals because we need all snapshots in order
# to correctly infer start dates.
if selected_snapshots is not None:
snapshots_to_intervals = {
s: i for s, i in snapshots_to_intervals.items() if s.name in selected_snapshots
}
return snapshots_to_intervals

def evaluate(
self,
Expand Down
13 changes: 3 additions & 10 deletions sqlmesh/core/snapshot/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -1340,13 +1340,7 @@ def _visit(node: SnapshotId, deployable: bool = True) -> None:

if deployable and node in snapshots:
snapshot = snapshots[node]
# Capture uncategorized snapshot which represents a forward-only model.
is_uncategorized_forward_only_model = (
snapshot.change_category is None
and snapshot.previous_versions
and snapshot.is_model
and snapshot.model.forward_only
)
is_forward_only_model = snapshot.is_model and snapshot.model.forward_only

is_valid_start = (
snapshot.is_valid_start(
Expand All @@ -1359,7 +1353,7 @@ def _visit(node: SnapshotId, deployable: bool = True) -> None:
if (
snapshot.is_forward_only
or snapshot.is_indirect_non_breaking
or is_uncategorized_forward_only_model
or is_forward_only_model
or not is_valid_start
):
# FORWARD_ONLY and INDIRECT_NON_BREAKING snapshots are not deployable by nature.
Expand All @@ -1372,8 +1366,7 @@ def _visit(node: SnapshotId, deployable: bool = True) -> None:
else:
this_deployable = True
children_deployable = is_valid_start and not (
snapshot.is_paused
and (snapshot.is_forward_only or is_uncategorized_forward_only_model)
snapshot.is_paused and (snapshot.is_forward_only or is_forward_only_model)
)
else:
this_deployable, children_deployable = False, False
Expand Down
74 changes: 70 additions & 4 deletions tests/core/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1774,6 +1774,72 @@ def test_new_forward_only_model_concurrent_versions(init_and_plan_context: t.Cal
assert df.to_dict() == {"ds": {0: "2023-01-07"}, "b": {0: None}}


@freeze_time("2023-01-08 15:00:00")
def test_new_forward_only_model_same_dev_environment(init_and_plan_context: t.Callable):
context, plan = init_and_plan_context("examples/sushi")
context.apply(plan)

new_model_expr = d.parse(
"""
MODEL (
name memory.sushi.new_model,
kind INCREMENTAL_BY_TIME_RANGE (
time_column ds,
forward_only TRUE,
on_destructive_change 'allow',
),
);
SELECT '2023-01-07' AS ds, 1 AS a;
"""
)
new_model = load_sql_based_model(new_model_expr)

# Add the first version of the model and apply it to dev.
context.upsert_model(new_model)
snapshot_a = context.get_snapshot(new_model.name)
plan_a = context.plan("dev", no_prompts=True)
snapshot_a = plan_a.snapshots[snapshot_a.snapshot_id]

assert snapshot_a.snapshot_id in plan_a.context_diff.new_snapshots
assert snapshot_a.snapshot_id in plan_a.context_diff.added
assert snapshot_a.change_category == SnapshotChangeCategory.BREAKING

context.apply(plan_a)

df = context.fetchdf("SELECT * FROM memory.sushi__dev.new_model")
assert df.to_dict() == {"ds": {0: "2023-01-07"}, "a": {0: 1}}

new_model_alt_expr = d.parse(
"""
MODEL (
name memory.sushi.new_model,
kind INCREMENTAL_BY_TIME_RANGE (
time_column ds,
forward_only TRUE,
on_destructive_change 'allow',
),
);
SELECT '2023-01-07' AS ds, 1 AS b;
"""
)
new_model_alt = load_sql_based_model(new_model_alt_expr)

# Add the second version of the model and apply it to the same environment.
context.upsert_model(new_model_alt)
snapshot_b = context.get_snapshot(new_model_alt.name)

context.invalidate_environment("dev", sync=True)
plan_b = context.plan("dev", no_prompts=True)
snapshot_b = plan_b.snapshots[snapshot_b.snapshot_id]

context.apply(plan_b)

df = context.fetchdf("SELECT * FROM memory.sushi__dev.new_model").replace({np.nan: None})
assert df.to_dict() == {"ds": {0: "2023-01-07"}, "b": {0: 1}}


def test_plan_twice_with_star_macro_yields_no_diff(tmp_path: Path):
init_example_project(tmp_path, dialect="duckdb")

Expand Down Expand Up @@ -2564,7 +2630,7 @@ def get_default_catalog_and_non_tables(
) = get_default_catalog_and_non_tables(metadata, context.default_catalog)
assert len(prod_views) == 13
assert len(dev_views) == 0
assert len(user_default_tables) == 13
assert len(user_default_tables) == 16
assert state_metadata.schemas == ["sqlmesh"]
assert {x.sql() for x in state_metadata.qualified_tables}.issuperset(
{
Expand All @@ -2583,7 +2649,7 @@ def get_default_catalog_and_non_tables(
) = get_default_catalog_and_non_tables(metadata, context.default_catalog)
assert len(prod_views) == 13
assert len(dev_views) == 13
assert len(user_default_tables) == 13
assert len(user_default_tables) == 16
assert len(non_default_tables) == 0
assert state_metadata.schemas == ["sqlmesh"]
assert {x.sql() for x in state_metadata.qualified_tables}.issuperset(
Expand All @@ -2603,7 +2669,7 @@ def get_default_catalog_and_non_tables(
) = get_default_catalog_and_non_tables(metadata, context.default_catalog)
assert len(prod_views) == 13
assert len(dev_views) == 26
assert len(user_default_tables) == 13
assert len(user_default_tables) == 16
assert len(non_default_tables) == 0
assert state_metadata.schemas == ["sqlmesh"]
assert {x.sql() for x in state_metadata.qualified_tables}.issuperset(
Expand All @@ -2624,7 +2690,7 @@ def get_default_catalog_and_non_tables(
) = get_default_catalog_and_non_tables(metadata, context.default_catalog)
assert len(prod_views) == 13
assert len(dev_views) == 13
assert len(user_default_tables) == 13
assert len(user_default_tables) == 16
assert len(non_default_tables) == 0
assert state_metadata.schemas == ["sqlmesh"]
assert {x.sql() for x in state_metadata.qualified_tables}.issuperset(
Expand Down
10 changes: 4 additions & 6 deletions tests/core/test_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1734,17 +1734,15 @@ def test_deployability_index_categorized_forward_only_model(make_snapshot):
snapshot_b.parents = (snapshot_a.snapshot_id,)
snapshot_b.categorize_as(SnapshotChangeCategory.METADATA)

# The fact that the model is forward only should be ignored if an actual category
# has been assigned.
deployability_index = DeployabilityIndex.create(
{s.snapshot_id: s for s in [snapshot_a, snapshot_b]}
)

assert deployability_index.is_deployable(snapshot_a)
assert deployability_index.is_deployable(snapshot_b)
assert not deployability_index.is_deployable(snapshot_a)
assert not deployability_index.is_deployable(snapshot_b)

assert deployability_index.is_representative(snapshot_a)
assert deployability_index.is_representative(snapshot_b)
assert not deployability_index.is_representative(snapshot_a)
assert not deployability_index.is_representative(snapshot_b)


def test_deployability_index_missing_parent(make_snapshot):
Expand Down
2 changes: 1 addition & 1 deletion tests/integrations/jupyter/test_magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def test_plan(

# TODO: Should this be going to stdout? This is printing the status updates for when each batch finishes for
# the models and how long it took
assert len(output.stdout.strip().split("\n")) == 22
assert len(output.stdout.strip().split("\n")) == 23
assert not output.stderr
assert len(output.outputs) == 4
text_output = convert_all_html_output_to_text(output)
Expand Down

0 comments on commit 434be40

Please sign in to comment.