Skip to content

Commit

Permalink
Feat!: Add support for virtual statements to be executed post update
Browse files Browse the repository at this point in the history
  • Loading branch information
themisvaltinos committed Dec 16, 2024
1 parent b2b87d8 commit f57f88e
Show file tree
Hide file tree
Showing 8 changed files with 348 additions and 14 deletions.
62 changes: 58 additions & 4 deletions sqlmesh/core/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ class JinjaStatement(Jinja):
pass


class VirtualStatement(exp.Expression):
pass


class ModelKind(exp.Expression):
arg_types = {"this": True, "expressions": False}

Expand Down Expand Up @@ -749,6 +753,8 @@ def _is_command_statement(command: str, tokens: t.List[Token], pos: int) -> bool
JINJA_QUERY_BEGIN = "JINJA_QUERY_BEGIN"
JINJA_STATEMENT_BEGIN = "JINJA_STATEMENT_BEGIN"
JINJA_END = "JINJA_END"
ON_VIRTUAL_UPDATE_BEGIN = "ON_VIRTUAL_UPDATE_BEGIN"
ON_VIRTUAL_UPDATE_END = "ON_VIRTUAL_UPDATE_END"


def _is_jinja_statement_begin(tokens: t.List[Token], pos: int) -> bool:
Expand All @@ -771,10 +777,24 @@ def jinja_statement(statement: str) -> JinjaStatement:
return JinjaStatement(this=exp.Literal.string(statement.strip()))


def _is_virtual_statement_begin(tokens: t.List[Token], pos: int) -> bool:
return _is_command_statement(ON_VIRTUAL_UPDATE_BEGIN, tokens, pos)


def _is_virtual_statement_end(tokens: t.List[Token], pos: int) -> bool:
return _is_command_statement(ON_VIRTUAL_UPDATE_END, tokens, pos)


def virtual_statement(statement: exp.Expression) -> VirtualStatement:
return VirtualStatement(this=statement)


class ChunkType(Enum):
JINJA_QUERY = auto()
JINJA_STATEMENT = auto()
SQL = auto()
VIRTUAL_STATEMENT = auto()
VIRTUAL_JINJA_STATEMENT = auto()


def parse_one(
Expand Down Expand Up @@ -814,9 +834,14 @@ def parse(
total = len(tokens)

pos = 0
virtual = False
while pos < total:
token = tokens[pos]
if _is_jinja_end(tokens, pos) or (
if _is_virtual_statement_end(tokens, pos):
pos += 2
virtual = False
chunks.append(([], ChunkType.SQL))
elif _is_jinja_end(tokens, pos) or (
chunks[-1][1] == ChunkType.SQL
and token.token_type == TokenType.SEMICOLON
and pos < total - 1
Expand All @@ -827,13 +852,32 @@ def parse(
# Jinja end statement
chunks[-1][0].append(token)
pos += 2
chunks.append(([], ChunkType.SQL))
if virtual and tokens[pos] != ON_VIRTUAL_UPDATE_END:
# This is required for nested Jinja statements that precede
# SQL statements within an ON_VIRTUAL_UPDATE block
chunks.append(
(
[Token(TokenType.VAR, text=ON_VIRTUAL_UPDATE_BEGIN)],
ChunkType.VIRTUAL_STATEMENT,
)
)
else:
chunks.append(([], ChunkType.SQL))
elif _is_jinja_query_begin(tokens, pos):
chunks.append(([token], ChunkType.JINJA_QUERY))
pos += 2
elif _is_jinja_statement_begin(tokens, pos):
chunks.append(([token], ChunkType.JINJA_STATEMENT))
chunks.append(
(
[token],
ChunkType.VIRTUAL_JINJA_STATEMENT if virtual else ChunkType.JINJA_STATEMENT,
)
)
pos += 2
elif _is_virtual_statement_begin(tokens, pos):
chunks.append(([token], ChunkType.VIRTUAL_STATEMENT))
pos += 2
virtual = True
else:
chunks[-1][0].append(token)
pos += 1
Expand All @@ -850,13 +894,23 @@ def parse(
if expression:
expression.meta["sql"] = parser._find_sql(chunk[0], chunk[-1])
expressions.append(expression)
elif chunk_type == ChunkType.VIRTUAL_STATEMENT:
sql_chunk = chunk[1:-1]
for expression in parser.parse(sql_chunk, sql):
if expression:
expression.meta["sql"] = expression.sql(dialect=dialect)
expressions.append(virtual_statement(expression))
else:
start, *_, end = chunk
segment = sql[start.end + 2 : end.start - 1]
factory = jinja_query if chunk_type == ChunkType.JINJA_QUERY else jinja_statement
expression = factory(segment.strip())
expression.meta["sql"] = sql[start.start : end.end + 1]
expressions.append(expression)
expressions.append(
virtual_statement(expression)
if chunk_type == ChunkType.VIRTUAL_JINJA_STATEMENT
else expression
)

return expressions

Expand Down
1 change: 1 addition & 0 deletions sqlmesh/core/model/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ def depends_on(cls: t.Type, v: t.Any, values: t.Dict[str, t.Any]) -> t.Optional[
"expressions_",
"pre_statements_",
"post_statements_",
"on_virtual_update_",
"unique_key",
mode="before",
check_fields=False,
Expand Down
2 changes: 1 addition & 1 deletion sqlmesh/core/model/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def model(
**self.kwargs,
}

for key in ("pre_statements", "post_statements"):
for key in ("pre_statements", "post_statements", "on_virtual_update"):
statements = common_kwargs.get(key)
if statements:
common_kwargs[key] = [
Expand Down
36 changes: 28 additions & 8 deletions sqlmesh/core/model/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ class _Model(ModelMeta, frozen=True):
post_statements_: t.Optional[t.List[exp.Expression]] = Field(
default=None, alias="post_statements"
)
on_virtual_update_: t.Optional[t.List[exp.Expression]] = Field(
default=None, alias="on_virtual_update"
)

_expressions_validator = expression_validator

Expand Down Expand Up @@ -499,10 +502,18 @@ def pre_statements(self) -> t.List[exp.Expression]:
def post_statements(self) -> t.List[exp.Expression]:
return self.post_statements_ or []

@property
def on_virtual_update(self) -> t.List[exp.Expression]:
return self.on_virtual_update_ or []

@property
def macro_definitions(self) -> t.List[d.MacroDef]:
"""All macro definitions from the list of expressions."""
return [s for s in self.pre_statements + self.post_statements if isinstance(s, d.MacroDef)]
return [
s
for s in self.pre_statements + self.post_statements + self.on_virtual_update
if isinstance(s, d.MacroDef)
]

def _render_statements(
self,
Expand Down Expand Up @@ -891,7 +902,7 @@ def _data_hash_values(self) -> t.List[str]:
data.append(key)
data.append(gen(value))

for statement in (*self.pre_statements, *self.post_statements):
for statement in (*self.pre_statements, *self.post_statements, *self.on_virtual_update):
statement_exprs: t.List[exp.Expression] = []
if not isinstance(statement, d.MacroDef):
rendered = self._statement_renderer(statement).render()
Expand Down Expand Up @@ -984,7 +995,7 @@ def _additional_metadata(self) -> t.List[str]:
if metadata_only_macros:
additional_metadata.append(str(metadata_only_macros))

for statement in (*self.pre_statements, *self.post_statements):
for statement in (*self.pre_statements, *self.post_statements, *self.on_virtual_update):
if self._is_metadata_statement(statement):
additional_metadata.append(gen(statement))

Expand Down Expand Up @@ -1056,6 +1067,7 @@ class SqlModel(_Model):
query: The main query representing the model.
pre_statements: The list of SQL statements that precede the model's query.
post_statements: The list of SQL statements that follow after the model's query.
on_virtual_update: The list of SQL statements to be executed after virtual update.
"""

query: t.Union[exp.Query, d.JinjaQuery, d.MacroFunc]
Expand Down Expand Up @@ -1117,6 +1129,7 @@ def render_definition(
result.extend(self.pre_statements)
result.append(self.query)
result.extend(self.post_statements)
result.extend(self.on_virtual_update)
return result

@property
Expand Down Expand Up @@ -1680,7 +1693,7 @@ def load_sql_based_model(
rendered_meta = rendered_meta_exprs[0]

# Extract the query and any pre/post statements
query_or_seed_insert, pre_statements, post_statements, inline_audits = (
query_or_seed_insert, pre_statements, post_statements, on_virtual_update, inline_audits = (
_split_sql_model_statements(expressions[1:], path, dialect=dialect)
)

Expand Down Expand Up @@ -1717,6 +1730,7 @@ def load_sql_based_model(
common_kwargs = dict(
pre_statements=pre_statements,
post_statements=post_statements,
on_virtual_update=on_virtual_update,
defaults=defaults,
path=path,
module_path=module_path,
Expand Down Expand Up @@ -1968,6 +1982,8 @@ def _create_model(
statements.append(kwargs["query"])
if "post_statements" in kwargs:
statements.extend(kwargs["post_statements"])
if "on_virtual_update" in kwargs:
statements.extend(kwargs["on_virtual_update"])

jinja_macro_references, used_variables = extract_macro_references_and_variables(
*(gen(e) for e in statements)
Expand Down Expand Up @@ -2057,6 +2073,7 @@ def _split_sql_model_statements(
t.Optional[exp.Expression],
t.List[exp.Expression],
t.List[exp.Expression],
t.List[exp.Expression],
UniqueKeyDict[str, ModelAudit],
]:
"""Extracts the SELECT query from a sequence of expressions.
Expand All @@ -2075,6 +2092,7 @@ def _split_sql_model_statements(

query_positions = []
sql_statements = []
on_virtual_update = []
inline_audits: UniqueKeyDict[str, ModelAudit] = UniqueKeyDict("inline_audits")

idx = 0
Expand All @@ -2086,7 +2104,9 @@ def _split_sql_model_statements(
loaded_audit = load_audit([expr, expressions[idx + 1]], dialect=dialect)
assert isinstance(loaded_audit, ModelAudit)
inline_audits[loaded_audit.name] = loaded_audit
idx += 2
idx += 1
elif isinstance(expr, d.VirtualStatement):
on_virtual_update.append(expr.this)
else:
if (
isinstance(expr, (exp.Query, d.JinjaQuery))
Expand All @@ -2098,16 +2118,16 @@ def _split_sql_model_statements(
):
query_positions.append((expr, idx))
sql_statements.append(expr)
idx += 1
idx += 1

if not query_positions:
return None, sql_statements, [], inline_audits
return None, sql_statements, [], on_virtual_update, inline_audits

elif len(query_positions) > 1:
raise_config_error("Only one SELECT query is allowed per model", path)

query, pos = query_positions[0]
return query, sql_statements[:pos], sql_statements[pos + 1 :], inline_audits
return query, sql_statements[:pos], sql_statements[pos + 1 :], on_virtual_update, inline_audits


def _resolve_session_properties(
Expand Down
33 changes: 32 additions & 1 deletion sqlmesh/core/plan/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from sqlmesh.schedulers.airflow.client import AirflowClient, BaseAirflowClient
from sqlmesh.schedulers.airflow.mwaa_client import MWAAClient
from sqlmesh.utils.errors import PlanError, SQLMeshError
from sqlmesh.utils.date import now

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -309,9 +310,10 @@ def _update_views(

completed = False
try:
added_snapshots = [snapshots[s.snapshot_id] for s in promotion_result.added]
self._promote_snapshots(
plan,
[snapshots[s.snapshot_id] for s in promotion_result.added],
added_snapshots,
environment.naming_info,
deployability_index=deployability_index,
on_complete=lambda s: self.console.update_promotion_progress(s, True),
Expand All @@ -323,6 +325,17 @@ def _update_views(
promotion_result.removed_environment_naming_info,
on_complete=lambda s: self.console.update_promotion_progress(s, False),
)

if promoted_snapshots := [
s for s in added_snapshots if s.is_model and not s.is_symbolic
]:
self._virtual_statements(
plan,
promoted_snapshots,
snapshots,
deployability_index,
)

self.state_sync.finalize(environment)
completed = True
finally:
Expand Down Expand Up @@ -354,6 +367,24 @@ def _demote_snapshots(
target_snapshots, environment_naming_info, on_complete=on_complete
)

def _virtual_statements(
self,
plan: EvaluatablePlan,
target_snapshots: t.Iterable[Snapshot],
snapshots: t.Dict[SnapshotId, Snapshot],
deployability_index: t.Optional[DeployabilityIndex] = None,
) -> None:
self.snapshot_evaluator._execute_virtual_statements(
target_snapshots,
snapshots,
plan.start,
plan.end,
plan.execution_time or now(),
plan.environment.naming_info,
self.default_catalog,
deployability_index,
)

def _restate(self, plan: EvaluatablePlan, snapshots_by_name: t.Dict[str, Snapshot]) -> None:
if not plan.restatements:
return
Expand Down
15 changes: 15 additions & 0 deletions sqlmesh/core/snapshot/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -1591,6 +1591,21 @@ def to_table_mapping(
}


def to_view_mapping(
snapshots: t.Iterable[Snapshot],
environment_naming_info: EnvironmentNamingInfo,
default_catalog: t.Optional[str] = None,
dialect: t.Optional[str] = None,
) -> t.Dict[str, str]:
return {
snapshot.name: snapshot.display_name(
environment_naming_info, default_catalog=default_catalog, dialect=dialect
)
for snapshot in snapshots
if snapshot.is_model
}


def has_paused_forward_only(
targets: t.Iterable[SnapshotIdLike],
snapshots: t.Union[t.List[Snapshot], t.Dict[SnapshotId, Snapshot]],
Expand Down
Loading

0 comments on commit f57f88e

Please sign in to comment.