-
Notifications
You must be signed in to change notification settings - Fork 56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update metadata export logic for join derivation #879
Changes from 31 commits
792b3a2
78d253c
ed1ae62
71629d8
8390305
1d113fa
70d8d1f
e14f6d0
4b36bc7
38f2ffa
bcbd162
7e3937e
09c3e87
b2394da
bc80fec
3f4ac12
a61a944
bbb09ce
68e394a
30f4624
cdeb1ac
4eaa63e
0c40fa4
1e3889f
9c1efd5
1071563
561a38b
69a640a
6a5855c
5e3c5b5
e628d34
f4deba2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,14 +25,14 @@ import ai.chronon.spark.Driver.parseConf | |
import com.yahoo.memory.Memory | ||
import com.yahoo.sketches.ArrayOfStringsSerDe | ||
import com.yahoo.sketches.frequencies.{ErrorType, ItemsSketch} | ||
import org.apache.spark.sql.{DataFrame, Row, types} | ||
import org.apache.spark.sql.{Column, DataFrame, Row, types} | ||
import org.apache.spark.sql.functions.{col, from_unixtime, lit, sum, when} | ||
import org.apache.spark.sql.types.{StringType, StructType} | ||
import ai.chronon.api.DataModel.{DataModel, Entities, Events} | ||
|
||
import scala.collection.{Seq, immutable, mutable} | ||
import scala.collection.mutable.ListBuffer | ||
import scala.util.ScalaJavaConversions.ListOps | ||
import scala.util.ScalaJavaConversions.{ListOps, MapOps} | ||
|
||
//@SerialVersionUID(3457890987L) | ||
//class ItemSketchSerializable(var mapSize: Int) extends ItemsSketch[String](mapSize) with Serializable {} | ||
|
@@ -407,8 +407,45 @@ class Analyzer(tableUtils: TableUtils, | |
) | ||
} | ||
} | ||
// Derive the join online fetching output schema with metadata | ||
val aggMetadata: ListBuffer[AggregationMetadata] = if (joinConf.hasDerivations) { | ||
val keyColumns: List[String] = joinConf.joinParts.toScala | ||
.flatMap(joinPart => { | ||
val keyCols: Seq[String] = joinPart.groupBy.keyColumns.toScala | ||
if (joinPart.keyMapping == null) keyCols | ||
else { | ||
keyCols.map(key => { | ||
val findKey = joinPart.keyMapping.toScala.find(_._2 == key) | ||
if (findKey.isDefined) { | ||
findKey.get._1 | ||
} else { | ||
key | ||
} | ||
}) | ||
} | ||
}) | ||
.distinct | ||
val tsDsSchema: Map[String, DataType] = { | ||
Map("ts" -> api.StringType, "ds" -> api.StringType) | ||
} | ||
val sparkSchema = { | ||
val keySchema = leftSchema.filter(tup => keyColumns.contains(tup._1)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we will need to handle key mapping here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. handled in above code |
||
val schema: Seq[(String, DataType)] = keySchema.toSeq ++ rightSchema.toSeq ++ tsDsSchema | ||
StructType(SparkConversions.fromChrononSchema(schema)) | ||
} | ||
val dummyOutputDf = tableUtils.sparkSession | ||
.createDataFrame(tableUtils.sparkSession.sparkContext.parallelize(immutable.Seq[Row]()), sparkSchema) | ||
val finalOutputColumns: Array[Column] = | ||
joinConf.derivationsScala.finalOutputColumn(rightSchema.toArray.map(_._1)).toArray | ||
val derivedDummyOutputDf = dummyOutputDf.select(finalOutputColumns: _*) | ||
val columns = SparkConversions.toChrononSchema( | ||
StructType(derivedDummyOutputDf.schema.filterNot(tup => tsDsSchema.contains(tup.name)))) | ||
ListBuffer(columns.map { tup => toAggregationMetadata(tup._1, tup._2, joinConf.hasDerivations) }: _*) | ||
} else { | ||
aggregationsMetadata | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: rename agg because agg is specific to group_by, but now we have external parts and derivations
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The concept of this aggMetadata is more like output features(not including keys). It will be used as the source data of features on MLI tool. I am thinking of maybe renaming this to featuresMetadata or joinOutputValuesMetadata. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @yuli-han rename?
|
||
} | ||
// (schema map showing the names and datatypes, right side feature aggregations metadata for metadata upload) | ||
(leftSchema ++ rightSchema, aggregationsMetadata) | ||
(leftSchema ++ rightSchema, aggMetadata) | ||
} | ||
|
||
// validate the schema of the left and right side of the join and make sure the types match | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,28 +18,19 @@ package ai.chronon.spark.test | |
|
||
import ai.chronon.aggregator.test.Column | ||
import ai.chronon.api | ||
import ai.chronon.api.{ | ||
Accuracy, | ||
Builders, | ||
Constants, | ||
JoinPart, | ||
LongType, | ||
Operation, | ||
PartitionSpec, | ||
StringType, | ||
TimeUnit, | ||
Window | ||
} | ||
import ai.chronon.api.StructField | ||
import ai.chronon.api.Builders.Derivation | ||
import ai.chronon.api.{Accuracy, Builders, Constants, JoinPart, LongType, Operation, PartitionSpec, StringType, TimeUnit, Window} | ||
import ai.chronon.api.Extensions._ | ||
import ai.chronon.spark.Extensions._ | ||
import ai.chronon.spark.GroupBy.renderDataSourceQuery | ||
import ai.chronon.spark.GroupBy.{logger, renderDataSourceQuery} | ||
import ai.chronon.spark.SemanticHashUtils.{tableHashesChanged, tablesToRecompute} | ||
import ai.chronon.spark._ | ||
import ai.chronon.spark.stats.SummaryJob | ||
import com.google.gson.Gson | ||
import org.apache.spark.rdd.RDD | ||
import org.apache.spark.sql.functions._ | ||
import org.apache.spark.sql.types.{StructType, StringType => SparkStringType} | ||
import org.apache.spark.sql.types.{StructType, LongType => SparkLongType, StringType => SparkStringType} | ||
import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SparkSession} | ||
import org.junit.Assert._ | ||
import org.junit.Test | ||
|
@@ -976,6 +967,44 @@ class JoinTest { | |
) | ||
} | ||
|
||
private def getViewsGroupByWithKeyMapping(suffix: String, makeCumulative: Boolean = false, namespace: String) = { | ||
val viewsSchema = List( | ||
Column("user", api.StringType, 10000), | ||
Column("item_id", api.StringType, 100), | ||
Column("time_spent_ms", api.LongType, 5000) | ||
) | ||
val spark: SparkSession = SparkSessionBuilder.build("JoinTest" + "_" + Random.alphanumeric.take(6).mkString, local = true) | ||
val tableUtils = TableUtils(spark) | ||
val viewsTable = s"$namespace.view_$suffix" | ||
val df = DataFrameGen.events(spark, viewsSchema, count = 1000, partitions = 200) | ||
|
||
val viewsSource = Builders.Source.events( | ||
table = viewsTable, | ||
query = Builders.Query(selects = Builders.Selects("time_spent_ms"), startPartition = yearAgo), | ||
isCumulative = makeCumulative | ||
) | ||
|
||
val dfToWrite = if (makeCumulative) { | ||
// Move all events into latest partition and set isCumulative on thrift object | ||
df.drop("ds").withColumn("ds", lit(today)) | ||
} else { df } | ||
|
||
spark.sql(s"DROP TABLE IF EXISTS $viewsTable") | ||
dfToWrite.save(viewsTable, Map("tblProp1" -> "1")) | ||
|
||
Builders.GroupBy( | ||
sources = Seq(viewsSource), | ||
keyColumns = Seq("item_id"), | ||
aggregations = Seq( | ||
Builders.Aggregation(operation = Operation.AVERAGE, inputColumn = "time_spent_ms"), | ||
Builders.Aggregation(operation = Operation.MIN, inputColumn = "ts"), | ||
Builders.Aggregation(operation = Operation.MAX, inputColumn = "ts") | ||
), | ||
metaData = Builders.MetaData(name = "unit_test.item_views_key_mapping", namespace = namespace, team = "item_team"), | ||
accuracy = Accuracy.TEMPORAL | ||
) | ||
} | ||
|
||
private def getEventsEventsTemporal(nameSuffix: String = "", namespace: String) = { | ||
val spark: SparkSession = SparkSessionBuilder.build("JoinTest" + "_" + Random.alphanumeric.take(6).mkString, local = true) | ||
val tableUtils = TableUtils(spark) | ||
|
@@ -1547,4 +1576,110 @@ class JoinTest { | |
assert( | ||
thrown2.getMessage.contains("Table or view not found") && thrown3.getMessage.contains("Table or view not found")) | ||
} | ||
|
||
def testJoinDerivationAnalyzer(): Unit = { | ||
lazy val spark: SparkSession = SparkSessionBuilder.build("JoinTest" + "_" + Random.alphanumeric.take(6).mkString, local = true) | ||
val tableUtils = TableUtils(spark) | ||
val namespace = "test_join_derivation" + "_" + Random.alphanumeric.take(6).mkString | ||
tableUtils.createDatabase(namespace) | ||
val viewsGroupBy = getViewsGroupBy(suffix = "cumulative", makeCumulative = true, namespace) | ||
val joinConf = getEventsEventsTemporal("cumulative", namespace) | ||
joinConf.setDerivations(Seq( | ||
Derivation( | ||
name = "*", | ||
expression = "*" | ||
), Derivation( | ||
name = "test_feature_name", | ||
expression = f"${viewsGroupBy.metaData.name}_time_spent_ms_average" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add some derivations that use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similarly, can we add a test case for key columns as output?
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Plus one on the ts and ds expressions. Those values will be derived from the request time. |
||
) | ||
).asJava) | ||
|
||
|
||
val today = tableUtils.partitionSpec.at(System.currentTimeMillis()) | ||
val (_, aggregationsMetadata) = | ||
new Analyzer(tableUtils, joinConf, monthAgo, today).analyzeJoin(joinConf, enableHitter = false) | ||
aggregationsMetadata.foreach(agg => {assertTrue(agg.operation == "Derivation")}) | ||
aggregationsMetadata.exists(_.name == "test_feature_name") | ||
} | ||
|
||
def testJoinDerivationOnExternalAnalyzer(): Unit = { | ||
lazy val spark: SparkSession = SparkSessionBuilder.build("JoinTest" + "_" + Random.alphanumeric.take(6).mkString, local = true) | ||
val tableUtils = TableUtils(spark) | ||
val namespace = "test_join_derivation" + "_" + Random.alphanumeric.take(6).mkString | ||
tableUtils.createDatabase(namespace) | ||
val joinConfWithExternal = getEventsEventsTemporal("cumulative", namespace) | ||
|
||
joinConfWithExternal.setOnlineExternalParts(Seq( | ||
Builders.ExternalPart( | ||
Builders.ContextualSource( | ||
fields = Array( | ||
StructField("user_txn_count_30d", LongType), | ||
StructField("item", StringType) | ||
) | ||
) | ||
) | ||
).asJava | ||
) | ||
|
||
joinConfWithExternal.setDerivations( | ||
Seq( | ||
Builders.Derivation( | ||
name = "*" | ||
), | ||
// contextual feature rename | ||
Builders.Derivation( | ||
name = "user_txn_count_30d", | ||
expression = "ext_contextual_user_txn_count_30d" | ||
), | ||
Builders.Derivation( | ||
name = "item", | ||
expression = "ext_contextual_item" | ||
) | ||
).asJava | ||
) | ||
|
||
val today = tableUtils.partitionSpec.at(System.currentTimeMillis()) | ||
val (_, aggregationsMetadata) = | ||
new Analyzer(tableUtils, joinConfWithExternal, monthAgo, today).analyzeJoin(joinConfWithExternal, enableHitter = false) | ||
aggregationsMetadata.foreach(agg => {assertTrue(agg.operation == "Derivation")}) | ||
aggregationsMetadata.exists(_.name == "user_txn_count_30d") | ||
aggregationsMetadata.exists(_.name == "item") | ||
} | ||
|
||
def testJoinDerivationWithKeyAnalyzer(): Unit = { | ||
lazy val spark: SparkSession = SparkSessionBuilder.build("JoinTest" + "_" + Random.alphanumeric.take(6).mkString, local = true) | ||
val tableUtils = TableUtils(spark) | ||
val namespace = "test_join_derivation" + "_" + Random.alphanumeric.take(6).mkString | ||
tableUtils.createDatabase(namespace) | ||
val joinConfWithDerivationWithKey = getEventsEventsTemporal("cumulative", namespace) | ||
val viewsGroupBy = getViewsGroupBy(suffix = "cumulative", makeCumulative = true, namespace) | ||
val viewGroupByWithKepMapping = getViewsGroupByWithKeyMapping("cumulative", makeCumulative = true, namespace) | ||
|
||
joinConfWithDerivationWithKey.setJoinParts( | ||
Seq(Builders.JoinPart( | ||
groupBy = viewGroupByWithKepMapping, | ||
keyMapping = Map("item" -> "item_id") | ||
)).asJava | ||
) | ||
|
||
joinConfWithDerivationWithKey.setDerivations( | ||
Seq( | ||
Builders.Derivation( | ||
name = "*" | ||
), | ||
Builders.Derivation( | ||
name = "item", | ||
expression = "item" | ||
) | ||
).asJava | ||
) | ||
|
||
val today = tableUtils.partitionSpec.at(System.currentTimeMillis()) | ||
val (_, aggregationsMetadata) = | ||
new Analyzer(tableUtils, joinConfWithDerivationWithKey, monthAgo, today).analyzeJoin(joinConfWithDerivationWithKey, enableHitter = false) | ||
aggregationsMetadata.foreach(agg => {assertTrue(agg.operation == "Derivation")}) | ||
aggregationsMetadata.exists(_.name == f"${viewsGroupBy.metaData.name}_time_spent_ms_average") | ||
aggregationsMetadata.exists(_.name == f"${viewGroupByWithKepMapping.metaData.name}_time_spent_ms_average") | ||
aggregationsMetadata.exists(_.name == "item") | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use joinPart.rightToLeft: https://github.com/airbnb/chronon/blob/main/api/src/main/scala/ai/chronon/api/Extensions.scala#L726