From 7214e036e8d45aa7edc5ab3e7cf2f9c067515b92 Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Sun, 24 Mar 2024 20:46:36 +0300 Subject: [PATCH 01/15] Add RESTClient and tests --- dlt/sources/helpers/rest_client/__init__.py | 1 + dlt/sources/helpers/rest_client/auth.py | 233 ++++++++++++ dlt/sources/helpers/rest_client/client.py | 269 +++++++++++++ dlt/sources/helpers/rest_client/detector.py | 161 ++++++++ dlt/sources/helpers/rest_client/exceptions.py | 5 + dlt/sources/helpers/rest_client/paginators.py | 178 +++++++++ dlt/sources/helpers/rest_client/typing.py | 9 + dlt/sources/helpers/rest_client/utils.py | 24 ++ tests/sources/helpers/rest_client/__init__.py | 0 tests/sources/helpers/rest_client/conftest.py | 196 ++++++++++ .../helpers/rest_client/private_key.pem | 28 ++ .../helpers/rest_client/test_client.py | 173 +++++++++ .../helpers/rest_client/test_detector.py | 360 ++++++++++++++++++ .../helpers/rest_client/test_paginators.py | 82 ++++ 14 files changed, 1719 insertions(+) create mode 100644 dlt/sources/helpers/rest_client/__init__.py create mode 100644 dlt/sources/helpers/rest_client/auth.py create mode 100644 dlt/sources/helpers/rest_client/client.py create mode 100644 dlt/sources/helpers/rest_client/detector.py create mode 100644 dlt/sources/helpers/rest_client/exceptions.py create mode 100644 dlt/sources/helpers/rest_client/paginators.py create mode 100644 dlt/sources/helpers/rest_client/typing.py create mode 100644 dlt/sources/helpers/rest_client/utils.py create mode 100644 tests/sources/helpers/rest_client/__init__.py create mode 100644 tests/sources/helpers/rest_client/conftest.py create mode 100644 tests/sources/helpers/rest_client/private_key.pem create mode 100644 tests/sources/helpers/rest_client/test_client.py create mode 100644 tests/sources/helpers/rest_client/test_detector.py create mode 100644 tests/sources/helpers/rest_client/test_paginators.py diff --git a/dlt/sources/helpers/rest_client/__init__.py b/dlt/sources/helpers/rest_client/__init__.py new file mode 100644 index 0000000000..3264ea4aae --- /dev/null +++ b/dlt/sources/helpers/rest_client/__init__.py @@ -0,0 +1 @@ +from .client import RESTClient # noqa: F401 diff --git a/dlt/sources/helpers/rest_client/auth.py b/dlt/sources/helpers/rest_client/auth.py new file mode 100644 index 0000000000..c305654cbd --- /dev/null +++ b/dlt/sources/helpers/rest_client/auth.py @@ -0,0 +1,233 @@ +from base64 import b64encode +import math +from typing import List, Dict, Final, Literal, Optional, Union, Any, cast, Iterable +from dlt.sources.helpers import requests +from requests.auth import AuthBase +from requests import PreparedRequest # noqa: I251 +import pendulum +import jwt +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes + +from dlt import config, secrets +from dlt.common import logger +from dlt.common.configuration.specs.base_configuration import configspec +from dlt.common.configuration.specs import CredentialsConfiguration +from dlt.common.configuration.specs.exceptions import NativeValueError +from dlt.common.typing import TSecretStrValue + + +TApiKeyLocation = Literal[ + "header", "cookie", "query", "param" +] # Alias for scheme "in" field + + +class AuthConfigBase(AuthBase, CredentialsConfiguration): + """Authenticator base which is both `requests` friendly AuthBase and dlt SPEC + configurable via env variables or toml files + """ + + pass + + +@configspec +class BearerTokenAuth(AuthConfigBase): + type: Final[Literal["http"]] = "http" # noqa: A003 + scheme: Literal["bearer"] = "bearer" + token: TSecretStrValue + + def __init__(self, token: TSecretStrValue = secrets.value) -> None: + self.token = token + + def parse_native_representation(self, value: Any) -> None: + if isinstance(value, str): + self.token = cast(TSecretStrValue, value) + else: + raise NativeValueError( + type(self), + value, + f"BearerTokenAuth token must be a string, got {type(value)}", + ) + + def __call__(self, request: PreparedRequest) -> PreparedRequest: + request.headers["Authorization"] = f"Bearer {self.token}" + return request + + +@configspec +class APIKeyAuth(AuthConfigBase): + type: Final[Literal["apiKey"]] = "apiKey" # noqa: A003 + name: str = "Authorization" + api_key: TSecretStrValue + location: TApiKeyLocation = "header" + + def __init__( + self, + name: str = config.value, + api_key: TSecretStrValue = secrets.value, + location: TApiKeyLocation = "header", + ) -> None: + self.name = name + self.api_key = api_key + self.location = location + + def parse_native_representation(self, value: Any) -> None: + if isinstance(value, str): + self.api_key = cast(TSecretStrValue, value) + else: + raise NativeValueError( + type(self), + value, + f"APIKeyAuth api_key must be a string, got {type(value)}", + ) + + def __call__(self, request: PreparedRequest) -> PreparedRequest: + if self.location == "header": + request.headers[self.name] = self.api_key + elif self.location in ["query", "param"]: + request.prepare_url(request.url, {self.name: self.api_key}) + elif self.location == "cookie": + raise NotImplementedError() + return request + + +@configspec +class HttpBasicAuth(AuthConfigBase): + type: Final[Literal["http"]] = "http" # noqa: A003 + scheme: Literal["basic"] = "basic" + username: str + password: TSecretStrValue + + def __init__( + self, username: str = config.value, password: TSecretStrValue = secrets.value + ) -> None: + self.username = username + self.password = password + + def parse_native_representation(self, value: Any) -> None: + if isinstance(value, Iterable) and not isinstance(value, str): + value = list(value) + if len(value) == 2: + self.username, self.password = value + return + raise NativeValueError( + type(self), + value, + f"HttpBasicAuth username and password must be a tuple of two strings, got {type(value)}", + ) + + def __call__(self, request: PreparedRequest) -> PreparedRequest: + encoded = b64encode(f"{self.username}:{self.password}".encode()).decode() + request.headers["Authorization"] = f"Basic {encoded}" + return request + + +@configspec +class OAuth2AuthBase(AuthConfigBase): + """Base class for oauth2 authenticators. requires access_token""" + + # TODO: Separate class for flows (implicit, authorization_code, client_credentials, etc) + type: Final[Literal["oauth2"]] = "oauth2" # noqa: A003 + access_token: TSecretStrValue + + def __init__(self, access_token: TSecretStrValue = secrets.value) -> None: + self.access_token = access_token + + def parse_native_representation(self, value: Any) -> None: + if isinstance(value, str): + self.access_token = cast(TSecretStrValue, value) + else: + raise NativeValueError( + type(self), + value, + f"OAuth2AuthBase access_token must be a string, got {type(value)}", + ) + + def __call__(self, request: PreparedRequest) -> PreparedRequest: + request.headers["Authorization"] = f"Bearer {self.access_token}" + return request + + +@configspec +class OAuthJWTAuth(BearerTokenAuth): + """This is a form of Bearer auth, actually there's not standard way to declare it in openAPI""" + + format: Final[Literal["JWT"]] = "JWT" # noqa: A003 + client_id: str + private_key: TSecretStrValue + auth_endpoint: str + scopes: Optional[str] = None + headers: Optional[Dict[str, str]] = None + private_key_passphrase: Optional[TSecretStrValue] = None + default_token_expiration: int = 3600 + + def __init__( + self, + client_id: str = config.value, + private_key: TSecretStrValue = secrets.value, + auth_endpoint: str = config.value, + scopes: Optional[Union[str, List[str]]] = None, + headers: Optional[Dict[str, str]] = None, + private_key_passphrase: Optional[TSecretStrValue] = None, + default_token_expiration: int = 3600, + ): + self.client_id = client_id + self.private_key = private_key + self.private_key_passphrase = private_key_passphrase + self.auth_endpoint = auth_endpoint + self.scopes = scopes if isinstance(scopes, str) else " ".join(scopes) + self.headers = headers + self.token = None + self.token_expiry: Optional[pendulum.DateTime] = None + self.default_token_expiration = default_token_expiration + + def __call__(self, r: PreparedRequest) -> PreparedRequest: + if self.token is None or self.is_token_expired(): + self.obtain_token() + r.headers["Authorization"] = f"Bearer {self.token}" + return r + + def is_token_expired(self) -> bool: + return not self.token_expiry or pendulum.now() >= self.token_expiry + + def obtain_token(self) -> None: + payload = self.create_jwt_payload() + data = { + "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", + "assertion": jwt.encode( + payload, self.load_private_key(), algorithm="RS256" + ), + } + + logger.debug(f"Obtaining token from {self.auth_endpoint}") + + response = requests.post(self.auth_endpoint, headers=self.headers, data=data) + response.raise_for_status() + + token_response = response.json() + self.token = token_response["access_token"] + self.token_expiry = pendulum.now().add( + seconds=token_response.get("expires_in", self.default_token_expiration) + ) + + def create_jwt_payload(self) -> Dict[str, Union[str, int]]: + now = pendulum.now() + return { + "iss": self.client_id, + "sub": self.client_id, + "aud": self.auth_endpoint, + "exp": math.floor((now.add(hours=1)).timestamp()), + "iat": math.floor(now.timestamp()), + "scope": self.scopes, + } + + def load_private_key(self) -> PrivateKeyTypes: + private_key_bytes = self.private_key.encode("utf-8") + return serialization.load_pem_private_key( + private_key_bytes, + password=self.private_key_passphrase.encode("utf-8") + if self.private_key_passphrase + else None, + backend=default_backend(), + ) diff --git a/dlt/sources/helpers/rest_client/client.py b/dlt/sources/helpers/rest_client/client.py new file mode 100644 index 0000000000..12e22c072d --- /dev/null +++ b/dlt/sources/helpers/rest_client/client.py @@ -0,0 +1,269 @@ +from typing import ( + Iterator, + Optional, + List, + Dict, + Any, + TypeVar, + Iterable, + Union, + Callable, + cast, +) +import copy +from urllib.parse import urlparse + +from requests import Session as BaseSession # noqa: I251 + +from dlt.common import logger +from dlt.common import jsonpath +from dlt.sources.helpers.requests.retry import Client +from dlt.sources.helpers.requests import Response, Request + +from .typing import HTTPMethodBasic, HTTPMethod +from .paginators import BasePaginator +from .auth import AuthConfigBase +from .detector import PaginatorFactory, find_records +from .exceptions import IgnoreResponseException + +from .utils import join_url + + +_T = TypeVar("_T") +HookFunction = Callable[[Response, Any, Any], None] +HookEvent = Union[HookFunction, List[HookFunction]] +Hooks = Dict[str, HookEvent] + + +class PageData(List[_T]): + """A list of elements in a single page of results with attached request context. + + The context allows to inspect the response, paginator and authenticator, modify the request + """ + + def __init__( + self, + __iterable: Iterable[_T], + request: Request, + response: Response, + paginator: BasePaginator, + auth: AuthConfigBase, + ): + super().__init__(__iterable) + self.request = request + self.response = response + self.paginator = paginator + self.auth = auth + + +class RESTClient: + """A generic REST client for making requests to an API with support for + pagination and authentication. + + Args: + base_url (str): The base URL of the API to make requests to. + headers (Optional[Dict[str, str]]): Default headers to include in all requests. + auth (Optional[AuthConfigBase]): Authentication configuration for all requests. + paginator (Optional[BasePaginator]): Default paginator for handling paginated responses. + data_selector (Optional[jsonpath.TJsonPath]): JSONPath selector for extracting data from responses. + session (BaseSession): HTTP session for making requests. + paginator_factory (Optional[PaginatorFactory]): Factory for creating paginator instances, + used for detecting paginators. + """ + + def __init__( + self, + base_url: str, + headers: Optional[Dict[str, str]] = None, + auth: Optional[AuthConfigBase] = None, + paginator: Optional[BasePaginator] = None, + data_selector: Optional[jsonpath.TJsonPath] = None, + session: BaseSession = None, + paginator_factory: Optional[PaginatorFactory] = None, + ) -> None: + self.base_url = base_url + self.headers = headers + self.auth = auth + + if session: + self._validate_session_raise_for_status(session) + self.session = session + else: + self.session = Client(raise_for_status=False).session + + self.paginator = paginator + self.pagination_factory = paginator_factory or PaginatorFactory() + + self.data_selector = data_selector + + def _validate_session_raise_for_status(self, session: BaseSession) -> None: + # dlt.sources.helpers.requests.session.Session + # has raise_for_status=True by default + if getattr(self.session, "raise_for_status", False): + logger.warning( + "The session provided has raise_for_status enabled. " + "This may cause unexpected behavior." + ) + + def _create_request( + self, + path: str, + method: HTTPMethod, + params: Dict[str, Any], + json: Optional[Dict[str, Any]] = None, + auth: Optional[AuthConfigBase] = None, + hooks: Optional[Hooks] = None, + ) -> Request: + parsed_url = urlparse(path) + if parsed_url.scheme in ("http", "https"): + url = path + else: + url = join_url(self.base_url, path) + + return Request( + method=method, + url=url, + headers=self.headers, + params=params, + json=json, + auth=auth or self.auth, + hooks=hooks, + ) + + def _send_request(self, request: Request) -> Response: + logger.info( + f"Making {request.method.upper()} request to {request.url}" + f" with params={request.params}, json={request.json}" + ) + + prepared_request = self.session.prepare_request(request) + + return self.session.send(prepared_request) + + def request( + self, path: str = "", method: HTTPMethod = "GET", **kwargs: Any + ) -> Response: + prepared_request = self._create_request( + path=path, + method=method, + **kwargs, + ) + return self._send_request(prepared_request) + + def get( + self, path: str, params: Optional[Dict[str, Any]] = None, **kwargs: Any + ) -> Response: + return self.request(path, method="GET", params=params, **kwargs) + + def post( + self, path: str, json: Optional[Dict[str, Any]] = None, **kwargs: Any + ) -> Response: + return self.request(path, method="POST", json=json, **kwargs) + + def paginate( + self, + path: str = "", + method: HTTPMethodBasic = "GET", + params: Optional[Dict[str, Any]] = None, + json: Optional[Dict[str, Any]] = None, + auth: Optional[AuthConfigBase] = None, + paginator: Optional[BasePaginator] = None, + data_selector: Optional[jsonpath.TJsonPath] = None, + hooks: Optional[Hooks] = None, + ) -> Iterator[PageData[Any]]: + """Iterates over paginated API responses, yielding pages of data. + + Args: + path (str): Endpoint path for the request, relative to `base_url`. + method (HTTPMethodBasic): HTTP method for the request, defaults to 'get'. + params (Optional[Dict[str, Any]]): URL parameters for the request. + json (Optional[Dict[str, Any]]): JSON payload for the request. + auth (Optional[AuthConfigBase]): Authentication configuration for the request. + paginator (Optional[BasePaginator]): Paginator instance for handling + pagination logic. + data_selector (Optional[jsonpath.TJsonPath]): JSONPath selector for + extracting data from the response. + hooks (Optional[Hooks]): Hooks to modify request/response objects. Note that + when hooks are not provided, the default behavior is to raise an exception + on error status codes. + + Yields: + PageData[Any]: A page of data from the paginated API response, along with request and response context. + + Raises: + HTTPError: If the response status code is not a success code. This is raised + by default when hooks are not provided. + + Example: + >>> client = RESTClient(base_url="https://api.example.com") + >>> for page in client.paginate("/search", method="post", json={"query": "foo"}): + >>> print(page) + """ + + paginator = paginator if paginator else copy.deepcopy(self.paginator) + auth = auth or self.auth + data_selector = data_selector or self.data_selector + hooks = hooks or {} + + def raise_for_status(response: Response, *args: Any, **kwargs: Any) -> None: + response.raise_for_status() + + if "response" not in hooks: + hooks["response"] = [raise_for_status] + + request = self._create_request( + path=path, method=method, params=params, json=json, auth=auth, hooks=hooks + ) + + while True: + try: + response = self._send_request(request) + except IgnoreResponseException: + break + + if paginator is None: + paginator = self.detect_paginator(response) + + data = self.extract_response(response, data_selector) + paginator.update_state(response) + paginator.update_request(request) + + # yield data with context + yield PageData( + data, request=request, response=response, paginator=paginator, auth=auth + ) + + if not paginator.has_next_page: + break + + def extract_response( + self, response: Response, data_selector: jsonpath.TJsonPath + ) -> List[Any]: + if data_selector: + # we should compile data_selector + data: Any = jsonpath.find_values(data_selector, response.json()) + # extract if single item selected + data = data[0] if isinstance(data, list) and len(data) == 1 else data + else: + data = find_records(response.json()) + # wrap single pages into lists + if not isinstance(data, list): + data = [data] + return cast(List[Any], data) + + def detect_paginator(self, response: Response) -> BasePaginator: + """Detects a paginator for the response and returns it. + + Args: + response (Response): The response to detect the paginator for. + + Returns: + BasePaginator: The paginator instance that was detected. + """ + paginator = self.pagination_factory.create_paginator(response) + if paginator is None: + raise ValueError( + f"No suitable paginator found for the response at {response.url}" + ) + logger.info(f"Detected paginator: {paginator.__class__.__name__}") + return paginator diff --git a/dlt/sources/helpers/rest_client/detector.py b/dlt/sources/helpers/rest_client/detector.py new file mode 100644 index 0000000000..f3af31bb4d --- /dev/null +++ b/dlt/sources/helpers/rest_client/detector.py @@ -0,0 +1,161 @@ +import re +from typing import List, Dict, Any, Tuple, Union, Optional, Callable, Iterable + +from dlt.sources.helpers.requests import Response + +from .paginators import ( + BasePaginator, + HeaderLinkPaginator, + JSONResponsePaginator, + SinglePagePaginator, +) + +RECORD_KEY_PATTERNS = frozenset( + [ + "data", + "items", + "results", + "entries", + "records", + "rows", + "entities", + "payload", + "content", + "objects", + ] +) + +NON_RECORD_KEY_PATTERNS = frozenset( + [ + "meta", + "metadata", + "pagination", + "links", + "extras", + "headers", + ] +) + +NEXT_PAGE_KEY_PATTERNS = frozenset(["next", "nextpage", "nexturl"]) +NEXT_PAGE_DICT_KEY_PATTERNS = frozenset(["href", "url"]) + + +def single_entity_path(path: str) -> bool: + """Checks if path ends with path param indicating that single object is returned""" + return re.search(r"\{([a-zA-Z_][a-zA-Z0-9_]*)\}$", path) is not None + + +def find_all_lists( + dict_: Dict[str, Any], + result: List[Tuple[int, str, List[Any]]] = None, + level: int = 0, +) -> List[Tuple[int, str, List[Any]]]: + """Recursively looks for lists in dict_ and returns tuples + in format (nesting level, dictionary key, list) + """ + if level > 2: + return [] + + for key, value in dict_.items(): + if isinstance(value, list): + result.append((level, key, value)) + elif isinstance(value, dict): + find_all_lists(value, result, level + 1) + + return result + + +def find_records( + response: Union[Dict[str, Any], List[Any], Any], +) -> Union[Dict[str, Any], List[Any], Any]: + # when a list was returned (or in rare case a simple type or null) + if not isinstance(response, dict): + return response + lists = find_all_lists(response, result=[]) + if len(lists) == 0: + # could not detect anything + return response + # we are ordered by nesting level, find the most suitable list + try: + return next( + list_info[2] + for list_info in lists + if list_info[1] in RECORD_KEY_PATTERNS + and list_info[1] not in NON_RECORD_KEY_PATTERNS + ) + except StopIteration: + # return the least nested element + return lists[0][2] + + +def matches_any_pattern(key: str, patterns: Iterable[str]) -> bool: + normalized_key = key.lower() + return any(pattern in normalized_key for pattern in patterns) + + +def find_next_page_path( + dictionary: Dict[str, Any], path: Optional[List[str]] = None +) -> Optional[List[str]]: + if not isinstance(dictionary, dict): + return None + + if path is None: + path = [] + + for key, value in dictionary.items(): + if matches_any_pattern(key, NEXT_PAGE_KEY_PATTERNS): + if isinstance(value, dict): + for dict_key in value: + if matches_any_pattern(dict_key, NEXT_PAGE_DICT_KEY_PATTERNS): + return [*path, key, dict_key] + return [*path, key] + + if isinstance(value, dict): + result = find_next_page_path(value, [*path, key]) + if result: + return result + + return None + + +def header_links_detector(response: Response) -> Optional[HeaderLinkPaginator]: + links_next_key = "next" + + if response.links.get(links_next_key): + return HeaderLinkPaginator() + return None + + +def json_links_detector(response: Response) -> Optional[JSONResponsePaginator]: + dictionary = response.json() + next_path_parts = find_next_page_path(dictionary) + + if not next_path_parts: + return None + + return JSONResponsePaginator(next_url_path=".".join(next_path_parts)) + + +def single_page_detector(response: Response) -> Optional[SinglePagePaginator]: + """This is our fallback paginator, also for results that are single entities""" + return SinglePagePaginator() + + +class PaginatorFactory: + def __init__( + self, detectors: List[Callable[[Response], Optional[BasePaginator]]] = None + ): + if detectors is None: + detectors = [ + header_links_detector, + json_links_detector, + single_page_detector, + ] + self.detectors = detectors + + def create_paginator(self, response: Response) -> Optional[BasePaginator]: + for detector in self.detectors: + paginator = detector(response) + if paginator: + return paginator + return None diff --git a/dlt/sources/helpers/rest_client/exceptions.py b/dlt/sources/helpers/rest_client/exceptions.py new file mode 100644 index 0000000000..4b4d555ca7 --- /dev/null +++ b/dlt/sources/helpers/rest_client/exceptions.py @@ -0,0 +1,5 @@ +from dlt.common.exceptions import DltException + + +class IgnoreResponseException(DltException): + pass diff --git a/dlt/sources/helpers/rest_client/paginators.py b/dlt/sources/helpers/rest_client/paginators.py new file mode 100644 index 0000000000..c098ea667f --- /dev/null +++ b/dlt/sources/helpers/rest_client/paginators.py @@ -0,0 +1,178 @@ +from abc import ABC, abstractmethod +from typing import Optional + +from dlt.sources.helpers.requests import Response, Request +from dlt.common import jsonpath + + +class BasePaginator(ABC): + def __init__(self) -> None: + self._has_next_page = True + self._next_reference: Optional[str] = None + + @property + def has_next_page(self) -> bool: + """ + Check if there is a next page available. + + Returns: + bool: True if there is a next page available, False otherwise. + """ + return self._has_next_page + + @property + def next_reference(self) -> Optional[str]: + return self._next_reference + + @next_reference.setter + def next_reference(self, value: Optional[str]) -> None: + self._next_reference = value + self._has_next_page = value is not None + + @abstractmethod + def update_state(self, response: Response) -> None: + """Update the paginator state based on the response. + + Args: + response (Response): The response object from the API. + """ + ... + + @abstractmethod + def update_request(self, request: Request) -> None: + """ + Update the request object with the next arguments for the API request. + + Args: + request (Request): The request object to be updated. + """ + ... + + +class SinglePagePaginator(BasePaginator): + """A paginator for single-page API responses.""" + + def update_state(self, response: Response) -> None: + self._has_next_page = False + + def update_request(self, request: Request) -> None: + return + + +class OffsetPaginator(BasePaginator): + """A paginator that uses the 'offset' parameter for pagination.""" + + def __init__( + self, + initial_limit: int, + initial_offset: int = 0, + offset_param: str = "offset", + limit_param: str = "limit", + total_path: jsonpath.TJsonPath = "total", + ) -> None: + super().__init__() + self.offset_param = offset_param + self.limit_param = limit_param + self.total_path = jsonpath.compile_path(total_path) + + self.offset = initial_offset + self.limit = initial_limit + + def update_state(self, response: Response) -> None: + values = jsonpath.find_values(self.total_path, response.json()) + total = values[0] if values else None + + if total is None: + raise ValueError( + f"Total count not found in response for {self.__class__.__name__}" + ) + + self.offset += self.limit + + if self.offset >= total: + self._has_next_page = False + + def update_request(self, request: Request) -> None: + if request.params is None: + request.params = {} + + request.params[self.offset_param] = self.offset + request.params[self.limit_param] = self.limit + + +class BaseNextUrlPaginator(BasePaginator): + def update_request(self, request: Request) -> None: + request.url = self.next_reference + + +class HeaderLinkPaginator(BaseNextUrlPaginator): + """A paginator that uses the 'Link' header in HTTP responses + for pagination. + + A good example of this is the GitHub API: + https://docs.github.com/en/rest/guides/traversing-with-pagination + """ + + def __init__(self, links_next_key: str = "next") -> None: + """ + Args: + links_next_key (str, optional): The key (rel ) in the 'Link' header + that contains the next page URL. Defaults to 'next'. + """ + super().__init__() + self.links_next_key = links_next_key + + def update_state(self, response: Response) -> None: + self.next_reference = response.links.get(self.links_next_key, {}).get("url") + + +class JSONResponsePaginator(BaseNextUrlPaginator): + """A paginator that uses a specific key in the JSON response to find + the next page URL. + """ + + def __init__( + self, + next_url_path: jsonpath.TJsonPath = "next", + ): + """ + Args: + next_url_path: The JSON path to the key that contains the next page URL in the response. + Defaults to 'next'. + """ + super().__init__() + self.next_url_path = jsonpath.compile_path(next_url_path) + + def update_state(self, response: Response) -> None: + values = jsonpath.find_values(self.next_url_path, response.json()) + self.next_reference = values[0] if values else None + + +class JSONResponseCursorPaginator(BasePaginator): + """A paginator that uses a cursor query param to paginate. The cursor for the + next page is found in the JSON response. + """ + + def __init__( + self, + cursor_path: jsonpath.TJsonPath = "cursors.next", + cursor_param: str = "after", + ): + """ + Args: + cursor_path: The JSON path to the key that contains the cursor in the response. + cursor_param: The name of the query parameter to be used in the request to get the next page. + """ + super().__init__() + self.cursor_path = jsonpath.compile_path(cursor_path) + self.cursor_param = cursor_param + + def update_state(self, response: Response) -> None: + values = jsonpath.find_values(self.cursor_path, response.json()) + self.next_reference = values[0] if values else None + + def update_request(self, request: Request) -> None: + if request.params is None: + request.params = {} + + request.params[self.cursor_param] = self._next_reference diff --git a/dlt/sources/helpers/rest_client/typing.py b/dlt/sources/helpers/rest_client/typing.py new file mode 100644 index 0000000000..dad9842071 --- /dev/null +++ b/dlt/sources/helpers/rest_client/typing.py @@ -0,0 +1,9 @@ +from typing import ( + Union, + Literal, +) + + +HTTPMethodBasic = Literal["GET", "POST"] +HTTPMethodExtended = Literal["PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"] +HTTPMethod = Union[HTTPMethodBasic, HTTPMethodExtended] diff --git a/dlt/sources/helpers/rest_client/utils.py b/dlt/sources/helpers/rest_client/utils.py new file mode 100644 index 0000000000..6001fad4ec --- /dev/null +++ b/dlt/sources/helpers/rest_client/utils.py @@ -0,0 +1,24 @@ +from functools import reduce +from operator import getitem +from typing import Any, Sequence, Union, Tuple + +from dlt.common import logger +from dlt.extract.source import DltSource + + +def join_url(base_url: str, path: str) -> str: + if not base_url.endswith("/"): + base_url += "/" + return base_url + path.lstrip("/") + + +def check_connection( + source: DltSource, + *resource_names: str, +) -> Tuple[bool, str]: + try: + list(source.with_resources(*resource_names).add_limit(1)) + return (True, "") + except Exception as e: + logger.error(f"Error checking connection: {e}") + return (False, str(e)) diff --git a/tests/sources/helpers/rest_client/__init__.py b/tests/sources/helpers/rest_client/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/sources/helpers/rest_client/conftest.py b/tests/sources/helpers/rest_client/conftest.py new file mode 100644 index 0000000000..3868cd4d9f --- /dev/null +++ b/tests/sources/helpers/rest_client/conftest.py @@ -0,0 +1,196 @@ +import re +from typing import NamedTuple, Callable, Pattern, List, TYPE_CHECKING +import json +import base64 + +from urllib.parse import urlsplit, urlunsplit + +import pytest +import requests_mock + +if TYPE_CHECKING: + RequestCallback = Callable[[requests_mock.Request, requests_mock.Context], str] +else: + RequestCallback = Callable + +MOCK_BASE_URL = "https://api.example.com" + + +class Route(NamedTuple): + method: str + pattern: Pattern[str] + callback: RequestCallback + + +class APIRouter: + def __init__(self, base_url: str): + self.routes: List[Route] = [] + self.base_url = base_url + + def _add_route( + self, method: str, pattern: str, func: RequestCallback + ) -> RequestCallback: + compiled_pattern = re.compile(f"{self.base_url}{pattern}") + self.routes.append(Route(method, compiled_pattern, func)) + return func + + def get(self, pattern: str) -> Callable[[RequestCallback], RequestCallback]: + def decorator(func: RequestCallback) -> RequestCallback: + return self._add_route("GET", pattern, func) + + return decorator + + def post(self, pattern: str) -> Callable[[RequestCallback], RequestCallback]: + def decorator(func: RequestCallback) -> RequestCallback: + return self._add_route("POST", pattern, func) + + return decorator + + def register_routes(self, mocker: requests_mock.Mocker) -> None: + for route in self.routes: + mocker.register_uri( + route.method, + route.pattern, + text=route.callback, + ) + + +router = APIRouter(MOCK_BASE_URL) + + +def serialize_page(records, page_number, total_pages, base_url, records_key="data"): + if records_key is None: + return json.dumps(records) + + response = { + records_key: records, + "page": page_number, + "total_pages": total_pages, + } + + if page_number < total_pages: + next_page = page_number + 1 + + scheme, netloc, path, _, _ = urlsplit(base_url) + next_page = urlunsplit([scheme, netloc, path, f"page={next_page}", ""]) + response["next_page"] = next_page + + return json.dumps(response) + + +def generate_posts(count=100): + return [{"id": i, "title": f"Post {i}"} for i in range(count)] + + +def generate_comments(post_id, count=50): + return [{"id": i, "body": f"Comment {i} for post {post_id}"} for i in range(count)] + + +def get_page_number(qs, key="page", default=1): + return int(qs.get(key, [default])[0]) + + +def paginate_response(request, records, page_size=10, records_key="data"): + page_number = get_page_number(request.qs) + total_records = len(records) + total_pages = (total_records + page_size - 1) // page_size + start_index = (page_number - 1) * 10 + end_index = start_index + 10 + records_slice = records[start_index:end_index] + return serialize_page( + records_slice, page_number, total_pages, request.url, records_key + ) + + +@pytest.fixture(scope="module") +def mock_api_server(): + with requests_mock.Mocker() as m: + + @router.get(r"/posts_no_key(\?page=\d+)?$") + def posts_no_key(request, context): + return paginate_response(request, generate_posts(), records_key=None) + + @router.get(r"/posts(\?page=\d+)?$") + def posts(request, context): + return paginate_response(request, generate_posts()) + + @router.get(r"/posts/(\d+)/comments") + def post_comments(request, context): + post_id = int(request.url.split("/")[-2]) + return paginate_response(request, generate_comments(post_id)) + + @router.get(r"/posts/\d+$") + def post_detail(request, context): + post_id = request.url.split("/")[-1] + return json.dumps({"id": post_id, "body": f"Post body {post_id}"}) + + @router.get(r"/posts/\d+/some_details_404") + def post_detail_404(request, context): + """Return 404 for post with id > 0. Used to test ignoring 404 errors.""" + post_id = int(request.url.split("/")[-2]) + if post_id < 1: + return json.dumps({"id": post_id, "body": f"Post body {post_id}"}) + else: + context.status_code = 404 + return json.dumps({"error": "Post not found"}) + + @router.get(r"/posts_under_a_different_key$") + def posts_with_results_key(request, context): + return paginate_response( + request, generate_posts(), records_key="many-results" + ) + + @router.get("/protected/posts/basic-auth") + def protected_basic_auth(request, context): + auth = request.headers.get("Authorization") + creds = "user:password" + creds_base64 = base64.b64encode(creds.encode()).decode() + if auth == f"Basic {creds_base64}": + return paginate_response(request, generate_posts()) + context.status_code = 401 + return json.dumps({"error": "Unauthorized"}) + + @router.get("/protected/posts/bearer-token") + def protected_bearer_token(request, context): + auth = request.headers.get("Authorization") + if auth == "Bearer test-token": + return paginate_response(request, generate_posts()) + context.status_code = 401 + return json.dumps({"error": "Unauthorized"}) + + @router.get("/protected/posts/bearer-token-plain-text-error") + def protected_bearer_token_plain_text_erorr(request, context): + auth = request.headers.get("Authorization") + if auth == "Bearer test-token": + return paginate_response(request, generate_posts()) + context.status_code = 401 + return "Unauthorized" + + @router.get("/protected/posts/api-key") + def protected_api_key(request, context): + api_key = request.headers.get("x-api-key") + if api_key == "test-api-key": + return paginate_response(request, generate_posts()) + context.status_code = 401 + return json.dumps({"error": "Unauthorized"}) + + @router.post("/oauth/token") + def oauth_token(request, context): + return json.dumps( + { + "access_token": "test-token", + "expires_in": 3600, + } + ) + + @router.post("/auth/refresh") + def refresh_token(request, context): + body = request.json() + if body.get("refresh_token") == "valid-refresh-token": + return json.dumps({"access_token": "new-valid-token"}) + context.status_code = 401 + return json.dumps({"error": "Invalid refresh token"}) + + router.register_routes(m) + + yield m diff --git a/tests/sources/helpers/rest_client/private_key.pem b/tests/sources/helpers/rest_client/private_key.pem new file mode 100644 index 0000000000..ce4592157b --- /dev/null +++ b/tests/sources/helpers/rest_client/private_key.pem @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDQQxVECHvO2Gs9 +MaRlD0HG5IpoJ3jhuG+nTgDEY7AU75nO74juOZuQR6AxO5nS/QeZS6bbjrzgz9P4 +vtDTksuSwXrgFJF1M5qiYwLZBr3ZNQA/e/D39+L2735craFsy8x6Xz5OCSCWaAyu +ufOMl1Yt2vRsDZ+x0OPPvKgUCBkgRMDxPbf4kuWnG/f4Z6czt3oReE6SiriT7EXS +ucNccSzgVs9HRopJ0M7jcbWPwGUfSlA3IO1G5sAEfVCihpzFlC7OoB+qAKj0wnAZ +Kr6gOuEFneoNUlErpLaeQwdRE+h61s5JybxZhFgr69n6kYIPG8ra6spVyB13WYt1 +FMEtL4P1AgMBAAECggEALv0vx2OdoaApZAt3Etk0J17JzrG3P8CIKqi6GhV+9V5R +JwRbMhrb21wZy/ntXVI7XG5aBbhJK/UgV8Of5Ni+Z0yRv4zMe/PqfCCYVCTGAYPI +nEpH5n7u3fXP3jPL0/sQlfy2108OY/kygVrR1YMQzfRUyStywGFIAUdI6gogtyt7 +cjh07mmMc8HUMhAVyluE5hpQCLDv5Xige2PY7zv1TqhI3OoJFi27VeBCSyI7x/94 +GM1XpzdFcvYPNPo6aE9vGnDq8TfYwjy+hkY+D9DRpnEmVEXmeBdsxsSD+ybyprO1 +C2sytiV9d3wJ96fhsYupLK88EGxU2uhmFntHuasMQQKBgQD9cWVo7B18FCV/NAdS +nV3KzNtlIrGRFZ7FMZuVZ/ZjOpvzbTVbla3YbRjTkXYpK9Meo8KczwzxQ2TQ1qxY +67SrhfFRRWzktMWqwBSKHPIig+DnqUCUo7OSA0pN+u6yUvFWdINZucB+yMWtgRrj +8GuAMXD/vaoCiNrHVf2V191fwQKBgQDSXP3cqBjBtDLP3qFwDzOG8cR9qiiDvesQ +DXf5seV/rBCXZvkw81t+PGz0O/UrUonv/FqxQR0GqpAdX1ZM3Jko0WxbfoCgsT0u +1aSzcMq1JQt0CI77T8tIPYvym9FO+Jz89kX0WliL/I7GLsmG5EYBK/+dcJBh1QCE +VaMCgrbxNQKBgB10zYWJU8/1A3qqUGOQuLL2ZlV11892BNMEdgHCaIeV60Q6oCX5 +2o+59lW4pVQZrNr1y4uwIN/1pkUDflqDYqdA1RBOEl7uh77Vvk1jGd1bGIu0RzY/ +ZIKG8V7o2E9Pho820YFfLnlN2nPU+owdiFEI7go7QAQ1ZcAfRW7h/O/BAoGBAJg+ +IKO/LBuUFGoIT4HQHpR9CJ2BtkyR+Drn5HpbWyKpHmDUb2gT15VmmduwQOEXnSiH +1AMQgrc+XYpEYyrBRD8cQXV9+g1R+Fua1tXevXWX19AkGYab2xzvHgd46WRj3Qne +GgacFBVLtPCND+CF+HwEobwJqRSEmRks+QpqG4g5AoGAXpw9CZb+gYfwl2hphFGO +kT/NOfk8PN7WeZAe7ktStZByiGhHWaxqYE0q5favhNG6tMxSdmSOzYF8liHWuvJm +cDHqNVJeTGT8rjW7Iz08wj5F+ZAJYCMkM9aDpDUKJIHnOwYZCGfZxRJCiHTReyR7 +u03hoszfCn13l85qBnYlwaw= +-----END PRIVATE KEY----- diff --git a/tests/sources/helpers/rest_client/test_client.py b/tests/sources/helpers/rest_client/test_client.py new file mode 100644 index 0000000000..9984dabb06 --- /dev/null +++ b/tests/sources/helpers/rest_client/test_client.py @@ -0,0 +1,173 @@ +import os +import pytest +from typing import Any, cast +from dlt.common.typing import TSecretStrValue +from dlt.sources.helpers.requests import Response, Request +from dlt.sources.helpers.rest_client import RESTClient +from dlt.sources.helpers.rest_client.client import Hooks +from dlt.sources.helpers.rest_client.paginators import JSONResponsePaginator + +from dlt.sources.helpers.rest_client.auth import AuthConfigBase +from dlt.sources.helpers.rest_client.auth import ( + BearerTokenAuth, + APIKeyAuth, + HttpBasicAuth, + OAuthJWTAuth, +) +from dlt.sources.helpers.rest_client.exceptions import IgnoreResponseException + + +def load_private_key(name="private_key.pem"): + key_path = os.path.join(os.path.dirname(__file__), name) + with open(key_path, "r", encoding="utf-8") as key_file: + return key_file.read() + + +TEST_PRIVATE_KEY = load_private_key() + + +@pytest.fixture +def rest_client() -> RESTClient: + return RESTClient( + base_url="https://api.example.com", + headers={"Accept": "application/json"}, + ) + + +@pytest.mark.usefixtures("mock_api_server") +class TestRESTClient: + def _assert_pagination(self, pages): + for i, page in enumerate(pages): + assert page == [ + {"id": i, "title": f"Post {i}"} for i in range(i * 10, (i + 1) * 10) + ] + + def test_get_single_resource(self, rest_client): + response = rest_client.get("/posts/1") + assert response.status_code == 200 + assert response.json() == {"id": "1", "body": "Post body 1"} + + def test_pagination(self, rest_client: RESTClient): + pages_iter = rest_client.paginate( + "/posts", + paginator=JSONResponsePaginator(next_url_path="next_page"), + ) + + pages = list(pages_iter) + + self._assert_pagination(pages) + + def test_page_context(self, rest_client: RESTClient) -> None: + for page in rest_client.paginate( + "/posts", + paginator=JSONResponsePaginator(next_url_path="next_page"), + auth=AuthConfigBase(), + ): + # response that produced data + assert isinstance(page.response, Response) + # updated request + assert isinstance(page.request, Request) + # make request url should be same as next link in paginator + if page.paginator.has_next_page: + assert page.paginator.next_reference == page.request.url + + def test_default_paginator(self, rest_client: RESTClient): + pages_iter = rest_client.paginate("/posts") + + pages = list(pages_iter) + + self._assert_pagination(pages) + + def test_paginate_with_hooks(self, rest_client: RESTClient): + def response_hook(response: Response, *args: Any, **kwargs: Any) -> None: + if response.status_code == 404: + raise IgnoreResponseException + + hooks: Hooks = { + "response": response_hook, + } + + pages_iter = rest_client.paginate( + "/posts", + paginator=JSONResponsePaginator(next_url_path="next_page"), + hooks=hooks, + ) + + pages = list(pages_iter) + + self._assert_pagination(pages) + + pages_iter = rest_client.paginate( + "/posts/1/some_details_404", + paginator=JSONResponsePaginator(), + hooks=hooks, + ) + + pages = list(pages_iter) + assert pages == [] + + def test_basic_auth_success(self, rest_client: RESTClient): + response = rest_client.get( + "/protected/posts/basic-auth", + auth=HttpBasicAuth("user", cast(TSecretStrValue, "password")), + ) + assert response.status_code == 200 + assert response.json()["data"][0] == {"id": 0, "title": "Post 0"} + + pages_iter = rest_client.paginate( + "/protected/posts/basic-auth", + auth=HttpBasicAuth("user", cast(TSecretStrValue, "password")), + ) + + pages = list(pages_iter) + self._assert_pagination(pages) + + def test_bearer_token_auth_success(self, rest_client: RESTClient): + response = rest_client.get( + "/protected/posts/bearer-token", + auth=BearerTokenAuth(cast(TSecretStrValue, "test-token")), + ) + assert response.status_code == 200 + assert response.json()["data"][0] == {"id": 0, "title": "Post 0"} + + pages_iter = rest_client.paginate( + "/protected/posts/bearer-token", + auth=BearerTokenAuth(cast(TSecretStrValue, "test-token")), + ) + + pages = list(pages_iter) + self._assert_pagination(pages) + + def test_api_key_auth_success(self, rest_client: RESTClient): + response = rest_client.get( + "/protected/posts/api-key", + auth=APIKeyAuth( + name="x-api-key", api_key=cast(TSecretStrValue, "test-api-key") + ), + ) + assert response.status_code == 200 + assert response.json()["data"][0] == {"id": 0, "title": "Post 0"} + + def test_oauth_jwt_auth_success(self, rest_client: RESTClient): + auth = OAuthJWTAuth( + client_id="test-client-id", + private_key=TEST_PRIVATE_KEY, + auth_endpoint="https://api.example.com/oauth/token", + scopes=["read", "write"], + headers={"Content-Type": "application/json"}, + ) + + response = rest_client.get( + "/protected/posts/bearer-token", + auth=auth, + ) + + assert response.status_code == 200 + assert "test-token" in response.request.headers["Authorization"] + + pages_iter = rest_client.paginate( + "/protected/posts/bearer-token", + auth=auth, + ) + + self._assert_pagination(list(pages_iter)) diff --git a/tests/sources/helpers/rest_client/test_detector.py b/tests/sources/helpers/rest_client/test_detector.py new file mode 100644 index 0000000000..a9af1d36a4 --- /dev/null +++ b/tests/sources/helpers/rest_client/test_detector.py @@ -0,0 +1,360 @@ +import pytest +from dlt.common import jsonpath + +from dlt.sources.helpers.rest_client.detector import ( + find_records, + find_next_page_path, + single_entity_path, +) + + +TEST_RESPONSES = [ + { + "response": { + "data": [{"id": 1, "name": "Item 1"}, {"id": 2, "name": "Item 2"}], + "pagination": {"offset": 0, "limit": 2, "total": 100}, + }, + "expected": { + "type": "offset_limit", + "records_path": "data", + }, + }, + { + "response": { + "items": [ + {"id": 11, "title": "Page Item 1"}, + {"id": 12, "title": "Page Item 2"}, + ], + "page_info": {"current_page": 1, "items_per_page": 2, "total_pages": 50}, + }, + "expected": { + "type": "page_number", + "records_path": "items", + }, + }, + { + "response": { + "products": [ + {"id": 101, "name": "Product 1"}, + {"id": 102, "name": "Product 2"}, + ], + "next_cursor": "eyJpZCI6MTAyfQ==", + }, + "expected": { + "type": "cursor", + "records_path": "products", + "next_path": ["next_cursor"], + }, + }, + { + "response": { + "results": [ + {"id": 201, "description": "Result 1"}, + {"id": 202, "description": "Result 2"}, + ], + "cursors": {"next": "NjM=", "previous": "MTk="}, + }, + "expected": { + "type": "cursor", + "records_path": "results", + "next_path": ["cursors", "next"], + }, + }, + { + "response": { + "entries": [{"id": 31, "value": "Entry 1"}, {"id": 32, "value": "Entry 2"}], + "next_id": 33, + "limit": 2, + }, + "expected": { + "type": "cursor", + "records_path": "entries", + "next_path": ["next_id"], + }, + }, + { + "response": { + "comments": [ + {"id": 51, "text": "Comment 1"}, + {"id": 52, "text": "Comment 2"}, + ], + "page_number": 3, + "total_pages": 15, + }, + "expected": { + "type": "page_number", + "records_path": "comments", + }, + }, + { + "response": { + "count": 1023, + "next": "https://api.example.org/accounts/?page=5", + "previous": "https://api.example.org/accounts/?page=3", + "results": [{"id": 1, "name": "Account 1"}, {"id": 2, "name": "Account 2"}], + }, + "expected": { + "type": "json_link", + "records_path": "results", + "next_path": ["next"], + }, + }, + { + "response": { + "_embedded": { + "items": [{"id": 1, "name": "Item 1"}, {"id": 2, "name": "Item 2"}] + }, + "_links": { + "first": {"href": "http://api.example.com/items?page=0&size=2"}, + "self": {"href": "http://api.example.com/items?page=1&size=2"}, + "next": {"href": "http://api.example.com/items?page=2&size=2"}, + "last": {"href": "http://api.example.com/items?page=50&size=2"}, + }, + "page": {"size": 2, "totalElements": 100, "totalPages": 50, "number": 1}, + }, + "expected": { + "type": "json_link", + "records_path": "_embedded.items", + "next_path": ["_links", "next", "href"], + }, + }, + { + "response": { + "items": [{"id": 1, "name": "Item 1"}, {"id": 2, "name": "Item 2"}], + "meta": { + "currentPage": 1, + "pageSize": 2, + "totalPages": 50, + "totalItems": 100, + }, + "links": { + "firstPage": "/items?page=1&limit=2", + "previousPage": "/items?page=0&limit=2", + "nextPage": "/items?page=2&limit=2", + "lastPage": "/items?page=50&limit=2", + }, + }, + "expected": { + "type": "json_link", + "records_path": "items", + "next_path": ["links", "nextPage"], + }, + }, + { + "response": { + "data": [{"id": 1, "name": "Item 1"}, {"id": 2, "name": "Item 2"}], + "pagination": { + "currentPage": 1, + "pageSize": 2, + "totalPages": 5, + "totalItems": 10, + }, + }, + "expected": { + "type": "page_number", + "records_path": "data", + }, + }, + { + "response": { + "items": [{"id": 1, "title": "Item 1"}, {"id": 2, "title": "Item 2"}], + "pagination": {"page": 1, "perPage": 2, "total": 10, "totalPages": 5}, + }, + "expected": { + "type": "page_number", + "records_path": "items", + }, + }, + { + "response": { + "data": [ + {"id": 1, "description": "Item 1"}, + {"id": 2, "description": "Item 2"}, + ], + "meta": { + "currentPage": 1, + "itemsPerPage": 2, + "totalItems": 10, + "totalPages": 5, + }, + "links": { + "first": "/api/items?page=1", + "previous": None, + "next": "/api/items?page=2", + "last": "/api/items?page=5", + }, + }, + "expected": { + "type": "json_link", + "records_path": "data", + "next_path": ["links", "next"], + }, + }, + { + "response": { + "page": 2, + "per_page": 10, + "total": 100, + "pages": 10, + "data": [{"id": 1, "name": "Item 1"}, {"id": 2, "name": "Item 2"}], + }, + "expected": { + "type": "page_number", + "records_path": "data", + }, + }, + { + "response": { + "currentPage": 1, + "pageSize": 10, + "totalPages": 5, + "totalRecords": 50, + "items": [{"id": 1, "name": "Item 1"}, {"id": 2, "name": "Item 2"}], + }, + "expected": { + "type": "page_number", + "records_path": "items", + }, + }, + { + "response": { + "articles": [ + {"id": 21, "headline": "Article 1"}, + {"id": 22, "headline": "Article 2"}, + ], + "paging": {"current": 3, "size": 2, "total": 60}, + }, + "expected": { + "type": "page_number", + "records_path": "articles", + }, + }, + { + "response": { + "feed": [ + {"id": 41, "content": "Feed Content 1"}, + {"id": 42, "content": "Feed Content 2"}, + ], + "offset": 40, + "limit": 2, + "total_count": 200, + }, + "expected": { + "type": "offset_limit", + "records_path": "feed", + }, + }, + { + "response": { + "query_results": [ + {"id": 81, "snippet": "Result Snippet 1"}, + {"id": 82, "snippet": "Result Snippet 2"}, + ], + "page_details": { + "number": 1, + "size": 2, + "total_elements": 50, + "total_pages": 25, + }, + }, + "expected": { + "type": "page_number", + "records_path": "query_results", + }, + }, + { + "response": { + "posts": [ + {"id": 91, "title": "Blog Post 1"}, + {"id": 92, "title": "Blog Post 2"}, + ], + "pagination_details": { + "current_page": 4, + "posts_per_page": 2, + "total_posts": 100, + "total_pages": 50, + }, + }, + "expected": { + "type": "page_number", + "records_path": "posts", + }, + }, + { + "response": { + "catalog": [ + {"id": 101, "product_name": "Product A"}, + {"id": 102, "product_name": "Product B"}, + ], + "page_metadata": { + "index": 1, + "size": 2, + "total_items": 20, + "total_pages": 10, + }, + }, + "expected": { + "type": "page_number", + "records_path": "catalog", + }, + }, +] + + +@pytest.mark.parametrize("test_case", TEST_RESPONSES) +def test_find_records(test_case): + response = test_case["response"] + expected = test_case["expected"]["records_path"] + r = find_records(response) + # all of them look fine mostly because those are simple cases... + # case 7 fails because it is nested but in fact we select a right response + # assert r is create_nested_accessor(expected)(response) + assert r == jsonpath.find_values(expected, response)[0] + + +@pytest.mark.parametrize("test_case", TEST_RESPONSES) +def test_find_next_page_key(test_case): + response = test_case["response"] + expected = test_case.get("expected").get( + "next_path", None + ) # Some cases may not have next_path + assert find_next_page_path(response) == expected + + +@pytest.mark.skip +@pytest.mark.parametrize( + "path", + [ + "/users/{user_id}", + "/api/v1/products/{product_id}/", + "/api/v1/products/{product_id}//", + "/api/v1/products/{product_id}?param1=value1", + "/api/v1/products/{product_id}#section", + "/api/v1/products/{product_id}/#section", + "/users/{user_id}/posts/{post_id}", + "/users/{user_id}/posts/{post_id}/comments/{comment_id}", + "{entity}", + "/{entity}", + "/{user_123}", + ], +) +def test_single_entity_path_valid(path): + assert single_entity_path(path) is True + + +@pytest.mark.parametrize( + "path", + [ + "/users/user_id", + "/api/v1/products/product_id/", + "/users/{user_id}/details", + "/", + "/{}", + "/users/{123}", + "/users/{user-id}", + "/users/{user id}", + "/users/{user_id}/{", # Invalid ending + ], +) +def test_single_entity_path_invalid(path): + assert single_entity_path(path) is False diff --git a/tests/sources/helpers/rest_client/test_paginators.py b/tests/sources/helpers/rest_client/test_paginators.py new file mode 100644 index 0000000000..cc4dea65dc --- /dev/null +++ b/tests/sources/helpers/rest_client/test_paginators.py @@ -0,0 +1,82 @@ +import pytest +from unittest.mock import Mock + +from requests.models import Response + +from dlt.sources.helpers.rest_client.paginators import ( + SinglePagePaginator, + OffsetPaginator, + HeaderLinkPaginator, + JSONResponsePaginator, +) + + +class TestHeaderLinkPaginator: + def test_update_state_with_next(self): + paginator = HeaderLinkPaginator() + response = Mock(Response) + response.links = {"next": {"url": "http://example.com/next"}} + paginator.update_state(response) + assert paginator.next_reference == "http://example.com/next" + assert paginator.has_next_page is True + + def test_update_state_without_next(self): + paginator = HeaderLinkPaginator() + response = Mock(Response) + response.links = {} + paginator.update_state(response) + assert paginator.has_next_page is False + + +class TestJSONResponsePaginator: + def test_update_state_with_next(self): + paginator = JSONResponsePaginator() + response = Mock( + Response, json=lambda: {"next": "http://example.com/next", "results": []} + ) + paginator.update_state(response) + assert paginator.next_reference == "http://example.com/next" + assert paginator.has_next_page is True + + def test_update_state_without_next(self): + paginator = JSONResponsePaginator() + response = Mock(Response, json=lambda: {"results": []}) + paginator.update_state(response) + assert paginator.next_reference is None + assert paginator.has_next_page is False + + +class TestSinglePagePaginator: + def test_update_state(self): + paginator = SinglePagePaginator() + response = Mock(Response) + paginator.update_state(response) + assert paginator.has_next_page is False + + def test_update_state_with_next(self): + paginator = SinglePagePaginator() + response = Mock( + Response, json=lambda: {"next": "http://example.com/next", "results": []} + ) + response.links = {"next": {"url": "http://example.com/next"}} + paginator.update_state(response) + assert paginator.has_next_page is False + + +class TestOffsetPaginator: + def test_update_state(self): + paginator = OffsetPaginator(initial_offset=0, initial_limit=10) + response = Mock(Response, json=lambda: {"total": 20}) + paginator.update_state(response) + assert paginator.offset == 10 + assert paginator.has_next_page is True + + # Test for reaching the end + paginator.update_state(response) + assert paginator.has_next_page is False + + def test_update_state_without_total(self): + paginator = OffsetPaginator(0, 10) + response = Mock(Response, json=lambda: {}) + with pytest.raises(ValueError): + paginator.update_state(response) From 600f2cde44438297bed167cfa1a80921c66c2807 Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Mon, 25 Mar 2024 09:56:09 +0300 Subject: [PATCH 02/15] Add PyJWT --- poetry.lock | 33 +-------------------------------- pyproject.toml | 1 + 2 files changed, 2 insertions(+), 32 deletions(-) diff --git a/poetry.lock b/poetry.lock index 96e730bf3a..b3f9dfa368 100644 --- a/poetry.lock +++ b/poetry.lock @@ -4681,16 +4681,6 @@ files = [ {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-win32.whl", hash = "sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"}, @@ -7081,7 +7071,6 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, - {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -7089,16 +7078,8 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, - {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, - {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, - {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, - {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -7115,7 +7096,6 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, - {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -7123,7 +7103,6 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, - {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -8011,7 +7990,6 @@ files = [ {file = "SQLAlchemy-1.4.49-cp27-cp27mu-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:03db81b89fe7ef3857b4a00b63dedd632d6183d4ea5a31c5d8a92e000a41fc71"}, {file = "SQLAlchemy-1.4.49-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:95b9df9afd680b7a3b13b38adf6e3a38995da5e162cc7524ef08e3be4e5ed3e1"}, {file = "SQLAlchemy-1.4.49-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a63e43bf3f668c11bb0444ce6e809c1227b8f067ca1068898f3008a273f52b09"}, - {file = "SQLAlchemy-1.4.49-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ca46de16650d143a928d10842939dab208e8d8c3a9a8757600cae9b7c579c5cd"}, {file = "SQLAlchemy-1.4.49-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:f835c050ebaa4e48b18403bed2c0fda986525896efd76c245bdd4db995e51a4c"}, {file = "SQLAlchemy-1.4.49-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c21b172dfb22e0db303ff6419451f0cac891d2e911bb9fbf8003d717f1bcf91"}, {file = "SQLAlchemy-1.4.49-cp310-cp310-win32.whl", hash = "sha256:5fb1ebdfc8373b5a291485757bd6431de8d7ed42c27439f543c81f6c8febd729"}, @@ -8021,35 +7999,26 @@ files = [ {file = "SQLAlchemy-1.4.49-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5debe7d49b8acf1f3035317e63d9ec8d5e4d904c6e75a2a9246a119f5f2fdf3d"}, {file = "SQLAlchemy-1.4.49-cp311-cp311-win32.whl", hash = "sha256:82b08e82da3756765c2e75f327b9bf6b0f043c9c3925fb95fb51e1567fa4ee87"}, {file = "SQLAlchemy-1.4.49-cp311-cp311-win_amd64.whl", hash = "sha256:171e04eeb5d1c0d96a544caf982621a1711d078dbc5c96f11d6469169bd003f1"}, - {file = "SQLAlchemy-1.4.49-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f23755c384c2969ca2f7667a83f7c5648fcf8b62a3f2bbd883d805454964a800"}, - {file = "SQLAlchemy-1.4.49-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8396e896e08e37032e87e7fbf4a15f431aa878c286dc7f79e616c2feacdb366c"}, - {file = "SQLAlchemy-1.4.49-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66da9627cfcc43bbdebd47bfe0145bb662041472393c03b7802253993b6b7c90"}, - {file = "SQLAlchemy-1.4.49-cp312-cp312-win32.whl", hash = "sha256:9a06e046ffeb8a484279e54bda0a5abfd9675f594a2e38ef3133d7e4d75b6214"}, - {file = "SQLAlchemy-1.4.49-cp312-cp312-win_amd64.whl", hash = "sha256:7cf8b90ad84ad3a45098b1c9f56f2b161601e4670827d6b892ea0e884569bd1d"}, {file = "SQLAlchemy-1.4.49-cp36-cp36m-macosx_10_14_x86_64.whl", hash = "sha256:36e58f8c4fe43984384e3fbe6341ac99b6b4e083de2fe838f0fdb91cebe9e9cb"}, {file = "SQLAlchemy-1.4.49-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b31e67ff419013f99ad6f8fc73ee19ea31585e1e9fe773744c0f3ce58c039c30"}, - {file = "SQLAlchemy-1.4.49-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ebc22807a7e161c0d8f3da34018ab7c97ef6223578fcdd99b1d3e7ed1100a5db"}, {file = "SQLAlchemy-1.4.49-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:c14b29d9e1529f99efd550cd04dbb6db6ba5d690abb96d52de2bff4ed518bc95"}, {file = "SQLAlchemy-1.4.49-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c40f3470e084d31247aea228aa1c39bbc0904c2b9ccbf5d3cfa2ea2dac06f26d"}, {file = "SQLAlchemy-1.4.49-cp36-cp36m-win32.whl", hash = "sha256:706bfa02157b97c136547c406f263e4c6274a7b061b3eb9742915dd774bbc264"}, {file = "SQLAlchemy-1.4.49-cp36-cp36m-win_amd64.whl", hash = "sha256:a7f7b5c07ae5c0cfd24c2db86071fb2a3d947da7bd487e359cc91e67ac1c6d2e"}, {file = "SQLAlchemy-1.4.49-cp37-cp37m-macosx_11_0_x86_64.whl", hash = "sha256:4afbbf5ef41ac18e02c8dc1f86c04b22b7a2125f2a030e25bbb4aff31abb224b"}, {file = "SQLAlchemy-1.4.49-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:24e300c0c2147484a002b175f4e1361f102e82c345bf263242f0449672a4bccf"}, - {file = "SQLAlchemy-1.4.49-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:393cd06c3b00b57f5421e2133e088df9cabcececcea180327e43b937b5a7caa5"}, {file = "SQLAlchemy-1.4.49-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:201de072b818f8ad55c80d18d1a788729cccf9be6d9dc3b9d8613b053cd4836d"}, {file = "SQLAlchemy-1.4.49-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7653ed6817c710d0c95558232aba799307d14ae084cc9b1f4c389157ec50df5c"}, {file = "SQLAlchemy-1.4.49-cp37-cp37m-win32.whl", hash = "sha256:647e0b309cb4512b1f1b78471fdaf72921b6fa6e750b9f891e09c6e2f0e5326f"}, {file = "SQLAlchemy-1.4.49-cp37-cp37m-win_amd64.whl", hash = "sha256:ab73ed1a05ff539afc4a7f8cf371764cdf79768ecb7d2ec691e3ff89abbc541e"}, {file = "SQLAlchemy-1.4.49-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:37ce517c011560d68f1ffb28af65d7e06f873f191eb3a73af5671e9c3fada08a"}, {file = "SQLAlchemy-1.4.49-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a1878ce508edea4a879015ab5215546c444233881301e97ca16fe251e89f1c55"}, - {file = "SQLAlchemy-1.4.49-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95ab792ca493891d7a45a077e35b418f68435efb3e1706cb8155e20e86a9013c"}, {file = "SQLAlchemy-1.4.49-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:0e8e608983e6f85d0852ca61f97e521b62e67969e6e640fe6c6b575d4db68557"}, {file = "SQLAlchemy-1.4.49-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ccf956da45290df6e809ea12c54c02ace7f8ff4d765d6d3dfb3655ee876ce58d"}, {file = "SQLAlchemy-1.4.49-cp38-cp38-win32.whl", hash = "sha256:f167c8175ab908ce48bd6550679cc6ea20ae169379e73c7720a28f89e53aa532"}, {file = "SQLAlchemy-1.4.49-cp38-cp38-win_amd64.whl", hash = "sha256:45806315aae81a0c202752558f0df52b42d11dd7ba0097bf71e253b4215f34f4"}, {file = "SQLAlchemy-1.4.49-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:b6d0c4b15d65087738a6e22e0ff461b407533ff65a73b818089efc8eb2b3e1de"}, {file = "SQLAlchemy-1.4.49-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a843e34abfd4c797018fd8d00ffffa99fd5184c421f190b6ca99def4087689bd"}, - {file = "SQLAlchemy-1.4.49-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:738d7321212941ab19ba2acf02a68b8ee64987b248ffa2101630e8fccb549e0d"}, {file = "SQLAlchemy-1.4.49-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:1c890421651b45a681181301b3497e4d57c0d01dc001e10438a40e9a9c25ee77"}, {file = "SQLAlchemy-1.4.49-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d26f280b8f0a8f497bc10573849ad6dc62e671d2468826e5c748d04ed9e670d5"}, {file = "SQLAlchemy-1.4.49-cp39-cp39-win32.whl", hash = "sha256:ec2268de67f73b43320383947e74700e95c6770d0c68c4e615e9897e46296294"}, @@ -9066,4 +9035,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<3.13" -content-hash = "99658baf1bfda2ac065bda897637cae0eb122c76777688a7d606df0ef06c7fcc" +content-hash = "c43c7a1ff57aef576ae43d825ef45215783d12e5d7fd9b6f870862db967b1fb1" diff --git a/pyproject.toml b/pyproject.toml index de5f8055c5..4ed1f77823 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,6 +80,7 @@ pyodbc = {version = "^4.0.39", optional = true} qdrant-client = {version = "^1.6.4", optional = true, extras = ["fastembed"]} databricks-sql-connector = {version = ">=2.9.3,<3.0.0", optional = true} dbt-databricks = {version = "^1.7.3", optional = true} +pyjwt = "^2.8.0" [tool.poetry.extras] dbt = ["dbt-core", "dbt-redshift", "dbt-bigquery", "dbt-duckdb", "dbt-snowflake", "dbt-athena-community", "dbt-databricks"] From 02d9fecb0f0931b5eb9958f96aa981aacd67fec5 Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Mon, 25 Mar 2024 10:59:29 +0300 Subject: [PATCH 03/15] Add initial version of `rest_client.paginate()` --- dlt/sources/helpers/rest_client/__init__.py | 45 ++++++++++ dlt/sources/helpers/rest_client/client.py | 7 +- dlt/sources/helpers/rest_client/typing.py | 8 ++ dlt/sources/helpers/rest_client/utils.py | 16 +++- .../helpers/rest_client/test_client.py | 12 ++- .../sources/helpers/rest_client/test_utils.py | 90 +++++++++++++++++++ 6 files changed, 168 insertions(+), 10 deletions(-) create mode 100644 tests/sources/helpers/rest_client/test_utils.py diff --git a/dlt/sources/helpers/rest_client/__init__.py b/dlt/sources/helpers/rest_client/__init__.py index 3264ea4aae..fd5d558018 100644 --- a/dlt/sources/helpers/rest_client/__init__.py +++ b/dlt/sources/helpers/rest_client/__init__.py @@ -1 +1,46 @@ +from typing import Optional, Dict, Iterator, Union, Any + +from dlt.common import jsonpath + from .client import RESTClient # noqa: F401 +from .client import PageData +from .auth import AuthConfigBase +from .paginators import BasePaginator +from .typing import HTTPMethodBasic, Hooks + + +def paginate( + url: str, + method: HTTPMethodBasic = "GET", + headers: Optional[Dict[str, str]] = None, + params: Optional[Dict[str, Any]] = None, + json: Optional[Dict[str, Any]] = None, + auth: AuthConfigBase = None, + paginator: Union[str, BasePaginator] = None, + data_selector: Optional[jsonpath.TJsonPath] = None, + hooks: Optional[Hooks] = None, +) -> Iterator[PageData[Any]]: + """ + Paginate over a REST API endpoint. + + Args: + url: URL to paginate over. + **kwargs: Keyword arguments to pass to `RESTClient.paginate`. + + Returns: + Iterator[Page]: Iterator over pages. + """ + client = RESTClient( + base_url=url, + headers=headers, + ) + return client.paginate( + path="", + method=method, + params=params, + json=json, + auth=auth, + paginator=paginator, + data_selector=data_selector, + hooks=hooks, + ) diff --git a/dlt/sources/helpers/rest_client/client.py b/dlt/sources/helpers/rest_client/client.py index 12e22c072d..4b5625eebe 100644 --- a/dlt/sources/helpers/rest_client/client.py +++ b/dlt/sources/helpers/rest_client/client.py @@ -6,8 +6,6 @@ Any, TypeVar, Iterable, - Union, - Callable, cast, ) import copy @@ -20,7 +18,7 @@ from dlt.sources.helpers.requests.retry import Client from dlt.sources.helpers.requests import Response, Request -from .typing import HTTPMethodBasic, HTTPMethod +from .typing import HTTPMethodBasic, HTTPMethod, Hooks from .paginators import BasePaginator from .auth import AuthConfigBase from .detector import PaginatorFactory, find_records @@ -30,9 +28,6 @@ _T = TypeVar("_T") -HookFunction = Callable[[Response, Any, Any], None] -HookEvent = Union[HookFunction, List[HookFunction]] -Hooks = Dict[str, HookEvent] class PageData(List[_T]): diff --git a/dlt/sources/helpers/rest_client/typing.py b/dlt/sources/helpers/rest_client/typing.py index dad9842071..626aee4877 100644 --- a/dlt/sources/helpers/rest_client/typing.py +++ b/dlt/sources/helpers/rest_client/typing.py @@ -1,9 +1,17 @@ from typing import ( + List, + Dict, Union, Literal, + Callable, + Any, ) +from dlt.sources.helpers.requests import Response HTTPMethodBasic = Literal["GET", "POST"] HTTPMethodExtended = Literal["PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"] HTTPMethod = Union[HTTPMethodBasic, HTTPMethodExtended] +HookFunction = Callable[[Response, Any, Any], None] +HookEvent = Union[HookFunction, List[HookFunction]] +Hooks = Dict[str, HookEvent] diff --git a/dlt/sources/helpers/rest_client/utils.py b/dlt/sources/helpers/rest_client/utils.py index 6001fad4ec..8732437a88 100644 --- a/dlt/sources/helpers/rest_client/utils.py +++ b/dlt/sources/helpers/rest_client/utils.py @@ -1,14 +1,24 @@ -from functools import reduce -from operator import getitem -from typing import Any, Sequence, Union, Tuple +from typing import Tuple from dlt.common import logger from dlt.extract.source import DltSource def join_url(base_url: str, path: str) -> str: + if base_url is None: + raise ValueError("Base URL must be provided or set to an empty string.") + + if base_url == "": + return path + + if path == "": + return base_url + + # Normalize the base URL + base_url = base_url.rstrip("/") if not base_url.endswith("/"): base_url += "/" + return base_url + path.lstrip("/") diff --git a/tests/sources/helpers/rest_client/test_client.py b/tests/sources/helpers/rest_client/test_client.py index 9984dabb06..17d445042c 100644 --- a/tests/sources/helpers/rest_client/test_client.py +++ b/tests/sources/helpers/rest_client/test_client.py @@ -3,7 +3,7 @@ from typing import Any, cast from dlt.common.typing import TSecretStrValue from dlt.sources.helpers.requests import Response, Request -from dlt.sources.helpers.rest_client import RESTClient +from dlt.sources.helpers.rest_client import RESTClient, paginate from dlt.sources.helpers.rest_client.client import Hooks from dlt.sources.helpers.rest_client.paginators import JSONResponsePaginator @@ -171,3 +171,13 @@ def test_oauth_jwt_auth_success(self, rest_client: RESTClient): ) self._assert_pagination(list(pages_iter)) + + def test_paginate_function(self, rest_client: RESTClient): + pages_iter = paginate( + "https://api.example.com/posts", + paginator=JSONResponsePaginator(next_url_path="next_page"), + ) + + pages = list(pages_iter) + + self._assert_pagination(pages) diff --git a/tests/sources/helpers/rest_client/test_utils.py b/tests/sources/helpers/rest_client/test_utils.py new file mode 100644 index 0000000000..0de9729a42 --- /dev/null +++ b/tests/sources/helpers/rest_client/test_utils.py @@ -0,0 +1,90 @@ +import pytest +from dlt.sources.helpers.rest_client.utils import join_url + + +@pytest.mark.parametrize( + "base_url, path, expected", + [ + # Normal cases + ( + "http://example.com", + "path/to/resource", + "http://example.com/path/to/resource", + ), + ( + "http://example.com/", + "/path/to/resource", + "http://example.com/path/to/resource", + ), + ( + "http://example.com/", + "path/to/resource", + "http://example.com/path/to/resource", + ), + ( + "http://example.com", + "//path/to/resource", + "http://example.com/path/to/resource", + ), + ( + "http://example.com///", + "//path/to/resource", + "http://example.com/path/to/resource", + ), + # Trailing and leading slashes + ("http://example.com/", "/", "http://example.com/"), + ("http://example.com", "/", "http://example.com/"), + ("http://example.com/", "///", "http://example.com/"), + ("http://example.com", "///", "http://example.com/"), + ("/", "path/to/resource", "/path/to/resource"), + ("/", "/path/to/resource", "/path/to/resource"), + # Empty strings + ("", "", ""), + ( + "", + "http://example.com/path/to/resource", + "http://example.com/path/to/resource", + ), + ("", "path/to/resource", "path/to/resource"), + ("http://example.com", "", "http://example.com"), + # Query parameters and fragments + ( + "http://example.com", + "path/to/resource?query=123", + "http://example.com/path/to/resource?query=123", + ), + ( + "http://example.com/", + "path/to/resource#fragment", + "http://example.com/path/to/resource#fragment", + ), + # Special characters in the path + ( + "http://example.com", + "/path/to/resource with spaces", + "http://example.com/path/to/resource with spaces", + ), + ("http://example.com", "/path/with/中文", "http://example.com/path/with/中文"), + # Protocols and subdomains + ("https://sub.example.com", "path", "https://sub.example.com/path"), + ("ftp://example.com", "/path", "ftp://example.com/path"), + # Missing protocol in base_url + ("example.com", "path", "example.com/path"), + ], +) +def test_join_url(base_url, path, expected): + assert join_url(base_url, path) == expected + + +@pytest.mark.parametrize( + "base_url, path, exception", + [ + (None, "path", ValueError), + ("http://example.com", None, AttributeError), + (123, "path", AttributeError), + ("http://example.com", 123, AttributeError), + ], +) +def test_join_url_invalid_input_types(base_url, path, exception): + with pytest.raises(exception): + join_url(base_url, path) From b52c97ba15c73fe4245c0ce86d9800cd5786a0b3 Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Mon, 25 Mar 2024 11:41:39 +0300 Subject: [PATCH 04/15] Export `rest_client.paginate` to `helpers.requests` module --- dlt/sources/helpers/requests/__init__.py | 4 +++- tests/sources/helpers/rest_client/test_client.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/dlt/sources/helpers/requests/__init__.py b/dlt/sources/helpers/requests/__init__.py index 3e29a2cf52..d76e24ec42 100644 --- a/dlt/sources/helpers/requests/__init__.py +++ b/dlt/sources/helpers/requests/__init__.py @@ -15,11 +15,12 @@ from requests.exceptions import ChunkedEncodingError from dlt.sources.helpers.requests.retry import Client from dlt.sources.helpers.requests.session import Session +from dlt.sources.helpers.rest_client import paginate from dlt.common.configuration.specs import RunConfiguration client = Client() -get, post, put, patch, delete, options, head, request = ( +get, post, put, patch, delete, options, head, request, paginate = ( client.get, client.post, client.put, @@ -28,6 +29,7 @@ client.options, client.head, client.request, + paginate, ) diff --git a/tests/sources/helpers/rest_client/test_client.py b/tests/sources/helpers/rest_client/test_client.py index 17d445042c..568aff4b78 100644 --- a/tests/sources/helpers/rest_client/test_client.py +++ b/tests/sources/helpers/rest_client/test_client.py @@ -2,8 +2,8 @@ import pytest from typing import Any, cast from dlt.common.typing import TSecretStrValue -from dlt.sources.helpers.requests import Response, Request -from dlt.sources.helpers.rest_client import RESTClient, paginate +from dlt.sources.helpers.requests import Response, Request, paginate +from dlt.sources.helpers.rest_client import RESTClient from dlt.sources.helpers.rest_client.client import Hooks from dlt.sources.helpers.rest_client.paginators import JSONResponsePaginator From b8825adf622ef5735d78848d5ae30bc2da63f29b Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Mon, 25 Mar 2024 11:47:19 +0300 Subject: [PATCH 05/15] Fix the typing error --- dlt/sources/helpers/rest_client/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dlt/sources/helpers/rest_client/__init__.py b/dlt/sources/helpers/rest_client/__init__.py index fd5d558018..b2fb0a2351 100644 --- a/dlt/sources/helpers/rest_client/__init__.py +++ b/dlt/sources/helpers/rest_client/__init__.py @@ -16,7 +16,7 @@ def paginate( params: Optional[Dict[str, Any]] = None, json: Optional[Dict[str, Any]] = None, auth: AuthConfigBase = None, - paginator: Union[str, BasePaginator] = None, + paginator: Optional[BasePaginator] = None, data_selector: Optional[jsonpath.TJsonPath] = None, hooks: Optional[Hooks] = None, ) -> Iterator[PageData[Any]]: From 3c15854f5ab44c66aa35839e242fc77c7e54fc3e Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Mon, 25 Mar 2024 11:55:43 +0300 Subject: [PATCH 06/15] Use dlt.common.json --- tests/sources/helpers/rest_client/conftest.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/sources/helpers/rest_client/conftest.py b/tests/sources/helpers/rest_client/conftest.py index 3868cd4d9f..d5217f0d77 100644 --- a/tests/sources/helpers/rest_client/conftest.py +++ b/tests/sources/helpers/rest_client/conftest.py @@ -1,6 +1,5 @@ import re from typing import NamedTuple, Callable, Pattern, List, TYPE_CHECKING -import json import base64 from urllib.parse import urlsplit, urlunsplit @@ -8,6 +7,8 @@ import pytest import requests_mock +from dlt.common import json + if TYPE_CHECKING: RequestCallback = Callable[[requests_mock.Request, requests_mock.Context], str] else: From 7ff539190930800c255e5e4d952cbeedd5728018 Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Mon, 25 Mar 2024 17:32:22 +0300 Subject: [PATCH 07/15] Add dependency checks for PyJWT and cryptography in auth module --- dlt/sources/helpers/rest_client/auth.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/dlt/sources/helpers/rest_client/auth.py b/dlt/sources/helpers/rest_client/auth.py index c305654cbd..90a936e801 100644 --- a/dlt/sources/helpers/rest_client/auth.py +++ b/dlt/sources/helpers/rest_client/auth.py @@ -5,10 +5,20 @@ from requests.auth import AuthBase from requests import PreparedRequest # noqa: I251 import pendulum -import jwt -from cryptography.hazmat.backends import default_backend -from cryptography.hazmat.primitives import serialization -from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes + +from dlt.common.exceptions import MissingDependencyException + +try: + import jwt +except ModuleNotFoundError: + raise MissingDependencyException("dlt OAuth helpers", ["PyJWT"]) + +try: + from cryptography.hazmat.backends import default_backend + from cryptography.hazmat.primitives import serialization + from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes +except ModuleNotFoundError: + raise MissingDependencyException("dlt OAuth helpers", ["cryptography"]) from dlt import config, secrets from dlt.common import logger From 9b3d353b9728c1bbf3608135ecf6f5cc6e70cf3a Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Mon, 25 Mar 2024 17:41:10 +0300 Subject: [PATCH 08/15] Remove unused imports and check_connection function from rest_client utils --- dlt/sources/helpers/rest_client/utils.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/dlt/sources/helpers/rest_client/utils.py b/dlt/sources/helpers/rest_client/utils.py index 8732437a88..7fe91655c5 100644 --- a/dlt/sources/helpers/rest_client/utils.py +++ b/dlt/sources/helpers/rest_client/utils.py @@ -1,9 +1,3 @@ -from typing import Tuple - -from dlt.common import logger -from dlt.extract.source import DltSource - - def join_url(base_url: str, path: str) -> str: if base_url is None: raise ValueError("Base URL must be provided or set to an empty string.") @@ -20,15 +14,3 @@ def join_url(base_url: str, path: str) -> str: base_url += "/" return base_url + path.lstrip("/") - - -def check_connection( - source: DltSource, - *resource_names: str, -) -> Tuple[bool, str]: - try: - list(source.with_resources(*resource_names).add_limit(1)) - return (True, "") - except Exception as e: - logger.error(f"Error checking connection: {e}") - return (False, str(e)) From be95676511b1e565910bad7ad9e2fe1158cca6ad Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Mon, 25 Mar 2024 18:16:04 +0300 Subject: [PATCH 09/15] Refactor pagination assertion into a standalone function --- tests/sources/helpers/rest_client/conftest.py | 7 ++++++ .../helpers/rest_client/test_client.py | 22 ++++++++----------- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/tests/sources/helpers/rest_client/conftest.py b/tests/sources/helpers/rest_client/conftest.py index d5217f0d77..7eec090db6 100644 --- a/tests/sources/helpers/rest_client/conftest.py +++ b/tests/sources/helpers/rest_client/conftest.py @@ -195,3 +195,10 @@ def refresh_token(request, context): router.register_routes(m) yield m + + +def assert_pagination(pages, expected_start=0, page_size=10): + for i, page in enumerate(pages): + assert page == [ + {"id": i, "title": f"Post {i}"} for i in range(i * 10, (i + 1) * 10) + ] diff --git a/tests/sources/helpers/rest_client/test_client.py b/tests/sources/helpers/rest_client/test_client.py index 568aff4b78..aa5f042212 100644 --- a/tests/sources/helpers/rest_client/test_client.py +++ b/tests/sources/helpers/rest_client/test_client.py @@ -16,6 +16,8 @@ ) from dlt.sources.helpers.rest_client.exceptions import IgnoreResponseException +from .conftest import assert_pagination + def load_private_key(name="private_key.pem"): key_path = os.path.join(os.path.dirname(__file__), name) @@ -36,12 +38,6 @@ def rest_client() -> RESTClient: @pytest.mark.usefixtures("mock_api_server") class TestRESTClient: - def _assert_pagination(self, pages): - for i, page in enumerate(pages): - assert page == [ - {"id": i, "title": f"Post {i}"} for i in range(i * 10, (i + 1) * 10) - ] - def test_get_single_resource(self, rest_client): response = rest_client.get("/posts/1") assert response.status_code == 200 @@ -55,7 +51,7 @@ def test_pagination(self, rest_client: RESTClient): pages = list(pages_iter) - self._assert_pagination(pages) + assert_pagination(pages) def test_page_context(self, rest_client: RESTClient) -> None: for page in rest_client.paginate( @@ -76,7 +72,7 @@ def test_default_paginator(self, rest_client: RESTClient): pages = list(pages_iter) - self._assert_pagination(pages) + assert_pagination(pages) def test_paginate_with_hooks(self, rest_client: RESTClient): def response_hook(response: Response, *args: Any, **kwargs: Any) -> None: @@ -95,7 +91,7 @@ def response_hook(response: Response, *args: Any, **kwargs: Any) -> None: pages = list(pages_iter) - self._assert_pagination(pages) + assert_pagination(pages) pages_iter = rest_client.paginate( "/posts/1/some_details_404", @@ -120,7 +116,7 @@ def test_basic_auth_success(self, rest_client: RESTClient): ) pages = list(pages_iter) - self._assert_pagination(pages) + assert_pagination(pages) def test_bearer_token_auth_success(self, rest_client: RESTClient): response = rest_client.get( @@ -136,7 +132,7 @@ def test_bearer_token_auth_success(self, rest_client: RESTClient): ) pages = list(pages_iter) - self._assert_pagination(pages) + assert_pagination(pages) def test_api_key_auth_success(self, rest_client: RESTClient): response = rest_client.get( @@ -170,7 +166,7 @@ def test_oauth_jwt_auth_success(self, rest_client: RESTClient): auth=auth, ) - self._assert_pagination(list(pages_iter)) + assert_pagination(list(pages_iter)) def test_paginate_function(self, rest_client: RESTClient): pages_iter = paginate( @@ -180,4 +176,4 @@ def test_paginate_function(self, rest_client: RESTClient): pages = list(pages_iter) - self._assert_pagination(pages) + assert_pagination(pages) From b1ee49ab10c4426f1aa4a9d12b6d8bdb77a42a73 Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Mon, 25 Mar 2024 18:23:04 +0300 Subject: [PATCH 10/15] Move `paginate` function test to new file `test_requests_paginate.py` --- .../sources/helpers/rest_client/test_client.py | 12 +----------- .../rest_client/test_requests_paginate.py | 17 +++++++++++++++++ 2 files changed, 18 insertions(+), 11 deletions(-) create mode 100644 tests/sources/helpers/rest_client/test_requests_paginate.py diff --git a/tests/sources/helpers/rest_client/test_client.py b/tests/sources/helpers/rest_client/test_client.py index aa5f042212..7a4c55f9a6 100644 --- a/tests/sources/helpers/rest_client/test_client.py +++ b/tests/sources/helpers/rest_client/test_client.py @@ -2,7 +2,7 @@ import pytest from typing import Any, cast from dlt.common.typing import TSecretStrValue -from dlt.sources.helpers.requests import Response, Request, paginate +from dlt.sources.helpers.requests import Response, Request from dlt.sources.helpers.rest_client import RESTClient from dlt.sources.helpers.rest_client.client import Hooks from dlt.sources.helpers.rest_client.paginators import JSONResponsePaginator @@ -167,13 +167,3 @@ def test_oauth_jwt_auth_success(self, rest_client: RESTClient): ) assert_pagination(list(pages_iter)) - - def test_paginate_function(self, rest_client: RESTClient): - pages_iter = paginate( - "https://api.example.com/posts", - paginator=JSONResponsePaginator(next_url_path="next_page"), - ) - - pages = list(pages_iter) - - assert_pagination(pages) diff --git a/tests/sources/helpers/rest_client/test_requests_paginate.py b/tests/sources/helpers/rest_client/test_requests_paginate.py new file mode 100644 index 0000000000..5ea137c735 --- /dev/null +++ b/tests/sources/helpers/rest_client/test_requests_paginate.py @@ -0,0 +1,17 @@ +import pytest + +from dlt.sources.helpers.requests import paginate +from dlt.sources.helpers.rest_client.paginators import JSONResponsePaginator +from .conftest import assert_pagination + + +@pytest.mark.usefixtures("mock_api_server") +def test_requests_paginate(): + pages_iter = paginate( + "https://api.example.com/posts", + paginator=JSONResponsePaginator(next_url_path="next_page"), + ) + + pages = list(pages_iter) + + assert_pagination(pages) From bc1422210518c9ff9863fe02a07235d9a8c088fe Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Mon, 25 Mar 2024 18:30:39 +0300 Subject: [PATCH 11/15] Remove PyJWT from deps --- poetry.lock | 2 +- pyproject.toml | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/poetry.lock b/poetry.lock index b3f9dfa368..f598c3e8b3 100644 --- a/poetry.lock +++ b/poetry.lock @@ -9035,4 +9035,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<3.13" -content-hash = "c43c7a1ff57aef576ae43d825ef45215783d12e5d7fd9b6f870862db967b1fb1" +content-hash = "99658baf1bfda2ac065bda897637cae0eb122c76777688a7d606df0ef06c7fcc" diff --git a/pyproject.toml b/pyproject.toml index 4ed1f77823..de5f8055c5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,7 +80,6 @@ pyodbc = {version = "^4.0.39", optional = true} qdrant-client = {version = "^1.6.4", optional = true, extras = ["fastembed"]} databricks-sql-connector = {version = ">=2.9.3,<3.0.0", optional = true} dbt-databricks = {version = "^1.7.3", optional = true} -pyjwt = "^2.8.0" [tool.poetry.extras] dbt = ["dbt-core", "dbt-redshift", "dbt-bigquery", "dbt-duckdb", "dbt-snowflake", "dbt-athena-community", "dbt-databricks"] From eaeeca33b441f0f1c9d82aadb6473d5ac74a0749 Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Mon, 25 Mar 2024 19:12:20 +0300 Subject: [PATCH 12/15] Remove explicit initializers and meta fields from configspec classes --- dlt/sources/helpers/rest_client/auth.py | 69 +++++-------------------- 1 file changed, 14 insertions(+), 55 deletions(-) diff --git a/dlt/sources/helpers/rest_client/auth.py b/dlt/sources/helpers/rest_client/auth.py index 90a936e801..6e156a1c68 100644 --- a/dlt/sources/helpers/rest_client/auth.py +++ b/dlt/sources/helpers/rest_client/auth.py @@ -43,12 +43,7 @@ class AuthConfigBase(AuthBase, CredentialsConfiguration): @configspec class BearerTokenAuth(AuthConfigBase): - type: Final[Literal["http"]] = "http" # noqa: A003 - scheme: Literal["bearer"] = "bearer" - token: TSecretStrValue - - def __init__(self, token: TSecretStrValue = secrets.value) -> None: - self.token = token + token: TSecretStrValue = None def parse_native_representation(self, value: Any) -> None: if isinstance(value, str): @@ -67,21 +62,10 @@ def __call__(self, request: PreparedRequest) -> PreparedRequest: @configspec class APIKeyAuth(AuthConfigBase): - type: Final[Literal["apiKey"]] = "apiKey" # noqa: A003 name: str = "Authorization" - api_key: TSecretStrValue + api_key: TSecretStrValue = None location: TApiKeyLocation = "header" - def __init__( - self, - name: str = config.value, - api_key: TSecretStrValue = secrets.value, - location: TApiKeyLocation = "header", - ) -> None: - self.name = name - self.api_key = api_key - self.location = location - def parse_native_representation(self, value: Any) -> None: if isinstance(value, str): self.api_key = cast(TSecretStrValue, value) @@ -104,16 +88,8 @@ def __call__(self, request: PreparedRequest) -> PreparedRequest: @configspec class HttpBasicAuth(AuthConfigBase): - type: Final[Literal["http"]] = "http" # noqa: A003 - scheme: Literal["basic"] = "basic" - username: str - password: TSecretStrValue - - def __init__( - self, username: str = config.value, password: TSecretStrValue = secrets.value - ) -> None: - self.username = username - self.password = password + username: str = "" + password: TSecretStrValue = None def parse_native_representation(self, value: Any) -> None: if isinstance(value, Iterable) and not isinstance(value, str): @@ -138,11 +114,7 @@ class OAuth2AuthBase(AuthConfigBase): """Base class for oauth2 authenticators. requires access_token""" # TODO: Separate class for flows (implicit, authorization_code, client_credentials, etc) - type: Final[Literal["oauth2"]] = "oauth2" # noqa: A003 - access_token: TSecretStrValue - - def __init__(self, access_token: TSecretStrValue = secrets.value) -> None: - self.access_token = access_token + access_token: TSecretStrValue = None def parse_native_representation(self, value: Any) -> None: if isinstance(value, str): @@ -164,33 +136,20 @@ class OAuthJWTAuth(BearerTokenAuth): """This is a form of Bearer auth, actually there's not standard way to declare it in openAPI""" format: Final[Literal["JWT"]] = "JWT" # noqa: A003 - client_id: str - private_key: TSecretStrValue - auth_endpoint: str - scopes: Optional[str] = None + client_id: str = None + private_key: TSecretStrValue = None + auth_endpoint: str = None + scopes: Optional[Union[str, List[str]]] = None headers: Optional[Dict[str, str]] = None private_key_passphrase: Optional[TSecretStrValue] = None default_token_expiration: int = 3600 - def __init__( - self, - client_id: str = config.value, - private_key: TSecretStrValue = secrets.value, - auth_endpoint: str = config.value, - scopes: Optional[Union[str, List[str]]] = None, - headers: Optional[Dict[str, str]] = None, - private_key_passphrase: Optional[TSecretStrValue] = None, - default_token_expiration: int = 3600, - ): - self.client_id = client_id - self.private_key = private_key - self.private_key_passphrase = private_key_passphrase - self.auth_endpoint = auth_endpoint - self.scopes = scopes if isinstance(scopes, str) else " ".join(scopes) - self.headers = headers + def __post_init__(self) -> None: + self.scopes = ( + self.scopes if isinstance(self.scopes, str) else " ".join(self.scopes) + ) self.token = None self.token_expiry: Optional[pendulum.DateTime] = None - self.default_token_expiration = default_token_expiration def __call__(self, r: PreparedRequest) -> PreparedRequest: if self.token is None or self.is_token_expired(): @@ -229,7 +188,7 @@ def create_jwt_payload(self) -> Dict[str, Union[str, int]]: "aud": self.auth_endpoint, "exp": math.floor((now.add(hours=1)).timestamp()), "iat": math.floor(now.timestamp()), - "scope": self.scopes, + "scope": cast(str, self.scopes), } def load_private_key(self) -> PrivateKeyTypes: From 6e02d1d8cac07e8415c058e58a1c4b88caa4ed33 Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Mon, 25 Mar 2024 19:51:06 +0300 Subject: [PATCH 13/15] Implement lazy loading for jwt and cryptography in auth --- dlt/sources/helpers/rest_client/auth.py | 43 ++++++++++++++++--------- 1 file changed, 28 insertions(+), 15 deletions(-) diff --git a/dlt/sources/helpers/rest_client/auth.py b/dlt/sources/helpers/rest_client/auth.py index 6e156a1c68..cd5f616f3a 100644 --- a/dlt/sources/helpers/rest_client/auth.py +++ b/dlt/sources/helpers/rest_client/auth.py @@ -1,6 +1,17 @@ from base64 import b64encode import math -from typing import List, Dict, Final, Literal, Optional, Union, Any, cast, Iterable +from typing import ( + List, + Dict, + Final, + Literal, + Optional, + Union, + Any, + cast, + Iterable, + TYPE_CHECKING, +) from dlt.sources.helpers import requests from requests.auth import AuthBase from requests import PreparedRequest # noqa: I251 @@ -8,25 +19,16 @@ from dlt.common.exceptions import MissingDependencyException -try: - import jwt -except ModuleNotFoundError: - raise MissingDependencyException("dlt OAuth helpers", ["PyJWT"]) - -try: - from cryptography.hazmat.backends import default_backend - from cryptography.hazmat.primitives import serialization - from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes -except ModuleNotFoundError: - raise MissingDependencyException("dlt OAuth helpers", ["cryptography"]) - -from dlt import config, secrets from dlt.common import logger from dlt.common.configuration.specs.base_configuration import configspec from dlt.common.configuration.specs import CredentialsConfiguration from dlt.common.configuration.specs.exceptions import NativeValueError from dlt.common.typing import TSecretStrValue +if TYPE_CHECKING: + from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes +else: + PrivateKeyTypes = Any TApiKeyLocation = Literal[ "header", "cookie", "query", "param" @@ -161,6 +163,11 @@ def is_token_expired(self) -> bool: return not self.token_expiry or pendulum.now() >= self.token_expiry def obtain_token(self) -> None: + try: + import jwt + except ModuleNotFoundError: + raise MissingDependencyException("dlt OAuth helpers", ["PyJWT"]) + payload = self.create_jwt_payload() data = { "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", @@ -191,7 +198,13 @@ def create_jwt_payload(self) -> Dict[str, Union[str, int]]: "scope": cast(str, self.scopes), } - def load_private_key(self) -> PrivateKeyTypes: + def load_private_key(self) -> "PrivateKeyTypes": + try: + from cryptography.hazmat.backends import default_backend + from cryptography.hazmat.primitives import serialization + except ModuleNotFoundError: + raise MissingDependencyException("dlt OAuth helpers", ["cryptography"]) + private_key_bytes = self.private_key.encode("utf-8") return serialization.load_pem_private_key( private_key_bytes, From b9bb0d354aa281fb84838f39d7e3f25e7fe9b040 Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Mon, 25 Mar 2024 19:52:15 +0300 Subject: [PATCH 14/15] Set username default to None --- dlt/sources/helpers/rest_client/auth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dlt/sources/helpers/rest_client/auth.py b/dlt/sources/helpers/rest_client/auth.py index cd5f616f3a..5d7a2f7eb2 100644 --- a/dlt/sources/helpers/rest_client/auth.py +++ b/dlt/sources/helpers/rest_client/auth.py @@ -90,7 +90,7 @@ def __call__(self, request: PreparedRequest) -> PreparedRequest: @configspec class HttpBasicAuth(AuthConfigBase): - username: str = "" + username: str = None password: TSecretStrValue = None def parse_native_representation(self, value: Any) -> None: From b05b20c3707041301f10afd6a56b977a95bceb40 Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Mon, 25 Mar 2024 21:23:21 +0300 Subject: [PATCH 15/15] Add PyJWT to dev dependencies --- poetry.lock | 2 +- pyproject.toml | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/poetry.lock b/poetry.lock index f598c3e8b3..a7c3979625 100644 --- a/poetry.lock +++ b/poetry.lock @@ -9035,4 +9035,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<3.13" -content-hash = "99658baf1bfda2ac065bda897637cae0eb122c76777688a7d606df0ef06c7fcc" +content-hash = "e6e43e82afedfa274c91f3fd13dbbddd9cac64f386d2f5f1c4564ff6f5784cd2" diff --git a/pyproject.toml b/pyproject.toml index de5f8055c5..62a45c86f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -145,6 +145,7 @@ google-api-python-client = ">=1.7.11" pytest-asyncio = "^0.23.5" types-sqlalchemy = "^1.4.53.38" ruff = "^0.3.2" +pyjwt = "^2.8.0" [tool.poetry.group.pipeline] optional=true