-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[featurestore] Update custom source sample (#2765)
* Update sample * update * fix lint
- Loading branch information
Showing
3 changed files
with
94 additions
and
18 deletions.
There are no files selected for viewing
46 changes: 46 additions & 0 deletions
46
...tore/featuresets/transactions_custom_source/feature_process_code/transaction_transform.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
from pyspark.sql import functions as F | ||
from pyspark.sql.window import Window | ||
from pyspark.ml import Transformer | ||
from pyspark.sql.dataframe import DataFrame | ||
|
||
|
||
class TransactionFeatureTransformer(Transformer): | ||
def _transform(self, df: DataFrame) -> DataFrame: | ||
days = lambda i: i * 86400 | ||
w_3d = ( | ||
Window.partitionBy("accountID") | ||
.orderBy(F.col("timestamp").cast("long")) | ||
.rangeBetween(-days(3), 0) | ||
) | ||
w_7d = ( | ||
Window.partitionBy("accountID") | ||
.orderBy(F.col("timestamp").cast("long")) | ||
.rangeBetween(-days(7), 0) | ||
) | ||
res = ( | ||
df.withColumn("transaction_7d_count", F.count("transactionID").over(w_7d)) | ||
.withColumn( | ||
"transaction_amount_7d_sum", F.sum("transactionAmount").over(w_7d) | ||
) | ||
.withColumn( | ||
"transaction_amount_7d_avg", F.avg("transactionAmount").over(w_7d) | ||
) | ||
.withColumn("transaction_3d_count", F.count("transactionID").over(w_3d)) | ||
.withColumn( | ||
"transaction_amount_3d_sum", F.sum("transactionAmount").over(w_3d) | ||
) | ||
.withColumn( | ||
"transaction_amount_3d_avg", F.avg("transactionAmount").over(w_3d) | ||
) | ||
.select( | ||
"accountID", | ||
"timestamp", | ||
"transaction_3d_count", | ||
"transaction_amount_3d_sum", | ||
"transaction_amount_3d_avg", | ||
"transaction_7d_count", | ||
"transaction_amount_7d_sum", | ||
"transaction_amount_7d_avg", | ||
) | ||
) | ||
return res |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,16 +6,30 @@ | |
|
||
class CustomSourceTransformer: | ||
def __init__(self, **kwargs): | ||
self.path = kwargs.get( | ||
"path", | ||
"wasbs://[email protected]/feature-store-prp/datasources/transactions-source/*.parquet", | ||
) | ||
self.path = kwargs.get("source_path") | ||
self.timestamp_column_name = kwargs.get("timestamp_column_name") | ||
if not self.path: | ||
raise Exception("`source_path` is not provided") | ||
if not self.timestamp_column_name: | ||
raise Exception("`timestamp_column_name` is not provided") | ||
|
||
def process( | ||
self, start_time: datetime, end_time: datetime, **kwargs | ||
) -> "pyspark.sql.DataFrame": | ||
from pyspark.sql import SparkSession | ||
from pyspark.sql.functions import col, lit, to_timestamp | ||
|
||
spark = SparkSession.builder.getOrCreate() | ||
df = spark.read.parquet(self.path) | ||
df = spark.read.json(self.path) | ||
|
||
if start_time: | ||
df = df.filter( | ||
col(self.timestamp_column_name) >= to_timestamp(lit(start_time)) | ||
) | ||
|
||
if end_time: | ||
df = df.filter( | ||
col(self.timestamp_column_name) < to_timestamp(lit(end_time)) | ||
) | ||
|
||
return df |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,7 +9,7 @@ | |
"Managed feature store supports defining a custom source for data. A custom source definition allows you to define their own logic to load data from any data storage. This allows support for complex scenarios, such as\n", | ||
"- Loading data from multiple tables with a complex join logic.\n", | ||
"- Loading data efficiently from data sources that have a custom partition format.\n", | ||
"- Support for data sources that do not use natively supported formats: `MLTable` and delta table. \n", | ||
"- Support for data sources that do not use natively supported formats, e.g: parquet, `MLTable` and delta table. \n", | ||
" \n", | ||
"In this tutorial you will configure a feature set to consume data from a user-defined custom data source." | ||
] | ||
|
@@ -198,19 +198,30 @@ | |
"\n", | ||
"class CustomSourceTransformer:\n", | ||
" def __init__(self, **kwargs):\n", | ||
" self.path = kwargs.get(\n", | ||
" \"path\",\n", | ||
" \"wasbs://[email protected]/feature-store-prp/datasources/transactions-source/*.parquet\",\n", | ||
" )\n", | ||
" self.path = kwargs.get(\"source_path\")\n", | ||
" self.timestamp_column_name = kwargs.get(\"timestamp_column_name\")\n", | ||
" if not self.path:\n", | ||
" raise Exception(\"`source_path` is not provided\")\n", | ||
" if not self.timestamp_column_name:\n", | ||
" raise Exception(\"`timestamp_column_name` is not provided\")\n", | ||
"\n", | ||
" def process(\n", | ||
" self, start_time: datetime, end_time: datetime, **kwargs\n", | ||
" ) -> \"pyspark.sql.DataFrame\":\n", | ||
" from pyspark.sql import SparkSession\n", | ||
" from pyspark.sql.functions import col, lit, to_timestamp\n", | ||
"\n", | ||
" spark = SparkSession.builder.getOrCreate()\n", | ||
" df = spark.read.parquet(self.path)\n", | ||
" df = spark.read.json(self.path)\n", | ||
"\n", | ||
" if start_time:\n", | ||
" df = df.filter(col(self.timestamp_column_name) >= to_timestamp(lit(start_time)))\n", | ||
"\n", | ||
" if end_time:\n", | ||
" df = df.filter(col(self.timestamp_column_name) < to_timestamp(lit(end_time)))\n", | ||
"\n", | ||
" return df\n", | ||
"\n", | ||
"```" | ||
] | ||
}, | ||
|
@@ -238,29 +249,34 @@ | |
" TransformationCode,\n", | ||
" Column,\n", | ||
" ColumnType,\n", | ||
" SourceType,\n", | ||
" DateTimeOffset,\n", | ||
" TimestampColumn,\n", | ||
")\n", | ||
"\n", | ||
"transactions_source_process_code_path = (\n", | ||
" root_dir\n", | ||
" + \"/featurestore/featuresets/transactions_custom_source/source_process_code\"\n", | ||
")\n", | ||
"transactions_feature_transform_code_path = (\n", | ||
" root_dir\n", | ||
" + \"/featurestore/featuresets/transactions_custom_source/feature_process_code\"\n", | ||
")\n", | ||
"\n", | ||
"udf_featureset_spec = create_feature_set_spec(\n", | ||
" source=CustomFeatureSource(\n", | ||
" kwargs={\n", | ||
" \"path\": \"wasbs://[email protected]/feature-store-prp/datasources/transactions-source/*.parquet\",\n", | ||
" \"source_path\": \"wasbs://[email protected]/feature-store-prp/datasources/transactions-source-json/*.json\",\n", | ||
" \"timestamp_column_name\": \"timestamp\",\n", | ||
" },\n", | ||
" timestamp_column=TimestampColumn(name=\"timestamp\"),\n", | ||
" source_delay=DateTimeOffset(days=0, hours=0, minutes=20),\n", | ||
" source_process_code=SourceProcessCode(\n", | ||
" path=source_transform_code_path,\n", | ||
" process_class=\"source_transformation.CustomSourceTransformer\",\n", | ||
" path=transactions_source_process_code_path,\n", | ||
" process_class=\"source_process.CustomSourceTransformer\",\n", | ||
" ),\n", | ||
" ),\n", | ||
" transformation_code=TransformationCode(\n", | ||
" path=transactions_featureset_code_path,\n", | ||
" feature_transformation=TransformationCode(\n", | ||
" path=transactions_feature_transform_code_path,\n", | ||
" transformer_class=\"transaction_transform.TransactionFeatureTransformer\",\n", | ||
" ),\n", | ||
" index_columns=[Column(name=\"accountID\", type=ColumnType.string)],\n", | ||
|
@@ -492,7 +508,7 @@ | |
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.7.13" | ||
"version": "3.8.13" | ||
}, | ||
"microsoft": { | ||
"host": { | ||
|