Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1819531: lqb bug fixes #2670

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -1002,6 +1002,9 @@ def projection_complexities(self) -> List[Dict[PlanNodeCategory, int]]:
dependent_column_complexity = (
subquery_projection_name_complexity_map[dependent_column]
)
assert (
PlanNodeCategory.COLUMN in projection_complexity
), projection_complexity
projection_complexity[PlanNodeCategory.COLUMN] -= 1
projection_complexity = sum_node_complexities(
projection_complexity, dependent_column_complexity
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,14 @@ def _is_relaxed_pipeline_breaker(self, node: LogicalPlan) -> bool:
if isinstance(node, SelectStatement):
return True

if isinstance(node, SnowflakePlan):
return node.source_plan is not None and self._is_relaxed_pipeline_breaker(
node.source_plan
)

if isinstance(node, SelectSnowflakePlan):
return self._is_relaxed_pipeline_breaker(node.snowflake_plan)

return False

def _is_node_pipeline_breaker(self, node: LogicalPlan) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def do_resolve_with_resolved_children(
iceberg_config=logical_plan.iceberg_config,
table_exists=logical_plan.table_exists,
)
resolved_plan.referenced_ctes = resolved_child.referenced_ctes

elif isinstance(
logical_plan,
Expand Down Expand Up @@ -197,6 +198,7 @@ def do_resolve_with_resolved_children(
resolved_plan = super().do_resolve_with_resolved_children(
logical_plan, resolved_children, df_aliased_col_name_to_real_col_name
)
resolved_plan.referenced_ctes = resolved_child.referenced_ctes

elif isinstance(logical_plan, Selectable):
# overwrite the Selectable resolving to make sure we are triggering
Expand Down
48 changes: 42 additions & 6 deletions src/snowflake/snowpark/_internal/compiler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@
TableMerge,
TableUpdate,
)
from snowflake.snowpark._internal.analyzer.unary_plan_node import UnaryNode
from snowflake.snowpark._internal.analyzer.unary_plan_node import (
CreateViewCommand,
UnaryNode,
)
from snowflake.snowpark._internal.compiler.query_generator import (
QueryGenerator,
SnowflakeCreateTablePlanInfo,
Expand Down Expand Up @@ -123,7 +126,7 @@ def to_selectable(plan: LogicalPlan, query_generator: QueryGenerator) -> Selecta
snowflake_plan = query_generator.resolve(plan)
return SelectSnowflakePlan(snowflake_plan, analyzer=query_generator)

if not parent._is_valid_for_replacement:
if not valid_for_replacement(parent):
raise ValueError(f"parent node {parent} is not valid for replacement.")

if old_child not in getattr(parent, "children_plan_nodes", parent.children):
Expand Down Expand Up @@ -195,6 +198,19 @@ def to_selectable(plan: LogicalPlan, query_generator: QueryGenerator) -> Selecta
raise ValueError(f"parent type {type(parent)} not supported")


def valid_for_replacement(node: LogicalPlan) -> bool:
if node._is_valid_for_replacement:
return True

if isinstance(node, SnowflakePlan) and node.source_plan is not None:
return node._is_valid_for_replacement

if isinstance(node, SelectSnowflakePlan):
return valid_for_replacement(node.snowflake_plan)

return False


def update_resolvable_node(
node: TreeNode,
query_generator: QueryGenerator,
Expand All @@ -219,7 +235,7 @@ def update_resolvable_node(
resolve_node(SelectSnowflakePlan, query_generator) will resolve both SelectSnowflakePlan and SnowflakePlan nodes.
"""

if not node._is_valid_for_replacement:
if not valid_for_replacement(node):
raise ValueError(f"node {node} is not valid for update.")

if not isinstance(node, (SnowflakePlan, Selectable)):
Expand Down Expand Up @@ -306,7 +322,17 @@ def get_snowflake_plan_queries(

plan_queries = plan.queries
post_action_queries = plan.post_actions
if len(plan.referenced_ctes) > 0:
if len(plan.referenced_ctes) > 0 and not isinstance(
plan.source_plan,
(
SnowflakeCreateTable,
CreateViewCommand,
TableUpdate,
TableDelete,
TableMerge,
CopyIntoLocationNode,
),
):
# make a copy of the original query to avoid any update to the
# original query object
plan_queries = copy.deepcopy(plan.queries)
Expand Down Expand Up @@ -349,6 +375,7 @@ def plot_plan_if_enabled(root: LogicalPlan, filename: str) -> None:
saved in the directory `snowpark_query_plan_plots` with the given `filename`. For example,
we can set the environment variables as follows:

$ export SNOWPARK_LOGICAL_PLAN_PLOTTING_THRESHOLD = 10 # minimum complexity score to start plotting
$ export ENABLE_SNOWPARK_LOGICAL_PLAN_PLOTTING=true
$ export TMPDIR="/tmp"
$ ls /tmp/snowpark_query_plan_plots/ # to see the plots
Expand All @@ -365,6 +392,11 @@ def plot_plan_if_enabled(root: LogicalPlan, filename: str) -> None:
):
return

if get_complexity_score(root) < int(
os.environ.get("SNOWPARK_LOGICAL_PLAN_PLOTTING_THRESHOLD", 0)
):
return

import graphviz # pyright: ignore[reportMissingImports]

def get_stat(node: LogicalPlan):
Expand Down Expand Up @@ -397,6 +429,9 @@ def get_name(node: Optional[LogicalPlan]) -> str:
name = f"{name} :: ({'| '.join(properties)})"

score = get_complexity_score(node)
num_ref_ctes = -1
if isinstance(node, (SnowflakePlan, Selectable)):
num_ref_ctes = len(node.referenced_ctes)
sql_text = ""
if isinstance(node, Selectable):
sql_text = node.sql_query
Expand All @@ -405,7 +440,7 @@ def get_name(node: Optional[LogicalPlan]) -> str:
sql_size = len(sql_text)
sql_preview = sql_text[:50]

return f"{name=}\n{score=}, {sql_size=}\n{sql_preview=}"
return f"{name=}\n{score=}, {num_ref_ctes=}, {sql_size=}\n{sql_preview=}"

g = graphviz.Graph(format="png")

Expand All @@ -415,7 +450,8 @@ def get_name(node: Optional[LogicalPlan]) -> str:
next_level = []
for node in curr_level:
node_id = hex(id(node))
g.node(node_id, get_stat(node))
color = "lightblue" if valid_for_replacement(node) else "red"
g.node(node_id, get_stat(node), color=color)
if isinstance(node, (Selectable, SnowflakePlan)):
children = node.children_plan_nodes
else:
Expand Down
23 changes: 17 additions & 6 deletions tests/integ/scala/test_datatype_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,11 +600,14 @@ def test_iceberg_nested_fields(
Utils.drop_table(structured_type_session, transformed_table_name)


@pytest.mark.skip(
reason="SNOW-1819531: Error in _contains_external_cte_ref when analyzing lqb"
@pytest.mark.xfail(
"config.getoption('local_testing_mode', default=False)",
reason="local testing does not fully support structured types yet.",
run=False,
)
@pytest.mark.parametrize("cte_enabled", [True, False])
def test_struct_dtype_iceberg_lqb(
structured_type_session, local_testing_mode, structured_type_support
structured_type_session, local_testing_mode, structured_type_support, cte_enabled
):
if not (
structured_type_support
Expand Down Expand Up @@ -641,12 +644,14 @@ def test_struct_dtype_iceberg_lqb(
is_query_compilation_stage_enabled = (
structured_type_session._query_compilation_stage_enabled
)
is_cte_optimization_enabled = structured_type_session._cte_optimization_enabled
is_large_query_breakdown_enabled = (
structured_type_session._large_query_breakdown_enabled
)
original_bounds = structured_type_session._large_query_breakdown_complexity_bounds
try:
structured_type_session._query_compilation_stage_enabled = True
structured_type_session._cte_optimization_enabled = cte_enabled
structured_type_session._large_query_breakdown_enabled = True
structured_type_session._large_query_breakdown_complexity_bounds = (300, 600)

Expand Down Expand Up @@ -695,9 +700,14 @@ def test_struct_dtype_iceberg_lqb(
)

queries = union_df.queries
# assert that the queries are broken down into 2 queries and 1 post action
assert len(queries["queries"]) == 2, queries["queries"]
assert len(queries["post_actions"]) == 1
if cte_enabled:
# when CTE is enabled, WithQueryBlock makes pipeline breaker nodes ineligible
assert len(queries["queries"]) == 1
assert len(queries["post_actions"]) == 0
else:
# assert that the queries are broken down into 2 queries and 1 post action
assert len(queries["queries"]) == 2, queries["queries"]
assert len(queries["post_actions"]) == 1
final_df = structured_type_session.table(write_table)

# assert that
Expand All @@ -707,6 +717,7 @@ def test_struct_dtype_iceberg_lqb(
structured_type_session._query_compilation_stage_enabled = (
is_query_compilation_stage_enabled
)
structured_type_session._cte_optimization_enabled = is_cte_optimization_enabled
structured_type_session._large_query_breakdown_enabled = (
is_large_query_breakdown_enabled
)
Expand Down
15 changes: 12 additions & 3 deletions tests/integ/test_large_query_breakdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,16 +731,21 @@ def test_large_query_breakdown_enabled_parameter(session, caplog):

@pytest.mark.skipif(IS_IN_STORED_PROC, reason="requires graphviz")
@pytest.mark.parametrize("enabled", [False, True])
def test_plotter(session, large_query_df, enabled):
@pytest.mark.parametrize("score_threshold", [0, 3000000000000])
def test_plotter(session, large_query_df, enabled, score_threshold):
original_plotter_enabled = os.environ.get("ENABLE_SNOWPARK_LOGICAL_PLAN_PLOTTING")
original_threshold = os.environ.get("SNOWPARK_LOGICAL_PLAN_PLOTTING_THRESHOLD")
try:
os.environ["ENABLE_SNOWPARK_LOGICAL_PLAN_PLOTTING"] = str(enabled)
os.environ["SNOWPARK_LOGICAL_PLAN_PLOTTING_THRESHOLD"] = str(score_threshold)
render_expected = enabled and (score_threshold == 0)

tmp_dir = tempfile.gettempdir()

with patch("graphviz.Graph.render") as mock_render:
large_query_df.collect()
assert mock_render.called == enabled
if not enabled:
assert mock_render.called == render_expected
if not render_expected:
return

assert mock_render.call_count == 5
Expand All @@ -762,3 +767,7 @@ def test_plotter(session, large_query_df, enabled):
] = original_plotter_enabled
else:
del os.environ["ENABLE_SNOWPARK_LOGICAL_PLAN_PLOTTING"]
if original_threshold is not None:
os.environ["SNOWPARK_LOGICAL_PLAN_PLOTTING_THRESHOLD"] = original_threshold
else:
del os.environ["SNOWPARK_LOGICAL_PLAN_PLOTTING_THRESHOLD"]
Loading