Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update metadata export logic for join derivation #879

Merged
merged 32 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions spark/src/main/scala/ai/chronon/spark/Analyzer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import ai.chronon.spark.Driver.parseConf
import com.yahoo.memory.Memory
import com.yahoo.sketches.ArrayOfStringsSerDe
import com.yahoo.sketches.frequencies.{ErrorType, ItemsSketch}
import org.apache.spark.sql.{DataFrame, Row, types}
import org.apache.spark.sql.{Column, DataFrame, Row, types}
import org.apache.spark.sql.functions.{col, from_unixtime, lit, sum, when}
import org.apache.spark.sql.types.{StringType, StructType}
import ai.chronon.api.DataModel.{DataModel, Entities, Events}
Expand Down Expand Up @@ -407,8 +407,38 @@ class Analyzer(tableUtils: TableUtils,
)
}
}
// Derive the join online fetching output schema with metadata
val aggMetadata: ListBuffer[AggregationMetadata] = if (joinConf.hasDerivations) {
val keyColumns: List[String] = joinConf.joinParts.toScala
.flatMap(joinPart => {
val keyCols: Seq[String] = joinPart.groupBy.keyColumns.toScala
if (joinPart.keyMapping == null || joinPart.keyMapping.isEmpty) keyCols
else {
keyCols.map(key => joinPart.keyMapping.getOrDefault(key, key))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yuli-han keyMapping is a mapping from left_key to right_key, but here you are doing the reverse

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's include this in a unit test

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for catching! Updated the code and also added a unit test to cover it. Also tested locally using a manually created testing config and attached the result in the above testing result doc.

}
})
.distinct
val tsDsSchema: Map[String, DataType] = {
Map("ts" -> api.StringType, "ds" -> api.StringType)
}
val sparkSchema = {
val keySchema = leftSchema.filter(tup => keyColumns.contains(tup._1))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we will need to handle key mapping here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

handled in above code

val schema: Seq[(String, DataType)] = keySchema.toSeq ++ rightSchema.toSeq ++ tsDsSchema
StructType(SparkConversions.fromChrononSchema(schema))
}
val dummyOutputDf = tableUtils.sparkSession
.createDataFrame(tableUtils.sparkSession.sparkContext.parallelize(immutable.Seq[Row]()), sparkSchema)
val finalOutputColumns: Array[Column] =
joinConf.derivationsScala.finalOutputColumn(rightSchema.toArray.map(_._1)).toArray
val derivedDummyOutputDf = dummyOutputDf.select(finalOutputColumns: _*)
val columns = SparkConversions.toChrononSchema(
StructType(derivedDummyOutputDf.schema.filterNot(tup => tsDsSchema.contains(tup.name))))
ListBuffer(columns.map { tup => toAggregationMetadata(tup._1, tup._2, joinConf.hasDerivations) }: _*)
} else {
aggregationsMetadata
Copy link
Collaborator

@hzding621 hzding621 Nov 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: rename agg because agg is specific to group_by, but now we have external parts and derivations

  • aggMetadata => joinOutputFieldsMetadata
  • aggregationsMetadata => joinIntermediateFieldsMetadata

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The concept of this aggMetadata is more like output features(not including keys). It will be used as the source data of features on MLI tool. I am thinking of maybe renaming this to featuresMetadata or joinOutputValuesMetadata.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yuli-han rename?

  • aggMetadata => joinOutputValuesMetadata
  • aggregationsMetadata => joinIntermediateValuesMetadata

}
// (schema map showing the names and datatypes, right side feature aggregations metadata for metadata upload)
(leftSchema ++ rightSchema, aggregationsMetadata)
(leftSchema ++ rightSchema, aggMetadata)
}

// validate the schema of the left and right side of the join and make sure the types match
Expand Down
88 changes: 74 additions & 14 deletions spark/src/test/scala/ai/chronon/spark/test/JoinTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,19 @@ package ai.chronon.spark.test

import ai.chronon.aggregator.test.Column
import ai.chronon.api
import ai.chronon.api.{
Accuracy,
Builders,
Constants,
JoinPart,
LongType,
Operation,
PartitionSpec,
StringType,
TimeUnit,
Window
}
import ai.chronon.api.StructField
import ai.chronon.api.Builders.Derivation
import ai.chronon.api.{Accuracy, Builders, Constants, JoinPart, LongType, Operation, PartitionSpec, StringType, TimeUnit, Window}
import ai.chronon.api.Extensions._
import ai.chronon.spark.Extensions._
import ai.chronon.spark.GroupBy.renderDataSourceQuery
import ai.chronon.spark.GroupBy.{logger, renderDataSourceQuery}
import ai.chronon.spark.SemanticHashUtils.{tableHashesChanged, tablesToRecompute}
import ai.chronon.spark._
import ai.chronon.spark.stats.SummaryJob
import com.google.gson.Gson
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{StructType, StringType => SparkStringType}
import org.apache.spark.sql.types.{StructType, LongType => SparkLongType, StringType => SparkStringType}
import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SparkSession}
import org.junit.Assert._
import org.junit.Test
Expand Down Expand Up @@ -1547,4 +1538,73 @@ class JoinTest {
assert(
thrown2.getMessage.contains("Table or view not found") && thrown3.getMessage.contains("Table or view not found"))
}

def testJoinDerivationAnalyzer(): Unit = {
lazy val spark: SparkSession = SparkSessionBuilder.build("JoinTest" + "_" + Random.alphanumeric.take(6).mkString, local = true)
val tableUtils = TableUtils(spark)
val namespace = "test_join_derivation" + "_" + Random.alphanumeric.take(6).mkString
tableUtils.createDatabase(namespace)
val viewsGroupBy = getViewsGroupBy(suffix = "cumulative", makeCumulative = true, namespace)
val joinConf = getEventsEventsTemporal("cumulative", namespace)
joinConf.setDerivations(Seq(
Derivation(
name = "*",
expression = "*"
), Derivation(
name = "test_feature_name",
expression = f"${viewsGroupBy.metaData.name}_time_spent_ms_average"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add some derivations that use ts and ds as inputs?

Copy link
Collaborator

@hzding621 hzding621 Nov 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, can we add a test case for key columns as output?
such as

Derivation(
  name = "event_id",
  expression = "ext_contextual_event_id"
)

Copy link
Collaborator

@pengyu-hou pengyu-hou Dec 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Plus one on the ts and ds expressions. Those values will be derived from the request time.

)
).asJava)


val today = tableUtils.partitionSpec.at(System.currentTimeMillis())
val (_, aggregationsMetadata) =
new Analyzer(tableUtils, joinConf, monthAgo, today).analyzeJoin(joinConf, enableHitter = false)
aggregationsMetadata.foreach(agg => {assertTrue(agg.operation == "Derivation")})
aggregationsMetadata.exists(_.name == "test_feature_name")
}

def testJoinDerivationOnExternalAnalyzer(): Unit = {
lazy val spark: SparkSession = SparkSessionBuilder.build("JoinTest" + "_" + Random.alphanumeric.take(6).mkString, local = true)
val tableUtils = TableUtils(spark)
val namespace = "test_join_derivation" + "_" + Random.alphanumeric.take(6).mkString
tableUtils.createDatabase(namespace)
val joinConfWithExternal = getEventsEventsTemporal("cumulative", namespace)

joinConfWithExternal.setOnlineExternalParts(Seq(
Builders.ExternalPart(
Builders.ContextualSource(
fields = Array(
StructField("user_txn_count_30d", LongType),
StructField("item", StringType)
)
)
)
).asJava
)

joinConfWithExternal.setDerivations(
Seq(
Builders.Derivation(
name = "*"
),
// contextual feature rename
Builders.Derivation(
name = "user_txn_count_30d",
expression = "ext_contextual_user_txn_count_30d"
),
Builders.Derivation(
name = "item",
expression = "ext_contextual_item"
)
).asJava
)

val today = tableUtils.partitionSpec.at(System.currentTimeMillis())
val (_, aggregationsMetadata) =
new Analyzer(tableUtils, joinConfWithExternal, monthAgo, today).analyzeJoin(joinConfWithExternal, enableHitter = false)
aggregationsMetadata.foreach(agg => {assertTrue(agg.operation == "Derivation")})
aggregationsMetadata.exists(_.name == "user_txn_count_30d")
aggregationsMetadata.exists(_.name == "item")
}
}