Skip to content

Commit

Permalink
feat: Support for nested timestamp fields in Spark Offline store (#4740)
Browse files Browse the repository at this point in the history
  • Loading branch information
EXPEbdodla authored Nov 6, 2024
1 parent 9bbc1c6 commit d4d94f8
Show file tree
Hide file tree
Showing 7 changed files with 236 additions and 16 deletions.
18 changes: 11 additions & 7 deletions sdk/python/feast/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ class DataSource(ABC):
was created, used for deduplicating rows.
field_mapping (optional): A dictionary mapping of column names in this data
source to feature names in a feature table or view. Only used for feature
columns, not entity or timestamp columns.
columns and timestamp columns, not entity columns.
description (optional) A human-readable description.
tags (optional): A dictionary of key-value pairs to store arbitrary metadata.
owner (optional): The owner of the data source, typically the email of the primary
Expand Down Expand Up @@ -463,9 +463,11 @@ def from_proto(data_source: DataSourceProto):
description=data_source.description,
tags=dict(data_source.tags),
owner=data_source.owner,
batch_source=DataSource.from_proto(data_source.batch_source)
if data_source.batch_source
else None,
batch_source=(
DataSource.from_proto(data_source.batch_source)
if data_source.batch_source
else None
),
)

def to_proto(self) -> DataSourceProto:
Expand Down Expand Up @@ -643,9 +645,11 @@ def from_proto(data_source: DataSourceProto):
description=data_source.description,
tags=dict(data_source.tags),
owner=data_source.owner,
batch_source=DataSource.from_proto(data_source.batch_source)
if data_source.batch_source
else None,
batch_source=(
DataSource.from_proto(data_source.batch_source)
if data_source.batch_source
else None
),
)

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,8 @@ def _map_by_partition(iterator, spark_serialized_artifacts: _SparkSerializedArti
) = spark_serialized_artifacts.unserialize()

if feature_view.batch_source.field_mapping is not None:
# Spark offline store does the field mapping in pull_latest_from_table_or_query() call
# This may be needed in future if this materialization engine supports other offline stores
table = _run_pyarrow_field_mapping(
table, feature_view.batch_source.field_mapping
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from feast.repo_config import FeastConfigBaseModel, RepoConfig
from feast.saved_dataset import SavedDatasetStorage
from feast.type_map import spark_schema_to_np_dtypes
from feast.utils import _get_fields_with_aliases

# Make sure spark warning doesn't raise more than once.
warnings.simplefilter("once", RuntimeWarning)
Expand Down Expand Up @@ -90,16 +91,22 @@ def pull_latest_from_table_or_query(
if created_timestamp_column:
timestamps.append(created_timestamp_column)
timestamp_desc_string = " DESC, ".join(timestamps) + " DESC"
field_string = ", ".join(join_key_columns + feature_name_columns + timestamps)
(fields_with_aliases, aliases) = _get_fields_with_aliases(
fields=join_key_columns + feature_name_columns + timestamps,
field_mappings=data_source.field_mapping,
)

fields_as_string = ", ".join(fields_with_aliases)
aliases_as_string = ", ".join(aliases)

start_date_str = _format_datetime(start_date)
end_date_str = _format_datetime(end_date)
query = f"""
SELECT
{field_string}
{aliases_as_string}
{f", {repr(DUMMY_ENTITY_VAL)} AS {DUMMY_ENTITY_ID}" if not join_key_columns else ""}
FROM (
SELECT {field_string},
SELECT {fields_as_string},
ROW_NUMBER() OVER({partition_by_join_key_string} ORDER BY {timestamp_desc_string}) AS feast_row_
FROM {from_expression} t1
WHERE {timestamp_field} BETWEEN TIMESTAMP('{start_date_str}') AND TIMESTAMP('{end_date_str}')
Expand Down Expand Up @@ -279,14 +286,19 @@ def pull_all_from_table_or_query(
spark_session = get_spark_session_or_start_new_with_repoconfig(
store_config=config.offline_store
)
(fields_with_aliases, aliases) = _get_fields_with_aliases(
fields=join_key_columns + feature_name_columns + [timestamp_field],
field_mappings=data_source.field_mapping,
)

fields_with_alias_string = ", ".join(fields_with_aliases)

fields = ", ".join(join_key_columns + feature_name_columns + [timestamp_field])
from_expression = data_source.get_table_query_string()
start_date = start_date.astimezone(tz=timezone.utc)
end_date = end_date.astimezone(tz=timezone.utc)

query = f"""
SELECT {fields}
SELECT {fields_with_alias_string}
FROM {from_expression}
WHERE {timestamp_field} BETWEEN TIMESTAMP '{start_date}' AND TIMESTAMP '{end_date}'
"""
Expand Down
33 changes: 29 additions & 4 deletions sdk/python/feast/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ def _get_requested_feature_views_to_features_dict(
on_demand_feature_views: List["OnDemandFeatureView"],
) -> Tuple[Dict["FeatureView", List[str]], Dict["OnDemandFeatureView", List[str]]]:
"""Create a dict of FeatureView -> List[Feature] for all requested features.
Set full_feature_names to True to have feature names prefixed by their feature view name."""
Set full_feature_names to True to have feature names prefixed by their feature view name.
"""

feature_views_to_feature_map: Dict["FeatureView", List[str]] = defaultdict(list)
on_demand_feature_views_to_feature_map: Dict["OnDemandFeatureView", List[str]] = (
Expand Down Expand Up @@ -212,6 +213,28 @@ def _run_pyarrow_field_mapping(
return table


def _get_fields_with_aliases(
fields: List[str],
field_mappings: Dict[str, str],
) -> Tuple[List[str], List[str]]:
"""
Get a list of fields with aliases based on the field mappings.
"""
for field in fields:
if "." in field and field not in field_mappings:
raise ValueError(
f"Feature {field} contains a '.' character, which is not allowed in field names. Use field mappings to rename fields."
)
fields_with_aliases = [
f"{field} AS {field_mappings[field]}" if field in field_mappings else field
for field in fields
]
aliases = [
field_mappings[field] if field in field_mappings else field for field in fields
]
return (fields_with_aliases, aliases)


def _coerce_datetime(ts):
"""
Depending on underlying time resolution, arrow to_pydict() sometimes returns pd
Expand Down Expand Up @@ -781,9 +804,11 @@ def _populate_response_from_feature_data(
"""
# Add the feature names to the response.
requested_feature_refs = [
f"{table.projection.name_to_use()}__{feature_name}"
if full_feature_names
else feature_name
(
f"{table.projection.name_to_use()}__{feature_name}"
if full_feature_names
else feature_name
)
for feature_name in requested_features
]
online_features_response.metadata.feature_names.val.extend(requested_feature_refs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1828,3 +1828,51 @@ def test_apply_entity_to_sql_registry_and_reinitialize_sql_registry(test_registr

updated_test_registry.teardown()
test_registry.teardown()


@pytest.mark.integration
def test_commit_for_read_only_user():
fd, registry_path = mkstemp()
registry_config = RegistryConfig(path=registry_path, cache_ttl_seconds=600)
write_registry = Registry("project", registry_config, None)

entity = Entity(
name="driver_car_id",
description="Car driver id",
tags={"team": "matchmaking"},
)

project = "project"

# Register Entity without commiting
write_registry.apply_entity(entity, project, commit=False)
assert write_registry.cached_registry_proto
project_obj = write_registry.cached_registry_proto.projects[0]
assert project == Project.from_proto(project_obj).name
assert_project(project, write_registry, True)

# Retrieving the entity should still succeed
entities = write_registry.list_entities(project, allow_cache=True, tags=entity.tags)
entity = entities[0]
assert (
len(entities) == 1
and entity.name == "driver_car_id"
and entity.description == "Car driver id"
and "team" in entity.tags
and entity.tags["team"] == "matchmaking"
)

# commit from the original registry
write_registry.commit()

# Reconstruct the new registry in order to read the newly written store
with mock.patch.object(
Registry,
"commit",
side_effect=Exception("Read only users are not allowed to commit"),
):
read_registry = Registry("project", registry_config, None)
entities = read_registry.list_entities(project, tags=entity.tags)
assert len(entities) == 1

write_registry.teardown()
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from datetime import datetime
from unittest.mock import MagicMock, patch

from feast.infra.offline_stores.contrib.spark_offline_store.spark import (
SparkOfflineStore,
SparkOfflineStoreConfig,
)
from feast.infra.offline_stores.contrib.spark_offline_store.spark_source import (
SparkSource,
)
from feast.infra.offline_stores.offline_store import RetrievalJob
from feast.repo_config import RepoConfig


@patch(
"feast.infra.offline_stores.contrib.spark_offline_store.spark.get_spark_session_or_start_new_with_repoconfig"
)
def test_pull_latest_from_table_with_nested_timestamp_or_query(mock_get_spark_session):
mock_spark_session = MagicMock()
mock_get_spark_session.return_value = mock_spark_session

test_repo_config = RepoConfig(
project="test_project",
registry="test_registry",
provider="local",
offline_store=SparkOfflineStoreConfig(type="spark"),
)

test_data_source = SparkSource(
name="test_nested_batch_source",
description="test_nested_batch_source",
table="offline_store_database_name.offline_store_table_name",
timestamp_field="nested_timestamp",
field_mapping={
"event_header.event_published_datetime_utc": "nested_timestamp",
},
)

# Define the parameters for the method
join_key_columns = ["key1", "key2"]
feature_name_columns = ["feature1", "feature2"]
timestamp_field = "event_header.event_published_datetime_utc"
created_timestamp_column = "created_timestamp"
start_date = datetime(2021, 1, 1)
end_date = datetime(2021, 1, 2)

# Call the method
retrieval_job = SparkOfflineStore.pull_latest_from_table_or_query(
config=test_repo_config,
data_source=test_data_source,
join_key_columns=join_key_columns,
feature_name_columns=feature_name_columns,
timestamp_field=timestamp_field,
created_timestamp_column=created_timestamp_column,
start_date=start_date,
end_date=end_date,
)

expected_query = """SELECT
key1, key2, feature1, feature2, nested_timestamp, created_timestamp
FROM (
SELECT key1, key2, feature1, feature2, event_header.event_published_datetime_utc AS nested_timestamp, created_timestamp,
ROW_NUMBER() OVER(PARTITION BY key1, key2 ORDER BY event_header.event_published_datetime_utc DESC, created_timestamp DESC) AS feast_row_
FROM `offline_store_database_name`.`offline_store_table_name` t1
WHERE event_header.event_published_datetime_utc BETWEEN TIMESTAMP('2021-01-01 00:00:00.000000') AND TIMESTAMP('2021-01-02 00:00:00.000000')
) t2
WHERE feast_row_ = 1""" # noqa: W293

assert isinstance(retrieval_job, RetrievalJob)
assert retrieval_job.query.strip() == expected_query.strip()


@patch(
"feast.infra.offline_stores.contrib.spark_offline_store.spark.get_spark_session_or_start_new_with_repoconfig"
)
def test_pull_latest_from_table_without_nested_timestamp_or_query(
mock_get_spark_session,
):
mock_spark_session = MagicMock()
mock_get_spark_session.return_value = mock_spark_session

test_repo_config = RepoConfig(
project="test_project",
registry="test_registry",
provider="local",
offline_store=SparkOfflineStoreConfig(type="spark"),
)

test_data_source = SparkSource(
name="test_batch_source",
description="test_nested_batch_source",
table="offline_store_database_name.offline_store_table_name",
timestamp_field="event_published_datetime_utc",
)

# Define the parameters for the method
join_key_columns = ["key1", "key2"]
feature_name_columns = ["feature1", "feature2"]
timestamp_field = "event_published_datetime_utc"
created_timestamp_column = "created_timestamp"
start_date = datetime(2021, 1, 1)
end_date = datetime(2021, 1, 2)

# Call the method
retrieval_job = SparkOfflineStore.pull_latest_from_table_or_query(
config=test_repo_config,
data_source=test_data_source,
join_key_columns=join_key_columns,
feature_name_columns=feature_name_columns,
timestamp_field=timestamp_field,
created_timestamp_column=created_timestamp_column,
start_date=start_date,
end_date=end_date,
)

expected_query = """SELECT
key1, key2, feature1, feature2, event_published_datetime_utc, created_timestamp
FROM (
SELECT key1, key2, feature1, feature2, event_published_datetime_utc, created_timestamp,
ROW_NUMBER() OVER(PARTITION BY key1, key2 ORDER BY event_published_datetime_utc DESC, created_timestamp DESC) AS feast_row_
FROM `offline_store_database_name`.`offline_store_table_name` t1
WHERE event_published_datetime_utc BETWEEN TIMESTAMP('2021-01-01 00:00:00.000000') AND TIMESTAMP('2021-01-02 00:00:00.000000')
) t2
WHERE feast_row_ = 1""" # noqa: W293

assert isinstance(retrieval_job, RetrievalJob)
assert retrieval_job.query.strip() == expected_query.strip()

0 comments on commit d4d94f8

Please sign in to comment.