Skip to content

Commit

Permalink
move temp view detection
Browse files Browse the repository at this point in the history
  • Loading branch information
nastra committed Jan 25, 2024
1 parent 7b86876 commit 71a00a4
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.logical.View
import org.apache.spark.sql.catalyst.plans.logical.views.CreateIcebergView
import org.apache.spark.sql.connector.catalog.Identifier
import org.apache.spark.sql.connector.catalog.ViewCatalog
Expand All @@ -36,7 +34,6 @@ object CheckViews extends (LogicalPlan => Unit) {
case CreateIcebergView(ResolvedIdentifier(_: ViewCatalog, ident), _, query, columnAliases, _,
queryColumnNames, _, _, _, _, _) =>
verifyColumnCount(ident, columnAliases, query)
verifyTemporaryObjectsDontExist(ident, query)
SchemaUtils.checkColumnNameDuplication(queryColumnNames, SQLConf.get.resolver)

case _ => // OK
Expand All @@ -62,44 +59,4 @@ object CheckViews extends (LogicalPlan => Unit) {
}
}
}

/**
* Permanent views are not allowed to reference temp objects
*/
private def verifyTemporaryObjectsDontExist(
name: Identifier,
child: LogicalPlan): Unit = {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._

val tempViews = collectTemporaryViews(child)
tempViews.foreach { nameParts =>
throw new AnalysisException(
errorClass = "INVALID_TEMP_OBJ_REFERENCE",
messageParameters = Map(
"obj" -> "VIEW",
"objName" -> name.name(),
"tempObj" -> "VIEW",
"tempObjName" -> nameParts.quoted))
}

// TODO: check for temp function names
}


/**
* Collect all temporary views and return the identifiers separately
*/
private def collectTemporaryViews(child: LogicalPlan): Seq[Seq[String]] = {
def collectTempViews(child: LogicalPlan): Seq[Seq[String]] = {
child.flatMap {
case view: View if view.isTempView => Seq(view.desc.identifier.nameParts)
case plan => plan.expressions.flatMap(_.flatMap {
case e: SubqueryExpression => collectTempViews(e.plan)
case _ => Seq.empty
})
}.distinct
}

collectTempViews(child)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,19 @@

package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
import org.apache.spark.sql.catalyst.plans.logical.CreateView
import org.apache.spark.sql.catalyst.plans.logical.DropView
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.logical.View
import org.apache.spark.sql.catalyst.plans.logical.views.CreateIcebergView
import org.apache.spark.sql.catalyst.plans.logical.views.DropIcebergView
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.connector.catalog.CatalogPlugin
import org.apache.spark.sql.connector.catalog.Identifier
import org.apache.spark.sql.connector.catalog.LookupCatalog
import org.apache.spark.sql.connector.catalog.ViewCatalog

Expand All @@ -45,6 +49,7 @@ case class RewriteViewCommands(spark: SparkSession) extends Rule[LogicalPlan] wi

case CreateView(ResolvedView(resolved), userSpecifiedColumns, comment, properties,
Some(queryText), query, allowExisting, replace) =>
verifyTemporaryObjectsDontExist(resolved.identifier, query)
CreateIcebergView(child = resolved,
queryText = queryText,
query = query,
Expand Down Expand Up @@ -76,4 +81,45 @@ case class RewriteViewCommands(spark: SparkSession) extends Rule[LogicalPlan] wi
None
}
}

/**
* Permanent views are not allowed to reference temp objects
*/
private def verifyTemporaryObjectsDontExist(
name: Identifier,
child: LogicalPlan): Unit = {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._

val tempViews = collectTemporaryViews(child)
tempViews.foreach { nameParts =>
throw new AnalysisException(
errorClass = "INVALID_TEMP_OBJ_REFERENCE",
messageParameters = Map(
"obj" -> "VIEW",
"objName" -> name.name(),
"tempObj" -> "VIEW",
"tempObjName" -> nameParts.quoted))
}

// TODO: check for temp function names
}

/**
* Collect all temporary views and return the identifiers separately
*/
private def collectTemporaryViews(child: LogicalPlan): Seq[Seq[String]] = {
def collectTempViews(child: LogicalPlan): Seq[Seq[String]] = {
child.flatMap {
case unresolved: UnresolvedRelation if isTempView(unresolved.multipartIdentifier) =>
Seq(unresolved.multipartIdentifier)
case view: View if view.isTempView => Seq(view.desc.identifier.nameParts)
case plan => plan.expressions.flatMap(_.flatMap {
case e: SubqueryExpression => collectTempViews(e.plan)
case _ => Seq.empty
})
}.distinct
}

collectTempViews(child)
}
}

0 comments on commit 71a00a4

Please sign in to comment.