Skip to content

Commit

Permalink
fix: Rewrite Spark materialization engine to use mapInPandas (#3936)
Browse files Browse the repository at this point in the history
rewrite spark materilization engine to use mapInPandas

Signed-off-by: tokoko <[email protected]>
  • Loading branch information
tokoko authored Feb 13, 2024
1 parent 4e450ad commit dbb59ba
Showing 1 changed file with 35 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Callable, List, Literal, Optional, Sequence, Union, cast

import dill
import pandas
import pandas as pd
import pyarrow
from tqdm import tqdm
Expand Down Expand Up @@ -178,9 +179,9 @@ def _materialize_one(
self.repo_config.batch_engine.partitions
)

spark_df.foreachPartition(
lambda x: _process_by_partition(x, spark_serialized_artifacts)
)
spark_df.mapInPandas(
lambda x: _map_by_partition(x, spark_serialized_artifacts), "status int"
).count() # dummy action to force evaluation

return SparkMaterializationJob(
job_id=job_id, status=MaterializationJobStatus.SUCCEEDED
Expand Down Expand Up @@ -225,38 +226,40 @@ def unserialize(self):
return feature_view, online_store, repo_config


def _process_by_partition(rows, spark_serialized_artifacts: _SparkSerializedArtifacts):
"""Load pandas df to online store"""

# convert to pyarrow table
dicts = []
for row in rows:
dicts.append(row.asDict())
def _map_by_partition(iterator, spark_serialized_artifacts: _SparkSerializedArtifacts):
for pdf in iterator:
if pdf.shape[0] == 0:
print("Skipping")
return

df = pd.DataFrame.from_records(dicts)
if df.shape[0] == 0:
print("Skipping")
return
table = pyarrow.Table.from_pandas(pdf)

table = pyarrow.Table.from_pandas(df)
(
feature_view,
online_store,
repo_config,
) = spark_serialized_artifacts.unserialize()

if feature_view.batch_source.field_mapping is not None:
table = _run_pyarrow_field_mapping(
table, feature_view.batch_source.field_mapping
)

# unserialize artifacts
feature_view, online_store, repo_config = spark_serialized_artifacts.unserialize()
join_key_to_value_type = {
entity.name: entity.dtype.to_value_type()
for entity in feature_view.entity_columns
}

if feature_view.batch_source.field_mapping is not None:
table = _run_pyarrow_field_mapping(
table, feature_view.batch_source.field_mapping
rows_to_write = _convert_arrow_to_proto(
table, feature_view, join_key_to_value_type
)
online_store.online_write_batch(
repo_config,
feature_view,
rows_to_write,
lambda x: None,
)

join_key_to_value_type = {
entity.name: entity.dtype.to_value_type()
for entity in feature_view.entity_columns
}

rows_to_write = _convert_arrow_to_proto(table, feature_view, join_key_to_value_type)
online_store.online_write_batch(
repo_config,
feature_view,
rows_to_write,
lambda x: None,
)
yield pd.DataFrame(
[pd.Series(range(1, 2))]
) # dummy result because mapInPandas needs to return something

0 comments on commit dbb59ba

Please sign in to comment.