Skip to content

Commit

Permalink
Fix validation error in for custom auth classes (#2129)
Browse files Browse the repository at this point in the history
  • Loading branch information
burnash authored Dec 12, 2024
1 parent 4e5a240 commit 3a8dfa7
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 3 deletions.
15 changes: 15 additions & 0 deletions dlt/common/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,3 +484,18 @@ def decorator(
return func

return decorator


def add_value_to_literal(literal: Any, value: Any) -> None:
"""Extends a Literal at runtime with a new value.
Args:
literal (Type[Any]): Literal to extend
value (Any): Value to add
"""
type_args = get_args(literal)

if value not in type_args:
type_args += (value,)
literal.__args__ = type_args
8 changes: 7 additions & 1 deletion dlt/sources/rest_api/config_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from dlt.common.configuration import resolve_configuration
from dlt.common.schema.utils import merge_columns
from dlt.common.utils import update_dict_nested, exclude_keys
from dlt.common.typing import add_value_to_literal
from dlt.common import jsonpath

from dlt.extract.incremental import Incremental
Expand Down Expand Up @@ -64,6 +65,8 @@
ResponseActionDict,
Endpoint,
EndpointResource,
AuthType,
PaginatorType,
)


Expand Down Expand Up @@ -103,6 +106,7 @@ def register_paginator(
"Your custom paginator has to be a subclass of BasePaginator"
)
PAGINATOR_MAP[paginator_name] = paginator_class
add_value_to_literal(PaginatorType, paginator_name)


def get_paginator_class(paginator_name: str) -> Type[BasePaginator]:
Expand Down Expand Up @@ -153,6 +157,8 @@ def register_auth(
)
AUTH_MAP[auth_name] = auth_class

add_value_to_literal(AuthType, auth_name)


def get_auth_class(auth_type: str) -> Type[AuthConfigBase]:
try:
Expand Down Expand Up @@ -285,7 +291,7 @@ def build_resource_dependency_graph(
resolved_param_map[resource_name] = None
break
assert isinstance(endpoint_resource["endpoint"], dict)
# connect transformers to resources via resolved params
# find resolved parameters to connect dependent resources
resolved_params = _find_resolved_params(endpoint_resource["endpoint"])

# set of resources in resolved params
Expand Down
17 changes: 17 additions & 0 deletions tests/common/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
is_union_type,
is_annotated,
is_callable_type,
add_value_to_literal,
)


Expand Down Expand Up @@ -293,3 +294,19 @@ def test_secret_type() -> None:

assert TSecretStrValue("x_str") == "x_str"
assert TSecretStrValue({}) == "{}"


def test_add_value_to_literal() -> None:
TestLiteral = Literal["red", "blue"]

add_value_to_literal(TestLiteral, "green")

assert get_args(TestLiteral) == ("red", "blue", "green")

add_value_to_literal(TestLiteral, "red")
assert get_args(TestLiteral) == ("red", "blue", "green")

TestSingleLiteral = Literal["red"]
add_value_to_literal(TestSingleLiteral, "green")
add_value_to_literal(TestSingleLiteral, "blue")
assert get_args(TestSingleLiteral) == ("red", "green", "blue")
17 changes: 16 additions & 1 deletion tests/sources/rest_api/configurations/test_custom_auth_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from dlt.sources import rest_api
from dlt.sources.helpers.rest_client.auth import APIKeyAuth, OAuth2ClientCredentials
from dlt.sources.rest_api.typing import ApiKeyAuthConfig, AuthConfig
from dlt.sources.rest_api.typing import ApiKeyAuthConfig, AuthConfig, RESTAPIConfig


class CustomOAuth2(OAuth2ClientCredentials):
Expand Down Expand Up @@ -77,3 +77,18 @@ class NotAuthConfigBase:
"not_an_auth_config_base", NotAuthConfigBase # type: ignore
)
assert e.match("Invalid auth: NotAuthConfigBase.")

def test_valid_config_raises_no_error(self, custom_auth_config: AuthConfig) -> None:
rest_api.config_setup.register_auth("custom_oauth_2", CustomOAuth2)

valid_config: RESTAPIConfig = {
"client": {
"base_url": "https://example.com",
"auth": custom_auth_config,
},
"resources": ["test"],
}

rest_api.rest_api_source(valid_config)

del rest_api.config_setup.AUTH_MAP["custom_oauth_2"]
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from dlt.sources import rest_api
from dlt.sources.helpers.rest_client.paginators import JSONLinkPaginator
from dlt.sources.rest_api.typing import PaginatorConfig
from dlt.sources.rest_api.typing import PaginatorConfig, RESTAPIConfig


class CustomPaginator(JSONLinkPaginator):
Expand Down Expand Up @@ -67,3 +67,13 @@ class NotAPaginator:
with pytest.raises(ValueError) as e:
rest_api.config_setup.register_paginator("not_a_paginator", NotAPaginator) # type: ignore[arg-type]
assert e.match("Invalid paginator: NotAPaginator.")

def test_test_valid_config_raises_no_error(self, custom_paginator_config) -> None:
rest_api.config_setup.register_paginator("custom_paginator", CustomPaginator)

valid_config: RESTAPIConfig = {
"client": {"base_url": "https://example.com", "paginator": custom_paginator_config},
"resources": ["test"],
}

rest_api.rest_api_source(valid_config)

0 comments on commit 3a8dfa7

Please sign in to comment.