Skip to content

Commit

Permalink
[Python] Add feast feature store handler for enrichment transform (ap…
Browse files Browse the repository at this point in the history
…ache#30957)

* add feast feature store handler

* add changes, unit test

* remove duplicate test, add doc

* correct string formatting

* add lambda, use filesystems, start test

* update pydoc
  • Loading branch information
riteshghorse authored Apr 26, 2024
1 parent c20e329 commit 3329edb
Show file tree
Hide file tree
Showing 9 changed files with 470 additions and 2 deletions.
1 change: 0 additions & 1 deletion .github/trigger_files/beam_PostCommit_Python.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
{
"comment": "Modify this file in a trivial way to cause this test suite to run"
}

34 changes: 34 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,40 @@
* ([#X](https://github.com/apache/beam/issues/X)).
-->

# [2.57.0] - Unreleased

## Highlights

* New highly anticipated feature X added to Python SDK ([#X](https://github.com/apache/beam/issues/X)).
* New highly anticipated feature Y added to Java SDK ([#Y](https://github.com/apache/beam/issues/Y)).

## I/Os

* Support for X source added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)).

## New Features / Improvements

* Added Feast feature store handler for enrichment transform (Python) ([#30957](https://github.com/apache/beam/issues/30964)).

## Breaking Changes

* X behavior was changed ([#X](https://github.com/apache/beam/issues/X)).

## Deprecations

* X behavior is deprecated and will be removed in X versions ([#X](https://github.com/apache/beam/issues/X)).

## Bugfixes

* Fixed X (Java/Python) ([#X](https://github.com/apache/beam/issues/X)).

## Security Fixes
* Fixed (CVE-YYYY-NNNN)[https://www.cve.org/CVERecord?id=CVE-YYYY-NNNN] (Java/Python/Go) ([#X](https://github.com/apache/beam/issues/X)).

## Known Issues

* ([#X](https://github.com/apache/beam/issues/X)).

# [2.56.0] - Unreleased

## Highlights
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import tempfile
from pathlib import Path
from typing import Any
from typing import Callable
from typing import List
from typing import Mapping
from typing import Optional

import apache_beam as beam
from apache_beam.io.filesystems import FileSystems
from apache_beam.transforms.enrichment import EnrichmentSourceHandler
from apache_beam.transforms.enrichment_handlers.utils import ExceptionLevel
from feast import FeatureStore

__all__ = [
'FeastFeatureStoreEnrichmentHandler',
]

EntityRowFn = Callable[[beam.Row], Mapping[str, Any]]

_LOGGER = logging.getLogger(__name__)

LOCAL_FEATURE_STORE_YAML_FILENAME = 'fs_yaml_file.yaml'


def download_fs_yaml_file(gcs_fs_yaml_file: str):
"""Download the feature store config file for Feast."""
try:
with FileSystems.open(gcs_fs_yaml_file, 'r') as gcs_file:
with tempfile.NamedTemporaryFile(suffix=LOCAL_FEATURE_STORE_YAML_FILENAME,
delete=False) as local_file:
local_file.write(gcs_file.read())
return Path(local_file.name)
except Exception:
raise RuntimeError(
'error downloading the file %s locally to load the '
'Feast feature store.' % gcs_fs_yaml_file)


def _validate_feature_names(feature_names, feature_service_name):
"""Check if one of `feature_names` or `feature_service_name` is provided."""
if ((not feature_names and not feature_service_name) or
bool(feature_names and feature_service_name)):
raise ValueError(
'Please provide exactly one of a list of feature names to fetch '
'from online store (`feature_names`) or a feature service name for '
'the Feast online feature store (`feature_service_name`).')


def _validate_feature_store_yaml_path_exists(fs_yaml_file):
"""Check if the feature store yaml path exists."""
if not FileSystems.exists(fs_yaml_file):
raise ValueError(
'The feature store yaml path (%s) does not exist.' % fs_yaml_file)


def _validate_entity_key_exists(entity_id, entity_row_fn):
"""Checks if the entity key or a lambda to build entity key exists."""
if ((not entity_row_fn and not entity_id) or
bool(entity_row_fn and entity_id)):
raise ValueError(
"Please specify exactly one of a `entity_id` or a lambda "
"function with `entity_row_fn` to extract the entity id "
"from the input row.")


class FeastFeatureStoreEnrichmentHandler(EnrichmentSourceHandler[beam.Row,
beam.Row]):
"""Enrichment handler to interact with the Feast feature store.
To specify the features to fetch from Feast online store,
please specify exactly one of `feature_names` or `feature_service_name`.
Use this handler with :class:`apache_beam.transforms.enrichment.Enrichment`
transform. To filter the features to enrich, use the `join_fn` param in
:class:`apache_beam.transforms.enrichment.Enrichment`.
"""
def __init__(
self,
feature_store_yaml_path: str,
feature_names: Optional[List[str]] = None,
feature_service_name: Optional[str] = "",
full_feature_names: Optional[bool] = False,
entity_id: str = "",
*,
entity_row_fn: Optional[EntityRowFn] = None,
exception_level: ExceptionLevel = ExceptionLevel.WARN,
):
"""Initializes an instance of `FeastFeatureStoreEnrichmentHandler`.
Args:
feature_store_yaml_path (str): The path to a YAML configuration file for
the Feast feature store. See
https://docs.feast.dev/reference/feature-repository/feature-store-yaml
for configuration options supported by Feast.
feature_names: A list of feature names to be retrieved from the online
Feast feature store.
feature_service_name (str): The name of the feature service containing
the features to fetch from the online Feast feature store.
full_feature_names (bool): Whether to use full feature names
(including namespaces, etc.). Defaults to False.
entity_id (str): entity name for the entity associated with the features.
The `entity_id` is used to extract the entity value from the input row.
Please provide exactly one of `entity_id` or `entity_row_fn`.
entity_row_fn: a lambda function that takes an input `beam.Row` and
returns a dictionary with a mapping from the entity key column name to
entity key value. It is used to build/extract the entity dict for
feature retrieval. Please provide exactly one of `entity_id` or
`entity_row_fn`.
See https://docs.feast.dev/getting-started/concepts/feature-retrieval
for more information.
exception_level: a `enum.Enum` value from
`apache_beam.transforms.enrichment_handlers.utils.ExceptionLevel`
to set the level when `None` feature values are fetched from the
online Feast store. Defaults to `ExceptionLevel.WARN`.
"""
self.entity_id = entity_id
self.feature_store_yaml_path = feature_store_yaml_path
self.feature_names = feature_names
self.feature_service_name = feature_service_name
self.full_feature_names = full_feature_names
self.entity_row_fn = entity_row_fn
self._exception_level = exception_level
_validate_entity_key_exists(self.entity_id, self.entity_row_fn)
_validate_feature_store_yaml_path_exists(self.feature_store_yaml_path)
_validate_feature_names(self.feature_names, self.feature_service_name)

def __enter__(self):
"""Connect with the Feast feature store."""
local_repo_path = download_fs_yaml_file(self.feature_store_yaml_path)
try:
self.store = FeatureStore(fs_yaml_file=local_repo_path)
except Exception:
raise RuntimeError(
'Invalid feature store yaml file provided. Make sure '
'the %s contains the valid configuration for Feast feature store.' %
self.feature_store_yaml_path)
if self.feature_service_name:
try:
self.features = self.store.get_feature_service(
self.feature_service_name)
except Exception:
raise RuntimeError(
'Could not find the feature service %s for the feature '
'store configured in %s.' %
(self.feature_service_name, self.feature_store_yaml_path))
else:
self.features = self.feature_names

def __call__(self, request: beam.Row, *args, **kwargs):
"""Fetches feature values for an entity-id from the Feast feature store.
Args:
request: the input `beam.Row` to enrich.
"""
if self.entity_row_fn:
entity_dict = self.entity_row_fn(request)
else:
request_dict = request._asdict()
entity_dict = {self.entity_id: request_dict[self.entity_id]}
feature_values = self.store.get_online_features(
features=self.features,
entity_rows=[entity_dict],
full_feature_names=self.full_feature_names).to_dict()
# get_online_features() returns a list of feature values per entity-id.
# Since we do this per entity, the list of feature values only contain
# a single element at position 0.
response_dict = {k: v[0] for k, v in feature_values.items()}
return request, beam.Row(**response_dict)

def __exit__(self, exc_type, exc_val, exc_tb):
"""Clean the instantiated Feast feature store client."""
self.store = None

def get_cache_key(self, request: beam.Row) -> str:
"""Returns a string formatted with unique entity-id for the feature values.
"""
if self.entity_row_fn:
entity_dict = self.entity_row_fn(request)
entity_id = list(entity_dict.keys())[0]
else:
entity_id = self.entity_id
return 'entity_id: %s' % request._asdict()[entity_id]
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Tests Feast feature store enrichment handler for enrichment transform.
See https://s.apache.org/feast-enrichment-test-setup
to set up test feast feature repository.
"""

import unittest
from typing import Any
from typing import Mapping

import pytest

import apache_beam as beam
from apache_beam.testing.test_pipeline import TestPipeline

# pylint: disable=ungrouped-imports
try:
from apache_beam.transforms.enrichment import Enrichment
from apache_beam.transforms.enrichment_handlers.feast_feature_store import \
FeastFeatureStoreEnrichmentHandler
from apache_beam.transforms.enrichment_handlers.vertex_ai_feature_store_it_test import ValidateResponse # pylint: disable=line-too-long
except ImportError:
raise unittest.SkipTest(
'Feast feature store test dependencies are not installed.')


def _entity_row_fn(request: beam.Row) -> Mapping[str, Any]:
entity_value = request.user_id # type: ignore[attr-defined]
return {'user_id': entity_value}


@pytest.mark.uses_feast
class TestFeastEnrichmentHandler(unittest.TestCase):
def setUp(self) -> None:
self.feature_store_yaml_file = (
'gs://apache-beam-testing-enrichment/'
'feast-feature-store/repos/ecommerce/'
'feature_repo/feature_store.yaml')
self.feature_service_name = 'demograph_service'

def test_feast_enrichment(self):
requests = [
beam.Row(user_id=2, product_id=1),
beam.Row(user_id=6, product_id=2),
beam.Row(user_id=9, product_id=3),
]
expected_fields = [
'user_id', 'product_id', 'state', 'country', 'gender', 'age'
]
handler = FeastFeatureStoreEnrichmentHandler(
entity_id='user_id',
feature_store_yaml_path=self.feature_store_yaml_file,
feature_service_name=self.feature_service_name,
)

with TestPipeline(is_integration_test=True) as test_pipeline:
_ = (
test_pipeline
| beam.Create(requests)
| Enrichment(handler)
| beam.ParDo(ValidateResponse(expected_fields)))

def test_feast_enrichment_bad_feature_service_name(self):
"""Test raising an error when a bad feature service name is given."""
requests = [
beam.Row(user_id=1, product_id=1),
]
handler = FeastFeatureStoreEnrichmentHandler(
entity_id='user_id',
feature_store_yaml_path=self.feature_store_yaml_file,
feature_service_name="bad_name",
)

with self.assertRaises(RuntimeError):
test_pipeline = beam.Pipeline()
_ = (test_pipeline | beam.Create(requests) | Enrichment(handler))
res = test_pipeline.run()
res.wait_until_finish()

def test_feast_enrichment_with_lambda(self):
requests = [
beam.Row(user_id=2, product_id=1),
beam.Row(user_id=6, product_id=2),
beam.Row(user_id=9, product_id=3),
]
expected_fields = [
'user_id', 'product_id', 'state', 'country', 'gender', 'age'
]
handler = FeastFeatureStoreEnrichmentHandler(
feature_store_yaml_path=self.feature_store_yaml_file,
feature_service_name=self.feature_service_name,
entity_row_fn=_entity_row_fn,
)

with TestPipeline(is_integration_test=True) as test_pipeline:
_ = (
test_pipeline
| beam.Create(requests)
| Enrichment(handler)
| beam.ParDo(ValidateResponse(expected_fields)))


if __name__ == '__main__':
unittest.main()
Loading

0 comments on commit 3329edb

Please sign in to comment.