Skip to content

Commit

Permalink
feat: Add Entity df in format of a Spark Dataframe instead of just pd…
Browse files Browse the repository at this point in the history
….DataFrame or string for SparkOfflineStore (#3988)

* remove unused parameter when init sparksource

Signed-off-by: tanlocnguyen <[email protected]>

* feat: add entity df to SparkOfflineStore when get_historical_features

Signed-off-by: tanlocnguyen <[email protected]>

* fix: lint error

Signed-off-by: tanlocnguyen <[email protected]>

---------

Signed-off-by: tanlocnguyen <[email protected]>
Co-authored-by: tanlocnguyen <[email protected]>
  • Loading branch information
ElliotNguyen68 and ElliotNguyen68 authored Mar 6, 2024
1 parent f604af9 commit 43b2c28
Showing 1 changed file with 17 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def get_historical_features(
config: RepoConfig,
feature_views: List[FeatureView],
feature_refs: List[str],
entity_df: Union[pandas.DataFrame, str],
entity_df: Union[pandas.DataFrame, str, pyspark.sql.DataFrame],
registry: Registry,
project: str,
full_feature_names: bool = False,
Expand Down Expand Up @@ -473,15 +473,16 @@ def _get_entity_df_event_timestamp_range(
entity_df_event_timestamp.min().to_pydatetime(),
entity_df_event_timestamp.max().to_pydatetime(),
)
elif isinstance(entity_df, str):
elif isinstance(entity_df, str) or isinstance(entity_df, pyspark.sql.DataFrame):
# If the entity_df is a string (SQL query), determine range
# from table
df = spark_session.sql(entity_df).select(entity_df_event_timestamp_col)

# Checks if executing entity sql resulted in any data
if df.rdd.isEmpty():
raise EntitySQLEmptyResults(entity_df)

if isinstance(entity_df, str):
df = spark_session.sql(entity_df).select(entity_df_event_timestamp_col)
# Checks if executing entity sql resulted in any data
if df.rdd.isEmpty():
raise EntitySQLEmptyResults(entity_df)
else:
df = entity_df
# TODO(kzhang132): need utc conversion here.

entity_df_event_timestamp_range = (
Expand All @@ -499,8 +500,11 @@ def _get_entity_schema(
) -> Dict[str, np.dtype]:
if isinstance(entity_df, pd.DataFrame):
return dict(zip(entity_df.columns, entity_df.dtypes))
elif isinstance(entity_df, str):
entity_spark_df = spark_session.sql(entity_df)
elif isinstance(entity_df, str) or isinstance(entity_df, pyspark.sql.DataFrame):
if isinstance(entity_df, str):
entity_spark_df = spark_session.sql(entity_df)
else:
entity_spark_df = entity_df
return dict(
zip(
entity_spark_df.columns,
Expand All @@ -526,6 +530,9 @@ def _upload_entity_df(
elif isinstance(entity_df, str):
spark_session.sql(entity_df).createOrReplaceTempView(table_name)
return
elif isinstance(entity_df, pyspark.sql.DataFrame):
entity_df.createOrReplaceTempView(table_name)
return
else:
raise InvalidEntityType(type(entity_df))

Expand Down

0 comments on commit 43b2c28

Please sign in to comment.