From 71a00a400c277515e042afe3a30d4cb6269cd392 Mon Sep 17 00:00:00 2001 From: Eduard Tudenhoefner Date: Thu, 25 Jan 2024 17:15:51 +0100 Subject: [PATCH] move temp view detection --- .../sql/catalyst/analysis/CheckViews.scala | 43 ----------------- .../analysis/RewriteViewCommands.scala | 46 +++++++++++++++++++ 2 files changed, 46 insertions(+), 43 deletions(-) diff --git a/spark/v3.5/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckViews.scala b/spark/v3.5/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckViews.scala index 2d1645b95ee9..95f54ccaf724 100644 --- a/spark/v3.5/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckViews.scala +++ b/spark/v3.5/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckViews.scala @@ -20,9 +20,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.SubqueryExpression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.plans.logical.View import org.apache.spark.sql.catalyst.plans.logical.views.CreateIcebergView import org.apache.spark.sql.connector.catalog.Identifier import org.apache.spark.sql.connector.catalog.ViewCatalog @@ -36,7 +34,6 @@ object CheckViews extends (LogicalPlan => Unit) { case CreateIcebergView(ResolvedIdentifier(_: ViewCatalog, ident), _, query, columnAliases, _, queryColumnNames, _, _, _, _, _) => verifyColumnCount(ident, columnAliases, query) - verifyTemporaryObjectsDontExist(ident, query) SchemaUtils.checkColumnNameDuplication(queryColumnNames, SQLConf.get.resolver) case _ => // OK @@ -62,44 +59,4 @@ object CheckViews extends (LogicalPlan => Unit) { } } } - - /** - * Permanent views are not allowed to reference temp objects - */ - private def verifyTemporaryObjectsDontExist( - name: Identifier, - child: LogicalPlan): Unit = { - import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ - - val tempViews = collectTemporaryViews(child) - tempViews.foreach { nameParts => - throw new AnalysisException( - errorClass = "INVALID_TEMP_OBJ_REFERENCE", - messageParameters = Map( - "obj" -> "VIEW", - "objName" -> name.name(), - "tempObj" -> "VIEW", - "tempObjName" -> nameParts.quoted)) - } - - // TODO: check for temp function names - } - - - /** - * Collect all temporary views and return the identifiers separately - */ - private def collectTemporaryViews(child: LogicalPlan): Seq[Seq[String]] = { - def collectTempViews(child: LogicalPlan): Seq[Seq[String]] = { - child.flatMap { - case view: View if view.isTempView => Seq(view.desc.identifier.nameParts) - case plan => plan.expressions.flatMap(_.flatMap { - case e: SubqueryExpression => collectTempViews(e.plan) - case _ => Seq.empty - }) - }.distinct - } - - collectTempViews(child) - } } diff --git a/spark/v3.5/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteViewCommands.scala b/spark/v3.5/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteViewCommands.scala index 7b4883494d96..884f6c9f774f 100644 --- a/spark/v3.5/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteViewCommands.scala +++ b/spark/v3.5/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteViewCommands.scala @@ -19,15 +19,19 @@ package org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.SubqueryExpression import org.apache.spark.sql.catalyst.plans.logical.CreateView import org.apache.spark.sql.catalyst.plans.logical.DropView import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.View import org.apache.spark.sql.catalyst.plans.logical.views.CreateIcebergView import org.apache.spark.sql.catalyst.plans.logical.views.DropIcebergView import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.connector.catalog.CatalogPlugin +import org.apache.spark.sql.connector.catalog.Identifier import org.apache.spark.sql.connector.catalog.LookupCatalog import org.apache.spark.sql.connector.catalog.ViewCatalog @@ -45,6 +49,7 @@ case class RewriteViewCommands(spark: SparkSession) extends Rule[LogicalPlan] wi case CreateView(ResolvedView(resolved), userSpecifiedColumns, comment, properties, Some(queryText), query, allowExisting, replace) => + verifyTemporaryObjectsDontExist(resolved.identifier, query) CreateIcebergView(child = resolved, queryText = queryText, query = query, @@ -76,4 +81,45 @@ case class RewriteViewCommands(spark: SparkSession) extends Rule[LogicalPlan] wi None } } + + /** + * Permanent views are not allowed to reference temp objects + */ + private def verifyTemporaryObjectsDontExist( + name: Identifier, + child: LogicalPlan): Unit = { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + val tempViews = collectTemporaryViews(child) + tempViews.foreach { nameParts => + throw new AnalysisException( + errorClass = "INVALID_TEMP_OBJ_REFERENCE", + messageParameters = Map( + "obj" -> "VIEW", + "objName" -> name.name(), + "tempObj" -> "VIEW", + "tempObjName" -> nameParts.quoted)) + } + + // TODO: check for temp function names + } + + /** + * Collect all temporary views and return the identifiers separately + */ + private def collectTemporaryViews(child: LogicalPlan): Seq[Seq[String]] = { + def collectTempViews(child: LogicalPlan): Seq[Seq[String]] = { + child.flatMap { + case unresolved: UnresolvedRelation if isTempView(unresolved.multipartIdentifier) => + Seq(unresolved.multipartIdentifier) + case view: View if view.isTempView => Seq(view.desc.identifier.nameParts) + case plan => plan.expressions.flatMap(_.flatMap { + case e: SubqueryExpression => collectTempViews(e.plan) + case _ => Seq.empty + }) + }.distinct + } + + collectTempViews(child) + } }