Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support catalog for Spark Snowflake #432

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 219 additions & 0 deletions src/main/scala/net/snowflake/spark/snowflake/catalog/SfCatalog.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
package net.snowflake.spark.snowflake.catalog

import net.snowflake.spark.snowflake.DefaultJDBCWrapper.DataBaseOperations
import net.snowflake.spark.snowflake.Parameters.{
MergedParameters,
PARAM_SF_DATABASE,
PARAM_SF_DBTABLE,
PARAM_SF_SCHEMA
}
import net.snowflake.spark.snowflake.{DefaultJDBCWrapper, Parameters}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.analysis.{
NoSuchNamespaceException,
NoSuchTableException
}
import org.apache.spark.sql.connector.catalog._
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap

import java.sql.SQLException
import scala.collection.convert.ImplicitConversions.`map AsScala`
import scala.collection.mutable.ArrayBuilder

class SfCatalog extends TableCatalog with Logging with SupportsNamespaces {
var catalogName: String = null
var params: MergedParameters = _
val jdbcWrapper = DefaultJDBCWrapper

override def name(): String = {
require(catalogName != null, "The SfCatalog is not initialed")
catalogName
}

override def initialize(
name: String,
options: CaseInsensitiveStringMap
): Unit = {
val map = options.asCaseSensitiveMap().toMap
// to pass the check
params = Parameters.mergeParameters(
map +
(PARAM_SF_DATABASE -> "__invalid_database") +
(PARAM_SF_SCHEMA -> "__invalid_schema") +
(PARAM_SF_DBTABLE -> "__invalid_dbtable")
)
catalogName = name
}

override def listTables(namespace: Array[String]): Array[Identifier] = {
checkNamespace(namespace)
val catalog = if (namespace.length == 2) namespace(0) else null
val schemaPattern = if (namespace.length == 2) namespace(1) else null
val rs = DefaultJDBCWrapper
.getConnector(params)
.getMetaData()
.getTables(catalog, schemaPattern, "%", Array("TABLE"))
new Iterator[Identifier] {
def hasNext = rs.next()
def next() = Identifier.of(namespace, rs.getString("TABLE_NAME"))
}.toArray

}

override def tableExists(ident: Identifier): Boolean = {
checkNamespace(ident.namespace())
DefaultJDBCWrapper.tableExists(params, getFullTableName(ident))
}

override def dropTable(ident: Identifier): Boolean = {
checkNamespace(ident.namespace())
val conn = DefaultJDBCWrapper.getConnector(params)
conn.dropTable(getFullTableName(ident))
}

override def renameTable(oldIdent: Identifier, newIdent: Identifier): Unit = {
checkNamespace(oldIdent.namespace())
val conn = DefaultJDBCWrapper.getConnector(params)
conn.renameTable(getFullTableName(newIdent), getFullTableName(newIdent))
}

override def loadTable(ident: Identifier): Table = {
checkNamespace(ident.namespace())
val map = params.parameters
params = Parameters.mergeParameters(
map +
(PARAM_SF_DBTABLE -> getTableName(ident)) +
(PARAM_SF_DATABASE -> getDatabase(ident)) +
(PARAM_SF_SCHEMA -> getSchema(ident))
)
try {
SfTable(ident, jdbcWrapper, params)
} catch {
case _: SQLException =>
throw new NoSuchTableException(ident)

}
}

override def alterTable(ident: Identifier, changes: TableChange*): Table = {
throw new UnsupportedOperationException(
"SfCatalog does not support altering table operation"
)
}

override def namespaceExists(namespace: Array[String]): Boolean =
namespace match {
case Array(catalog, schema) =>
val rs = DefaultJDBCWrapper
.getConnector(params)
.getMetaData()
.getSchemas(catalog, schema)

while (rs.next()) {
val tableSchema = rs.getString("TABLE_SCHEM")
if (tableSchema == schema) return true
}
false
case _ => false
}

override def listNamespaces(): Array[Array[String]] = {
val schemaBuilder = ArrayBuilder.make[Array[String]]
val rs = DefaultJDBCWrapper.getConnector(params).getMetaData().getSchemas()
while (rs.next()) {
schemaBuilder += Array(rs.getString("TABLE_SCHEM"))
}
schemaBuilder.result
}

override def listNamespaces(
namespace: Array[String]
): Array[Array[String]] = {
namespace match {
case Array() =>
listNamespaces()
case Array(_, _) if namespaceExists(namespace) =>
Array()
case _ =>
throw new NoSuchNamespaceException(namespace)
}
}

override def loadNamespaceMetadata(
namespace: Array[String]
): java.util.Map[String, String] = {
namespace match {
case Array(catalog, schema) =>
if (!namespaceExists(namespace)) {
throw new NoSuchNamespaceException(
Array(catalog, schema)
)
}
new java.util.HashMap[String, String]()
case _ =>
throw new NoSuchNamespaceException(namespace)
}
}

override def createTable(
ident: Identifier,
schema: StructType,
partitions: Array[Transform],
properties: java.util.Map[String, String]
): Table = {
throw new UnsupportedOperationException(
"SfCatalog does not support creating table operation"
)
}

override def alterNamespace(
namespace: Array[String],
changes: NamespaceChange*
): Unit = {
throw new UnsupportedOperationException(
"SfCatalog does not support altering namespace operation"
)

}

override def dropNamespace(
namespace: Array[String],
cascade: Boolean
): Boolean = {
throw new UnsupportedOperationException(
"SfCatalog does not support dropping namespace operation"
)
}

private def checkNamespace(namespace: Array[String]): Unit = {
// a database and schema comprise a namespace in Snowflake
if (namespace.length != 2) {
throw new NoSuchNamespaceException(namespace)
}
}

override def createNamespace(
namespace: Array[String],
metadata: java.util.Map[String, String]
): Unit = {
throw new UnsupportedOperationException(
"SfCatalog does not support creating namespace operation"
)
}

private def getTableName(ident: Identifier): String = {
(ident.name())
}
private def getDatabase(ident: Identifier): String = {
(ident.namespace())(0)
}
private def getSchema(ident: Identifier): String = {
(ident.namespace())(1)
}
private def getFullTableName(ident: Identifier): String = {
(ident.namespace() :+ ident.name()).mkString(".")

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package net.snowflake.spark.snowflake.catalog

import net.snowflake.spark.snowflake.SnowflakeRelation
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.connector.read.V1Scan
import org.apache.spark.sql.sources.{BaseRelation, Filter, TableScan}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{Row, SQLContext}

case class SfScan(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain how SfScan/SfScanBuilder/SfWriteBuilder/SfTable are used.
They are defined in this PR but I don't understand when and how they are used?
Could you please add some test cases for them?

Copy link
Author

@zhaohehuhu zhaohehuhu Aug 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Just wanna implement some datasource interfaces already inside Spark to enable reading and sinking data on Snowflake by catalog.

relation: SnowflakeRelation,
prunedSchema: StructType,
pushedFilters: Array[Filter]
) extends V1Scan {

override def readSchema(): StructType = prunedSchema

override def toV1TableScan[T <: BaseRelation with TableScan](
context: SQLContext
): T = {
new BaseRelation with TableScan {
override def sqlContext: SQLContext = context
override def schema: StructType = prunedSchema
override def needConversion: Boolean = relation.needConversion
override def buildScan(): RDD[Row] = {
val columnList = prunedSchema.map(_.name).toArray
relation.buildScan(columnList, pushedFilters)
}
}.asInstanceOf[T]
}

override def description(): String = {
super.description() + ", prunedSchema: " + seqToString(prunedSchema) +
", PushedFilters: " + seqToString(pushedFilters)
}

private def seqToString(seq: Seq[Any]): String = seq.mkString("[", ", ", "]")
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package net.snowflake.spark.snowflake.catalog

import net.snowflake.spark.snowflake.Parameters.MergedParameters
import net.snowflake.spark.snowflake.{
FilterPushdown,
JDBCWrapper,
SnowflakeRelation
}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connector.read.{
Scan,
ScanBuilder,
SupportsPushDownFilters,
SupportsPushDownRequiredColumns
}
import org.apache.spark.sql.execution.datasources.PartitioningUtils
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType

case class SfScanBuilder(
session: SparkSession,
schema: StructType,
params: MergedParameters,
jdbcWrapper: JDBCWrapper
) extends ScanBuilder
with SupportsPushDownFilters
with SupportsPushDownRequiredColumns
with Logging {
private val isCaseSensitive = session.sessionState.conf.caseSensitiveAnalysis

private var pushedFilter = Array.empty[Filter]

private var finalSchema = schema

override def pushFilters(filters: Array[Filter]): Array[Filter] = {
val (pushed, unSupported) = filters.partition(filter =>
FilterPushdown
.buildFilterStatement(
schema,
filter,
true
)
.isDefined
)
this.pushedFilter = pushed
unSupported
}

override def pushedFilters(): Array[Filter] = pushedFilter

override def pruneColumns(requiredSchema: StructType): Unit = {
val requiredCols = requiredSchema.fields
.map(PartitioningUtils.getColName(_, isCaseSensitive))
.toSet
val fields = schema.fields.filter { field =>
val colName = PartitioningUtils.getColName(field, isCaseSensitive)
requiredCols.contains(colName)
}
finalSchema = StructType(fields)
}

override def build(): Scan = {
SfScan(
SnowflakeRelation(jdbcWrapper, params, Option(schema))(
session.sqlContext
),
finalSchema,
pushedFilters
)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package net.snowflake.spark.snowflake.catalog

import net.snowflake.spark.snowflake.DefaultJDBCWrapper.DataBaseOperations
import net.snowflake.spark.snowflake.Parameters.MergedParameters
import net.snowflake.spark.snowflake.{DefaultJDBCWrapper, JDBCWrapper}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connector.catalog._
import org.apache.spark.sql.connector.read.ScanBuilder
import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap

import java.sql.Connection
import java.util
import scala.collection.JavaConverters._

case class SfTable(
ident: Identifier,
jdbcWrapper: JDBCWrapper,
params: MergedParameters
) extends Table
with SupportsRead
with SupportsWrite
with Logging {

override def name(): String =
(ident.namespace() :+ ident.name()).mkString(".")

override def schema(): StructType = {
val conn: Connection = DefaultJDBCWrapper.getConnector(params)
try {
conn.tableSchema(name, params)
} finally {
conn.close()
}
}

override def capabilities(): util.Set[TableCapability] = {
Set(
TableCapability.BATCH_READ,
TableCapability.V1_BATCH_WRITE,
TableCapability.TRUNCATE
).asJava
}

override def newScanBuilder(
options: CaseInsensitiveStringMap
): ScanBuilder = {
SfScanBuilder(SparkSession.active, schema, params, jdbcWrapper)
}

override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = {
SfWriterBuilder(jdbcWrapper, params)
}
}
Loading