Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(ingest): better correctness on the emitter -> graph conversion #12272

Merged
merged 3 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion metadata-ingestion/src/datahub/cli/cli_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import time
import typing
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union

import click
import requests
Expand Down Expand Up @@ -33,6 +33,15 @@ def first_non_null(ls: List[Optional[str]]) -> Optional[str]:
return next((el for el in ls if el is not None and el.strip() != ""), None)


_T = TypeVar("_T")


def get_or_else(value: Optional[_T], default: _T) -> _T:
# Normally we'd use `value or default`. However, that runs into issues
# when value is falsey but not None.
return value if value is not None else default


def parse_run_restli_response(response: requests.Response) -> dict:
response_json = response.json()
if response.status_code != 200:
Expand Down
209 changes: 125 additions & 84 deletions metadata-ingestion/src/datahub/emitter/rest_emitter.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,21 @@
from __future__ import annotations

import functools
import json
import logging
import os
from json.decoder import JSONDecodeError
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
Union,
)

import requests
from deprecated import deprecated
Expand All @@ -12,9 +24,13 @@

from datahub import nice_version_name
from datahub.cli import config_utils
from datahub.cli.cli_utils import ensure_has_system_metadata, fixup_gms_url
from datahub.cli.cli_utils import ensure_has_system_metadata, fixup_gms_url, get_or_else
from datahub.cli.env_utils import get_boolean_env_variable
from datahub.configuration.common import ConfigurationError, OperationalError
from datahub.configuration.common import (
ConfigModel,
ConfigurationError,
OperationalError,
)
from datahub.emitter.generic_emitter import Emitter
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.emitter.request_helper import make_curl_command
Expand All @@ -31,10 +47,8 @@

logger = logging.getLogger(__name__)

_DEFAULT_CONNECT_TIMEOUT_SEC = 30 # 30 seconds should be plenty to connect
_DEFAULT_READ_TIMEOUT_SEC = (
30 # Any ingest call taking longer than 30 seconds should be abandoned
)
_DEFAULT_TIMEOUT_SEC = 30 # 30 seconds should be plenty to connect
_TIMEOUT_LOWER_BOUND_SEC = 1 # if below this, we log a warning
_DEFAULT_RETRY_STATUS_CODES = [ # Additional status codes to retry on
429,
500,
Expand Down Expand Up @@ -63,15 +77,76 @@
)


class RequestsSessionConfig(ConfigModel):
timeout: Union[float, Tuple[float, float], None] = _DEFAULT_TIMEOUT_SEC

retry_status_codes: List[int] = _DEFAULT_RETRY_STATUS_CODES
retry_methods: List[str] = _DEFAULT_RETRY_METHODS
retry_max_times: int = _DEFAULT_RETRY_MAX_TIMES

extra_headers: Dict[str, str] = {}

ca_certificate_path: Optional[str] = None
client_certificate_path: Optional[str] = None
disable_ssl_verification: bool = False

def build_session(self) -> requests.Session:
session = requests.Session()

if self.extra_headers:
session.headers.update(self.extra_headers)

if self.client_certificate_path:
session.cert = self.client_certificate_path

Check warning on line 100 in metadata-ingestion/src/datahub/emitter/rest_emitter.py

View check run for this annotation

Codecov / codecov/patch

metadata-ingestion/src/datahub/emitter/rest_emitter.py#L100

Added line #L100 was not covered by tests

if self.ca_certificate_path:
session.verify = self.ca_certificate_path

Check warning on line 103 in metadata-ingestion/src/datahub/emitter/rest_emitter.py

View check run for this annotation

Codecov / codecov/patch

metadata-ingestion/src/datahub/emitter/rest_emitter.py#L103

Added line #L103 was not covered by tests

if self.disable_ssl_verification:
session.verify = False

Check warning on line 106 in metadata-ingestion/src/datahub/emitter/rest_emitter.py

View check run for this annotation

Codecov / codecov/patch

metadata-ingestion/src/datahub/emitter/rest_emitter.py#L106

Added line #L106 was not covered by tests

try:
# Set raise_on_status to False to propagate errors:
# https://stackoverflow.com/questions/70189330/determine-status-code-from-python-retry-exception
# Must call `raise_for_status` after making a request, which we do
retry_strategy = Retry(
total=self.retry_max_times,
status_forcelist=self.retry_status_codes,
backoff_factor=2,
allowed_methods=self.retry_methods,
raise_on_status=False,
)
except TypeError:

Check warning on line 119 in metadata-ingestion/src/datahub/emitter/rest_emitter.py

View check run for this annotation

Codecov / codecov/patch

metadata-ingestion/src/datahub/emitter/rest_emitter.py#L119

Added line #L119 was not covered by tests
# Prior to urllib3 1.26, the Retry class used `method_whitelist` instead of `allowed_methods`.
retry_strategy = Retry(

Check warning on line 121 in metadata-ingestion/src/datahub/emitter/rest_emitter.py

View check run for this annotation

Codecov / codecov/patch

metadata-ingestion/src/datahub/emitter/rest_emitter.py#L121

Added line #L121 was not covered by tests
total=self.retry_max_times,
status_forcelist=self.retry_status_codes,
backoff_factor=2,
method_whitelist=self.retry_methods,
raise_on_status=False,
)

adapter = HTTPAdapter(
pool_connections=100, pool_maxsize=100, max_retries=retry_strategy
)
session.mount("http://", adapter)
session.mount("https://", adapter)

if self.timeout is not None:
# Shim session.request to apply default timeout values.
# Via https://stackoverflow.com/a/59317604.
session.request = functools.partial( # type: ignore
session.request,
timeout=self.timeout,
)

return session


class DataHubRestEmitter(Closeable, Emitter):
_gms_server: str
_token: Optional[str]
_session: requests.Session
_connect_timeout_sec: float = _DEFAULT_CONNECT_TIMEOUT_SEC
_read_timeout_sec: float = _DEFAULT_READ_TIMEOUT_SEC
_retry_status_codes: List[int] = _DEFAULT_RETRY_STATUS_CODES
_retry_methods: List[str] = _DEFAULT_RETRY_METHODS
_retry_max_times: int = _DEFAULT_RETRY_MAX_TIMES

def __init__(
self,
Expand Down Expand Up @@ -102,15 +177,13 @@

self._session = requests.Session()

self._session.headers.update(
{
"X-RestLi-Protocol-Version": "2.0.0",
"X-DataHub-Py-Cli-Version": nice_version_name(),
"Content-Type": "application/json",
}
)
headers = {
"X-RestLi-Protocol-Version": "2.0.0",
"X-DataHub-Py-Cli-Version": nice_version_name(),
"Content-Type": "application/json",
}
if token:
self._session.headers.update({"Authorization": f"Bearer {token}"})
headers["Authorization"] = f"Bearer {token}"
else:
# HACK: When no token is provided but system auth env variables are set, we use them.
# Ideally this should simply get passed in as config, instead of being sneakily injected
Expand All @@ -119,75 +192,43 @@
# rest emitter, and the rest sink uses the rest emitter under the hood.
system_auth = config_utils.get_system_auth()
if system_auth is not None:
self._session.headers.update({"Authorization": system_auth})

if extra_headers:
self._session.headers.update(extra_headers)

if client_certificate_path:
self._session.cert = client_certificate_path

if ca_certificate_path:
self._session.verify = ca_certificate_path

if disable_ssl_verification:
self._session.verify = False

self._connect_timeout_sec = (
connect_timeout_sec or timeout_sec or _DEFAULT_CONNECT_TIMEOUT_SEC
)
self._read_timeout_sec = (
read_timeout_sec or timeout_sec or _DEFAULT_READ_TIMEOUT_SEC
)

if self._connect_timeout_sec < 1 or self._read_timeout_sec < 1:
logger.warning(
f"Setting timeout values lower than 1 second is not recommended. Your configuration is connect_timeout:{self._connect_timeout_sec}s, read_timeout:{self._read_timeout_sec}s"
)

if retry_status_codes is not None: # Only if missing. Empty list is allowed
self._retry_status_codes = retry_status_codes

if retry_methods is not None:
self._retry_methods = retry_methods

if retry_max_times:
self._retry_max_times = retry_max_times
headers["Authorization"] = system_auth

try:
# Set raise_on_status to False to propagate errors:
# https://stackoverflow.com/questions/70189330/determine-status-code-from-python-retry-exception
# Must call `raise_for_status` after making a request, which we do
retry_strategy = Retry(
total=self._retry_max_times,
status_forcelist=self._retry_status_codes,
backoff_factor=2,
allowed_methods=self._retry_methods,
raise_on_status=False,
)
except TypeError:
# Prior to urllib3 1.26, the Retry class used `method_whitelist` instead of `allowed_methods`.
retry_strategy = Retry(
total=self._retry_max_times,
status_forcelist=self._retry_status_codes,
backoff_factor=2,
method_whitelist=self._retry_methods,
raise_on_status=False,
timeout: float | tuple[float, float]
if connect_timeout_sec is not None or read_timeout_sec is not None:
timeout = (
connect_timeout_sec or timeout_sec or _DEFAULT_TIMEOUT_SEC,
read_timeout_sec or timeout_sec or _DEFAULT_TIMEOUT_SEC,
)
if (
timeout[0] < _TIMEOUT_LOWER_BOUND_SEC
or timeout[1] < _TIMEOUT_LOWER_BOUND_SEC
):
logger.warning(

Check warning on line 207 in metadata-ingestion/src/datahub/emitter/rest_emitter.py

View check run for this annotation

Codecov / codecov/patch

metadata-ingestion/src/datahub/emitter/rest_emitter.py#L207

Added line #L207 was not covered by tests
f"Setting timeout values lower than {_TIMEOUT_LOWER_BOUND_SEC} second is not recommended. Your configuration is (connect_timeout, read_timeout) = {timeout} seconds"
)
else:
timeout = get_or_else(timeout_sec, _DEFAULT_TIMEOUT_SEC)
if timeout < _TIMEOUT_LOWER_BOUND_SEC:
logger.warning(

Check warning on line 213 in metadata-ingestion/src/datahub/emitter/rest_emitter.py

View check run for this annotation

Codecov / codecov/patch

metadata-ingestion/src/datahub/emitter/rest_emitter.py#L213

Added line #L213 was not covered by tests
f"Setting timeout values lower than {_TIMEOUT_LOWER_BOUND_SEC} second is not recommended. Your configuration is timeout = {timeout} seconds"
)

adapter = HTTPAdapter(
pool_connections=100, pool_maxsize=100, max_retries=retry_strategy
)
self._session.mount("http://", adapter)
self._session.mount("https://", adapter)

# Shim session.request to apply default timeout values.
# Via https://stackoverflow.com/a/59317604.
self._session.request = functools.partial( # type: ignore
self._session.request,
timeout=(self._connect_timeout_sec, self._read_timeout_sec),
self._session_config = RequestsSessionConfig(
timeout=timeout,
retry_status_codes=get_or_else(
retry_status_codes, _DEFAULT_RETRY_STATUS_CODES
),
retry_methods=get_or_else(retry_methods, _DEFAULT_RETRY_METHODS),
retry_max_times=get_or_else(retry_max_times, _DEFAULT_RETRY_MAX_TIMES),
extra_headers={**headers, **(extra_headers or {})},
ca_certificate_path=ca_certificate_path,
client_certificate_path=client_certificate_path,
disable_ssl_verification=disable_ssl_verification,
)

self._session = self._session_config.build_session()

def test_connection(self) -> None:
url = f"{self._gms_server}/config"
response = self._session.get(url)
Expand Down
25 changes: 14 additions & 11 deletions metadata-ingestion/src/datahub/ingestion/graph/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,21 +179,24 @@

@classmethod
def from_emitter(cls, emitter: DatahubRestEmitter) -> "DataHubGraph":
session_config = emitter._session_config
if isinstance(session_config.timeout, tuple):
# TODO: This is slightly lossy. Eventually, we want to modify the emitter
# to accept a tuple for timeout_sec, and then we'll be able to remove this.
timeout_sec: Optional[float] = session_config.timeout[0]

Check warning on line 186 in metadata-ingestion/src/datahub/ingestion/graph/client.py

View check run for this annotation

Codecov / codecov/patch

metadata-ingestion/src/datahub/ingestion/graph/client.py#L186

Added line #L186 was not covered by tests
else:
timeout_sec = session_config.timeout
return cls(
DatahubClientConfig(
server=emitter._gms_server,
token=emitter._token,
timeout_sec=emitter._read_timeout_sec,
retry_status_codes=emitter._retry_status_codes,
retry_max_times=emitter._retry_max_times,
extra_headers=emitter._session.headers,
disable_ssl_verification=emitter._session.verify is False,
ca_certificate_path=(
emitter._session.verify
if isinstance(emitter._session.verify, str)
else None
),
client_certificate_path=emitter._session.cert,
timeout_sec=timeout_sec,
retry_status_codes=session_config.retry_status_codes,
retry_max_times=session_config.retry_max_times,
extra_headers=session_config.extra_headers,
disable_ssl_verification=session_config.disable_ssl_verification,
ca_certificate_path=session_config.ca_certificate_path,
client_certificate_path=session_config.client_certificate_path,
)
)

Expand Down
2 changes: 1 addition & 1 deletion metadata-ingestion/src/datahub/ingestion/graph/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class DatahubClientConfig(ConfigModel):
# by callers / the CLI, but the actual client should not have any magic.
server: str
token: Optional[str] = None
timeout_sec: Optional[int] = None
timeout_sec: Optional[float] = None
retry_status_codes: Optional[List[int]] = None
retry_max_times: Optional[int] = None
extra_headers: Optional[Dict[str, str]] = None
Expand Down
32 changes: 17 additions & 15 deletions metadata-ingestion/tests/unit/sdk/test_rest_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,39 +4,41 @@
MOCK_GMS_ENDPOINT = "http://fakegmshost:8080"


def test_datahub_rest_emitter_construction():
def test_datahub_rest_emitter_construction() -> None:
emitter = DatahubRestEmitter(MOCK_GMS_ENDPOINT)
assert emitter._connect_timeout_sec == rest_emitter._DEFAULT_CONNECT_TIMEOUT_SEC
assert emitter._read_timeout_sec == rest_emitter._DEFAULT_READ_TIMEOUT_SEC
assert emitter._retry_status_codes == rest_emitter._DEFAULT_RETRY_STATUS_CODES
assert emitter._retry_max_times == rest_emitter._DEFAULT_RETRY_MAX_TIMES
assert emitter._session_config.timeout == rest_emitter._DEFAULT_TIMEOUT_SEC
assert (
emitter._session_config.retry_status_codes
== rest_emitter._DEFAULT_RETRY_STATUS_CODES
)
assert (
emitter._session_config.retry_max_times == rest_emitter._DEFAULT_RETRY_MAX_TIMES
)


def test_datahub_rest_emitter_timeout_construction():
def test_datahub_rest_emitter_timeout_construction() -> None:
emitter = DatahubRestEmitter(
MOCK_GMS_ENDPOINT, connect_timeout_sec=2, read_timeout_sec=4
)
assert emitter._connect_timeout_sec == 2
assert emitter._read_timeout_sec == 4
assert emitter._session_config.timeout == (2, 4)


def test_datahub_rest_emitter_general_timeout_construction():
def test_datahub_rest_emitter_general_timeout_construction() -> None:
emitter = DatahubRestEmitter(MOCK_GMS_ENDPOINT, timeout_sec=2, read_timeout_sec=4)
assert emitter._connect_timeout_sec == 2
assert emitter._read_timeout_sec == 4
assert emitter._session_config.timeout == (2, 4)


def test_datahub_rest_emitter_retry_construction():
def test_datahub_rest_emitter_retry_construction() -> None:
emitter = DatahubRestEmitter(
MOCK_GMS_ENDPOINT,
retry_status_codes=[418],
retry_max_times=42,
)
assert emitter._retry_status_codes == [418]
assert emitter._retry_max_times == 42
assert emitter._session_config.retry_status_codes == [418]
assert emitter._session_config.retry_max_times == 42


def test_datahub_rest_emitter_extra_params():
def test_datahub_rest_emitter_extra_params() -> None:
emitter = DatahubRestEmitter(
MOCK_GMS_ENDPOINT, extra_headers={"key1": "value1", "key2": "value2"}
)
Expand Down
Loading