Skip to content

Commit

Permalink
[featurestore] Update custom source sample (#2765)
Browse files Browse the repository at this point in the history
* Update sample

* update

* fix lint
  • Loading branch information
bastrik authored Oct 27, 2023
1 parent ebe993c commit 7128753
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 18 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from pyspark.sql import functions as F
from pyspark.sql.window import Window
from pyspark.ml import Transformer
from pyspark.sql.dataframe import DataFrame


class TransactionFeatureTransformer(Transformer):
def _transform(self, df: DataFrame) -> DataFrame:
days = lambda i: i * 86400
w_3d = (
Window.partitionBy("accountID")
.orderBy(F.col("timestamp").cast("long"))
.rangeBetween(-days(3), 0)
)
w_7d = (
Window.partitionBy("accountID")
.orderBy(F.col("timestamp").cast("long"))
.rangeBetween(-days(7), 0)
)
res = (
df.withColumn("transaction_7d_count", F.count("transactionID").over(w_7d))
.withColumn(
"transaction_amount_7d_sum", F.sum("transactionAmount").over(w_7d)
)
.withColumn(
"transaction_amount_7d_avg", F.avg("transactionAmount").over(w_7d)
)
.withColumn("transaction_3d_count", F.count("transactionID").over(w_3d))
.withColumn(
"transaction_amount_3d_sum", F.sum("transactionAmount").over(w_3d)
)
.withColumn(
"transaction_amount_3d_avg", F.avg("transactionAmount").over(w_3d)
)
.select(
"accountID",
"timestamp",
"transaction_3d_count",
"transaction_amount_3d_sum",
"transaction_amount_3d_avg",
"transaction_7d_count",
"transaction_amount_7d_sum",
"transaction_amount_7d_avg",
)
)
return res
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,30 @@

class CustomSourceTransformer:
def __init__(self, **kwargs):
self.path = kwargs.get(
"path",
"wasbs://[email protected]/feature-store-prp/datasources/transactions-source/*.parquet",
)
self.path = kwargs.get("source_path")
self.timestamp_column_name = kwargs.get("timestamp_column_name")
if not self.path:
raise Exception("`source_path` is not provided")
if not self.timestamp_column_name:
raise Exception("`timestamp_column_name` is not provided")

def process(
self, start_time: datetime, end_time: datetime, **kwargs
) -> "pyspark.sql.DataFrame":
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, lit, to_timestamp

spark = SparkSession.builder.getOrCreate()
df = spark.read.parquet(self.path)
df = spark.read.json(self.path)

if start_time:
df = df.filter(
col(self.timestamp_column_name) >= to_timestamp(lit(start_time))
)

if end_time:
df = df.filter(
col(self.timestamp_column_name) < to_timestamp(lit(end_time))
)

return df
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"Managed feature store supports defining a custom source for data. A custom source definition allows you to define their own logic to load data from any data storage. This allows support for complex scenarios, such as\n",
"- Loading data from multiple tables with a complex join logic.\n",
"- Loading data efficiently from data sources that have a custom partition format.\n",
"- Support for data sources that do not use natively supported formats: `MLTable` and delta table. \n",
"- Support for data sources that do not use natively supported formats, e.g: parquet, `MLTable` and delta table. \n",
" \n",
"In this tutorial you will configure a feature set to consume data from a user-defined custom data source."
]
Expand Down Expand Up @@ -198,19 +198,30 @@
"\n",
"class CustomSourceTransformer:\n",
" def __init__(self, **kwargs):\n",
" self.path = kwargs.get(\n",
" \"path\",\n",
" \"wasbs://[email protected]/feature-store-prp/datasources/transactions-source/*.parquet\",\n",
" )\n",
" self.path = kwargs.get(\"source_path\")\n",
" self.timestamp_column_name = kwargs.get(\"timestamp_column_name\")\n",
" if not self.path:\n",
" raise Exception(\"`source_path` is not provided\")\n",
" if not self.timestamp_column_name:\n",
" raise Exception(\"`timestamp_column_name` is not provided\")\n",
"\n",
" def process(\n",
" self, start_time: datetime, end_time: datetime, **kwargs\n",
" ) -> \"pyspark.sql.DataFrame\":\n",
" from pyspark.sql import SparkSession\n",
" from pyspark.sql.functions import col, lit, to_timestamp\n",
"\n",
" spark = SparkSession.builder.getOrCreate()\n",
" df = spark.read.parquet(self.path)\n",
" df = spark.read.json(self.path)\n",
"\n",
" if start_time:\n",
" df = df.filter(col(self.timestamp_column_name) >= to_timestamp(lit(start_time)))\n",
"\n",
" if end_time:\n",
" df = df.filter(col(self.timestamp_column_name) < to_timestamp(lit(end_time)))\n",
"\n",
" return df\n",
"\n",
"```"
]
},
Expand Down Expand Up @@ -238,29 +249,34 @@
" TransformationCode,\n",
" Column,\n",
" ColumnType,\n",
" SourceType,\n",
" DateTimeOffset,\n",
" TimestampColumn,\n",
")\n",
"\n",
"transactions_source_process_code_path = (\n",
" root_dir\n",
" + \"/featurestore/featuresets/transactions_custom_source/source_process_code\"\n",
")\n",
"transactions_feature_transform_code_path = (\n",
" root_dir\n",
" + \"/featurestore/featuresets/transactions_custom_source/feature_process_code\"\n",
")\n",
"\n",
"udf_featureset_spec = create_feature_set_spec(\n",
" source=CustomFeatureSource(\n",
" kwargs={\n",
" \"path\": \"wasbs://[email protected]/feature-store-prp/datasources/transactions-source/*.parquet\",\n",
" \"source_path\": \"wasbs://[email protected]/feature-store-prp/datasources/transactions-source-json/*.json\",\n",
" \"timestamp_column_name\": \"timestamp\",\n",
" },\n",
" timestamp_column=TimestampColumn(name=\"timestamp\"),\n",
" source_delay=DateTimeOffset(days=0, hours=0, minutes=20),\n",
" source_process_code=SourceProcessCode(\n",
" path=source_transform_code_path,\n",
" process_class=\"source_transformation.CustomSourceTransformer\",\n",
" path=transactions_source_process_code_path,\n",
" process_class=\"source_process.CustomSourceTransformer\",\n",
" ),\n",
" ),\n",
" transformation_code=TransformationCode(\n",
" path=transactions_featureset_code_path,\n",
" feature_transformation=TransformationCode(\n",
" path=transactions_feature_transform_code_path,\n",
" transformer_class=\"transaction_transform.TransactionFeatureTransformer\",\n",
" ),\n",
" index_columns=[Column(name=\"accountID\", type=ColumnType.string)],\n",
Expand Down Expand Up @@ -492,7 +508,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
"version": "3.8.13"
},
"microsoft": {
"host": {
Expand Down

0 comments on commit 7128753

Please sign in to comment.