diff --git a/sdk/python/feast/data_source.py b/sdk/python/feast/data_source.py index bffa9ef74a..25475fcb4c 100644 --- a/sdk/python/feast/data_source.py +++ b/sdk/python/feast/data_source.py @@ -176,7 +176,7 @@ class DataSource(ABC): was created, used for deduplicating rows. field_mapping (optional): A dictionary mapping of column names in this data source to feature names in a feature table or view. Only used for feature - columns, not entity or timestamp columns. + columns and timestamp columns, not entity columns. description (optional) A human-readable description. tags (optional): A dictionary of key-value pairs to store arbitrary metadata. owner (optional): The owner of the data source, typically the email of the primary @@ -463,9 +463,11 @@ def from_proto(data_source: DataSourceProto): description=data_source.description, tags=dict(data_source.tags), owner=data_source.owner, - batch_source=DataSource.from_proto(data_source.batch_source) - if data_source.batch_source - else None, + batch_source=( + DataSource.from_proto(data_source.batch_source) + if data_source.batch_source + else None + ), ) def to_proto(self) -> DataSourceProto: @@ -643,9 +645,11 @@ def from_proto(data_source: DataSourceProto): description=data_source.description, tags=dict(data_source.tags), owner=data_source.owner, - batch_source=DataSource.from_proto(data_source.batch_source) - if data_source.batch_source - else None, + batch_source=( + DataSource.from_proto(data_source.batch_source) + if data_source.batch_source + else None + ), ) @staticmethod diff --git a/sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py b/sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py index 3abb6fffd6..53b29cdfc0 100644 --- a/sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py +++ b/sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py @@ -240,6 +240,8 @@ def _map_by_partition(iterator, spark_serialized_artifacts: _SparkSerializedArti ) = spark_serialized_artifacts.unserialize() if feature_view.batch_source.field_mapping is not None: + # Spark offline store does the field mapping in pull_latest_from_table_or_query() call + # This may be needed in future if this materialization engine supports other offline stores table = _run_pyarrow_field_mapping( table, feature_view.batch_source.field_mapping ) diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py index b462607ae1..aeb9e3cd68 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py @@ -33,6 +33,7 @@ from feast.repo_config import FeastConfigBaseModel, RepoConfig from feast.saved_dataset import SavedDatasetStorage from feast.type_map import spark_schema_to_np_dtypes +from feast.utils import _get_fields_with_aliases # Make sure spark warning doesn't raise more than once. warnings.simplefilter("once", RuntimeWarning) @@ -90,16 +91,22 @@ def pull_latest_from_table_or_query( if created_timestamp_column: timestamps.append(created_timestamp_column) timestamp_desc_string = " DESC, ".join(timestamps) + " DESC" - field_string = ", ".join(join_key_columns + feature_name_columns + timestamps) + (fields_with_aliases, aliases) = _get_fields_with_aliases( + fields=join_key_columns + feature_name_columns + timestamps, + field_mappings=data_source.field_mapping, + ) + + fields_as_string = ", ".join(fields_with_aliases) + aliases_as_string = ", ".join(aliases) start_date_str = _format_datetime(start_date) end_date_str = _format_datetime(end_date) query = f""" SELECT - {field_string} + {aliases_as_string} {f", {repr(DUMMY_ENTITY_VAL)} AS {DUMMY_ENTITY_ID}" if not join_key_columns else ""} FROM ( - SELECT {field_string}, + SELECT {fields_as_string}, ROW_NUMBER() OVER({partition_by_join_key_string} ORDER BY {timestamp_desc_string}) AS feast_row_ FROM {from_expression} t1 WHERE {timestamp_field} BETWEEN TIMESTAMP('{start_date_str}') AND TIMESTAMP('{end_date_str}') @@ -279,14 +286,19 @@ def pull_all_from_table_or_query( spark_session = get_spark_session_or_start_new_with_repoconfig( store_config=config.offline_store ) + (fields_with_aliases, aliases) = _get_fields_with_aliases( + fields=join_key_columns + feature_name_columns + [timestamp_field], + field_mappings=data_source.field_mapping, + ) + + fields_with_alias_string = ", ".join(fields_with_aliases) - fields = ", ".join(join_key_columns + feature_name_columns + [timestamp_field]) from_expression = data_source.get_table_query_string() start_date = start_date.astimezone(tz=timezone.utc) end_date = end_date.astimezone(tz=timezone.utc) query = f""" - SELECT {fields} + SELECT {fields_with_alias_string} FROM {from_expression} WHERE {timestamp_field} BETWEEN TIMESTAMP '{start_date}' AND TIMESTAMP '{end_date}' """ diff --git a/sdk/python/feast/utils.py b/sdk/python/feast/utils.py index 49af4fe717..51d4bf4f2c 100644 --- a/sdk/python/feast/utils.py +++ b/sdk/python/feast/utils.py @@ -106,7 +106,8 @@ def _get_requested_feature_views_to_features_dict( on_demand_feature_views: List["OnDemandFeatureView"], ) -> Tuple[Dict["FeatureView", List[str]], Dict["OnDemandFeatureView", List[str]]]: """Create a dict of FeatureView -> List[Feature] for all requested features. - Set full_feature_names to True to have feature names prefixed by their feature view name.""" + Set full_feature_names to True to have feature names prefixed by their feature view name. + """ feature_views_to_feature_map: Dict["FeatureView", List[str]] = defaultdict(list) on_demand_feature_views_to_feature_map: Dict["OnDemandFeatureView", List[str]] = ( @@ -212,6 +213,28 @@ def _run_pyarrow_field_mapping( return table +def _get_fields_with_aliases( + fields: List[str], + field_mappings: Dict[str, str], +) -> Tuple[List[str], List[str]]: + """ + Get a list of fields with aliases based on the field mappings. + """ + for field in fields: + if "." in field and field not in field_mappings: + raise ValueError( + f"Feature {field} contains a '.' character, which is not allowed in field names. Use field mappings to rename fields." + ) + fields_with_aliases = [ + f"{field} AS {field_mappings[field]}" if field in field_mappings else field + for field in fields + ] + aliases = [ + field_mappings[field] if field in field_mappings else field for field in fields + ] + return (fields_with_aliases, aliases) + + def _coerce_datetime(ts): """ Depending on underlying time resolution, arrow to_pydict() sometimes returns pd @@ -781,9 +804,11 @@ def _populate_response_from_feature_data( """ # Add the feature names to the response. requested_feature_refs = [ - f"{table.projection.name_to_use()}__{feature_name}" - if full_feature_names - else feature_name + ( + f"{table.projection.name_to_use()}__{feature_name}" + if full_feature_names + else feature_name + ) for feature_name in requested_features ] online_features_response.metadata.feature_names.val.extend(requested_feature_refs) diff --git a/sdk/python/tests/integration/materialization/contrib/spark/test_spark.py b/sdk/python/tests/integration/materialization/contrib/spark/test_spark_materialization_engine.py similarity index 100% rename from sdk/python/tests/integration/materialization/contrib/spark/test_spark.py rename to sdk/python/tests/integration/materialization/contrib/spark/test_spark_materialization_engine.py diff --git a/sdk/python/tests/integration/registration/test_universal_registry.py b/sdk/python/tests/integration/registration/test_universal_registry.py index a194b8ae26..5e06247ebb 100644 --- a/sdk/python/tests/integration/registration/test_universal_registry.py +++ b/sdk/python/tests/integration/registration/test_universal_registry.py @@ -1828,3 +1828,51 @@ def test_apply_entity_to_sql_registry_and_reinitialize_sql_registry(test_registr updated_test_registry.teardown() test_registry.teardown() + + +@pytest.mark.integration +def test_commit_for_read_only_user(): + fd, registry_path = mkstemp() + registry_config = RegistryConfig(path=registry_path, cache_ttl_seconds=600) + write_registry = Registry("project", registry_config, None) + + entity = Entity( + name="driver_car_id", + description="Car driver id", + tags={"team": "matchmaking"}, + ) + + project = "project" + + # Register Entity without commiting + write_registry.apply_entity(entity, project, commit=False) + assert write_registry.cached_registry_proto + project_obj = write_registry.cached_registry_proto.projects[0] + assert project == Project.from_proto(project_obj).name + assert_project(project, write_registry, True) + + # Retrieving the entity should still succeed + entities = write_registry.list_entities(project, allow_cache=True, tags=entity.tags) + entity = entities[0] + assert ( + len(entities) == 1 + and entity.name == "driver_car_id" + and entity.description == "Car driver id" + and "team" in entity.tags + and entity.tags["team"] == "matchmaking" + ) + + # commit from the original registry + write_registry.commit() + + # Reconstruct the new registry in order to read the newly written store + with mock.patch.object( + Registry, + "commit", + side_effect=Exception("Read only users are not allowed to commit"), + ): + read_registry = Registry("project", registry_config, None) + entities = read_registry.list_entities(project, tags=entity.tags) + assert len(entities) == 1 + + write_registry.teardown() diff --git a/sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_spark.py b/sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_spark.py new file mode 100644 index 0000000000..b8f8cc4247 --- /dev/null +++ b/sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_spark.py @@ -0,0 +1,129 @@ +from datetime import datetime +from unittest.mock import MagicMock, patch + +from feast.infra.offline_stores.contrib.spark_offline_store.spark import ( + SparkOfflineStore, + SparkOfflineStoreConfig, +) +from feast.infra.offline_stores.contrib.spark_offline_store.spark_source import ( + SparkSource, +) +from feast.infra.offline_stores.offline_store import RetrievalJob +from feast.repo_config import RepoConfig + + +@patch( + "feast.infra.offline_stores.contrib.spark_offline_store.spark.get_spark_session_or_start_new_with_repoconfig" +) +def test_pull_latest_from_table_with_nested_timestamp_or_query(mock_get_spark_session): + mock_spark_session = MagicMock() + mock_get_spark_session.return_value = mock_spark_session + + test_repo_config = RepoConfig( + project="test_project", + registry="test_registry", + provider="local", + offline_store=SparkOfflineStoreConfig(type="spark"), + ) + + test_data_source = SparkSource( + name="test_nested_batch_source", + description="test_nested_batch_source", + table="offline_store_database_name.offline_store_table_name", + timestamp_field="nested_timestamp", + field_mapping={ + "event_header.event_published_datetime_utc": "nested_timestamp", + }, + ) + + # Define the parameters for the method + join_key_columns = ["key1", "key2"] + feature_name_columns = ["feature1", "feature2"] + timestamp_field = "event_header.event_published_datetime_utc" + created_timestamp_column = "created_timestamp" + start_date = datetime(2021, 1, 1) + end_date = datetime(2021, 1, 2) + + # Call the method + retrieval_job = SparkOfflineStore.pull_latest_from_table_or_query( + config=test_repo_config, + data_source=test_data_source, + join_key_columns=join_key_columns, + feature_name_columns=feature_name_columns, + timestamp_field=timestamp_field, + created_timestamp_column=created_timestamp_column, + start_date=start_date, + end_date=end_date, + ) + + expected_query = """SELECT + key1, key2, feature1, feature2, nested_timestamp, created_timestamp + + FROM ( + SELECT key1, key2, feature1, feature2, event_header.event_published_datetime_utc AS nested_timestamp, created_timestamp, + ROW_NUMBER() OVER(PARTITION BY key1, key2 ORDER BY event_header.event_published_datetime_utc DESC, created_timestamp DESC) AS feast_row_ + FROM `offline_store_database_name`.`offline_store_table_name` t1 + WHERE event_header.event_published_datetime_utc BETWEEN TIMESTAMP('2021-01-01 00:00:00.000000') AND TIMESTAMP('2021-01-02 00:00:00.000000') + ) t2 + WHERE feast_row_ = 1""" # noqa: W293 + + assert isinstance(retrieval_job, RetrievalJob) + assert retrieval_job.query.strip() == expected_query.strip() + + +@patch( + "feast.infra.offline_stores.contrib.spark_offline_store.spark.get_spark_session_or_start_new_with_repoconfig" +) +def test_pull_latest_from_table_without_nested_timestamp_or_query( + mock_get_spark_session, +): + mock_spark_session = MagicMock() + mock_get_spark_session.return_value = mock_spark_session + + test_repo_config = RepoConfig( + project="test_project", + registry="test_registry", + provider="local", + offline_store=SparkOfflineStoreConfig(type="spark"), + ) + + test_data_source = SparkSource( + name="test_batch_source", + description="test_nested_batch_source", + table="offline_store_database_name.offline_store_table_name", + timestamp_field="event_published_datetime_utc", + ) + + # Define the parameters for the method + join_key_columns = ["key1", "key2"] + feature_name_columns = ["feature1", "feature2"] + timestamp_field = "event_published_datetime_utc" + created_timestamp_column = "created_timestamp" + start_date = datetime(2021, 1, 1) + end_date = datetime(2021, 1, 2) + + # Call the method + retrieval_job = SparkOfflineStore.pull_latest_from_table_or_query( + config=test_repo_config, + data_source=test_data_source, + join_key_columns=join_key_columns, + feature_name_columns=feature_name_columns, + timestamp_field=timestamp_field, + created_timestamp_column=created_timestamp_column, + start_date=start_date, + end_date=end_date, + ) + + expected_query = """SELECT + key1, key2, feature1, feature2, event_published_datetime_utc, created_timestamp + + FROM ( + SELECT key1, key2, feature1, feature2, event_published_datetime_utc, created_timestamp, + ROW_NUMBER() OVER(PARTITION BY key1, key2 ORDER BY event_published_datetime_utc DESC, created_timestamp DESC) AS feast_row_ + FROM `offline_store_database_name`.`offline_store_table_name` t1 + WHERE event_published_datetime_utc BETWEEN TIMESTAMP('2021-01-01 00:00:00.000000') AND TIMESTAMP('2021-01-02 00:00:00.000000') + ) t2 + WHERE feast_row_ = 1""" # noqa: W293 + + assert isinstance(retrieval_job, RetrievalJob) + assert retrieval_job.query.strip() == expected_query.strip()