From 08a37c18fbf060671ebd56499a5ea53a7fce2450 Mon Sep 17 00:00:00 2001 From: Harshal Sheth Date: Fri, 3 Jan 2025 21:22:33 -0500 Subject: [PATCH 1/2] fix(ingest): better correctness on the emitter -> graph conversion Will help us unify these two types in the future. --- .../src/datahub/cli/cli_utils.py | 11 +- .../src/datahub/emitter/rest_emitter.py | 197 ++++++++++-------- .../src/datahub/ingestion/graph/client.py | 25 ++- .../src/datahub/ingestion/graph/config.py | 2 +- .../tests/unit/sdk/test_rest_emitter.py | 32 +-- 5 files changed, 156 insertions(+), 111 deletions(-) diff --git a/metadata-ingestion/src/datahub/cli/cli_utils.py b/metadata-ingestion/src/datahub/cli/cli_utils.py index f80181192ba583..ea58f6fa2c0208 100644 --- a/metadata-ingestion/src/datahub/cli/cli_utils.py +++ b/metadata-ingestion/src/datahub/cli/cli_utils.py @@ -3,7 +3,7 @@ import time import typing from datetime import datetime -from typing import Any, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union import click import requests @@ -33,6 +33,15 @@ def first_non_null(ls: List[Optional[str]]) -> Optional[str]: return next((el for el in ls if el is not None and el.strip() != ""), None) +_T = TypeVar("_T") + + +def value_or(value: Optional[_T], default: _T) -> _T: + # Normally we'd use `value or default`. However, that runs into issues + # value is falsey but not None. + return value if value is not None else default + + def parse_run_restli_response(response: requests.Response) -> dict: response_json = response.json() if response.status_code != 200: diff --git a/metadata-ingestion/src/datahub/emitter/rest_emitter.py b/metadata-ingestion/src/datahub/emitter/rest_emitter.py index 04242c8bf45d2b..28818b759304b7 100644 --- a/metadata-ingestion/src/datahub/emitter/rest_emitter.py +++ b/metadata-ingestion/src/datahub/emitter/rest_emitter.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import functools import json import logging @@ -12,8 +14,12 @@ from datahub import nice_version_name from datahub.cli import config_utils -from datahub.cli.cli_utils import ensure_has_system_metadata, fixup_gms_url -from datahub.configuration.common import ConfigurationError, OperationalError +from datahub.cli.cli_utils import ensure_has_system_metadata, fixup_gms_url, value_or +from datahub.configuration.common import ( + ConfigModel, + ConfigurationError, + OperationalError, +) from datahub.emitter.generic_emitter import Emitter from datahub.emitter.mcp import MetadataChangeProposalWrapper from datahub.emitter.request_helper import make_curl_command @@ -30,10 +36,8 @@ logger = logging.getLogger(__name__) -_DEFAULT_CONNECT_TIMEOUT_SEC = 30 # 30 seconds should be plenty to connect -_DEFAULT_READ_TIMEOUT_SEC = ( - 30 # Any ingest call taking longer than 30 seconds should be abandoned -) +_DEFAULT_TIMEOUT_SEC = 30 # 30 seconds should be plenty to connect +_TIMEOUT_LOWER_BOUND_SEC = 1 # if below this, we log a warning _DEFAULT_RETRY_STATUS_CODES = [ # Additional status codes to retry on 429, 500, @@ -60,15 +64,76 @@ ) +class RequestsSessionConfig(ConfigModel): + timeout: float | tuple[float, float] | None = _DEFAULT_TIMEOUT_SEC + + retry_status_codes: list[int] = _DEFAULT_RETRY_STATUS_CODES + retry_methods: list[str] = _DEFAULT_RETRY_METHODS + retry_max_times: int = _DEFAULT_RETRY_MAX_TIMES + + extra_headers: dict[str, str] = {} + + ca_certificate_path: Optional[str] = None + client_certificate_path: Optional[str] = None + disable_ssl_verification: bool = False + + def build_session(self) -> requests.Session: + session = requests.Session() + + if self.extra_headers: + session.headers.update(self.extra_headers) + + if self.client_certificate_path: + session.cert = self.client_certificate_path + + if self.ca_certificate_path: + session.verify = self.ca_certificate_path + + if self.disable_ssl_verification: + session.verify = False + + try: + # Set raise_on_status to False to propagate errors: + # https://stackoverflow.com/questions/70189330/determine-status-code-from-python-retry-exception + # Must call `raise_for_status` after making a request, which we do + retry_strategy = Retry( + total=self.retry_max_times, + status_forcelist=self.retry_status_codes, + backoff_factor=2, + allowed_methods=self.retry_methods, + raise_on_status=False, + ) + except TypeError: + # Prior to urllib3 1.26, the Retry class used `method_whitelist` instead of `allowed_methods`. + retry_strategy = Retry( + total=self.retry_max_times, + status_forcelist=self.retry_status_codes, + backoff_factor=2, + method_whitelist=self.retry_methods, + raise_on_status=False, + ) + + adapter = HTTPAdapter( + pool_connections=100, pool_maxsize=100, max_retries=retry_strategy + ) + session.mount("http://", adapter) + session.mount("https://", adapter) + + if self.timeout is not None: + # Shim session.request to apply default timeout values. + # Via https://stackoverflow.com/a/59317604. + session.request = functools.partial( # type: ignore + session.request, + timeout=self.timeout, + ) + + return session + + class DataHubRestEmitter(Closeable, Emitter): _gms_server: str _token: Optional[str] _session: requests.Session - _connect_timeout_sec: float = _DEFAULT_CONNECT_TIMEOUT_SEC - _read_timeout_sec: float = _DEFAULT_READ_TIMEOUT_SEC - _retry_status_codes: List[int] = _DEFAULT_RETRY_STATUS_CODES - _retry_methods: List[str] = _DEFAULT_RETRY_METHODS - _retry_max_times: int = _DEFAULT_RETRY_MAX_TIMES def __init__( self, @@ -99,15 +164,13 @@ def __init__( self._session = requests.Session() - self._session.headers.update( - { - "X-RestLi-Protocol-Version": "2.0.0", - "X-DataHub-Py-Cli-Version": nice_version_name(), - "Content-Type": "application/json", - } - ) + headers = { + "X-RestLi-Protocol-Version": "2.0.0", + "X-DataHub-Py-Cli-Version": nice_version_name(), + "Content-Type": "application/json", + } if token: - self._session.headers.update({"Authorization": f"Bearer {token}"}) + headers["Authorization"] = f"Bearer {token}" else: # HACK: When no token is provided but system auth env variables are set, we use them. # Ideally this should simply get passed in as config, instead of being sneakily injected @@ -116,75 +179,43 @@ def __init__( # rest emitter, and the rest sink uses the rest emitter under the hood. system_auth = config_utils.get_system_auth() if system_auth is not None: - self._session.headers.update({"Authorization": system_auth}) + headers["Authorization"] = system_auth - if extra_headers: - self._session.headers.update(extra_headers) - - if client_certificate_path: - self._session.cert = client_certificate_path - - if ca_certificate_path: - self._session.verify = ca_certificate_path - - if disable_ssl_verification: - self._session.verify = False - - self._connect_timeout_sec = ( - connect_timeout_sec or timeout_sec or _DEFAULT_CONNECT_TIMEOUT_SEC - ) - self._read_timeout_sec = ( - read_timeout_sec or timeout_sec or _DEFAULT_READ_TIMEOUT_SEC - ) - - if self._connect_timeout_sec < 1 or self._read_timeout_sec < 1: - logger.warning( - f"Setting timeout values lower than 1 second is not recommended. Your configuration is connect_timeout:{self._connect_timeout_sec}s, read_timeout:{self._read_timeout_sec}s" - ) - - if retry_status_codes is not None: # Only if missing. Empty list is allowed - self._retry_status_codes = retry_status_codes - - if retry_methods is not None: - self._retry_methods = retry_methods - - if retry_max_times: - self._retry_max_times = retry_max_times - - try: - # Set raise_on_status to False to propagate errors: - # https://stackoverflow.com/questions/70189330/determine-status-code-from-python-retry-exception - # Must call `raise_for_status` after making a request, which we do - retry_strategy = Retry( - total=self._retry_max_times, - status_forcelist=self._retry_status_codes, - backoff_factor=2, - allowed_methods=self._retry_methods, - raise_on_status=False, - ) - except TypeError: - # Prior to urllib3 1.26, the Retry class used `method_whitelist` instead of `allowed_methods`. - retry_strategy = Retry( - total=self._retry_max_times, - status_forcelist=self._retry_status_codes, - backoff_factor=2, - method_whitelist=self._retry_methods, - raise_on_status=False, + timeout: float | tuple[float, float] + if connect_timeout_sec is not None or read_timeout_sec is not None: + timeout = ( + connect_timeout_sec or timeout_sec or _DEFAULT_TIMEOUT_SEC, + read_timeout_sec or timeout_sec or _DEFAULT_TIMEOUT_SEC, ) + if ( + timeout[0] < _TIMEOUT_LOWER_BOUND_SEC + or timeout[1] < _TIMEOUT_LOWER_BOUND_SEC + ): + logger.warning( + f"Setting timeout values lower than {_TIMEOUT_LOWER_BOUND_SEC} second is not recommended. Your configuration is (connect_timeout, read_timeout) = {timeout} seconds" + ) + else: + timeout = value_or(timeout_sec, _DEFAULT_TIMEOUT_SEC) + if timeout < _TIMEOUT_LOWER_BOUND_SEC: + logger.warning( + f"Setting timeout values lower than {_TIMEOUT_LOWER_BOUND_SEC} second is not recommended. Your configuration is timeout = {timeout} seconds" + ) - adapter = HTTPAdapter( - pool_connections=100, pool_maxsize=100, max_retries=retry_strategy - ) - self._session.mount("http://", adapter) - self._session.mount("https://", adapter) - - # Shim session.request to apply default timeout values. - # Via https://stackoverflow.com/a/59317604. - self._session.request = functools.partial( # type: ignore - self._session.request, - timeout=(self._connect_timeout_sec, self._read_timeout_sec), + self._session_config = RequestsSessionConfig( + timeout=timeout, + retry_status_codes=value_or( + retry_status_codes, _DEFAULT_RETRY_STATUS_CODES + ), + retry_methods=value_or(retry_methods, _DEFAULT_RETRY_METHODS), + retry_max_times=value_or(retry_max_times, _DEFAULT_RETRY_MAX_TIMES), + extra_headers={**headers, **(extra_headers or {})}, + ca_certificate_path=ca_certificate_path, + client_certificate_path=client_certificate_path, + disable_ssl_verification=disable_ssl_verification, ) + self._session = self._session_config.build_session() + def test_connection(self) -> None: url = f"{self._gms_server}/config" response = self._session.get(url) diff --git a/metadata-ingestion/src/datahub/ingestion/graph/client.py b/metadata-ingestion/src/datahub/ingestion/graph/client.py index ca9a41172e5b6e..7de6e8130a7ab6 100644 --- a/metadata-ingestion/src/datahub/ingestion/graph/client.py +++ b/metadata-ingestion/src/datahub/ingestion/graph/client.py @@ -179,21 +179,24 @@ def frontend_base_url(self) -> str: @classmethod def from_emitter(cls, emitter: DatahubRestEmitter) -> "DataHubGraph": + session_config = emitter._session_config + if isinstance(session_config.timeout, tuple): + # TODO: This is slightly lossy. Eventually, we want to modify the emitter + # to accept a tuple for timeout_sec, and then we'll be able to remove this. + timeout_sec: Optional[float] = session_config.timeout[0] + else: + timeout_sec = session_config.timeout return cls( DatahubClientConfig( server=emitter._gms_server, token=emitter._token, - timeout_sec=emitter._read_timeout_sec, - retry_status_codes=emitter._retry_status_codes, - retry_max_times=emitter._retry_max_times, - extra_headers=emitter._session.headers, - disable_ssl_verification=emitter._session.verify is False, - ca_certificate_path=( - emitter._session.verify - if isinstance(emitter._session.verify, str) - else None - ), - client_certificate_path=emitter._session.cert, + timeout_sec=timeout_sec, + retry_status_codes=session_config.retry_status_codes, + retry_max_times=session_config.retry_max_times, + extra_headers=session_config.extra_headers, + disable_ssl_verification=session_config.disable_ssl_verification, + ca_certificate_path=session_config.ca_certificate_path, + client_certificate_path=session_config.client_certificate_path, ) ) diff --git a/metadata-ingestion/src/datahub/ingestion/graph/config.py b/metadata-ingestion/src/datahub/ingestion/graph/config.py index 5f269e14e1a4af..8f0a5844c97c4b 100644 --- a/metadata-ingestion/src/datahub/ingestion/graph/config.py +++ b/metadata-ingestion/src/datahub/ingestion/graph/config.py @@ -10,7 +10,7 @@ class DatahubClientConfig(ConfigModel): # by callers / the CLI, but the actual client should not have any magic. server: str token: Optional[str] = None - timeout_sec: Optional[int] = None + timeout_sec: Optional[float] = None retry_status_codes: Optional[List[int]] = None retry_max_times: Optional[int] = None extra_headers: Optional[Dict[str, str]] = None diff --git a/metadata-ingestion/tests/unit/sdk/test_rest_emitter.py b/metadata-ingestion/tests/unit/sdk/test_rest_emitter.py index b4d7cb17b66f5c..81120dfc87aba3 100644 --- a/metadata-ingestion/tests/unit/sdk/test_rest_emitter.py +++ b/metadata-ingestion/tests/unit/sdk/test_rest_emitter.py @@ -4,39 +4,41 @@ MOCK_GMS_ENDPOINT = "http://fakegmshost:8080" -def test_datahub_rest_emitter_construction(): +def test_datahub_rest_emitter_construction() -> None: emitter = DatahubRestEmitter(MOCK_GMS_ENDPOINT) - assert emitter._connect_timeout_sec == rest_emitter._DEFAULT_CONNECT_TIMEOUT_SEC - assert emitter._read_timeout_sec == rest_emitter._DEFAULT_READ_TIMEOUT_SEC - assert emitter._retry_status_codes == rest_emitter._DEFAULT_RETRY_STATUS_CODES - assert emitter._retry_max_times == rest_emitter._DEFAULT_RETRY_MAX_TIMES + assert emitter._session_config.timeout == rest_emitter._DEFAULT_TIMEOUT_SEC + assert ( + emitter._session_config.retry_status_codes + == rest_emitter._DEFAULT_RETRY_STATUS_CODES + ) + assert ( + emitter._session_config.retry_max_times == rest_emitter._DEFAULT_RETRY_MAX_TIMES + ) -def test_datahub_rest_emitter_timeout_construction(): +def test_datahub_rest_emitter_timeout_construction() -> None: emitter = DatahubRestEmitter( MOCK_GMS_ENDPOINT, connect_timeout_sec=2, read_timeout_sec=4 ) - assert emitter._connect_timeout_sec == 2 - assert emitter._read_timeout_sec == 4 + assert emitter._session_config.timeout == (2, 4) -def test_datahub_rest_emitter_general_timeout_construction(): +def test_datahub_rest_emitter_general_timeout_construction() -> None: emitter = DatahubRestEmitter(MOCK_GMS_ENDPOINT, timeout_sec=2, read_timeout_sec=4) - assert emitter._connect_timeout_sec == 2 - assert emitter._read_timeout_sec == 4 + assert emitter._session_config.timeout == (2, 4) -def test_datahub_rest_emitter_retry_construction(): +def test_datahub_rest_emitter_retry_construction() -> None: emitter = DatahubRestEmitter( MOCK_GMS_ENDPOINT, retry_status_codes=[418], retry_max_times=42, ) - assert emitter._retry_status_codes == [418] - assert emitter._retry_max_times == 42 + assert emitter._session_config.retry_status_codes == [418] + assert emitter._session_config.retry_max_times == 42 -def test_datahub_rest_emitter_extra_params(): +def test_datahub_rest_emitter_extra_params() -> None: emitter = DatahubRestEmitter( MOCK_GMS_ENDPOINT, extra_headers={"key1": "value1", "key2": "value2"} ) From b013c2e5dfc528edb4c66877d1d2e713ed897969 Mon Sep 17 00:00:00 2001 From: Harshal Sheth Date: Tue, 7 Jan 2025 13:06:36 -0500 Subject: [PATCH 2/2] rename to get_or_else --- .../src/datahub/cli/cli_utils.py | 4 +-- .../src/datahub/emitter/rest_emitter.py | 30 ++++++++++++------- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/metadata-ingestion/src/datahub/cli/cli_utils.py b/metadata-ingestion/src/datahub/cli/cli_utils.py index ea58f6fa2c0208..ca4a11b41925e5 100644 --- a/metadata-ingestion/src/datahub/cli/cli_utils.py +++ b/metadata-ingestion/src/datahub/cli/cli_utils.py @@ -36,9 +36,9 @@ def first_non_null(ls: List[Optional[str]]) -> Optional[str]: _T = TypeVar("_T") -def value_or(value: Optional[_T], default: _T) -> _T: +def get_or_else(value: Optional[_T], default: _T) -> _T: # Normally we'd use `value or default`. However, that runs into issues - # value is falsey but not None. + # when value is falsey but not None. return value if value is not None else default diff --git a/metadata-ingestion/src/datahub/emitter/rest_emitter.py b/metadata-ingestion/src/datahub/emitter/rest_emitter.py index 044a0afb4415b3..74b8ade7da445b 100644 --- a/metadata-ingestion/src/datahub/emitter/rest_emitter.py +++ b/metadata-ingestion/src/datahub/emitter/rest_emitter.py @@ -5,7 +5,17 @@ import logging import os from json.decoder import JSONDecodeError -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Optional, + Sequence, + Tuple, + Union, +) import requests from deprecated import deprecated @@ -14,7 +24,7 @@ from datahub import nice_version_name from datahub.cli import config_utils -from datahub.cli.cli_utils import ensure_has_system_metadata, fixup_gms_url, value_or +from datahub.cli.cli_utils import ensure_has_system_metadata, fixup_gms_url, get_or_else from datahub.cli.env_utils import get_boolean_env_variable from datahub.configuration.common import ( ConfigModel, @@ -68,13 +78,13 @@ class RequestsSessionConfig(ConfigModel): - timeout: float | tuple[float, float] | None = _DEFAULT_TIMEOUT_SEC + timeout: Union[float, Tuple[float, float], None] = _DEFAULT_TIMEOUT_SEC - retry_status_codes: list[int] = _DEFAULT_RETRY_STATUS_CODES - retry_methods: list[str] = _DEFAULT_RETRY_METHODS + retry_status_codes: List[int] = _DEFAULT_RETRY_STATUS_CODES + retry_methods: List[str] = _DEFAULT_RETRY_METHODS retry_max_times: int = _DEFAULT_RETRY_MAX_TIMES - extra_headers: dict[str, str] = {} + extra_headers: Dict[str, str] = {} ca_certificate_path: Optional[str] = None client_certificate_path: Optional[str] = None @@ -198,7 +208,7 @@ def __init__( f"Setting timeout values lower than {_TIMEOUT_LOWER_BOUND_SEC} second is not recommended. Your configuration is (connect_timeout, read_timeout) = {timeout} seconds" ) else: - timeout = value_or(timeout_sec, _DEFAULT_TIMEOUT_SEC) + timeout = get_or_else(timeout_sec, _DEFAULT_TIMEOUT_SEC) if timeout < _TIMEOUT_LOWER_BOUND_SEC: logger.warning( f"Setting timeout values lower than {_TIMEOUT_LOWER_BOUND_SEC} second is not recommended. Your configuration is timeout = {timeout} seconds" @@ -206,11 +216,11 @@ def __init__( self._session_config = RequestsSessionConfig( timeout=timeout, - retry_status_codes=value_or( + retry_status_codes=get_or_else( retry_status_codes, _DEFAULT_RETRY_STATUS_CODES ), - retry_methods=value_or(retry_methods, _DEFAULT_RETRY_METHODS), - retry_max_times=value_or(retry_max_times, _DEFAULT_RETRY_MAX_TIMES), + retry_methods=get_or_else(retry_methods, _DEFAULT_RETRY_METHODS), + retry_max_times=get_or_else(retry_max_times, _DEFAULT_RETRY_MAX_TIMES), extra_headers={**headers, **(extra_headers or {})}, ca_certificate_path=ca_certificate_path, client_certificate_path=client_certificate_path,