Skip to content

Commit

Permalink
feat: Enhance customization of Trino connections when using Trino-bas…
Browse files Browse the repository at this point in the history
…ed Offline Stores (feast-dev#3699)

* feat: Enhance customization of Trino connections when using Trino-based Offline Stores

Signed-off-by: boliri <[email protected]>

* docs: Add new connection parameters to Trino Offline Store's reference

Signed-off-by: boliri <[email protected]>

---------

Signed-off-by: boliri <[email protected]>
  • Loading branch information
boliri authored and james-crabtree-sp committed Sep 14, 2023
1 parent b536cdb commit b8ccb7a
Show file tree
Hide file tree
Showing 5 changed files with 209 additions and 60 deletions.
41 changes: 41 additions & 0 deletions docs/reference/offline-stores/trino.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,47 @@ offline_store:
catalog: memory
connector:
type: memory
user: trino
source: feast-trino-offline-store
http-scheme: https
ssl-verify: false
x-trino-extra-credential-header: foo=bar, baz=qux

# enables authentication in Trino connections, pick the one you need
# if you don't need authentication, you can safely remove the whole auth block
auth:
# Basic Auth
type: basic
config:
username: foo
password: $FOO

# Certificate
type: certificate
config:
cert-file: /path/to/cert/file
key-file: /path/to/key/file

# JWT
type: jwt
config:
token: $JWT_TOKEN

# OAuth2 (no config required)
type: oauth2

# Kerberos
type: kerberos
config:
config-file: /path/to/kerberos/config/file
service-name: foo
mutual-authentication: true
force-preemptive: true
hostname-override: custom-hostname
sanitize-mutual-error-response: true
principal: principal-name
delegate: true
ca_bundle: /path/to/ca/bundle/file
online_store:
path: data/online_store.db
```
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ def __init__(
catalog="memory",
host="localhost",
port=self.exposed_port,
source="trino-python-client",
http_scheme="http",
verify=False,
extra_credential=None,
auth=None,
)

def teardown(self):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import uuid
from datetime import date, datetime
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Literal, Optional, Tuple, Union

import numpy as np
import pandas as pd
import pyarrow
from pydantic import StrictStr
from trino.auth import Authentication
from pydantic import Field, FilePath, SecretStr, StrictBool, StrictStr, root_validator
from trino.auth import (
BasicAuthentication,
CertificateAuthentication,
JWTAuthentication,
KerberosAuthentication,
OAuth2Authentication,
)

from feast.data_source import DataSource
from feast.errors import InvalidEntityType
Expand All @@ -32,6 +38,87 @@
from feast.usage import log_exceptions_and_usage


class BasicAuthModel(FeastConfigBaseModel):
username: StrictStr
password: SecretStr


class KerberosAuthModel(FeastConfigBaseModel):
config: Optional[FilePath] = Field(default=None, alias="config-file")
service_name: Optional[StrictStr] = Field(default=None, alias="service-name")
mutual_authentication: StrictBool = Field(
default=False, alias="mutual-authentication"
)
force_preemptive: StrictBool = Field(default=False, alias="force-preemptive")
hostname_override: Optional[StrictStr] = Field(
default=None, alias="hostname-override"
)
sanitize_mutual_error_response: StrictBool = Field(
default=True, alias="sanitize-mutual-error-response"
)
principal: Optional[StrictStr]
delegate: StrictBool = False
ca_bundle: Optional[FilePath] = Field(default=None, alias="ca-bundle-file")


class JWTAuthModel(FeastConfigBaseModel):
token: SecretStr


class CertificateAuthModel(FeastConfigBaseModel):
cert: FilePath = Field(default=None, alias="cert-file")
key: FilePath = Field(default=None, alias="key-file")


CLASSES_BY_AUTH_TYPE = {
"kerberos": {
"auth_model": KerberosAuthModel,
"trino_auth": KerberosAuthentication,
},
"basic": {
"auth_model": BasicAuthModel,
"trino_auth": BasicAuthentication,
},
"jwt": {
"auth_model": JWTAuthModel,
"trino_auth": JWTAuthentication,
},
"oauth2": {
"auth_model": None,
"trino_auth": OAuth2Authentication,
},
"certificate": {
"auth_model": CertificateAuthModel,
"trino_auth": CertificateAuthentication,
},
}


class AuthConfig(FeastConfigBaseModel):
type: Literal["kerberos", "basic", "jwt", "oauth2", "certificate"]
config: Optional[Dict[StrictStr, Any]]

@root_validator
def config_only_nullable_for_oauth2(cls, values):
auth_type = values["type"]
auth_config = values["config"]
if auth_type != "oauth2" and auth_config is None:
raise ValueError(f"config cannot be null for auth type '{auth_type}'")

return values

def to_trino_auth(self):
auth_type = self.type
trino_auth_cls = CLASSES_BY_AUTH_TYPE[auth_type]["trino_auth"]

if auth_type == "oauth2":
return trino_auth_cls()

model_cls = CLASSES_BY_AUTH_TYPE[auth_type]["auth_model"]
model = model_cls(**self.config)
return trino_auth_cls(**model.dict())


class TrinoOfflineStoreConfig(FeastConfigBaseModel):
"""Online store config for Trino"""

Expand All @@ -47,6 +134,23 @@ class TrinoOfflineStoreConfig(FeastConfigBaseModel):
catalog: StrictStr
""" Catalog of the Trino cluster """

user: StrictStr
""" User of the Trino cluster """

source: Optional[StrictStr] = "trino-python-client"
""" ID of the feast's Trino Python client, useful for debugging """

http_scheme: Literal["http", "https"] = Field(default="http", alias="http-scheme")
""" HTTP scheme that should be used while establishing a connection to the Trino cluster """

verify: StrictBool = Field(default=True, alias="ssl-verify")
""" Whether the SSL certificate emited by the Trino cluster should be verified or not """

extra_credential: Optional[StrictStr] = Field(
default=None, alias="x-trino-extra-credential-header"
)
""" Specifies the HTTP header X-Trino-Extra-Credential, e.g. user1=pwd1, user2=pwd2 """

connector: Dict[str, str]
"""
Trino connector to use as well as potential extra parameters.
Expand All @@ -59,6 +163,16 @@ class TrinoOfflineStoreConfig(FeastConfigBaseModel):
dataset: StrictStr = "feast"
""" (optional) Trino Dataset name for temporary tables """

auth: Optional[AuthConfig]
"""
(optional) Authentication mechanism to use when connecting to Trino. Supported options are:
- kerberos
- basic
- jwt
- oauth2
- certificate
"""


class TrinoRetrievalJob(RetrievalJob):
def __init__(
Expand Down Expand Up @@ -162,9 +276,6 @@ def pull_latest_from_table_or_query(
created_timestamp_column: Optional[str],
start_date: datetime,
end_date: datetime,
user: Optional[str] = None,
auth: Optional[Authentication] = None,
http_scheme: Optional[str] = None,
) -> TrinoRetrievalJob:
assert isinstance(config.offline_store, TrinoOfflineStoreConfig)
assert isinstance(data_source, TrinoSource)
Expand All @@ -181,9 +292,7 @@ def pull_latest_from_table_or_query(
timestamps.append(created_timestamp_column)
timestamp_desc_string = " DESC, ".join(timestamps) + " DESC"
field_string = ", ".join(join_key_columns + feature_name_columns + timestamps)
client = _get_trino_client(
config=config, user=user, auth=auth, http_scheme=http_scheme
)
client = _get_trino_client(config=config)

query = f"""
SELECT
Expand Down Expand Up @@ -216,17 +325,12 @@ def get_historical_features(
registry: Registry,
project: str,
full_feature_names: bool = False,
user: Optional[str] = None,
auth: Optional[Authentication] = None,
http_scheme: Optional[str] = None,
) -> TrinoRetrievalJob:
assert isinstance(config.offline_store, TrinoOfflineStoreConfig)
for fv in feature_views:
assert isinstance(fv.batch_source, TrinoSource)

client = _get_trino_client(
config=config, user=user, auth=auth, http_scheme=http_scheme
)
client = _get_trino_client(config=config)

table_reference = _get_table_reference_for_new_entity(
catalog=config.offline_store.catalog,
Expand Down Expand Up @@ -307,17 +411,12 @@ def pull_all_from_table_or_query(
timestamp_field: str,
start_date: datetime,
end_date: datetime,
user: Optional[str] = None,
auth: Optional[Authentication] = None,
http_scheme: Optional[str] = None,
) -> RetrievalJob:
assert isinstance(config.offline_store, TrinoOfflineStoreConfig)
assert isinstance(data_source, TrinoSource)
from_expression = data_source.get_table_query_string()

client = _get_trino_client(
config=config, user=user, auth=auth, http_scheme=http_scheme
)
client = _get_trino_client(config=config)
field_string = ", ".join(
join_key_columns + feature_name_columns + [timestamp_field]
)
Expand Down Expand Up @@ -378,21 +477,22 @@ def _upload_entity_df_and_get_entity_schema(
# TODO: Ensure that the table expires after some time


def _get_trino_client(
config: RepoConfig,
user: Optional[str],
auth: Optional[Any],
http_scheme: Optional[str],
) -> Trino:
client = Trino(
user=user,
catalog=config.offline_store.catalog,
def _get_trino_client(config: RepoConfig) -> Trino:
auth = None
if config.offline_store.auth is not None:
auth = config.offline_store.auth.to_trino_auth()

return Trino(
host=config.offline_store.host,
port=config.offline_store.port,
user=config.offline_store.user,
catalog=config.offline_store.catalog,
source=config.offline_store.source,
http_scheme=config.offline_store.http_scheme,
verify=config.offline_store.verify,
extra_credential=config.offline_store.extra_credential,
auth=auth,
http_scheme=http_scheme,
)
return client


def _get_entity_df_event_timestamp_range(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import datetime
import os
import signal
from dataclasses import dataclass
from enum import Enum
Expand Down Expand Up @@ -30,34 +29,27 @@ class QueryStatus(Enum):
class Trino:
def __init__(
self,
host: Optional[str] = None,
port: Optional[int] = None,
user: Optional[str] = None,
catalog: Optional[str] = None,
auth: Optional[Any] = None,
http_scheme: Optional[str] = None,
source: Optional[str] = None,
extra_credential: Optional[str] = None,
host: str,
port: int,
user: str,
catalog: str,
source: Optional[str],
http_scheme: str,
verify: bool,
extra_credential: Optional[str],
auth: Optional[trino.Authentication],
):
self.host = host or os.getenv("TRINO_HOST")
self.port = port or os.getenv("TRINO_PORT")
self.user = user or os.getenv("TRINO_USER")
self.catalog = catalog or os.getenv("TRINO_CATALOG")
self.auth = auth or os.getenv("TRINO_AUTH")
self.http_scheme = http_scheme or os.getenv("TRINO_HTTP_SCHEME")
self.source = source or os.getenv("TRINO_SOURCE")
self.extra_credential = extra_credential or os.getenv("TRINO_EXTRA_CREDENTIAL")
self.host = host
self.port = port
self.user = user
self.catalog = catalog
self.source = source
self.http_scheme = http_scheme
self.verify = verify
self.extra_credential = extra_credential
self.auth = auth
self._cursor: Optional[Cursor] = None

if self.host is None:
raise ValueError("TRINO_HOST must be set if not passed in")
if self.port is None:
raise ValueError("TRINO_PORT must be set if not passed in")
if self.user is None:
raise ValueError("TRINO_USER must be set if not passed in")
if self.catalog is None:
raise ValueError("TRINO_CATALOG must be set if not passed in")

def _get_cursor(self) -> Cursor:
if self._cursor is None:
headers = (
Expand All @@ -70,9 +62,10 @@ def _get_cursor(self) -> Cursor:
port=self.port,
user=self.user,
catalog=self.catalog,
auth=self.auth,
http_scheme=self.http_scheme,
source=self.source,
http_scheme=self.http_scheme,
verify=self.verify,
auth=self.auth,
http_headers=headers,
).cursor()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,10 +227,20 @@ def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]:
def get_table_column_names_and_types(
self, config: RepoConfig
) -> Iterable[Tuple[str, str]]:
auth = None
if config.offline_store.auth is not None:
auth = config.offline_store.auth.to_trino_auth()

client = Trino(
catalog=config.offline_store.catalog,
host=config.offline_store.host,
port=config.offline_store.port,
user=config.offline_store.user,
source=config.offline_store.source,
http_scheme=config.offline_store.http_scheme,
verify=config.offline_store.verify,
extra_credential=config.offline_store.extra_credential,
auth=auth,
)
if self.table:
table_schema = client.execute_query(
Expand Down

0 comments on commit b8ccb7a

Please sign in to comment.