Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT: parameterised headers rest_api_source #2084

Open
wants to merge 13 commits into
base: devel
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions dlt/sources/helpers/rest_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def _create_request(
path: str,
method: HTTPMethod,
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, Any]] = None,
json: Optional[Dict[str, Any]] = None,
auth: Optional[AuthBase] = None,
hooks: Optional[Hooks] = None,
Expand All @@ -109,10 +110,14 @@ def _create_request(
else:
url = join_url(self.base_url, path)

headers = headers or {}
if self.headers:
headers.update(self.headers)

return Request(
method=method,
url=url,
headers=self.headers,
headers=headers,
params=params,
json=json,
auth=auth or self.auth,
Expand All @@ -123,6 +128,7 @@ def _send_request(self, request: Request, **kwargs: Any) -> Response:
logger.info(
f"Making {request.method.upper()} request to {request.url}"
f" with params={request.params}, json={request.json}"
f" with headers={request.headers}"
)

prepared_request = self.session.prepare_request(request)
Expand All @@ -142,6 +148,7 @@ def request(self, path: str = "", method: HTTPMethod = "GET", **kwargs: Any) ->
prepared_request = self._create_request(
path=path,
method=method,
headers=kwargs.pop("headers", None),
params=kwargs.pop("params", None),
json=kwargs.pop("json", None),
auth=kwargs.pop("auth", None),
Expand All @@ -160,6 +167,7 @@ def paginate(
path: str = "",
method: HTTPMethodBasic = "GET",
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, Any]] = None,
json: Optional[Dict[str, Any]] = None,
auth: Optional[AuthBase] = None,
paginator: Optional[BasePaginator] = None,
Expand All @@ -173,6 +181,7 @@ def paginate(
path (str): Endpoint path for the request, relative to `base_url`.
method (HTTPMethodBasic): HTTP method for the request, defaults to 'get'.
params (Optional[Dict[str, Any]]): URL parameters for the request.
headers (Optional[Dict[str, Any]]): Headers for the request.
json (Optional[Dict[str, Any]]): JSON payload for the request.
auth (Optional[AuthBase): Authentication configuration for the request.
paginator (Optional[BasePaginator]): Paginator instance for handling
Expand Down Expand Up @@ -210,7 +219,13 @@ def paginate(
hooks["response"] = [raise_for_status]

request = self._create_request(
path=path, method=method, params=params, json=json, auth=auth, hooks=hooks
path=path,
headers=headers,
method=method,
params=params,
json=json,
auth=auth,
hooks=hooks,
)

if paginator:
Expand Down
10 changes: 8 additions & 2 deletions dlt/sources/rest_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ def process(
def paginate_resource(
method: HTTPMethodBasic,
path: str,
headers: Dict[str, Any],
params: Dict[str, Any],
json: Optional[Dict[str, Any]],
paginator: Optional[BasePaginator],
Expand All @@ -313,6 +314,7 @@ def paginate_resource(

yield from client.paginate(
method=method,
headers=headers,
path=path,
params=params,
json=json,
Expand All @@ -327,6 +329,7 @@ def paginate_resource(
)(
method=endpoint_config.get("method", "get"),
path=endpoint_config.get("path"),
headers=endpoint_config.get("headers"),
params=request_params,
json=request_json,
paginator=paginator,
Expand All @@ -346,6 +349,7 @@ def paginate_dependent_resource(
items: List[Dict[str, Any]],
method: HTTPMethodBasic,
path: str,
headers: Dict[str, Any],
params: Dict[str, Any],
paginator: Optional[BasePaginator],
data_selector: Optional[jsonpath.TJsonPath],
Expand All @@ -368,12 +372,13 @@ def paginate_dependent_resource(
)

for item in items:
formatted_path, parent_record = process_parent_data_item(
path, item, resolved_params, include_from_parent
formatted_path, formatted_headers, parent_record = process_parent_data_item(
path, item, resolved_params, include_from_parent, headers
)

for child_page in client.paginate(
method=method,
headers=formatted_headers,
path=formatted_path,
params=params,
paginator=paginator,
Expand All @@ -392,6 +397,7 @@ def paginate_dependent_resource(
)(
method=endpoint_config.get("method", "get"),
path=endpoint_config.get("path"),
headers=endpoint_config.get("headers"),
params=base_params,
paginator=paginator,
data_selector=endpoint_config.get("data_selector"),
Expand Down
50 changes: 42 additions & 8 deletions dlt/sources/rest_api/config_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ def expand_and_index_resources(
assert isinstance(endpoint_resource["endpoint"], dict)
_setup_single_entity_endpoint(endpoint_resource["endpoint"])
_bind_path_params(endpoint_resource)
_bind_header_params(endpoint_resource)

resource_name = endpoint_resource["name"]
assert isinstance(
Expand Down Expand Up @@ -410,15 +411,29 @@ def _bind_path_params(resource: EndpointResource) -> None:
# resolved params are bound later
path_params[name] = "{" + name + "}"

if len(resolve_params) > 0:
raise NotImplementedError(
f"Resource {resource['name']} defines resolve params {resolve_params} that are not"
f" bound in path {path}. Resolve query params not supported yet."
)

resource["endpoint"]["path"] = path.format(**path_params)


def _bind_header_params(resource: EndpointResource) -> None:
"""Binds params declared in headers to params available in `params`. Pops the
bound params but skips params of type `resolve` and `incremental`, which are bound later.
"""
assert isinstance(resource["endpoint"], dict) # type guard
params = resource["endpoint"].get("params", {})
bind_params = {}
# copy must be made because size of dict changes during iteration
params_iter = params.copy()
for k, v in params_iter.items():
if not isinstance(v, dict):
bind_params[k] = params.pop(k)
else:
# resolved params are bound later
bind_params[k] = "{" + k + "}"

headers = resource["endpoint"].get("headers", {})
resource["endpoint"]["headers"] = generic_format(headers, bind_params) # type: ignore[typeddict-item]


def _setup_single_entity_endpoint(endpoint: Endpoint) -> Endpoint:
"""Tries to guess if the endpoint refers to a single entity and when detected:
* if `data_selector` was not specified (or is None), "$" is selected
Expand Down Expand Up @@ -569,12 +584,28 @@ def remove_field(response: Response, *args, **kwargs) -> Response:
return None


def generic_format(
to_format: Union[Dict[str, Any], str, List[Any]], param_values: Dict[str, Any]
) -> Union[Dict[str, Any], str, List[Any]]:
if isinstance(to_format, dict):
return {
generic_format(key, param_values): generic_format(val, param_values) # type: ignore[misc]
for key, val in to_format.items()
}
if isinstance(to_format, list):
return [generic_format(item, param_values) for item in to_format]
if isinstance(to_format, str):
return to_format.format(**param_values)
return str(to_format)


def process_parent_data_item(
path: str,
item: Dict[str, Any],
resolved_params: List[ResolvedParam],
include_from_parent: List[str],
) -> Tuple[str, Dict[str, Any]]:
headers: Optional[Dict[str, Any]] = None,
) -> Tuple[str, Dict[str, Any], Dict[str, Any]]:
parent_resource_name = resolved_params[0].resolve_config["resource"]

param_values = {}
Expand Down Expand Up @@ -607,7 +638,10 @@ def process_parent_data_item(
)
parent_record[child_key] = item[parent_key]

return bound_path, parent_record
if headers is not None:
formatted_headers = generic_format(headers, param_values)
return bound_path, formatted_headers, parent_record # type: ignore[return-value]
return bound_path, {}, parent_record


def _merge_resource_endpoints(
Expand Down
1 change: 1 addition & 0 deletions dlt/sources/rest_api/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ class Endpoint(TypedDict, total=False):
method: Optional[HTTPMethodBasic]
params: Optional[Dict[str, Union[ResolveParamConfig, IncrementalParamConfig, Any]]]
json: Optional[Dict[str, Any]]
headers: Optional[Dict[str, Any]]
paginator: Optional[PaginatorConfig]
data_selector: Optional[jsonpath.TJsonPath]
response_actions: Optional[List[ResponseAction]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,7 @@ def test_modal_snippet() -> None:
# @@@DLT_SNIPPET_END modal_image

# @@@DLT_SNIPPET_START modal_function
@app.function(
volumes={"/data/": vol},
schedule=modal.Period(days=1),
serialized=True
)
@app.function(volumes={"/data/": vol}, schedule=modal.Period(days=1), serialized=True)
def load_tables() -> None:
import dlt
import os
Expand Down
Loading
Loading