diff --git a/dlt/common/typing.py b/dlt/common/typing.py index 8986d753f3..a0322fe01e 100644 --- a/dlt/common/typing.py +++ b/dlt/common/typing.py @@ -484,3 +484,18 @@ def decorator( return func return decorator + + +def add_value_to_literal(literal: Any, value: Any) -> None: + """Extends a Literal at runtime with a new value. + + Args: + literal (Type[Any]): Literal to extend + value (Any): Value to add + + """ + type_args = get_args(literal) + + if value not in type_args: + type_args += (value,) + literal.__args__ = type_args diff --git a/dlt/sources/rest_api/config_setup.py b/dlt/sources/rest_api/config_setup.py index d03a4fd59b..bf62c6c4f7 100644 --- a/dlt/sources/rest_api/config_setup.py +++ b/dlt/sources/rest_api/config_setup.py @@ -20,6 +20,7 @@ from dlt.common.configuration import resolve_configuration from dlt.common.schema.utils import merge_columns from dlt.common.utils import update_dict_nested, exclude_keys +from dlt.common.typing import add_value_to_literal from dlt.common import jsonpath from dlt.extract.incremental import Incremental @@ -64,6 +65,8 @@ ResponseActionDict, Endpoint, EndpointResource, + AuthType, + PaginatorType, ) @@ -103,6 +106,7 @@ def register_paginator( "Your custom paginator has to be a subclass of BasePaginator" ) PAGINATOR_MAP[paginator_name] = paginator_class + add_value_to_literal(PaginatorType, paginator_name) def get_paginator_class(paginator_name: str) -> Type[BasePaginator]: @@ -153,6 +157,8 @@ def register_auth( ) AUTH_MAP[auth_name] = auth_class + add_value_to_literal(AuthType, auth_name) + def get_auth_class(auth_type: str) -> Type[AuthConfigBase]: try: @@ -285,7 +291,7 @@ def build_resource_dependency_graph( resolved_param_map[resource_name] = None break assert isinstance(endpoint_resource["endpoint"], dict) - # connect transformers to resources via resolved params + # find resolved parameters to connect dependent resources resolved_params = _find_resolved_params(endpoint_resource["endpoint"]) # set of resources in resolved params diff --git a/tests/common/test_typing.py b/tests/common/test_typing.py index 2749e3ebb1..e81c3e7fa2 100644 --- a/tests/common/test_typing.py +++ b/tests/common/test_typing.py @@ -43,6 +43,7 @@ is_union_type, is_annotated, is_callable_type, + add_value_to_literal, ) @@ -293,3 +294,19 @@ def test_secret_type() -> None: assert TSecretStrValue("x_str") == "x_str" assert TSecretStrValue({}) == "{}" + + +def test_add_value_to_literal() -> None: + TestLiteral = Literal["red", "blue"] + + add_value_to_literal(TestLiteral, "green") + + assert get_args(TestLiteral) == ("red", "blue", "green") + + add_value_to_literal(TestLiteral, "red") + assert get_args(TestLiteral) == ("red", "blue", "green") + + TestSingleLiteral = Literal["red"] + add_value_to_literal(TestSingleLiteral, "green") + add_value_to_literal(TestSingleLiteral, "blue") + assert get_args(TestSingleLiteral) == ("red", "green", "blue") diff --git a/tests/sources/rest_api/configurations/test_custom_auth_config.py b/tests/sources/rest_api/configurations/test_custom_auth_config.py index 1a5a2e58a3..52cdb95735 100644 --- a/tests/sources/rest_api/configurations/test_custom_auth_config.py +++ b/tests/sources/rest_api/configurations/test_custom_auth_config.py @@ -5,7 +5,7 @@ from dlt.sources import rest_api from dlt.sources.helpers.rest_client.auth import APIKeyAuth, OAuth2ClientCredentials -from dlt.sources.rest_api.typing import ApiKeyAuthConfig, AuthConfig +from dlt.sources.rest_api.typing import ApiKeyAuthConfig, AuthConfig, RESTAPIConfig class CustomOAuth2(OAuth2ClientCredentials): @@ -77,3 +77,18 @@ class NotAuthConfigBase: "not_an_auth_config_base", NotAuthConfigBase # type: ignore ) assert e.match("Invalid auth: NotAuthConfigBase.") + + def test_valid_config_raises_no_error(self, custom_auth_config: AuthConfig) -> None: + rest_api.config_setup.register_auth("custom_oauth_2", CustomOAuth2) + + valid_config: RESTAPIConfig = { + "client": { + "base_url": "https://example.com", + "auth": custom_auth_config, + }, + "resources": ["test"], + } + + rest_api.rest_api_source(valid_config) + + del rest_api.config_setup.AUTH_MAP["custom_oauth_2"] diff --git a/tests/sources/rest_api/configurations/test_custom_paginator_config.py b/tests/sources/rest_api/configurations/test_custom_paginator_config.py index f8ac060218..975ab10176 100644 --- a/tests/sources/rest_api/configurations/test_custom_paginator_config.py +++ b/tests/sources/rest_api/configurations/test_custom_paginator_config.py @@ -4,7 +4,7 @@ from dlt.sources import rest_api from dlt.sources.helpers.rest_client.paginators import JSONLinkPaginator -from dlt.sources.rest_api.typing import PaginatorConfig +from dlt.sources.rest_api.typing import PaginatorConfig, RESTAPIConfig class CustomPaginator(JSONLinkPaginator): @@ -67,3 +67,13 @@ class NotAPaginator: with pytest.raises(ValueError) as e: rest_api.config_setup.register_paginator("not_a_paginator", NotAPaginator) # type: ignore[arg-type] assert e.match("Invalid paginator: NotAPaginator.") + + def test_test_valid_config_raises_no_error(self, custom_paginator_config) -> None: + rest_api.config_setup.register_paginator("custom_paginator", CustomPaginator) + + valid_config: RESTAPIConfig = { + "client": {"base_url": "https://example.com", "paginator": custom_paginator_config}, + "resources": ["test"], + } + + rest_api.rest_api_source(valid_config)