From 338f4763c2f8163b2c7607cf77920e8f15125630 Mon Sep 17 00:00:00 2001 From: octavia-squidington-iii Date: Sun, 10 Nov 2024 05:13:52 +0000 Subject: [PATCH] Auto-fix lint and format issues --- airbyte_cdk/config_observation.py | 13 +- airbyte_cdk/connector.py | 22 +- .../connector_builder_handler.py | 36 +- airbyte_cdk/connector_builder/main.py | 30 +- .../connector_builder/message_grouper.py | 120 +- airbyte_cdk/destinations/destination.py | 60 +- .../destinations/vector_db_based/config.py | 50 +- .../vector_db_based/document_processor.py | 60 +- .../destinations/vector_db_based/embedder.py | 40 +- .../vector_db_based/test_utils.py | 18 +- .../destinations/vector_db_based/utils.py | 10 +- .../destinations/vector_db_based/writer.py | 19 +- airbyte_cdk/entrypoint.py | 108 +- airbyte_cdk/exception_handler.py | 16 +- airbyte_cdk/logger.py | 12 +- airbyte_cdk/models/airbyte_protocol.py | 6 +- .../models/airbyte_protocol_serializers.py | 12 +- airbyte_cdk/sources/abstract_source.py | 80 +- .../concurrent_read_processor.py | 61 +- .../concurrent_source/concurrent_source.py | 32 +- .../concurrent_source_adapter.py | 33 +- .../concurrent_source/thread_pool_manager.py | 12 +- .../sources/connector_state_manager.py | 39 +- .../sources/declarative/async_job/job.py | 4 +- .../declarative/async_job/job_orchestrator.py | 82 +- .../declarative/async_job/job_tracker.py | 30 +- .../declarative/async_job/repository.py | 4 +- .../auth/declarative_authenticator.py | 4 +- airbyte_cdk/sources/declarative/auth/jwt.py | 34 +- airbyte_cdk/sources/declarative/auth/oauth.py | 46 +- .../auth/selective_authenticator.py | 4 +- airbyte_cdk/sources/declarative/auth/token.py | 33 +- .../declarative/checks/check_stream.py | 16 +- .../declarative/checks/connection_checker.py | 4 +- .../concurrency_level/concurrency_level.py | 14 +- .../concurrent_declarative_source.py | 130 ++- .../declarative/datetime/min_max_datetime.py | 24 +- .../sources/declarative/declarative_source.py | 4 +- .../sources/declarative/declarative_stream.py | 33 +- .../sources/declarative/decoders/decoder.py | 4 +- .../declarative/decoders/json_decoder.py | 4 +- .../decoders/pagination_decoder_decorator.py | 4 +- .../declarative/decoders/xml_decoder.py | 8 +- .../declarative/extractors/dpath_extractor.py | 8 +- .../declarative/extractors/record_filter.py | 31 +- .../declarative/extractors/record_selector.py | 13 +- .../extractors/response_to_file_extractor.py | 20 +- .../incremental/datetime_based_cursor.py | 123 +- .../incremental/global_substream_cursor.py | 30 +- .../incremental/per_partition_cursor.py | 61 +- .../incremental/per_partition_with_global.py | 24 +- .../resumable_full_refresh_cursor.py | 4 +- .../interpolation/interpolated_boolean.py | 22 +- .../interpolation/interpolated_mapping.py | 6 +- .../interpolated_nested_mapping.py | 13 +- .../interpolation/interpolated_string.py | 8 +- .../interpolation/interpolation.py | 8 +- .../declarative/interpolation/jinja.py | 8 +- .../declarative/interpolation/macros.py | 23 +- .../manifest_declarative_source.py | 110 +- ...legacy_to_per_partition_state_migration.py | 9 +- .../models/declarative_component_schema.py | 52 +- .../parsers/manifest_component_transformer.py | 17 +- .../parsers/manifest_reference_resolver.py | 11 +- .../parsers/model_to_component_factory.py | 1009 +++++++++++++---- .../cartesian_product_stream_slicer.py | 36 +- .../list_partition_router.py | 32 +- .../substream_partition_router.py | 64 +- .../constant_backoff_strategy.py | 8 +- .../backoff_strategies/header_helper.py | 4 +- .../wait_time_from_header_backoff_strategy.py | 22 +- ...until_time_from_header_backoff_strategy.py | 20 +- .../error_handlers/composite_error_handler.py | 4 +- .../error_handlers/default_error_handler.py | 24 +- .../default_http_response_filter.py | 21 +- .../error_handlers/http_response_filter.py | 57 +- .../requesters/http_job_repository.py | 37 +- .../declarative/requesters/http_requester.py | 73 +- .../paginators/default_paginator.py | 63 +- .../requesters/paginators/no_pagination.py | 4 +- .../requesters/paginators/paginator.py | 4 +- .../strategies/cursor_pagination_strategy.py | 22 +- .../paginators/strategies/offset_increment.py | 31 +- .../paginators/strategies/page_increment.py | 12 +- .../strategies/pagination_strategy.py | 4 +- .../paginators/strategies/stop_condition.py | 8 +- ...datetime_based_request_options_provider.py | 21 +- .../default_request_options_provider.py | 4 +- ...erpolated_nested_request_input_provider.py | 27 +- .../interpolated_request_input_provider.py | 23 +- .../interpolated_request_options_provider.py | 35 +- .../declarative/requesters/requester.py | 4 +- .../declarative/retrievers/async_retriever.py | 16 +- .../retrievers/simple_retriever.py | 129 ++- .../schema/default_schema_loader.py | 4 +- airbyte_cdk/sources/declarative/spec/spec.py | 10 +- .../stream_slicers/stream_slicer.py | 4 +- .../declarative/transformations/add_fields.py | 15 +- .../transformations/remove_fields.py | 8 +- airbyte_cdk/sources/declarative/types.py | 9 +- .../declarative/yaml_declarative_source.py | 4 +- .../sources/embedded/base_integration.py | 18 +- airbyte_cdk/sources/embedded/catalog.py | 20 +- airbyte_cdk/sources/embedded/runner.py | 22 +- airbyte_cdk/sources/embedded/tools.py | 4 +- ...stract_file_based_availability_strategy.py | 16 +- ...efault_file_based_availability_strategy.py | 30 +- .../config/abstract_file_based_spec.py | 18 +- .../sources/file_based/config/csv_format.py | 30 +- .../config/file_based_stream_config.py | 8 +- .../file_based/config/unstructured_format.py | 8 +- .../default_discovery_policy.py | 9 +- airbyte_cdk/sources/file_based/exceptions.py | 28 +- .../sources/file_based/file_based_source.py | 106 +- .../file_based/file_based_stream_reader.py | 21 +- .../file_based/file_types/avro_parser.py | 75 +- .../file_based/file_types/csv_parser.py | 109 +- .../file_based/file_types/excel_parser.py | 30 +- .../file_based/file_types/file_transfer.py | 10 +- .../file_based/file_types/file_type_parser.py | 5 +- .../file_based/file_types/jsonl_parser.py | 25 +- .../file_based/file_types/parquet_parser.py | 72 +- .../file_types/unstructured_parser.py | 79 +- .../sources/file_based/schema_helpers.py | 43 +- .../abstract_schema_validation_policy.py | 4 +- .../default_schema_validation_policies.py | 21 +- .../stream/abstract_file_based_stream.py | 40 +- .../file_based/stream/concurrent/adapters.py | 58 +- .../abstract_concurrent_file_based_cursor.py | 4 +- .../cursor/file_based_concurrent_cursor.py | 70 +- .../cursor/file_based_final_state_cursor.py | 22 +- .../cursor/abstract_file_based_cursor.py | 4 +- .../cursor/default_file_based_cursor.py | 35 +- .../stream/default_file_based_stream.py | 87 +- airbyte_cdk/sources/http_logger.py | 6 +- airbyte_cdk/sources/message/repository.py | 22 +- airbyte_cdk/sources/source.py | 18 +- .../sources/streams/availability_strategy.py | 12 +- airbyte_cdk/sources/streams/call_rate.py | 82 +- .../streams/checkpoint/checkpoint_reader.py | 38 +- ...substream_resumable_full_refresh_cursor.py | 8 +- .../sources/streams/concurrent/adapters.py | 87 +- .../sources/streams/concurrent/cursor.py | 73 +- .../streams/concurrent/default_stream.py | 11 +- .../sources/streams/concurrent/helpers.py | 8 +- .../streams/concurrent/partition_enqueuer.py | 11 +- .../streams/concurrent/partition_reader.py | 5 +- .../streams/concurrent/partitions/record.py | 12 +- .../streams/concurrent/partitions/types.py | 8 +- .../abstract_stream_state_converter.py | 32 +- .../datetime_stream_state_converter.py | 36 +- airbyte_cdk/sources/streams/core.py | 99 +- .../streams/http/availability_strategy.py | 4 +- .../error_handlers/default_error_mapping.py | 5 +- .../http/error_handlers/error_handler.py | 4 +- .../http_status_error_handler.py | 21 +- .../http/error_handlers/response_models.py | 12 +- .../sources/streams/http/exceptions.py | 3 +- airbyte_cdk/sources/streams/http/http.py | 165 ++- .../sources/streams/http/http_client.py | 116 +- .../sources/streams/http/rate_limiting.py | 30 +- .../requests_native_auth/abstract_oauth.py | 25 +- .../http/requests_native_auth/oauth.py | 49 +- .../http/requests_native_auth/token.py | 16 +- airbyte_cdk/sources/types.py | 6 +- airbyte_cdk/sources/utils/record_helper.py | 15 +- airbyte_cdk/sources/utils/schema_helpers.py | 12 +- airbyte_cdk/sources/utils/slice_logger.py | 5 +- airbyte_cdk/sources/utils/transform.py | 33 +- airbyte_cdk/sql/exceptions.py | 25 +- airbyte_cdk/sql/secrets.py | 4 +- airbyte_cdk/sql/shared/catalog_providers.py | 17 +- airbyte_cdk/sql/shared/sql_processor.py | 58 +- airbyte_cdk/test/catalog_builder.py | 21 +- airbyte_cdk/test/entrypoint_wrapper.py | 35 +- airbyte_cdk/test/mock_http/mocker.py | 52 +- airbyte_cdk/test/mock_http/request.py | 12 +- airbyte_cdk/test/mock_http/response.py | 4 +- .../test/mock_http/response_builder.py | 36 +- airbyte_cdk/test/state_builder.py | 12 +- airbyte_cdk/test/utils/data.py | 8 +- airbyte_cdk/test/utils/http_mocking.py | 4 +- airbyte_cdk/utils/airbyte_secrets_utils.py | 4 +- airbyte_cdk/utils/analytics_message.py | 12 +- airbyte_cdk/utils/datetime_format_inferrer.py | 5 +- airbyte_cdk/utils/mapping_helpers.py | 4 +- airbyte_cdk/utils/message_utils.py | 12 +- airbyte_cdk/utils/print_buffer.py | 7 +- airbyte_cdk/utils/schema_inferrer.py | 39 +- .../utils/spec_schema_transformations.py | 4 +- airbyte_cdk/utils/traced_exception.py | 42 +- bin/generate_component_manifest_files.py | 22 +- docs/generate.py | 4 +- reference_docs/generate_rst_schema.py | 14 +- unit_tests/conftest.py | 8 +- .../test_connector_builder_handler.py | 291 ++++- .../connector_builder/test_message_grouper.py | 280 ++++- unit_tests/destinations/test_destination.py | 46 +- .../document_processor_test.py | 10 +- .../vector_db_based/embedder_test.py | 44 +- .../vector_db_based/writer_test.py | 35 +- .../test_concurrent_source_adapter.py | 39 +- .../declarative/async_job/test_integration.py | 24 +- .../async_job/test_job_orchestrator.py | 126 +- .../declarative/async_job/test_job_tracker.py | 5 +- .../sources/declarative/auth/test_jwt.py | 8 +- .../sources/declarative/auth/test_oauth.py | 44 +- .../auth/test_selective_authenticator.py | 4 +- .../auth/test_session_token_auth.py | 24 +- .../declarative/auth/test_token_auth.py | 75 +- .../declarative/auth/test_token_provider.py | 13 +- .../declarative/checks/test_check_stream.py | 29 +- .../test_concurrency_level.py | 49 +- .../datetime/test_datetime_parser.py | 39 +- .../datetime/test_min_max_datetime.py | 107 +- .../declarative/decoders/test_json_decoder.py | 19 +- .../test_pagination_decoder_decorator.py | 5 +- .../declarative/decoders/test_xml_decoder.py | 17 +- .../extractors/test_dpath_extractor.py | 38 +- .../extractors/test_record_filter.py | 161 ++- .../extractors/test_record_selector.py | 56 +- .../incremental/test_datetime_based_cursor.py | 549 +++++++-- .../incremental/test_per_partition_cursor.py | 279 ++++- .../test_per_partition_cursor_integration.py | 251 +++- .../test_resumable_full_refresh_cursor.py | 5 +- .../declarative/interpolation/test_filters.py | 4 +- .../test_interpolated_boolean.py | 16 +- .../test_interpolated_mapping.py | 18 +- .../test_interpolated_nested_mapping.py | 14 +- .../declarative/interpolation/test_jinja.py | 80 +- .../declarative/interpolation/test_macros.py | 53 +- .../test_legacy_to_per_partition_migration.py | 133 ++- .../test_manifest_component_transformer.py | 85 +- .../test_manifest_reference_resolver.py | 38 +- .../test_model_to_component_factory.py | 872 ++++++++++---- .../declarative/parsers/testing_components.py | 9 +- ...test_cartesian_product_partition_router.py | 167 ++- .../test_list_partition_router.py | 80 +- .../test_parent_state_stream.py | 667 ++++++++--- .../test_single_partition_router.py | 4 +- .../test_substream_partition_router.py | 408 +++++-- .../test_constant_backoff.py | 15 +- .../test_exponential_backoff.py | 4 +- .../backoff_strategies/test_header_helper.py | 28 +- .../test_wait_time_from_header.py | 21 +- .../test_wait_until_time_from_header.py | 76 +- .../test_composite_error_handler.py | 81 +- .../test_default_error_handler.py | 60 +- .../test_default_http_response_filter.py | 8 +- .../test_http_response_filter.py | 65 +- .../test_cursor_pagination_strategy.py | 12 +- .../paginators/test_default_paginator.py | 100 +- .../paginators/test_offset_increment.py | 46 +- .../paginators/test_page_increment.py | 37 +- .../paginators/test_request_option.py | 21 +- .../paginators/test_stop_condition.py | 42 +- ...datetime_based_request_options_provider.py | 120 +- ...t_interpolated_request_options_provider.py | 191 +++- .../requesters/test_http_job_repository.py | 45 +- .../requesters/test_http_requester.py | 304 ++++- ...est_interpolated_request_input_provider.py | 30 +- .../retrievers/test_simple_retriever.py | 149 ++- .../schema/test_default_schema_loader.py | 5 +- .../schema/test_json_file_schema_loader.py | 31 +- .../sources/declarative/spec/test_spec.py | 20 +- .../test_concurrent_declarative_source.py | 626 ++++++++-- .../declarative/test_declarative_stream.py | 38 +- .../test_manifest_declarative_source.py | 670 +++++++++-- unit_tests/sources/declarative/test_types.py | 56 +- .../transformations/test_add_fields.py | 50 +- .../test_keys_to_lower_transformation.py | 4 +- .../transformations/test_remove_fields.py | 111 +- .../embedded/test_embedded_integration.py | 24 +- ...efault_file_based_availability_strategy.py | 32 +- .../config/test_abstract_file_based_spec.py | 21 +- .../file_based/config/test_csv_format.py | 7 +- .../config/test_file_based_stream_config.py | 68 +- .../test_default_discovery_policy.py | 4 +- .../file_based/file_types/test_avro_parser.py | 201 +++- .../file_based/file_types/test_csv_parser.py | 223 +++- .../file_types/test_excel_parser.py | 23 +- .../file_types/test_jsonl_parser.py | 80 +- .../file_types/test_parquet_parser.py | 397 +++++-- .../file_types/test_unstructured_parser.py | 130 ++- unit_tests/sources/file_based/helpers.py | 31 +- .../file_based/in_memory_files_source.py | 68 +- .../file_based/scenarios/avro_scenarios.py | 168 ++- .../file_based/scenarios/check_scenarios.py | 12 +- .../concurrent_incremental_scenarios.py | 69 +- .../file_based/scenarios/csv_scenarios.py | 200 +++- .../file_based/scenarios/excel_scenarios.py | 51 +- .../scenarios/file_based_source_builder.py | 30 +- .../scenarios/incremental_scenarios.py | 39 +- .../file_based/scenarios/jsonl_scenarios.py | 49 +- .../file_based/scenarios/parquet_scenarios.py | 79 +- .../file_based/scenarios/scenario_builder.py | 56 +- .../scenarios/unstructured_scenarios.py | 8 +- .../scenarios/user_input_schema_scenarios.py | 52 +- .../scenarios/validation_policy_scenarios.py | 8 +- .../test_default_schema_validation_policy.py | 68 +- .../stream/concurrent/test_adapters.py | 150 ++- .../test_file_based_concurrent_cursor.py | 226 +++- .../stream/test_default_file_based_cursor.py | 231 +++- .../stream/test_default_file_based_stream.py | 35 +- .../file_based/test_file_based_scenarios.py | 19 +- .../test_file_based_stream_reader.py | 123 +- .../sources/file_based/test_scenarios.py | 80 +- .../sources/file_based/test_schema_helpers.py | 121 +- .../sources/fixtures/source_test_fixture.py | 9 +- unit_tests/sources/message/test_repository.py | 45 +- .../mock_server_tests/mock_source_fixture.py | 64 +- .../airbyte_message_assertions.py | 4 +- .../test_mock_server_abstract_source.py | 249 +++- .../test_resumable_full_refresh.py | 165 ++- .../checkpoint/test_checkpoint_reader.py | 56 +- ...substream_resumable_full_refresh_cursor.py | 44 +- .../scenarios/incremental_scenarios.py | 73 +- .../scenarios/stream_facade_builder.py | 33 +- .../scenarios/stream_facade_scenarios.py | 19 +- .../scenarios/test_concurrent_scenarios.py | 4 +- ...hread_based_concurrent_stream_scenarios.py | 58 +- ..._based_concurrent_stream_source_builder.py | 40 +- .../streams/concurrent/scenarios/utils.py | 14 +- .../streams/concurrent/test_adapters.py | 158 ++- .../test_concurrent_read_processor.py | 180 ++- .../sources/streams/concurrent/test_cursor.py | 453 ++++++-- .../test_datetime_state_converter.py | 33 +- .../streams/concurrent/test_default_stream.py | 30 +- .../concurrent/test_partition_enqueuer.py | 30 +- .../concurrent/test_partition_reader.py | 22 +- .../concurrent/test_thread_pool_manager.py | 8 +- .../test_default_backoff_strategy.py | 4 +- .../test_http_status_error_handler.py | 37 +- .../test_json_error_message_parser.py | 14 +- .../error_handlers/test_response_models.py | 32 +- .../test_requests_native_auth.py | 163 ++- .../http/test_availability_strategy.py | 15 +- unit_tests/sources/streams/http/test_http.py | 263 ++++- .../sources/streams/http/test_http_client.py | 202 +++- unit_tests/sources/streams/test_call_rate.py | 94 +- .../sources/streams/test_stream_read.py | 217 +++- .../sources/streams/test_streams_core.py | 49 +- .../streams/utils/test_stream_helper.py | 7 +- unit_tests/sources/test_abstract_source.py | 484 ++++++-- unit_tests/sources/test_config.py | 11 +- .../sources/test_connector_state_manager.py | 157 ++- unit_tests/sources/test_http_logger.py | 61 +- unit_tests/sources/test_integration_source.py | 101 +- unit_tests/sources/test_source.py | 107 +- unit_tests/sources/test_source_read.py | 102 +- .../sources/utils/test_record_helper.py | 16 +- .../sources/utils/test_schema_helpers.py | 6 +- unit_tests/sources/utils/test_slice_logger.py | 85 +- unit_tests/sources/utils/test_transform.py | 111 +- unit_tests/test/mock_http/test_matcher.py | 8 +- unit_tests/test/mock_http/test_mocker.py | 47 +- unit_tests/test/mock_http/test_request.py | 76 +- .../test/mock_http/test_response_builder.py | 52 +- unit_tests/test/test_entrypoint_wrapper.py | 93 +- unit_tests/test_config_observation.py | 14 +- unit_tests/test_connector.py | 5 +- unit_tests/test_counter.py | 8 +- unit_tests/test_entrypoint.py | 356 ++++-- unit_tests/test_exception_handler.py | 7 +- unit_tests/test_secure_logger.py | 42 +- .../utils/test_datetime_format_inferrer.py | 62 +- unit_tests/utils/test_message_utils.py | 4 +- unit_tests/utils/test_rate_limiting.py | 8 +- unit_tests/utils/test_schema_inferrer.py | 141 ++- unit_tests/utils/test_secret_utils.py | 74 +- unit_tests/utils/test_stream_status_utils.py | 16 +- unit_tests/utils/test_traced_exception.py | 49 +- 372 files changed, 18596 insertions(+), 4697 deletions(-) diff --git a/airbyte_cdk/config_observation.py b/airbyte_cdk/config_observation.py index 94a3d64a..764174f0 100644 --- a/airbyte_cdk/config_observation.py +++ b/airbyte_cdk/config_observation.py @@ -23,7 +23,10 @@ class ObservedDict(dict): # type: ignore # disallow_any_generics is set to True, and dict is equivalent to dict[Any] def __init__( - self, non_observed_mapping: MutableMapping[Any, Any], observer: ConfigObserver, update_on_unchanged_value: bool = True + self, + non_observed_mapping: MutableMapping[Any, Any], + observer: ConfigObserver, + update_on_unchanged_value: bool = True, ) -> None: non_observed_mapping = copy(non_observed_mapping) self.observer = observer @@ -69,11 +72,15 @@ def update(self) -> None: emit_configuration_as_airbyte_control_message(self.config) -def observe_connector_config(non_observed_connector_config: MutableMapping[str, Any]) -> ObservedDict: +def observe_connector_config( + non_observed_connector_config: MutableMapping[str, Any], +) -> ObservedDict: if isinstance(non_observed_connector_config, ObservedDict): raise ValueError("This connector configuration is already observed") connector_config_observer = ConfigObserver() - observed_connector_config = ObservedDict(non_observed_connector_config, connector_config_observer) + observed_connector_config = ObservedDict( + non_observed_connector_config, connector_config_observer + ) connector_config_observer.set_config(observed_connector_config) return observed_connector_config diff --git a/airbyte_cdk/connector.py b/airbyte_cdk/connector.py index 299f814e..29cfc968 100644 --- a/airbyte_cdk/connector.py +++ b/airbyte_cdk/connector.py @@ -11,7 +11,11 @@ from typing import Any, Generic, Mapping, Optional, Protocol, TypeVar import yaml -from airbyte_cdk.models import AirbyteConnectionStatus, ConnectorSpecification, ConnectorSpecificationSerializer +from airbyte_cdk.models import ( + AirbyteConnectionStatus, + ConnectorSpecification, + ConnectorSpecificationSerializer, +) def load_optional_package_file(package: str, filename: str) -> Optional[bytes]: @@ -53,7 +57,9 @@ def _read_json_file(file_path: str) -> Any: try: return json.loads(contents) except json.JSONDecodeError as error: - raise ValueError(f"Could not read json file {file_path}: {error}. Please ensure that it is a valid JSON.") + raise ValueError( + f"Could not read json file {file_path}: {error}. Please ensure that it is a valid JSON." + ) @staticmethod def write_config(config: TConfig, config_path: str) -> None: @@ -72,7 +78,9 @@ def spec(self, logger: logging.Logger) -> ConnectorSpecification: json_spec = load_optional_package_file(package, "spec.json") if yaml_spec and json_spec: - raise RuntimeError("Found multiple spec files in the package. Only one of spec.yaml or spec.json should be provided.") + raise RuntimeError( + "Found multiple spec files in the package. Only one of spec.yaml or spec.json should be provided." + ) if yaml_spec: spec_obj = yaml.load(yaml_spec, Loader=yaml.SafeLoader) @@ -80,7 +88,9 @@ def spec(self, logger: logging.Logger) -> ConnectorSpecification: try: spec_obj = json.loads(json_spec) except json.JSONDecodeError as error: - raise ValueError(f"Could not read json spec file: {error}. Please ensure that it is a valid JSON.") + raise ValueError( + f"Could not read json spec file: {error}. Please ensure that it is a valid JSON." + ) else: raise FileNotFoundError("Unable to find spec.yaml or spec.json in the package.") @@ -101,7 +111,9 @@ def write_config(config: Mapping[str, Any], config_path: str) -> None: ... class DefaultConnectorMixin: # can be overridden to change an input config - def configure(self: _WriteConfigProtocol, config: Mapping[str, Any], temp_dir: str) -> Mapping[str, Any]: + def configure( + self: _WriteConfigProtocol, config: Mapping[str, Any], temp_dir: str + ) -> Mapping[str, Any]: config_path = os.path.join(temp_dir, "config.json") self.write_config(config, config_path) return config diff --git a/airbyte_cdk/connector_builder/connector_builder_handler.py b/airbyte_cdk/connector_builder/connector_builder_handler.py index b3cfd9a0..44d1bfe1 100644 --- a/airbyte_cdk/connector_builder/connector_builder_handler.py +++ b/airbyte_cdk/connector_builder/connector_builder_handler.py @@ -7,12 +7,19 @@ from typing import Any, List, Mapping from airbyte_cdk.connector_builder.message_grouper import MessageGrouper -from airbyte_cdk.models import AirbyteMessage, AirbyteRecordMessage, AirbyteStateMessage, ConfiguredAirbyteCatalog +from airbyte_cdk.models import ( + AirbyteMessage, + AirbyteRecordMessage, + AirbyteStateMessage, + ConfiguredAirbyteCatalog, +) from airbyte_cdk.models import Type from airbyte_cdk.models import Type as MessageType from airbyte_cdk.sources.declarative.declarative_source import DeclarativeSource from airbyte_cdk.sources.declarative.manifest_declarative_source import ManifestDeclarativeSource -from airbyte_cdk.sources.declarative.parsers.model_to_component_factory import ModelToComponentFactory +from airbyte_cdk.sources.declarative.parsers.model_to_component_factory import ( + ModelToComponentFactory, +) from airbyte_cdk.utils.airbyte_secrets_utils import filter_secrets from airbyte_cdk.utils.traced_exception import AirbyteTracedException @@ -34,7 +41,9 @@ class TestReadLimits: def get_limits(config: Mapping[str, Any]) -> TestReadLimits: command_config = config.get("__test_read_config", {}) - max_pages_per_slice = command_config.get(MAX_PAGES_PER_SLICE_KEY) or DEFAULT_MAXIMUM_NUMBER_OF_PAGES_PER_SLICE + max_pages_per_slice = ( + command_config.get(MAX_PAGES_PER_SLICE_KEY) or DEFAULT_MAXIMUM_NUMBER_OF_PAGES_PER_SLICE + ) max_slices = command_config.get(MAX_SLICES_KEY) or DEFAULT_MAXIMUM_NUMBER_OF_SLICES max_records = command_config.get(MAX_RECORDS_KEY) or DEFAULT_MAXIMUM_RECORDS return TestReadLimits(max_records, max_pages_per_slice, max_slices) @@ -64,15 +73,24 @@ def read_stream( ) -> AirbyteMessage: try: handler = MessageGrouper(limits.max_pages_per_slice, limits.max_slices, limits.max_records) - stream_name = configured_catalog.streams[0].stream.name # The connector builder only supports a single stream - stream_read = handler.get_message_groups(source, config, configured_catalog, state, limits.max_records) + stream_name = configured_catalog.streams[ + 0 + ].stream.name # The connector builder only supports a single stream + stream_read = handler.get_message_groups( + source, config, configured_catalog, state, limits.max_records + ) return AirbyteMessage( type=MessageType.RECORD, - record=AirbyteRecordMessage(data=dataclasses.asdict(stream_read), stream=stream_name, emitted_at=_emitted_at()), + record=AirbyteRecordMessage( + data=dataclasses.asdict(stream_read), stream=stream_name, emitted_at=_emitted_at() + ), ) except Exception as exc: error = AirbyteTracedException.from_exception( - exc, message=filter_secrets(f"Error reading stream with config={config} and catalog={configured_catalog}: {str(exc)}") + exc, + message=filter_secrets( + f"Error reading stream with config={config} and catalog={configured_catalog}: {str(exc)}" + ), ) return error.as_airbyte_message() @@ -88,7 +106,9 @@ def resolve_manifest(source: ManifestDeclarativeSource) -> AirbyteMessage: ), ) except Exception as exc: - error = AirbyteTracedException.from_exception(exc, message=f"Error resolving manifest: {str(exc)}") + error = AirbyteTracedException.from_exception( + exc, message=f"Error resolving manifest: {str(exc)}" + ) return error.as_airbyte_message() diff --git a/airbyte_cdk/connector_builder/main.py b/airbyte_cdk/connector_builder/main.py index 9e6fe188..35ba7e46 100644 --- a/airbyte_cdk/connector_builder/main.py +++ b/airbyte_cdk/connector_builder/main.py @@ -7,7 +7,13 @@ from typing import Any, List, Mapping, Optional, Tuple from airbyte_cdk.connector import BaseConnector -from airbyte_cdk.connector_builder.connector_builder_handler import TestReadLimits, create_source, get_limits, read_stream, resolve_manifest +from airbyte_cdk.connector_builder.connector_builder_handler import ( + TestReadLimits, + create_source, + get_limits, + read_stream, + resolve_manifest, +) from airbyte_cdk.entrypoint import AirbyteEntrypoint from airbyte_cdk.models import ( AirbyteMessage, @@ -22,11 +28,17 @@ from orjson import orjson -def get_config_and_catalog_from_args(args: List[str]) -> Tuple[str, Mapping[str, Any], Optional[ConfiguredAirbyteCatalog], Any]: +def get_config_and_catalog_from_args( + args: List[str], +) -> Tuple[str, Mapping[str, Any], Optional[ConfiguredAirbyteCatalog], Any]: # TODO: Add functionality for the `debug` logger. # Currently, no one `debug` level log will be displayed during `read` a stream for a connector created through `connector-builder`. parsed_args = AirbyteEntrypoint.parse_args(args) - config_path, catalog_path, state_path = parsed_args.config, parsed_args.catalog, parsed_args.state + config_path, catalog_path, state_path = ( + parsed_args.config, + parsed_args.catalog, + parsed_args.state, + ) if parsed_args.command != "read": raise ValueError("Only read commands are allowed for Connector Builder requests.") @@ -64,7 +76,9 @@ def handle_connector_builder_request( if command == "resolve_manifest": return resolve_manifest(source) elif command == "test_read": - assert catalog is not None, "`test_read` requires a valid `ConfiguredAirbyteCatalog`, got None." + assert ( + catalog is not None + ), "`test_read` requires a valid `ConfiguredAirbyteCatalog`, got None." return read_stream(source, config, catalog, state, limits) else: raise ValueError(f"Unrecognized command {command}.") @@ -75,7 +89,9 @@ def handle_request(args: List[str]) -> str: limits = get_limits(config) source = create_source(config, limits) return orjson.dumps( - AirbyteMessageSerializer.dump(handle_connector_builder_request(source, command, config, catalog, state, limits)) + AirbyteMessageSerializer.dump( + handle_connector_builder_request(source, command, config, catalog, state, limits) + ) ).decode() # type: ignore[no-any-return] # Serializer.dump() always returns AirbyteMessage @@ -83,6 +99,8 @@ def handle_request(args: List[str]) -> str: try: print(handle_request(sys.argv[1:])) except Exception as exc: - error = AirbyteTracedException.from_exception(exc, message=f"Error handling request: {str(exc)}") + error = AirbyteTracedException.from_exception( + exc, message=f"Error handling request: {str(exc)}" + ) m = error.as_airbyte_message() print(orjson.dumps(AirbyteMessageSerializer.dump(m)).decode()) diff --git a/airbyte_cdk/connector_builder/message_grouper.py b/airbyte_cdk/connector_builder/message_grouper.py index e21ffd61..aa3a4293 100644 --- a/airbyte_cdk/connector_builder/message_grouper.py +++ b/airbyte_cdk/connector_builder/message_grouper.py @@ -45,7 +45,9 @@ def __init__(self, max_pages_per_slice: int, max_slices: int, max_record_limit: self._max_slices = max_slices self._max_record_limit = max_record_limit - def _pk_to_nested_and_composite_field(self, field: Optional[Union[str, List[str], List[List[str]]]]) -> List[List[str]]: + def _pk_to_nested_and_composite_field( + self, field: Optional[Union[str, List[str], List[List[str]]]] + ) -> List[List[str]]: if not field: return [[]] @@ -58,7 +60,9 @@ def _pk_to_nested_and_composite_field(self, field: Optional[Union[str, List[str] return field # type: ignore # the type of field is expected to be List[List[str]] here - def _cursor_field_to_nested_and_composite_field(self, field: Union[str, List[str]]) -> List[List[str]]: + def _cursor_field_to_nested_and_composite_field( + self, field: Union[str, List[str]] + ) -> List[List[str]]: if not field: return [[]] @@ -80,8 +84,12 @@ def get_message_groups( record_limit: Optional[int] = None, ) -> StreamRead: if record_limit is not None and not (1 <= record_limit <= self._max_record_limit): - raise ValueError(f"Record limit must be between 1 and {self._max_record_limit}. Got {record_limit}") - stream = source.streams(config)[0] # The connector builder currently only supports reading from a single stream at a time + raise ValueError( + f"Record limit must be between 1 and {self._max_record_limit}. Got {record_limit}" + ) + stream = source.streams(config)[ + 0 + ] # The connector builder currently only supports reading from a single stream at a time schema_inferrer = SchemaInferrer( self._pk_to_nested_and_composite_field(stream.primary_key), self._cursor_field_to_nested_and_composite_field(stream.cursor_field), @@ -104,7 +112,11 @@ def get_message_groups( record_limit, ): if isinstance(message_group, AirbyteLogMessage): - log_messages.append(LogMessage(**{"message": message_group.message, "level": message_group.level.value})) + log_messages.append( + LogMessage( + **{"message": message_group.message, "level": message_group.level.value} + ) + ) elif isinstance(message_group, AirbyteTraceMessage): if message_group.type == TraceType.ERROR: log_messages.append( @@ -118,7 +130,10 @@ def get_message_groups( ) ) elif isinstance(message_group, AirbyteControlMessage): - if not latest_config_update or latest_config_update.emitted_at <= message_group.emitted_at: + if ( + not latest_config_update + or latest_config_update.emitted_at <= message_group.emitted_at + ): latest_config_update = message_group elif isinstance(message_group, AuxiliaryRequest): auxiliary_requests.append(message_group) @@ -142,7 +157,9 @@ def get_message_groups( test_read_limit_reached=self._has_reached_limit(slices), auxiliary_requests=auxiliary_requests, inferred_schema=schema, - latest_config_update=self._clean_config(latest_config_update.connectorConfig.config) if latest_config_update else None, + latest_config_update=self._clean_config(latest_config_update.connectorConfig.config) + if latest_config_update + else None, inferred_datetime_formats=datetime_format_inferrer.get_inferred_datetime_formats(), ) @@ -152,7 +169,15 @@ def _get_message_groups( schema_inferrer: SchemaInferrer, datetime_format_inferrer: DatetimeFormatInferrer, limit: int, - ) -> Iterable[Union[StreamReadPages, AirbyteControlMessage, AirbyteLogMessage, AirbyteTraceMessage, AuxiliaryRequest]]: + ) -> Iterable[ + Union[ + StreamReadPages, + AirbyteControlMessage, + AirbyteLogMessage, + AirbyteTraceMessage, + AuxiliaryRequest, + ] + ]: """ Message groups are partitioned according to when request log messages are received. Subsequent response log messages and record messages belong to the prior request log message and when we encounter another request, append the latest @@ -180,10 +205,17 @@ def _get_message_groups( while records_count < limit and (message := next(messages, None)): json_object = self._parse_json(message.log) if message.type == MessageType.LOG else None if json_object is not None and not isinstance(json_object, dict): - raise ValueError(f"Expected log message to be a dict, got {json_object} of type {type(json_object)}") + raise ValueError( + f"Expected log message to be a dict, got {json_object} of type {type(json_object)}" + ) json_message: Optional[Dict[str, JsonType]] = json_object if self._need_to_close_page(at_least_one_page_in_group, message, json_message): - self._close_page(current_page_request, current_page_response, current_slice_pages, current_page_records) + self._close_page( + current_page_request, + current_page_response, + current_slice_pages, + current_page_records, + ) current_page_request = None current_page_response = None @@ -200,7 +232,9 @@ def _get_message_groups( current_slice_descriptor = self._parse_slice_description(message.log.message) # type: ignore[union-attr] # AirbyteMessage with MessageType.LOG has log.message current_slice_pages = [] at_least_one_page_in_group = False - elif message.type == MessageType.LOG and message.log.message.startswith(SliceLogger.SLICE_LOG_PREFIX): # type: ignore[union-attr] # AirbyteMessage with MessageType.LOG has log.message + elif message.type == MessageType.LOG and message.log.message.startswith( + SliceLogger.SLICE_LOG_PREFIX + ): # type: ignore[union-attr] # AirbyteMessage with MessageType.LOG has log.message # parsing the first slice current_slice_descriptor = self._parse_slice_description(message.log.message) # type: ignore[union-attr] # AirbyteMessage with MessageType.LOG has log.message elif message.type == MessageType.LOG: @@ -208,14 +242,22 @@ def _get_message_groups( if self._is_auxiliary_http_request(json_message): airbyte_cdk = json_message.get("airbyte_cdk", {}) if not isinstance(airbyte_cdk, dict): - raise ValueError(f"Expected airbyte_cdk to be a dict, got {airbyte_cdk} of type {type(airbyte_cdk)}") + raise ValueError( + f"Expected airbyte_cdk to be a dict, got {airbyte_cdk} of type {type(airbyte_cdk)}" + ) stream = airbyte_cdk.get("stream", {}) if not isinstance(stream, dict): - raise ValueError(f"Expected stream to be a dict, got {stream} of type {type(stream)}") - title_prefix = "Parent stream: " if stream.get("is_substream", False) else "" + raise ValueError( + f"Expected stream to be a dict, got {stream} of type {type(stream)}" + ) + title_prefix = ( + "Parent stream: " if stream.get("is_substream", False) else "" + ) http = json_message.get("http", {}) if not isinstance(http, dict): - raise ValueError(f"Expected http to be a dict, got {http} of type {type(http)}") + raise ValueError( + f"Expected http to be a dict, got {http} of type {type(http)}" + ) yield AuxiliaryRequest( title=title_prefix + str(http.get("title", None)), description=str(http.get("description", None)), @@ -236,13 +278,21 @@ def _get_message_groups( records_count += 1 schema_inferrer.accumulate(message.record) datetime_format_inferrer.accumulate(message.record) - elif message.type == MessageType.CONTROL and message.control.type == OrchestratorType.CONNECTOR_CONFIG: # type: ignore[union-attr] # AirbyteMessage with MessageType.CONTROL has control.type + elif ( + message.type == MessageType.CONTROL + and message.control.type == OrchestratorType.CONNECTOR_CONFIG + ): # type: ignore[union-attr] # AirbyteMessage with MessageType.CONTROL has control.type yield message.control elif message.type == MessageType.STATE: latest_state_message = message.state # type: ignore[assignment] else: if current_page_request or current_page_response or current_page_records: - self._close_page(current_page_request, current_page_response, current_slice_pages, current_page_records) + self._close_page( + current_page_request, + current_page_response, + current_slice_pages, + current_page_records, + ) yield StreamReadSlices( pages=current_slice_pages, slice_descriptor=current_slice_descriptor, @@ -250,11 +300,18 @@ def _get_message_groups( ) @staticmethod - def _need_to_close_page(at_least_one_page_in_group: bool, message: AirbyteMessage, json_message: Optional[Dict[str, Any]]) -> bool: + def _need_to_close_page( + at_least_one_page_in_group: bool, + message: AirbyteMessage, + json_message: Optional[Dict[str, Any]], + ) -> bool: return ( at_least_one_page_in_group and message.type == MessageType.LOG - and (MessageGrouper._is_page_http_request(json_message) or message.log.message.startswith("slice:")) # type: ignore[union-attr] # AirbyteMessage with MessageType.LOG has log.message + and ( + MessageGrouper._is_page_http_request(json_message) + or message.log.message.startswith("slice:") + ) # type: ignore[union-attr] # AirbyteMessage with MessageType.LOG has log.message ) @staticmethod @@ -262,7 +319,9 @@ def _is_page_http_request(json_message: Optional[Dict[str, Any]]) -> bool: if not json_message: return False else: - return MessageGrouper._is_http_log(json_message) and not MessageGrouper._is_auxiliary_http_request(json_message) + return MessageGrouper._is_http_log( + json_message + ) and not MessageGrouper._is_auxiliary_http_request(json_message) @staticmethod def _is_http_log(message: Dict[str, JsonType]) -> bool: @@ -293,7 +352,11 @@ def _close_page( Close a page when parsing message groups """ current_slice_pages.append( - StreamReadPages(request=current_page_request, response=current_page_response, records=deepcopy(current_page_records)) # type: ignore + StreamReadPages( + request=current_page_request, + response=current_page_response, + records=deepcopy(current_page_records), + ) # type: ignore ) current_page_records.clear() @@ -307,7 +370,9 @@ def _read_stream( # the generator can raise an exception # iterate over the generated messages. if next raise an exception, catch it and yield it as an AirbyteLogMessage try: - yield from AirbyteEntrypoint(source).read(source.spec(self.logger), config, configured_catalog, state) + yield from AirbyteEntrypoint(source).read( + source.spec(self.logger), config, configured_catalog, state + ) except AirbyteTracedException as traced_exception: # Look for this message which indicates that it is the "final exception" raised by AbstractSource. # If it matches, don't yield this as we don't need to show this in the Builder. @@ -315,13 +380,16 @@ def _read_stream( # is that this message will be shown in the Builder. if ( traced_exception.message is not None - and "During the sync, the following streams did not sync successfully" in traced_exception.message + and "During the sync, the following streams did not sync successfully" + in traced_exception.message ): return yield traced_exception.as_airbyte_message() except Exception as e: error_message = f"{e.args[0] if len(e.args) > 0 else str(e)}" - yield AirbyteTracedException.from_exception(e, message=error_message).as_airbyte_message() + yield AirbyteTracedException.from_exception( + e, message=error_message + ).as_airbyte_message() @staticmethod def _parse_json(log_message: AirbyteLogMessage) -> JsonType: @@ -349,7 +417,9 @@ def _create_request_from_log_message(json_http_message: Dict[str, Any]) -> HttpR def _create_response_from_log_message(json_http_message: Dict[str, Any]) -> HttpResponse: response = json_http_message.get("http", {}).get("response", {}) body = response.get("body", {}).get("content", "") - return HttpResponse(status=response.get("status_code"), body=body, headers=response.get("headers")) + return HttpResponse( + status=response.get("status_code"), body=body, headers=response.get("headers") + ) def _has_reached_limit(self, slices: List[StreamReadSlices]) -> bool: if len(slices) >= self._max_slices: diff --git a/airbyte_cdk/destinations/destination.py b/airbyte_cdk/destinations/destination.py index 308fb2e2..febf4a1b 100644 --- a/airbyte_cdk/destinations/destination.py +++ b/airbyte_cdk/destinations/destination.py @@ -11,7 +11,13 @@ from airbyte_cdk.connector import Connector from airbyte_cdk.exception_handler import init_uncaught_exception_handler -from airbyte_cdk.models import AirbyteMessage, AirbyteMessageSerializer, ConfiguredAirbyteCatalog, ConfiguredAirbyteCatalogSerializer, Type +from airbyte_cdk.models import ( + AirbyteMessage, + AirbyteMessageSerializer, + ConfiguredAirbyteCatalog, + ConfiguredAirbyteCatalogSerializer, + Type, +) from airbyte_cdk.sources.utils.schema_helpers import check_config_against_spec_or_exit from airbyte_cdk.utils.traced_exception import AirbyteTracedException from orjson import orjson @@ -24,7 +30,10 @@ class Destination(Connector, ABC): @abstractmethod def write( - self, config: Mapping[str, Any], configured_catalog: ConfiguredAirbyteCatalog, input_messages: Iterable[AirbyteMessage] + self, + config: Mapping[str, Any], + configured_catalog: ConfiguredAirbyteCatalog, + input_messages: Iterable[AirbyteMessage], ) -> Iterable[AirbyteMessage]: """Implement to define how the connector writes data to the destination""" @@ -38,15 +47,24 @@ def _parse_input_stream(self, input_stream: io.TextIOWrapper) -> Iterable[Airbyt try: yield AirbyteMessageSerializer.load(orjson.loads(line)) except orjson.JSONDecodeError: - logger.info(f"ignoring input which can't be deserialized as Airbyte Message: {line}") + logger.info( + f"ignoring input which can't be deserialized as Airbyte Message: {line}" + ) def _run_write( - self, config: Mapping[str, Any], configured_catalog_path: str, input_stream: io.TextIOWrapper + self, + config: Mapping[str, Any], + configured_catalog_path: str, + input_stream: io.TextIOWrapper, ) -> Iterable[AirbyteMessage]: - catalog = ConfiguredAirbyteCatalogSerializer.load(orjson.loads(open(configured_catalog_path).read())) + catalog = ConfiguredAirbyteCatalogSerializer.load( + orjson.loads(open(configured_catalog_path).read()) + ) input_messages = self._parse_input_stream(input_stream) logger.info("Begin writing to the destination...") - yield from self.write(config=config, configured_catalog=catalog, input_messages=input_messages) + yield from self.write( + config=config, configured_catalog=catalog, input_messages=input_messages + ) logger.info("Writing complete.") def parse_args(self, args: List[str]) -> argparse.Namespace: @@ -60,18 +78,30 @@ def parse_args(self, args: List[str]) -> argparse.Namespace: subparsers = main_parser.add_subparsers(title="commands", dest="command") # spec - subparsers.add_parser("spec", help="outputs the json configuration specification", parents=[parent_parser]) + subparsers.add_parser( + "spec", help="outputs the json configuration specification", parents=[parent_parser] + ) # check - check_parser = subparsers.add_parser("check", help="checks the config can be used to connect", parents=[parent_parser]) + check_parser = subparsers.add_parser( + "check", help="checks the config can be used to connect", parents=[parent_parser] + ) required_check_parser = check_parser.add_argument_group("required named arguments") - required_check_parser.add_argument("--config", type=str, required=True, help="path to the json configuration file") + required_check_parser.add_argument( + "--config", type=str, required=True, help="path to the json configuration file" + ) # write - write_parser = subparsers.add_parser("write", help="Writes data to the destination", parents=[parent_parser]) + write_parser = subparsers.add_parser( + "write", help="Writes data to the destination", parents=[parent_parser] + ) write_required = write_parser.add_argument_group("required named arguments") - write_required.add_argument("--config", type=str, required=True, help="path to the JSON configuration file") - write_required.add_argument("--catalog", type=str, required=True, help="path to the configured catalog JSON file") + write_required.add_argument( + "--config", type=str, required=True, help="path to the JSON configuration file" + ) + write_required.add_argument( + "--catalog", type=str, required=True, help="path to the configured catalog JSON file" + ) parsed_args = main_parser.parse_args(args) cmd = parsed_args.command @@ -109,7 +139,11 @@ def run_cmd(self, parsed_args: argparse.Namespace) -> Iterable[AirbyteMessage]: elif cmd == "write": # Wrap in UTF-8 to override any other input encodings wrapped_stdin = io.TextIOWrapper(sys.stdin.buffer, encoding="utf-8") - yield from self._run_write(config=config, configured_catalog_path=parsed_args.catalog, input_stream=wrapped_stdin) + yield from self._run_write( + config=config, + configured_catalog_path=parsed_args.catalog, + input_stream=wrapped_stdin, + ) def run(self, args: List[str]) -> None: init_uncaught_exception_handler(logger) diff --git a/airbyte_cdk/destinations/vector_db_based/config.py b/airbyte_cdk/destinations/vector_db_based/config.py index 90de6b77..904f40d3 100644 --- a/airbyte_cdk/destinations/vector_db_based/config.py +++ b/airbyte_cdk/destinations/vector_db_based/config.py @@ -17,7 +17,11 @@ class SeparatorSplitterConfigModel(BaseModel): title="Separators", description='List of separator strings to split text fields by. The separator itself needs to be wrapped in double quotes, e.g. to split by the dot character, use ".". To split by a newline, use "\\n".', ) - keep_separator: bool = Field(default=False, title="Keep separator", description="Whether to keep the separator in the resulting chunks") + keep_separator: bool = Field( + default=False, + title="Keep separator", + description="Whether to keep the separator in the resulting chunks", + ) class Config(OneOfOptionConfig): title = "By Separator" @@ -68,18 +72,20 @@ class CodeSplitterConfigModel(BaseModel): class Config(OneOfOptionConfig): title = "By Programming Language" - description = ( - "Split the text by suitable delimiters based on the programming language. This is useful for splitting code into chunks." - ) + description = "Split the text by suitable delimiters based on the programming language. This is useful for splitting code into chunks." discriminator = "mode" -TextSplitterConfigModel = Union[SeparatorSplitterConfigModel, MarkdownHeaderSplitterConfigModel, CodeSplitterConfigModel] +TextSplitterConfigModel = Union[ + SeparatorSplitterConfigModel, MarkdownHeaderSplitterConfigModel, CodeSplitterConfigModel +] class FieldNameMappingConfigModel(BaseModel): from_field: str = Field(title="From field name", description="The field name in the source") - to_field: str = Field(title="To field name", description="The field name to use in the destination") + to_field: str = Field( + title="To field name", description="The field name to use in the destination" + ) class ProcessingConfigModel(BaseModel): @@ -132,9 +138,7 @@ class OpenAIEmbeddingConfigModel(BaseModel): class Config(OneOfOptionConfig): title = "OpenAI" - description = ( - "Use the OpenAI API to embed text. This option is using the text-embedding-ada-002 model with 1536 embedding dimensions." - ) + description = "Use the OpenAI API to embed text. This option is using the text-embedding-ada-002 model with 1536 embedding dimensions." discriminator = "mode" @@ -142,7 +146,10 @@ class OpenAICompatibleEmbeddingConfigModel(BaseModel): mode: Literal["openai_compatible"] = Field("openai_compatible", const=True) api_key: str = Field(title="API key", default="", airbyte_secret=True) base_url: str = Field( - ..., title="Base URL", description="The base URL for your OpenAI-compatible service", examples=["https://your-service-name.com"] + ..., + title="Base URL", + description="The base URL for your OpenAI-compatible service", + examples=["https://your-service-name.com"], ) model_name: str = Field( title="Model name", @@ -151,7 +158,9 @@ class OpenAICompatibleEmbeddingConfigModel(BaseModel): examples=["text-embedding-ada-002"], ) dimensions: int = Field( - title="Embedding dimensions", description="The number of dimensions the embedding model is generating", examples=[1536, 384] + title="Embedding dimensions", + description="The number of dimensions the embedding model is generating", + examples=[1536, 384], ) class Config(OneOfOptionConfig): @@ -199,10 +208,16 @@ class Config(OneOfOptionConfig): class FromFieldEmbeddingConfigModel(BaseModel): mode: Literal["from_field"] = Field("from_field", const=True) field_name: str = Field( - ..., title="Field name", description="Name of the field in the record that contains the embedding", examples=["embedding", "vector"] + ..., + title="Field name", + description="Name of the field in the record that contains the embedding", + examples=["embedding", "vector"], ) dimensions: int = Field( - ..., title="Embedding dimensions", description="The number of dimensions the embedding model is generating", examples=[1536, 384] + ..., + title="Embedding dimensions", + description="The number of dimensions the embedding model is generating", + examples=[1536, 384], ) class Config(OneOfOptionConfig): @@ -241,7 +256,14 @@ class VectorDBConfigModel(BaseModel): FakeEmbeddingConfigModel, AzureOpenAIEmbeddingConfigModel, OpenAICompatibleEmbeddingConfigModel, - ] = Field(..., title="Embedding", description="Embedding configuration", discriminator="mode", group="embedding", type="object") + ] = Field( + ..., + title="Embedding", + description="Embedding configuration", + discriminator="mode", + group="embedding", + type="object", + ) processing: ProcessingConfigModel omit_raw_text: bool = Field( default=False, diff --git a/airbyte_cdk/destinations/vector_db_based/document_processor.py b/airbyte_cdk/destinations/vector_db_based/document_processor.py index 45b6e4d7..6e1723cb 100644 --- a/airbyte_cdk/destinations/vector_db_based/document_processor.py +++ b/airbyte_cdk/destinations/vector_db_based/document_processor.py @@ -8,9 +8,18 @@ from typing import Any, Dict, List, Mapping, Optional, Tuple import dpath -from airbyte_cdk.destinations.vector_db_based.config import ProcessingConfigModel, SeparatorSplitterConfigModel, TextSplitterConfigModel +from airbyte_cdk.destinations.vector_db_based.config import ( + ProcessingConfigModel, + SeparatorSplitterConfigModel, + TextSplitterConfigModel, +) from airbyte_cdk.destinations.vector_db_based.utils import create_stream_identifier -from airbyte_cdk.models import AirbyteRecordMessage, ConfiguredAirbyteCatalog, ConfiguredAirbyteStream, DestinationSyncMode +from airbyte_cdk.models import ( + AirbyteRecordMessage, + ConfiguredAirbyteCatalog, + ConfiguredAirbyteStream, + DestinationSyncMode, +) from airbyte_cdk.utils.traced_exception import AirbyteTracedException, FailureType from langchain.text_splitter import Language, RecursiveCharacterTextSplitter from langchain.utils import stringify_dict @@ -30,7 +39,14 @@ class Chunk: embedding: Optional[List[float]] = None -headers_to_split_on = ["(?:^|\n)# ", "(?:^|\n)## ", "(?:^|\n)### ", "(?:^|\n)#### ", "(?:^|\n)##### ", "(?:^|\n)###### "] +headers_to_split_on = [ + "(?:^|\n)# ", + "(?:^|\n)## ", + "(?:^|\n)### ", + "(?:^|\n)#### ", + "(?:^|\n)##### ", + "(?:^|\n)###### ", +] class DocumentProcessor: @@ -64,7 +80,10 @@ def check_config(config: ProcessingConfigModel) -> Optional[str]: return None def _get_text_splitter( - self, chunk_size: int, chunk_overlap: int, splitter_config: Optional[TextSplitterConfigModel] + self, + chunk_size: int, + chunk_overlap: int, + splitter_config: Optional[TextSplitterConfigModel], ) -> RecursiveCharacterTextSplitter: if splitter_config is None: splitter_config = SeparatorSplitterConfigModel(mode="separator") @@ -89,14 +108,20 @@ def _get_text_splitter( return RecursiveCharacterTextSplitter.from_tiktoken_encoder( chunk_size=chunk_size, chunk_overlap=chunk_overlap, - separators=RecursiveCharacterTextSplitter.get_separators_for_language(Language(splitter_config.language)), + separators=RecursiveCharacterTextSplitter.get_separators_for_language( + Language(splitter_config.language) + ), disallowed_special=(), ) def __init__(self, config: ProcessingConfigModel, catalog: ConfiguredAirbyteCatalog): - self.streams = {create_stream_identifier(stream.stream): stream for stream in catalog.streams} + self.streams = { + create_stream_identifier(stream.stream): stream for stream in catalog.streams + } - self.splitter = self._get_text_splitter(config.chunk_size, config.chunk_overlap, config.text_splitter) + self.splitter = self._get_text_splitter( + config.chunk_size, config.chunk_overlap, config.text_splitter + ) self.text_fields = config.text_fields self.metadata_fields = config.metadata_fields self.field_name_mappings = config.field_name_mappings @@ -119,10 +144,18 @@ def process(self, record: AirbyteRecordMessage) -> Tuple[List[Chunk], Optional[s failure_type=FailureType.config_error, ) chunks = [ - Chunk(page_content=chunk_document.page_content, metadata=chunk_document.metadata, record=record) + Chunk( + page_content=chunk_document.page_content, + metadata=chunk_document.metadata, + record=record, + ) for chunk_document in self._split_document(doc) ] - id_to_delete = doc.metadata[METADATA_RECORD_ID_FIELD] if METADATA_RECORD_ID_FIELD in doc.metadata else None + id_to_delete = ( + doc.metadata[METADATA_RECORD_ID_FIELD] + if METADATA_RECORD_ID_FIELD in doc.metadata + else None + ) return chunks, id_to_delete def _generate_document(self, record: AirbyteRecordMessage) -> Optional[Document]: @@ -133,7 +166,9 @@ def _generate_document(self, record: AirbyteRecordMessage) -> Optional[Document] metadata = self._extract_metadata(record) return Document(page_content=text, metadata=metadata) - def _extract_relevant_fields(self, record: AirbyteRecordMessage, fields: Optional[List[str]]) -> Dict[str, Any]: + def _extract_relevant_fields( + self, record: AirbyteRecordMessage, fields: Optional[List[str]] + ) -> Dict[str, Any]: relevant_fields = {} if fields and len(fields) > 0: for field in fields: @@ -156,7 +191,10 @@ def _extract_primary_key(self, record: AirbyteRecordMessage) -> Optional[str]: stream_identifier = create_stream_identifier(record) current_stream: ConfiguredAirbyteStream = self.streams[stream_identifier] # if the sync mode is deduping, use the primary key to upsert existing records instead of appending new ones - if not current_stream.primary_key or current_stream.destination_sync_mode != DestinationSyncMode.append_dedup: + if ( + not current_stream.primary_key + or current_stream.destination_sync_mode != DestinationSyncMode.append_dedup + ): return None primary_key = [] diff --git a/airbyte_cdk/destinations/vector_db_based/embedder.py b/airbyte_cdk/destinations/vector_db_based/embedder.py index c18592f3..4ec56fbf 100644 --- a/airbyte_cdk/destinations/vector_db_based/embedder.py +++ b/airbyte_cdk/destinations/vector_db_based/embedder.py @@ -92,7 +92,9 @@ def embed_documents(self, documents: List[Document]) -> List[Optional[List[float batches = create_chunks(documents, batch_size=embedding_batch_size) embeddings: List[Optional[List[float]]] = [] for batch in batches: - embeddings.extend(self.embeddings.embed_documents([chunk.page_content for chunk in batch])) + embeddings.extend( + self.embeddings.embed_documents([chunk.page_content for chunk in batch]) + ) return embeddings @property @@ -103,7 +105,12 @@ def embedding_dimensions(self) -> int: class OpenAIEmbedder(BaseOpenAIEmbedder): def __init__(self, config: OpenAIEmbeddingConfigModel, chunk_size: int): - super().__init__(OpenAIEmbeddings(openai_api_key=config.openai_key, max_retries=15, disallowed_special=()), chunk_size) # type: ignore + super().__init__( + OpenAIEmbeddings( + openai_api_key=config.openai_key, max_retries=15, disallowed_special=() + ), + chunk_size, + ) # type: ignore class AzureOpenAIEmbedder(BaseOpenAIEmbedder): @@ -131,7 +138,9 @@ class CohereEmbedder(Embedder): def __init__(self, config: CohereEmbeddingConfigModel): super().__init__() # Client is set internally - self.embeddings = CohereEmbeddings(cohere_api_key=config.cohere_key, model="embed-english-light-v2.0") # type: ignore + self.embeddings = CohereEmbeddings( + cohere_api_key=config.cohere_key, model="embed-english-light-v2.0" + ) # type: ignore def check(self) -> Optional[str]: try: @@ -141,7 +150,10 @@ def check(self) -> Optional[str]: return None def embed_documents(self, documents: List[Document]) -> List[Optional[List[float]]]: - return cast(List[Optional[List[float]]], self.embeddings.embed_documents([document.page_content for document in documents])) + return cast( + List[Optional[List[float]]], + self.embeddings.embed_documents([document.page_content for document in documents]), + ) @property def embedding_dimensions(self) -> int: @@ -162,7 +174,10 @@ def check(self) -> Optional[str]: return None def embed_documents(self, documents: List[Document]) -> List[Optional[List[float]]]: - return cast(List[Optional[List[float]]], self.embeddings.embed_documents([document.page_content for document in documents])) + return cast( + List[Optional[List[float]]], + self.embeddings.embed_documents([document.page_content for document in documents]), + ) @property def embedding_dimensions(self) -> int: @@ -189,7 +204,10 @@ def __init__(self, config: OpenAICompatibleEmbeddingConfigModel): def check(self) -> Optional[str]: deployment_mode = os.environ.get("DEPLOYMENT_MODE", "") - if deployment_mode.casefold() == CLOUD_DEPLOYMENT_MODE and not self.config.base_url.startswith("https://"): + if ( + deployment_mode.casefold() == CLOUD_DEPLOYMENT_MODE + and not self.config.base_url.startswith("https://") + ): return "Base URL must start with https://" try: @@ -199,7 +217,10 @@ def check(self) -> Optional[str]: return None def embed_documents(self, documents: List[Document]) -> List[Optional[List[float]]]: - return cast(List[Optional[List[float]]], self.embeddings.embed_documents([document.page_content for document in documents])) + return cast( + List[Optional[List[float]]], + self.embeddings.embed_documents([document.page_content for document in documents]), + ) @property def embedding_dimensions(self) -> int: @@ -273,6 +294,9 @@ def create_from_config( processing_config: ProcessingConfigModel, ) -> Embedder: if embedding_config.mode == "azure_openai" or embedding_config.mode == "openai": - return cast(Embedder, embedder_map[embedding_config.mode](embedding_config, processing_config.chunk_size)) + return cast( + Embedder, + embedder_map[embedding_config.mode](embedding_config, processing_config.chunk_size), + ) else: return cast(Embedder, embedder_map[embedding_config.mode](embedding_config)) diff --git a/airbyte_cdk/destinations/vector_db_based/test_utils.py b/airbyte_cdk/destinations/vector_db_based/test_utils.py index 7f8cfe5f..a2f3d3d8 100644 --- a/airbyte_cdk/destinations/vector_db_based/test_utils.py +++ b/airbyte_cdk/destinations/vector_db_based/test_utils.py @@ -26,12 +26,19 @@ class BaseIntegrationTest(unittest.TestCase): It provides helper methods to create Airbyte catalogs, records and state messages. """ - def _get_configured_catalog(self, destination_mode: DestinationSyncMode) -> ConfiguredAirbyteCatalog: - stream_schema = {"type": "object", "properties": {"str_col": {"type": "str"}, "int_col": {"type": "integer"}}} + def _get_configured_catalog( + self, destination_mode: DestinationSyncMode + ) -> ConfiguredAirbyteCatalog: + stream_schema = { + "type": "object", + "properties": {"str_col": {"type": "str"}, "int_col": {"type": "integer"}}, + } overwrite_stream = ConfiguredAirbyteStream( stream=AirbyteStream( - name="mystream", json_schema=stream_schema, supported_sync_modes=[SyncMode.incremental, SyncMode.full_refresh] + name="mystream", + json_schema=stream_schema, + supported_sync_modes=[SyncMode.incremental, SyncMode.full_refresh], ), primary_key=[["int_col"]], sync_mode=SyncMode.incremental, @@ -45,7 +52,10 @@ def _state(self, data: Dict[str, Any]) -> AirbyteMessage: def _record(self, stream: str, str_value: str, int_value: int) -> AirbyteMessage: return AirbyteMessage( - type=Type.RECORD, record=AirbyteRecordMessage(stream=stream, data={"str_col": str_value, "int_col": int_value}, emitted_at=0) + type=Type.RECORD, + record=AirbyteRecordMessage( + stream=stream, data={"str_col": str_value, "int_col": int_value}, emitted_at=0 + ), ) def setUp(self) -> None: diff --git a/airbyte_cdk/destinations/vector_db_based/utils.py b/airbyte_cdk/destinations/vector_db_based/utils.py index b0d4edeb..dbb1f471 100644 --- a/airbyte_cdk/destinations/vector_db_based/utils.py +++ b/airbyte_cdk/destinations/vector_db_based/utils.py @@ -10,7 +10,11 @@ def format_exception(exception: Exception) -> str: - return str(exception) + "\n" + "".join(traceback.TracebackException.from_exception(exception).format()) + return ( + str(exception) + + "\n" + + "".join(traceback.TracebackException.from_exception(exception).format()) + ) def create_chunks(iterable: Iterable[Any], batch_size: int) -> Iterator[Tuple[Any, ...]]: @@ -26,4 +30,6 @@ def create_stream_identifier(stream: Union[AirbyteStream, AirbyteRecordMessage]) if isinstance(stream, AirbyteStream): return str(stream.name if stream.namespace is None else f"{stream.namespace}_{stream.name}") else: - return str(stream.stream if stream.namespace is None else f"{stream.namespace}_{stream.stream}") + return str( + stream.stream if stream.namespace is None else f"{stream.namespace}_{stream.stream}" + ) diff --git a/airbyte_cdk/destinations/vector_db_based/writer.py b/airbyte_cdk/destinations/vector_db_based/writer.py index 0f764c36..268e49ef 100644 --- a/airbyte_cdk/destinations/vector_db_based/writer.py +++ b/airbyte_cdk/destinations/vector_db_based/writer.py @@ -27,7 +27,12 @@ class Writer: """ def __init__( - self, processing_config: ProcessingConfigModel, indexer: Indexer, embedder: Embedder, batch_size: int, omit_raw_text: bool + self, + processing_config: ProcessingConfigModel, + indexer: Indexer, + embedder: Embedder, + batch_size: int, + omit_raw_text: bool, ) -> None: self.processing_config = processing_config self.indexer = indexer @@ -54,7 +59,9 @@ def _process_batch(self) -> None: self.indexer.delete(ids, namespace, stream) for (namespace, stream), chunks in self.chunks.items(): - embeddings = self.embedder.embed_documents([self._convert_to_document(chunk) for chunk in chunks]) + embeddings = self.embedder.embed_documents( + [self._convert_to_document(chunk) for chunk in chunks] + ) for i, document in enumerate(chunks): document.embedding = embeddings[i] if self.omit_raw_text: @@ -63,7 +70,9 @@ def _process_batch(self) -> None: self._init_batch() - def write(self, configured_catalog: ConfiguredAirbyteCatalog, input_messages: Iterable[AirbyteMessage]) -> Iterable[AirbyteMessage]: + def write( + self, configured_catalog: ConfiguredAirbyteCatalog, input_messages: Iterable[AirbyteMessage] + ) -> Iterable[AirbyteMessage]: self.processor = DocumentProcessor(self.processing_config, configured_catalog) self.indexer.pre_sync(configured_catalog) for message in input_messages: @@ -76,7 +85,9 @@ def write(self, configured_catalog: ConfiguredAirbyteCatalog, input_messages: It record_chunks, record_id_to_delete = self.processor.process(message.record) self.chunks[(message.record.namespace, message.record.stream)].extend(record_chunks) if record_id_to_delete is not None: - self.ids_to_delete[(message.record.namespace, message.record.stream)].append(record_id_to_delete) + self.ids_to_delete[(message.record.namespace, message.record.stream)].append( + record_id_to_delete + ) self.number_of_chunks += len(record_chunks) if self.number_of_chunks >= self.batch_size: self._process_batch() diff --git a/airbyte_cdk/entrypoint.py b/airbyte_cdk/entrypoint.py index 945d2840..5a979a94 100644 --- a/airbyte_cdk/entrypoint.py +++ b/airbyte_cdk/entrypoint.py @@ -62,33 +62,54 @@ def __init__(self, source: Source): def parse_args(args: List[str]) -> argparse.Namespace: # set up parent parsers parent_parser = argparse.ArgumentParser(add_help=False) - parent_parser.add_argument("--debug", action="store_true", help="enables detailed debug logs related to the sync") + parent_parser.add_argument( + "--debug", action="store_true", help="enables detailed debug logs related to the sync" + ) main_parser = argparse.ArgumentParser() subparsers = main_parser.add_subparsers(title="commands", dest="command") # spec - subparsers.add_parser("spec", help="outputs the json configuration specification", parents=[parent_parser]) + subparsers.add_parser( + "spec", help="outputs the json configuration specification", parents=[parent_parser] + ) # check - check_parser = subparsers.add_parser("check", help="checks the config can be used to connect", parents=[parent_parser]) + check_parser = subparsers.add_parser( + "check", help="checks the config can be used to connect", parents=[parent_parser] + ) required_check_parser = check_parser.add_argument_group("required named arguments") - required_check_parser.add_argument("--config", type=str, required=True, help="path to the json configuration file") + required_check_parser.add_argument( + "--config", type=str, required=True, help="path to the json configuration file" + ) # discover discover_parser = subparsers.add_parser( - "discover", help="outputs a catalog describing the source's schema", parents=[parent_parser] + "discover", + help="outputs a catalog describing the source's schema", + parents=[parent_parser], ) required_discover_parser = discover_parser.add_argument_group("required named arguments") - required_discover_parser.add_argument("--config", type=str, required=True, help="path to the json configuration file") + required_discover_parser.add_argument( + "--config", type=str, required=True, help="path to the json configuration file" + ) # read - read_parser = subparsers.add_parser("read", help="reads the source and outputs messages to STDOUT", parents=[parent_parser]) + read_parser = subparsers.add_parser( + "read", help="reads the source and outputs messages to STDOUT", parents=[parent_parser] + ) - read_parser.add_argument("--state", type=str, required=False, help="path to the json-encoded state file") + read_parser.add_argument( + "--state", type=str, required=False, help="path to the json-encoded state file" + ) required_read_parser = read_parser.add_argument_group("required named arguments") - required_read_parser.add_argument("--config", type=str, required=True, help="path to the json configuration file") required_read_parser.add_argument( - "--catalog", type=str, required=True, help="path to the catalog used to determine which data to read" + "--config", type=str, required=True, help="path to the json configuration file" + ) + required_read_parser.add_argument( + "--catalog", + type=str, + required=True, + help="path to the catalog used to determine which data to read", ) return main_parser.parse_args(args) @@ -108,11 +129,14 @@ def run(self, parsed_args: argparse.Namespace) -> Iterable[str]: source_spec: ConnectorSpecification = self.source.spec(self.logger) try: with tempfile.TemporaryDirectory() as temp_dir: - os.environ[ENV_REQUEST_CACHE_PATH] = temp_dir # set this as default directory for request_cache to store *.sqlite files + os.environ[ENV_REQUEST_CACHE_PATH] = ( + temp_dir # set this as default directory for request_cache to store *.sqlite files + ) if cmd == "spec": message = AirbyteMessage(type=Type.SPEC, spec=source_spec) yield from [ - self.airbyte_message_to_string(queued_message) for queued_message in self._emit_queued_messages(self.source) + self.airbyte_message_to_string(queued_message) + for queued_message in self._emit_queued_messages(self.source) ] yield self.airbyte_message_to_string(message) else: @@ -120,23 +144,38 @@ def run(self, parsed_args: argparse.Namespace) -> Iterable[str]: config = self.source.configure(raw_config, temp_dir) yield from [ - self.airbyte_message_to_string(queued_message) for queued_message in self._emit_queued_messages(self.source) + self.airbyte_message_to_string(queued_message) + for queued_message in self._emit_queued_messages(self.source) ] if cmd == "check": - yield from map(AirbyteEntrypoint.airbyte_message_to_string, self.check(source_spec, config)) + yield from map( + AirbyteEntrypoint.airbyte_message_to_string, + self.check(source_spec, config), + ) elif cmd == "discover": - yield from map(AirbyteEntrypoint.airbyte_message_to_string, self.discover(source_spec, config)) + yield from map( + AirbyteEntrypoint.airbyte_message_to_string, + self.discover(source_spec, config), + ) elif cmd == "read": config_catalog = self.source.read_catalog(parsed_args.catalog) state = self.source.read_state(parsed_args.state) - yield from map(AirbyteEntrypoint.airbyte_message_to_string, self.read(source_spec, config, config_catalog, state)) + yield from map( + AirbyteEntrypoint.airbyte_message_to_string, + self.read(source_spec, config, config_catalog, state), + ) else: raise Exception("Unexpected command " + cmd) finally: - yield from [self.airbyte_message_to_string(queued_message) for queued_message in self._emit_queued_messages(self.source)] - - def check(self, source_spec: ConnectorSpecification, config: TConfig) -> Iterable[AirbyteMessage]: + yield from [ + self.airbyte_message_to_string(queued_message) + for queued_message in self._emit_queued_messages(self.source) + ] + + def check( + self, source_spec: ConnectorSpecification, config: TConfig + ) -> Iterable[AirbyteMessage]: self.set_up_secret_filter(config, source_spec.connectionSpecification) try: self.validate_connection(source_spec, config) @@ -161,7 +200,10 @@ def check(self, source_spec: ConnectorSpecification, config: TConfig) -> Iterabl raise traced_exc else: yield AirbyteMessage( - type=Type.CONNECTION_STATUS, connectionStatus=AirbyteConnectionStatus(status=Status.FAILED, message=traced_exc.message) + type=Type.CONNECTION_STATUS, + connectionStatus=AirbyteConnectionStatus( + status=Status.FAILED, message=traced_exc.message + ), ) return if check_result.status == Status.SUCCEEDED: @@ -172,7 +214,9 @@ def check(self, source_spec: ConnectorSpecification, config: TConfig) -> Iterabl yield from self._emit_queued_messages(self.source) yield AirbyteMessage(type=Type.CONNECTION_STATUS, connectionStatus=check_result) - def discover(self, source_spec: ConnectorSpecification, config: TConfig) -> Iterable[AirbyteMessage]: + def discover( + self, source_spec: ConnectorSpecification, config: TConfig + ) -> Iterable[AirbyteMessage]: self.set_up_secret_filter(config, source_spec.connectionSpecification) if self.source.check_config_against_spec: self.validate_connection(source_spec, config) @@ -181,7 +225,9 @@ def discover(self, source_spec: ConnectorSpecification, config: TConfig) -> Iter yield from self._emit_queued_messages(self.source) yield AirbyteMessage(type=Type.CATALOG, catalog=catalog) - def read(self, source_spec: ConnectorSpecification, config: TConfig, catalog: Any, state: list[Any]) -> Iterable[AirbyteMessage]: + def read( + self, source_spec: ConnectorSpecification, config: TConfig, catalog: Any, state: list[Any] + ) -> Iterable[AirbyteMessage]: self.set_up_secret_filter(config, source_spec.connectionSpecification) if self.source.check_config_against_spec: self.validate_connection(source_spec, config) @@ -194,16 +240,24 @@ def read(self, source_spec: ConnectorSpecification, config: TConfig, catalog: An yield self.handle_record_counts(message, stream_message_counter) @staticmethod - def handle_record_counts(message: AirbyteMessage, stream_message_count: DefaultDict[HashableStreamDescriptor, float]) -> AirbyteMessage: + def handle_record_counts( + message: AirbyteMessage, stream_message_count: DefaultDict[HashableStreamDescriptor, float] + ) -> AirbyteMessage: match message.type: case Type.RECORD: - stream_message_count[HashableStreamDescriptor(name=message.record.stream, namespace=message.record.namespace)] += 1.0 # type: ignore[union-attr] # record has `stream` and `namespace` + stream_message_count[ + HashableStreamDescriptor( + name=message.record.stream, namespace=message.record.namespace + ) + ] += 1.0 # type: ignore[union-attr] # record has `stream` and `namespace` case Type.STATE: stream_descriptor = message_utils.get_stream_descriptor(message) # Set record count from the counter onto the state message message.state.sourceStats = message.state.sourceStats or AirbyteStateStats() # type: ignore[union-attr] # state has `sourceStats` - message.state.sourceStats.recordCount = stream_message_count.get(stream_descriptor, 0.0) # type: ignore[union-attr] # state has `sourceStats` + message.state.sourceStats.recordCount = stream_message_count.get( + stream_descriptor, 0.0 + ) # type: ignore[union-attr] # state has `sourceStats` # Reset the counter stream_message_count[stream_descriptor] = 0.0 @@ -283,7 +337,9 @@ def filtered_send(self: Any, request: PreparedRequest, **kwargs: Any) -> Respons ) if not parsed_url.hostname: - raise requests.exceptions.InvalidURL("Invalid URL specified: The endpoint that data is being requested from is not a valid URL") + raise requests.exceptions.InvalidURL( + "Invalid URL specified: The endpoint that data is being requested from is not a valid URL" + ) try: is_private = _is_private_url(parsed_url.hostname, parsed_url.port) # type: ignore [arg-type] diff --git a/airbyte_cdk/exception_handler.py b/airbyte_cdk/exception_handler.py index 77fa8898..84aa39ba 100644 --- a/airbyte_cdk/exception_handler.py +++ b/airbyte_cdk/exception_handler.py @@ -11,7 +11,9 @@ from airbyte_cdk.utils.traced_exception import AirbyteTracedException -def assemble_uncaught_exception(exception_type: type[BaseException], exception_value: BaseException) -> AirbyteTracedException: +def assemble_uncaught_exception( + exception_type: type[BaseException], exception_value: BaseException +) -> AirbyteTracedException: if issubclass(exception_type, AirbyteTracedException): return exception_value # type: ignore # validated as part of the previous line return AirbyteTracedException.from_exception(exception_value) @@ -23,7 +25,11 @@ def init_uncaught_exception_handler(logger: logging.Logger) -> None: printed to the console without having secrets removed. """ - def hook_fn(exception_type: type[BaseException], exception_value: BaseException, traceback_: Optional[TracebackType]) -> Any: + def hook_fn( + exception_type: type[BaseException], + exception_value: BaseException, + traceback_: Optional[TracebackType], + ) -> Any: # For developer ergonomics, we want to see the stack trace in the logs when we do a ctrl-c if issubclass(exception_type, KeyboardInterrupt): sys.__excepthook__(exception_type, exception_value, traceback_) @@ -41,6 +47,10 @@ def hook_fn(exception_type: type[BaseException], exception_value: BaseException, def generate_failed_streams_error_message(stream_failures: Mapping[str, List[Exception]]) -> str: failures = "\n".join( - [f"{stream}: {filter_secrets(exception.__repr__())}" for stream, exceptions in stream_failures.items() for exception in exceptions] + [ + f"{stream}: {filter_secrets(exception.__repr__())}" + for stream, exceptions in stream_failures.items() + for exception in exceptions + ] ) return f"During the sync, the following streams did not sync successfully: {failures}" diff --git a/airbyte_cdk/logger.py b/airbyte_cdk/logger.py index 59d4d7dd..055d80e8 100644 --- a/airbyte_cdk/logger.py +++ b/airbyte_cdk/logger.py @@ -7,7 +7,13 @@ import logging.config from typing import Any, Callable, Mapping, Optional, Tuple -from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, AirbyteMessageSerializer, Level, Type +from airbyte_cdk.models import ( + AirbyteLogMessage, + AirbyteMessage, + AirbyteMessageSerializer, + Level, + Type, +) from airbyte_cdk.utils.airbyte_secrets_utils import filter_secrets from orjson import orjson @@ -68,7 +74,9 @@ def format(self, record: logging.LogRecord) -> str: else: message = super().format(record) message = filter_secrets(message) - log_message = AirbyteMessage(type=Type.LOG, log=AirbyteLogMessage(level=airbyte_level, message=message)) + log_message = AirbyteMessage( + type=Type.LOG, log=AirbyteLogMessage(level=airbyte_level, message=message) + ) return orjson.dumps(AirbyteMessageSerializer.dump(log_message)).decode() # type: ignore[no-any-return] # orjson.dumps(message).decode() always returns string @staticmethod diff --git a/airbyte_cdk/models/airbyte_protocol.py b/airbyte_cdk/models/airbyte_protocol.py index 7f12da5b..b5d8683a 100644 --- a/airbyte_cdk/models/airbyte_protocol.py +++ b/airbyte_cdk/models/airbyte_protocol.py @@ -42,7 +42,11 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: setattr(self, key, value) def __eq__(self, other: object) -> bool: - return False if not isinstance(other, AirbyteStateBlob) else bool(self.__dict__ == other.__dict__) + return ( + False + if not isinstance(other, AirbyteStateBlob) + else bool(self.__dict__ == other.__dict__) + ) # The following dataclasses have been redeclared to include the new version of AirbyteStateBlob diff --git a/airbyte_cdk/models/airbyte_protocol_serializers.py b/airbyte_cdk/models/airbyte_protocol_serializers.py index aeac43f7..129556ac 100644 --- a/airbyte_cdk/models/airbyte_protocol_serializers.py +++ b/airbyte_cdk/models/airbyte_protocol_serializers.py @@ -30,9 +30,15 @@ def custom_type_resolver(t: type) -> CustomType[AirbyteStateBlob, Dict[str, Any] return AirbyteStateBlobType() if t is AirbyteStateBlob else None -AirbyteStreamStateSerializer = Serializer(AirbyteStreamState, omit_none=True, custom_type_resolver=custom_type_resolver) -AirbyteStateMessageSerializer = Serializer(AirbyteStateMessage, omit_none=True, custom_type_resolver=custom_type_resolver) -AirbyteMessageSerializer = Serializer(AirbyteMessage, omit_none=True, custom_type_resolver=custom_type_resolver) +AirbyteStreamStateSerializer = Serializer( + AirbyteStreamState, omit_none=True, custom_type_resolver=custom_type_resolver +) +AirbyteStateMessageSerializer = Serializer( + AirbyteStateMessage, omit_none=True, custom_type_resolver=custom_type_resolver +) +AirbyteMessageSerializer = Serializer( + AirbyteMessage, omit_none=True, custom_type_resolver=custom_type_resolver +) ConfiguredAirbyteCatalogSerializer = Serializer(ConfiguredAirbyteCatalog, omit_none=True) ConfiguredAirbyteStreamSerializer = Serializer(ConfiguredAirbyteStream, omit_none=True) ConnectorSpecificationSerializer = Serializer(ConnectorSpecification, omit_none=True) diff --git a/airbyte_cdk/sources/abstract_source.py b/airbyte_cdk/sources/abstract_source.py index 3656a88c..34ba816b 100644 --- a/airbyte_cdk/sources/abstract_source.py +++ b/airbyte_cdk/sources/abstract_source.py @@ -4,7 +4,18 @@ import logging from abc import ABC, abstractmethod -from typing import Any, Dict, Iterable, Iterator, List, Mapping, MutableMapping, Optional, Tuple, Union +from typing import ( + Any, + Dict, + Iterable, + Iterator, + List, + Mapping, + MutableMapping, + Optional, + Tuple, + Union, +) from airbyte_cdk.exception_handler import generate_failed_streams_error_message from airbyte_cdk.models import ( @@ -30,7 +41,9 @@ from airbyte_cdk.sources.utils.schema_helpers import InternalConfig, split_config from airbyte_cdk.sources.utils.slice_logger import DebugSliceLogger, SliceLogger from airbyte_cdk.utils.event_timing import create_timer -from airbyte_cdk.utils.stream_status_utils import as_airbyte_message as stream_status_as_airbyte_message +from airbyte_cdk.utils.stream_status_utils import ( + as_airbyte_message as stream_status_as_airbyte_message, +) from airbyte_cdk.utils.traced_exception import AirbyteTracedException _default_message_repository = InMemoryMessageRepository() @@ -43,7 +56,9 @@ class AbstractSource(Source, ABC): """ @abstractmethod - def check_connection(self, logger: logging.Logger, config: Mapping[str, Any]) -> Tuple[bool, Optional[Any]]: + def check_connection( + self, logger: logging.Logger, config: Mapping[str, Any] + ) -> Tuple[bool, Optional[Any]]: """ :param logger: source logger :param config: The user-provided configuration as specified by the source's spec. @@ -109,7 +124,9 @@ def read( # Used direct reference to `stream_instance` instead of `is_stream_exist` to avoid mypy type checking errors if not stream_instance: if not self.raise_exception_on_missing_stream: - yield stream_status_as_airbyte_message(configured_stream.stream, AirbyteStreamStatus.INCOMPLETE) + yield stream_status_as_airbyte_message( + configured_stream.stream, AirbyteStreamStatus.INCOMPLETE + ) continue error_message = ( @@ -129,7 +146,9 @@ def read( timer.start_event(f"Syncing stream {configured_stream.stream.name}") logger.info(f"Marking stream {configured_stream.stream.name} as STARTED") - yield stream_status_as_airbyte_message(configured_stream.stream, AirbyteStreamStatus.STARTED) + yield stream_status_as_airbyte_message( + configured_stream.stream, AirbyteStreamStatus.STARTED + ) yield from self._read_stream( logger=logger, stream_instance=stream_instance, @@ -138,13 +157,19 @@ def read( internal_config=internal_config, ) logger.info(f"Marking stream {configured_stream.stream.name} as STOPPED") - yield stream_status_as_airbyte_message(configured_stream.stream, AirbyteStreamStatus.COMPLETE) + yield stream_status_as_airbyte_message( + configured_stream.stream, AirbyteStreamStatus.COMPLETE + ) except Exception as e: yield from self._emit_queued_messages() - logger.exception(f"Encountered an exception while reading stream {configured_stream.stream.name}") + logger.exception( + f"Encountered an exception while reading stream {configured_stream.stream.name}" + ) logger.info(f"Marking stream {configured_stream.stream.name} as STOPPED") - yield stream_status_as_airbyte_message(configured_stream.stream, AirbyteStreamStatus.INCOMPLETE) + yield stream_status_as_airbyte_message( + configured_stream.stream, AirbyteStreamStatus.INCOMPLETE + ) stream_descriptor = StreamDescriptor(name=configured_stream.stream.name) @@ -152,10 +177,14 @@ def read( traced_exception = e info_message = f"Stopping sync on error from stream {configured_stream.stream.name} because {self.name} does not support continuing syncs on error." else: - traced_exception = self._serialize_exception(stream_descriptor, e, stream_instance=stream_instance) + traced_exception = self._serialize_exception( + stream_descriptor, e, stream_instance=stream_instance + ) info_message = f"{self.name} does not support continuing syncs on error from stream {configured_stream.stream.name}" - yield traced_exception.as_sanitized_airbyte_message(stream_descriptor=stream_descriptor) + yield traced_exception.as_sanitized_airbyte_message( + stream_descriptor=stream_descriptor + ) stream_name_to_exception[stream_instance.name] = traced_exception # type: ignore # use configured_stream if stream_instance is None if self.stop_sync_on_stream_failure: logger.info(info_message) @@ -169,12 +198,16 @@ def read( logger.info(timer.report()) if len(stream_name_to_exception) > 0: - error_message = generate_failed_streams_error_message({key: [value] for key, value in stream_name_to_exception.items()}) # type: ignore # for some reason, mypy can't figure out the types for key and value + error_message = generate_failed_streams_error_message( + {key: [value] for key, value in stream_name_to_exception.items()} + ) # type: ignore # for some reason, mypy can't figure out the types for key and value logger.info(error_message) # We still raise at least one exception when a stream raises an exception because the platform currently relies # on a non-zero exit code to determine if a sync attempt has failed. We also raise the exception as a config_error # type because this combined error isn't actionable, but rather the previously emitted individual errors. - raise AirbyteTracedException(message=error_message, failure_type=FailureType.config_error) + raise AirbyteTracedException( + message=error_message, failure_type=FailureType.config_error + ) logger.info(f"Finished syncing {self.name}") @staticmethod @@ -183,7 +216,9 @@ def _serialize_exception( ) -> AirbyteTracedException: display_message = stream_instance.get_error_display_message(e) if stream_instance else None if display_message: - return AirbyteTracedException.from_exception(e, message=display_message, stream_descriptor=stream_descriptor) + return AirbyteTracedException.from_exception( + e, message=display_message, stream_descriptor=stream_descriptor + ) return AirbyteTracedException.from_exception(e, stream_descriptor=stream_descriptor) @property @@ -199,7 +234,9 @@ def _read_stream( internal_config: InternalConfig, ) -> Iterator[AirbyteMessage]: if internal_config.page_size and isinstance(stream_instance, HttpStream): - logger.info(f"Setting page size for {stream_instance.name} to {internal_config.page_size}") + logger.info( + f"Setting page size for {stream_instance.name} to {internal_config.page_size}" + ) stream_instance.page_size = internal_config.page_size logger.debug( f"Syncing configured stream: {configured_stream.stream.name}", @@ -243,7 +280,9 @@ def _read_stream( if record_counter == 1: logger.info(f"Marking stream {stream_name} as RUNNING") # If we just read the first record of the stream, emit the transition to the RUNNING state - yield stream_status_as_airbyte_message(configured_stream.stream, AirbyteStreamStatus.RUNNING) + yield stream_status_as_airbyte_message( + configured_stream.stream, AirbyteStreamStatus.RUNNING + ) yield from self._emit_queued_messages() yield record @@ -254,7 +293,9 @@ def _emit_queued_messages(self) -> Iterable[AirbyteMessage]: yield from self.message_repository.consume_queue() return - def _get_message(self, record_data_or_message: Union[StreamData, AirbyteMessage], stream: Stream) -> AirbyteMessage: + def _get_message( + self, record_data_or_message: Union[StreamData, AirbyteMessage], stream: Stream + ) -> AirbyteMessage: """ Converts the input to an AirbyteMessage if it is a StreamData. Returns the input as is if it is already an AirbyteMessage """ @@ -262,7 +303,12 @@ def _get_message(self, record_data_or_message: Union[StreamData, AirbyteMessage] case AirbyteMessage(): return record_data_or_message case _: - return stream_data_to_airbyte_message(stream.name, record_data_or_message, stream.transformer, stream.get_json_schema()) + return stream_data_to_airbyte_message( + stream.name, + record_data_or_message, + stream.transformer, + stream.get_json_schema(), + ) @property def message_repository(self) -> Union[None, MessageRepository]: diff --git a/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py b/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py index 0356f211..1f4a1b81 100644 --- a/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py +++ b/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py @@ -7,7 +7,9 @@ from airbyte_cdk.exception_handler import generate_failed_streams_error_message from airbyte_cdk.models import AirbyteMessage, AirbyteStreamStatus, FailureType, StreamDescriptor from airbyte_cdk.models import Type as MessageType -from airbyte_cdk.sources.concurrent_source.partition_generation_completed_sentinel import PartitionGenerationCompletedSentinel +from airbyte_cdk.sources.concurrent_source.partition_generation_completed_sentinel import ( + PartitionGenerationCompletedSentinel, +) from airbyte_cdk.sources.concurrent_source.stream_thread_exception import StreamThreadException from airbyte_cdk.sources.concurrent_source.thread_pool_manager import ThreadPoolManager from airbyte_cdk.sources.message import MessageRepository @@ -20,7 +22,9 @@ from airbyte_cdk.sources.utils.record_helper import stream_data_to_airbyte_message from airbyte_cdk.sources.utils.slice_logger import SliceLogger from airbyte_cdk.utils import AirbyteTracedException -from airbyte_cdk.utils.stream_status_utils import as_airbyte_message as stream_status_as_airbyte_message +from airbyte_cdk.utils.stream_status_utils import ( + as_airbyte_message as stream_status_as_airbyte_message, +) class ConcurrentReadProcessor: @@ -61,7 +65,9 @@ def __init__( self._streams_done: Set[str] = set() self._exceptions_per_stream_name: dict[str, List[Exception]] = {} - def on_partition_generation_completed(self, sentinel: PartitionGenerationCompletedSentinel) -> Iterable[AirbyteMessage]: + def on_partition_generation_completed( + self, sentinel: PartitionGenerationCompletedSentinel + ) -> Iterable[AirbyteMessage]: """ This method is called when a partition generation is completed. 1. Remove the stream from the list of streams currently generating partitions @@ -72,7 +78,10 @@ def on_partition_generation_completed(self, sentinel: PartitionGenerationComplet self._streams_currently_generating_partitions.remove(sentinel.stream.name) # It is possible for the stream to already be done if no partitions were generated # If the partition generation process was completed and there are no partitions left to process, the stream is done - if self._is_stream_done(stream_name) or len(self._streams_to_running_partitions[stream_name]) == 0: + if ( + self._is_stream_done(stream_name) + or len(self._streams_to_running_partitions[stream_name]) == 0 + ): yield from self._on_stream_is_done(stream_name) if self._stream_instances_to_start_partition_generation: yield self.start_next_partition_generator() # type:ignore # None may be yielded @@ -87,10 +96,14 @@ def on_partition(self, partition: Partition) -> None: stream_name = partition.stream_name() self._streams_to_running_partitions[stream_name].add(partition) if self._slice_logger.should_log_slice_message(self._logger): - self._message_repository.emit_message(self._slice_logger.create_slice_log_message(partition.to_slice())) + self._message_repository.emit_message( + self._slice_logger.create_slice_log_message(partition.to_slice()) + ) self._thread_pool_manager.submit(self._partition_reader.process_partition, partition) - def on_partition_complete_sentinel(self, sentinel: PartitionCompleteSentinel) -> Iterable[AirbyteMessage]: + def on_partition_complete_sentinel( + self, sentinel: PartitionCompleteSentinel + ) -> Iterable[AirbyteMessage]: """ This method is called when a partition is completed. 1. Close the partition @@ -112,7 +125,10 @@ def on_partition_complete_sentinel(self, sentinel: PartitionCompleteSentinel) -> if partition in partitions_running: partitions_running.remove(partition) # If all partitions were generated and this was the last one, the stream is done - if partition.stream_name() not in self._streams_currently_generating_partitions and len(partitions_running) == 0: + if ( + partition.stream_name() not in self._streams_currently_generating_partitions + and len(partitions_running) == 0 + ): yield from self._on_stream_is_done(partition.stream_name()) yield from self._message_repository.consume_queue() @@ -139,7 +155,9 @@ def on_record(self, record: Record) -> Iterable[AirbyteMessage]: if message.type == MessageType.RECORD: if self._record_counter[stream.name] == 0: self._logger.info(f"Marking stream {stream.name} as RUNNING") - yield stream_status_as_airbyte_message(stream.as_airbyte_stream(), AirbyteStreamStatus.RUNNING) + yield stream_status_as_airbyte_message( + stream.as_airbyte_stream(), AirbyteStreamStatus.RUNNING + ) self._record_counter[stream.name] += 1 stream.cursor.observe(record) yield message @@ -152,13 +170,17 @@ def on_exception(self, exception: StreamThreadException) -> Iterable[AirbyteMess 2. Raise the exception """ self._flag_exception(exception.stream_name, exception.exception) - self._logger.exception(f"Exception while syncing stream {exception.stream_name}", exc_info=exception.exception) + self._logger.exception( + f"Exception while syncing stream {exception.stream_name}", exc_info=exception.exception + ) stream_descriptor = StreamDescriptor(name=exception.stream_name) if isinstance(exception.exception, AirbyteTracedException): yield exception.exception.as_airbyte_message(stream_descriptor=stream_descriptor) else: - yield AirbyteTracedException.from_exception(exception, stream_descriptor=stream_descriptor).as_airbyte_message() + yield AirbyteTracedException.from_exception( + exception, stream_descriptor=stream_descriptor + ).as_airbyte_message() def _flag_exception(self, stream_name: str, exception: Exception) -> None: self._exceptions_per_stream_name.setdefault(stream_name, []).append(exception) @@ -192,7 +214,12 @@ def is_done(self) -> bool: 2. There are no more streams to read from 3. All partitions for all streams are closed """ - is_done = all([self._is_stream_done(stream_name) for stream_name in self._stream_name_to_instance.keys()]) + is_done = all( + [ + self._is_stream_done(stream_name) + for stream_name in self._stream_name_to_instance.keys() + ] + ) if is_done and self._exceptions_per_stream_name: error_message = generate_failed_streams_error_message(self._exceptions_per_stream_name) self._logger.info(error_message) @@ -200,7 +227,9 @@ def is_done(self) -> bool: # on a non-zero exit code to determine if a sync attempt has failed. We also raise the exception as a config_error # type because this combined error isn't actionable, but rather the previously emitted individual errors. raise AirbyteTracedException( - message=error_message, internal_message="Concurrent read failure", failure_type=FailureType.config_error + message=error_message, + internal_message="Concurrent read failure", + failure_type=FailureType.config_error, ) return is_done @@ -208,7 +237,9 @@ def _is_stream_done(self, stream_name: str) -> bool: return stream_name in self._streams_done def _on_stream_is_done(self, stream_name: str) -> Iterable[AirbyteMessage]: - self._logger.info(f"Read {self._record_counter[stream_name]} records from {stream_name} stream") + self._logger.info( + f"Read {self._record_counter[stream_name]} records from {stream_name} stream" + ) self._logger.info(f"Marking stream {stream_name} as STOPPED") stream = self._stream_name_to_instance[stream_name] stream.cursor.ensure_at_least_one_state_emitted() @@ -216,6 +247,8 @@ def _on_stream_is_done(self, stream_name: str) -> Iterable[AirbyteMessage]: self._logger.info(f"Finished syncing {stream.name}") self._streams_done.add(stream_name) stream_status = ( - AirbyteStreamStatus.INCOMPLETE if self._exceptions_per_stream_name.get(stream_name, []) else AirbyteStreamStatus.COMPLETE + AirbyteStreamStatus.INCOMPLETE + if self._exceptions_per_stream_name.get(stream_name, []) + else AirbyteStreamStatus.COMPLETE ) yield stream_status_as_airbyte_message(stream.as_airbyte_stream(), stream_status) diff --git a/airbyte_cdk/sources/concurrent_source/concurrent_source.py b/airbyte_cdk/sources/concurrent_source/concurrent_source.py index 8e49f66a..e5540799 100644 --- a/airbyte_cdk/sources/concurrent_source/concurrent_source.py +++ b/airbyte_cdk/sources/concurrent_source/concurrent_source.py @@ -8,7 +8,9 @@ from airbyte_cdk.models import AirbyteMessage from airbyte_cdk.sources.concurrent_source.concurrent_read_processor import ConcurrentReadProcessor -from airbyte_cdk.sources.concurrent_source.partition_generation_completed_sentinel import PartitionGenerationCompletedSentinel +from airbyte_cdk.sources.concurrent_source.partition_generation_completed_sentinel import ( + PartitionGenerationCompletedSentinel, +) from airbyte_cdk.sources.concurrent_source.stream_thread_exception import StreamThreadException from airbyte_cdk.sources.concurrent_source.thread_pool_manager import ThreadPoolManager from airbyte_cdk.sources.message import InMemoryMessageRepository, MessageRepository @@ -17,7 +19,10 @@ from airbyte_cdk.sources.streams.concurrent.partition_reader import PartitionReader from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition from airbyte_cdk.sources.streams.concurrent.partitions.record import Record -from airbyte_cdk.sources.streams.concurrent.partitions.types import PartitionCompleteSentinel, QueueItem +from airbyte_cdk.sources.streams.concurrent.partitions.types import ( + PartitionCompleteSentinel, + QueueItem, +) from airbyte_cdk.sources.utils.slice_logger import DebugSliceLogger, SliceLogger @@ -41,14 +46,25 @@ def create( timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS, ) -> "ConcurrentSource": is_single_threaded = initial_number_of_partitions_to_generate == 1 and num_workers == 1 - too_many_generator = not is_single_threaded and initial_number_of_partitions_to_generate >= num_workers - assert not too_many_generator, "It is required to have more workers than threads generating partitions" + too_many_generator = ( + not is_single_threaded and initial_number_of_partitions_to_generate >= num_workers + ) + assert ( + not too_many_generator + ), "It is required to have more workers than threads generating partitions" threadpool = ThreadPoolManager( - concurrent.futures.ThreadPoolExecutor(max_workers=num_workers, thread_name_prefix="workerpool"), + concurrent.futures.ThreadPoolExecutor( + max_workers=num_workers, thread_name_prefix="workerpool" + ), logger, ) return ConcurrentSource( - threadpool, logger, slice_logger, message_repository, initial_number_of_partitions_to_generate, timeout_seconds + threadpool, + logger, + slice_logger, + message_repository, + initial_number_of_partitions_to_generate, + timeout_seconds, ) def __init__( @@ -107,7 +123,9 @@ def read( self._threadpool.check_for_errors_and_shutdown() self._logger.info("Finished syncing") - def _submit_initial_partition_generators(self, concurrent_stream_processor: ConcurrentReadProcessor) -> Iterable[AirbyteMessage]: + def _submit_initial_partition_generators( + self, concurrent_stream_processor: ConcurrentReadProcessor + ) -> Iterable[AirbyteMessage]: for _ in range(self._initial_number_partitions_to_generate): status_message = concurrent_stream_processor.start_next_partition_generator() if status_message: diff --git a/airbyte_cdk/sources/concurrent_source/concurrent_source_adapter.py b/airbyte_cdk/sources/concurrent_source/concurrent_source_adapter.py index bbffe8f8..c150dc95 100644 --- a/airbyte_cdk/sources/concurrent_source/concurrent_source_adapter.py +++ b/airbyte_cdk/sources/concurrent_source/concurrent_source_adapter.py @@ -15,8 +15,17 @@ from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream from airbyte_cdk.sources.streams.concurrent.abstract_stream_facade import AbstractStreamFacade from airbyte_cdk.sources.streams.concurrent.adapters import StreamFacade -from airbyte_cdk.sources.streams.concurrent.cursor import ConcurrentCursor, Cursor, CursorField, CursorValueType, FinalStateCursor, GapType -from airbyte_cdk.sources.streams.concurrent.state_converters.abstract_stream_state_converter import AbstractStreamStateConverter +from airbyte_cdk.sources.streams.concurrent.cursor import ( + ConcurrentCursor, + Cursor, + CursorField, + CursorValueType, + FinalStateCursor, + GapType, +) +from airbyte_cdk.sources.streams.concurrent.state_converters.abstract_stream_state_converter import ( + AbstractStreamStateConverter, +) DEFAULT_LOOKBACK_SECONDS = 0 @@ -43,14 +52,20 @@ def read( abstract_streams = self._select_abstract_streams(config, catalog) concurrent_stream_names = {stream.name for stream in abstract_streams} configured_catalog_for_regular_streams = ConfiguredAirbyteCatalog( - streams=[stream for stream in catalog.streams if stream.stream.name not in concurrent_stream_names] + streams=[ + stream + for stream in catalog.streams + if stream.stream.name not in concurrent_stream_names + ] ) if abstract_streams: yield from self._concurrent_source.read(abstract_streams) if configured_catalog_for_regular_streams.streams: yield from super().read(logger, config, configured_catalog_for_regular_streams, state) - def _select_abstract_streams(self, config: Mapping[str, Any], configured_catalog: ConfiguredAirbyteCatalog) -> List[AbstractStream]: + def _select_abstract_streams( + self, config: Mapping[str, Any], configured_catalog: ConfiguredAirbyteCatalog + ) -> List[AbstractStream]: """ Selects streams that can be processed concurrently and returns their abstract representations. """ @@ -67,7 +82,11 @@ def _select_abstract_streams(self, config: Mapping[str, Any], configured_catalog return abstract_streams def convert_to_concurrent_stream( - self, logger: logging.Logger, stream: Stream, state_manager: ConnectorStateManager, cursor: Optional[Cursor] = None + self, + logger: logging.Logger, + stream: Stream, + state_manager: ConnectorStateManager, + cursor: Optional[Cursor] = None, ) -> Stream: """ Prepares a stream for concurrent processing by initializing or assigning a cursor, @@ -106,7 +125,9 @@ def initialize_cursor( if cursor_field_name: if not isinstance(cursor_field_name, str): - raise ValueError(f"Cursor field type must be a string, but received {type(cursor_field_name).__name__}.") + raise ValueError( + f"Cursor field type must be a string, but received {type(cursor_field_name).__name__}." + ) return ConcurrentCursor( stream.name, diff --git a/airbyte_cdk/sources/concurrent_source/thread_pool_manager.py b/airbyte_cdk/sources/concurrent_source/thread_pool_manager.py index b6933e6b..59f8a1f0 100644 --- a/airbyte_cdk/sources/concurrent_source/thread_pool_manager.py +++ b/airbyte_cdk/sources/concurrent_source/thread_pool_manager.py @@ -37,7 +37,9 @@ def __init__( def prune_to_validate_has_reached_futures_limit(self) -> bool: self._prune_futures(self._futures) if len(self._futures) > self._logging_threshold: - self._logger.warning(f"ThreadPoolManager: The list of futures is getting bigger than expected ({len(self._futures)})") + self._logger.warning( + f"ThreadPoolManager: The list of futures is getting bigger than expected ({len(self._futures)})" + ) return len(self._futures) >= self._max_concurrent_tasks def submit(self, function: Callable[..., Any], *args: Any) -> None: @@ -92,14 +94,18 @@ def check_for_errors_and_shutdown(self) -> None: ) self._stop_and_raise_exception(self._most_recently_seen_exception) - exceptions_from_futures = [f for f in [future.exception() for future in self._futures] if f is not None] + exceptions_from_futures = [ + f for f in [future.exception() for future in self._futures] if f is not None + ] if exceptions_from_futures: exception = RuntimeError(f"Failed reading with errors: {exceptions_from_futures}") self._stop_and_raise_exception(exception) else: futures_not_done = [f for f in self._futures if not f.done()] if futures_not_done: - exception = RuntimeError(f"Failed reading with futures not done: {futures_not_done}") + exception = RuntimeError( + f"Failed reading with futures not done: {futures_not_done}" + ) self._stop_and_raise_exception(exception) else: self._shutdown() diff --git a/airbyte_cdk/sources/connector_state_manager.py b/airbyte_cdk/sources/connector_state_manager.py index 2396029d..56b58127 100644 --- a/airbyte_cdk/sources/connector_state_manager.py +++ b/airbyte_cdk/sources/connector_state_manager.py @@ -6,7 +6,14 @@ from dataclasses import dataclass from typing import Any, List, Mapping, MutableMapping, Optional, Tuple, Union -from airbyte_cdk.models import AirbyteMessage, AirbyteStateBlob, AirbyteStateMessage, AirbyteStateType, AirbyteStreamState, StreamDescriptor +from airbyte_cdk.models import ( + AirbyteMessage, + AirbyteStateBlob, + AirbyteStateMessage, + AirbyteStateType, + AirbyteStreamState, + StreamDescriptor, +) from airbyte_cdk.models import Type as MessageType @@ -42,19 +49,25 @@ def __init__(self, state: Optional[List[AirbyteStateMessage]] = None): ) self.per_stream_states = per_stream_states - def get_stream_state(self, stream_name: str, namespace: Optional[str]) -> MutableMapping[str, Any]: + def get_stream_state( + self, stream_name: str, namespace: Optional[str] + ) -> MutableMapping[str, Any]: """ Retrieves the state of a given stream based on its descriptor (name + namespace). :param stream_name: Name of the stream being fetched :param namespace: Namespace of the stream being fetched :return: The per-stream state for a stream """ - stream_state: AirbyteStateBlob | None = self.per_stream_states.get(HashableStreamDescriptor(name=stream_name, namespace=namespace)) + stream_state: AirbyteStateBlob | None = self.per_stream_states.get( + HashableStreamDescriptor(name=stream_name, namespace=namespace) + ) if stream_state: return copy.deepcopy({k: v for k, v in stream_state.__dict__.items()}) return {} - def update_state_for_stream(self, stream_name: str, namespace: Optional[str], value: Mapping[str, Any]) -> None: + def update_state_for_stream( + self, stream_name: str, namespace: Optional[str], value: Mapping[str, Any] + ) -> None: """ Overwrites the state blob of a specific stream based on the provided stream name and optional namespace :param stream_name: The name of the stream whose state is being updated @@ -79,7 +92,8 @@ def create_state_message(self, stream_name: str, namespace: Optional[str]) -> Ai state=AirbyteStateMessage( type=AirbyteStateType.STREAM, stream=AirbyteStreamState( - stream_descriptor=StreamDescriptor(name=stream_name, namespace=namespace), stream_state=stream_state + stream_descriptor=StreamDescriptor(name=stream_name, namespace=namespace), + stream_state=stream_state, ), ), ) @@ -88,7 +102,10 @@ def create_state_message(self, stream_name: str, namespace: Optional[str]) -> Ai def _extract_from_state_message( cls, state: Optional[List[AirbyteStateMessage]], - ) -> Tuple[Optional[AirbyteStateBlob], MutableMapping[HashableStreamDescriptor, Optional[AirbyteStateBlob]]]: + ) -> Tuple[ + Optional[AirbyteStateBlob], + MutableMapping[HashableStreamDescriptor, Optional[AirbyteStateBlob]], + ]: """ Takes an incoming list of state messages or a global state message and extracts state attributes according to type which can then be assigned to the new state manager being instantiated @@ -105,7 +122,8 @@ def _extract_from_state_message( shared_state = copy.deepcopy(global_state.shared_state, {}) # type: ignore[union-attr] # global_state has shared_state streams = { HashableStreamDescriptor( - name=per_stream_state.stream_descriptor.name, namespace=per_stream_state.stream_descriptor.namespace + name=per_stream_state.stream_descriptor.name, + namespace=per_stream_state.stream_descriptor.namespace, ): per_stream_state.stream_state for per_stream_state in global_state.stream_states # type: ignore[union-attr] # global_state has shared_state } @@ -117,7 +135,8 @@ def _extract_from_state_message( namespace=per_stream_state.stream.stream_descriptor.namespace, # type: ignore[union-attr] # stream has stream_descriptor ): per_stream_state.stream.stream_state # type: ignore[union-attr] # stream has stream_state for per_stream_state in state - if per_stream_state.type == AirbyteStateType.STREAM and hasattr(per_stream_state, "stream") # type: ignore # state is always a list of AirbyteStateMessage if is_per_stream is True + if per_stream_state.type == AirbyteStateType.STREAM + and hasattr(per_stream_state, "stream") # type: ignore # state is always a list of AirbyteStateMessage if is_per_stream is True } return None, streams @@ -131,5 +150,7 @@ def _is_global_state(state: Union[List[AirbyteStateMessage], MutableMapping[str, ) @staticmethod - def _is_per_stream_state(state: Union[List[AirbyteStateMessage], MutableMapping[str, Any]]) -> bool: + def _is_per_stream_state( + state: Union[List[AirbyteStateMessage], MutableMapping[str, Any]], + ) -> bool: return isinstance(state, List) diff --git a/airbyte_cdk/sources/declarative/async_job/job.py b/airbyte_cdk/sources/declarative/async_job/job.py index 5b4f1c7a..b075b61e 100644 --- a/airbyte_cdk/sources/declarative/async_job/job.py +++ b/airbyte_cdk/sources/declarative/async_job/job.py @@ -18,7 +18,9 @@ class AsyncJob: it and call `ApiJob.update_status`, `ApiJob.status` will not reflect the actual API side status. """ - def __init__(self, api_job_id: str, job_parameters: StreamSlice, timeout: Optional[timedelta] = None) -> None: + def __init__( + self, api_job_id: str, job_parameters: StreamSlice, timeout: Optional[timedelta] = None + ) -> None: self._api_job_id = api_job_id self._job_parameters = job_parameters self._status = AsyncJobStatus.RUNNING diff --git a/airbyte_cdk/sources/declarative/async_job/job_orchestrator.py b/airbyte_cdk/sources/declarative/async_job/job_orchestrator.py index ddd7b8b3..d94885fa 100644 --- a/airbyte_cdk/sources/declarative/async_job/job_orchestrator.py +++ b/airbyte_cdk/sources/declarative/async_job/job_orchestrator.py @@ -6,13 +6,28 @@ import traceback import uuid from datetime import timedelta -from typing import Any, Generator, Generic, Iterable, List, Mapping, Optional, Set, Tuple, Type, TypeVar +from typing import ( + Any, + Generator, + Generic, + Iterable, + List, + Mapping, + Optional, + Set, + Tuple, + Type, + TypeVar, +) from airbyte_cdk import StreamSlice from airbyte_cdk.logger import lazy_log from airbyte_cdk.models import FailureType from airbyte_cdk.sources.declarative.async_job.job import AsyncJob -from airbyte_cdk.sources.declarative.async_job.job_tracker import ConcurrentJobLimitReached, JobTracker +from airbyte_cdk.sources.declarative.async_job.job_tracker import ( + ConcurrentJobLimitReached, + JobTracker, +) from airbyte_cdk.sources.declarative.async_job.repository import AsyncJobRepository from airbyte_cdk.sources.declarative.async_job.status import AsyncJobStatus from airbyte_cdk.sources.message import MessageRepository @@ -36,7 +51,12 @@ def __init__(self, jobs: List[AsyncJob], stream_slice: StreamSlice) -> None: self._stream_slice = stream_slice def has_reached_max_attempt(self) -> bool: - return any(map(lambda attempt_count: attempt_count >= self._MAX_NUMBER_OF_ATTEMPTS, self._attempts_per_job.values())) + return any( + map( + lambda attempt_count: attempt_count >= self._MAX_NUMBER_OF_ATTEMPTS, + self._attempts_per_job.values(), + ) + ) def replace_job(self, job_to_replace: AsyncJob, new_jobs: List[AsyncJob]) -> None: current_attempt_count = self._attempts_per_job.pop(job_to_replace, None) @@ -119,7 +139,12 @@ def add_at_the_beginning(self, item: T) -> None: class AsyncJobOrchestrator: _WAIT_TIME_BETWEEN_STATUS_UPDATE_IN_SECONDS = 5 - _KNOWN_JOB_STATUSES = {AsyncJobStatus.COMPLETED, AsyncJobStatus.FAILED, AsyncJobStatus.RUNNING, AsyncJobStatus.TIMED_OUT} + _KNOWN_JOB_STATUSES = { + AsyncJobStatus.COMPLETED, + AsyncJobStatus.FAILED, + AsyncJobStatus.RUNNING, + AsyncJobStatus.TIMED_OUT, + } _RUNNING_ON_API_SIDE_STATUS = {AsyncJobStatus.RUNNING, AsyncJobStatus.TIMED_OUT} def __init__( @@ -176,7 +201,11 @@ def _start_jobs(self) -> None: for partition in self._running_partitions: self._replace_failed_jobs(partition) - if self._has_bulk_parent and self._running_partitions and self._slice_iterator.has_next(): + if ( + self._has_bulk_parent + and self._running_partitions + and self._slice_iterator.has_next() + ): LOGGER.debug( "This AsyncJobOrchestrator is operating as a child of a bulk stream hence we limit the number of concurrent jobs on the child until there are no more parent slices to avoid the child taking all the API job budget" ) @@ -192,7 +221,9 @@ def _start_jobs(self) -> None: if at_least_one_slice_consumed_from_slice_iterator_during_current_iteration: # this means a slice has been consumed but the job couldn't be create therefore we need to put it back at the beginning of the _slice_iterator self._slice_iterator.add_at_the_beginning(_slice) # type: ignore # we know it's not None here because `ConcurrentJobLimitReached` happens during the for loop - LOGGER.debug("Waiting before creating more jobs as the limit of concurrent jobs has been reached. Will try again later...") + LOGGER.debug( + "Waiting before creating more jobs as the limit of concurrent jobs has been reached. Will try again later..." + ) def _start_job(self, _slice: StreamSlice, previous_job_id: Optional[str] = None) -> AsyncJob: if previous_job_id: @@ -212,7 +243,9 @@ def _start_job(self, _slice: StreamSlice, previous_job_id: Optional[str] = None) raise exception return self._keep_api_budget_with_failed_job(_slice, exception, id_to_replace) - def _keep_api_budget_with_failed_job(self, _slice: StreamSlice, exception: Exception, intent: str) -> AsyncJob: + def _keep_api_budget_with_failed_job( + self, _slice: StreamSlice, exception: Exception, intent: str + ) -> AsyncJob: """ We have a mechanism to retry job. It is used when a job status is FAILED or TIMED_OUT. The easiest way to retry is to have this job as created in a failed state and leverage the retry for failed/timed out jobs. This way, we don't have to have another process for @@ -221,7 +254,11 @@ def _keep_api_budget_with_failed_job(self, _slice: StreamSlice, exception: Excep LOGGER.warning( f"Could not start job for slice {_slice}. Job will be flagged as failed and retried if max number of attempts not reached: {exception}" ) - traced_exception = exception if isinstance(exception, AirbyteTracedException) else AirbyteTracedException.from_exception(exception) + traced_exception = ( + exception + if isinstance(exception, AirbyteTracedException) + else AirbyteTracedException.from_exception(exception) + ) # Even though we're not sure this will break the stream, we will emit here for simplicity's sake. If we wanted to be more accurate, # we would keep the exceptions in-memory until we know that we have reached the max attempt. self._message_repository.emit_message(traced_exception.as_airbyte_message()) @@ -241,7 +278,12 @@ def _get_running_jobs(self) -> Set[AsyncJob]: Returns: Set[AsyncJob]: A set of AsyncJob objects that are currently running. """ - return {job for partition in self._running_partitions for job in partition.jobs if job.status() == AsyncJobStatus.RUNNING} + return { + job + for partition in self._running_partitions + for job in partition.jobs + if job.status() == AsyncJobStatus.RUNNING + } def _update_jobs_status(self) -> None: """ @@ -283,14 +325,18 @@ def _process_completed_partition(self, partition: AsyncPartition) -> None: partition (AsyncPartition): The completed partition to process. """ job_ids = list(map(lambda job: job.api_job_id(), {job for job in partition.jobs})) - LOGGER.info(f"The following jobs for stream slice {partition.stream_slice} have been completed: {job_ids}.") + LOGGER.info( + f"The following jobs for stream slice {partition.stream_slice} have been completed: {job_ids}." + ) # It is important to remove the jobs from the job tracker before yielding the partition as the caller might try to schedule jobs # but won't be able to as all jobs slots are taken even though job is done. for job in partition.jobs: self._job_tracker.remove_job(job.api_job_id()) - def _process_running_partitions_and_yield_completed_ones(self) -> Generator[AsyncPartition, Any, None]: + def _process_running_partitions_and_yield_completed_ones( + self, + ) -> Generator[AsyncPartition, Any, None]: """ Process the running partitions. @@ -392,7 +438,9 @@ def create_and_get_completed_partitions(self) -> Iterable[AsyncPartition]: self._wait_on_status_update() except Exception as exception: if self._is_breaking_exception(exception): - LOGGER.warning(f"Caught exception that stops the processing of the jobs: {exception}") + LOGGER.warning( + f"Caught exception that stops the processing of the jobs: {exception}" + ) self._abort_all_running_jobs() raise exception @@ -406,7 +454,12 @@ def create_and_get_completed_partitions(self) -> Iterable[AsyncPartition]: # call of `create_and_get_completed_partitions` knows that there was an issue with some partitions and the sync is incomplete. raise AirbyteTracedException( message="", - internal_message="\n".join([filter_secrets(exception.__repr__()) for exception in self._non_breaking_exceptions]), + internal_message="\n".join( + [ + filter_secrets(exception.__repr__()) + for exception in self._non_breaking_exceptions + ] + ), failure_type=FailureType.config_error, ) @@ -425,7 +478,8 @@ def _abort_all_running_jobs(self) -> None: def _is_breaking_exception(self, exception: Exception) -> bool: return isinstance(exception, self._exceptions_to_break_on) or ( - isinstance(exception, AirbyteTracedException) and exception.failure_type == FailureType.config_error + isinstance(exception, AirbyteTracedException) + and exception.failure_type == FailureType.config_error ) def fetch_records(self, partition: AsyncPartition) -> Iterable[Mapping[str, Any]]: diff --git a/airbyte_cdk/sources/declarative/async_job/job_tracker.py b/airbyte_cdk/sources/declarative/async_job/job_tracker.py index 54fbd26d..b47fc4ca 100644 --- a/airbyte_cdk/sources/declarative/async_job/job_tracker.py +++ b/airbyte_cdk/sources/declarative/async_job/job_tracker.py @@ -21,25 +21,39 @@ def __init__(self, limit: int): self._lock = threading.Lock() def try_to_get_intent(self) -> str: - lazy_log(LOGGER, logging.DEBUG, lambda: f"JobTracker - Trying to acquire lock by thread {threading.get_native_id()}...") + lazy_log( + LOGGER, + logging.DEBUG, + lambda: f"JobTracker - Trying to acquire lock by thread {threading.get_native_id()}...", + ) with self._lock: if self._has_reached_limit(): - raise ConcurrentJobLimitReached("Can't allocate more jobs right now: limit already reached") + raise ConcurrentJobLimitReached( + "Can't allocate more jobs right now: limit already reached" + ) intent = f"intent_{str(uuid.uuid4())}" - lazy_log(LOGGER, logging.DEBUG, lambda: f"JobTracker - Thread {threading.get_native_id()} has acquired {intent}!") + lazy_log( + LOGGER, + logging.DEBUG, + lambda: f"JobTracker - Thread {threading.get_native_id()} has acquired {intent}!", + ) self._jobs.add(intent) return intent def add_job(self, intent_or_job_id: str, job_id: str) -> None: if intent_or_job_id not in self._jobs: - raise ValueError(f"Can't add job: Unknown intent or job id, known values are {self._jobs}") + raise ValueError( + f"Can't add job: Unknown intent or job id, known values are {self._jobs}" + ) if intent_or_job_id == job_id: # Nothing to do here as the ID to replace is the same return lazy_log( - LOGGER, logging.DEBUG, lambda: f"JobTracker - Thread {threading.get_native_id()} replacing job {intent_or_job_id} by {job_id}!" + LOGGER, + logging.DEBUG, + lambda: f"JobTracker - Thread {threading.get_native_id()} replacing job {intent_or_job_id} by {job_id}!", ) with self._lock: self._jobs.add(job_id) @@ -49,7 +63,11 @@ def remove_job(self, job_id: str) -> None: """ If the job is not allocated as a running job, this method does nothing and it won't raise. """ - lazy_log(LOGGER, logging.DEBUG, lambda: f"JobTracker - Thread {threading.get_native_id()} removing job {job_id}") + lazy_log( + LOGGER, + logging.DEBUG, + lambda: f"JobTracker - Thread {threading.get_native_id()} removing job {job_id}", + ) with self._lock: self._jobs.discard(job_id) diff --git a/airbyte_cdk/sources/declarative/async_job/repository.py b/airbyte_cdk/sources/declarative/async_job/repository.py index b2de8659..21581ec4 100644 --- a/airbyte_cdk/sources/declarative/async_job/repository.py +++ b/airbyte_cdk/sources/declarative/async_job/repository.py @@ -26,7 +26,9 @@ def abort(self, job: AsyncJob) -> None: Called when we need to stop on the API side. This method can raise NotImplementedError as not all the APIs will support aborting jobs. """ - raise NotImplementedError("Either the API or the AsyncJobRepository implementation do not support aborting jobs") + raise NotImplementedError( + "Either the API or the AsyncJobRepository implementation do not support aborting jobs" + ) @abstractmethod def delete(self, job: AsyncJob) -> None: diff --git a/airbyte_cdk/sources/declarative/auth/declarative_authenticator.py b/airbyte_cdk/sources/declarative/auth/declarative_authenticator.py index 5517f546..b749718f 100644 --- a/airbyte_cdk/sources/declarative/auth/declarative_authenticator.py +++ b/airbyte_cdk/sources/declarative/auth/declarative_authenticator.py @@ -5,7 +5,9 @@ from dataclasses import InitVar, dataclass from typing import Any, Mapping, Union -from airbyte_cdk.sources.streams.http.requests_native_auth.abstract_token import AbstractHeaderAuthenticator +from airbyte_cdk.sources.streams.http.requests_native_auth.abstract_token import ( + AbstractHeaderAuthenticator, +) @dataclass diff --git a/airbyte_cdk/sources/declarative/auth/jwt.py b/airbyte_cdk/sources/declarative/auth/jwt.py index e24ee793..4095635d 100644 --- a/airbyte_cdk/sources/declarative/auth/jwt.py +++ b/airbyte_cdk/sources/declarative/auth/jwt.py @@ -75,22 +75,32 @@ class JwtAuthenticator(DeclarativeAuthenticator): def __post_init__(self, parameters: Mapping[str, Any]) -> None: self._secret_key = InterpolatedString.create(self.secret_key, parameters=parameters) - self._algorithm = JwtAlgorithm(self.algorithm) if isinstance(self.algorithm, str) else self.algorithm + self._algorithm = ( + JwtAlgorithm(self.algorithm) if isinstance(self.algorithm, str) else self.algorithm + ) self._base64_encode_secret_key = ( InterpolatedBoolean(self.base64_encode_secret_key, parameters=parameters) if isinstance(self.base64_encode_secret_key, str) else self.base64_encode_secret_key ) self._token_duration = self.token_duration - self._header_prefix = InterpolatedString.create(self.header_prefix, parameters=parameters) if self.header_prefix else None + self._header_prefix = ( + InterpolatedString.create(self.header_prefix, parameters=parameters) + if self.header_prefix + else None + ) self._kid = InterpolatedString.create(self.kid, parameters=parameters) if self.kid else None self._typ = InterpolatedString.create(self.typ, parameters=parameters) if self.typ else None self._cty = InterpolatedString.create(self.cty, parameters=parameters) if self.cty else None self._iss = InterpolatedString.create(self.iss, parameters=parameters) if self.iss else None self._sub = InterpolatedString.create(self.sub, parameters=parameters) if self.sub else None self._aud = InterpolatedString.create(self.aud, parameters=parameters) if self.aud else None - self._additional_jwt_headers = InterpolatedMapping(self.additional_jwt_headers or {}, parameters=parameters) - self._additional_jwt_payload = InterpolatedMapping(self.additional_jwt_payload or {}, parameters=parameters) + self._additional_jwt_headers = InterpolatedMapping( + self.additional_jwt_headers or {}, parameters=parameters + ) + self._additional_jwt_payload = InterpolatedMapping( + self.additional_jwt_payload or {}, parameters=parameters + ) def _get_jwt_headers(self) -> dict[str, Any]: """ " @@ -98,7 +108,9 @@ def _get_jwt_headers(self) -> dict[str, Any]: """ headers = self._additional_jwt_headers.eval(self.config) if any(prop in headers for prop in ["kid", "alg", "typ", "cty"]): - raise ValueError("'kid', 'alg', 'typ', 'cty' are reserved headers and should not be set as part of 'additional_jwt_headers'") + raise ValueError( + "'kid', 'alg', 'typ', 'cty' are reserved headers and should not be set as part of 'additional_jwt_headers'" + ) if self._kid: headers["kid"] = self._kid.eval(self.config) @@ -139,7 +151,11 @@ def _get_secret_key(self) -> str: Returns the secret key used to sign the JWT. """ secret_key: str = self._secret_key.eval(self.config) - return base64.b64encode(secret_key.encode()).decode() if self._base64_encode_secret_key else secret_key + return ( + base64.b64encode(secret_key.encode()).decode() + if self._base64_encode_secret_key + else secret_key + ) def _get_signed_token(self) -> Union[str, Any]: """ @@ -167,4 +183,8 @@ def auth_header(self) -> str: @property def token(self) -> str: - return f"{self._get_header_prefix()} {self._get_signed_token()}" if self._get_header_prefix() else self._get_signed_token() + return ( + f"{self._get_header_prefix()} {self._get_signed_token()}" + if self._get_header_prefix() + else self._get_signed_token() + ) diff --git a/airbyte_cdk/sources/declarative/auth/oauth.py b/airbyte_cdk/sources/declarative/auth/oauth.py index b68dbcf1..773d2818 100644 --- a/airbyte_cdk/sources/declarative/auth/oauth.py +++ b/airbyte_cdk/sources/declarative/auth/oauth.py @@ -10,8 +10,12 @@ from airbyte_cdk.sources.declarative.interpolation.interpolated_mapping import InterpolatedMapping from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString from airbyte_cdk.sources.message import MessageRepository, NoopMessageRepository -from airbyte_cdk.sources.streams.http.requests_native_auth.abstract_oauth import AbstractOauth2Authenticator -from airbyte_cdk.sources.streams.http.requests_native_auth.oauth import SingleUseRefreshTokenOauth2Authenticator +from airbyte_cdk.sources.streams.http.requests_native_auth.abstract_oauth import ( + AbstractOauth2Authenticator, +) +from airbyte_cdk.sources.streams.http.requests_native_auth.oauth import ( + SingleUseRefreshTokenOauth2Authenticator, +) @dataclass @@ -57,31 +61,49 @@ class DeclarativeOauth2Authenticator(AbstractOauth2Authenticator, DeclarativeAut def __post_init__(self, parameters: Mapping[str, Any]) -> None: super().__init__() - self._token_refresh_endpoint = InterpolatedString.create(self.token_refresh_endpoint, parameters=parameters) + self._token_refresh_endpoint = InterpolatedString.create( + self.token_refresh_endpoint, parameters=parameters + ) self._client_id = InterpolatedString.create(self.client_id, parameters=parameters) self._client_secret = InterpolatedString.create(self.client_secret, parameters=parameters) if self.refresh_token is not None: - self._refresh_token: Optional[InterpolatedString] = InterpolatedString.create(self.refresh_token, parameters=parameters) + self._refresh_token: Optional[InterpolatedString] = InterpolatedString.create( + self.refresh_token, parameters=parameters + ) else: self._refresh_token = None - self.access_token_name = InterpolatedString.create(self.access_token_name, parameters=parameters) - self.expires_in_name = InterpolatedString.create(self.expires_in_name, parameters=parameters) + self.access_token_name = InterpolatedString.create( + self.access_token_name, parameters=parameters + ) + self.expires_in_name = InterpolatedString.create( + self.expires_in_name, parameters=parameters + ) self.grant_type = InterpolatedString.create(self.grant_type, parameters=parameters) - self._refresh_request_body = InterpolatedMapping(self.refresh_request_body or {}, parameters=parameters) + self._refresh_request_body = InterpolatedMapping( + self.refresh_request_body or {}, parameters=parameters + ) self._token_expiry_date: pendulum.DateTime = ( - pendulum.parse(InterpolatedString.create(self.token_expiry_date, parameters=parameters).eval(self.config)) # type: ignore # pendulum.parse returns a datetime in this context + pendulum.parse( + InterpolatedString.create(self.token_expiry_date, parameters=parameters).eval( + self.config + ) + ) # type: ignore # pendulum.parse returns a datetime in this context if self.token_expiry_date else pendulum.now().subtract(days=1) # type: ignore # substract does not have type hints ) self._access_token: Optional[str] = None # access_token is initialized by a setter if self.get_grant_type() == "refresh_token" and self._refresh_token is None: - raise ValueError("OAuthAuthenticator needs a refresh_token parameter if grant_type is set to `refresh_token`") + raise ValueError( + "OAuthAuthenticator needs a refresh_token parameter if grant_type is set to `refresh_token`" + ) def get_token_refresh_endpoint(self) -> str: refresh_token: str = self._token_refresh_endpoint.eval(self.config) if not refresh_token: - raise ValueError("OAuthAuthenticator was unable to evaluate token_refresh_endpoint parameter") + raise ValueError( + "OAuthAuthenticator was unable to evaluate token_refresh_endpoint parameter" + ) return refresh_token def get_client_id(self) -> str: @@ -139,7 +161,9 @@ def _message_repository(self) -> MessageRepository: @dataclass -class DeclarativeSingleUseRefreshTokenOauth2Authenticator(SingleUseRefreshTokenOauth2Authenticator, DeclarativeAuthenticator): +class DeclarativeSingleUseRefreshTokenOauth2Authenticator( + SingleUseRefreshTokenOauth2Authenticator, DeclarativeAuthenticator +): """ Declarative version of SingleUseRefreshTokenOauth2Authenticator which can be used in declarative connectors. """ diff --git a/airbyte_cdk/sources/declarative/auth/selective_authenticator.py b/airbyte_cdk/sources/declarative/auth/selective_authenticator.py index e3f39a0a..11a2ae7d 100644 --- a/airbyte_cdk/sources/declarative/auth/selective_authenticator.py +++ b/airbyte_cdk/sources/declarative/auth/selective_authenticator.py @@ -29,7 +29,9 @@ def __new__( # type: ignore[misc] try: selected_key = str(dpath.get(config, authenticator_selection_path)) except KeyError as err: - raise ValueError("The path from `authenticator_selection_path` is not found in the config.") from err + raise ValueError( + "The path from `authenticator_selection_path` is not found in the config." + ) from err try: return authenticators[selected_key] diff --git a/airbyte_cdk/sources/declarative/auth/token.py b/airbyte_cdk/sources/declarative/auth/token.py index a2b64fce..dc35eb45 100644 --- a/airbyte_cdk/sources/declarative/auth/token.py +++ b/airbyte_cdk/sources/declarative/auth/token.py @@ -11,7 +11,10 @@ from airbyte_cdk.sources.declarative.auth.declarative_authenticator import DeclarativeAuthenticator from airbyte_cdk.sources.declarative.auth.token_provider import TokenProvider from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString -from airbyte_cdk.sources.declarative.requesters.request_option import RequestOption, RequestOptionType +from airbyte_cdk.sources.declarative.requesters.request_option import ( + RequestOption, + RequestOptionType, +) from airbyte_cdk.sources.types import Config from cachetools import TTLCache, cached @@ -42,7 +45,9 @@ class ApiKeyAuthenticator(DeclarativeAuthenticator): parameters: InitVar[Mapping[str, Any]] def __post_init__(self, parameters: Mapping[str, Any]) -> None: - self._field_name = InterpolatedString.create(self.request_option.field_name, parameters=parameters) + self._field_name = InterpolatedString.create( + self.request_option.field_name, parameters=parameters + ) @property def auth_header(self) -> str: @@ -127,7 +132,9 @@ def auth_header(self) -> str: @property def token(self) -> str: - auth_string = f"{self._username.eval(self.config)}:{self._password.eval(self.config)}".encode("utf8") + auth_string = ( + f"{self._username.eval(self.config)}:{self._password.eval(self.config)}".encode("utf8") + ) b64_encoded = base64.b64encode(auth_string).decode("utf8") return f"Basic {b64_encoded}" @@ -164,7 +171,9 @@ def get_new_session_token(api_url: str, username: str, password: str, response_k ) response.raise_for_status() if not response.ok: - raise ConnectionError(f"Failed to retrieve new session token, response code {response.status_code} because {response.reason}") + raise ConnectionError( + f"Failed to retrieve new session token, response code {response.status_code} because {response.reason}" + ) return str(response.json()[response_key]) @@ -208,9 +217,13 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: self._api_url = InterpolatedString.create(self.api_url, parameters=parameters) self._header = InterpolatedString.create(self.header, parameters=parameters) self._session_token = InterpolatedString.create(self.session_token, parameters=parameters) - self._session_token_response_key = InterpolatedString.create(self.session_token_response_key, parameters=parameters) + self._session_token_response_key = InterpolatedString.create( + self.session_token_response_key, parameters=parameters + ) self._login_url = InterpolatedString.create(self.login_url, parameters=parameters) - self._validate_session_url = InterpolatedString.create(self.validate_session_url, parameters=parameters) + self._validate_session_url = InterpolatedString.create( + self.validate_session_url, parameters=parameters + ) self.logger = logging.getLogger("airbyte") @@ -232,7 +245,9 @@ def token(self) -> str: self.logger.info("Using generated session token by username and password") return get_new_session_token(api_url, username, password, session_token_response_key) - raise ConnectionError("Invalid credentials: session token is not valid or provide username and password") + raise ConnectionError( + "Invalid credentials: session token is not valid or provide username and password" + ) def is_valid_session_token(self) -> bool: try: @@ -251,4 +266,6 @@ def is_valid_session_token(self) -> bool: self.logger.info("Connection check for source is successful.") return True else: - raise ConnectionError(f"Failed to retrieve new session token, response code {response.status_code} because {response.reason}") + raise ConnectionError( + f"Failed to retrieve new session token, response code {response.status_code} because {response.reason}" + ) diff --git a/airbyte_cdk/sources/declarative/checks/check_stream.py b/airbyte_cdk/sources/declarative/checks/check_stream.py index baf056d3..c45159ec 100644 --- a/airbyte_cdk/sources/declarative/checks/check_stream.py +++ b/airbyte_cdk/sources/declarative/checks/check_stream.py @@ -27,22 +27,30 @@ class CheckStream(ConnectionChecker): def __post_init__(self, parameters: Mapping[str, Any]) -> None: self._parameters = parameters - def check_connection(self, source: AbstractSource, logger: logging.Logger, config: Mapping[str, Any]) -> Tuple[bool, Any]: + def check_connection( + self, source: AbstractSource, logger: logging.Logger, config: Mapping[str, Any] + ) -> Tuple[bool, Any]: streams = source.streams(config=config) stream_name_to_stream = {s.name: s for s in streams} if len(streams) == 0: return False, f"No streams to connect to from source {source}" for stream_name in self.stream_names: if stream_name not in stream_name_to_stream.keys(): - raise ValueError(f"{stream_name} is not part of the catalog. Expected one of {stream_name_to_stream.keys()}.") + raise ValueError( + f"{stream_name} is not part of the catalog. Expected one of {stream_name_to_stream.keys()}." + ) stream = stream_name_to_stream[stream_name] availability_strategy = HttpAvailabilityStrategy() try: - stream_is_available, reason = availability_strategy.check_availability(stream, logger) + stream_is_available, reason = availability_strategy.check_availability( + stream, logger + ) if not stream_is_available: return False, reason except Exception as error: - logger.error(f"Encountered an error trying to connect to stream {stream_name}. Error: \n {traceback.format_exc()}") + logger.error( + f"Encountered an error trying to connect to stream {stream_name}. Error: \n {traceback.format_exc()}" + ) return False, f"Unable to connect to stream {stream_name} - {error}" return True, None diff --git a/airbyte_cdk/sources/declarative/checks/connection_checker.py b/airbyte_cdk/sources/declarative/checks/connection_checker.py index 908e659b..fd1d1bba 100644 --- a/airbyte_cdk/sources/declarative/checks/connection_checker.py +++ b/airbyte_cdk/sources/declarative/checks/connection_checker.py @@ -15,7 +15,9 @@ class ConnectionChecker(ABC): """ @abstractmethod - def check_connection(self, source: AbstractSource, logger: logging.Logger, config: Mapping[str, Any]) -> Tuple[bool, Any]: + def check_connection( + self, source: AbstractSource, logger: logging.Logger, config: Mapping[str, Any] + ) -> Tuple[bool, Any]: """ Tests if the input configuration can be used to successfully connect to the integration e.g: if a provided Stripe API token can be used to connect to the Stripe API. diff --git a/airbyte_cdk/sources/declarative/concurrency_level/concurrency_level.py b/airbyte_cdk/sources/declarative/concurrency_level/concurrency_level.py index a86c4f8f..f5cd24f0 100644 --- a/airbyte_cdk/sources/declarative/concurrency_level/concurrency_level.py +++ b/airbyte_cdk/sources/declarative/concurrency_level/concurrency_level.py @@ -28,15 +28,23 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: if isinstance(self.default_concurrency, int): self._default_concurrency: Union[int, InterpolatedString] = self.default_concurrency elif "config" in self.default_concurrency and not self.max_concurrency: - raise ValueError("ConcurrencyLevel requires that max_concurrency be defined if the default_concurrency can be used-specified") + raise ValueError( + "ConcurrencyLevel requires that max_concurrency be defined if the default_concurrency can be used-specified" + ) else: - self._default_concurrency = InterpolatedString.create(self.default_concurrency, parameters=parameters) + self._default_concurrency = InterpolatedString.create( + self.default_concurrency, parameters=parameters + ) def get_concurrency_level(self) -> int: if isinstance(self._default_concurrency, InterpolatedString): evaluated_default_concurrency = self._default_concurrency.eval(config=self.config) if not isinstance(evaluated_default_concurrency, int): raise ValueError("default_concurrency did not evaluate to an integer") - return min(evaluated_default_concurrency, self.max_concurrency) if self.max_concurrency else evaluated_default_concurrency + return ( + min(evaluated_default_concurrency, self.max_concurrency) + if self.max_concurrency + else evaluated_default_concurrency + ) else: return self._default_concurrency diff --git a/airbyte_cdk/sources/declarative/concurrent_declarative_source.py b/airbyte_cdk/sources/declarative/concurrent_declarative_source.py index 8c8239ba..62e0b578 100644 --- a/airbyte_cdk/sources/declarative/concurrent_declarative_source.py +++ b/airbyte_cdk/sources/declarative/concurrent_declarative_source.py @@ -5,7 +5,12 @@ import logging from typing import Any, Generic, Iterator, List, Mapping, Optional, Tuple, Union -from airbyte_cdk.models import AirbyteCatalog, AirbyteMessage, AirbyteStateMessage, ConfiguredAirbyteCatalog +from airbyte_cdk.models import ( + AirbyteCatalog, + AirbyteMessage, + AirbyteStateMessage, + ConfiguredAirbyteCatalog, +) from airbyte_cdk.sources.concurrent_source.concurrent_source import ConcurrentSource from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager from airbyte_cdk.sources.declarative.concurrency_level import ConcurrencyLevel @@ -14,9 +19,15 @@ from airbyte_cdk.sources.declarative.incremental.datetime_based_cursor import DatetimeBasedCursor from airbyte_cdk.sources.declarative.interpolation import InterpolatedString from airbyte_cdk.sources.declarative.manifest_declarative_source import ManifestDeclarativeSource -from airbyte_cdk.sources.declarative.models.declarative_component_schema import ConcurrencyLevel as ConcurrencyLevelModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import DatetimeBasedCursor as DatetimeBasedCursorModel -from airbyte_cdk.sources.declarative.parsers.model_to_component_factory import ModelToComponentFactory +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + ConcurrencyLevel as ConcurrencyLevelModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + DatetimeBasedCursor as DatetimeBasedCursorModel, +) +from airbyte_cdk.sources.declarative.parsers.model_to_component_factory import ( + ModelToComponentFactory, +) from airbyte_cdk.sources.declarative.requesters import HttpRequester from airbyte_cdk.sources.declarative.retrievers import SimpleRetriever from airbyte_cdk.sources.declarative.transformations.add_fields import AddFields @@ -25,7 +36,9 @@ from airbyte_cdk.sources.streams import Stream from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream from airbyte_cdk.sources.streams.concurrent.adapters import CursorPartitionGenerator -from airbyte_cdk.sources.streams.concurrent.availability_strategy import AlwaysAvailableAvailabilityStrategy +from airbyte_cdk.sources.streams.concurrent.availability_strategy import ( + AlwaysAvailableAvailabilityStrategy, +) from airbyte_cdk.sources.streams.concurrent.default_stream import DefaultStream from airbyte_cdk.sources.streams.concurrent.helpers import get_primary_key_from_stream @@ -62,7 +75,9 @@ def __init__( # any other arguments, but the existing entrypoint.py isn't designed to support this. Just noting this # for our future improvements to the CDK. if config: - self._concurrent_streams, self._synchronous_streams = self._group_streams(config=config or {}) + self._concurrent_streams, self._synchronous_streams = self._group_streams( + config=config or {} + ) else: self._concurrent_streams = None self._synchronous_streams = None @@ -70,10 +85,14 @@ def __init__( concurrency_level_from_manifest = self._source_config.get("concurrency_level") if concurrency_level_from_manifest: concurrency_level_component = self._constructor.create_component( - model_type=ConcurrencyLevelModel, component_definition=concurrency_level_from_manifest, config=config or {} + model_type=ConcurrencyLevelModel, + component_definition=concurrency_level_from_manifest, + config=config or {}, ) if not isinstance(concurrency_level_component, ConcurrencyLevel): - raise ValueError(f"Expected to generate a ConcurrencyLevel component, but received {concurrency_level_component.__class__}") + raise ValueError( + f"Expected to generate a ConcurrencyLevel component, but received {concurrency_level_component.__class__}" + ) concurrency_level = concurrency_level_component.get_concurrency_level() initial_number_of_partitions_to_generate = max( @@ -101,9 +120,13 @@ def read( # ConcurrentReadProcessor pops streams that are finished being read so before syncing, the names of the concurrent # streams must be saved so that they can be removed from the catalog before starting synchronous streams if self._concurrent_streams: - concurrent_stream_names = set([concurrent_stream.name for concurrent_stream in self._concurrent_streams]) + concurrent_stream_names = set( + [concurrent_stream.name for concurrent_stream in self._concurrent_streams] + ) - selected_concurrent_streams = self._select_streams(streams=self._concurrent_streams, configured_catalog=catalog) + selected_concurrent_streams = self._select_streams( + streams=self._concurrent_streams, configured_catalog=catalog + ) # It would appear that passing in an empty set of streams causes an infinite loop in ConcurrentReadProcessor. # This is also evident in concurrent_source_adapter.py so I'll leave this out of scope to fix for now if selected_concurrent_streams: @@ -123,7 +146,11 @@ def read( def discover(self, logger: logging.Logger, config: Mapping[str, Any]) -> AirbyteCatalog: concurrent_streams = self._concurrent_streams or [] synchronous_streams = self._synchronous_streams or [] - return AirbyteCatalog(streams=[stream.as_airbyte_stream() for stream in concurrent_streams + synchronous_streams]) + return AirbyteCatalog( + streams=[ + stream.as_airbyte_stream() for stream in concurrent_streams + synchronous_streams + ] + ) def streams(self, config: Mapping[str, Any]) -> List[Stream]: """ @@ -136,25 +163,34 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]: """ return super().streams(config) - def _group_streams(self, config: Mapping[str, Any]) -> Tuple[List[AbstractStream], List[Stream]]: + def _group_streams( + self, config: Mapping[str, Any] + ) -> Tuple[List[AbstractStream], List[Stream]]: concurrent_streams: List[AbstractStream] = [] synchronous_streams: List[Stream] = [] state_manager = ConnectorStateManager(state=self._state) # type: ignore # state is always in the form of List[AirbyteStateMessage]. The ConnectorStateManager should use generics, but this can be done later - name_to_stream_mapping = {stream["name"]: stream for stream in self.resolved_manifest["streams"]} + name_to_stream_mapping = { + stream["name"]: stream for stream in self.resolved_manifest["streams"] + } for declarative_stream in self.streams(config=config): # Some low-code sources use a combination of DeclarativeStream and regular Python streams. We can't inspect # these legacy Python streams the way we do low-code streams to determine if they are concurrent compatible, # so we need to treat them as synchronous if isinstance(declarative_stream, DeclarativeStream): - datetime_based_cursor_component_definition = name_to_stream_mapping[declarative_stream.name].get("incremental_sync") + datetime_based_cursor_component_definition = name_to_stream_mapping[ + declarative_stream.name + ].get("incremental_sync") if ( datetime_based_cursor_component_definition - and datetime_based_cursor_component_definition.get("type", "") == DatetimeBasedCursorModel.__name__ - and self._stream_supports_concurrent_partition_processing(declarative_stream=declarative_stream) + and datetime_based_cursor_component_definition.get("type", "") + == DatetimeBasedCursorModel.__name__ + and self._stream_supports_concurrent_partition_processing( + declarative_stream=declarative_stream + ) and hasattr(declarative_stream.retriever, "stream_slicer") and isinstance(declarative_stream.retriever.stream_slicer, DatetimeBasedCursor) ): @@ -162,26 +198,34 @@ def _group_streams(self, config: Mapping[str, Any]) -> Tuple[List[AbstractStream stream_name=declarative_stream.name, namespace=declarative_stream.namespace ) - cursor, connector_state_converter = self._constructor.create_concurrent_cursor_from_datetime_based_cursor( - state_manager=state_manager, - model_type=DatetimeBasedCursorModel, - component_definition=datetime_based_cursor_component_definition, - stream_name=declarative_stream.name, - stream_namespace=declarative_stream.namespace, - config=config or {}, - stream_state=stream_state, + cursor, connector_state_converter = ( + self._constructor.create_concurrent_cursor_from_datetime_based_cursor( + state_manager=state_manager, + model_type=DatetimeBasedCursorModel, + component_definition=datetime_based_cursor_component_definition, + stream_name=declarative_stream.name, + stream_namespace=declarative_stream.namespace, + config=config or {}, + stream_state=stream_state, + ) ) # This is an optimization so that we don't invoke any cursor or state management flows within the # low-code framework because state management is handled through the ConcurrentCursor. - if declarative_stream and declarative_stream.retriever and isinstance(declarative_stream.retriever, SimpleRetriever): + if ( + declarative_stream + and declarative_stream.retriever + and isinstance(declarative_stream.retriever, SimpleRetriever) + ): # Also a temporary hack. In the legacy Stream implementation, as part of the read, set_initial_state() is # called to instantiate incoming state on the cursor. Although we no longer rely on the legacy low-code cursor # for concurrent checkpointing, low-code components like StopConditionPaginationStrategyDecorator and # ClientSideIncrementalRecordFilterDecorator still rely on a DatetimeBasedCursor that is properly initialized # with state. if declarative_stream.retriever.cursor: - declarative_stream.retriever.cursor.set_initial_state(stream_state=stream_state) + declarative_stream.retriever.cursor.set_initial_state( + stream_state=stream_state + ) declarative_stream.retriever.cursor = None partition_generator = CursorPartitionGenerator( @@ -212,7 +256,9 @@ def _group_streams(self, config: Mapping[str, Any]) -> Tuple[List[AbstractStream return concurrent_streams, synchronous_streams - def _stream_supports_concurrent_partition_processing(self, declarative_stream: DeclarativeStream) -> bool: + def _stream_supports_concurrent_partition_processing( + self, declarative_stream: DeclarativeStream + ) -> bool: """ Many connectors make use of stream_state during interpolation on a per-partition basis under the assumption that state is updated sequentially. Because the concurrent CDK engine processes different partitions in parallel, @@ -224,7 +270,9 @@ def _stream_supports_concurrent_partition_processing(self, declarative_stream: D cdk-migrations.md for the full list of connectors. """ - if isinstance(declarative_stream.retriever, SimpleRetriever) and isinstance(declarative_stream.retriever.requester, HttpRequester): + if isinstance(declarative_stream.retriever, SimpleRetriever) and isinstance( + declarative_stream.retriever.requester, HttpRequester + ): http_requester = declarative_stream.retriever.requester if "stream_state" in http_requester._path.string: self.logger.warning( @@ -241,14 +289,19 @@ def _stream_supports_concurrent_partition_processing(self, declarative_stream: D record_selector = declarative_stream.retriever.record_selector if isinstance(record_selector, RecordSelector): - if record_selector.record_filter and "stream_state" in record_selector.record_filter.condition: + if ( + record_selector.record_filter + and "stream_state" in record_selector.record_filter.condition + ): self.logger.warning( f"Low-code stream '{declarative_stream.name}' uses interpolation of stream_state in the RecordFilter which is not thread-safe. Defaulting to synchronous processing" ) return False for add_fields in [ - transformation for transformation in record_selector.transformations if isinstance(transformation, AddFields) + transformation + for transformation in record_selector.transformations + if isinstance(transformation, AddFields) ]: for field in add_fields.fields: if isinstance(field.value, str) and "stream_state" in field.value: @@ -256,7 +309,10 @@ def _stream_supports_concurrent_partition_processing(self, declarative_stream: D f"Low-code stream '{declarative_stream.name}' uses interpolation of stream_state in the AddFields which is not thread-safe. Defaulting to synchronous processing" ) return False - if isinstance(field.value, InterpolatedString) and "stream_state" in field.value.string: + if ( + isinstance(field.value, InterpolatedString) + and "stream_state" in field.value.string + ): self.logger.warning( f"Low-code stream '{declarative_stream.name}' uses interpolation of stream_state in the AddFields which is not thread-safe. Defaulting to synchronous processing" ) @@ -264,7 +320,9 @@ def _stream_supports_concurrent_partition_processing(self, declarative_stream: D return True @staticmethod - def _select_streams(streams: List[AbstractStream], configured_catalog: ConfiguredAirbyteCatalog) -> List[AbstractStream]: + def _select_streams( + streams: List[AbstractStream], configured_catalog: ConfiguredAirbyteCatalog + ) -> List[AbstractStream]: stream_name_to_instance: Mapping[str, AbstractStream] = {s.name: s for s in streams} abstract_streams: List[AbstractStream] = [] for configured_stream in configured_catalog.streams: @@ -279,4 +337,10 @@ def _remove_concurrent_streams_from_catalog( catalog: ConfiguredAirbyteCatalog, concurrent_stream_names: set[str], ) -> ConfiguredAirbyteCatalog: - return ConfiguredAirbyteCatalog(streams=[stream for stream in catalog.streams if stream.stream.name not in concurrent_stream_names]) + return ConfiguredAirbyteCatalog( + streams=[ + stream + for stream in catalog.streams + if stream.stream.name not in concurrent_stream_names + ] + ) diff --git a/airbyte_cdk/sources/declarative/datetime/min_max_datetime.py b/airbyte_cdk/sources/declarative/datetime/min_max_datetime.py index 2694da27..b53f50ff 100644 --- a/airbyte_cdk/sources/declarative/datetime/min_max_datetime.py +++ b/airbyte_cdk/sources/declarative/datetime/min_max_datetime.py @@ -40,10 +40,20 @@ class MinMaxDatetime: def __post_init__(self, parameters: Mapping[str, Any]) -> None: self.datetime = InterpolatedString.create(self.datetime, parameters=parameters or {}) self._parser = DatetimeParser() - self.min_datetime = InterpolatedString.create(self.min_datetime, parameters=parameters) if self.min_datetime else None # type: ignore - self.max_datetime = InterpolatedString.create(self.max_datetime, parameters=parameters) if self.max_datetime else None # type: ignore + self.min_datetime = ( + InterpolatedString.create(self.min_datetime, parameters=parameters) + if self.min_datetime + else None + ) # type: ignore + self.max_datetime = ( + InterpolatedString.create(self.max_datetime, parameters=parameters) + if self.max_datetime + else None + ) # type: ignore - def get_datetime(self, config: Mapping[str, Any], **additional_parameters: Mapping[str, Any]) -> dt.datetime: + def get_datetime( + self, config: Mapping[str, Any], **additional_parameters: Mapping[str, Any] + ) -> dt.datetime: """ Evaluates and returns the datetime :param config: The user-provided configuration as specified by the source's spec @@ -55,7 +65,9 @@ def get_datetime(self, config: Mapping[str, Any], **additional_parameters: Mappi if not datetime_format: datetime_format = "%Y-%m-%dT%H:%M:%S.%f%z" - time = self._parser.parse(str(self.datetime.eval(config, **additional_parameters)), datetime_format) # type: ignore # datetime is always cast to an interpolated string + time = self._parser.parse( + str(self.datetime.eval(config, **additional_parameters)), datetime_format + ) # type: ignore # datetime is always cast to an interpolated string if self.min_datetime: min_time = str(self.min_datetime.eval(config, **additional_parameters)) # type: ignore # min_datetime is always cast to an interpolated string @@ -93,6 +105,8 @@ def create( if isinstance(interpolated_string_or_min_max_datetime, InterpolatedString) or isinstance( interpolated_string_or_min_max_datetime, str ): - return MinMaxDatetime(datetime=interpolated_string_or_min_max_datetime, parameters=parameters) + return MinMaxDatetime( + datetime=interpolated_string_or_min_max_datetime, parameters=parameters + ) else: return interpolated_string_or_min_max_datetime diff --git a/airbyte_cdk/sources/declarative/declarative_source.py b/airbyte_cdk/sources/declarative/declarative_source.py index 9135f2a9..77bf427a 100644 --- a/airbyte_cdk/sources/declarative/declarative_source.py +++ b/airbyte_cdk/sources/declarative/declarative_source.py @@ -20,7 +20,9 @@ class DeclarativeSource(AbstractSource): def connection_checker(self) -> ConnectionChecker: """Returns the ConnectionChecker to use for the `check` operation""" - def check_connection(self, logger: logging.Logger, config: Mapping[str, Any]) -> Tuple[bool, Any]: + def check_connection( + self, logger: logging.Logger, config: Mapping[str, Any] + ) -> Tuple[bool, Any]: """ :param logger: The source logger :param config: The user-provided configuration as specified by the source's spec. diff --git a/airbyte_cdk/sources/declarative/declarative_stream.py b/airbyte_cdk/sources/declarative/declarative_stream.py index 09ce080c..12cdd333 100644 --- a/airbyte_cdk/sources/declarative/declarative_stream.py +++ b/airbyte_cdk/sources/declarative/declarative_stream.py @@ -6,14 +6,23 @@ from typing import Any, Iterable, List, Mapping, MutableMapping, Optional, Union from airbyte_cdk.models import SyncMode -from airbyte_cdk.sources.declarative.incremental import GlobalSubstreamCursor, PerPartitionCursor, PerPartitionWithGlobalCursor +from airbyte_cdk.sources.declarative.incremental import ( + GlobalSubstreamCursor, + PerPartitionCursor, + PerPartitionWithGlobalCursor, +) from airbyte_cdk.sources.declarative.interpolation import InterpolatedString from airbyte_cdk.sources.declarative.migrations.state_migration import StateMigration from airbyte_cdk.sources.declarative.retrievers import SimpleRetriever from airbyte_cdk.sources.declarative.retrievers.retriever import Retriever from airbyte_cdk.sources.declarative.schema import DefaultSchemaLoader from airbyte_cdk.sources.declarative.schema.schema_loader import SchemaLoader -from airbyte_cdk.sources.streams.checkpoint import CheckpointMode, CheckpointReader, Cursor, CursorBasedCheckpointReader +from airbyte_cdk.sources.streams.checkpoint import ( + CheckpointMode, + CheckpointReader, + Cursor, + CursorBasedCheckpointReader, +) from airbyte_cdk.sources.streams.core import Stream from airbyte_cdk.sources.types import Config, StreamSlice @@ -50,7 +59,11 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: if isinstance(self.stream_cursor_field, str) else self.stream_cursor_field ) - self._schema_loader = self.schema_loader if self.schema_loader else DefaultSchemaLoader(config=self.config, parameters=parameters) + self._schema_loader = ( + self.schema_loader + if self.schema_loader + else DefaultSchemaLoader(config=self.config, parameters=parameters) + ) @property # type: ignore def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]: @@ -133,7 +146,9 @@ def read_records( # empty slice which seems to make sense. stream_slice = StreamSlice(partition={}, cursor_slice={}) if not isinstance(stream_slice, StreamSlice): - raise ValueError(f"DeclarativeStream does not support stream_slices that are not StreamSlice. Got {stream_slice}") + raise ValueError( + f"DeclarativeStream does not support stream_slices that are not StreamSlice. Got {stream_slice}" + ) yield from self.retriever.read_records(self.get_json_schema(), stream_slice) # type: ignore # records are of the correct type def get_json_schema(self) -> Mapping[str, Any]: # type: ignore @@ -146,7 +161,11 @@ def get_json_schema(self) -> Mapping[str, Any]: # type: ignore return self._schema_loader.get_json_schema() def stream_slices( - self, *, sync_mode: SyncMode, cursor_field: Optional[List[str]] = None, stream_state: Optional[Mapping[str, Any]] = None + self, + *, + sync_mode: SyncMode, + cursor_field: Optional[List[str]] = None, + stream_state: Optional[Mapping[str, Any]] = None, ) -> Iterable[Optional[StreamSlice]]: """ Override to define the slices for this stream. See the stream slicing section of the docs for more information. @@ -200,7 +219,9 @@ def _get_checkpoint_reader( cursor = self.get_cursor() checkpoint_mode = self._checkpoint_mode - if isinstance(cursor, (GlobalSubstreamCursor, PerPartitionCursor, PerPartitionWithGlobalCursor)): + if isinstance( + cursor, (GlobalSubstreamCursor, PerPartitionCursor, PerPartitionWithGlobalCursor) + ): self.has_multiple_slices = True return CursorBasedCheckpointReader( stream_slices=mappings_or_slices, diff --git a/airbyte_cdk/sources/declarative/decoders/decoder.py b/airbyte_cdk/sources/declarative/decoders/decoder.py index 4e8fdd64..5fa9dc8f 100644 --- a/airbyte_cdk/sources/declarative/decoders/decoder.py +++ b/airbyte_cdk/sources/declarative/decoders/decoder.py @@ -22,7 +22,9 @@ def is_stream_response(self) -> bool: """ @abstractmethod - def decode(self, response: requests.Response) -> Generator[MutableMapping[str, Any], None, None]: + def decode( + self, response: requests.Response + ) -> Generator[MutableMapping[str, Any], None, None]: """ Decodes a requests.Response into a Mapping[str, Any] or an array :param response: the response to decode diff --git a/airbyte_cdk/sources/declarative/decoders/json_decoder.py b/airbyte_cdk/sources/declarative/decoders/json_decoder.py index b2c25c33..986bbd87 100644 --- a/airbyte_cdk/sources/declarative/decoders/json_decoder.py +++ b/airbyte_cdk/sources/declarative/decoders/json_decoder.py @@ -37,7 +37,9 @@ def decode(self, response: requests.Response) -> Generator[Mapping[str, Any], No else: yield from body_json except requests.exceptions.JSONDecodeError: - logger.warning(f"Response cannot be parsed into json: {response.status_code=}, {response.text=}") + logger.warning( + f"Response cannot be parsed into json: {response.status_code=}, {response.text=}" + ) yield {} diff --git a/airbyte_cdk/sources/declarative/decoders/pagination_decoder_decorator.py b/airbyte_cdk/sources/declarative/decoders/pagination_decoder_decorator.py index dadb717a..fa37607b 100644 --- a/airbyte_cdk/sources/declarative/decoders/pagination_decoder_decorator.py +++ b/airbyte_cdk/sources/declarative/decoders/pagination_decoder_decorator.py @@ -28,7 +28,9 @@ def decoder(self) -> Decoder: def is_stream_response(self) -> bool: return self._decoder.is_stream_response() - def decode(self, response: requests.Response) -> Generator[MutableMapping[str, Any], None, None]: + def decode( + self, response: requests.Response + ) -> Generator[MutableMapping[str, Any], None, None]: if self._decoder.is_stream_response(): logger.warning("Response is streamed and therefore will not be decoded for pagination.") yield {} diff --git a/airbyte_cdk/sources/declarative/decoders/xml_decoder.py b/airbyte_cdk/sources/declarative/decoders/xml_decoder.py index 7b598ba8..6fb0477e 100644 --- a/airbyte_cdk/sources/declarative/decoders/xml_decoder.py +++ b/airbyte_cdk/sources/declarative/decoders/xml_decoder.py @@ -78,7 +78,9 @@ class XmlDecoder(Decoder): def is_stream_response(self) -> bool: return False - def decode(self, response: requests.Response) -> Generator[MutableMapping[str, Any], None, None]: + def decode( + self, response: requests.Response + ) -> Generator[MutableMapping[str, Any], None, None]: body_xml = response.text try: body_json = xmltodict.parse(body_xml) @@ -89,5 +91,7 @@ def decode(self, response: requests.Response) -> Generator[MutableMapping[str, A else: yield from body_json except ExpatError as exc: - logger.warning(f"Response cannot be parsed from XML: {response.status_code=}, {response.text=}, {exc=}") + logger.warning( + f"Response cannot be parsed from XML: {response.status_code=}, {response.text=}, {exc=}" + ) yield {} diff --git a/airbyte_cdk/sources/declarative/extractors/dpath_extractor.py b/airbyte_cdk/sources/declarative/extractors/dpath_extractor.py index 512d6919..0878c31a 100644 --- a/airbyte_cdk/sources/declarative/extractors/dpath_extractor.py +++ b/airbyte_cdk/sources/declarative/extractors/dpath_extractor.py @@ -58,10 +58,14 @@ class DpathExtractor(RecordExtractor): decoder: Decoder = field(default_factory=lambda: JsonDecoder(parameters={})) def __post_init__(self, parameters: Mapping[str, Any]) -> None: - self._field_path = [InterpolatedString.create(path, parameters=parameters) for path in self.field_path] + self._field_path = [ + InterpolatedString.create(path, parameters=parameters) for path in self.field_path + ] for path_index in range(len(self.field_path)): if isinstance(self.field_path[path_index], str): - self._field_path[path_index] = InterpolatedString.create(self.field_path[path_index], parameters=parameters) + self._field_path[path_index] = InterpolatedString.create( + self.field_path[path_index], parameters=parameters + ) def extract_records(self, response: requests.Response) -> Iterable[MutableMapping[Any, Any]]: for body in self.decoder.decode(response): diff --git a/airbyte_cdk/sources/declarative/extractors/record_filter.py b/airbyte_cdk/sources/declarative/extractors/record_filter.py index f396224c..e84e229f 100644 --- a/airbyte_cdk/sources/declarative/extractors/record_filter.py +++ b/airbyte_cdk/sources/declarative/extractors/record_filter.py @@ -5,7 +5,11 @@ from dataclasses import InitVar, dataclass from typing import Any, Iterable, Mapping, Optional, Union -from airbyte_cdk.sources.declarative.incremental import DatetimeBasedCursor, GlobalSubstreamCursor, PerPartitionWithGlobalCursor +from airbyte_cdk.sources.declarative.incremental import ( + DatetimeBasedCursor, + GlobalSubstreamCursor, + PerPartitionWithGlobalCursor, +) from airbyte_cdk.sources.declarative.interpolation.interpolated_boolean import InterpolatedBoolean from airbyte_cdk.sources.types import Config, StreamSlice, StreamState @@ -24,7 +28,9 @@ class RecordFilter: condition: str = "" def __post_init__(self, parameters: Mapping[str, Any]) -> None: - self._filter_interpolator = InterpolatedBoolean(condition=self.condition, parameters=parameters) + self._filter_interpolator = InterpolatedBoolean( + condition=self.condition, parameters=parameters + ) def filter_records( self, @@ -68,7 +74,9 @@ def _cursor_field(self) -> str: @property def _start_date_from_config(self) -> datetime.datetime: - return self._date_time_based_cursor._start_datetime.get_datetime(self._date_time_based_cursor.config) + return self._date_time_based_cursor._start_datetime.get_datetime( + self._date_time_based_cursor.config + ) @property def _end_datetime(self) -> datetime.datetime: @@ -81,20 +89,29 @@ def filter_records( stream_slice: Optional[StreamSlice] = None, next_page_token: Optional[Mapping[str, Any]] = None, ) -> Iterable[Mapping[str, Any]]: - state_value = self._get_state_value(stream_state, stream_slice or StreamSlice(partition={}, cursor_slice={})) + state_value = self._get_state_value( + stream_state, stream_slice or StreamSlice(partition={}, cursor_slice={}) + ) filter_date: datetime.datetime = self._get_filter_date(state_value) records = ( record for record in records - if self._end_datetime >= self._date_time_based_cursor.parse_date(record[self._cursor_field]) >= filter_date + if self._end_datetime + >= self._date_time_based_cursor.parse_date(record[self._cursor_field]) + >= filter_date ) if self.condition: records = super().filter_records( - records=records, stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token + records=records, + stream_state=stream_state, + stream_slice=stream_slice, + next_page_token=next_page_token, ) yield from records - def _get_state_value(self, stream_state: StreamState, stream_slice: StreamSlice) -> Optional[str]: + def _get_state_value( + self, stream_state: StreamState, stream_slice: StreamSlice + ) -> Optional[str]: """ Return cursor_value or None in case it was not found. Cursor_value may be empty if: diff --git a/airbyte_cdk/sources/declarative/extractors/record_selector.py b/airbyte_cdk/sources/declarative/extractors/record_selector.py index eed33d85..caaa4be2 100644 --- a/airbyte_cdk/sources/declarative/extractors/record_selector.py +++ b/airbyte_cdk/sources/declarative/extractors/record_selector.py @@ -61,7 +61,9 @@ def select_records( :return: List of Records selected from the response """ all_data: Iterable[Mapping[str, Any]] = self.extractor.extract_records(response) - yield from self.filter_and_transform(all_data, stream_state, records_schema, stream_slice, next_page_token) + yield from self.filter_and_transform( + all_data, stream_state, records_schema, stream_slice, next_page_token + ) def filter_and_transform( self, @@ -106,7 +108,10 @@ def _filter( ) -> Iterable[Mapping[str, Any]]: if self.record_filter: yield from self.record_filter.filter_records( - records, stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token + records, + stream_state=stream_state, + stream_slice=stream_slice, + next_page_token=next_page_token, ) else: yield from records @@ -119,5 +124,7 @@ def _transform( ) -> Iterable[Mapping[str, Any]]: for record in records: for transformation in self.transformations: - transformation.transform(record, config=self.config, stream_state=stream_state, stream_slice=stream_slice) # type: ignore # record has type Mapping[str, Any], but Dict[str, Any] expected + transformation.transform( + record, config=self.config, stream_state=stream_state, stream_slice=stream_slice + ) # type: ignore # record has type Mapping[str, Any], but Dict[str, Any] expected yield record diff --git a/airbyte_cdk/sources/declarative/extractors/response_to_file_extractor.py b/airbyte_cdk/sources/declarative/extractors/response_to_file_extractor.py index 48ef69e1..8be2f6b6 100644 --- a/airbyte_cdk/sources/declarative/extractors/response_to_file_extractor.py +++ b/airbyte_cdk/sources/declarative/extractors/response_to_file_extractor.py @@ -68,7 +68,9 @@ def _filter_null_bytes(self, b: bytes) -> bytes: res = b.replace(b"\x00", b"") if len(res) < len(b): - self.logger.warning("Filter 'null' bytes from string, size reduced %d -> %d chars", len(b), len(res)) + self.logger.warning( + "Filter 'null' bytes from string, size reduced %d -> %d chars", len(b), len(res) + ) return res def _save_to_file(self, response: requests.Response) -> Tuple[str, str]: @@ -106,9 +108,13 @@ def _save_to_file(self, response: requests.Response) -> Tuple[str, str]: if os.path.isfile(tmp_file): return tmp_file, response_encoding else: - raise ValueError(f"The IO/Error occured while verifying binary data. Tmp file {tmp_file} doesn't exist.") + raise ValueError( + f"The IO/Error occured while verifying binary data. Tmp file {tmp_file} doesn't exist." + ) - def _read_with_chunks(self, path: str, file_encoding: str, chunk_size: int = 100) -> Iterable[Mapping[str, Any]]: + def _read_with_chunks( + self, path: str, file_encoding: str, chunk_size: int = 100 + ) -> Iterable[Mapping[str, Any]]: """ Reads data from a file in chunks and yields each row as a dictionary. @@ -126,7 +132,9 @@ def _read_with_chunks(self, path: str, file_encoding: str, chunk_size: int = 100 try: with open(path, "r", encoding=file_encoding) as data: - chunks = pd.read_csv(data, chunksize=chunk_size, iterator=True, dialect="unix", dtype=object) + chunks = pd.read_csv( + data, chunksize=chunk_size, iterator=True, dialect="unix", dtype=object + ) for chunk in chunks: chunk = chunk.replace({nan: None}).to_dict(orient="records") for row in chunk: @@ -140,7 +148,9 @@ def _read_with_chunks(self, path: str, file_encoding: str, chunk_size: int = 100 # remove binary tmp file, after data is read os.remove(path) - def extract_records(self, response: Optional[requests.Response] = None) -> Iterable[Mapping[str, Any]]: + def extract_records( + self, response: Optional[requests.Response] = None + ) -> Iterable[Mapping[str, Any]]: """ Extracts records from the given response by: 1) Saving the result to a tmp file diff --git a/airbyte_cdk/sources/declarative/incremental/datetime_based_cursor.py b/airbyte_cdk/sources/declarative/incremental/datetime_based_cursor.py index e58d2256..3977623d 100644 --- a/airbyte_cdk/sources/declarative/incremental/datetime_based_cursor.py +++ b/airbyte_cdk/sources/declarative/incremental/datetime_based_cursor.py @@ -13,7 +13,10 @@ from airbyte_cdk.sources.declarative.incremental.declarative_cursor import DeclarativeCursor from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString from airbyte_cdk.sources.declarative.interpolation.jinja import JinjaInterpolation -from airbyte_cdk.sources.declarative.requesters.request_option import RequestOption, RequestOptionType +from airbyte_cdk.sources.declarative.requesters.request_option import ( + RequestOption, + RequestOptionType, +) from airbyte_cdk.sources.message import MessageRepository from airbyte_cdk.sources.types import Config, Record, StreamSlice, StreamState from isodate import Duration, duration_isoformat, parse_duration @@ -72,27 +75,41 @@ class DatetimeBasedCursor(DeclarativeCursor): cursor_datetime_formats: List[str] = field(default_factory=lambda: []) def __post_init__(self, parameters: Mapping[str, Any]) -> None: - if (self.step and not self.cursor_granularity) or (not self.step and self.cursor_granularity): + if (self.step and not self.cursor_granularity) or ( + not self.step and self.cursor_granularity + ): raise ValueError( f"If step is defined, cursor_granularity should be as well and vice-versa. " f"Right now, step is `{self.step}` and cursor_granularity is `{self.cursor_granularity}`" ) self._start_datetime = MinMaxDatetime.create(self.start_datetime, parameters) - self._end_datetime = None if not self.end_datetime else MinMaxDatetime.create(self.end_datetime, parameters) + self._end_datetime = ( + None if not self.end_datetime else MinMaxDatetime.create(self.end_datetime, parameters) + ) self._timezone = datetime.timezone.utc self._interpolation = JinjaInterpolation() self._step = ( - self._parse_timedelta(InterpolatedString.create(self.step, parameters=parameters).eval(self.config)) + self._parse_timedelta( + InterpolatedString.create(self.step, parameters=parameters).eval(self.config) + ) if self.step else datetime.timedelta.max ) self._cursor_granularity = self._parse_timedelta(self.cursor_granularity) self.cursor_field = InterpolatedString.create(self.cursor_field, parameters=parameters) - self._lookback_window = InterpolatedString.create(self.lookback_window, parameters=parameters) if self.lookback_window else None - self._partition_field_start = InterpolatedString.create(self.partition_field_start or "start_time", parameters=parameters) - self._partition_field_end = InterpolatedString.create(self.partition_field_end or "end_time", parameters=parameters) + self._lookback_window = ( + InterpolatedString.create(self.lookback_window, parameters=parameters) + if self.lookback_window + else None + ) + self._partition_field_start = InterpolatedString.create( + self.partition_field_start or "start_time", parameters=parameters + ) + self._partition_field_end = InterpolatedString.create( + self.partition_field_end or "end_time", parameters=parameters + ) self._parser = DatetimeParser() # If datetime format is not specified then start/end datetime should inherit it from the stream slicer @@ -114,7 +131,9 @@ def set_initial_state(self, stream_state: StreamState) -> None: :param stream_state: The state of the stream as returned by get_stream_state """ - self._cursor = stream_state.get(self.cursor_field.eval(self.config)) if stream_state else None # type: ignore # cursor_field is converted to an InterpolatedString in __post_init__ + self._cursor = ( + stream_state.get(self.cursor_field.eval(self.config)) if stream_state else None + ) # type: ignore # cursor_field is converted to an InterpolatedString in __post_init__ def observe(self, stream_slice: StreamSlice, record: Record) -> None: """ @@ -131,28 +150,38 @@ def observe(self, stream_slice: StreamSlice, record: Record) -> None: start_field = self._partition_field_start.eval(self.config) end_field = self._partition_field_end.eval(self.config) - is_highest_observed_cursor_value = not self._highest_observed_cursor_field_value or self.parse_date( - record_cursor_value - ) > self.parse_date(self._highest_observed_cursor_field_value) + is_highest_observed_cursor_value = ( + not self._highest_observed_cursor_field_value + or self.parse_date(record_cursor_value) + > self.parse_date(self._highest_observed_cursor_field_value) + ) if ( - self._is_within_daterange_boundaries(record, stream_slice.get(start_field), stream_slice.get(end_field)) # type: ignore # we know that stream_slices for these cursors will use a string representing an unparsed date + self._is_within_daterange_boundaries( + record, stream_slice.get(start_field), stream_slice.get(end_field) + ) # type: ignore # we know that stream_slices for these cursors will use a string representing an unparsed date and is_highest_observed_cursor_value ): self._highest_observed_cursor_field_value = record_cursor_value def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None: if stream_slice.partition: - raise ValueError(f"Stream slice {stream_slice} should not have a partition. Got {stream_slice.partition}.") + raise ValueError( + f"Stream slice {stream_slice} should not have a partition. Got {stream_slice.partition}." + ) cursor_value_str_by_cursor_value_datetime = dict( map( # we need to ensure the cursor value is preserved as is in the state else the CATs might complain of something like # 2023-01-04T17:30:19.000Z' <= '2023-01-04T17:30:19.000000Z' lambda datetime_str: (self.parse_date(datetime_str), datetime_str), # type: ignore # because of the filter on the next line, this will only be called with a str - filter(lambda item: item, [self._cursor, self._highest_observed_cursor_field_value]), + filter( + lambda item: item, [self._cursor, self._highest_observed_cursor_field_value] + ), ) ) self._cursor = ( - cursor_value_str_by_cursor_value_datetime[max(cursor_value_str_by_cursor_value_datetime.keys())] + cursor_value_str_by_cursor_value_datetime[ + max(cursor_value_str_by_cursor_value_datetime.keys()) + ] if cursor_value_str_by_cursor_value_datetime else None ) @@ -175,11 +204,19 @@ def select_state(self, stream_slice: Optional[StreamSlice] = None) -> Optional[S # through each slice and does not belong to a specific slice. We just return stream state as it is. return self.get_stream_state() - def _calculate_earliest_possible_value(self, end_datetime: datetime.datetime) -> datetime.datetime: - lookback_delta = self._parse_timedelta(self._lookback_window.eval(self.config) if self._lookback_window else "P0D") - earliest_possible_start_datetime = min(self._start_datetime.get_datetime(self.config), end_datetime) + def _calculate_earliest_possible_value( + self, end_datetime: datetime.datetime + ) -> datetime.datetime: + lookback_delta = self._parse_timedelta( + self._lookback_window.eval(self.config) if self._lookback_window else "P0D" + ) + earliest_possible_start_datetime = min( + self._start_datetime.get_datetime(self.config), end_datetime + ) try: - cursor_datetime = self._calculate_cursor_datetime_from_state(self.get_stream_state()) - lookback_delta + cursor_datetime = ( + self._calculate_cursor_datetime_from_state(self.get_stream_state()) - lookback_delta + ) except OverflowError: # cursor_datetime defers to the minimum date if it does not exist in the state. Trying to subtract # a timedelta from the minimum datetime results in an OverflowError @@ -200,7 +237,9 @@ def select_best_end_datetime(self) -> datetime.datetime: return now return min(self._end_datetime.get_datetime(self.config), now) - def _calculate_cursor_datetime_from_state(self, stream_state: Mapping[str, Any]) -> datetime.datetime: + def _calculate_cursor_datetime_from_state( + self, stream_state: Mapping[str, Any] + ) -> datetime.datetime: if self.cursor_field.eval(self.config, stream_state=stream_state) in stream_state: # type: ignore # cursor_field is converted to an InterpolatedString in __post_init__ return self.parse_date(stream_state[self.cursor_field.eval(self.config)]) # type: ignore # cursor_field is converted to an InterpolatedString in __post_init__ return datetime.datetime.min.replace(tzinfo=datetime.timezone.utc) @@ -209,7 +248,10 @@ def _format_datetime(self, dt: datetime.datetime) -> str: return self._parser.format(dt, self.datetime_format) def _partition_daterange( - self, start: datetime.datetime, end: datetime.datetime, step: Union[datetime.timedelta, Duration] + self, + start: datetime.datetime, + end: datetime.datetime, + step: Union[datetime.timedelta, Duration], ) -> List[StreamSlice]: start_field = self._partition_field_start.eval(self.config) end_field = self._partition_field_end.eval(self.config) @@ -220,7 +262,11 @@ def _partition_daterange( end_date = self._get_date(next_start - self._cursor_granularity, end, min) dates.append( StreamSlice( - partition={}, cursor_slice={start_field: self._format_datetime(start), end_field: self._format_datetime(end_date)} + partition={}, + cursor_slice={ + start_field: self._format_datetime(start), + end_field: self._format_datetime(end_date), + }, ) ) start = next_start @@ -231,7 +277,9 @@ def _is_within_date_range(self, start: datetime.datetime, end: datetime.datetime return start < end return start <= end - def _evaluate_next_start_date_safely(self, start: datetime.datetime, step: datetime.timedelta) -> datetime.datetime: + def _evaluate_next_start_date_safely( + self, start: datetime.datetime, step: datetime.timedelta + ) -> datetime.datetime: """ Given that we set the default step at datetime.timedelta.max, we will generate an OverflowError when evaluating the next start_date This method assumes that users would never enter a step that would generate an overflow. Given that would be the case, the code @@ -308,7 +356,9 @@ def request_kwargs(self) -> Mapping[str, Any]: # Never update kwargs return {} - def _get_request_options(self, option_type: RequestOptionType, stream_slice: Optional[StreamSlice]) -> Mapping[str, Any]: + def _get_request_options( + self, option_type: RequestOptionType, stream_slice: Optional[StreamSlice] + ) -> Mapping[str, Any]: options: MutableMapping[str, Any] = {} if not stream_slice: return options @@ -332,11 +382,18 @@ def should_be_synced(self, record: Record) -> bool: ) return True latest_possible_cursor_value = self.select_best_end_datetime() - earliest_possible_cursor_value = self._calculate_earliest_possible_value(latest_possible_cursor_value) - return self._is_within_daterange_boundaries(record, earliest_possible_cursor_value, latest_possible_cursor_value) + earliest_possible_cursor_value = self._calculate_earliest_possible_value( + latest_possible_cursor_value + ) + return self._is_within_daterange_boundaries( + record, earliest_possible_cursor_value, latest_possible_cursor_value + ) def _is_within_daterange_boundaries( - self, record: Record, start_datetime_boundary: Union[datetime.datetime, str], end_datetime_boundary: Union[datetime.datetime, str] + self, + record: Record, + start_datetime_boundary: Union[datetime.datetime, str], + end_datetime_boundary: Union[datetime.datetime, str], ) -> bool: cursor_field = self.cursor_field.eval(self.config) # type: ignore # cursor_field is converted to an InterpolatedString in __post_init__ record_cursor_value = record.get(cursor_field) @@ -350,7 +407,9 @@ def _is_within_daterange_boundaries( start_datetime_boundary = self.parse_date(start_datetime_boundary) if isinstance(end_datetime_boundary, str): end_datetime_boundary = self.parse_date(end_datetime_boundary) - return start_datetime_boundary <= self.parse_date(record_cursor_value) <= end_datetime_boundary + return ( + start_datetime_boundary <= self.parse_date(record_cursor_value) <= end_datetime_boundary + ) def _send_log(self, level: Level, message: str) -> None: if self.message_repository: @@ -380,8 +439,12 @@ def set_runtime_lookback_window(self, lookback_window_in_seconds: int) -> None: :param lookback_window_in_seconds: The lookback duration in seconds to potentially update to. """ runtime_lookback_window = duration_isoformat(timedelta(seconds=lookback_window_in_seconds)) - config_lookback = parse_duration(self._lookback_window.eval(self.config) if self._lookback_window else "P0D") + config_lookback = parse_duration( + self._lookback_window.eval(self.config) if self._lookback_window else "P0D" + ) # Check if the new runtime lookback window is greater than the current config lookback if parse_duration(runtime_lookback_window) > config_lookback: - self._lookback_window = InterpolatedString.create(runtime_lookback_window, parameters={}) + self._lookback_window = InterpolatedString.create( + runtime_lookback_window, parameters={} + ) diff --git a/airbyte_cdk/sources/declarative/incremental/global_substream_cursor.py b/airbyte_cdk/sources/declarative/incremental/global_substream_cursor.py index f7454ef0..b912eb9a 100644 --- a/airbyte_cdk/sources/declarative/incremental/global_substream_cursor.py +++ b/airbyte_cdk/sources/declarative/incremental/global_substream_cursor.py @@ -84,7 +84,9 @@ def __init__(self, stream_cursor: DatetimeBasedCursor, partition_router: Partiti self._partition_router = partition_router self._timer = Timer() self._lock = threading.Lock() - self._slice_semaphore = threading.Semaphore(0) # Start with 0, indicating no slices being tracked + self._slice_semaphore = threading.Semaphore( + 0 + ) # Start with 0, indicating no slices being tracked self._all_slices_yielded = False self._lookback_window: Optional[int] = None self._current_partition: Optional[Mapping[str, Any]] = None @@ -116,7 +118,9 @@ def stream_slices(self) -> Iterable[StreamSlice]: ) self.start_slices_generation() - for slice, last, state in iterate_with_last_flag_and_state(slice_generator, self._partition_router.get_stream_state): + for slice, last, state in iterate_with_last_flag_and_state( + slice_generator, self._partition_router.get_stream_state + ): self._parent_state = state self.register_slice(last) yield slice @@ -124,7 +128,8 @@ def stream_slices(self) -> Iterable[StreamSlice]: def generate_slices_from_partition(self, partition: StreamSlice) -> Iterable[StreamSlice]: slice_generator = ( - StreamSlice(partition=partition, cursor_slice=cursor_slice) for cursor_slice in self._stream_cursor.stream_slices() + StreamSlice(partition=partition, cursor_slice=cursor_slice) + for cursor_slice in self._stream_cursor.stream_slices() ) yield from slice_generator @@ -199,10 +204,14 @@ def _inject_lookback_into_stream_cursor(self, lookback_window: int) -> None: if hasattr(self._stream_cursor, "set_runtime_lookback_window"): self._stream_cursor.set_runtime_lookback_window(lookback_window) else: - raise ValueError("The cursor class for Global Substream Cursor does not have a set_runtime_lookback_window method") + raise ValueError( + "The cursor class for Global Substream Cursor does not have a set_runtime_lookback_window method" + ) def observe(self, stream_slice: StreamSlice, record: Record) -> None: - self._stream_cursor.observe(StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice), record) + self._stream_cursor.observe( + StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice), record + ) def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None: """ @@ -220,7 +229,9 @@ def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None: self._slice_semaphore.acquire() if self._all_slices_yielded and self._slice_semaphore._value == 0: self._lookback_window = self._timer.finish() - self._stream_cursor.close_slice(StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice), *args) + self._stream_cursor.close_slice( + StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice), *args + ) def get_stream_state(self) -> StreamState: state: dict[str, Any] = {"state": self._stream_cursor.get_stream_state()} @@ -322,12 +333,15 @@ def should_be_synced(self, record: Record) -> bool: def is_greater_than_or_equal(self, first: Record, second: Record) -> bool: return self._stream_cursor.is_greater_than_or_equal( - self._convert_record_to_cursor_record(first), self._convert_record_to_cursor_record(second) + self._convert_record_to_cursor_record(first), + self._convert_record_to_cursor_record(second), ) @staticmethod def _convert_record_to_cursor_record(record: Record) -> Record: return Record( record.data, - StreamSlice(partition={}, cursor_slice=record.associated_slice.cursor_slice) if record.associated_slice else None, + StreamSlice(partition={}, cursor_slice=record.associated_slice.cursor_slice) + if record.associated_slice + else None, ) diff --git a/airbyte_cdk/sources/declarative/incremental/per_partition_cursor.py b/airbyte_cdk/sources/declarative/incremental/per_partition_cursor.py index 86236ec9..a6449d81 100644 --- a/airbyte_cdk/sources/declarative/incremental/per_partition_cursor.py +++ b/airbyte_cdk/sources/declarative/incremental/per_partition_cursor.py @@ -8,7 +8,9 @@ from airbyte_cdk.sources.declarative.incremental.declarative_cursor import DeclarativeCursor from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter -from airbyte_cdk.sources.streams.checkpoint.per_partition_key_serializer import PerPartitionKeySerializer +from airbyte_cdk.sources.streams.checkpoint.per_partition_key_serializer import ( + PerPartitionKeySerializer, +) from airbyte_cdk.sources.types import Record, StreamSlice, StreamState logger = logging.getLogger("airbyte") @@ -67,12 +69,18 @@ def generate_slices_from_partition(self, partition: StreamSlice) -> Iterable[Str cursor = self._cursor_per_partition.get(self._to_partition_key(partition.partition)) if not cursor: - partition_state = self._state_to_migrate_from if self._state_to_migrate_from else self._NO_CURSOR_STATE + partition_state = ( + self._state_to_migrate_from + if self._state_to_migrate_from + else self._NO_CURSOR_STATE + ) cursor = self._create_cursor(partition_state) self._cursor_per_partition[self._to_partition_key(partition.partition)] = cursor for cursor_slice in cursor.stream_slices(): - yield StreamSlice(partition=partition, cursor_slice=cursor_slice, extra_fields=partition.extra_fields) + yield StreamSlice( + partition=partition, cursor_slice=cursor_slice, extra_fields=partition.extra_fields + ) def _ensure_partition_limit(self) -> None: """ @@ -80,7 +88,9 @@ def _ensure_partition_limit(self) -> None: """ while len(self._cursor_per_partition) > self.DEFAULT_MAX_PARTITIONS_NUMBER - 1: self._over_limit += 1 - oldest_partition = self._cursor_per_partition.popitem(last=False)[0] # Remove the oldest partition + oldest_partition = self._cursor_per_partition.popitem(last=False)[ + 0 + ] # Remove the oldest partition logger.warning( f"The maximum number of partitions has been reached. Dropping the oldest partition: {oldest_partition}. Over limit: {self._over_limit}." ) @@ -128,7 +138,9 @@ def set_initial_state(self, stream_state: StreamState) -> None: else: for state in stream_state["states"]: - self._cursor_per_partition[self._to_partition_key(state["partition"])] = self._create_cursor(state["cursor"]) + self._cursor_per_partition[self._to_partition_key(state["partition"])] = ( + self._create_cursor(state["cursor"]) + ) # set default state for missing partitions if it is per partition with fallback to global if "state" in stream_state: @@ -214,7 +226,9 @@ def get_request_params( stream_state=stream_state, stream_slice=StreamSlice(partition=stream_slice.partition, cursor_slice={}), next_page_token=next_page_token, - ) | self._cursor_per_partition[self._to_partition_key(stream_slice.partition)].get_request_params( + ) | self._cursor_per_partition[ + self._to_partition_key(stream_slice.partition) + ].get_request_params( stream_state=stream_state, stream_slice=StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice), next_page_token=next_page_token, @@ -234,7 +248,9 @@ def get_request_headers( stream_state=stream_state, stream_slice=StreamSlice(partition=stream_slice.partition, cursor_slice={}), next_page_token=next_page_token, - ) | self._cursor_per_partition[self._to_partition_key(stream_slice.partition)].get_request_headers( + ) | self._cursor_per_partition[ + self._to_partition_key(stream_slice.partition) + ].get_request_headers( stream_state=stream_state, stream_slice=StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice), next_page_token=next_page_token, @@ -254,7 +270,9 @@ def get_request_body_data( stream_state=stream_state, stream_slice=StreamSlice(partition=stream_slice.partition, cursor_slice={}), next_page_token=next_page_token, - ) | self._cursor_per_partition[self._to_partition_key(stream_slice.partition)].get_request_body_data( + ) | self._cursor_per_partition[ + self._to_partition_key(stream_slice.partition) + ].get_request_body_data( stream_state=stream_state, stream_slice=StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice), next_page_token=next_page_token, @@ -274,7 +292,9 @@ def get_request_body_json( stream_state=stream_state, stream_slice=StreamSlice(partition=stream_slice.partition, cursor_slice={}), next_page_token=next_page_token, - ) | self._cursor_per_partition[self._to_partition_key(stream_slice.partition)].get_request_body_json( + ) | self._cursor_per_partition[ + self._to_partition_key(stream_slice.partition) + ].get_request_body_json( stream_state=stream_state, stream_slice=StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice), next_page_token=next_page_token, @@ -283,32 +303,43 @@ def get_request_body_json( raise ValueError("A partition needs to be provided in order to get request body json") def should_be_synced(self, record: Record) -> bool: - return self._get_cursor(record).should_be_synced(self._convert_record_to_cursor_record(record)) + return self._get_cursor(record).should_be_synced( + self._convert_record_to_cursor_record(record) + ) def is_greater_than_or_equal(self, first: Record, second: Record) -> bool: if not first.associated_slice or not second.associated_slice: - raise ValueError(f"Both records should have an associated slice but got {first.associated_slice} and {second.associated_slice}") + raise ValueError( + f"Both records should have an associated slice but got {first.associated_slice} and {second.associated_slice}" + ) if first.associated_slice.partition != second.associated_slice.partition: raise ValueError( f"To compare records, partition should be the same but got {first.associated_slice.partition} and {second.associated_slice.partition}" ) return self._get_cursor(first).is_greater_than_or_equal( - self._convert_record_to_cursor_record(first), self._convert_record_to_cursor_record(second) + self._convert_record_to_cursor_record(first), + self._convert_record_to_cursor_record(second), ) @staticmethod def _convert_record_to_cursor_record(record: Record) -> Record: return Record( record.data, - StreamSlice(partition={}, cursor_slice=record.associated_slice.cursor_slice) if record.associated_slice else None, + StreamSlice(partition={}, cursor_slice=record.associated_slice.cursor_slice) + if record.associated_slice + else None, ) def _get_cursor(self, record: Record) -> DeclarativeCursor: if not record.associated_slice: - raise ValueError("Invalid state as stream slices that are emitted should refer to an existing cursor") + raise ValueError( + "Invalid state as stream slices that are emitted should refer to an existing cursor" + ) partition_key = self._to_partition_key(record.associated_slice.partition) if partition_key not in self._cursor_per_partition: - raise ValueError("Invalid state as stream slices that are emitted should refer to an existing cursor") + raise ValueError( + "Invalid state as stream slices that are emitted should refer to an existing cursor" + ) cursor = self._cursor_per_partition[partition_key] return cursor diff --git a/airbyte_cdk/sources/declarative/incremental/per_partition_with_global.py b/airbyte_cdk/sources/declarative/incremental/per_partition_with_global.py index d5ad6b40..346810a1 100644 --- a/airbyte_cdk/sources/declarative/incremental/per_partition_with_global.py +++ b/airbyte_cdk/sources/declarative/incremental/per_partition_with_global.py @@ -5,8 +5,14 @@ from airbyte_cdk.sources.declarative.incremental.datetime_based_cursor import DatetimeBasedCursor from airbyte_cdk.sources.declarative.incremental.declarative_cursor import DeclarativeCursor -from airbyte_cdk.sources.declarative.incremental.global_substream_cursor import GlobalSubstreamCursor, iterate_with_last_flag_and_state -from airbyte_cdk.sources.declarative.incremental.per_partition_cursor import CursorFactory, PerPartitionCursor +from airbyte_cdk.sources.declarative.incremental.global_substream_cursor import ( + GlobalSubstreamCursor, + iterate_with_last_flag_and_state, +) +from airbyte_cdk.sources.declarative.incremental.per_partition_cursor import ( + CursorFactory, + PerPartitionCursor, +) from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter from airbyte_cdk.sources.types import Record, StreamSlice, StreamState @@ -60,7 +66,12 @@ class PerPartitionWithGlobalCursor(DeclarativeCursor): Suitable for streams where the number of partitions may vary significantly, requiring dynamic switching between per-partition and global state management to ensure data consistency and efficient synchronization. """ - def __init__(self, cursor_factory: CursorFactory, partition_router: PartitionRouter, stream_cursor: DatetimeBasedCursor): + def __init__( + self, + cursor_factory: CursorFactory, + partition_router: PartitionRouter, + stream_cursor: DatetimeBasedCursor, + ): self._partition_router = partition_router self._per_partition_cursor = PerPartitionCursor(cursor_factory, partition_router) self._global_cursor = GlobalSubstreamCursor(stream_cursor, partition_router) @@ -82,7 +93,8 @@ def stream_slices(self) -> Iterable[StreamSlice]: # Generate slices for the current cursor and handle the last slice using the flag self._parent_state = parent_state for slice, is_last_slice, _ in iterate_with_last_flag_and_state( - self._get_active_cursor().generate_slices_from_partition(partition=partition), lambda: None + self._get_active_cursor().generate_slices_from_partition(partition=partition), + lambda: None, ): self._global_cursor.register_slice(is_last_slice and is_last_partition) yield slice @@ -182,7 +194,9 @@ def get_request_body_json( ) def should_be_synced(self, record: Record) -> bool: - return self._global_cursor.should_be_synced(record) or self._per_partition_cursor.should_be_synced(record) + return self._global_cursor.should_be_synced( + record + ) or self._per_partition_cursor.should_be_synced(record) def is_greater_than_or_equal(self, first: Record, second: Record) -> bool: return self._global_cursor.is_greater_than_or_equal(first, second) diff --git a/airbyte_cdk/sources/declarative/incremental/resumable_full_refresh_cursor.py b/airbyte_cdk/sources/declarative/incremental/resumable_full_refresh_cursor.py index 499220a4..a0b4665f 100644 --- a/airbyte_cdk/sources/declarative/incremental/resumable_full_refresh_cursor.py +++ b/airbyte_cdk/sources/declarative/incremental/resumable_full_refresh_cursor.py @@ -30,7 +30,9 @@ def observe(self, stream_slice: StreamSlice, record: Record) -> None: def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None: # The ResumableFullRefreshCursor doesn't support nested streams yet so receiving a partition is unexpected if stream_slice.partition: - raise ValueError(f"Stream slice {stream_slice} should not have a partition. Got {stream_slice.partition}.") + raise ValueError( + f"Stream slice {stream_slice} should not have a partition. Got {stream_slice.partition}." + ) self._cursor = stream_slice.cursor_slice def should_be_synced(self, record: Record) -> bool: diff --git a/airbyte_cdk/sources/declarative/interpolation/interpolated_boolean.py b/airbyte_cdk/sources/declarative/interpolation/interpolated_boolean.py index e3c3d6a6..78569b35 100644 --- a/airbyte_cdk/sources/declarative/interpolation/interpolated_boolean.py +++ b/airbyte_cdk/sources/declarative/interpolation/interpolated_boolean.py @@ -8,7 +8,21 @@ from airbyte_cdk.sources.declarative.interpolation.jinja import JinjaInterpolation from airbyte_cdk.sources.types import Config -FALSE_VALUES: Final[List[Any]] = ["False", "false", "{}", "[]", "()", "", "0", "0.0", {}, False, [], (), set()] +FALSE_VALUES: Final[List[Any]] = [ + "False", + "false", + "{}", + "[]", + "()", + "", + "0", + "0.0", + {}, + False, + [], + (), + set(), +] @dataclass @@ -40,7 +54,11 @@ def eval(self, config: Config, **additional_parameters: Any) -> bool: return self.condition else: evaluated = self._interpolation.eval( - self.condition, config, self._default, parameters=self._parameters, **additional_parameters + self.condition, + config, + self._default, + parameters=self._parameters, + **additional_parameters, ) if evaluated in FALSE_VALUES: return False diff --git a/airbyte_cdk/sources/declarative/interpolation/interpolated_mapping.py b/airbyte_cdk/sources/declarative/interpolation/interpolated_mapping.py index b0f26e0d..11b2dac9 100644 --- a/airbyte_cdk/sources/declarative/interpolation/interpolated_mapping.py +++ b/airbyte_cdk/sources/declarative/interpolation/interpolated_mapping.py @@ -38,7 +38,11 @@ def eval(self, config: Config, **additional_parameters: Any) -> Dict[str, Any]: valid_value_types = additional_parameters.pop("valid_value_types", None) return { self._interpolation.eval( - name, config, valid_types=valid_key_types, parameters=self._parameters, **additional_parameters + name, + config, + valid_types=valid_key_types, + parameters=self._parameters, + **additional_parameters, ): self._eval(value, config, valid_types=valid_value_types, **additional_parameters) for name, value in self.mapping.items() } diff --git a/airbyte_cdk/sources/declarative/interpolation/interpolated_nested_mapping.py b/airbyte_cdk/sources/declarative/interpolation/interpolated_nested_mapping.py index 6c0afde2..82454919 100644 --- a/airbyte_cdk/sources/declarative/interpolation/interpolated_nested_mapping.py +++ b/airbyte_cdk/sources/declarative/interpolation/interpolated_nested_mapping.py @@ -9,7 +9,9 @@ from airbyte_cdk.sources.declarative.interpolation.jinja import JinjaInterpolation from airbyte_cdk.sources.types import Config -NestedMappingEntry = Union[dict[str, "NestedMapping"], list["NestedMapping"], str, int, float, bool, None] +NestedMappingEntry = Union[ + dict[str, "NestedMapping"], list["NestedMapping"], str, int, float, bool, None +] NestedMapping = Union[dict[str, NestedMappingEntry], str] @@ -32,12 +34,17 @@ def __post_init__(self, parameters: Optional[Mapping[str, Any]]) -> None: def eval(self, config: Config, **additional_parameters: Any) -> Any: return self._eval(self.mapping, config, **additional_parameters) - def _eval(self, value: Union[NestedMapping, NestedMappingEntry], config: Config, **kwargs: Any) -> Any: + def _eval( + self, value: Union[NestedMapping, NestedMappingEntry], config: Config, **kwargs: Any + ) -> Any: # Recursively interpolate dictionaries and lists if isinstance(value, str): return self._interpolation.eval(value, config, parameters=self._parameters, **kwargs) elif isinstance(value, dict): - interpolated_dict = {self._eval(k, config, **kwargs): self._eval(v, config, **kwargs) for k, v in value.items()} + interpolated_dict = { + self._eval(k, config, **kwargs): self._eval(v, config, **kwargs) + for k, v in value.items() + } return {k: v for k, v in interpolated_dict.items() if v is not None} elif isinstance(value, list): return [self._eval(v, config, **kwargs) for v in value] diff --git a/airbyte_cdk/sources/declarative/interpolation/interpolated_string.py b/airbyte_cdk/sources/declarative/interpolation/interpolated_string.py index 393abc94..542fa806 100644 --- a/airbyte_cdk/sources/declarative/interpolation/interpolated_string.py +++ b/airbyte_cdk/sources/declarative/interpolation/interpolated_string.py @@ -45,10 +45,14 @@ def eval(self, config: Config, **kwargs: Any) -> Any: if self._is_plain_string is None: # Let's check whether output from evaluation is the same as input. # This indicates occurrence of a plain string, not a template and we can skip Jinja in subsequent runs. - evaluated = self._interpolation.eval(self.string, config, self.default, parameters=self._parameters, **kwargs) + evaluated = self._interpolation.eval( + self.string, config, self.default, parameters=self._parameters, **kwargs + ) self._is_plain_string = self.string == evaluated return evaluated - return self._interpolation.eval(self.string, config, self.default, parameters=self._parameters, **kwargs) + return self._interpolation.eval( + self.string, config, self.default, parameters=self._parameters, **kwargs + ) def __eq__(self, other: Any) -> bool: if not isinstance(other, InterpolatedString): diff --git a/airbyte_cdk/sources/declarative/interpolation/interpolation.py b/airbyte_cdk/sources/declarative/interpolation/interpolation.py index 8a8f05a5..5af61905 100644 --- a/airbyte_cdk/sources/declarative/interpolation/interpolation.py +++ b/airbyte_cdk/sources/declarative/interpolation/interpolation.py @@ -14,7 +14,13 @@ class Interpolation(ABC): """ @abstractmethod - def eval(self, input_str: str, config: Config, default: Optional[str] = None, **additional_options: Any) -> Any: + def eval( + self, + input_str: str, + config: Config, + default: Optional[str] = None, + **additional_options: Any, + ) -> Any: """ Interpolates the input string using the config, and additional options passed as parameter. diff --git a/airbyte_cdk/sources/declarative/interpolation/jinja.py b/airbyte_cdk/sources/declarative/interpolation/jinja.py index 45a93e58..553ef024 100644 --- a/airbyte_cdk/sources/declarative/interpolation/jinja.py +++ b/airbyte_cdk/sources/declarative/interpolation/jinja.py @@ -61,7 +61,9 @@ class JinjaInterpolation(Interpolation): # By default, these Python builtin functions are available in the Jinja context. # We explicitely remove them because of the potential security risk. # Please add a unit test to test_jinja.py when adding a restriction. - RESTRICTED_BUILTIN_FUNCTIONS = ["range"] # The range function can cause very expensive computations + RESTRICTED_BUILTIN_FUNCTIONS = [ + "range" + ] # The range function can cause very expensive computations def __init__(self) -> None: self._environment = StreamPartitionAccessEnvironment() @@ -119,7 +121,9 @@ def _eval(self, s: Optional[str], context: Mapping[str, Any]) -> Optional[str]: undeclared = self._find_undeclared_variables(s) undeclared_not_in_context = {var for var in undeclared if var not in context} if undeclared_not_in_context: - raise ValueError(f"Jinja macro has undeclared variables: {undeclared_not_in_context}. Context: {context}") + raise ValueError( + f"Jinja macro has undeclared variables: {undeclared_not_in_context}. Context: {context}" + ) return self._compile(s).render(context) # type: ignore # from_string is able to handle None except TypeError: # The string is a static value, not a jinja template diff --git a/airbyte_cdk/sources/declarative/interpolation/macros.py b/airbyte_cdk/sources/declarative/interpolation/macros.py index f5044434..ce448c12 100644 --- a/airbyte_cdk/sources/declarative/interpolation/macros.py +++ b/airbyte_cdk/sources/declarative/interpolation/macros.py @@ -104,7 +104,9 @@ def day_delta(num_days: int, format: str = "%Y-%m-%dT%H:%M:%S.%f%z") -> str: :param num_days: number of days to add to current date time :return: datetime formatted as RFC3339 """ - return (datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=num_days)).strftime(format) + return ( + datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=num_days) + ).strftime(format) def duration(datestring: str) -> Union[datetime.timedelta, isodate.Duration]: @@ -117,7 +119,9 @@ def duration(datestring: str) -> Union[datetime.timedelta, isodate.Duration]: return parse_duration(datestring) # type: ignore # mypy thinks this returns Any for some reason -def format_datetime(dt: Union[str, datetime.datetime], format: str, input_format: Optional[str] = None) -> str: +def format_datetime( + dt: Union[str, datetime.datetime], format: str, input_format: Optional[str] = None +) -> str: """ Converts datetime to another format @@ -130,11 +134,22 @@ def format_datetime(dt: Union[str, datetime.datetime], format: str, input_format """ if isinstance(dt, datetime.datetime): return dt.strftime(format) - dt_datetime = datetime.datetime.strptime(dt, input_format) if input_format else _str_to_datetime(dt) + dt_datetime = ( + datetime.datetime.strptime(dt, input_format) if input_format else _str_to_datetime(dt) + ) if format == "%s": return str(int(dt_datetime.timestamp())) return dt_datetime.strftime(format) -_macros_list = [now_utc, today_utc, timestamp, max, day_delta, duration, format_datetime, today_with_timezone] +_macros_list = [ + now_utc, + today_utc, + timestamp, + max, + day_delta, + duration, + format_datetime, + today_with_timezone, +] macros = {f.__name__: f for f in _macros_list} diff --git a/airbyte_cdk/sources/declarative/manifest_declarative_source.py b/airbyte_cdk/sources/declarative/manifest_declarative_source.py index 842d4e94..05fbee7a 100644 --- a/airbyte_cdk/sources/declarative/manifest_declarative_source.py +++ b/airbyte_cdk/sources/declarative/manifest_declarative_source.py @@ -20,16 +20,30 @@ ) from airbyte_cdk.sources.declarative.checks.connection_checker import ConnectionChecker from airbyte_cdk.sources.declarative.declarative_source import DeclarativeSource -from airbyte_cdk.sources.declarative.models.declarative_component_schema import CheckStream as CheckStreamModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import DeclarativeStream as DeclarativeStreamModel +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + CheckStream as CheckStreamModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + DeclarativeStream as DeclarativeStreamModel, +) from airbyte_cdk.sources.declarative.models.declarative_component_schema import Spec as SpecModel -from airbyte_cdk.sources.declarative.parsers.manifest_component_transformer import ManifestComponentTransformer -from airbyte_cdk.sources.declarative.parsers.manifest_reference_resolver import ManifestReferenceResolver -from airbyte_cdk.sources.declarative.parsers.model_to_component_factory import ModelToComponentFactory +from airbyte_cdk.sources.declarative.parsers.manifest_component_transformer import ( + ManifestComponentTransformer, +) +from airbyte_cdk.sources.declarative.parsers.manifest_reference_resolver import ( + ManifestReferenceResolver, +) +from airbyte_cdk.sources.declarative.parsers.model_to_component_factory import ( + ModelToComponentFactory, +) from airbyte_cdk.sources.message import MessageRepository from airbyte_cdk.sources.streams.core import Stream from airbyte_cdk.sources.types import ConnectionDefinition -from airbyte_cdk.sources.utils.slice_logger import AlwaysLogSliceLogger, DebugSliceLogger, SliceLogger +from airbyte_cdk.sources.utils.slice_logger import ( + AlwaysLogSliceLogger, + DebugSliceLogger, + SliceLogger, +) from jsonschema.exceptions import ValidationError from jsonschema.validators import validate @@ -57,13 +71,21 @@ def __init__( manifest["type"] = "DeclarativeSource" resolved_source_config = ManifestReferenceResolver().preprocess_manifest(manifest) - propagated_source_config = ManifestComponentTransformer().propagate_types_and_parameters("", resolved_source_config, {}) + propagated_source_config = ManifestComponentTransformer().propagate_types_and_parameters( + "", resolved_source_config, {} + ) self._source_config = propagated_source_config self._debug = debug self._emit_connector_builder_messages = emit_connector_builder_messages - self._constructor = component_factory if component_factory else ModelToComponentFactory(emit_connector_builder_messages) + self._constructor = ( + component_factory + if component_factory + else ModelToComponentFactory(emit_connector_builder_messages) + ) self._message_repository = self._constructor.get_message_repository() - self._slice_logger: SliceLogger = AlwaysLogSliceLogger() if emit_connector_builder_messages else DebugSliceLogger() + self._slice_logger: SliceLogger = ( + AlwaysLogSliceLogger() if emit_connector_builder_messages else DebugSliceLogger() + ) self._validate_source() @@ -81,20 +103,30 @@ def connection_checker(self) -> ConnectionChecker: if "type" not in check: check["type"] = "CheckStream" check_stream = self._constructor.create_component( - CheckStreamModel, check, dict(), emit_connector_builder_messages=self._emit_connector_builder_messages + CheckStreamModel, + check, + dict(), + emit_connector_builder_messages=self._emit_connector_builder_messages, ) if isinstance(check_stream, ConnectionChecker): return check_stream else: - raise ValueError(f"Expected to generate a ConnectionChecker component, but received {check_stream.__class__}") + raise ValueError( + f"Expected to generate a ConnectionChecker component, but received {check_stream.__class__}" + ) def streams(self, config: Mapping[str, Any]) -> List[Stream]: - self._emit_manifest_debug_message(extra_args={"source_name": self.name, "parsed_config": json.dumps(self._source_config)}) + self._emit_manifest_debug_message( + extra_args={"source_name": self.name, "parsed_config": json.dumps(self._source_config)} + ) stream_configs = self._stream_configs(self._source_config) source_streams = [ self._constructor.create_component( - DeclarativeStreamModel, stream_config, config, emit_connector_builder_messages=self._emit_connector_builder_messages + DeclarativeStreamModel, + stream_config, + config, + emit_connector_builder_messages=self._emit_connector_builder_messages, ) for stream_config in self._initialize_cache_for_parent_streams(deepcopy(stream_configs)) ] @@ -102,7 +134,9 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]: return source_streams @staticmethod - def _initialize_cache_for_parent_streams(stream_configs: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + def _initialize_cache_for_parent_streams( + stream_configs: List[Dict[str, Any]], + ) -> List[Dict[str, Any]]: parent_streams = set() def update_with_cache_parent_configs(parent_configs: list[dict[str, Any]]) -> None: @@ -113,12 +147,16 @@ def update_with_cache_parent_configs(parent_configs: list[dict[str, Any]]) -> No for stream_config in stream_configs: if stream_config.get("incremental_sync", {}).get("parent_stream"): parent_streams.add(stream_config["incremental_sync"]["parent_stream"]["name"]) - stream_config["incremental_sync"]["parent_stream"]["retriever"]["requester"]["use_cache"] = True + stream_config["incremental_sync"]["parent_stream"]["retriever"]["requester"][ + "use_cache" + ] = True elif stream_config.get("retriever", {}).get("partition_router", {}): partition_router = stream_config["retriever"]["partition_router"] - if isinstance(partition_router, dict) and partition_router.get("parent_stream_configs"): + if isinstance(partition_router, dict) and partition_router.get( + "parent_stream_configs" + ): update_with_cache_parent_configs(partition_router["parent_stream_configs"]) elif isinstance(partition_router, list): for router in partition_router: @@ -139,7 +177,9 @@ def spec(self, logger: logging.Logger) -> ConnectorSpecification: in the project root. """ self._configure_logger_level(logger) - self._emit_manifest_debug_message(extra_args={"source_name": self.name, "parsed_config": json.dumps(self._source_config)}) + self._emit_manifest_debug_message( + extra_args={"source_name": self.name, "parsed_config": json.dumps(self._source_config)} + ) spec = self._source_config.get("spec") if spec: @@ -176,22 +216,34 @@ def _validate_source(self) -> None: Validates the connector manifest against the declarative component schema """ try: - raw_component_schema = pkgutil.get_data("airbyte_cdk", "sources/declarative/declarative_component_schema.yaml") + raw_component_schema = pkgutil.get_data( + "airbyte_cdk", "sources/declarative/declarative_component_schema.yaml" + ) if raw_component_schema is not None: - declarative_component_schema = yaml.load(raw_component_schema, Loader=yaml.SafeLoader) + declarative_component_schema = yaml.load( + raw_component_schema, Loader=yaml.SafeLoader + ) else: - raise RuntimeError("Failed to read manifest component json schema required for validation") + raise RuntimeError( + "Failed to read manifest component json schema required for validation" + ) except FileNotFoundError as e: - raise FileNotFoundError(f"Failed to read manifest component json schema required for validation: {e}") + raise FileNotFoundError( + f"Failed to read manifest component json schema required for validation: {e}" + ) streams = self._source_config.get("streams") if not streams: - raise ValidationError(f"A valid manifest should have at least one stream defined. Got {streams}") + raise ValidationError( + f"A valid manifest should have at least one stream defined. Got {streams}" + ) try: validate(self._source_config, declarative_component_schema) except ValidationError as e: - raise ValidationError("Validation against json schema defined in declarative_component_schema.yaml schema failed") from e + raise ValidationError( + "Validation against json schema defined in declarative_component_schema.yaml schema failed" + ) from e cdk_version = metadata.version("airbyte_cdk") cdk_major, cdk_minor, cdk_patch = self._get_version_parts(cdk_version, "airbyte-cdk") @@ -200,9 +252,13 @@ def _validate_source(self) -> None: raise RuntimeError( "Manifest version is not defined in the manifest. This is unexpected since it should be a required field. Please contact support." ) - manifest_major, manifest_minor, manifest_patch = self._get_version_parts(manifest_version, "manifest") + manifest_major, manifest_minor, manifest_patch = self._get_version_parts( + manifest_version, "manifest" + ) - if cdk_major < manifest_major or (cdk_major == manifest_major and cdk_minor < manifest_minor): + if cdk_major < manifest_major or ( + cdk_major == manifest_major and cdk_minor < manifest_minor + ): raise ValidationError( f"The manifest version {manifest_version} is greater than the airbyte-cdk package version ({cdk_version}). Your " f"manifest may contain features that are not in the current CDK version." @@ -221,7 +277,9 @@ def _get_version_parts(version: str, version_type: str) -> Tuple[int, int, int]: """ version_parts = re.split(r"\.", version) if len(version_parts) != 3 or not all([part.isdigit() for part in version_parts]): - raise ValidationError(f"The {version_type} version {version} specified is not a valid version format (ex. 1.2.3)") + raise ValidationError( + f"The {version_type} version {version} specified is not a valid version format (ex. 1.2.3)" + ) return tuple(int(part) for part in version_parts) # type: ignore # We already verified there were 3 parts and they are all digits def _stream_configs(self, manifest: Mapping[str, Any]) -> List[Dict[str, Any]]: diff --git a/airbyte_cdk/sources/declarative/migrations/legacy_to_per_partition_state_migration.py b/airbyte_cdk/sources/declarative/migrations/legacy_to_per_partition_state_migration.py index 361f81bf..38546168 100644 --- a/airbyte_cdk/sources/declarative/migrations/legacy_to_per_partition_state_migration.py +++ b/airbyte_cdk/sources/declarative/migrations/legacy_to_per_partition_state_migration.py @@ -43,7 +43,9 @@ def __init__( self._partition_key_field = InterpolatedString.create( self._get_partition_field(self._partition_router), parameters=self._parameters ).eval(self._config) - self._cursor_field = InterpolatedString.create(self._cursor.cursor_field, parameters=self._parameters).eval(self._config) + self._cursor_field = InterpolatedString.create( + self._cursor.cursor_field, parameters=self._parameters + ).eval(self._config) def _get_partition_field(self, partition_router: SubstreamPartitionRouter) -> str: parent_stream_config = partition_router.parent_stream_configs[0] @@ -85,5 +87,8 @@ def should_migrate(self, stream_state: Mapping[str, Any]) -> bool: return True def migrate(self, stream_state: Mapping[str, Any]) -> Mapping[str, Any]: - states = [{"partition": {self._partition_key_field: key}, "cursor": value} for key, value in stream_state.items()] + states = [ + {"partition": {self._partition_key_field: key}, "cursor": value} + for key, value in stream_state.items() + ] return {"states": states} diff --git a/airbyte_cdk/sources/declarative/models/declarative_component_schema.py b/airbyte_cdk/sources/declarative/models/declarative_component_schema.py index 75f34878..43848eae 100644 --- a/airbyte_cdk/sources/declarative/models/declarative_component_schema.py +++ b/airbyte_cdk/sources/declarative/models/declarative_component_schema.py @@ -1071,7 +1071,9 @@ class ApiKeyAuthenticator(BaseModel): class AuthFlow(BaseModel): - auth_flow_type: Optional[AuthFlowType] = Field(None, description="The type of auth to use", title="Auth flow type") + auth_flow_type: Optional[AuthFlowType] = Field( + None, description="The type of auth to use", title="Auth flow type" + ) predicate_key: Optional[List[str]] = Field( None, description="JSON path to a field in the connectorSpecification that should exist for the advanced auth to be applicable.", @@ -1214,7 +1216,9 @@ class DefaultErrorHandler(BaseModel): class DefaultPaginator(BaseModel): type: Literal["DefaultPaginator"] - pagination_strategy: Union[CursorPagination, CustomPaginationStrategy, OffsetIncrement, PageIncrement] = Field( + pagination_strategy: Union[ + CursorPagination, CustomPaginationStrategy, OffsetIncrement, PageIncrement + ] = Field( ..., description="Strategy defining how records are paginated.", title="Pagination Strategy", @@ -1383,18 +1387,26 @@ class Config: title="Incremental Sync", ) name: Optional[str] = Field("", description="The stream name.", example=["Users"], title="Name") - primary_key: Optional[PrimaryKey] = Field("", description="The primary key of the stream.", title="Primary Key") - schema_loader: Optional[Union[InlineSchemaLoader, JsonFileSchemaLoader, CustomSchemaLoader]] = Field( - None, - description="Component used to retrieve the schema for the current stream.", - title="Schema Loader", - ) - transformations: Optional[List[Union[AddFields, CustomTransformation, RemoveFields, KeysToLower]]] = Field( + primary_key: Optional[PrimaryKey] = Field( + "", description="The primary key of the stream.", title="Primary Key" + ) + schema_loader: Optional[Union[InlineSchemaLoader, JsonFileSchemaLoader, CustomSchemaLoader]] = ( + Field( + None, + description="Component used to retrieve the schema for the current stream.", + title="Schema Loader", + ) + ) + transformations: Optional[ + List[Union[AddFields, CustomTransformation, RemoveFields, KeysToLower]] + ] = Field( None, description="A list of transformations to be applied to each output record.", title="Transformations", ) - state_migrations: Optional[List[Union[LegacyToPerPartitionStateMigration, CustomStateMigration]]] = Field( + state_migrations: Optional[ + List[Union[LegacyToPerPartitionStateMigration, CustomStateMigration]] + ] = Field( [], description="Array of state migrations to be applied on the input state", title="State Migrations", @@ -1433,12 +1445,16 @@ class SessionTokenAuthenticator(BaseModel): examples=["PT1H", "P1D"], title="Expiration Duration", ) - request_authentication: Union[SessionTokenRequestApiKeyAuthenticator, SessionTokenRequestBearerAuthenticator] = Field( + request_authentication: Union[ + SessionTokenRequestApiKeyAuthenticator, SessionTokenRequestBearerAuthenticator + ] = Field( ..., description="Authentication method to use for requests sent to the API, specifying how to inject the session token.", title="Data Request Authentication", ) - decoder: Optional[Union[JsonDecoder, XmlDecoder]] = Field(None, description="Component used to decode the response.", title="Decoder") + decoder: Optional[Union[JsonDecoder, XmlDecoder]] = Field( + None, description="Component used to decode the response.", title="Decoder" + ) parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") @@ -1481,7 +1497,9 @@ class HttpRequester(BaseModel): description="Authentication method to use for requests sent to the API.", title="Authenticator", ) - error_handler: Optional[Union[DefaultErrorHandler, CustomErrorHandler, CompositeErrorHandler]] = Field( + error_handler: Optional[ + Union[DefaultErrorHandler, CustomErrorHandler, CompositeErrorHandler] + ] = Field( None, description="Error handler component that defines how to handle errors.", title="Error Handler", @@ -1545,7 +1563,9 @@ class ParentStreamConfig(BaseModel): examples=["id", "{{ config['parent_record_id'] }}"], title="Parent Key", ) - stream: DeclarativeStream = Field(..., description="Reference to the parent stream.", title="Parent Stream") + stream: DeclarativeStream = Field( + ..., description="Reference to the parent stream.", title="Parent Stream" + ) partition_field: str = Field( ..., description="While iterating over parent records during a sync, the parent_key value can be referenced by using this field.", @@ -1614,7 +1634,9 @@ class AsyncRetriever(BaseModel): ..., description="Component that describes how to extract records from a HTTP response.", ) - status_mapping: AsyncJobStatusMap = Field(..., description="Async Job Status to Airbyte CDK Async Job Status mapping.") + status_mapping: AsyncJobStatusMap = Field( + ..., description="Async Job Status to Airbyte CDK Async Job Status mapping." + ) status_extractor: Union[CustomRecordExtractor, DpathExtractor] = Field( ..., description="Responsible for fetching the actual status of the async job." ) diff --git a/airbyte_cdk/sources/declarative/parsers/manifest_component_transformer.py b/airbyte_cdk/sources/declarative/parsers/manifest_component_transformer.py index 7b8b221c..8cacda3d 100644 --- a/airbyte_cdk/sources/declarative/parsers/manifest_component_transformer.py +++ b/airbyte_cdk/sources/declarative/parsers/manifest_component_transformer.py @@ -77,7 +77,10 @@ class ManifestComponentTransformer: def propagate_types_and_parameters( - self, parent_field_identifier: str, declarative_component: Mapping[str, Any], parent_parameters: Mapping[str, Any] + self, + parent_field_identifier: str, + declarative_component: Mapping[str, Any], + parent_parameters: Mapping[str, Any], ) -> Mapping[str, Any]: """ Recursively transforms the specified declarative component and subcomponents to propagate parameters and insert the @@ -119,7 +122,9 @@ def propagate_types_and_parameters( # Parameters should be applied to the current component fields with the existing field taking precedence over parameters if # both exist for parameter_key, parameter_value in current_parameters.items(): - propagated_component[parameter_key] = propagated_component.get(parameter_key) or parameter_value + propagated_component[parameter_key] = ( + propagated_component.get(parameter_key) or parameter_value + ) for field_name, field_value in propagated_component.items(): if isinstance(field_value, dict): @@ -136,8 +141,12 @@ def propagate_types_and_parameters( excluded_parameter = current_parameters.pop(field_name, None) for i, element in enumerate(field_value): if isinstance(element, dict): - parent_type_field_identifier = f"{propagated_component.get('type')}.{field_name}" - field_value[i] = self.propagate_types_and_parameters(parent_type_field_identifier, element, current_parameters) + parent_type_field_identifier = ( + f"{propagated_component.get('type')}.{field_name}" + ) + field_value[i] = self.propagate_types_and_parameters( + parent_type_field_identifier, element, current_parameters + ) if excluded_parameter: current_parameters[field_name] = excluded_parameter diff --git a/airbyte_cdk/sources/declarative/parsers/manifest_reference_resolver.py b/airbyte_cdk/sources/declarative/parsers/manifest_reference_resolver.py index 66bf3d5e..045ea9a2 100644 --- a/airbyte_cdk/sources/declarative/parsers/manifest_reference_resolver.py +++ b/airbyte_cdk/sources/declarative/parsers/manifest_reference_resolver.py @@ -5,7 +5,10 @@ import re from typing import Any, Mapping, Set, Tuple, Union -from airbyte_cdk.sources.declarative.parsers.custom_exceptions import CircularReferenceException, UndefinedReferenceException +from airbyte_cdk.sources.declarative.parsers.custom_exceptions import ( + CircularReferenceException, + UndefinedReferenceException, +) REF_TAG = "$ref" @@ -105,7 +108,11 @@ def preprocess_manifest(self, manifest: Mapping[str, Any]) -> Mapping[str, Any]: def _evaluate_node(self, node: Any, manifest: Mapping[str, Any], visited: Set[Any]) -> Any: if isinstance(node, dict): - evaluated_dict = {k: self._evaluate_node(v, manifest, visited) for k, v in node.items() if not self._is_ref_key(k)} + evaluated_dict = { + k: self._evaluate_node(v, manifest, visited) + for k, v in node.items() + if not self._is_ref_key(k) + } if REF_TAG in node: # The node includes a $ref key, so we splat the referenced value(s) into the evaluated dict evaluated_ref = self._evaluate_node(node[REF_TAG], manifest, visited) diff --git a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py index e9420dfb..2812ba81 100644 --- a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py +++ b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py @@ -9,7 +9,21 @@ import inspect import re from functools import partial -from typing import Any, Callable, Dict, List, Mapping, MutableMapping, Optional, Tuple, Type, Union, get_args, get_origin, get_type_hints +from typing import ( + Any, + Callable, + Dict, + List, + Mapping, + MutableMapping, + Optional, + Tuple, + Type, + Union, + get_args, + get_origin, + get_type_hints, +) from airbyte_cdk.models import FailureType, Level from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager @@ -18,9 +32,14 @@ from airbyte_cdk.sources.declarative.async_job.repository import AsyncJobRepository from airbyte_cdk.sources.declarative.async_job.status import AsyncJobStatus from airbyte_cdk.sources.declarative.auth import DeclarativeOauth2Authenticator, JwtAuthenticator -from airbyte_cdk.sources.declarative.auth.declarative_authenticator import DeclarativeAuthenticator, NoAuth +from airbyte_cdk.sources.declarative.auth.declarative_authenticator import ( + DeclarativeAuthenticator, + NoAuth, +) from airbyte_cdk.sources.declarative.auth.jwt import JwtAlgorithm -from airbyte_cdk.sources.declarative.auth.oauth import DeclarativeSingleUseRefreshTokenOauth2Authenticator +from airbyte_cdk.sources.declarative.auth.oauth import ( + DeclarativeSingleUseRefreshTokenOauth2Authenticator, +) from airbyte_cdk.sources.declarative.auth.selective_authenticator import SelectiveAuthenticator from airbyte_cdk.sources.declarative.auth.token import ( ApiKeyAuthenticator, @@ -28,7 +47,11 @@ BearerAuthenticator, LegacySessionTokenAuthenticator, ) -from airbyte_cdk.sources.declarative.auth.token_provider import InterpolatedStringTokenProvider, SessionTokenProvider, TokenProvider +from airbyte_cdk.sources.declarative.auth.token_provider import ( + InterpolatedStringTokenProvider, + SessionTokenProvider, + TokenProvider, +) from airbyte_cdk.sources.declarative.checks import CheckStream from airbyte_cdk.sources.declarative.concurrency_level import ConcurrencyLevel from airbyte_cdk.sources.declarative.datetime import MinMaxDatetime @@ -41,9 +64,18 @@ PaginationDecoderDecorator, XmlDecoder, ) -from airbyte_cdk.sources.declarative.extractors import DpathExtractor, RecordFilter, RecordSelector, ResponseToFileExtractor -from airbyte_cdk.sources.declarative.extractors.record_filter import ClientSideIncrementalRecordFilterDecorator -from airbyte_cdk.sources.declarative.extractors.record_selector import SCHEMA_TRANSFORMER_TYPE_MAPPING +from airbyte_cdk.sources.declarative.extractors import ( + DpathExtractor, + RecordFilter, + RecordSelector, + ResponseToFileExtractor, +) +from airbyte_cdk.sources.declarative.extractors.record_filter import ( + ClientSideIncrementalRecordFilterDecorator, +) +from airbyte_cdk.sources.declarative.extractors.record_selector import ( + SCHEMA_TRANSFORMER_TYPE_MAPPING, +) from airbyte_cdk.sources.declarative.incremental import ( ChildPartitionResumableFullRefreshCursor, CursorFactory, @@ -56,88 +88,216 @@ ) from airbyte_cdk.sources.declarative.interpolation import InterpolatedString from airbyte_cdk.sources.declarative.interpolation.interpolated_mapping import InterpolatedMapping -from airbyte_cdk.sources.declarative.migrations.legacy_to_per_partition_state_migration import LegacyToPerPartitionStateMigration +from airbyte_cdk.sources.declarative.migrations.legacy_to_per_partition_state_migration import ( + LegacyToPerPartitionStateMigration, +) from airbyte_cdk.sources.declarative.models import CustomStateMigration -from airbyte_cdk.sources.declarative.models.declarative_component_schema import AddedFieldDefinition as AddedFieldDefinitionModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import AddFields as AddFieldsModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import ApiKeyAuthenticator as ApiKeyAuthenticatorModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import AsyncJobStatusMap as AsyncJobStatusMapModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import AsyncRetriever as AsyncRetrieverModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import BasicHttpAuthenticator as BasicHttpAuthenticatorModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import BearerAuthenticator as BearerAuthenticatorModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import CheckStream as CheckStreamModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import CompositeErrorHandler as CompositeErrorHandlerModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import ConcurrencyLevel as ConcurrencyLevelModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import ConstantBackoffStrategy as ConstantBackoffStrategyModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import CursorPagination as CursorPaginationModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import CustomAuthenticator as CustomAuthenticatorModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import CustomBackoffStrategy as CustomBackoffStrategyModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import CustomErrorHandler as CustomErrorHandlerModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import CustomIncrementalSync as CustomIncrementalSyncModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import CustomPaginationStrategy as CustomPaginationStrategyModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import CustomPartitionRouter as CustomPartitionRouterModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import CustomRecordExtractor as CustomRecordExtractorModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import CustomRecordFilter as CustomRecordFilterModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import CustomRequester as CustomRequesterModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import CustomRetriever as CustomRetrieverModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import CustomSchemaLoader as CustomSchemaLoader -from airbyte_cdk.sources.declarative.models.declarative_component_schema import CustomTransformation as CustomTransformationModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import DatetimeBasedCursor as DatetimeBasedCursorModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import DeclarativeStream as DeclarativeStreamModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import DefaultErrorHandler as DefaultErrorHandlerModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import DefaultPaginator as DefaultPaginatorModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import DpathExtractor as DpathExtractorModel +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + AddedFieldDefinition as AddedFieldDefinitionModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + AddFields as AddFieldsModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + ApiKeyAuthenticator as ApiKeyAuthenticatorModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + AsyncJobStatusMap as AsyncJobStatusMapModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + AsyncRetriever as AsyncRetrieverModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + BasicHttpAuthenticator as BasicHttpAuthenticatorModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + BearerAuthenticator as BearerAuthenticatorModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + CheckStream as CheckStreamModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + CompositeErrorHandler as CompositeErrorHandlerModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + ConcurrencyLevel as ConcurrencyLevelModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + ConstantBackoffStrategy as ConstantBackoffStrategyModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + CursorPagination as CursorPaginationModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + CustomAuthenticator as CustomAuthenticatorModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + CustomBackoffStrategy as CustomBackoffStrategyModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + CustomErrorHandler as CustomErrorHandlerModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + CustomIncrementalSync as CustomIncrementalSyncModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + CustomPaginationStrategy as CustomPaginationStrategyModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + CustomPartitionRouter as CustomPartitionRouterModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + CustomRecordExtractor as CustomRecordExtractorModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + CustomRecordFilter as CustomRecordFilterModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + CustomRequester as CustomRequesterModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + CustomRetriever as CustomRetrieverModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + CustomSchemaLoader as CustomSchemaLoader, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + CustomTransformation as CustomTransformationModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + DatetimeBasedCursor as DatetimeBasedCursorModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + DeclarativeStream as DeclarativeStreamModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + DefaultErrorHandler as DefaultErrorHandlerModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + DefaultPaginator as DefaultPaginatorModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + DpathExtractor as DpathExtractorModel, +) from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( ExponentialBackoffStrategy as ExponentialBackoffStrategyModel, ) -from airbyte_cdk.sources.declarative.models.declarative_component_schema import HttpRequester as HttpRequesterModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import HttpResponseFilter as HttpResponseFilterModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import InlineSchemaLoader as InlineSchemaLoaderModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import IterableDecoder as IterableDecoderModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import JsonDecoder as JsonDecoderModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import JsonFileSchemaLoader as JsonFileSchemaLoaderModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import JsonlDecoder as JsonlDecoderModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import JwtAuthenticator as JwtAuthenticatorModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import JwtHeaders as JwtHeadersModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import JwtPayload as JwtPayloadModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import KeysToLower as KeysToLowerModel +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + HttpRequester as HttpRequesterModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + HttpResponseFilter as HttpResponseFilterModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + InlineSchemaLoader as InlineSchemaLoaderModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + IterableDecoder as IterableDecoderModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + JsonDecoder as JsonDecoderModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + JsonFileSchemaLoader as JsonFileSchemaLoaderModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + JsonlDecoder as JsonlDecoderModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + JwtAuthenticator as JwtAuthenticatorModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + JwtHeaders as JwtHeadersModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + JwtPayload as JwtPayloadModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + KeysToLower as KeysToLowerModel, +) from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( LegacySessionTokenAuthenticator as LegacySessionTokenAuthenticatorModel, ) from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( LegacyToPerPartitionStateMigration as LegacyToPerPartitionStateMigrationModel, ) -from airbyte_cdk.sources.declarative.models.declarative_component_schema import ListPartitionRouter as ListPartitionRouterModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import MinMaxDatetime as MinMaxDatetimeModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import NoAuth as NoAuthModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import NoPagination as NoPaginationModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import OAuthAuthenticator as OAuthAuthenticatorModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import OffsetIncrement as OffsetIncrementModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import PageIncrement as PageIncrementModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import ParentStreamConfig as ParentStreamConfigModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import RecordFilter as RecordFilterModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import RecordSelector as RecordSelectorModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import RemoveFields as RemoveFieldsModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import RequestOption as RequestOptionModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import RequestPath as RequestPathModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import SelectiveAuthenticator as SelectiveAuthenticatorModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import SessionTokenAuthenticator as SessionTokenAuthenticatorModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import SimpleRetriever as SimpleRetrieverModel +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + ListPartitionRouter as ListPartitionRouterModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + MinMaxDatetime as MinMaxDatetimeModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + NoAuth as NoAuthModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + NoPagination as NoPaginationModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + OAuthAuthenticator as OAuthAuthenticatorModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + OffsetIncrement as OffsetIncrementModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + PageIncrement as PageIncrementModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + ParentStreamConfig as ParentStreamConfigModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + RecordFilter as RecordFilterModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + RecordSelector as RecordSelectorModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + RemoveFields as RemoveFieldsModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + RequestOption as RequestOptionModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + RequestPath as RequestPathModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + SelectiveAuthenticator as SelectiveAuthenticatorModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + SessionTokenAuthenticator as SessionTokenAuthenticatorModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + SimpleRetriever as SimpleRetrieverModel, +) from airbyte_cdk.sources.declarative.models.declarative_component_schema import Spec as SpecModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import SubstreamPartitionRouter as SubstreamPartitionRouterModel +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + SubstreamPartitionRouter as SubstreamPartitionRouterModel, +) from airbyte_cdk.sources.declarative.models.declarative_component_schema import ValueType -from airbyte_cdk.sources.declarative.models.declarative_component_schema import WaitTimeFromHeader as WaitTimeFromHeaderModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import WaitUntilTimeFromHeader as WaitUntilTimeFromHeaderModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import XmlDecoder as XmlDecoderModel +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + WaitTimeFromHeader as WaitTimeFromHeaderModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + WaitUntilTimeFromHeader as WaitUntilTimeFromHeaderModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + XmlDecoder as XmlDecoderModel, +) from airbyte_cdk.sources.declarative.partition_routers import ( CartesianProductStreamSlicer, ListPartitionRouter, SinglePartitionRouter, SubstreamPartitionRouter, ) -from airbyte_cdk.sources.declarative.partition_routers.substream_partition_router import ParentStreamConfig +from airbyte_cdk.sources.declarative.partition_routers.substream_partition_router import ( + ParentStreamConfig, +) from airbyte_cdk.sources.declarative.requesters import HttpRequester, RequestOption -from airbyte_cdk.sources.declarative.requesters.error_handlers import CompositeErrorHandler, DefaultErrorHandler, HttpResponseFilter +from airbyte_cdk.sources.declarative.requesters.error_handlers import ( + CompositeErrorHandler, + DefaultErrorHandler, + HttpResponseFilter, +) from airbyte_cdk.sources.declarative.requesters.error_handlers.backoff_strategies import ( ConstantBackoffStrategy, ExponentialBackoffStrategy, @@ -145,7 +305,11 @@ WaitUntilTimeFromHeaderBackoffStrategy, ) from airbyte_cdk.sources.declarative.requesters.http_job_repository import AsyncHttpJobRepository -from airbyte_cdk.sources.declarative.requesters.paginators import DefaultPaginator, NoPagination, PaginatorTestReadDecorator +from airbyte_cdk.sources.declarative.requesters.paginators import ( + DefaultPaginator, + NoPagination, + PaginatorTestReadDecorator, +) from airbyte_cdk.sources.declarative.requesters.paginators.strategies import ( CursorPaginationStrategy, CursorStopCondition, @@ -162,14 +326,32 @@ ) from airbyte_cdk.sources.declarative.requesters.request_path import RequestPath from airbyte_cdk.sources.declarative.requesters.requester import HttpMethod -from airbyte_cdk.sources.declarative.retrievers import AsyncRetriever, SimpleRetriever, SimpleRetrieverTestReadDecorator -from airbyte_cdk.sources.declarative.schema import DefaultSchemaLoader, InlineSchemaLoader, JsonFileSchemaLoader +from airbyte_cdk.sources.declarative.retrievers import ( + AsyncRetriever, + SimpleRetriever, + SimpleRetrieverTestReadDecorator, +) +from airbyte_cdk.sources.declarative.schema import ( + DefaultSchemaLoader, + InlineSchemaLoader, + JsonFileSchemaLoader, +) from airbyte_cdk.sources.declarative.spec import Spec from airbyte_cdk.sources.declarative.stream_slicers import StreamSlicer -from airbyte_cdk.sources.declarative.transformations import AddFields, RecordTransformation, RemoveFields +from airbyte_cdk.sources.declarative.transformations import ( + AddFields, + RecordTransformation, + RemoveFields, +) from airbyte_cdk.sources.declarative.transformations.add_fields import AddedFieldDefinition -from airbyte_cdk.sources.declarative.transformations.keys_to_lower_transformation import KeysToLowerTransformation -from airbyte_cdk.sources.message import InMemoryMessageRepository, LogAppenderMessageRepositoryDecorator, MessageRepository +from airbyte_cdk.sources.declarative.transformations.keys_to_lower_transformation import ( + KeysToLowerTransformation, +) +from airbyte_cdk.sources.message import ( + InMemoryMessageRepository, + LogAppenderMessageRepositoryDecorator, + MessageRepository, +) from airbyte_cdk.sources.streams.concurrent.cursor import ConcurrentCursor, CursorField from airbyte_cdk.sources.streams.concurrent.state_converters.datetime_stream_state_converter import ( CustomFormatConcurrentStreamStateConverter, @@ -276,7 +458,11 @@ def _init_mappings(self) -> None: self.TYPE_NAME_TO_MODEL = {cls.__name__: cls for cls in self.PYDANTIC_MODEL_TO_CONSTRUCTOR} def create_component( - self, model_type: Type[BaseModel], component_definition: ComponentDefinition, config: Config, **kwargs: Any + self, + model_type: Type[BaseModel], + component_definition: ComponentDefinition, + config: Config, + **kwargs: Any, ) -> Any: """ Takes a given Pydantic model type and Mapping representing a component definition and creates a declarative component and @@ -291,26 +477,38 @@ def create_component( component_type = component_definition.get("type") if component_definition.get("type") != model_type.__name__: - raise ValueError(f"Expected manifest component of type {model_type.__name__}, but received {component_type} instead") + raise ValueError( + f"Expected manifest component of type {model_type.__name__}, but received {component_type} instead" + ) declarative_component_model = model_type.parse_obj(component_definition) if not isinstance(declarative_component_model, model_type): - raise ValueError(f"Expected {model_type.__name__} component, but received {declarative_component_model.__class__.__name__}") + raise ValueError( + f"Expected {model_type.__name__} component, but received {declarative_component_model.__class__.__name__}" + ) - return self._create_component_from_model(model=declarative_component_model, config=config, **kwargs) + return self._create_component_from_model( + model=declarative_component_model, config=config, **kwargs + ) def _create_component_from_model(self, model: BaseModel, config: Config, **kwargs: Any) -> Any: if model.__class__ not in self.PYDANTIC_MODEL_TO_CONSTRUCTOR: - raise ValueError(f"{model.__class__} with attributes {model} is not a valid component type") + raise ValueError( + f"{model.__class__} with attributes {model} is not a valid component type" + ) component_constructor = self.PYDANTIC_MODEL_TO_CONSTRUCTOR.get(model.__class__) if not component_constructor: raise ValueError(f"Could not find constructor for {model.__class__}") return component_constructor(model=model, config=config, **kwargs) @staticmethod - def create_added_field_definition(model: AddedFieldDefinitionModel, config: Config, **kwargs: Any) -> AddedFieldDefinition: - interpolated_value = InterpolatedString.create(model.value, parameters=model.parameters or {}) + def create_added_field_definition( + model: AddedFieldDefinitionModel, config: Config, **kwargs: Any + ) -> AddedFieldDefinition: + interpolated_value = InterpolatedString.create( + model.value, parameters=model.parameters or {} + ) return AddedFieldDefinition( path=model.path, value=interpolated_value, @@ -322,14 +520,18 @@ def create_add_fields(self, model: AddFieldsModel, config: Config, **kwargs: Any added_field_definitions = [ self._create_component_from_model( model=added_field_definition_model, - value_type=ModelToComponentFactory._json_schema_type_name_to_type(added_field_definition_model.value_type), + value_type=ModelToComponentFactory._json_schema_type_name_to_type( + added_field_definition_model.value_type + ), config=config, ) for added_field_definition_model in model.fields ] return AddFields(fields=added_field_definitions, parameters=model.parameters or {}) - def create_keys_to_lower_transformation(self, model: KeysToLowerModel, config: Config, **kwargs: Any) -> KeysToLowerTransformation: + def create_keys_to_lower_transformation( + self, model: KeysToLowerModel, config: Config, **kwargs: Any + ) -> KeysToLowerTransformation: return KeysToLowerTransformation() @staticmethod @@ -346,16 +548,25 @@ def _json_schema_type_name_to_type(value_type: Optional[ValueType]) -> Optional[ @staticmethod def create_api_key_authenticator( - model: ApiKeyAuthenticatorModel, config: Config, token_provider: Optional[TokenProvider] = None, **kwargs: Any + model: ApiKeyAuthenticatorModel, + config: Config, + token_provider: Optional[TokenProvider] = None, + **kwargs: Any, ) -> ApiKeyAuthenticator: if model.inject_into is None and model.header is None: - raise ValueError("Expected either inject_into or header to be set for ApiKeyAuthenticator") + raise ValueError( + "Expected either inject_into or header to be set for ApiKeyAuthenticator" + ) if model.inject_into is not None and model.header is not None: - raise ValueError("inject_into and header cannot be set both for ApiKeyAuthenticator - remove the deprecated header option") + raise ValueError( + "inject_into and header cannot be set both for ApiKeyAuthenticator - remove the deprecated header option" + ) if token_provider is not None and model.api_token != "": - raise ValueError("If token_provider is set, api_token is ignored and has to be set to empty string.") + raise ValueError( + "If token_provider is set, api_token is ignored and has to be set to empty string." + ) request_option = ( RequestOption( @@ -374,7 +585,11 @@ def create_api_key_authenticator( token_provider=( token_provider if token_provider is not None - else InterpolatedStringTokenProvider(api_token=model.api_token or "", config=config, parameters=model.parameters or {}) + else InterpolatedStringTokenProvider( + api_token=model.api_token or "", + config=config, + parameters=model.parameters or {}, + ) ), request_option=request_option, config=config, @@ -393,28 +608,44 @@ def create_legacy_to_per_partition_state_migration( f"LegacyToPerPartitionStateMigrations can only be applied on a DeclarativeStream with a SimpleRetriever. Got {type(retriever)}" ) partition_router = retriever.partition_router - if not isinstance(partition_router, (SubstreamPartitionRouterModel, CustomPartitionRouterModel)): + if not isinstance( + partition_router, (SubstreamPartitionRouterModel, CustomPartitionRouterModel) + ): raise ValueError( f"LegacyToPerPartitionStateMigrations can only be applied on a SimpleRetriever with a Substream partition router. Got {type(partition_router)}" ) if not hasattr(partition_router, "parent_stream_configs"): - raise ValueError("LegacyToPerPartitionStateMigrations can only be applied with a parent stream configuration.") + raise ValueError( + "LegacyToPerPartitionStateMigrations can only be applied with a parent stream configuration." + ) return LegacyToPerPartitionStateMigration( - declarative_stream.retriever.partition_router, declarative_stream.incremental_sync, config, declarative_stream.parameters + declarative_stream.retriever.partition_router, + declarative_stream.incremental_sync, + config, + declarative_stream.parameters, ) # type: ignore # The retriever type was already checked def create_session_token_authenticator( self, model: SessionTokenAuthenticatorModel, config: Config, name: str, **kwargs: Any ) -> Union[ApiKeyAuthenticator, BearerAuthenticator]: - decoder = self._create_component_from_model(model=model.decoder, config=config) if model.decoder else JsonDecoder(parameters={}) + decoder = ( + self._create_component_from_model(model=model.decoder, config=config) + if model.decoder + else JsonDecoder(parameters={}) + ) login_requester = self._create_component_from_model( - model=model.login_requester, config=config, name=f"{name}_login_requester", decoder=decoder + model=model.login_requester, + config=config, + name=f"{name}_login_requester", + decoder=decoder, ) token_provider = SessionTokenProvider( login_requester=login_requester, session_token_path=model.session_token_path, - expiration_duration=parse_duration(model.expiration_duration) if model.expiration_duration else None, + expiration_duration=parse_duration(model.expiration_duration) + if model.expiration_duration + else None, parameters=model.parameters or {}, message_repository=self._message_repository, decoder=decoder, @@ -427,28 +658,46 @@ def create_session_token_authenticator( ) else: return ModelToComponentFactory.create_api_key_authenticator( - ApiKeyAuthenticatorModel(type="ApiKeyAuthenticator", api_token="", inject_into=model.request_authentication.inject_into), # type: ignore # $parameters and headers default to None + ApiKeyAuthenticatorModel( + type="ApiKeyAuthenticator", + api_token="", + inject_into=model.request_authentication.inject_into, + ), # type: ignore # $parameters and headers default to None config=config, token_provider=token_provider, ) @staticmethod - def create_basic_http_authenticator(model: BasicHttpAuthenticatorModel, config: Config, **kwargs: Any) -> BasicHttpAuthenticator: + def create_basic_http_authenticator( + model: BasicHttpAuthenticatorModel, config: Config, **kwargs: Any + ) -> BasicHttpAuthenticator: return BasicHttpAuthenticator( - password=model.password or "", username=model.username, config=config, parameters=model.parameters or {} + password=model.password or "", + username=model.username, + config=config, + parameters=model.parameters or {}, ) @staticmethod def create_bearer_authenticator( - model: BearerAuthenticatorModel, config: Config, token_provider: Optional[TokenProvider] = None, **kwargs: Any + model: BearerAuthenticatorModel, + config: Config, + token_provider: Optional[TokenProvider] = None, + **kwargs: Any, ) -> BearerAuthenticator: if token_provider is not None and model.api_token != "": - raise ValueError("If token_provider is set, api_token is ignored and has to be set to empty string.") + raise ValueError( + "If token_provider is set, api_token is ignored and has to be set to empty string." + ) return BearerAuthenticator( token_provider=( token_provider if token_provider is not None - else InterpolatedStringTokenProvider(api_token=model.api_token or "", config=config, parameters=model.parameters or {}) + else InterpolatedStringTokenProvider( + api_token=model.api_token or "", + config=config, + parameters=model.parameters or {}, + ) ), config=config, parameters=model.parameters or {}, @@ -458,14 +707,21 @@ def create_bearer_authenticator( def create_check_stream(model: CheckStreamModel, config: Config, **kwargs: Any) -> CheckStream: return CheckStream(stream_names=model.stream_names, parameters={}) - def create_composite_error_handler(self, model: CompositeErrorHandlerModel, config: Config, **kwargs: Any) -> CompositeErrorHandler: + def create_composite_error_handler( + self, model: CompositeErrorHandlerModel, config: Config, **kwargs: Any + ) -> CompositeErrorHandler: error_handlers = [ - self._create_component_from_model(model=error_handler_model, config=config) for error_handler_model in model.error_handlers + self._create_component_from_model(model=error_handler_model, config=config) + for error_handler_model in model.error_handlers ] - return CompositeErrorHandler(error_handlers=error_handlers, parameters=model.parameters or {}) + return CompositeErrorHandler( + error_handlers=error_handlers, parameters=model.parameters or {} + ) @staticmethod - def create_concurrency_level(model: ConcurrencyLevelModel, config: Config, **kwargs: Any) -> ConcurrencyLevel: + def create_concurrency_level( + model: ConcurrencyLevelModel, config: Config, **kwargs: Any + ) -> ConcurrencyLevel: return ConcurrencyLevel( default_concurrency=model.default_concurrency, max_concurrency=model.max_concurrency, @@ -486,23 +742,30 @@ def create_concurrent_cursor_from_datetime_based_cursor( ) -> Tuple[ConcurrentCursor, DateTimeStreamStateConverter]: component_type = component_definition.get("type") if component_definition.get("type") != model_type.__name__: - raise ValueError(f"Expected manifest component of type {model_type.__name__}, but received {component_type} instead") + raise ValueError( + f"Expected manifest component of type {model_type.__name__}, but received {component_type} instead" + ) datetime_based_cursor_model = model_type.parse_obj(component_definition) if not isinstance(datetime_based_cursor_model, DatetimeBasedCursorModel): - raise ValueError(f"Expected {model_type.__name__} component, but received {datetime_based_cursor_model.__class__.__name__}") + raise ValueError( + f"Expected {model_type.__name__} component, but received {datetime_based_cursor_model.__class__.__name__}" + ) interpolated_cursor_field = InterpolatedString.create( - datetime_based_cursor_model.cursor_field, parameters=datetime_based_cursor_model.parameters or {} + datetime_based_cursor_model.cursor_field, + parameters=datetime_based_cursor_model.parameters or {}, ) cursor_field = CursorField(interpolated_cursor_field.eval(config=config)) interpolated_partition_field_start = InterpolatedString.create( - datetime_based_cursor_model.partition_field_start or "start_time", parameters=datetime_based_cursor_model.parameters or {} + datetime_based_cursor_model.partition_field_start or "start_time", + parameters=datetime_based_cursor_model.parameters or {}, ) interpolated_partition_field_end = InterpolatedString.create( - datetime_based_cursor_model.partition_field_end or "end_time", parameters=datetime_based_cursor_model.parameters or {} + datetime_based_cursor_model.partition_field_end or "end_time", + parameters=datetime_based_cursor_model.parameters or {}, ) slice_boundary_fields = ( @@ -513,12 +776,17 @@ def create_concurrent_cursor_from_datetime_based_cursor( datetime_format = datetime_based_cursor_model.datetime_format cursor_granularity = ( - parse_duration(datetime_based_cursor_model.cursor_granularity) if datetime_based_cursor_model.cursor_granularity else None + parse_duration(datetime_based_cursor_model.cursor_granularity) + if datetime_based_cursor_model.cursor_granularity + else None ) lookback_window = None interpolated_lookback_window = ( - InterpolatedString.create(datetime_based_cursor_model.lookback_window, parameters=datetime_based_cursor_model.parameters or {}) + InterpolatedString.create( + datetime_based_cursor_model.lookback_window, + parameters=datetime_based_cursor_model.parameters or {}, + ) if datetime_based_cursor_model.lookback_window else None ) @@ -538,21 +806,30 @@ def create_concurrent_cursor_from_datetime_based_cursor( start_date_runtime_value: Union[InterpolatedString, str, MinMaxDatetime] if isinstance(datetime_based_cursor_model.start_datetime, MinMaxDatetimeModel): - start_date_runtime_value = self.create_min_max_datetime(model=datetime_based_cursor_model.start_datetime, config=config) + start_date_runtime_value = self.create_min_max_datetime( + model=datetime_based_cursor_model.start_datetime, config=config + ) else: start_date_runtime_value = datetime_based_cursor_model.start_datetime end_date_runtime_value: Optional[Union[InterpolatedString, str, MinMaxDatetime]] if isinstance(datetime_based_cursor_model.end_datetime, MinMaxDatetimeModel): - end_date_runtime_value = self.create_min_max_datetime(model=datetime_based_cursor_model.end_datetime, config=config) + end_date_runtime_value = self.create_min_max_datetime( + model=datetime_based_cursor_model.end_datetime, config=config + ) else: end_date_runtime_value = datetime_based_cursor_model.end_datetime interpolated_start_date = MinMaxDatetime.create( - interpolated_string_or_min_max_datetime=start_date_runtime_value, parameters=datetime_based_cursor_model.parameters + interpolated_string_or_min_max_datetime=start_date_runtime_value, + parameters=datetime_based_cursor_model.parameters, ) interpolated_end_date = ( - None if not end_date_runtime_value else MinMaxDatetime.create(end_date_runtime_value, datetime_based_cursor_model.parameters) + None + if not end_date_runtime_value + else MinMaxDatetime.create( + end_date_runtime_value, datetime_based_cursor_model.parameters + ) ) # If datetime format is not specified then start/end datetime should inherit it from the stream slicer @@ -563,10 +840,14 @@ def create_concurrent_cursor_from_datetime_based_cursor( start_date = interpolated_start_date.get_datetime(config=config) end_date_provider = ( - partial(interpolated_end_date.get_datetime, config) if interpolated_end_date else connector_state_converter.get_end_provider() + partial(interpolated_end_date.get_datetime, config) + if interpolated_end_date + else connector_state_converter.get_end_provider() ) - if (datetime_based_cursor_model.step and not datetime_based_cursor_model.cursor_granularity) or ( + if ( + datetime_based_cursor_model.step and not datetime_based_cursor_model.cursor_granularity + ) or ( not datetime_based_cursor_model.step and datetime_based_cursor_model.cursor_granularity ): raise ValueError( @@ -577,7 +858,10 @@ def create_concurrent_cursor_from_datetime_based_cursor( # When step is not defined, default to a step size from the starting date to the present moment step_length = datetime.timedelta.max interpolated_step = ( - InterpolatedString.create(datetime_based_cursor_model.step, parameters=datetime_based_cursor_model.parameters or {}) + InterpolatedString.create( + datetime_based_cursor_model.step, + parameters=datetime_based_cursor_model.parameters or {}, + ) if datetime_based_cursor_model.step else None ) @@ -606,7 +890,9 @@ def create_concurrent_cursor_from_datetime_based_cursor( ) @staticmethod - def create_constant_backoff_strategy(model: ConstantBackoffStrategyModel, config: Config, **kwargs: Any) -> ConstantBackoffStrategy: + def create_constant_backoff_strategy( + model: ConstantBackoffStrategyModel, config: Config, **kwargs: Any + ) -> ConstantBackoffStrategy: return ConstantBackoffStrategy( backoff_time_in_seconds=model.backoff_time_in_seconds, config=config, @@ -624,7 +910,9 @@ def create_cursor_pagination( decoder_to_use = decoder else: if not isinstance(decoder, (JsonDecoder, XmlDecoder)): - raise ValueError(f"Provided decoder of {type(decoder)=} is not supported. Please set JsonDecoder or XmlDecoder instead.") + raise ValueError( + f"Provided decoder of {type(decoder)=} is not supported. Please set JsonDecoder or XmlDecoder instead." + ) decoder_to_use = PaginationDecoderDecorator(decoder=decoder) return CursorPaginationStrategy( @@ -660,18 +948,28 @@ def create_custom_component(self, model: Any, config: Config, **kwargs: Any) -> # the custom component and this code performs a second parse to convert the sub-fields first into models, then declarative components for model_field, model_value in model_args.items(): # If a custom component field doesn't have a type set, we try to use the type hints to infer the type - if isinstance(model_value, dict) and "type" not in model_value and model_field in component_fields: - derived_type = self._derive_component_type_from_type_hints(component_fields.get(model_field)) + if ( + isinstance(model_value, dict) + and "type" not in model_value + and model_field in component_fields + ): + derived_type = self._derive_component_type_from_type_hints( + component_fields.get(model_field) + ) if derived_type: model_value["type"] = derived_type if self._is_component(model_value): - model_args[model_field] = self._create_nested_component(model, model_field, model_value, config) + model_args[model_field] = self._create_nested_component( + model, model_field, model_value, config + ) elif isinstance(model_value, list): vals = [] for v in model_value: if isinstance(v, dict) and "type" not in v and model_field in component_fields: - derived_type = self._derive_component_type_from_type_hints(component_fields.get(model_field)) + derived_type = self._derive_component_type_from_type_hints( + component_fields.get(model_field) + ) if derived_type: v["type"] = derived_type if self._is_component(v): @@ -680,7 +978,11 @@ def create_custom_component(self, model: Any, config: Config, **kwargs: Any) -> vals.append(v) model_args[model_field] = vals - kwargs = {class_field: model_args[class_field] for class_field in component_fields.keys() if class_field in model_args} + kwargs = { + class_field: model_args[class_field] + for class_field in component_fields.keys() + if class_field in model_args + } return custom_component_class(**kwargs) @staticmethod @@ -724,7 +1026,9 @@ def _extract_missing_parameters(error: TypeError) -> List[str]: else: return [] - def _create_nested_component(self, model: Any, model_field: str, model_value: Any, config: Config) -> Any: + def _create_nested_component( + self, model: Any, model_field: str, model_value: Any, config: Config + ) -> Any: type_name = model_value.get("type", None) if not type_name: # If no type is specified, we can assume this is a dictionary object which can be returned instead of a subcomponent @@ -743,16 +1047,29 @@ def _create_nested_component(self, model: Any, model_field: str, model_value: An model_constructor = self.PYDANTIC_MODEL_TO_CONSTRUCTOR.get(parsed_model.__class__) constructor_kwargs = inspect.getfullargspec(model_constructor).kwonlyargs model_parameters = model_value.get("$parameters", {}) - matching_parameters = {kwarg: model_parameters[kwarg] for kwarg in constructor_kwargs if kwarg in model_parameters} - return self._create_component_from_model(model=parsed_model, config=config, **matching_parameters) + matching_parameters = { + kwarg: model_parameters[kwarg] + for kwarg in constructor_kwargs + if kwarg in model_parameters + } + return self._create_component_from_model( + model=parsed_model, config=config, **matching_parameters + ) except TypeError as error: missing_parameters = self._extract_missing_parameters(error) if missing_parameters: raise ValueError( f"Error creating component '{type_name}' with parent custom component {model.class_name}: Please provide " - + ", ".join((f"{type_name}.$parameters.{parameter}" for parameter in missing_parameters)) + + ", ".join( + ( + f"{type_name}.$parameters.{parameter}" + for parameter in missing_parameters + ) + ) ) - raise TypeError(f"Error creating component '{type_name}' with parent custom component {model.class_name}: {error}") + raise TypeError( + f"Error creating component '{type_name}' with parent custom component {model.class_name}: {error}" + ) else: raise ValueError( f"Error creating custom component {model.class_name}. Subcomponent creation has not been implemented for '{type_name}'" @@ -762,18 +1079,26 @@ def _create_nested_component(self, model: Any, model_field: str, model_value: An def _is_component(model_value: Any) -> bool: return isinstance(model_value, dict) and model_value.get("type") is not None - def create_datetime_based_cursor(self, model: DatetimeBasedCursorModel, config: Config, **kwargs: Any) -> DatetimeBasedCursor: + def create_datetime_based_cursor( + self, model: DatetimeBasedCursorModel, config: Config, **kwargs: Any + ) -> DatetimeBasedCursor: start_datetime: Union[str, MinMaxDatetime] = ( - model.start_datetime if isinstance(model.start_datetime, str) else self.create_min_max_datetime(model.start_datetime, config) + model.start_datetime + if isinstance(model.start_datetime, str) + else self.create_min_max_datetime(model.start_datetime, config) ) end_datetime: Union[str, MinMaxDatetime, None] = None if model.is_data_feed and model.end_datetime: raise ValueError("Data feed does not support end_datetime") if model.is_data_feed and model.is_client_side_incremental: - raise ValueError("`Client side incremental` cannot be applied with `data feed`. Choose only 1 from them.") + raise ValueError( + "`Client side incremental` cannot be applied with `data feed`. Choose only 1 from them." + ) if model.end_datetime: end_datetime = ( - model.end_datetime if isinstance(model.end_datetime, str) else self.create_min_max_datetime(model.end_datetime, config) + model.end_datetime + if isinstance(model.end_datetime, str) + else self.create_min_max_datetime(model.end_datetime, config) ) end_time_option = ( @@ -797,7 +1122,9 @@ def create_datetime_based_cursor(self, model: DatetimeBasedCursorModel, config: return DatetimeBasedCursor( cursor_field=model.cursor_field, - cursor_datetime_formats=model.cursor_datetime_formats if model.cursor_datetime_formats else [], + cursor_datetime_formats=model.cursor_datetime_formats + if model.cursor_datetime_formats + else [], cursor_granularity=model.cursor_granularity, datetime_format=model.datetime_format, end_datetime=end_datetime, @@ -814,7 +1141,9 @@ def create_datetime_based_cursor(self, model: DatetimeBasedCursorModel, config: parameters=model.parameters or {}, ) - def create_declarative_stream(self, model: DeclarativeStreamModel, config: Config, **kwargs: Any) -> DeclarativeStream: + def create_declarative_stream( + self, model: DeclarativeStreamModel, config: Config, **kwargs: Any + ) -> DeclarativeStream: # When constructing a declarative stream, we assemble the incremental_sync component and retriever's partition_router field # components if they exist into a single CartesianProductStreamSlicer. This is then passed back as an argument when constructing the # Retriever. This is done in the declarative stream not the retriever to support custom retrievers. The custom create methods in @@ -823,7 +1152,9 @@ def create_declarative_stream(self, model: DeclarativeStreamModel, config: Confi primary_key = model.primary_key.__root__ if model.primary_key else None stop_condition_on_cursor = ( - model.incremental_sync and hasattr(model.incremental_sync, "is_data_feed") and model.incremental_sync.is_data_feed + model.incremental_sync + and hasattr(model.incremental_sync, "is_data_feed") + and model.incremental_sync.is_data_feed ) client_side_incremental_sync = None if ( @@ -831,13 +1162,25 @@ def create_declarative_stream(self, model: DeclarativeStreamModel, config: Confi and hasattr(model.incremental_sync, "is_client_side_incremental") and model.incremental_sync.is_client_side_incremental ): - supported_slicers = (DatetimeBasedCursor, GlobalSubstreamCursor, PerPartitionWithGlobalCursor) + supported_slicers = ( + DatetimeBasedCursor, + GlobalSubstreamCursor, + PerPartitionWithGlobalCursor, + ) if combined_slicers and not isinstance(combined_slicers, supported_slicers): - raise ValueError("Unsupported Slicer is used. PerPartitionWithGlobalCursor should be used here instead") + raise ValueError( + "Unsupported Slicer is used. PerPartitionWithGlobalCursor should be used here instead" + ) client_side_incremental_sync = { - "date_time_based_cursor": self._create_component_from_model(model=model.incremental_sync, config=config), + "date_time_based_cursor": self._create_component_from_model( + model=model.incremental_sync, config=config + ), "substream_cursor": ( - combined_slicers if isinstance(combined_slicers, (PerPartitionWithGlobalCursor, GlobalSubstreamCursor)) else None + combined_slicers + if isinstance( + combined_slicers, (PerPartitionWithGlobalCursor, GlobalSubstreamCursor) + ) + else None ), } @@ -877,7 +1220,9 @@ def create_declarative_stream(self, model: DeclarativeStreamModel, config: Confi transformations = [] if model.transformations: for transformation_model in model.transformations: - transformations.append(self._create_component_from_model(model=transformation_model, config=config)) + transformations.append( + self._create_component_from_model(model=transformation_model, config=config) + ) retriever = self._create_component_from_model( model=model.retriever, config=config, @@ -900,7 +1245,9 @@ def create_declarative_stream(self, model: DeclarativeStreamModel, config: Confi state_transformations = [] if model.schema_loader: - schema_loader = self._create_component_from_model(model=model.schema_loader, config=config) + schema_loader = self._create_component_from_model( + model=model.schema_loader, config=config + ) else: options = model.parameters or {} if "name" not in options: @@ -918,7 +1265,9 @@ def create_declarative_stream(self, model: DeclarativeStreamModel, config: Confi parameters=model.parameters or {}, ) - def _merge_stream_slicers(self, model: DeclarativeStreamModel, config: Config) -> Optional[StreamSlicer]: + def _merge_stream_slicers( + self, model: DeclarativeStreamModel, config: Config + ) -> Optional[StreamSlicer]: stream_slicer = None if ( hasattr(model.retriever, "partition_router") @@ -929,50 +1278,85 @@ def _merge_stream_slicers(self, model: DeclarativeStreamModel, config: Config) - if isinstance(stream_slicer_model, list): stream_slicer = CartesianProductStreamSlicer( - [self._create_component_from_model(model=slicer, config=config) for slicer in stream_slicer_model], parameters={} + [ + self._create_component_from_model(model=slicer, config=config) + for slicer in stream_slicer_model + ], + parameters={}, ) else: - stream_slicer = self._create_component_from_model(model=stream_slicer_model, config=config) + stream_slicer = self._create_component_from_model( + model=stream_slicer_model, config=config + ) if model.incremental_sync and stream_slicer: incremental_sync_model = model.incremental_sync - if hasattr(incremental_sync_model, "global_substream_cursor") and incremental_sync_model.global_substream_cursor: - cursor_component = self._create_component_from_model(model=incremental_sync_model, config=config) - return GlobalSubstreamCursor(stream_cursor=cursor_component, partition_router=stream_slicer) + if ( + hasattr(incremental_sync_model, "global_substream_cursor") + and incremental_sync_model.global_substream_cursor + ): + cursor_component = self._create_component_from_model( + model=incremental_sync_model, config=config + ) + return GlobalSubstreamCursor( + stream_cursor=cursor_component, partition_router=stream_slicer + ) else: - cursor_component = self._create_component_from_model(model=incremental_sync_model, config=config) + cursor_component = self._create_component_from_model( + model=incremental_sync_model, config=config + ) return PerPartitionWithGlobalCursor( cursor_factory=CursorFactory( - lambda: self._create_component_from_model(model=incremental_sync_model, config=config), + lambda: self._create_component_from_model( + model=incremental_sync_model, config=config + ), ), partition_router=stream_slicer, stream_cursor=cursor_component, ) elif model.incremental_sync: - return self._create_component_from_model(model=model.incremental_sync, config=config) if model.incremental_sync else None + return ( + self._create_component_from_model(model=model.incremental_sync, config=config) + if model.incremental_sync + else None + ) elif stream_slicer: # For the Full-Refresh sub-streams, we use the nested `ChildPartitionResumableFullRefreshCursor` return PerPartitionCursor( - cursor_factory=CursorFactory(create_function=partial(ChildPartitionResumableFullRefreshCursor, {})), + cursor_factory=CursorFactory( + create_function=partial(ChildPartitionResumableFullRefreshCursor, {}) + ), partition_router=stream_slicer, ) - elif hasattr(model.retriever, "paginator") and model.retriever.paginator and not stream_slicer: + elif ( + hasattr(model.retriever, "paginator") + and model.retriever.paginator + and not stream_slicer + ): # For the regular Full-Refresh streams, we use the high level `ResumableFullRefreshCursor` return ResumableFullRefreshCursor(parameters={}) else: return None - def create_default_error_handler(self, model: DefaultErrorHandlerModel, config: Config, **kwargs: Any) -> DefaultErrorHandler: + def create_default_error_handler( + self, model: DefaultErrorHandlerModel, config: Config, **kwargs: Any + ) -> DefaultErrorHandler: backoff_strategies = [] if model.backoff_strategies: for backoff_strategy_model in model.backoff_strategies: - backoff_strategies.append(self._create_component_from_model(model=backoff_strategy_model, config=config)) + backoff_strategies.append( + self._create_component_from_model(model=backoff_strategy_model, config=config) + ) response_filters = [] if model.response_filters: for response_filter_model in model.response_filters: - response_filters.append(self._create_component_from_model(model=response_filter_model, config=config)) - response_filters.append(HttpResponseFilter(config=config, parameters=model.parameters or {})) + response_filters.append( + self._create_component_from_model(model=response_filter_model, config=config) + ) + response_filters.append( + HttpResponseFilter(config=config, parameters=model.parameters or {}) + ) return DefaultErrorHandler( backoff_strategies=backoff_strategies, @@ -993,17 +1377,25 @@ def create_default_paginator( ) -> Union[DefaultPaginator, PaginatorTestReadDecorator]: if decoder: if not isinstance(decoder, (JsonDecoder, XmlDecoder)): - raise ValueError(f"Provided decoder of {type(decoder)=} is not supported. Please set JsonDecoder or XmlDecoder instead.") + raise ValueError( + f"Provided decoder of {type(decoder)=} is not supported. Please set JsonDecoder or XmlDecoder instead." + ) decoder_to_use = PaginationDecoderDecorator(decoder=decoder) else: decoder_to_use = PaginationDecoderDecorator(decoder=JsonDecoder(parameters={})) page_size_option = ( - self._create_component_from_model(model=model.page_size_option, config=config) if model.page_size_option else None + self._create_component_from_model(model=model.page_size_option, config=config) + if model.page_size_option + else None ) page_token_option = ( - self._create_component_from_model(model=model.page_token_option, config=config) if model.page_token_option else None + self._create_component_from_model(model=model.page_token_option, config=config) + if model.page_token_option + else None + ) + pagination_strategy = self._create_component_from_model( + model=model.pagination_strategy, config=config, decoder=decoder_to_use ) - pagination_strategy = self._create_component_from_model(model=model.pagination_strategy, config=config, decoder=decoder_to_use) if cursor_used_for_stop_condition: pagination_strategy = StopConditionPaginationStrategyDecorator( pagination_strategy, CursorStopCondition(cursor_used_for_stop_condition) @@ -1022,29 +1414,55 @@ def create_default_paginator( return paginator def create_dpath_extractor( - self, model: DpathExtractorModel, config: Config, decoder: Optional[Decoder] = None, **kwargs: Any + self, + model: DpathExtractorModel, + config: Config, + decoder: Optional[Decoder] = None, + **kwargs: Any, ) -> DpathExtractor: if decoder: decoder_to_use = decoder else: decoder_to_use = JsonDecoder(parameters={}) model_field_path: List[Union[InterpolatedString, str]] = [x for x in model.field_path] - return DpathExtractor(decoder=decoder_to_use, field_path=model_field_path, config=config, parameters=model.parameters or {}) + return DpathExtractor( + decoder=decoder_to_use, + field_path=model_field_path, + config=config, + parameters=model.parameters or {}, + ) @staticmethod - def create_exponential_backoff_strategy(model: ExponentialBackoffStrategyModel, config: Config) -> ExponentialBackoffStrategy: - return ExponentialBackoffStrategy(factor=model.factor or 5, parameters=model.parameters or {}, config=config) + def create_exponential_backoff_strategy( + model: ExponentialBackoffStrategyModel, config: Config + ) -> ExponentialBackoffStrategy: + return ExponentialBackoffStrategy( + factor=model.factor or 5, parameters=model.parameters or {}, config=config + ) - def create_http_requester(self, model: HttpRequesterModel, decoder: Decoder, config: Config, *, name: str) -> HttpRequester: + def create_http_requester( + self, model: HttpRequesterModel, decoder: Decoder, config: Config, *, name: str + ) -> HttpRequester: authenticator = ( - self._create_component_from_model(model=model.authenticator, config=config, url_base=model.url_base, name=name, decoder=decoder) + self._create_component_from_model( + model=model.authenticator, + config=config, + url_base=model.url_base, + name=name, + decoder=decoder, + ) if model.authenticator else None ) error_handler = ( self._create_component_from_model(model=model.error_handler, config=config) if model.error_handler - else DefaultErrorHandler(backoff_strategies=[], response_filters=[], config=config, parameters=model.parameters or {}) + else DefaultErrorHandler( + backoff_strategies=[], + response_filters=[], + config=config, + parameters=model.parameters or {}, + ) ) request_options_provider = InterpolatedRequestOptionsProvider( @@ -1079,7 +1497,9 @@ def create_http_requester(self, model: HttpRequesterModel, decoder: Decoder, con ) @staticmethod - def create_http_response_filter(model: HttpResponseFilterModel, config: Config, **kwargs: Any) -> HttpResponseFilter: + def create_http_response_filter( + model: HttpResponseFilterModel, config: Config, **kwargs: Any + ) -> HttpResponseFilter: if model.action: action = ResponseAction(model.action.value) else: @@ -1103,7 +1523,9 @@ def create_http_response_filter(model: HttpResponseFilterModel, config: Config, ) @staticmethod - def create_inline_schema_loader(model: InlineSchemaLoaderModel, config: Config, **kwargs: Any) -> InlineSchemaLoader: + def create_inline_schema_loader( + model: InlineSchemaLoaderModel, config: Config, **kwargs: Any + ) -> InlineSchemaLoader: return InlineSchemaLoader(schema=model.schema_ or {}, parameters={}) @staticmethod @@ -1111,11 +1533,15 @@ def create_json_decoder(model: JsonDecoderModel, config: Config, **kwargs: Any) return JsonDecoder(parameters={}) @staticmethod - def create_jsonl_decoder(model: JsonlDecoderModel, config: Config, **kwargs: Any) -> JsonlDecoder: + def create_jsonl_decoder( + model: JsonlDecoderModel, config: Config, **kwargs: Any + ) -> JsonlDecoder: return JsonlDecoder(parameters={}) @staticmethod - def create_iterable_decoder(model: IterableDecoderModel, config: Config, **kwargs: Any) -> IterableDecoder: + def create_iterable_decoder( + model: IterableDecoderModel, config: Config, **kwargs: Any + ) -> IterableDecoder: return IterableDecoder(parameters={}) @staticmethod @@ -1123,11 +1549,17 @@ def create_xml_decoder(model: XmlDecoderModel, config: Config, **kwargs: Any) -> return XmlDecoder(parameters={}) @staticmethod - def create_json_file_schema_loader(model: JsonFileSchemaLoaderModel, config: Config, **kwargs: Any) -> JsonFileSchemaLoader: - return JsonFileSchemaLoader(file_path=model.file_path or "", config=config, parameters=model.parameters or {}) + def create_json_file_schema_loader( + model: JsonFileSchemaLoaderModel, config: Config, **kwargs: Any + ) -> JsonFileSchemaLoader: + return JsonFileSchemaLoader( + file_path=model.file_path or "", config=config, parameters=model.parameters or {} + ) @staticmethod - def create_jwt_authenticator(model: JwtAuthenticatorModel, config: Config, **kwargs: Any) -> JwtAuthenticator: + def create_jwt_authenticator( + model: JwtAuthenticatorModel, config: Config, **kwargs: Any + ) -> JwtAuthenticator: jwt_headers = model.jwt_headers or JwtHeadersModel(kid=None, typ="JWT", cty=None) jwt_payload = model.jwt_payload or JwtPayloadModel(iss=None, sub=None, aud=None) return JwtAuthenticator( @@ -1149,7 +1581,9 @@ def create_jwt_authenticator(model: JwtAuthenticatorModel, config: Config, **kwa ) @staticmethod - def create_list_partition_router(model: ListPartitionRouterModel, config: Config, **kwargs: Any) -> ListPartitionRouter: + def create_list_partition_router( + model: ListPartitionRouterModel, config: Config, **kwargs: Any + ) -> ListPartitionRouter: request_option = ( RequestOption( inject_into=RequestOptionType(model.request_option.inject_into.value), @@ -1168,7 +1602,9 @@ def create_list_partition_router(model: ListPartitionRouterModel, config: Config ) @staticmethod - def create_min_max_datetime(model: MinMaxDatetimeModel, config: Config, **kwargs: Any) -> MinMaxDatetime: + def create_min_max_datetime( + model: MinMaxDatetimeModel, config: Config, **kwargs: Any + ) -> MinMaxDatetime: return MinMaxDatetime( datetime=model.datetime, datetime_format=model.datetime_format or "", @@ -1182,29 +1618,43 @@ def create_no_auth(model: NoAuthModel, config: Config, **kwargs: Any) -> NoAuth: return NoAuth(parameters=model.parameters or {}) @staticmethod - def create_no_pagination(model: NoPaginationModel, config: Config, **kwargs: Any) -> NoPagination: + def create_no_pagination( + model: NoPaginationModel, config: Config, **kwargs: Any + ) -> NoPagination: return NoPagination(parameters={}) - def create_oauth_authenticator(self, model: OAuthAuthenticatorModel, config: Config, **kwargs: Any) -> DeclarativeOauth2Authenticator: + def create_oauth_authenticator( + self, model: OAuthAuthenticatorModel, config: Config, **kwargs: Any + ) -> DeclarativeOauth2Authenticator: if model.refresh_token_updater: # ignore type error because fixing it would have a lot of dependencies, revisit later return DeclarativeSingleUseRefreshTokenOauth2Authenticator( # type: ignore config, - InterpolatedString.create(model.token_refresh_endpoint, parameters=model.parameters or {}).eval(config), + InterpolatedString.create( + model.token_refresh_endpoint, parameters=model.parameters or {} + ).eval(config), access_token_name=InterpolatedString.create( model.access_token_name or "access_token", parameters=model.parameters or {} ).eval(config), refresh_token_name=model.refresh_token_updater.refresh_token_name, - expires_in_name=InterpolatedString.create(model.expires_in_name or "expires_in", parameters=model.parameters or {}).eval( - config - ), - client_id=InterpolatedString.create(model.client_id, parameters=model.parameters or {}).eval(config), - client_secret=InterpolatedString.create(model.client_secret, parameters=model.parameters or {}).eval(config), + expires_in_name=InterpolatedString.create( + model.expires_in_name or "expires_in", parameters=model.parameters or {} + ).eval(config), + client_id=InterpolatedString.create( + model.client_id, parameters=model.parameters or {} + ).eval(config), + client_secret=InterpolatedString.create( + model.client_secret, parameters=model.parameters or {} + ).eval(config), access_token_config_path=model.refresh_token_updater.access_token_config_path, refresh_token_config_path=model.refresh_token_updater.refresh_token_config_path, token_expiry_date_config_path=model.refresh_token_updater.token_expiry_date_config_path, - grant_type=InterpolatedString.create(model.grant_type or "refresh_token", parameters=model.parameters or {}).eval(config), - refresh_request_body=InterpolatedMapping(model.refresh_request_body or {}, parameters=model.parameters or {}).eval(config), + grant_type=InterpolatedString.create( + model.grant_type or "refresh_token", parameters=model.parameters or {} + ).eval(config), + refresh_request_body=InterpolatedMapping( + model.refresh_request_body or {}, parameters=model.parameters or {} + ).eval(config), scopes=model.scopes, token_expiry_date_format=model.token_expiry_date_format, message_repository=self._message_repository, @@ -1232,7 +1682,9 @@ def create_oauth_authenticator(self, model: OAuthAuthenticatorModel, config: Con ) @staticmethod - def create_offset_increment(model: OffsetIncrementModel, config: Config, decoder: Decoder, **kwargs: Any) -> OffsetIncrement: + def create_offset_increment( + model: OffsetIncrementModel, config: Config, decoder: Decoder, **kwargs: Any + ) -> OffsetIncrement: if isinstance(decoder, PaginationDecoderDecorator): if not isinstance(decoder.decoder, (JsonDecoder, XmlDecoder)): raise ValueError( @@ -1241,7 +1693,9 @@ def create_offset_increment(model: OffsetIncrementModel, config: Config, decoder decoder_to_use = decoder else: if not isinstance(decoder, (JsonDecoder, XmlDecoder)): - raise ValueError(f"Provided decoder of {type(decoder)=} is not supported. Please set JsonDecoder or XmlDecoder instead.") + raise ValueError( + f"Provided decoder of {type(decoder)=} is not supported. Please set JsonDecoder or XmlDecoder instead." + ) decoder_to_use = PaginationDecoderDecorator(decoder=decoder) return OffsetIncrement( page_size=model.page_size, @@ -1252,7 +1706,9 @@ def create_offset_increment(model: OffsetIncrementModel, config: Config, decoder ) @staticmethod - def create_page_increment(model: PageIncrementModel, config: Config, **kwargs: Any) -> PageIncrement: + def create_page_increment( + model: PageIncrementModel, config: Config, **kwargs: Any + ) -> PageIncrement: return PageIncrement( page_size=model.page_size, config=config, @@ -1261,9 +1717,15 @@ def create_page_increment(model: PageIncrementModel, config: Config, **kwargs: A parameters=model.parameters or {}, ) - def create_parent_stream_config(self, model: ParentStreamConfigModel, config: Config, **kwargs: Any) -> ParentStreamConfig: + def create_parent_stream_config( + self, model: ParentStreamConfigModel, config: Config, **kwargs: Any + ) -> ParentStreamConfig: declarative_stream = self._create_component_from_model(model.stream, config=config) - request_option = self._create_component_from_model(model.request_option, config=config) if model.request_option else None + request_option = ( + self._create_component_from_model(model.request_option, config=config) + if model.request_option + else None + ) return ParentStreamConfig( parent_key=model.parent_key, request_option=request_option, @@ -1276,15 +1738,21 @@ def create_parent_stream_config(self, model: ParentStreamConfigModel, config: Co ) @staticmethod - def create_record_filter(model: RecordFilterModel, config: Config, **kwargs: Any) -> RecordFilter: - return RecordFilter(condition=model.condition or "", config=config, parameters=model.parameters or {}) + def create_record_filter( + model: RecordFilterModel, config: Config, **kwargs: Any + ) -> RecordFilter: + return RecordFilter( + condition=model.condition or "", config=config, parameters=model.parameters or {} + ) @staticmethod def create_request_path(model: RequestPathModel, config: Config, **kwargs: Any) -> RequestPath: return RequestPath(parameters={}) @staticmethod - def create_request_option(model: RequestOptionModel, config: Config, **kwargs: Any) -> RequestOption: + def create_request_option( + model: RequestOptionModel, config: Config, **kwargs: Any + ) -> RequestOption: inject_into = RequestOptionType(model.inject_into.value) return RequestOption(field_name=model.field_name, inject_into=inject_into, parameters={}) @@ -1299,16 +1767,26 @@ def create_record_selector( **kwargs: Any, ) -> RecordSelector: assert model.schema_normalization is not None # for mypy - extractor = self._create_component_from_model(model=model.extractor, decoder=decoder, config=config) - record_filter = self._create_component_from_model(model.record_filter, config=config) if model.record_filter else None + extractor = self._create_component_from_model( + model=model.extractor, decoder=decoder, config=config + ) + record_filter = ( + self._create_component_from_model(model.record_filter, config=config) + if model.record_filter + else None + ) if client_side_incremental_sync: record_filter = ClientSideIncrementalRecordFilterDecorator( config=config, parameters=model.parameters, - condition=model.record_filter.condition if (model.record_filter and hasattr(model.record_filter, "condition")) else None, + condition=model.record_filter.condition + if (model.record_filter and hasattr(model.record_filter, "condition")) + else None, **client_side_incremental_sync, ) - schema_normalization = TypeTransformer(SCHEMA_TRANSFORMER_TYPE_MAPPING[model.schema_normalization]) + schema_normalization = TypeTransformer( + SCHEMA_TRANSFORMER_TYPE_MAPPING[model.schema_normalization] + ) return RecordSelector( extractor=extractor, @@ -1320,11 +1798,20 @@ def create_record_selector( ) @staticmethod - def create_remove_fields(model: RemoveFieldsModel, config: Config, **kwargs: Any) -> RemoveFields: - return RemoveFields(field_pointers=model.field_pointers, condition=model.condition or "", parameters={}) - - def create_selective_authenticator(self, model: SelectiveAuthenticatorModel, config: Config, **kwargs: Any) -> DeclarativeAuthenticator: - authenticators = {name: self._create_component_from_model(model=auth, config=config) for name, auth in model.authenticators.items()} + def create_remove_fields( + model: RemoveFieldsModel, config: Config, **kwargs: Any + ) -> RemoveFields: + return RemoveFields( + field_pointers=model.field_pointers, condition=model.condition or "", parameters={} + ) + + def create_selective_authenticator( + self, model: SelectiveAuthenticatorModel, config: Config, **kwargs: Any + ) -> DeclarativeAuthenticator: + authenticators = { + name: self._create_component_from_model(model=auth, config=config) + for name, auth in model.authenticators.items() + } # SelectiveAuthenticator will return instance of DeclarativeAuthenticator or raise ValueError error return SelectiveAuthenticator( # type: ignore[abstract] config=config, @@ -1363,8 +1850,14 @@ def create_simple_retriever( client_side_incremental_sync: Optional[Dict[str, Any]] = None, transformations: List[RecordTransformation], ) -> SimpleRetriever: - decoder = self._create_component_from_model(model=model.decoder, config=config) if model.decoder else JsonDecoder(parameters={}) - requester = self._create_component_from_model(model=model.requester, decoder=decoder, config=config, name=name) + decoder = ( + self._create_component_from_model(model=model.decoder, config=config) + if model.decoder + else JsonDecoder(parameters={}) + ) + requester = self._create_component_from_model( + model=model.requester, decoder=decoder, config=config, name=name + ) record_selector = self._create_component_from_model( model=model.record_selector, config=config, @@ -1372,12 +1865,19 @@ def create_simple_retriever( transformations=transformations, client_side_incremental_sync=client_side_incremental_sync, ) - url_base = model.requester.url_base if hasattr(model.requester, "url_base") else requester.get_url_base() + url_base = ( + model.requester.url_base + if hasattr(model.requester, "url_base") + else requester.get_url_base() + ) # Define cursor only if per partition or common incremental support is needed cursor = stream_slicer if isinstance(stream_slicer, DeclarativeCursor) else None - if not isinstance(stream_slicer, DatetimeBasedCursor) or type(stream_slicer) is not DatetimeBasedCursor: + if ( + not isinstance(stream_slicer, DatetimeBasedCursor) + or type(stream_slicer) is not DatetimeBasedCursor + ): # Many of the custom component implementations of DatetimeBasedCursor override get_request_params() (or other methods). # Because we're decoupling RequestOptionsProvider from the Cursor, custom components will eventually need to reimplement # their own RequestOptionsProvider. However, right now the existing StreamSlicer/Cursor still can act as the SimpleRetriever's @@ -1401,7 +1901,9 @@ def create_simple_retriever( else NoPagination(parameters={}) ) - ignore_stream_slicer_parameters_on_paginated_requests = model.ignore_stream_slicer_parameters_on_paginated_requests or False + ignore_stream_slicer_parameters_on_paginated_requests = ( + model.ignore_stream_slicer_parameters_on_paginated_requests or False + ) if self._limit_slices_fetched or self._emit_connector_builder_messages: return SimpleRetrieverTestReadDecorator( @@ -1468,13 +1970,19 @@ def create_async_retriever( config: Config, *, name: str, - primary_key: Optional[Union[str, List[str], List[List[str]]]], # this seems to be needed to match create_simple_retriever + primary_key: Optional[ + Union[str, List[str], List[List[str]]] + ], # this seems to be needed to match create_simple_retriever stream_slicer: Optional[StreamSlicer], client_side_incremental_sync: Optional[Dict[str, Any]] = None, transformations: List[RecordTransformation], **kwargs: Any, ) -> AsyncRetriever: - decoder = self._create_component_from_model(model=model.decoder, config=config) if model.decoder else JsonDecoder(parameters={}) + decoder = ( + self._create_component_from_model(model=model.decoder, config=config) + if model.decoder + else JsonDecoder(parameters={}) + ) record_selector = self._create_component_from_model( model=model.record_selector, config=config, @@ -1484,14 +1992,23 @@ def create_async_retriever( ) stream_slicer = stream_slicer or SinglePartitionRouter(parameters={}) creation_requester = self._create_component_from_model( - model=model.creation_requester, decoder=decoder, config=config, name=f"job creation - {name}" + model=model.creation_requester, + decoder=decoder, + config=config, + name=f"job creation - {name}", ) polling_requester = self._create_component_from_model( - model=model.polling_requester, decoder=decoder, config=config, name=f"job polling - {name}" + model=model.polling_requester, + decoder=decoder, + config=config, + name=f"job polling - {name}", ) job_download_components_name = f"job download - {name}" download_requester = self._create_component_from_model( - model=model.download_requester, decoder=decoder, config=config, name=job_download_components_name + model=model.download_requester, + decoder=decoder, + config=config, + name=job_download_components_name, ) download_retriever = SimpleRetriever( requester=download_requester, @@ -1506,7 +2023,9 @@ def create_async_retriever( primary_key=None, name=job_download_components_name, paginator=( - self._create_component_from_model(model=model.download_paginator, decoder=decoder, config=config, url_base="") + self._create_component_from_model( + model=model.download_paginator, decoder=decoder, config=config, url_base="" + ) if model.download_paginator else NoPagination(parameters={}) ), @@ -1514,17 +2033,31 @@ def create_async_retriever( parameters={}, ) abort_requester = ( - self._create_component_from_model(model=model.abort_requester, decoder=decoder, config=config, name=f"job abort - {name}") + self._create_component_from_model( + model=model.abort_requester, + decoder=decoder, + config=config, + name=f"job abort - {name}", + ) if model.abort_requester else None ) delete_requester = ( - self._create_component_from_model(model=model.delete_requester, decoder=decoder, config=config, name=f"job delete - {name}") + self._create_component_from_model( + model=model.delete_requester, + decoder=decoder, + config=config, + name=f"job delete - {name}", + ) if model.delete_requester else None ) - status_extractor = self._create_component_from_model(model=model.status_extractor, decoder=decoder, config=config, name=name) - urls_extractor = self._create_component_from_model(model=model.urls_extractor, decoder=decoder, config=config, name=name) + status_extractor = self._create_component_from_model( + model=model.status_extractor, decoder=decoder, config=config, name=name + ) + urls_extractor = self._create_component_from_model( + model=model.urls_extractor, decoder=decoder, config=config, name=name + ) job_repository: AsyncJobRepository = AsyncHttpJobRepository( creation_requester=creation_requester, polling_requester=polling_requester, @@ -1540,7 +2073,9 @@ def create_async_retriever( job_orchestrator_factory=lambda stream_slices: AsyncJobOrchestrator( job_repository, stream_slices, - JobTracker(1), # FIXME eventually make the number of concurrent jobs in the API configurable. Until then, we limit to 1 + JobTracker( + 1 + ), # FIXME eventually make the number of concurrent jobs in the API configurable. Until then, we limit to 1 self._message_repository, has_bulk_parent=False, # FIXME work would need to be done here in order to detect if a stream as a parent stream that is bulk ), @@ -1566,14 +2101,22 @@ def create_substream_partition_router( if model.parent_stream_configs: parent_stream_configs.extend( [ - self._create_message_repository_substream_wrapper(model=parent_stream_config, config=config) + self._create_message_repository_substream_wrapper( + model=parent_stream_config, config=config + ) for parent_stream_config in model.parent_stream_configs ] ) - return SubstreamPartitionRouter(parent_stream_configs=parent_stream_configs, parameters=model.parameters or {}, config=config) + return SubstreamPartitionRouter( + parent_stream_configs=parent_stream_configs, + parameters=model.parameters or {}, + config=config, + ) - def _create_message_repository_substream_wrapper(self, model: ParentStreamConfigModel, config: Config) -> Any: + def _create_message_repository_substream_wrapper( + self, model: ParentStreamConfigModel, config: Config + ) -> Any: substream_factory = ModelToComponentFactory( limit_pages_fetched_per_slice=self._limit_pages_fetched_per_slice, limit_slices_fetched=self._limit_slices_fetched, @@ -1589,13 +2132,17 @@ def _create_message_repository_substream_wrapper(self, model: ParentStreamConfig return substream_factory._create_component_from_model(model=model, config=config) @staticmethod - def create_wait_time_from_header(model: WaitTimeFromHeaderModel, config: Config, **kwargs: Any) -> WaitTimeFromHeaderBackoffStrategy: + def create_wait_time_from_header( + model: WaitTimeFromHeaderModel, config: Config, **kwargs: Any + ) -> WaitTimeFromHeaderBackoffStrategy: return WaitTimeFromHeaderBackoffStrategy( header=model.header, parameters=model.parameters or {}, config=config, regex=model.regex, - max_waiting_time_in_seconds=model.max_waiting_time_in_seconds if model.max_waiting_time_in_seconds is not None else None, + max_waiting_time_in_seconds=model.max_waiting_time_in_seconds + if model.max_waiting_time_in_seconds is not None + else None, ) @staticmethod @@ -1603,7 +2150,11 @@ def create_wait_until_time_from_header( model: WaitUntilTimeFromHeaderModel, config: Config, **kwargs: Any ) -> WaitUntilTimeFromHeaderBackoffStrategy: return WaitUntilTimeFromHeaderBackoffStrategy( - header=model.header, parameters=model.parameters or {}, config=config, min_wait=model.min_wait, regex=model.regex + header=model.header, + parameters=model.parameters or {}, + config=config, + min_wait=model.min_wait, + regex=model.regex, ) def get_message_repository(self) -> MessageRepository: diff --git a/airbyte_cdk/sources/declarative/partition_routers/cartesian_product_stream_slicer.py b/airbyte_cdk/sources/declarative/partition_routers/cartesian_product_stream_slicer.py index 14898428..8718004b 100644 --- a/airbyte_cdk/sources/declarative/partition_routers/cartesian_product_stream_slicer.py +++ b/airbyte_cdk/sources/declarative/partition_routers/cartesian_product_stream_slicer.py @@ -10,11 +10,15 @@ from typing import Any, Iterable, List, Mapping, Optional from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter -from airbyte_cdk.sources.declarative.partition_routers.substream_partition_router import SubstreamPartitionRouter +from airbyte_cdk.sources.declarative.partition_routers.substream_partition_router import ( + SubstreamPartitionRouter, +) from airbyte_cdk.sources.types import StreamSlice, StreamState -def check_for_substream_in_slicers(slicers: Iterable[PartitionRouter], log_warning: Callable[[str], None]) -> None: +def check_for_substream_in_slicers( + slicers: Iterable[PartitionRouter], log_warning: Callable[[str], None] +) -> None: """ Recursively checks for the presence of SubstreamPartitionRouter within slicers. Logs a warning if a SubstreamPartitionRouter is found within a CartesianProductStreamSlicer. @@ -69,7 +73,11 @@ def get_request_params( return dict( ChainMap( *[ # type: ignore # ChainMap expects a MutableMapping[Never, Never] for reasons - s.get_request_params(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token) + s.get_request_params( + stream_state=stream_state, + stream_slice=stream_slice, + next_page_token=next_page_token, + ) for s in self.stream_slicers ] ) @@ -85,7 +93,11 @@ def get_request_headers( return dict( ChainMap( *[ # type: ignore # ChainMap expects a MutableMapping[Never, Never] for reasons - s.get_request_headers(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token) + s.get_request_headers( + stream_state=stream_state, + stream_slice=stream_slice, + next_page_token=next_page_token, + ) for s in self.stream_slicers ] ) @@ -101,7 +113,11 @@ def get_request_body_data( return dict( ChainMap( *[ # type: ignore # ChainMap expects a MutableMapping[Never, Never] for reasons - s.get_request_body_data(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token) + s.get_request_body_data( + stream_state=stream_state, + stream_slice=stream_slice, + next_page_token=next_page_token, + ) for s in self.stream_slicers ] ) @@ -117,7 +133,11 @@ def get_request_body_json( return dict( ChainMap( *[ # type: ignore # ChainMap expects a MutableMapping[Never, Never] for reasons - s.get_request_body_json(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token) + s.get_request_body_json( + stream_state=stream_state, + stream_slice=stream_slice, + next_page_token=next_page_token, + ) for s in self.stream_slicers ] ) @@ -130,7 +150,9 @@ def stream_slices(self) -> Iterable[StreamSlice]: partition = dict(ChainMap(*[s.partition for s in stream_slice_tuple])) # type: ignore # ChainMap expects a MutableMapping[Never, Never] for reasons cursor_slices = [s.cursor_slice for s in stream_slice_tuple if s.cursor_slice] if len(cursor_slices) > 1: - raise ValueError(f"There should only be a single cursor slice. Found {cursor_slices}") + raise ValueError( + f"There should only be a single cursor slice. Found {cursor_slices}" + ) if cursor_slices: cursor_slice = cursor_slices[0] else: diff --git a/airbyte_cdk/sources/declarative/partition_routers/list_partition_router.py b/airbyte_cdk/sources/declarative/partition_routers/list_partition_router.py index 564a3119..29b700b0 100644 --- a/airbyte_cdk/sources/declarative/partition_routers/list_partition_router.py +++ b/airbyte_cdk/sources/declarative/partition_routers/list_partition_router.py @@ -7,7 +7,10 @@ from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter -from airbyte_cdk.sources.declarative.requesters.request_option import RequestOption, RequestOptionType +from airbyte_cdk.sources.declarative.requesters.request_option import ( + RequestOption, + RequestOptionType, +) from airbyte_cdk.sources.types import Config, StreamSlice, StreamState @@ -32,9 +35,13 @@ class ListPartitionRouter(PartitionRouter): def __post_init__(self, parameters: Mapping[str, Any]) -> None: if isinstance(self.values, str): - self.values = InterpolatedString.create(self.values, parameters=parameters).eval(self.config) + self.values = InterpolatedString.create(self.values, parameters=parameters).eval( + self.config + ) self._cursor_field = ( - InterpolatedString(string=self.cursor_field, parameters=parameters) if isinstance(self.cursor_field, str) else self.cursor_field + InterpolatedString(string=self.cursor_field, parameters=parameters) + if isinstance(self.cursor_field, str) + else self.cursor_field ) self._cursor = None @@ -76,10 +83,21 @@ def get_request_body_json( return self._get_request_option(RequestOptionType.body_json, stream_slice) def stream_slices(self) -> Iterable[StreamSlice]: - return [StreamSlice(partition={self._cursor_field.eval(self.config): slice_value}, cursor_slice={}) for slice_value in self.values] - - def _get_request_option(self, request_option_type: RequestOptionType, stream_slice: Optional[StreamSlice]) -> Mapping[str, Any]: - if self.request_option and self.request_option.inject_into == request_option_type and stream_slice: + return [ + StreamSlice( + partition={self._cursor_field.eval(self.config): slice_value}, cursor_slice={} + ) + for slice_value in self.values + ] + + def _get_request_option( + self, request_option_type: RequestOptionType, stream_slice: Optional[StreamSlice] + ) -> Mapping[str, Any]: + if ( + self.request_option + and self.request_option.inject_into == request_option_type + and stream_slice + ): slice_value = stream_slice.get(self._cursor_field.eval(self.config)) if slice_value: return {self.request_option.field_name.eval(self.config): slice_value} # type: ignore # field_name is always casted to InterpolatedString diff --git a/airbyte_cdk/sources/declarative/partition_routers/substream_partition_router.py b/airbyte_cdk/sources/declarative/partition_routers/substream_partition_router.py index 667b673d..4c761d08 100644 --- a/airbyte_cdk/sources/declarative/partition_routers/substream_partition_router.py +++ b/airbyte_cdk/sources/declarative/partition_routers/substream_partition_router.py @@ -11,7 +11,10 @@ from airbyte_cdk.models import Type as MessageType from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter -from airbyte_cdk.sources.declarative.requesters.request_option import RequestOption, RequestOptionType +from airbyte_cdk.sources.declarative.requesters.request_option import ( + RequestOption, + RequestOptionType, +) from airbyte_cdk.sources.types import Config, Record, StreamSlice, StreamState from airbyte_cdk.utils import AirbyteTracedException @@ -37,17 +40,22 @@ class ParentStreamConfig: partition_field: Union[InterpolatedString, str] config: Config parameters: InitVar[Mapping[str, Any]] - extra_fields: Optional[Union[List[List[str]], List[List[InterpolatedString]]]] = None # List of field paths (arrays of strings) + extra_fields: Optional[Union[List[List[str]], List[List[InterpolatedString]]]] = ( + None # List of field paths (arrays of strings) + ) request_option: Optional[RequestOption] = None incremental_dependency: bool = False def __post_init__(self, parameters: Mapping[str, Any]) -> None: self.parent_key = InterpolatedString.create(self.parent_key, parameters=parameters) - self.partition_field = InterpolatedString.create(self.partition_field, parameters=parameters) + self.partition_field = InterpolatedString.create( + self.partition_field, parameters=parameters + ) if self.extra_fields: # Create InterpolatedString for each field path in extra_keys self.extra_fields = [ - [InterpolatedString.create(path, parameters=parameters) for path in key_path] for key_path in self.extra_fields + [InterpolatedString.create(path, parameters=parameters) for path in key_path] + for key_path in self.extra_fields ] @@ -106,15 +114,26 @@ def get_request_body_json( # Pass the stream_slice from the argument, not the cursor because the cursor is updated after processing the response return self._get_request_option(RequestOptionType.body_json, stream_slice) - def _get_request_option(self, option_type: RequestOptionType, stream_slice: Optional[StreamSlice]) -> Mapping[str, Any]: + def _get_request_option( + self, option_type: RequestOptionType, stream_slice: Optional[StreamSlice] + ) -> Mapping[str, Any]: params = {} if stream_slice: for parent_config in self.parent_stream_configs: - if parent_config.request_option and parent_config.request_option.inject_into == option_type: + if ( + parent_config.request_option + and parent_config.request_option.inject_into == option_type + ): key = parent_config.partition_field.eval(self.config) # type: ignore # partition_field is always casted to an interpolated string value = stream_slice.get(key) if value: - params.update({parent_config.request_option.field_name.eval(config=self.config): value}) # type: ignore # field_name is always casted to an interpolated string + params.update( + { + parent_config.request_option.field_name.eval( + config=self.config + ): value + } + ) # type: ignore # field_name is always casted to an interpolated string return params def stream_slices(self) -> Iterable[StreamSlice]: @@ -160,11 +179,17 @@ def stream_slices(self) -> Iterable[StreamSlice]: else: continue elif isinstance(parent_record, Record): - parent_partition = parent_record.associated_slice.partition if parent_record.associated_slice else {} + parent_partition = ( + parent_record.associated_slice.partition + if parent_record.associated_slice + else {} + ) parent_record = parent_record.data elif not isinstance(parent_record, Mapping): # The parent_record should only take the form of a Record, AirbyteMessage, or Mapping. Anything else is invalid - raise AirbyteTracedException(message=f"Parent stream returned records as invalid type {type(parent_record)}") + raise AirbyteTracedException( + message=f"Parent stream returned records as invalid type {type(parent_record)}" + ) try: partition_value = dpath.get(parent_record, parent_field) except KeyError: @@ -174,13 +199,18 @@ def stream_slices(self) -> Iterable[StreamSlice]: extracted_extra_fields = self._extract_extra_fields(parent_record, extra_fields) yield StreamSlice( - partition={partition_field: partition_value, "parent_slice": parent_partition or {}}, + partition={ + partition_field: partition_value, + "parent_slice": parent_partition or {}, + }, cursor_slice={}, extra_fields=extracted_extra_fields, ) def _extract_extra_fields( - self, parent_record: Mapping[str, Any] | AirbyteMessage, extra_fields: Optional[List[List[str]]] = None + self, + parent_record: Mapping[str, Any] | AirbyteMessage, + extra_fields: Optional[List[List[str]]] = None, ) -> Mapping[str, Any]: """ Extracts additional fields specified by their paths from the parent record. @@ -198,7 +228,9 @@ def _extract_extra_fields( for extra_field_path in extra_fields: try: extra_field_value = dpath.get(parent_record, extra_field_path) - self.logger.debug(f"Extracted extra_field_path: {extra_field_path} with value: {extra_field_value}") + self.logger.debug( + f"Extracted extra_field_path: {extra_field_path} with value: {extra_field_value}" + ) except KeyError: self.logger.debug(f"Failed to extract extra_field_path: {extra_field_path}") extra_field_value = None @@ -249,7 +281,9 @@ def set_initial_state(self, stream_state: StreamState) -> None: # If `parent_state` doesn't exist and at least one parent stream has an incremental dependency, # copy the child state to parent streams with incremental dependencies. - incremental_dependency = any([parent_config.incremental_dependency for parent_config in self.parent_stream_configs]) + incremental_dependency = any( + [parent_config.incremental_dependency for parent_config in self.parent_stream_configs] + ) if not parent_state and not incremental_dependency: return @@ -263,7 +297,9 @@ def set_initial_state(self, stream_state: StreamState) -> None: if substream_state: for parent_config in self.parent_stream_configs: if parent_config.incremental_dependency: - parent_state[parent_config.stream.name] = {parent_config.stream.cursor_field: substream_state} + parent_state[parent_config.stream.name] = { + parent_config.stream.cursor_field: substream_state + } # Set state for each parent stream with an incremental dependency for parent_config in self.parent_stream_configs: diff --git a/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/constant_backoff_strategy.py b/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/constant_backoff_strategy.py index 96c50ef0..d9213eb9 100644 --- a/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/constant_backoff_strategy.py +++ b/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/constant_backoff_strategy.py @@ -28,9 +28,13 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: if not isinstance(self.backoff_time_in_seconds, InterpolatedString): self.backoff_time_in_seconds = str(self.backoff_time_in_seconds) if isinstance(self.backoff_time_in_seconds, float): - self.backoff_time_in_seconds = InterpolatedString.create(str(self.backoff_time_in_seconds), parameters=parameters) + self.backoff_time_in_seconds = InterpolatedString.create( + str(self.backoff_time_in_seconds), parameters=parameters + ) else: - self.backoff_time_in_seconds = InterpolatedString.create(self.backoff_time_in_seconds, parameters=parameters) + self.backoff_time_in_seconds = InterpolatedString.create( + self.backoff_time_in_seconds, parameters=parameters + ) def backoff_time( self, diff --git a/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/header_helper.py b/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/header_helper.py index e7f5e1f8..60103f34 100644 --- a/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/header_helper.py +++ b/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/header_helper.py @@ -9,7 +9,9 @@ import requests -def get_numeric_value_from_header(response: requests.Response, header: str, regex: Optional[Pattern[str]]) -> Optional[float]: +def get_numeric_value_from_header( + response: requests.Response, header: str, regex: Optional[Pattern[str]] +) -> Optional[float]: """ Extract a header value from the response as a float :param response: response the extract header value from diff --git a/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/wait_time_from_header_backoff_strategy.py b/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/wait_time_from_header_backoff_strategy.py index 79eb8a7fe..7672bd82 100644 --- a/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/wait_time_from_header_backoff_strategy.py +++ b/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/wait_time_from_header_backoff_strategy.py @@ -9,8 +9,12 @@ import requests from airbyte_cdk.models import FailureType from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString -from airbyte_cdk.sources.declarative.requesters.error_handlers.backoff_strategies.header_helper import get_numeric_value_from_header -from airbyte_cdk.sources.declarative.requesters.error_handlers.backoff_strategy import BackoffStrategy +from airbyte_cdk.sources.declarative.requesters.error_handlers.backoff_strategies.header_helper import ( + get_numeric_value_from_header, +) +from airbyte_cdk.sources.declarative.requesters.error_handlers.backoff_strategy import ( + BackoffStrategy, +) from airbyte_cdk.sources.types import Config from airbyte_cdk.utils import AirbyteTracedException @@ -33,11 +37,15 @@ class WaitTimeFromHeaderBackoffStrategy(BackoffStrategy): max_waiting_time_in_seconds: Optional[float] = None def __post_init__(self, parameters: Mapping[str, Any]) -> None: - self.regex = InterpolatedString.create(self.regex, parameters=parameters) if self.regex else None + self.regex = ( + InterpolatedString.create(self.regex, parameters=parameters) if self.regex else None + ) self.header = InterpolatedString.create(self.header, parameters=parameters) def backoff_time( - self, response_or_exception: Optional[Union[requests.Response, requests.RequestException]], attempt_count: int + self, + response_or_exception: Optional[Union[requests.Response, requests.RequestException]], + attempt_count: int, ) -> Optional[float]: header = self.header.eval(config=self.config) # type: ignore # header is always cast to an interpolated stream if self.regex: @@ -48,7 +56,11 @@ def backoff_time( header_value = None if isinstance(response_or_exception, requests.Response): header_value = get_numeric_value_from_header(response_or_exception, header, regex) - if self.max_waiting_time_in_seconds and header_value and header_value >= self.max_waiting_time_in_seconds: + if ( + self.max_waiting_time_in_seconds + and header_value + and header_value >= self.max_waiting_time_in_seconds + ): raise AirbyteTracedException( internal_message=f"Rate limit wait time {header_value} is greater than max waiting time of {self.max_waiting_time_in_seconds} seconds. Stopping the stream...", message="The rate limit is greater than max waiting time has been reached.", diff --git a/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/wait_until_time_from_header_backoff_strategy.py b/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/wait_until_time_from_header_backoff_strategy.py index 861f8ba8..4aed7338 100644 --- a/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/wait_until_time_from_header_backoff_strategy.py +++ b/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/wait_until_time_from_header_backoff_strategy.py @@ -10,8 +10,12 @@ import requests from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString -from airbyte_cdk.sources.declarative.requesters.error_handlers.backoff_strategies.header_helper import get_numeric_value_from_header -from airbyte_cdk.sources.declarative.requesters.error_handlers.backoff_strategy import BackoffStrategy +from airbyte_cdk.sources.declarative.requesters.error_handlers.backoff_strategies.header_helper import ( + get_numeric_value_from_header, +) +from airbyte_cdk.sources.declarative.requesters.error_handlers.backoff_strategy import ( + BackoffStrategy, +) from airbyte_cdk.sources.types import Config @@ -35,12 +39,16 @@ class WaitUntilTimeFromHeaderBackoffStrategy(BackoffStrategy): def __post_init__(self, parameters: Mapping[str, Any]) -> None: self.header = InterpolatedString.create(self.header, parameters=parameters) - self.regex = InterpolatedString.create(self.regex, parameters=parameters) if self.regex else None + self.regex = ( + InterpolatedString.create(self.regex, parameters=parameters) if self.regex else None + ) if not isinstance(self.min_wait, InterpolatedString): self.min_wait = InterpolatedString.create(str(self.min_wait), parameters=parameters) def backoff_time( - self, response_or_exception: Optional[Union[requests.Response, requests.RequestException]], attempt_count: int + self, + response_or_exception: Optional[Union[requests.Response, requests.RequestException]], + attempt_count: int, ) -> Optional[float]: now = time.time() header = self.header.eval(self.config) # type: ignore # header is always cast to an interpolated string @@ -55,7 +63,9 @@ def backoff_time( min_wait = self.min_wait.eval(self.config) # type: ignore # header is always cast to an interpolated string if wait_until is None or not wait_until: return float(min_wait) if min_wait else None - if (isinstance(wait_until, str) and wait_until.isnumeric()) or isinstance(wait_until, numbers.Number): + if (isinstance(wait_until, str) and wait_until.isnumeric()) or isinstance( + wait_until, numbers.Number + ): wait_time = float(wait_until) - now else: return float(min_wait) diff --git a/airbyte_cdk/sources/declarative/requesters/error_handlers/composite_error_handler.py b/airbyte_cdk/sources/declarative/requesters/error_handlers/composite_error_handler.py index bc151fec..717fcba6 100644 --- a/airbyte_cdk/sources/declarative/requesters/error_handlers/composite_error_handler.py +++ b/airbyte_cdk/sources/declarative/requesters/error_handlers/composite_error_handler.py @@ -54,7 +54,9 @@ def max_retries(self) -> Optional[int]: def max_time(self) -> Optional[int]: return max([error_handler.max_time or 0 for error_handler in self.error_handlers]) - def interpret_response(self, response_or_exception: Optional[Union[requests.Response, Exception]]) -> ErrorResolution: + def interpret_response( + self, response_or_exception: Optional[Union[requests.Response, Exception]] + ) -> ErrorResolution: matched_error_resolution = None for error_handler in self.error_handlers: matched_error_resolution = error_handler.interpret_response(response_or_exception) diff --git a/airbyte_cdk/sources/declarative/requesters/error_handlers/default_error_handler.py b/airbyte_cdk/sources/declarative/requesters/error_handlers/default_error_handler.py index c255360f..ad4a6261 100644 --- a/airbyte_cdk/sources/declarative/requesters/error_handlers/default_error_handler.py +++ b/airbyte_cdk/sources/declarative/requesters/error_handlers/default_error_handler.py @@ -6,8 +6,12 @@ from typing import Any, List, Mapping, MutableMapping, Optional, Union import requests -from airbyte_cdk.sources.declarative.requesters.error_handlers.default_http_response_filter import DefaultHttpResponseFilter -from airbyte_cdk.sources.declarative.requesters.error_handlers.http_response_filter import HttpResponseFilter +from airbyte_cdk.sources.declarative.requesters.error_handlers.default_http_response_filter import ( + DefaultHttpResponseFilter, +) +from airbyte_cdk.sources.declarative.requesters.error_handlers.http_response_filter import ( + HttpResponseFilter, +) from airbyte_cdk.sources.streams.http.error_handlers import BackoffStrategy, ErrorHandler from airbyte_cdk.sources.streams.http.error_handlers.response_models import ( SUCCESS_RESOLUTION, @@ -103,10 +107,14 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: self._last_request_to_attempt_count: MutableMapping[requests.PreparedRequest, int] = {} - def interpret_response(self, response_or_exception: Optional[Union[requests.Response, Exception]]) -> ErrorResolution: + def interpret_response( + self, response_or_exception: Optional[Union[requests.Response, Exception]] + ) -> ErrorResolution: if self.response_filters: for response_filter in self.response_filters: - matched_error_resolution = response_filter.matches(response_or_exception=response_or_exception) + matched_error_resolution = response_filter.matches( + response_or_exception=response_or_exception + ) if matched_error_resolution: return matched_error_resolution if isinstance(response_or_exception, requests.Response): @@ -123,12 +131,16 @@ def interpret_response(self, response_or_exception: Optional[Union[requests.Resp ) def backoff_time( - self, response_or_exception: Optional[Union[requests.Response, requests.RequestException]], attempt_count: int = 0 + self, + response_or_exception: Optional[Union[requests.Response, requests.RequestException]], + attempt_count: int = 0, ) -> Optional[float]: backoff = None if self.backoff_strategies: for backoff_strategy in self.backoff_strategies: - backoff = backoff_strategy.backoff_time(response_or_exception=response_or_exception, attempt_count=attempt_count) # type: ignore # attempt_count maintained for compatibility with low code CDK + backoff = backoff_strategy.backoff_time( + response_or_exception=response_or_exception, attempt_count=attempt_count + ) # type: ignore # attempt_count maintained for compatibility with low code CDK if backoff: return backoff return backoff diff --git a/airbyte_cdk/sources/declarative/requesters/error_handlers/default_http_response_filter.py b/airbyte_cdk/sources/declarative/requesters/error_handlers/default_http_response_filter.py index 2a8eae72..395df5c9 100644 --- a/airbyte_cdk/sources/declarative/requesters/error_handlers/default_http_response_filter.py +++ b/airbyte_cdk/sources/declarative/requesters/error_handlers/default_http_response_filter.py @@ -5,13 +5,22 @@ from typing import Optional, Union import requests -from airbyte_cdk.sources.declarative.requesters.error_handlers.http_response_filter import HttpResponseFilter -from airbyte_cdk.sources.streams.http.error_handlers.default_error_mapping import DEFAULT_ERROR_MAPPING -from airbyte_cdk.sources.streams.http.error_handlers.response_models import ErrorResolution, create_fallback_error_resolution +from airbyte_cdk.sources.declarative.requesters.error_handlers.http_response_filter import ( + HttpResponseFilter, +) +from airbyte_cdk.sources.streams.http.error_handlers.default_error_mapping import ( + DEFAULT_ERROR_MAPPING, +) +from airbyte_cdk.sources.streams.http.error_handlers.response_models import ( + ErrorResolution, + create_fallback_error_resolution, +) class DefaultHttpResponseFilter(HttpResponseFilter): - def matches(self, response_or_exception: Optional[Union[requests.Response, Exception]]) -> Optional[ErrorResolution]: + def matches( + self, response_or_exception: Optional[Union[requests.Response, Exception]] + ) -> Optional[ErrorResolution]: default_mapped_error_resolution = None if isinstance(response_or_exception, (requests.Response, Exception)): @@ -24,5 +33,7 @@ def matches(self, response_or_exception: Optional[Union[requests.Response, Excep default_mapped_error_resolution = DEFAULT_ERROR_MAPPING.get(mapped_key) return ( - default_mapped_error_resolution if default_mapped_error_resolution else create_fallback_error_resolution(response_or_exception) + default_mapped_error_resolution + if default_mapped_error_resolution + else create_fallback_error_resolution(response_or_exception) ) diff --git a/airbyte_cdk/sources/declarative/requesters/error_handlers/http_response_filter.py b/airbyte_cdk/sources/declarative/requesters/error_handlers/http_response_filter.py index 172e1521..366ad687 100644 --- a/airbyte_cdk/sources/declarative/requesters/error_handlers/http_response_filter.py +++ b/airbyte_cdk/sources/declarative/requesters/error_handlers/http_response_filter.py @@ -10,8 +10,13 @@ from airbyte_cdk.sources.declarative.interpolation import InterpolatedString from airbyte_cdk.sources.declarative.interpolation.interpolated_boolean import InterpolatedBoolean from airbyte_cdk.sources.streams.http.error_handlers import JsonErrorMessageParser -from airbyte_cdk.sources.streams.http.error_handlers.default_error_mapping import DEFAULT_ERROR_MAPPING -from airbyte_cdk.sources.streams.http.error_handlers.response_models import ErrorResolution, ResponseAction +from airbyte_cdk.sources.streams.http.error_handlers.default_error_mapping import ( + DEFAULT_ERROR_MAPPING, +) +from airbyte_cdk.sources.streams.http.error_handlers.response_models import ( + ErrorResolution, + ResponseAction, +) from airbyte_cdk.sources.types import Config @@ -43,22 +48,34 @@ class HttpResponseFilter: def __post_init__(self, parameters: Mapping[str, Any]) -> None: if self.action is not None: - if self.http_codes is None and self.predicate is None and self.error_message_contains is None: - raise ValueError("HttpResponseFilter requires a filter condition if an action is specified") + if ( + self.http_codes is None + and self.predicate is None + and self.error_message_contains is None + ): + raise ValueError( + "HttpResponseFilter requires a filter condition if an action is specified" + ) elif isinstance(self.action, str): self.action = ResponseAction[self.action] self.http_codes = self.http_codes or set() if isinstance(self.predicate, str): self.predicate = InterpolatedBoolean(condition=self.predicate, parameters=parameters) - self.error_message = InterpolatedString.create(string_or_interpolated=self.error_message, parameters=parameters) + self.error_message = InterpolatedString.create( + string_or_interpolated=self.error_message, parameters=parameters + ) self._error_message_parser = JsonErrorMessageParser() if self.failure_type and isinstance(self.failure_type, str): self.failure_type = FailureType[self.failure_type] - def matches(self, response_or_exception: Optional[Union[requests.Response, Exception]]) -> Optional[ErrorResolution]: + def matches( + self, response_or_exception: Optional[Union[requests.Response, Exception]] + ) -> Optional[ErrorResolution]: filter_action = self._matches_filter(response_or_exception) mapped_key = ( - response_or_exception.status_code if isinstance(response_or_exception, requests.Response) else response_or_exception.__class__ + response_or_exception.status_code + if isinstance(response_or_exception, requests.Response) + else response_or_exception.__class__ ) if isinstance(mapped_key, (int, Exception)): @@ -67,7 +84,11 @@ def matches(self, response_or_exception: Optional[Union[requests.Response, Excep default_mapped_error_resolution = None if filter_action is not None: - default_error_message = default_mapped_error_resolution.error_message if default_mapped_error_resolution else "" + default_error_message = ( + default_mapped_error_resolution.error_message + if default_mapped_error_resolution + else "" + ) error_message = None if isinstance(response_or_exception, requests.Response): error_message = self._create_error_message(response_or_exception) @@ -95,10 +116,14 @@ def matches(self, response_or_exception: Optional[Union[requests.Response, Excep return None - def _match_default_error_mapping(self, mapped_key: Union[int, type[Exception]]) -> Optional[ErrorResolution]: + def _match_default_error_mapping( + self, mapped_key: Union[int, type[Exception]] + ) -> Optional[ErrorResolution]: return DEFAULT_ERROR_MAPPING.get(mapped_key) - def _matches_filter(self, response_or_exception: Optional[Union[requests.Response, Exception]]) -> Optional[ResponseAction]: + def _matches_filter( + self, response_or_exception: Optional[Union[requests.Response, Exception]] + ) -> Optional[ResponseAction]: """ Apply the HTTP filter on the response and return the action to execute if it matches :param response: The HTTP response to evaluate @@ -125,13 +150,17 @@ def _create_error_message(self, response: requests.Response) -> Optional[str]: :param response: The HTTP response which can be used during interpolation :return: The evaluated error message string to be emitted """ - return self.error_message.eval(self.config, response=self._safe_response_json(response), headers=response.headers) # type: ignore # error_message is always cast to an interpolated string + return self.error_message.eval( + self.config, response=self._safe_response_json(response), headers=response.headers + ) # type: ignore # error_message is always cast to an interpolated string def _response_matches_predicate(self, response: requests.Response) -> bool: return ( bool( self.predicate.condition - and self.predicate.eval(None, response=self._safe_response_json(response), headers=response.headers) + and self.predicate.eval( + None, response=self._safe_response_json(response), headers=response.headers + ) ) if self.predicate else False @@ -141,5 +170,7 @@ def _response_contains_error_message(self, response: requests.Response) -> bool: if not self.error_message_contains: return False else: - error_message = self._error_message_parser.parse_response_error_message(response=response) + error_message = self._error_message_parser.parse_response_error_message( + response=response + ) return bool(error_message and self.error_message_contains in error_message) diff --git a/airbyte_cdk/sources/declarative/requesters/http_job_repository.py b/airbyte_cdk/sources/declarative/requesters/http_job_repository.py index 2c425bf8..ff213068 100644 --- a/airbyte_cdk/sources/declarative/requesters/http_job_repository.py +++ b/airbyte_cdk/sources/declarative/requesters/http_job_repository.py @@ -12,8 +12,13 @@ from airbyte_cdk.sources.declarative.async_job.job import AsyncJob from airbyte_cdk.sources.declarative.async_job.repository import AsyncJobRepository from airbyte_cdk.sources.declarative.async_job.status import AsyncJobStatus -from airbyte_cdk.sources.declarative.extractors.dpath_extractor import DpathExtractor, RecordExtractor -from airbyte_cdk.sources.declarative.extractors.response_to_file_extractor import ResponseToFileExtractor +from airbyte_cdk.sources.declarative.extractors.dpath_extractor import ( + DpathExtractor, + RecordExtractor, +) +from airbyte_cdk.sources.declarative.extractors.response_to_file_extractor import ( + ResponseToFileExtractor, +) from airbyte_cdk.sources.declarative.requesters.requester import Requester from airbyte_cdk.sources.declarative.retrievers.simple_retriever import SimpleRetriever from airbyte_cdk.sources.types import Record, StreamSlice @@ -35,7 +40,9 @@ class AsyncHttpJobRepository(AsyncJobRepository): urls_extractor: DpathExtractor job_timeout: Optional[timedelta] = None - record_extractor: RecordExtractor = field(init=False, repr=False, default_factory=lambda: ResponseToFileExtractor()) + record_extractor: RecordExtractor = field( + init=False, repr=False, default_factory=lambda: ResponseToFileExtractor() + ) def __post_init__(self) -> None: self._create_job_response_by_id: Dict[str, Response] = {} @@ -55,7 +62,9 @@ def _get_validated_polling_response(self, stream_slice: StreamSlice) -> requests AirbyteTracedException: If the polling request returns an empty response. """ - polling_response: Optional[requests.Response] = self.polling_requester.send_request(stream_slice=stream_slice) + polling_response: Optional[requests.Response] = self.polling_requester.send_request( + stream_slice=stream_slice + ) if polling_response is None: raise AirbyteTracedException( internal_message="Polling Requester received an empty Response.", @@ -100,7 +109,9 @@ def _start_job_and_validate_response(self, stream_slice: StreamSlice) -> request AirbyteTracedException: If no response is received from the creation requester. """ - response: Optional[requests.Response] = self.creation_requester.send_request(stream_slice=stream_slice) + response: Optional[requests.Response] = self.creation_requester.send_request( + stream_slice=stream_slice + ) if not response: raise AirbyteTracedException( internal_message="Always expect a response or an exception from creation_requester", @@ -146,9 +157,17 @@ def update_jobs_status(self, jobs: Iterable[AsyncJob]) -> None: job_status: AsyncJobStatus = self._get_validated_job_status(polling_response) if job_status != job.status(): - lazy_log(LOGGER, logging.DEBUG, lambda: f"Status of job {job.api_job_id()} changed from {job.status()} to {job_status}") + lazy_log( + LOGGER, + logging.DEBUG, + lambda: f"Status of job {job.api_job_id()} changed from {job.status()} to {job_status}", + ) else: - lazy_log(LOGGER, logging.DEBUG, lambda: f"Status of job {job.api_job_id()} is still {job.status()}") + lazy_log( + LOGGER, + logging.DEBUG, + lambda: f"Status of job {job.api_job_id()} is still {job.status()}", + ) job.update_status(job_status) if job_status == AsyncJobStatus.COMPLETED: @@ -166,7 +185,9 @@ def fetch_records(self, job: AsyncJob) -> Iterable[Mapping[str, Any]]: """ - for url in self.urls_extractor.extract_records(self._polling_job_response_by_id[job.api_job_id()]): + for url in self.urls_extractor.extract_records( + self._polling_job_response_by_id[job.api_job_id()] + ): stream_slice: StreamSlice = StreamSlice(partition={"url": url}, cursor_slice={}) for message in self.download_retriever.read_records({}, stream_slice): if isinstance(message, Record): diff --git a/airbyte_cdk/sources/declarative/requesters/http_requester.py b/airbyte_cdk/sources/declarative/requesters/http_requester.py index 15d20981..51ece9f9 100644 --- a/airbyte_cdk/sources/declarative/requesters/http_requester.py +++ b/airbyte_cdk/sources/declarative/requesters/http_requester.py @@ -9,7 +9,10 @@ from urllib.parse import urljoin import requests -from airbyte_cdk.sources.declarative.auth.declarative_authenticator import DeclarativeAuthenticator, NoAuth +from airbyte_cdk.sources.declarative.auth.declarative_authenticator import ( + DeclarativeAuthenticator, + NoAuth, +) from airbyte_cdk.sources.declarative.decoders import Decoder from airbyte_cdk.sources.declarative.decoders.json_decoder import JsonDecoder from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString @@ -62,13 +65,19 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: self._url_base = InterpolatedString.create(self.url_base, parameters=parameters) self._path = InterpolatedString.create(self.path, parameters=parameters) if self.request_options_provider is None: - self._request_options_provider = InterpolatedRequestOptionsProvider(config=self.config, parameters=parameters) + self._request_options_provider = InterpolatedRequestOptionsProvider( + config=self.config, parameters=parameters + ) elif isinstance(self.request_options_provider, dict): - self._request_options_provider = InterpolatedRequestOptionsProvider(config=self.config, **self.request_options_provider) + self._request_options_provider = InterpolatedRequestOptionsProvider( + config=self.config, **self.request_options_provider + ) else: self._request_options_provider = self.request_options_provider self._authenticator = self.authenticator or NoAuth(parameters=parameters) - self._http_method = HttpMethod[self.http_method] if isinstance(self.http_method, str) else self.http_method + self._http_method = ( + HttpMethod[self.http_method] if isinstance(self.http_method, str) else self.http_method + ) self.error_handler = self.error_handler self._parameters = parameters @@ -103,9 +112,17 @@ def get_url_base(self) -> str: return os.path.join(self._url_base.eval(self.config), "") def get_path( - self, *, stream_state: Optional[StreamState], stream_slice: Optional[StreamSlice], next_page_token: Optional[Mapping[str, Any]] + self, + *, + stream_state: Optional[StreamState], + stream_slice: Optional[StreamSlice], + next_page_token: Optional[Mapping[str, Any]], ) -> str: - kwargs = {"stream_state": stream_state, "stream_slice": stream_slice, "next_page_token": next_page_token} + kwargs = { + "stream_state": stream_state, + "stream_slice": stream_slice, + "next_page_token": next_page_token, + } path = str(self._path.eval(self.config, **kwargs)) return path.lstrip("/") @@ -144,7 +161,9 @@ def get_request_body_data( # type: ignore ) -> Union[Mapping[str, Any], str]: return ( self._request_options_provider.get_request_body_data( - stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token + stream_state=stream_state, + stream_slice=stream_slice, + next_page_token=next_page_token, ) or {} ) @@ -181,7 +200,11 @@ def _get_request_options( """ return combine_mappings( [ - requester_method(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token), + requester_method( + stream_state=stream_state, + stream_slice=stream_slice, + next_page_token=next_page_token, + ), auth_options_method(), extra_options, ] @@ -223,14 +246,21 @@ def _request_params( E.g: you might want to define query parameters for paging if next_page_token is not None. """ options = self._get_request_options( - stream_state, stream_slice, next_page_token, self.get_request_params, self.get_authenticator().get_request_params, extra_params + stream_state, + stream_slice, + next_page_token, + self.get_request_params, + self.get_authenticator().get_request_params, + extra_params, ) if isinstance(options, str): raise ValueError("Request params cannot be a string") for k, v in options.items(): if isinstance(v, (dict,)): - raise ValueError(f"Invalid value for `{k}` parameter. The values of request params cannot be an object.") + raise ValueError( + f"Invalid value for `{k}` parameter. The values of request params cannot be an object." + ) return options @@ -305,13 +335,26 @@ def send_request( http_method=self.get_method().value, url=self._join_url( self.get_url_base(), - path or self.get_path(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token), + path + or self.get_path( + stream_state=stream_state, + stream_slice=stream_slice, + next_page_token=next_page_token, + ), ), request_kwargs={"stream": self.stream_response}, - headers=self._request_headers(stream_state, stream_slice, next_page_token, request_headers), - params=self._request_params(stream_state, stream_slice, next_page_token, request_params), - json=self._request_body_json(stream_state, stream_slice, next_page_token, request_body_json), - data=self._request_body_data(stream_state, stream_slice, next_page_token, request_body_data), + headers=self._request_headers( + stream_state, stream_slice, next_page_token, request_headers + ), + params=self._request_params( + stream_state, stream_slice, next_page_token, request_params + ), + json=self._request_body_json( + stream_state, stream_slice, next_page_token, request_body_json + ), + data=self._request_body_data( + stream_state, stream_slice, next_page_token, request_body_data + ), dedupe_query_params=True, log_formatter=log_formatter, exit_on_rate_limit=self._exit_on_rate_limit, diff --git a/airbyte_cdk/sources/declarative/requesters/paginators/default_paginator.py b/airbyte_cdk/sources/declarative/requesters/paginators/default_paginator.py index c92e9977..e26f32de 100644 --- a/airbyte_cdk/sources/declarative/requesters/paginators/default_paginator.py +++ b/airbyte_cdk/sources/declarative/requesters/paginators/default_paginator.py @@ -6,11 +6,20 @@ from typing import Any, Mapping, MutableMapping, Optional, Union import requests -from airbyte_cdk.sources.declarative.decoders import Decoder, JsonDecoder, PaginationDecoderDecorator +from airbyte_cdk.sources.declarative.decoders import ( + Decoder, + JsonDecoder, + PaginationDecoderDecorator, +) from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString from airbyte_cdk.sources.declarative.requesters.paginators.paginator import Paginator -from airbyte_cdk.sources.declarative.requesters.paginators.strategies.pagination_strategy import PaginationStrategy -from airbyte_cdk.sources.declarative.requesters.request_option import RequestOption, RequestOptionType +from airbyte_cdk.sources.declarative.requesters.paginators.strategies.pagination_strategy import ( + PaginationStrategy, +) +from airbyte_cdk.sources.declarative.requesters.request_option import ( + RequestOption, + RequestOptionType, +) from airbyte_cdk.sources.declarative.requesters.request_path import RequestPath from airbyte_cdk.sources.types import Config, Record, StreamSlice, StreamState @@ -89,13 +98,17 @@ class DefaultPaginator(Paginator): config: Config url_base: Union[InterpolatedString, str] parameters: InitVar[Mapping[str, Any]] - decoder: Decoder = field(default_factory=lambda: PaginationDecoderDecorator(decoder=JsonDecoder(parameters={}))) + decoder: Decoder = field( + default_factory=lambda: PaginationDecoderDecorator(decoder=JsonDecoder(parameters={})) + ) page_size_option: Optional[RequestOption] = None page_token_option: Optional[Union[RequestPath, RequestOption]] = None def __post_init__(self, parameters: Mapping[str, Any]) -> None: if self.page_size_option and not self.pagination_strategy.get_page_size(): - raise ValueError("page_size_option cannot be set if the pagination strategy does not have a page_size") + raise ValueError( + "page_size_option cannot be set if the pagination strategy does not have a page_size" + ) if isinstance(self.url_base, str): self.url_base = InterpolatedString(string=self.url_base, parameters=parameters) self._token: Optional[Any] = self.pagination_strategy.initial_token @@ -103,14 +116,20 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: def next_page_token( self, response: requests.Response, last_page_size: int, last_record: Optional[Record] ) -> Optional[Mapping[str, Any]]: - self._token = self.pagination_strategy.next_page_token(response, last_page_size, last_record) + self._token = self.pagination_strategy.next_page_token( + response, last_page_size, last_record + ) if self._token: return {"next_page_token": self._token} else: return None def path(self) -> Optional[str]: - if self._token and self.page_token_option and isinstance(self.page_token_option, RequestPath): + if ( + self._token + and self.page_token_option + and isinstance(self.page_token_option, RequestPath) + ): # Replace url base to only return the path return str(self._token).replace(self.url_base.eval(self.config), "") # type: ignore # url_base is casted to a InterpolatedString in __post_init__ else: @@ -169,8 +188,14 @@ def _get_request_options(self, option_type: RequestOptionType) -> MutableMapping and self.page_token_option.inject_into == option_type ): options[self.page_token_option.field_name.eval(config=self.config)] = self._token # type: ignore # field_name is always cast to an interpolated string - if self.page_size_option and self.pagination_strategy.get_page_size() and self.page_size_option.inject_into == option_type: - options[self.page_size_option.field_name.eval(config=self.config)] = self.pagination_strategy.get_page_size() # type: ignore # field_name is always cast to an interpolated string + if ( + self.page_size_option + and self.pagination_strategy.get_page_size() + and self.page_size_option.inject_into == option_type + ): + options[self.page_size_option.field_name.eval(config=self.config)] = ( + self.pagination_strategy.get_page_size() + ) # type: ignore # field_name is always cast to an interpolated string return options @@ -184,7 +209,9 @@ class PaginatorTestReadDecorator(Paginator): def __init__(self, decorated: Paginator, maximum_number_of_pages: int = 5) -> None: if maximum_number_of_pages and maximum_number_of_pages < 1: - raise ValueError(f"The maximum number of pages on a test read needs to be strictly positive. Got {maximum_number_of_pages}") + raise ValueError( + f"The maximum number of pages on a test read needs to be strictly positive. Got {maximum_number_of_pages}" + ) self._maximum_number_of_pages = maximum_number_of_pages self._decorated = decorated self._page_count = self._PAGE_COUNT_BEFORE_FIRST_NEXT_CALL @@ -208,7 +235,9 @@ def get_request_params( stream_slice: Optional[StreamSlice] = None, next_page_token: Optional[Mapping[str, Any]] = None, ) -> Mapping[str, Any]: - return self._decorated.get_request_params(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token) + return self._decorated.get_request_params( + stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token + ) def get_request_headers( self, @@ -217,7 +246,9 @@ def get_request_headers( stream_slice: Optional[StreamSlice] = None, next_page_token: Optional[Mapping[str, Any]] = None, ) -> Mapping[str, str]: - return self._decorated.get_request_headers(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token) + return self._decorated.get_request_headers( + stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token + ) def get_request_body_data( self, @@ -226,7 +257,9 @@ def get_request_body_data( stream_slice: Optional[StreamSlice] = None, next_page_token: Optional[Mapping[str, Any]] = None, ) -> Union[Mapping[str, Any], str]: - return self._decorated.get_request_body_data(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token) + return self._decorated.get_request_body_data( + stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token + ) def get_request_body_json( self, @@ -235,7 +268,9 @@ def get_request_body_json( stream_slice: Optional[StreamSlice] = None, next_page_token: Optional[Mapping[str, Any]] = None, ) -> Mapping[str, Any]: - return self._decorated.get_request_body_json(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token) + return self._decorated.get_request_body_json( + stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token + ) def reset(self, reset_value: Optional[Any] = None) -> None: self._decorated.reset() diff --git a/airbyte_cdk/sources/declarative/requesters/paginators/no_pagination.py b/airbyte_cdk/sources/declarative/requesters/paginators/no_pagination.py index 4065902f..db4eb0ed 100644 --- a/airbyte_cdk/sources/declarative/requesters/paginators/no_pagination.py +++ b/airbyte_cdk/sources/declarative/requesters/paginators/no_pagination.py @@ -57,7 +57,9 @@ def get_request_body_json( ) -> Mapping[str, Any]: return {} - def next_page_token(self, response: requests.Response, last_page_size: int, last_record: Optional[Record]) -> Mapping[str, Any]: + def next_page_token( + self, response: requests.Response, last_page_size: int, last_record: Optional[Record] + ) -> Mapping[str, Any]: return {} def reset(self, reset_value: Optional[Any] = None) -> None: diff --git a/airbyte_cdk/sources/declarative/requesters/paginators/paginator.py b/airbyte_cdk/sources/declarative/requesters/paginators/paginator.py index aebc8241..1bf17d1d 100644 --- a/airbyte_cdk/sources/declarative/requesters/paginators/paginator.py +++ b/airbyte_cdk/sources/declarative/requesters/paginators/paginator.py @@ -7,7 +7,9 @@ from typing import Any, Mapping, Optional import requests -from airbyte_cdk.sources.declarative.requesters.request_options.request_options_provider import RequestOptionsProvider +from airbyte_cdk.sources.declarative.requesters.request_options.request_options_provider import ( + RequestOptionsProvider, +) from airbyte_cdk.sources.types import Record diff --git a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/cursor_pagination_strategy.py b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/cursor_pagination_strategy.py index 7ba3c1d0..a53a044b 100644 --- a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/cursor_pagination_strategy.py +++ b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/cursor_pagination_strategy.py @@ -6,10 +6,16 @@ from typing import Any, Dict, Mapping, Optional, Union import requests -from airbyte_cdk.sources.declarative.decoders import Decoder, JsonDecoder, PaginationDecoderDecorator +from airbyte_cdk.sources.declarative.decoders import ( + Decoder, + JsonDecoder, + PaginationDecoderDecorator, +) from airbyte_cdk.sources.declarative.interpolation.interpolated_boolean import InterpolatedBoolean from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString -from airbyte_cdk.sources.declarative.requesters.paginators.strategies.pagination_strategy import PaginationStrategy +from airbyte_cdk.sources.declarative.requesters.paginators.strategies.pagination_strategy import ( + PaginationStrategy, +) from airbyte_cdk.sources.types import Config, Record @@ -31,7 +37,9 @@ class CursorPaginationStrategy(PaginationStrategy): parameters: InitVar[Mapping[str, Any]] page_size: Optional[int] = None stop_condition: Optional[Union[InterpolatedBoolean, str]] = None - decoder: Decoder = field(default_factory=lambda: PaginationDecoderDecorator(decoder=JsonDecoder(parameters={}))) + decoder: Decoder = field( + default_factory=lambda: PaginationDecoderDecorator(decoder=JsonDecoder(parameters={})) + ) def __post_init__(self, parameters: Mapping[str, Any]) -> None: self._initial_cursor = None @@ -40,7 +48,9 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: else: self._cursor_value = self.cursor_value if isinstance(self.stop_condition, str): - self._stop_condition: Optional[InterpolatedBoolean] = InterpolatedBoolean(condition=self.stop_condition, parameters=parameters) + self._stop_condition: Optional[InterpolatedBoolean] = InterpolatedBoolean( + condition=self.stop_condition, parameters=parameters + ) else: self._stop_condition = self.stop_condition @@ -48,7 +58,9 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: def initial_token(self) -> Optional[Any]: return self._initial_cursor - def next_page_token(self, response: requests.Response, last_page_size: int, last_record: Optional[Record]) -> Optional[Any]: + def next_page_token( + self, response: requests.Response, last_page_size: int, last_record: Optional[Record] + ) -> Optional[Any]: decoded_response = next(self.decoder.decode(response)) # The default way that link is presented in requests.Response is a string of various links (last, next, etc). This diff --git a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/offset_increment.py b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/offset_increment.py index 295b0908..9f24b961 100644 --- a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/offset_increment.py +++ b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/offset_increment.py @@ -6,9 +6,15 @@ from typing import Any, Mapping, Optional, Union import requests -from airbyte_cdk.sources.declarative.decoders import Decoder, JsonDecoder, PaginationDecoderDecorator +from airbyte_cdk.sources.declarative.decoders import ( + Decoder, + JsonDecoder, + PaginationDecoderDecorator, +) from airbyte_cdk.sources.declarative.interpolation import InterpolatedString -from airbyte_cdk.sources.declarative.requesters.paginators.strategies.pagination_strategy import PaginationStrategy +from airbyte_cdk.sources.declarative.requesters.paginators.strategies.pagination_strategy import ( + PaginationStrategy, +) from airbyte_cdk.sources.types import Config, Record @@ -39,14 +45,18 @@ class OffsetIncrement(PaginationStrategy): config: Config page_size: Optional[Union[str, int]] parameters: InitVar[Mapping[str, Any]] - decoder: Decoder = field(default_factory=lambda: PaginationDecoderDecorator(decoder=JsonDecoder(parameters={}))) + decoder: Decoder = field( + default_factory=lambda: PaginationDecoderDecorator(decoder=JsonDecoder(parameters={})) + ) inject_on_first_request: bool = False def __post_init__(self, parameters: Mapping[str, Any]) -> None: self._offset = 0 page_size = str(self.page_size) if isinstance(self.page_size, int) else self.page_size if page_size: - self._page_size: Optional[InterpolatedString] = InterpolatedString(page_size, parameters=parameters) + self._page_size: Optional[InterpolatedString] = InterpolatedString( + page_size, parameters=parameters + ) else: self._page_size = None @@ -56,11 +66,16 @@ def initial_token(self) -> Optional[Any]: return self._offset return None - def next_page_token(self, response: requests.Response, last_page_size: int, last_record: Optional[Record]) -> Optional[Any]: + def next_page_token( + self, response: requests.Response, last_page_size: int, last_record: Optional[Record] + ) -> Optional[Any]: decoded_response = next(self.decoder.decode(response)) # Stop paginating when there are fewer records than the page size or the current page has no records - if (self._page_size and last_page_size < self._page_size.eval(self.config, response=decoded_response)) or last_page_size == 0: + if ( + self._page_size + and last_page_size < self._page_size.eval(self.config, response=decoded_response) + ) or last_page_size == 0: return None else: self._offset += last_page_size @@ -68,7 +83,9 @@ def next_page_token(self, response: requests.Response, last_page_size: int, last def reset(self, reset_value: Optional[Any] = 0) -> None: if not isinstance(reset_value, int): - raise ValueError(f"Reset value {reset_value} for OffsetIncrement pagination strategy was not an integer") + raise ValueError( + f"Reset value {reset_value} for OffsetIncrement pagination strategy was not an integer" + ) else: self._offset = reset_value diff --git a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/page_increment.py b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/page_increment.py index 978ac1ab..1ce0a1c8 100644 --- a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/page_increment.py +++ b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/page_increment.py @@ -7,7 +7,9 @@ import requests from airbyte_cdk.sources.declarative.interpolation import InterpolatedString -from airbyte_cdk.sources.declarative.requesters.paginators.strategies.pagination_strategy import PaginationStrategy +from airbyte_cdk.sources.declarative.requesters.paginators.strategies.pagination_strategy import ( + PaginationStrategy, +) from airbyte_cdk.sources.types import Config, Record @@ -43,7 +45,9 @@ def initial_token(self) -> Optional[Any]: return self._page return None - def next_page_token(self, response: requests.Response, last_page_size: int, last_record: Optional[Record]) -> Optional[Any]: + def next_page_token( + self, response: requests.Response, last_page_size: int, last_record: Optional[Record] + ) -> Optional[Any]: # Stop paginating when there are fewer records than the page size or the current page has no records if (self._page_size and last_page_size < self._page_size) or last_page_size == 0: return None @@ -55,7 +59,9 @@ def reset(self, reset_value: Optional[Any] = None) -> None: if reset_value is None: self._page = self.start_from_page elif not isinstance(reset_value, int): - raise ValueError(f"Reset value {reset_value} for PageIncrement pagination strategy was not an integer") + raise ValueError( + f"Reset value {reset_value} for PageIncrement pagination strategy was not an integer" + ) else: self._page = reset_value diff --git a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/pagination_strategy.py b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/pagination_strategy.py index 135eb481..0b350d33 100644 --- a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/pagination_strategy.py +++ b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/pagination_strategy.py @@ -24,7 +24,9 @@ def initial_token(self) -> Optional[Any]: """ @abstractmethod - def next_page_token(self, response: requests.Response, last_page_size: int, last_record: Optional[Record]) -> Optional[Any]: + def next_page_token( + self, response: requests.Response, last_page_size: int, last_record: Optional[Record] + ) -> Optional[Any]: """ :param response: response to process :param last_page_size: the number of records read from the response diff --git a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/stop_condition.py b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/stop_condition.py index ca79bfd3..3f322aa9 100644 --- a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/stop_condition.py +++ b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/stop_condition.py @@ -7,7 +7,9 @@ import requests from airbyte_cdk.sources.declarative.incremental.declarative_cursor import DeclarativeCursor -from airbyte_cdk.sources.declarative.requesters.paginators.strategies.pagination_strategy import PaginationStrategy +from airbyte_cdk.sources.declarative.requesters.paginators.strategies.pagination_strategy import ( + PaginationStrategy, +) from airbyte_cdk.sources.types import Record @@ -35,7 +37,9 @@ def __init__(self, _delegate: PaginationStrategy, stop_condition: PaginationStop self._delegate = _delegate self._stop_condition = stop_condition - def next_page_token(self, response: requests.Response, last_page_size: int, last_record: Optional[Record]) -> Optional[Any]: + def next_page_token( + self, response: requests.Response, last_page_size: int, last_record: Optional[Record] + ) -> Optional[Any]: # We evaluate in reverse order because the assumption is that most of the APIs using data feed structure will return records in # descending order. In terms of performance/memory, we return the records lazily if last_record and self._stop_condition.is_met(last_record): diff --git a/airbyte_cdk/sources/declarative/requesters/request_options/datetime_based_request_options_provider.py b/airbyte_cdk/sources/declarative/requesters/request_options/datetime_based_request_options_provider.py index 453940e7..5ce7c9a3 100644 --- a/airbyte_cdk/sources/declarative/requesters/request_options/datetime_based_request_options_provider.py +++ b/airbyte_cdk/sources/declarative/requesters/request_options/datetime_based_request_options_provider.py @@ -6,8 +6,13 @@ from typing import Any, Mapping, MutableMapping, Optional, Union from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString -from airbyte_cdk.sources.declarative.requesters.request_option import RequestOption, RequestOptionType -from airbyte_cdk.sources.declarative.requesters.request_options.request_options_provider import RequestOptionsProvider +from airbyte_cdk.sources.declarative.requesters.request_option import ( + RequestOption, + RequestOptionType, +) +from airbyte_cdk.sources.declarative.requesters.request_options.request_options_provider import ( + RequestOptionsProvider, +) from airbyte_cdk.sources.types import Config, StreamSlice, StreamState @@ -26,8 +31,12 @@ class DatetimeBasedRequestOptionsProvider(RequestOptionsProvider): partition_field_end: Optional[str] = None def __post_init__(self, parameters: Mapping[str, Any]) -> None: - self._partition_field_start = InterpolatedString.create(self.partition_field_start or "start_time", parameters=parameters) - self._partition_field_end = InterpolatedString.create(self.partition_field_end or "end_time", parameters=parameters) + self._partition_field_start = InterpolatedString.create( + self.partition_field_start or "start_time", parameters=parameters + ) + self._partition_field_end = InterpolatedString.create( + self.partition_field_end or "end_time", parameters=parameters + ) def get_request_params( self, @@ -65,7 +74,9 @@ def get_request_body_json( ) -> Mapping[str, Any]: return self._get_request_options(RequestOptionType.body_json, stream_slice) - def _get_request_options(self, option_type: RequestOptionType, stream_slice: Optional[StreamSlice]) -> Mapping[str, Any]: + def _get_request_options( + self, option_type: RequestOptionType, stream_slice: Optional[StreamSlice] + ) -> Mapping[str, Any]: options: MutableMapping[str, Any] = {} if not stream_slice: return options diff --git a/airbyte_cdk/sources/declarative/requesters/request_options/default_request_options_provider.py b/airbyte_cdk/sources/declarative/requesters/request_options/default_request_options_provider.py index 42d8ee70..449da977 100644 --- a/airbyte_cdk/sources/declarative/requesters/request_options/default_request_options_provider.py +++ b/airbyte_cdk/sources/declarative/requesters/request_options/default_request_options_provider.py @@ -5,7 +5,9 @@ from dataclasses import InitVar, dataclass from typing import Any, Mapping, Optional, Union -from airbyte_cdk.sources.declarative.requesters.request_options.request_options_provider import RequestOptionsProvider +from airbyte_cdk.sources.declarative.requesters.request_options.request_options_provider import ( + RequestOptionsProvider, +) from airbyte_cdk.sources.types import StreamSlice, StreamState diff --git a/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_nested_request_input_provider.py b/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_nested_request_input_provider.py index 1880ce82..6403417c 100644 --- a/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_nested_request_input_provider.py +++ b/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_nested_request_input_provider.py @@ -5,7 +5,10 @@ from dataclasses import InitVar, dataclass, field from typing import Any, Mapping, Optional, Union -from airbyte_cdk.sources.declarative.interpolation.interpolated_nested_mapping import InterpolatedNestedMapping, NestedMapping +from airbyte_cdk.sources.declarative.interpolation.interpolated_nested_mapping import ( + InterpolatedNestedMapping, + NestedMapping, +) from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString from airbyte_cdk.sources.types import Config, StreamSlice, StreamState @@ -19,15 +22,23 @@ class InterpolatedNestedRequestInputProvider: parameters: InitVar[Mapping[str, Any]] request_inputs: Optional[Union[str, NestedMapping]] = field(default=None) config: Config = field(default_factory=dict) - _interpolator: Optional[Union[InterpolatedString, InterpolatedNestedMapping]] = field(init=False, repr=False, default=None) - _request_inputs: Optional[Union[str, NestedMapping]] = field(init=False, repr=False, default=None) + _interpolator: Optional[Union[InterpolatedString, InterpolatedNestedMapping]] = field( + init=False, repr=False, default=None + ) + _request_inputs: Optional[Union[str, NestedMapping]] = field( + init=False, repr=False, default=None + ) def __post_init__(self, parameters: Mapping[str, Any]) -> None: self._request_inputs = self.request_inputs or {} if isinstance(self._request_inputs, str): - self._interpolator = InterpolatedString(self._request_inputs, default="", parameters=parameters) + self._interpolator = InterpolatedString( + self._request_inputs, default="", parameters=parameters + ) else: - self._interpolator = InterpolatedNestedMapping(self._request_inputs, parameters=parameters) + self._interpolator = InterpolatedNestedMapping( + self._request_inputs, parameters=parameters + ) def eval_request_inputs( self, @@ -43,5 +54,9 @@ def eval_request_inputs( :param next_page_token: The pagination token :return: The request inputs to set on an outgoing HTTP request """ - kwargs = {"stream_state": stream_state, "stream_slice": stream_slice, "next_page_token": next_page_token} + kwargs = { + "stream_state": stream_state, + "stream_slice": stream_slice, + "next_page_token": next_page_token, + } return self._interpolator.eval(self.config, **kwargs) # type: ignore # self._interpolator is always initialized with a value and will not be None diff --git a/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_request_input_provider.py b/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_request_input_provider.py index 8a3ed7d8..0278df35 100644 --- a/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_request_input_provider.py +++ b/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_request_input_provider.py @@ -19,13 +19,19 @@ class InterpolatedRequestInputProvider: parameters: InitVar[Mapping[str, Any]] request_inputs: Optional[Union[str, Mapping[str, str]]] = field(default=None) config: Config = field(default_factory=dict) - _interpolator: Optional[Union[InterpolatedString, InterpolatedMapping]] = field(init=False, repr=False, default=None) - _request_inputs: Optional[Union[str, Mapping[str, str]]] = field(init=False, repr=False, default=None) + _interpolator: Optional[Union[InterpolatedString, InterpolatedMapping]] = field( + init=False, repr=False, default=None + ) + _request_inputs: Optional[Union[str, Mapping[str, str]]] = field( + init=False, repr=False, default=None + ) def __post_init__(self, parameters: Mapping[str, Any]) -> None: self._request_inputs = self.request_inputs or {} if isinstance(self._request_inputs, str): - self._interpolator = InterpolatedString(self._request_inputs, default="", parameters=parameters) + self._interpolator = InterpolatedString( + self._request_inputs, default="", parameters=parameters + ) else: self._interpolator = InterpolatedMapping(self._request_inputs, parameters=parameters) @@ -47,9 +53,16 @@ def eval_request_inputs( :param valid_value_types: A tuple of types that the interpolator should allow :return: The request inputs to set on an outgoing HTTP request """ - kwargs = {"stream_state": stream_state, "stream_slice": stream_slice, "next_page_token": next_page_token} + kwargs = { + "stream_state": stream_state, + "stream_slice": stream_slice, + "next_page_token": next_page_token, + } interpolated_value = self._interpolator.eval( # type: ignore # self._interpolator is always initialized with a value and will not be None - self.config, valid_key_types=valid_key_types, valid_value_types=valid_value_types, **kwargs + self.config, + valid_key_types=valid_key_types, + valid_value_types=valid_value_types, + **kwargs, ) if isinstance(interpolated_value, dict): diff --git a/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_request_options_provider.py b/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_request_options_provider.py index 413a8bb1..bd8cfc17 100644 --- a/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_request_options_provider.py +++ b/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_request_options_provider.py @@ -9,8 +9,12 @@ from airbyte_cdk.sources.declarative.requesters.request_options.interpolated_nested_request_input_provider import ( InterpolatedNestedRequestInputProvider, ) -from airbyte_cdk.sources.declarative.requesters.request_options.interpolated_request_input_provider import InterpolatedRequestInputProvider -from airbyte_cdk.sources.declarative.requesters.request_options.request_options_provider import RequestOptionsProvider +from airbyte_cdk.sources.declarative.requesters.request_options.interpolated_request_input_provider import ( + InterpolatedRequestInputProvider, +) +from airbyte_cdk.sources.declarative.requesters.request_options.request_options_provider import ( + RequestOptionsProvider, +) from airbyte_cdk.sources.source import ExperimentalClassWarning from airbyte_cdk.sources.types import Config, StreamSlice, StreamState from deprecated import deprecated @@ -50,7 +54,9 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: self.request_body_json = {} if self.request_body_json and self.request_body_data: - raise ValueError("RequestOptionsProvider should only contain either 'request_body_data' or 'request_body_json' not both") + raise ValueError( + "RequestOptionsProvider should only contain either 'request_body_data' or 'request_body_json' not both" + ) self._parameter_interpolator = InterpolatedRequestInputProvider( config=self.config, request_inputs=self.request_parameters, parameters=parameters @@ -73,7 +79,11 @@ def get_request_params( next_page_token: Optional[Mapping[str, Any]] = None, ) -> MutableMapping[str, Any]: interpolated_value = self._parameter_interpolator.eval_request_inputs( - stream_state, stream_slice, next_page_token, valid_key_types=(str,), valid_value_types=ValidRequestTypes + stream_state, + stream_slice, + next_page_token, + valid_key_types=(str,), + valid_value_types=ValidRequestTypes, ) if isinstance(interpolated_value, dict): return interpolated_value @@ -86,7 +96,9 @@ def get_request_headers( stream_slice: Optional[StreamSlice] = None, next_page_token: Optional[Mapping[str, Any]] = None, ) -> Mapping[str, Any]: - return self._headers_interpolator.eval_request_inputs(stream_state, stream_slice, next_page_token) + return self._headers_interpolator.eval_request_inputs( + stream_state, stream_slice, next_page_token + ) def get_request_body_data( self, @@ -110,9 +122,14 @@ def get_request_body_json( stream_slice: Optional[StreamSlice] = None, next_page_token: Optional[Mapping[str, Any]] = None, ) -> Mapping[str, Any]: - return self._body_json_interpolator.eval_request_inputs(stream_state, stream_slice, next_page_token) + return self._body_json_interpolator.eval_request_inputs( + stream_state, stream_slice, next_page_token + ) - @deprecated("This class is temporary and used to incrementally deliver low-code to concurrent", category=ExperimentalClassWarning) + @deprecated( + "This class is temporary and used to incrementally deliver low-code to concurrent", + category=ExperimentalClassWarning, + ) def request_options_contain_stream_state(self) -> bool: """ Temporary helper method used as we move low-code streams to the concurrent framework. This method determines if @@ -128,7 +145,9 @@ def request_options_contain_stream_state(self) -> bool: ) @staticmethod - def _check_if_interpolation_uses_stream_state(request_input: Optional[Union[RequestInput, NestedMapping]]) -> bool: + def _check_if_interpolation_uses_stream_state( + request_input: Optional[Union[RequestInput, NestedMapping]], + ) -> bool: if not request_input: return False elif isinstance(request_input, str): diff --git a/airbyte_cdk/sources/declarative/requesters/requester.py b/airbyte_cdk/sources/declarative/requesters/requester.py index ef702216..19003a83 100644 --- a/airbyte_cdk/sources/declarative/requesters/requester.py +++ b/airbyte_cdk/sources/declarative/requesters/requester.py @@ -8,7 +8,9 @@ import requests from airbyte_cdk.sources.declarative.auth.declarative_authenticator import DeclarativeAuthenticator -from airbyte_cdk.sources.declarative.requesters.request_options.request_options_provider import RequestOptionsProvider +from airbyte_cdk.sources.declarative.requesters.request_options.request_options_provider import ( + RequestOptionsProvider, +) from airbyte_cdk.sources.types import StreamSlice, StreamState diff --git a/airbyte_cdk/sources/declarative/retrievers/async_retriever.py b/airbyte_cdk/sources/declarative/retrievers/async_retriever.py index 886ef0f9..f3902dfc 100644 --- a/airbyte_cdk/sources/declarative/retrievers/async_retriever.py +++ b/airbyte_cdk/sources/declarative/retrievers/async_retriever.py @@ -5,7 +5,10 @@ from typing import Any, Callable, Iterable, Mapping, Optional from airbyte_cdk.models import FailureType -from airbyte_cdk.sources.declarative.async_job.job_orchestrator import AsyncJobOrchestrator, AsyncPartition +from airbyte_cdk.sources.declarative.async_job.job_orchestrator import ( + AsyncJobOrchestrator, + AsyncPartition, +) from airbyte_cdk.sources.declarative.extractors.record_selector import RecordSelector from airbyte_cdk.sources.declarative.partition_routers import SinglePartitionRouter from airbyte_cdk.sources.declarative.retrievers import Retriever @@ -24,7 +27,9 @@ class AsyncRetriever(Retriever): parameters: InitVar[Mapping[str, Any]] job_orchestrator_factory: Callable[[Iterable[StreamSlice]], AsyncJobOrchestrator] record_selector: RecordSelector - stream_slicer: StreamSlicer = field(default_factory=lambda: SinglePartitionRouter(parameters={})) + stream_slicer: StreamSlicer = field( + default_factory=lambda: SinglePartitionRouter(parameters={}) + ) def __post_init__(self, parameters: Mapping[str, Any]) -> None: self._job_orchestrator_factory = self.job_orchestrator_factory @@ -66,7 +71,9 @@ def _get_stream_state(self) -> StreamState: return self.state - def _validate_and_get_stream_slice_partition(self, stream_slice: Optional[StreamSlice] = None) -> AsyncPartition: + def _validate_and_get_stream_slice_partition( + self, stream_slice: Optional[StreamSlice] = None + ) -> AsyncPartition: """ Validates the stream_slice argument and returns the partition from it. @@ -93,7 +100,8 @@ def stream_slices(self) -> Iterable[Optional[StreamSlice]]: for completed_partition in self._job_orchestrator.create_and_get_completed_partitions(): yield StreamSlice( - partition=dict(completed_partition.stream_slice.partition) | {"partition": completed_partition}, + partition=dict(completed_partition.stream_slice.partition) + | {"partition": completed_partition}, cursor_slice=completed_partition.stream_slice.cursor_slice, ) diff --git a/airbyte_cdk/sources/declarative/retrievers/simple_retriever.py b/airbyte_cdk/sources/declarative/retrievers/simple_retriever.py index 99639d84..530cf5f5 100644 --- a/airbyte_cdk/sources/declarative/retrievers/simple_retriever.py +++ b/airbyte_cdk/sources/declarative/retrievers/simple_retriever.py @@ -6,7 +6,18 @@ from dataclasses import InitVar, dataclass, field from functools import partial from itertools import islice -from typing import Any, Callable, Iterable, List, Mapping, MutableMapping, Optional, Set, Tuple, Union +from typing import ( + Any, + Callable, + Iterable, + List, + Mapping, + MutableMapping, + Optional, + Set, + Tuple, + Union, +) import requests from airbyte_cdk.models import AirbyteMessage @@ -14,10 +25,15 @@ from airbyte_cdk.sources.declarative.incremental import ResumableFullRefreshCursor from airbyte_cdk.sources.declarative.incremental.declarative_cursor import DeclarativeCursor from airbyte_cdk.sources.declarative.interpolation import InterpolatedString -from airbyte_cdk.sources.declarative.partition_routers.single_partition_router import SinglePartitionRouter +from airbyte_cdk.sources.declarative.partition_routers.single_partition_router import ( + SinglePartitionRouter, +) from airbyte_cdk.sources.declarative.requesters.paginators.no_pagination import NoPagination from airbyte_cdk.sources.declarative.requesters.paginators.paginator import Paginator -from airbyte_cdk.sources.declarative.requesters.request_options import DefaultRequestOptionsProvider, RequestOptionsProvider +from airbyte_cdk.sources.declarative.requesters.request_options import ( + DefaultRequestOptionsProvider, + RequestOptionsProvider, +) from airbyte_cdk.sources.declarative.requesters.requester import Requester from airbyte_cdk.sources.declarative.retrievers.retriever import Retriever from airbyte_cdk.sources.declarative.stream_slicers.stream_slicer import StreamSlicer @@ -62,8 +78,12 @@ class SimpleRetriever(Retriever): primary_key: Optional[Union[str, List[str], List[List[str]]]] _primary_key: str = field(init=False, repr=False, default="") paginator: Optional[Paginator] = None - stream_slicer: StreamSlicer = field(default_factory=lambda: SinglePartitionRouter(parameters={})) - request_option_provider: RequestOptionsProvider = field(default_factory=lambda: DefaultRequestOptionsProvider(parameters={})) + stream_slicer: StreamSlicer = field( + default_factory=lambda: SinglePartitionRouter(parameters={}) + ) + request_option_provider: RequestOptionsProvider = field( + default_factory=lambda: DefaultRequestOptionsProvider(parameters={}) + ) cursor: Optional[DeclarativeCursor] = None ignore_stream_slicer_parameters_on_paginated_requests: bool = False @@ -73,7 +93,11 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: self._last_page_size: int = 0 self._last_record: Optional[Record] = None self._parameters = parameters - self._name = InterpolatedString(self._name, parameters=parameters) if isinstance(self._name, str) else self._name + self._name = ( + InterpolatedString(self._name, parameters=parameters) + if isinstance(self._name, str) + else self._name + ) # This mapping is used during a resumable full refresh syncs to indicate whether a partition has started syncing # records. Partitions serve as the key and map to True if they already began processing records @@ -84,7 +108,11 @@ def name(self) -> str: """ :return: Stream name """ - return str(self._name.eval(self.config)) if isinstance(self._name, InterpolatedString) else self._name + return ( + str(self._name.eval(self.config)) + if isinstance(self._name, InterpolatedString) + else self._name + ) @name.setter def name(self, value: str) -> None: @@ -118,10 +146,20 @@ def _get_request_options( """ # FIXME we should eventually remove the usage of stream_state as part of the interpolation mappings = [ - paginator_method(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token), + paginator_method( + stream_state=stream_state, + stream_slice=stream_slice, + next_page_token=next_page_token, + ), ] if not next_page_token or not self.ignore_stream_slicer_parameters_on_paginated_requests: - mappings.append(stream_slicer_method(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token)) + mappings.append( + stream_slicer_method( + stream_state=stream_state, + stream_slice=stream_slice, + next_page_token=next_page_token, + ) + ) return combine_mappings(mappings) def _request_headers( @@ -271,20 +309,35 @@ def _next_page_token(self, response: requests.Response) -> Optional[Mapping[str, return self._paginator.next_page_token(response, self._last_page_size, self._last_record) def _fetch_next_page( - self, stream_state: Mapping[str, Any], stream_slice: StreamSlice, next_page_token: Optional[Mapping[str, Any]] = None + self, + stream_state: Mapping[str, Any], + stream_slice: StreamSlice, + next_page_token: Optional[Mapping[str, Any]] = None, ) -> Optional[requests.Response]: return self.requester.send_request( path=self._paginator_path(), stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token, - request_headers=self._request_headers(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token), - request_params=self._request_params(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token), + request_headers=self._request_headers( + stream_state=stream_state, + stream_slice=stream_slice, + next_page_token=next_page_token, + ), + request_params=self._request_params( + stream_state=stream_state, + stream_slice=stream_slice, + next_page_token=next_page_token, + ), request_body_data=self._request_body_data( - stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token + stream_state=stream_state, + stream_slice=stream_slice, + next_page_token=next_page_token, ), request_body_json=self._request_body_json( - stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token + stream_state=stream_state, + stream_slice=stream_slice, + next_page_token=next_page_token, ), ) @@ -323,10 +376,14 @@ def _read_single_page( if not response: next_page_token: Mapping[str, Any] = {FULL_REFRESH_SYNC_COMPLETE_KEY: True} else: - next_page_token = self._next_page_token(response) or {FULL_REFRESH_SYNC_COMPLETE_KEY: True} + next_page_token = self._next_page_token(response) or { + FULL_REFRESH_SYNC_COMPLETE_KEY: True + } if self.cursor: - self.cursor.close_slice(StreamSlice(cursor_slice=next_page_token, partition=stream_slice.partition)) + self.cursor.close_slice( + StreamSlice(cursor_slice=next_page_token, partition=stream_slice.partition) + ) # Always return an empty generator just in case no records were ever yielded yield from [] @@ -383,7 +440,9 @@ def read_records( # Latest record read, not necessarily within slice boundaries. # TODO Remove once all custom components implement `observe` method. # https://github.com/airbytehq/airbyte-internal-issues/issues/6955 - most_recent_record_from_slice = self._get_most_recent_record(most_recent_record_from_slice, current_record, _slice) + most_recent_record_from_slice = self._get_most_recent_record( + most_recent_record_from_slice, current_record, _slice + ) yield stream_data if self.cursor: @@ -391,13 +450,20 @@ def read_records( return def _get_most_recent_record( - self, current_most_recent: Optional[Record], current_record: Optional[Record], stream_slice: StreamSlice + self, + current_most_recent: Optional[Record], + current_record: Optional[Record], + stream_slice: StreamSlice, ) -> Optional[Record]: if self.cursor and current_record: if not current_most_recent: return current_record else: - return current_most_recent if self.cursor.is_greater_than_or_equal(current_most_recent, current_record) else current_record + return ( + current_most_recent + if self.cursor.is_greater_than_or_equal(current_most_recent, current_record) + else current_record + ) else: return None @@ -482,20 +548,35 @@ def stream_slices(self) -> Iterable[Optional[StreamSlice]]: # type: ignore return islice(super().stream_slices(), self.maximum_number_of_slices) def _fetch_next_page( - self, stream_state: Mapping[str, Any], stream_slice: StreamSlice, next_page_token: Optional[Mapping[str, Any]] = None + self, + stream_state: Mapping[str, Any], + stream_slice: StreamSlice, + next_page_token: Optional[Mapping[str, Any]] = None, ) -> Optional[requests.Response]: return self.requester.send_request( path=self._paginator_path(), stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token, - request_headers=self._request_headers(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token), - request_params=self._request_params(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token), + request_headers=self._request_headers( + stream_state=stream_state, + stream_slice=stream_slice, + next_page_token=next_page_token, + ), + request_params=self._request_params( + stream_state=stream_state, + stream_slice=stream_slice, + next_page_token=next_page_token, + ), request_body_data=self._request_body_data( - stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token + stream_state=stream_state, + stream_slice=stream_slice, + next_page_token=next_page_token, ), request_body_json=self._request_body_json( - stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token + stream_state=stream_state, + stream_slice=stream_slice, + next_page_token=next_page_token, ), log_formatter=lambda response: format_http_message( response, diff --git a/airbyte_cdk/sources/declarative/schema/default_schema_loader.py b/airbyte_cdk/sources/declarative/schema/default_schema_loader.py index 1aa70be1..a9b625e7 100644 --- a/airbyte_cdk/sources/declarative/schema/default_schema_loader.py +++ b/airbyte_cdk/sources/declarative/schema/default_schema_loader.py @@ -41,5 +41,7 @@ def get_json_schema(self) -> Mapping[str, Any]: # A slight hack since we don't directly have the stream name. However, when building the default filepath we assume the # runtime options stores stream name 'name' so we'll do the same here stream_name = self._parameters.get("name", "") - logging.info(f"Could not find schema for stream {stream_name}, defaulting to the empty schema") + logging.info( + f"Could not find schema for stream {stream_name}, defaulting to the empty schema" + ) return {} diff --git a/airbyte_cdk/sources/declarative/spec/spec.py b/airbyte_cdk/sources/declarative/spec/spec.py index 87c8911d..05fa079b 100644 --- a/airbyte_cdk/sources/declarative/spec/spec.py +++ b/airbyte_cdk/sources/declarative/spec/spec.py @@ -5,7 +5,11 @@ from dataclasses import InitVar, dataclass from typing import Any, Mapping, Optional -from airbyte_cdk.models import AdvancedAuth, ConnectorSpecification, ConnectorSpecificationSerializer # type: ignore [attr-defined] +from airbyte_cdk.models import ( + AdvancedAuth, + ConnectorSpecification, + ConnectorSpecificationSerializer, +) # type: ignore [attr-defined] from airbyte_cdk.sources.declarative.models.declarative_component_schema import AuthFlow @@ -29,7 +33,9 @@ def generate_spec(self) -> ConnectorSpecification: Returns the connector specification according the spec block defined in the low code connector manifest. """ - obj: dict[str, Mapping[str, Any] | str | AdvancedAuth] = {"connectionSpecification": self.connection_specification} + obj: dict[str, Mapping[str, Any] | str | AdvancedAuth] = { + "connectionSpecification": self.connection_specification + } if self.documentation_url: obj["documentationUrl"] = self.documentation_url diff --git a/airbyte_cdk/sources/declarative/stream_slicers/stream_slicer.py b/airbyte_cdk/sources/declarative/stream_slicers/stream_slicer.py index a1ecf68a..af9c438f 100644 --- a/airbyte_cdk/sources/declarative/stream_slicers/stream_slicer.py +++ b/airbyte_cdk/sources/declarative/stream_slicers/stream_slicer.py @@ -6,7 +6,9 @@ from dataclasses import dataclass from typing import Iterable -from airbyte_cdk.sources.declarative.requesters.request_options.request_options_provider import RequestOptionsProvider +from airbyte_cdk.sources.declarative.requesters.request_options.request_options_provider import ( + RequestOptionsProvider, +) from airbyte_cdk.sources.types import StreamSlice diff --git a/airbyte_cdk/sources/declarative/transformations/add_fields.py b/airbyte_cdk/sources/declarative/transformations/add_fields.py index 2a69b782..fa920993 100644 --- a/airbyte_cdk/sources/declarative/transformations/add_fields.py +++ b/airbyte_cdk/sources/declarative/transformations/add_fields.py @@ -85,12 +85,16 @@ class AddFields(RecordTransformation): fields: List[AddedFieldDefinition] parameters: InitVar[Mapping[str, Any]] - _parsed_fields: List[ParsedAddFieldDefinition] = field(init=False, repr=False, default_factory=list) + _parsed_fields: List[ParsedAddFieldDefinition] = field( + init=False, repr=False, default_factory=list + ) def __post_init__(self, parameters: Mapping[str, Any]) -> None: for add_field in self.fields: if len(add_field.path) < 1: - raise ValueError(f"Expected a non-zero-length path for the AddFields transformation {add_field}") + raise ValueError( + f"Expected a non-zero-length path for the AddFields transformation {add_field}" + ) if not isinstance(add_field.value, InterpolatedString): if not isinstance(add_field.value, str): @@ -106,7 +110,12 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: ) else: self._parsed_fields.append( - ParsedAddFieldDefinition(add_field.path, add_field.value, value_type=add_field.value_type, parameters={}) + ParsedAddFieldDefinition( + add_field.path, + add_field.value, + value_type=add_field.value_type, + parameters={}, + ) ) def transform( diff --git a/airbyte_cdk/sources/declarative/transformations/remove_fields.py b/airbyte_cdk/sources/declarative/transformations/remove_fields.py index 658d5dd2..8ac20a0d 100644 --- a/airbyte_cdk/sources/declarative/transformations/remove_fields.py +++ b/airbyte_cdk/sources/declarative/transformations/remove_fields.py @@ -44,7 +44,9 @@ class RemoveFields(RecordTransformation): condition: str = "" def __post_init__(self, parameters: Mapping[str, Any]) -> None: - self._filter_interpolator = InterpolatedBoolean(condition=self.condition, parameters=parameters) + self._filter_interpolator = InterpolatedBoolean( + condition=self.condition, parameters=parameters + ) def transform( self, @@ -63,7 +65,9 @@ def transform( dpath.delete( record, pointer, - afilter=(lambda x: self._filter_interpolator.eval(config or {}, property=x)) if self.condition else None, + afilter=(lambda x: self._filter_interpolator.eval(config or {}, property=x)) + if self.condition + else None, ) except dpath.exceptions.PathNotFound: # if the (potentially nested) property does not exist, silently skip diff --git a/airbyte_cdk/sources/declarative/types.py b/airbyte_cdk/sources/declarative/types.py index 91900d18..a4d0aeb1 100644 --- a/airbyte_cdk/sources/declarative/types.py +++ b/airbyte_cdk/sources/declarative/types.py @@ -4,7 +4,14 @@ from __future__ import annotations -from airbyte_cdk.sources.types import Config, ConnectionDefinition, FieldPointer, Record, StreamSlice, StreamState +from airbyte_cdk.sources.types import ( + Config, + ConnectionDefinition, + FieldPointer, + Record, + StreamSlice, + StreamState, +) # Note: This package originally contained class definitions for low-code CDK types, but we promoted them into the Python CDK. # We've migrated connectors in the repository to reference the new location, but these assignments are used to retain backwards diff --git a/airbyte_cdk/sources/declarative/yaml_declarative_source.py b/airbyte_cdk/sources/declarative/yaml_declarative_source.py index a0443b03..cecb91a6 100644 --- a/airbyte_cdk/sources/declarative/yaml_declarative_source.py +++ b/airbyte_cdk/sources/declarative/yaml_declarative_source.py @@ -7,7 +7,9 @@ import yaml from airbyte_cdk.models import AirbyteStateMessage, ConfiguredAirbyteCatalog -from airbyte_cdk.sources.declarative.concurrent_declarative_source import ConcurrentDeclarativeSource +from airbyte_cdk.sources.declarative.concurrent_declarative_source import ( + ConcurrentDeclarativeSource, +) from airbyte_cdk.sources.types import ConnectionDefinition diff --git a/airbyte_cdk/sources/embedded/base_integration.py b/airbyte_cdk/sources/embedded/base_integration.py index 79c9bd85..c2e67408 100644 --- a/airbyte_cdk/sources/embedded/base_integration.py +++ b/airbyte_cdk/sources/embedded/base_integration.py @@ -7,7 +7,11 @@ from airbyte_cdk.connector import TConfig from airbyte_cdk.models import AirbyteRecordMessage, AirbyteStateMessage, SyncMode, Type -from airbyte_cdk.sources.embedded.catalog import create_configured_catalog, get_stream, get_stream_names +from airbyte_cdk.sources.embedded.catalog import ( + create_configured_catalog, + get_stream, + get_stream_names, +) from airbyte_cdk.sources.embedded.runner import SourceRunner from airbyte_cdk.sources.embedded.tools import get_defined_id from airbyte_cdk.sources.utils.schema_helpers import check_config_against_spec_or_exit @@ -31,11 +35,15 @@ def _handle_record(self, record: AirbyteRecordMessage, id: Optional[str]) -> Opt """ pass - def _load_data(self, stream_name: str, state: Optional[AirbyteStateMessage] = None) -> Iterable[TOutput]: + def _load_data( + self, stream_name: str, state: Optional[AirbyteStateMessage] = None + ) -> Iterable[TOutput]: catalog = self.source.discover(self.config) stream = get_stream(catalog, stream_name) if not stream: - raise ValueError(f"Stream {stream_name} not found, the following streams are available: {', '.join(get_stream_names(catalog))}") + raise ValueError( + f"Stream {stream_name} not found, the following streams are available: {', '.join(get_stream_names(catalog))}" + ) if SyncMode.incremental not in stream.supported_sync_modes: configured_catalog = create_configured_catalog(stream, sync_mode=SyncMode.full_refresh) else: @@ -43,7 +51,9 @@ def _load_data(self, stream_name: str, state: Optional[AirbyteStateMessage] = No for message in self.source.read(self.config, configured_catalog, state): if message.type == Type.RECORD: - output = self._handle_record(message.record, get_defined_id(stream, message.record.data)) # type: ignore[union-attr] # record has `data` + output = self._handle_record( + message.record, get_defined_id(stream, message.record.data) + ) # type: ignore[union-attr] # record has `data` if output: yield output elif message.type is Type.STATE and message.state: diff --git a/airbyte_cdk/sources/embedded/catalog.py b/airbyte_cdk/sources/embedded/catalog.py index 765e9b26..62c7a623 100644 --- a/airbyte_cdk/sources/embedded/catalog.py +++ b/airbyte_cdk/sources/embedded/catalog.py @@ -31,15 +31,27 @@ def to_configured_stream( primary_key: Optional[List[List[str]]] = None, ) -> ConfiguredAirbyteStream: return ConfiguredAirbyteStream( - stream=stream, sync_mode=sync_mode, destination_sync_mode=destination_sync_mode, cursor_field=cursor_field, primary_key=primary_key + stream=stream, + sync_mode=sync_mode, + destination_sync_mode=destination_sync_mode, + cursor_field=cursor_field, + primary_key=primary_key, ) -def to_configured_catalog(configured_streams: List[ConfiguredAirbyteStream]) -> ConfiguredAirbyteCatalog: +def to_configured_catalog( + configured_streams: List[ConfiguredAirbyteStream], +) -> ConfiguredAirbyteCatalog: return ConfiguredAirbyteCatalog(streams=configured_streams) -def create_configured_catalog(stream: AirbyteStream, sync_mode: SyncMode = SyncMode.full_refresh) -> ConfiguredAirbyteCatalog: - configured_streams = [to_configured_stream(stream, sync_mode=sync_mode, primary_key=stream.source_defined_primary_key)] +def create_configured_catalog( + stream: AirbyteStream, sync_mode: SyncMode = SyncMode.full_refresh +) -> ConfiguredAirbyteCatalog: + configured_streams = [ + to_configured_stream( + stream, sync_mode=sync_mode, primary_key=stream.source_defined_primary_key + ) + ] return to_configured_catalog(configured_streams) diff --git a/airbyte_cdk/sources/embedded/runner.py b/airbyte_cdk/sources/embedded/runner.py index c64e66ed..43217f15 100644 --- a/airbyte_cdk/sources/embedded/runner.py +++ b/airbyte_cdk/sources/embedded/runner.py @@ -8,7 +8,13 @@ from typing import Generic, Iterable, Optional from airbyte_cdk.connector import TConfig -from airbyte_cdk.models import AirbyteCatalog, AirbyteMessage, AirbyteStateMessage, ConfiguredAirbyteCatalog, ConnectorSpecification +from airbyte_cdk.models import ( + AirbyteCatalog, + AirbyteMessage, + AirbyteStateMessage, + ConfiguredAirbyteCatalog, + ConnectorSpecification, +) from airbyte_cdk.sources.source import Source @@ -22,7 +28,12 @@ def discover(self, config: TConfig) -> AirbyteCatalog: pass @abstractmethod - def read(self, config: TConfig, catalog: ConfiguredAirbyteCatalog, state: Optional[AirbyteStateMessage]) -> Iterable[AirbyteMessage]: + def read( + self, + config: TConfig, + catalog: ConfiguredAirbyteCatalog, + state: Optional[AirbyteStateMessage], + ) -> Iterable[AirbyteMessage]: pass @@ -37,5 +48,10 @@ def spec(self) -> ConnectorSpecification: def discover(self, config: TConfig) -> AirbyteCatalog: return self._source.discover(self._logger, config) - def read(self, config: TConfig, catalog: ConfiguredAirbyteCatalog, state: Optional[AirbyteStateMessage]) -> Iterable[AirbyteMessage]: + def read( + self, + config: TConfig, + catalog: ConfiguredAirbyteCatalog, + state: Optional[AirbyteStateMessage], + ) -> Iterable[AirbyteMessage]: return self._source.read(self._logger, config, catalog, state=[state] if state else []) diff --git a/airbyte_cdk/sources/embedded/tools.py b/airbyte_cdk/sources/embedded/tools.py index 39d70c11..6bffa1a0 100644 --- a/airbyte_cdk/sources/embedded/tools.py +++ b/airbyte_cdk/sources/embedded/tools.py @@ -8,7 +8,9 @@ from airbyte_cdk.models import AirbyteStream -def get_first(iterable: Iterable[Any], predicate: Callable[[Any], bool] = lambda m: True) -> Optional[Any]: +def get_first( + iterable: Iterable[Any], predicate: Callable[[Any], bool] = lambda m: True +) -> Optional[Any]: return next(filter(predicate, iterable), None) diff --git a/airbyte_cdk/sources/file_based/availability_strategy/abstract_file_based_availability_strategy.py b/airbyte_cdk/sources/file_based/availability_strategy/abstract_file_based_availability_strategy.py index ba26745e..c0234ca1 100644 --- a/airbyte_cdk/sources/file_based/availability_strategy/abstract_file_based_availability_strategy.py +++ b/airbyte_cdk/sources/file_based/availability_strategy/abstract_file_based_availability_strategy.py @@ -22,7 +22,9 @@ class AbstractFileBasedAvailabilityStrategy(AvailabilityStrategy): @abstractmethod - def check_availability(self, stream: Stream, logger: logging.Logger, _: Optional[Source]) -> Tuple[bool, Optional[str]]: + def check_availability( + self, stream: Stream, logger: logging.Logger, _: Optional[Source] + ) -> Tuple[bool, Optional[str]]: """ Perform a connection check for the stream. @@ -48,10 +50,16 @@ def __init__(self, stream: "AbstractFileBasedStream"): self.stream = stream def check_availability(self, logger: logging.Logger) -> StreamAvailability: - is_available, reason = self.stream.availability_strategy.check_availability(self.stream, logger, None) + is_available, reason = self.stream.availability_strategy.check_availability( + self.stream, logger, None + ) if is_available: return StreamAvailable() return StreamUnavailable(reason or "") - def check_availability_and_parsability(self, logger: logging.Logger) -> Tuple[bool, Optional[str]]: - return self.stream.availability_strategy.check_availability_and_parsability(self.stream, logger, None) + def check_availability_and_parsability( + self, logger: logging.Logger + ) -> Tuple[bool, Optional[str]]: + return self.stream.availability_strategy.check_availability_and_parsability( + self.stream, logger, None + ) diff --git a/airbyte_cdk/sources/file_based/availability_strategy/default_file_based_availability_strategy.py b/airbyte_cdk/sources/file_based/availability_strategy/default_file_based_availability_strategy.py index 20ddd114..cf985d9e 100644 --- a/airbyte_cdk/sources/file_based/availability_strategy/default_file_based_availability_strategy.py +++ b/airbyte_cdk/sources/file_based/availability_strategy/default_file_based_availability_strategy.py @@ -8,8 +8,14 @@ from airbyte_cdk import AirbyteTracedException from airbyte_cdk.sources import Source -from airbyte_cdk.sources.file_based.availability_strategy import AbstractFileBasedAvailabilityStrategy -from airbyte_cdk.sources.file_based.exceptions import CheckAvailabilityError, CustomFileBasedException, FileBasedSourceError +from airbyte_cdk.sources.file_based.availability_strategy import ( + AbstractFileBasedAvailabilityStrategy, +) +from airbyte_cdk.sources.file_based.exceptions import ( + CheckAvailabilityError, + CustomFileBasedException, + FileBasedSourceError, +) from airbyte_cdk.sources.file_based.file_based_stream_reader import AbstractFileBasedStreamReader from airbyte_cdk.sources.file_based.remote_file import RemoteFile from airbyte_cdk.sources.file_based.schema_helpers import conforms_to_schema @@ -89,15 +95,25 @@ def _check_list_files(self, stream: "AbstractFileBasedStream") -> RemoteFile: except CustomFileBasedException as exc: raise CheckAvailabilityError(str(exc), stream=stream.name) from exc except Exception as exc: - raise CheckAvailabilityError(FileBasedSourceError.ERROR_LISTING_FILES, stream=stream.name) from exc + raise CheckAvailabilityError( + FileBasedSourceError.ERROR_LISTING_FILES, stream=stream.name + ) from exc return file - def _check_parse_record(self, stream: "AbstractFileBasedStream", file: RemoteFile, logger: logging.Logger) -> None: + def _check_parse_record( + self, stream: "AbstractFileBasedStream", file: RemoteFile, logger: logging.Logger + ) -> None: parser = stream.get_parser() try: - record = next(iter(parser.parse_records(stream.config, file, self.stream_reader, logger, discovered_schema=None))) + record = next( + iter( + parser.parse_records( + stream.config, file, self.stream_reader, logger, discovered_schema=None + ) + ) + ) except StopIteration: # The file is empty. We've verified that we can open it, so will # consider the connection check successful even though it means @@ -106,7 +122,9 @@ def _check_parse_record(self, stream: "AbstractFileBasedStream", file: RemoteFil except AirbyteTracedException as ate: raise ate except Exception as exc: - raise CheckAvailabilityError(FileBasedSourceError.ERROR_READING_FILE, stream=stream.name, file=file.uri) from exc + raise CheckAvailabilityError( + FileBasedSourceError.ERROR_READING_FILE, stream=stream.name, file=file.uri + ) from exc schema = stream.catalog_schema or stream.config.input_schema if schema and stream.validation_policy.validate_schema_before_sync: diff --git a/airbyte_cdk/sources/file_based/config/abstract_file_based_spec.py b/airbyte_cdk/sources/file_based/config/abstract_file_based_spec.py index 38159698..ee220388 100644 --- a/airbyte_cdk/sources/file_based/config/abstract_file_based_spec.py +++ b/airbyte_cdk/sources/file_based/config/abstract_file_based_spec.py @@ -107,10 +107,16 @@ def replace_enum_allOf_and_anyOf(schema: Dict[str, Any]) -> Dict[str, Any]: properties_to_change = ["validation_policy"] for property_to_change in properties_to_change: - property_object = schema["properties"]["streams"]["items"]["properties"][property_to_change] + property_object = schema["properties"]["streams"]["items"]["properties"][ + property_to_change + ] if "anyOf" in property_object: - schema["properties"]["streams"]["items"]["properties"][property_to_change]["type"] = "object" - schema["properties"]["streams"]["items"]["properties"][property_to_change]["oneOf"] = property_object.pop("anyOf") + schema["properties"]["streams"]["items"]["properties"][property_to_change][ + "type" + ] = "object" + schema["properties"]["streams"]["items"]["properties"][property_to_change][ + "oneOf" + ] = property_object.pop("anyOf") AbstractFileBasedSpec.move_enum_to_root(property_object) csv_format_schemas = list( @@ -121,9 +127,9 @@ def replace_enum_allOf_and_anyOf(schema: Dict[str, Any]) -> Dict[str, Any]: ) if len(csv_format_schemas) != 1: raise ValueError(f"Expecting only one CSV format but got {csv_format_schemas}") - csv_format_schemas[0]["properties"]["header_definition"]["oneOf"] = csv_format_schemas[0]["properties"]["header_definition"].pop( - "anyOf", [] - ) + csv_format_schemas[0]["properties"]["header_definition"]["oneOf"] = csv_format_schemas[0][ + "properties" + ]["header_definition"].pop("anyOf", []) csv_format_schemas[0]["properties"]["header_definition"]["type"] = "object" return schema diff --git a/airbyte_cdk/sources/file_based/config/csv_format.py b/airbyte_cdk/sources/file_based/config/csv_format.py index 317a0172..83789c45 100644 --- a/airbyte_cdk/sources/file_based/config/csv_format.py +++ b/airbyte_cdk/sources/file_based/config/csv_format.py @@ -70,7 +70,9 @@ def has_header_row(self) -> bool: @validator("column_names") def validate_column_names(cls, v: List[str]) -> List[str]: if not v: - raise ValueError("At least one column name needs to be provided when using user provided headers") + raise ValueError( + "At least one column name needs to be provided when using user provided headers" + ) return v @@ -107,7 +109,9 @@ class Config(OneOfOptionConfig): description='The character encoding of the CSV data. Leave blank to default to UTF8. See list of python encodings for allowable options.', ) double_quote: bool = Field( - title="Double Quote", default=True, description="Whether two quotes in a quoted CSV value denote a single quote in the data." + title="Double Quote", + default=True, + description="Whether two quotes in a quoted CSV value denote a single quote in the data.", ) null_values: Set[str] = Field( title="Null Values", @@ -125,12 +129,16 @@ class Config(OneOfOptionConfig): description="The number of rows to skip before the header row. For example, if the header row is on the 3rd row, enter 2 in this field.", ) skip_rows_after_header: int = Field( - title="Skip Rows After Header", default=0, description="The number of rows to skip after the header row." + title="Skip Rows After Header", + default=0, + description="The number of rows to skip after the header row.", ) - header_definition: Union[CsvHeaderFromCsv, CsvHeaderAutogenerated, CsvHeaderUserProvided] = Field( - title="CSV Header Definition", - default=CsvHeaderFromCsv(header_definition_type=CsvHeaderDefinitionType.FROM_CSV.value), - description="How headers will be defined. `User Provided` assumes the CSV does not have a header row and uses the headers provided and `Autogenerated` assumes the CSV does not have a header row and the CDK will generate headers using for `f{i}` where `i` is the index starting from 0. Else, the default behavior is to use the header from the CSV file. If a user wants to autogenerate or provide column names for a CSV having headers, they can skip rows.", + header_definition: Union[CsvHeaderFromCsv, CsvHeaderAutogenerated, CsvHeaderUserProvided] = ( + Field( + title="CSV Header Definition", + default=CsvHeaderFromCsv(header_definition_type=CsvHeaderDefinitionType.FROM_CSV.value), + description="How headers will be defined. `User Provided` assumes the CSV does not have a header row and uses the headers provided and `Autogenerated` assumes the CSV does not have a header row and the CDK will generate headers using for `f{i}` where `i` is the index starting from 0. Else, the default behavior is to use the header from the CSV file. If a user wants to autogenerate or provide column names for a CSV having headers, they can skip rows.", + ) ) true_values: Set[str] = Field( title="True Values", @@ -189,9 +197,13 @@ def validate_optional_args(cls, values: Dict[str, Any]) -> Dict[str, Any]: definition_type = values.get("header_definition_type") column_names = values.get("user_provided_column_names") if definition_type == CsvHeaderDefinitionType.USER_PROVIDED and not column_names: - raise ValidationError("`user_provided_column_names` should be defined if the definition 'User Provided'.", model=CsvFormat) + raise ValidationError( + "`user_provided_column_names` should be defined if the definition 'User Provided'.", + model=CsvFormat, + ) if definition_type != CsvHeaderDefinitionType.USER_PROVIDED and column_names: raise ValidationError( - "`user_provided_column_names` should not be defined if the definition is not 'User Provided'.", model=CsvFormat + "`user_provided_column_names` should not be defined if the definition is not 'User Provided'.", + model=CsvFormat, ) return values diff --git a/airbyte_cdk/sources/file_based/config/file_based_stream_config.py b/airbyte_cdk/sources/file_based/config/file_based_stream_config.py index 5419f4a6..5d92f6f0 100644 --- a/airbyte_cdk/sources/file_based/config/file_based_stream_config.py +++ b/airbyte_cdk/sources/file_based/config/file_based_stream_config.py @@ -56,7 +56,9 @@ class FileBasedStreamConfig(BaseModel): description="When the state history of the file store is full, syncs will only read files that were last modified in the provided day range.", default=3, ) - format: Union[AvroFormat, CsvFormat, JsonlFormat, ParquetFormat, UnstructuredFormat, ExcelFormat] = Field( + format: Union[ + AvroFormat, CsvFormat, JsonlFormat, ParquetFormat, UnstructuredFormat, ExcelFormat + ] = Field( title="Format", description="The configuration options that are used to alter how to read incoming files that deviate from the standard formatting.", ) @@ -89,6 +91,8 @@ def get_input_schema(self) -> Optional[Mapping[str, Any]]: if self.input_schema: schema = type_mapping_to_jsonschema(self.input_schema) if not schema: - raise ValueError(f"Unable to create JSON schema from input schema {self.input_schema}") + raise ValueError( + f"Unable to create JSON schema from input schema {self.input_schema}" + ) return schema return None diff --git a/airbyte_cdk/sources/file_based/config/unstructured_format.py b/airbyte_cdk/sources/file_based/config/unstructured_format.py index b799d1fe..dcebd951 100644 --- a/airbyte_cdk/sources/file_based/config/unstructured_format.py +++ b/airbyte_cdk/sources/file_based/config/unstructured_format.py @@ -13,7 +13,9 @@ class LocalProcessingConfigModel(BaseModel): class Config(OneOfOptionConfig): title = "Local" - description = "Process files locally, supporting `fast` and `ocr` modes. This is the default option." + description = ( + "Process files locally, supporting `fast` and `ocr` modes. This is the default option." + ) discriminator = "mode" @@ -23,7 +25,9 @@ class APIParameterConfigModel(BaseModel): description="The name of the unstructured API parameter to use", examples=["combine_under_n_chars", "languages"], ) - value: str = Field(title="Value", description="The value of the parameter", examples=["true", "hi_res"]) + value: str = Field( + title="Value", description="The value of the parameter", examples=["true", "hi_res"] + ) class APIProcessingConfigModel(BaseModel): diff --git a/airbyte_cdk/sources/file_based/discovery_policy/default_discovery_policy.py b/airbyte_cdk/sources/file_based/discovery_policy/default_discovery_policy.py index 3ce09889..f651c2ce 100644 --- a/airbyte_cdk/sources/file_based/discovery_policy/default_discovery_policy.py +++ b/airbyte_cdk/sources/file_based/discovery_policy/default_discovery_policy.py @@ -2,7 +2,9 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # -from airbyte_cdk.sources.file_based.discovery_policy.abstract_discovery_policy import AbstractDiscoveryPolicy +from airbyte_cdk.sources.file_based.discovery_policy.abstract_discovery_policy import ( + AbstractDiscoveryPolicy, +) from airbyte_cdk.sources.file_based.file_types.file_type_parser import FileTypeParser DEFAULT_N_CONCURRENT_REQUESTS = 10 @@ -23,6 +25,9 @@ def get_max_n_files_for_schema_inference(self, parser: FileTypeParser) -> int: return min( filter( None, - (DEFAULT_MAX_N_FILES_FOR_STREAM_SCHEMA_INFERENCE, parser.parser_max_n_files_for_schema_inference), + ( + DEFAULT_MAX_N_FILES_FOR_STREAM_SCHEMA_INFERENCE, + parser.parser_max_n_files_for_schema_inference, + ), ) ) diff --git a/airbyte_cdk/sources/file_based/exceptions.py b/airbyte_cdk/sources/file_based/exceptions.py index 60adf321..1c5ce0b1 100644 --- a/airbyte_cdk/sources/file_based/exceptions.py +++ b/airbyte_cdk/sources/file_based/exceptions.py @@ -11,27 +11,21 @@ class FileBasedSourceError(Enum): EMPTY_STREAM = "No files were identified in the stream. This may be because there are no files in the specified container, or because your glob patterns did not match any files. Please verify that your source contains files last modified after the start_date and that your glob patterns are not overly strict." - GLOB_PARSE_ERROR = ( - "Error parsing glob pattern. Please refer to the glob pattern rules at https://facelessuser.github.io/wcmatch/glob/#split." - ) + GLOB_PARSE_ERROR = "Error parsing glob pattern. Please refer to the glob pattern rules at https://facelessuser.github.io/wcmatch/glob/#split." ENCODING_ERROR = "File encoding error. The configured encoding must match file encoding." ERROR_CASTING_VALUE = "Could not cast the value to the expected type." ERROR_CASTING_VALUE_UNRECOGNIZED_TYPE = "Could not cast the value to the expected type because the type is not recognized. Valid types are null, array, boolean, integer, number, object, and string." ERROR_DECODING_VALUE = "Expected a JSON-decodeable value but could not decode record." - ERROR_LISTING_FILES = ( - "Error listing files. Please check the credentials provided in the config and verify that they provide permission to list files." - ) - ERROR_READING_FILE = ( - "Error opening file. Please check the credentials provided in the config and verify that they provide permission to read files." - ) + ERROR_LISTING_FILES = "Error listing files. Please check the credentials provided in the config and verify that they provide permission to list files." + ERROR_READING_FILE = "Error opening file. Please check the credentials provided in the config and verify that they provide permission to read files." ERROR_PARSING_RECORD = "Error parsing record. This could be due to a mismatch between the config's file type and the actual file type, or because the file or record is not parseable." - ERROR_PARSING_USER_PROVIDED_SCHEMA = "The provided schema could not be transformed into valid JSON Schema." + ERROR_PARSING_USER_PROVIDED_SCHEMA = ( + "The provided schema could not be transformed into valid JSON Schema." + ) ERROR_VALIDATING_RECORD = "One or more records do not pass the schema validation policy. Please modify your input schema, or select a more lenient validation policy." ERROR_PARSING_RECORD_MISMATCHED_COLUMNS = "A header field has resolved to `None`. This indicates that the CSV has more rows than the number of header fields. If you input your schema or headers, please verify that the number of columns corresponds to the number of columns in your CSV's rows." ERROR_PARSING_RECORD_MISMATCHED_ROWS = "A row's value has resolved to `None`. This indicates that the CSV has more columns in the header field than the number of columns in the row(s). If you input your schema or headers, please verify that the number of columns corresponds to the number of columns in your CSV's rows." - STOP_SYNC_PER_SCHEMA_VALIDATION_POLICY = ( - "Stopping sync in accordance with the configured validation policy. Records in file did not conform to the schema." - ) + STOP_SYNC_PER_SCHEMA_VALIDATION_POLICY = "Stopping sync in accordance with the configured validation policy. Records in file did not conform to the schema." NULL_VALUE_IN_SCHEMA = "Error during schema inference: no type was detected for key." UNRECOGNIZED_TYPE = "Error during schema inference: unrecognized type." SCHEMA_INFERENCE_ERROR = "Error inferring schema from files. Are the files valid?" @@ -39,7 +33,9 @@ class FileBasedSourceError(Enum): CONFIG_VALIDATION_ERROR = "Error creating stream config object." MISSING_SCHEMA = "Expected `json_schema` in the configured catalog but it is missing." UNDEFINED_PARSER = "No parser is defined for this file type." - UNDEFINED_VALIDATION_POLICY = "The validation policy defined in the config does not exist for the source." + UNDEFINED_VALIDATION_POLICY = ( + "The validation policy defined in the config does not exist for the source." + ) class FileBasedErrorsCollector: @@ -70,7 +66,9 @@ class BaseFileBasedSourceError(Exception): def __init__(self, error: Union[FileBasedSourceError, str], **kwargs): # type: ignore # noqa if isinstance(error, FileBasedSourceError): error = FileBasedSourceError(error).value - super().__init__(f"{error} Contact Support if you need assistance.\n{' '.join([f'{k}={v}' for k, v in kwargs.items()])}") + super().__init__( + f"{error} Contact Support if you need assistance.\n{' '.join([f'{k}={v}' for k, v in kwargs.items()])}" + ) class ConfigValidationError(BaseFileBasedSourceError): diff --git a/airbyte_cdk/sources/file_based/file_based_source.py b/airbyte_cdk/sources/file_based/file_based_source.py index 65f9e531..2c5758b2 100644 --- a/airbyte_cdk/sources/file_based/file_based_source.py +++ b/airbyte_cdk/sources/file_based/file_based_source.py @@ -22,15 +22,31 @@ from airbyte_cdk.sources.concurrent_source.concurrent_source import ConcurrentSource from airbyte_cdk.sources.concurrent_source.concurrent_source_adapter import ConcurrentSourceAdapter from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager -from airbyte_cdk.sources.file_based.availability_strategy import AbstractFileBasedAvailabilityStrategy, DefaultFileBasedAvailabilityStrategy +from airbyte_cdk.sources.file_based.availability_strategy import ( + AbstractFileBasedAvailabilityStrategy, + DefaultFileBasedAvailabilityStrategy, +) from airbyte_cdk.sources.file_based.config.abstract_file_based_spec import AbstractFileBasedSpec -from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig, ValidationPolicy -from airbyte_cdk.sources.file_based.discovery_policy import AbstractDiscoveryPolicy, DefaultDiscoveryPolicy -from airbyte_cdk.sources.file_based.exceptions import ConfigValidationError, FileBasedErrorsCollector, FileBasedSourceError +from airbyte_cdk.sources.file_based.config.file_based_stream_config import ( + FileBasedStreamConfig, + ValidationPolicy, +) +from airbyte_cdk.sources.file_based.discovery_policy import ( + AbstractDiscoveryPolicy, + DefaultDiscoveryPolicy, +) +from airbyte_cdk.sources.file_based.exceptions import ( + ConfigValidationError, + FileBasedErrorsCollector, + FileBasedSourceError, +) from airbyte_cdk.sources.file_based.file_based_stream_reader import AbstractFileBasedStreamReader from airbyte_cdk.sources.file_based.file_types import default_parsers from airbyte_cdk.sources.file_based.file_types.file_type_parser import FileTypeParser -from airbyte_cdk.sources.file_based.schema_validation_policies import DEFAULT_SCHEMA_VALIDATION_POLICIES, AbstractSchemaValidationPolicy +from airbyte_cdk.sources.file_based.schema_validation_policies import ( + DEFAULT_SCHEMA_VALIDATION_POLICIES, + AbstractSchemaValidationPolicy, +) from airbyte_cdk.sources.file_based.stream import AbstractFileBasedStream, DefaultFileBasedStream from airbyte_cdk.sources.file_based.stream.concurrent.adapters import FileBasedStreamFacade from airbyte_cdk.sources.file_based.stream.concurrent.cursor import ( @@ -65,25 +81,37 @@ def __init__( availability_strategy: Optional[AbstractFileBasedAvailabilityStrategy] = None, discovery_policy: AbstractDiscoveryPolicy = DefaultDiscoveryPolicy(), parsers: Mapping[Type[Any], FileTypeParser] = default_parsers, - validation_policies: Mapping[ValidationPolicy, AbstractSchemaValidationPolicy] = DEFAULT_SCHEMA_VALIDATION_POLICIES, - cursor_cls: Type[Union[AbstractConcurrentFileBasedCursor, AbstractFileBasedCursor]] = FileBasedConcurrentCursor, + validation_policies: Mapping[ + ValidationPolicy, AbstractSchemaValidationPolicy + ] = DEFAULT_SCHEMA_VALIDATION_POLICIES, + cursor_cls: Type[ + Union[AbstractConcurrentFileBasedCursor, AbstractFileBasedCursor] + ] = FileBasedConcurrentCursor, ): self.stream_reader = stream_reader self.spec_class = spec_class self.config = config self.catalog = catalog self.state = state - self.availability_strategy = availability_strategy or DefaultFileBasedAvailabilityStrategy(stream_reader) + self.availability_strategy = availability_strategy or DefaultFileBasedAvailabilityStrategy( + stream_reader + ) self.discovery_policy = discovery_policy self.parsers = parsers self.validation_policies = validation_policies - self.stream_schemas = {s.stream.name: s.stream.json_schema for s in catalog.streams} if catalog else {} + self.stream_schemas = ( + {s.stream.name: s.stream.json_schema for s in catalog.streams} if catalog else {} + ) self.cursor_cls = cursor_cls self.logger = init_logger(f"airbyte.{self.name}") self.errors_collector: FileBasedErrorsCollector = FileBasedErrorsCollector() self._message_repository: Optional[MessageRepository] = None concurrent_source = ConcurrentSource.create( - MAX_CONCURRENCY, INITIAL_N_PARTITIONS, self.logger, self._slice_logger, self.message_repository + MAX_CONCURRENCY, + INITIAL_N_PARTITIONS, + self.logger, + self._slice_logger, + self.message_repository, ) self._state = None super().__init__(concurrent_source) @@ -91,10 +119,14 @@ def __init__( @property def message_repository(self) -> MessageRepository: if self._message_repository is None: - self._message_repository = InMemoryMessageRepository(Level(AirbyteLogFormatter.level_mapping[self.logger.level])) + self._message_repository = InMemoryMessageRepository( + Level(AirbyteLogFormatter.level_mapping[self.logger.level]) + ) return self._message_repository - def check_connection(self, logger: logging.Logger, config: Mapping[str, Any]) -> Tuple[bool, Optional[Any]]: + def check_connection( + self, logger: logging.Logger, config: Mapping[str, Any] + ) -> Tuple[bool, Optional[Any]]: """ Check that the source can be accessed using the user-provided configuration. @@ -195,13 +227,21 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]: sync_mode = self._get_sync_mode_from_catalog(stream_config.name) - if sync_mode == SyncMode.full_refresh and hasattr(self, "_concurrency_level") and self._concurrency_level is not None: + if ( + sync_mode == SyncMode.full_refresh + and hasattr(self, "_concurrency_level") + and self._concurrency_level is not None + ): cursor = FileBasedFinalStateCursor( - stream_config=stream_config, stream_namespace=None, message_repository=self.message_repository + stream_config=stream_config, + stream_namespace=None, + message_repository=self.message_repository, ) stream = FileBasedStreamFacade.create_from_stream( stream=self._make_default_stream( - stream_config=stream_config, cursor=cursor, use_file_transfer=self._use_file_transfer(parsed_config) + stream_config=stream_config, + cursor=cursor, + use_file_transfer=self._use_file_transfer(parsed_config), ), source=self, logger=self.logger, @@ -230,7 +270,9 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]: ) stream = FileBasedStreamFacade.create_from_stream( stream=self._make_default_stream( - stream_config=stream_config, cursor=cursor, use_file_transfer=self._use_file_transfer(parsed_config) + stream_config=stream_config, + cursor=cursor, + use_file_transfer=self._use_file_transfer(parsed_config), ), source=self, logger=self.logger, @@ -240,7 +282,9 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]: else: cursor = self.cursor_cls(stream_config) stream = self._make_default_stream( - stream_config=stream_config, cursor=cursor, use_file_transfer=self._use_file_transfer(parsed_config) + stream_config=stream_config, + cursor=cursor, + use_file_transfer=self._use_file_transfer(parsed_config), ) streams.append(stream) @@ -250,7 +294,10 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]: raise ConfigValidationError(FileBasedSourceError.CONFIG_VALIDATION_ERROR) from exc def _make_default_stream( - self, stream_config: FileBasedStreamConfig, cursor: Optional[AbstractFileBasedCursor], use_file_transfer: bool = False + self, + stream_config: FileBasedStreamConfig, + cursor: Optional[AbstractFileBasedCursor], + use_file_transfer: bool = False, ) -> AbstractFileBasedStream: return DefaultFileBasedStream( config=stream_config, @@ -265,7 +312,9 @@ def _make_default_stream( use_file_transfer=use_file_transfer, ) - def _get_stream_from_catalog(self, stream_config: FileBasedStreamConfig) -> Optional[AirbyteStream]: + def _get_stream_from_catalog( + self, stream_config: FileBasedStreamConfig + ) -> Optional[AirbyteStream]: if self.catalog: for stream in self.catalog.streams or []: if stream.stream.name == stream_config.name: @@ -292,7 +341,9 @@ def read( yield from self.errors_collector.yield_and_raise_collected() # count streams using a certain parser parsed_config = self._get_parsed_config(config) - for parser, count in Counter(stream.format.filetype for stream in parsed_config.streams).items(): + for parser, count in Counter( + stream.format.filetype for stream in parsed_config.streams + ).items(): yield create_analytics_message(f"file-cdk-{parser}-stream-count", count) def spec(self, *args: Any, **kwargs: Any) -> ConnectorSpecification: @@ -308,21 +359,28 @@ def spec(self, *args: Any, **kwargs: Any) -> ConnectorSpecification: def _get_parsed_config(self, config: Mapping[str, Any]) -> AbstractFileBasedSpec: return self.spec_class(**config) - def _validate_and_get_validation_policy(self, stream_config: FileBasedStreamConfig) -> AbstractSchemaValidationPolicy: + def _validate_and_get_validation_policy( + self, stream_config: FileBasedStreamConfig + ) -> AbstractSchemaValidationPolicy: if stream_config.validation_policy not in self.validation_policies: # This should never happen because we validate the config against the schema's validation_policy enum raise ValidationError( - f"`validation_policy` must be one of {list(self.validation_policies.keys())}", model=FileBasedStreamConfig + f"`validation_policy` must be one of {list(self.validation_policies.keys())}", + model=FileBasedStreamConfig, ) return self.validation_policies[stream_config.validation_policy] def _validate_input_schema(self, stream_config: FileBasedStreamConfig) -> None: if stream_config.schemaless and stream_config.input_schema: - raise ValidationError("`input_schema` and `schemaless` options cannot both be set", model=FileBasedStreamConfig) + raise ValidationError( + "`input_schema` and `schemaless` options cannot both be set", + model=FileBasedStreamConfig, + ) @staticmethod def _use_file_transfer(parsed_config: AbstractFileBasedSpec) -> bool: use_file_transfer = ( - hasattr(parsed_config.delivery_method, "delivery_type") and parsed_config.delivery_method.delivery_type == "use_file_transfer" + hasattr(parsed_config.delivery_method, "delivery_type") + and parsed_config.delivery_method.delivery_type == "use_file_transfer" ) return use_file_transfer diff --git a/airbyte_cdk/sources/file_based/file_based_stream_reader.py b/airbyte_cdk/sources/file_based/file_based_stream_reader.py index d98513da..f8a9f89f 100644 --- a/airbyte_cdk/sources/file_based/file_based_stream_reader.py +++ b/airbyte_cdk/sources/file_based/file_based_stream_reader.py @@ -45,7 +45,9 @@ def config(self, value: AbstractFileBasedSpec) -> None: ... @abstractmethod - def open_file(self, file: RemoteFile, mode: FileReadMode, encoding: Optional[str], logger: logging.Logger) -> IOBase: + def open_file( + self, file: RemoteFile, mode: FileReadMode, encoding: Optional[str], logger: logging.Logger + ) -> IOBase: """ Return a file handle for reading. @@ -80,11 +82,17 @@ def get_matching_files( """ ... - def filter_files_by_globs_and_start_date(self, files: List[RemoteFile], globs: List[str]) -> Iterable[RemoteFile]: + def filter_files_by_globs_and_start_date( + self, files: List[RemoteFile], globs: List[str] + ) -> Iterable[RemoteFile]: """ Utility method for filtering files based on globs. """ - start_date = datetime.strptime(self.config.start_date, self.DATE_TIME_FORMAT) if self.config and self.config.start_date else None + start_date = ( + datetime.strptime(self.config.start_date, self.DATE_TIME_FORMAT) + if self.config and self.config.start_date + else None + ) seen = set() for file in files: @@ -120,13 +128,16 @@ def get_prefixes_from_globs(globs: List[str]) -> Set[str]: def use_file_transfer(self) -> bool: if self.config: use_file_transfer = ( - hasattr(self.config.delivery_method, "delivery_type") and self.config.delivery_method.delivery_type == "use_file_transfer" + hasattr(self.config.delivery_method, "delivery_type") + and self.config.delivery_method.delivery_type == "use_file_transfer" ) return use_file_transfer return False @abstractmethod - def get_file(self, file: RemoteFile, local_directory: str, logger: logging.Logger) -> Dict[str, Any]: + def get_file( + self, file: RemoteFile, local_directory: str, logger: logging.Logger + ) -> Dict[str, Any]: """ This is required for connectors that will support writing to files. It will handle the logic to download,get,read,acquire or diff --git a/airbyte_cdk/sources/file_based/file_types/avro_parser.py b/airbyte_cdk/sources/file_based/file_types/avro_parser.py index b033afa5..a1535eaa 100644 --- a/airbyte_cdk/sources/file_based/file_types/avro_parser.py +++ b/airbyte_cdk/sources/file_based/file_types/avro_parser.py @@ -9,7 +9,10 @@ from airbyte_cdk.sources.file_based.config.avro_format import AvroFormat from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig from airbyte_cdk.sources.file_based.exceptions import FileBasedSourceError, RecordParseError -from airbyte_cdk.sources.file_based.file_based_stream_reader import AbstractFileBasedStreamReader, FileReadMode +from airbyte_cdk.sources.file_based.file_based_stream_reader import ( + AbstractFileBasedStreamReader, + FileReadMode, +) from airbyte_cdk.sources.file_based.file_types.file_type_parser import FileTypeParser from airbyte_cdk.sources.file_based.remote_file import RemoteFile from airbyte_cdk.sources.file_based.schema_helpers import SchemaType @@ -64,15 +67,21 @@ async def infer_schema( avro_schema = avro_reader.writer_schema if not avro_schema["type"] == "record": unsupported_type = avro_schema["type"] - raise ValueError(f"Only record based avro files are supported. Found {unsupported_type}") + raise ValueError( + f"Only record based avro files are supported. Found {unsupported_type}" + ) json_schema = { - field["name"]: AvroParser._convert_avro_type_to_json(avro_format, field["name"], field["type"]) + field["name"]: AvroParser._convert_avro_type_to_json( + avro_format, field["name"], field["type"] + ) for field in avro_schema["fields"] } return json_schema @classmethod - def _convert_avro_type_to_json(cls, avro_format: AvroFormat, field_name: str, avro_field: str) -> Mapping[str, Any]: + def _convert_avro_type_to_json( + cls, avro_format: AvroFormat, field_name: str, avro_field: str + ) -> Mapping[str, Any]: if isinstance(avro_field, str) and avro_field in AVRO_TYPE_TO_JSON_TYPE: # Legacy behavior to retain backwards compatibility. Long term we should always represent doubles as strings if avro_field == "double" and not avro_format.double_as_string: @@ -83,17 +92,28 @@ def _convert_avro_type_to_json(cls, avro_format: AvroFormat, field_name: str, av return { "type": "object", "properties": { - object_field["name"]: AvroParser._convert_avro_type_to_json(avro_format, object_field["name"], object_field["type"]) + object_field["name"]: AvroParser._convert_avro_type_to_json( + avro_format, object_field["name"], object_field["type"] + ) for object_field in avro_field["fields"] }, } elif avro_field["type"] == "array": if "items" not in avro_field: - raise ValueError(f"{field_name} array type does not have a required field items") - return {"type": "array", "items": AvroParser._convert_avro_type_to_json(avro_format, "", avro_field["items"])} + raise ValueError( + f"{field_name} array type does not have a required field items" + ) + return { + "type": "array", + "items": AvroParser._convert_avro_type_to_json( + avro_format, "", avro_field["items"] + ), + } elif avro_field["type"] == "enum": if "symbols" not in avro_field: - raise ValueError(f"{field_name} enum type does not have a required field symbols") + raise ValueError( + f"{field_name} enum type does not have a required field symbols" + ) if "name" not in avro_field: raise ValueError(f"{field_name} enum type does not have a required field name") return {"type": "string", "enum": avro_field["symbols"]} @@ -102,7 +122,9 @@ def _convert_avro_type_to_json(cls, avro_format: AvroFormat, field_name: str, av raise ValueError(f"{field_name} map type does not have a required field values") return { "type": "object", - "additionalProperties": AvroParser._convert_avro_type_to_json(avro_format, "", avro_field["values"]), + "additionalProperties": AvroParser._convert_avro_type_to_json( + avro_format, "", avro_field["values"] + ), } elif avro_field["type"] == "fixed" and avro_field.get("logicalType") != "duration": if "size" not in avro_field: @@ -115,18 +137,27 @@ def _convert_avro_type_to_json(cls, avro_format: AvroFormat, field_name: str, av } elif avro_field.get("logicalType") == "decimal": if "precision" not in avro_field: - raise ValueError(f"{field_name} decimal type does not have a required field precision") + raise ValueError( + f"{field_name} decimal type does not have a required field precision" + ) if "scale" not in avro_field: - raise ValueError(f"{field_name} decimal type does not have a required field scale") + raise ValueError( + f"{field_name} decimal type does not have a required field scale" + ) max_whole_number_range = avro_field["precision"] - avro_field["scale"] decimal_range = avro_field["scale"] # This regex looks like a mess, but it is validation for at least one whole number and optional fractional numbers # For example: ^-?\d{1,5}(?:\.\d{1,3})?$ would accept 12345.123 and 123456.12345 would be rejected - return {"type": "string", "pattern": f"^-?\\d{{{1,max_whole_number_range}}}(?:\\.\\d{1,decimal_range})?$"} + return { + "type": "string", + "pattern": f"^-?\\d{{{1,max_whole_number_range}}}(?:\\.\\d{1,decimal_range})?$", + } elif "logicalType" in avro_field: if avro_field["logicalType"] not in AVRO_LOGICAL_TYPE_TO_JSON: - raise ValueError(f"{avro_field['logicalType']} is not a valid Avro logical type") + raise ValueError( + f"{avro_field['logicalType']} is not a valid Avro logical type" + ) return AVRO_LOGICAL_TYPE_TO_JSON[avro_field["logicalType"]] else: raise ValueError(f"Unsupported avro type: {avro_field}") @@ -150,22 +181,32 @@ def parse_records( with stream_reader.open_file(file, self.file_read_mode, self.ENCODING, logger) as fp: avro_reader = fastavro.reader(fp) schema = avro_reader.writer_schema - schema_field_name_to_type = {field["name"]: field["type"] for field in schema["fields"]} + schema_field_name_to_type = { + field["name"]: field["type"] for field in schema["fields"] + } for record in avro_reader: line_no += 1 yield { - record_field: self._to_output_value(avro_format, schema_field_name_to_type[record_field], record[record_field]) + record_field: self._to_output_value( + avro_format, + schema_field_name_to_type[record_field], + record[record_field], + ) for record_field, record_value in schema_field_name_to_type.items() } except Exception as exc: - raise RecordParseError(FileBasedSourceError.ERROR_PARSING_RECORD, filename=file.uri, lineno=line_no) from exc + raise RecordParseError( + FileBasedSourceError.ERROR_PARSING_RECORD, filename=file.uri, lineno=line_no + ) from exc @property def file_read_mode(self) -> FileReadMode: return FileReadMode.READ_BINARY @staticmethod - def _to_output_value(avro_format: AvroFormat, record_type: Mapping[str, Any], record_value: Any) -> Any: + def _to_output_value( + avro_format: AvroFormat, record_type: Mapping[str, Any], record_value: Any + ) -> Any: if isinstance(record_value, bytes): return record_value.decode() elif not isinstance(record_type, Mapping): diff --git a/airbyte_cdk/sources/file_based/file_types/csv_parser.py b/airbyte_cdk/sources/file_based/file_types/csv_parser.py index 961fc8f1..6927b2fd 100644 --- a/airbyte_cdk/sources/file_based/file_types/csv_parser.py +++ b/airbyte_cdk/sources/file_based/file_types/csv_parser.py @@ -13,10 +13,18 @@ from uuid import uuid4 from airbyte_cdk.models import FailureType -from airbyte_cdk.sources.file_based.config.csv_format import CsvFormat, CsvHeaderAutogenerated, CsvHeaderUserProvided, InferenceType +from airbyte_cdk.sources.file_based.config.csv_format import ( + CsvFormat, + CsvHeaderAutogenerated, + CsvHeaderUserProvided, + InferenceType, +) from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig from airbyte_cdk.sources.file_based.exceptions import FileBasedSourceError, RecordParseError -from airbyte_cdk.sources.file_based.file_based_stream_reader import AbstractFileBasedStreamReader, FileReadMode +from airbyte_cdk.sources.file_based.file_based_stream_reader import ( + AbstractFileBasedStreamReader, + FileReadMode, +) from airbyte_cdk.sources.file_based.file_types.file_type_parser import FileTypeParser from airbyte_cdk.sources.file_based.remote_file import RemoteFile from airbyte_cdk.sources.file_based.schema_helpers import TYPE_PYTHON_MAPPING, SchemaType @@ -77,7 +85,9 @@ def read_data( # than headers or more headers dans columns if None in row: if config_format.ignore_errors_on_fields_mismatch: - logger.error(f"Skipping record in line {lineno} of file {file.uri}; invalid CSV row with missing column.") + logger.error( + f"Skipping record in line {lineno} of file {file.uri}; invalid CSV row with missing column." + ) else: raise RecordParseError( FileBasedSourceError.ERROR_PARSING_RECORD_MISMATCHED_COLUMNS, @@ -86,10 +96,14 @@ def read_data( ) if None in row.values(): if config_format.ignore_errors_on_fields_mismatch: - logger.error(f"Skipping record in line {lineno} of file {file.uri}; invalid CSV row with extra column.") + logger.error( + f"Skipping record in line {lineno} of file {file.uri}; invalid CSV row with extra column." + ) else: raise RecordParseError( - FileBasedSourceError.ERROR_PARSING_RECORD_MISMATCHED_ROWS, filename=file.uri, lineno=lineno + FileBasedSourceError.ERROR_PARSING_RECORD_MISMATCHED_ROWS, + filename=file.uri, + lineno=lineno, ) yield row finally: @@ -105,7 +119,9 @@ def _get_headers(self, fp: IOBase, config_format: CsvFormat, dialect_name: str) return config_format.header_definition.column_names # type: ignore # should be CsvHeaderUserProvided given the type if isinstance(config_format.header_definition, CsvHeaderAutogenerated): - self._skip_rows(fp, config_format.skip_rows_before_header + config_format.skip_rows_after_header) + self._skip_rows( + fp, config_format.skip_rows_before_header + config_format.skip_rows_after_header + ) headers = self._auto_generate_headers(fp, dialect_name) else: # Then read the header @@ -165,11 +181,15 @@ async def infer_schema( # sources will likely require one. Rather than modify the interface now we can wait until the real use case config_format = _extract_format(config) type_inferrer_by_field: Dict[str, _TypeInferrer] = defaultdict( - lambda: _JsonTypeInferrer(config_format.true_values, config_format.false_values, config_format.null_values) + lambda: _JsonTypeInferrer( + config_format.true_values, config_format.false_values, config_format.null_values + ) if config_format.inference_type != InferenceType.NONE else _DisabledTypeInferrer() ) - data_generator = self._csv_reader.read_data(config, file, stream_reader, logger, self.file_read_mode) + data_generator = self._csv_reader.read_data( + config, file, stream_reader, logger, self.file_read_mode + ) read_bytes = 0 for row in data_generator: for header, value in row.items(): @@ -187,7 +207,10 @@ async def infer_schema( f"Else, please contact Airbyte.", failure_type=FailureType.config_error, ) - schema = {header.strip(): {"type": type_inferred.infer()} for header, type_inferred in type_inferrer_by_field.items()} + schema = { + header.strip(): {"type": type_inferred.infer()} + for header, type_inferred in type_inferrer_by_field.items() + } data_generator.close() return schema @@ -203,19 +226,30 @@ def parse_records( try: config_format = _extract_format(config) if discovered_schema: - property_types = {col: prop["type"] for col, prop in discovered_schema["properties"].items()} # type: ignore # discovered_schema["properties"] is known to be a mapping + property_types = { + col: prop["type"] for col, prop in discovered_schema["properties"].items() + } # type: ignore # discovered_schema["properties"] is known to be a mapping deduped_property_types = CsvParser._pre_propcess_property_types(property_types) else: deduped_property_types = {} - cast_fn = CsvParser._get_cast_function(deduped_property_types, config_format, logger, config.schemaless) - data_generator = self._csv_reader.read_data(config, file, stream_reader, logger, self.file_read_mode) + cast_fn = CsvParser._get_cast_function( + deduped_property_types, config_format, logger, config.schemaless + ) + data_generator = self._csv_reader.read_data( + config, file, stream_reader, logger, self.file_read_mode + ) for row in data_generator: line_no += 1 yield CsvParser._to_nullable( - cast_fn(row), deduped_property_types, config_format.null_values, config_format.strings_can_be_null + cast_fn(row), + deduped_property_types, + config_format.null_values, + config_format.strings_can_be_null, ) except RecordParseError as parse_err: - raise RecordParseError(FileBasedSourceError.ERROR_PARSING_RECORD, filename=file.uri, lineno=line_no) from parse_err + raise RecordParseError( + FileBasedSourceError.ERROR_PARSING_RECORD, filename=file.uri, lineno=line_no + ) from parse_err finally: data_generator.close() @@ -225,27 +259,47 @@ def file_read_mode(self) -> FileReadMode: @staticmethod def _get_cast_function( - deduped_property_types: Mapping[str, str], config_format: CsvFormat, logger: logging.Logger, schemaless: bool + deduped_property_types: Mapping[str, str], + config_format: CsvFormat, + logger: logging.Logger, + schemaless: bool, ) -> Callable[[Mapping[str, str]], Mapping[str, str]]: # Only cast values if the schema is provided if deduped_property_types and not schemaless: - return partial(CsvParser._cast_types, deduped_property_types=deduped_property_types, config_format=config_format, logger=logger) + return partial( + CsvParser._cast_types, + deduped_property_types=deduped_property_types, + config_format=config_format, + logger=logger, + ) else: # If no schema is provided, yield the rows as they are return _no_cast @staticmethod def _to_nullable( - row: Mapping[str, str], deduped_property_types: Mapping[str, str], null_values: Set[str], strings_can_be_null: bool + row: Mapping[str, str], + deduped_property_types: Mapping[str, str], + null_values: Set[str], + strings_can_be_null: bool, ) -> Dict[str, Optional[str]]: nullable = { - k: None if CsvParser._value_is_none(v, deduped_property_types.get(k), null_values, strings_can_be_null) else v + k: None + if CsvParser._value_is_none( + v, deduped_property_types.get(k), null_values, strings_can_be_null + ) + else v for k, v in row.items() } return nullable @staticmethod - def _value_is_none(value: Any, deduped_property_type: Optional[str], null_values: Set[str], strings_can_be_null: bool) -> bool: + def _value_is_none( + value: Any, + deduped_property_type: Optional[str], + null_values: Set[str], + strings_can_be_null: bool, + ) -> bool: return value in null_values and (strings_can_be_null or deduped_property_type != "string") @staticmethod @@ -280,7 +334,10 @@ def _pre_propcess_property_types(property_types: Dict[str, Any]) -> Mapping[str, @staticmethod def _cast_types( - row: Dict[str, str], deduped_property_types: Mapping[str, str], config_format: CsvFormat, logger: logging.Logger + row: Dict[str, str], + deduped_property_types: Mapping[str, str], + config_format: CsvFormat, + logger: logging.Logger, ) -> Dict[str, Any]: """ Casts the values in the input 'row' dictionary according to the types defined in the JSON schema. @@ -307,7 +364,9 @@ def _cast_types( elif python_type == bool: try: - cast_value = _value_to_bool(value, config_format.true_values, config_format.false_values) + cast_value = _value_to_bool( + value, config_format.true_values, config_format.false_values + ) except ValueError: warnings.append(_format_warning(key, value, prop_type)) @@ -364,7 +423,9 @@ class _JsonTypeInferrer(_TypeInferrer): _NUMBER_TYPE = "number" _STRING_TYPE = "string" - def __init__(self, boolean_trues: Set[str], boolean_falses: Set[str], null_values: Set[str]) -> None: + def __init__( + self, boolean_trues: Set[str], boolean_falses: Set[str], null_values: Set[str] + ) -> None: self._boolean_trues = boolean_trues self._boolean_falses = boolean_falses self._null_values = null_values @@ -375,7 +436,9 @@ def add_value(self, value: Any) -> None: def infer(self) -> str: types_by_value = {value: self._infer_type(value) for value in self._values} - types_excluding_null_values = [types for types in types_by_value.values() if self._NULL_TYPE not in types] + types_excluding_null_values = [ + types for types in types_by_value.values() if self._NULL_TYPE not in types + ] if not types_excluding_null_values: # this is highly unusual but we will consider the column as a string return self._STRING_TYPE diff --git a/airbyte_cdk/sources/file_based/file_types/excel_parser.py b/airbyte_cdk/sources/file_based/file_types/excel_parser.py index 93add410..7a8b4e4b 100644 --- a/airbyte_cdk/sources/file_based/file_types/excel_parser.py +++ b/airbyte_cdk/sources/file_based/file_types/excel_parser.py @@ -8,9 +8,19 @@ from typing import Any, Dict, Iterable, Mapping, Optional, Tuple, Union import pandas as pd -from airbyte_cdk.sources.file_based.config.file_based_stream_config import ExcelFormat, FileBasedStreamConfig -from airbyte_cdk.sources.file_based.exceptions import ConfigValidationError, FileBasedSourceError, RecordParseError -from airbyte_cdk.sources.file_based.file_based_stream_reader import AbstractFileBasedStreamReader, FileReadMode +from airbyte_cdk.sources.file_based.config.file_based_stream_config import ( + ExcelFormat, + FileBasedStreamConfig, +) +from airbyte_cdk.sources.file_based.exceptions import ( + ConfigValidationError, + FileBasedSourceError, + RecordParseError, +) +from airbyte_cdk.sources.file_based.file_based_stream_reader import ( + AbstractFileBasedStreamReader, + FileReadMode, +) from airbyte_cdk.sources.file_based.file_types.file_type_parser import FileTypeParser from airbyte_cdk.sources.file_based.remote_file import RemoteFile from airbyte_cdk.sources.file_based.schema_helpers import SchemaType @@ -63,7 +73,11 @@ async def infer_schema( fields[column] = self.dtype_to_json_type(prev_frame_column_type, df_type) schema = { - field: ({"type": "string", "format": "date-time"} if fields[field] == "date-time" else {"type": fields[field]}) + field: ( + {"type": "string", "format": "date-time"} + if fields[field] == "date-time" + else {"type": fields[field]} + ) for field in fields } return schema @@ -101,11 +115,15 @@ def parse_records( # DataFrame.to_dict() method returns datetime values in pandas.Timestamp values, which are not serializable by orjson # DataFrame.to_json() returns string with datetime values serialized to iso8601 with microseconds to align with pydantic behavior # see PR description: https://github.com/airbytehq/airbyte/pull/44444/ - yield from orjson.loads(df.to_json(orient="records", date_format="iso", date_unit="us")) + yield from orjson.loads( + df.to_json(orient="records", date_format="iso", date_unit="us") + ) except Exception as exc: # Raise a RecordParseError if any exception occurs during parsing - raise RecordParseError(FileBasedSourceError.ERROR_PARSING_RECORD, filename=file.uri) from exc + raise RecordParseError( + FileBasedSourceError.ERROR_PARSING_RECORD, filename=file.uri + ) from exc @property def file_read_mode(self) -> FileReadMode: diff --git a/airbyte_cdk/sources/file_based/file_types/file_transfer.py b/airbyte_cdk/sources/file_based/file_types/file_transfer.py index e3481867..154b6ff4 100644 --- a/airbyte_cdk/sources/file_based/file_types/file_transfer.py +++ b/airbyte_cdk/sources/file_based/file_types/file_transfer.py @@ -15,7 +15,11 @@ class FileTransfer: def __init__(self) -> None: - self._local_directory = AIRBYTE_STAGING_DIRECTORY if os.path.exists(AIRBYTE_STAGING_DIRECTORY) else DEFAULT_LOCAL_DIRECTORY + self._local_directory = ( + AIRBYTE_STAGING_DIRECTORY + if os.path.exists(AIRBYTE_STAGING_DIRECTORY) + else DEFAULT_LOCAL_DIRECTORY + ) def get_file( self, @@ -25,7 +29,9 @@ def get_file( logger: logging.Logger, ) -> Iterable[Dict[str, Any]]: try: - yield stream_reader.get_file(file=file, local_directory=self._local_directory, logger=logger) + yield stream_reader.get_file( + file=file, local_directory=self._local_directory, logger=logger + ) except Exception as ex: logger.error("An error has occurred while getting file: %s", str(ex)) raise ex diff --git a/airbyte_cdk/sources/file_based/file_types/file_type_parser.py b/airbyte_cdk/sources/file_based/file_types/file_type_parser.py index d334621a..e6a9c5cb 100644 --- a/airbyte_cdk/sources/file_based/file_types/file_type_parser.py +++ b/airbyte_cdk/sources/file_based/file_types/file_type_parser.py @@ -7,7 +7,10 @@ from typing import Any, Dict, Iterable, Mapping, Optional, Tuple from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig -from airbyte_cdk.sources.file_based.file_based_stream_reader import AbstractFileBasedStreamReader, FileReadMode +from airbyte_cdk.sources.file_based.file_based_stream_reader import ( + AbstractFileBasedStreamReader, + FileReadMode, +) from airbyte_cdk.sources.file_based.remote_file import RemoteFile from airbyte_cdk.sources.file_based.schema_helpers import SchemaType diff --git a/airbyte_cdk/sources/file_based/file_types/jsonl_parser.py b/airbyte_cdk/sources/file_based/file_types/jsonl_parser.py index 4772173f..6cd59075 100644 --- a/airbyte_cdk/sources/file_based/file_types/jsonl_parser.py +++ b/airbyte_cdk/sources/file_based/file_types/jsonl_parser.py @@ -8,10 +8,17 @@ from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig from airbyte_cdk.sources.file_based.exceptions import FileBasedSourceError, RecordParseError -from airbyte_cdk.sources.file_based.file_based_stream_reader import AbstractFileBasedStreamReader, FileReadMode +from airbyte_cdk.sources.file_based.file_based_stream_reader import ( + AbstractFileBasedStreamReader, + FileReadMode, +) from airbyte_cdk.sources.file_based.file_types.file_type_parser import FileTypeParser from airbyte_cdk.sources.file_based.remote_file import RemoteFile -from airbyte_cdk.sources.file_based.schema_helpers import PYTHON_TYPE_MAPPING, SchemaType, merge_schemas +from airbyte_cdk.sources.file_based.schema_helpers import ( + PYTHON_TYPE_MAPPING, + SchemaType, + merge_schemas, +) from orjson import orjson @@ -102,7 +109,9 @@ def _parse_jsonl_entries( try: record = orjson.loads(accumulator) if had_json_parsing_error and not has_warned_for_multiline_json_object: - logger.warning(f"File at {file.uri} is using multiline JSON. Performance could be greatly reduced") + logger.warning( + f"File at {file.uri} is using multiline JSON. Performance could be greatly reduced" + ) has_warned_for_multiline_json_object = True yield record @@ -111,7 +120,11 @@ def _parse_jsonl_entries( except orjson.JSONDecodeError: had_json_parsing_error = True - if read_limit and yielded_at_least_once and read_bytes >= self.MAX_BYTES_PER_FILE_FOR_SCHEMA_INFERENCE: + if ( + read_limit + and yielded_at_least_once + and read_bytes >= self.MAX_BYTES_PER_FILE_FOR_SCHEMA_INFERENCE + ): logger.warning( f"Exceeded the maximum number of bytes per file for schema inference ({self.MAX_BYTES_PER_FILE_FOR_SCHEMA_INFERENCE}). " f"Inferring schema from an incomplete set of records." @@ -119,7 +132,9 @@ def _parse_jsonl_entries( break if had_json_parsing_error and not yielded_at_least_once: - raise RecordParseError(FileBasedSourceError.ERROR_PARSING_RECORD, filename=file.uri, lineno=line) + raise RecordParseError( + FileBasedSourceError.ERROR_PARSING_RECORD, filename=file.uri, lineno=line + ) @staticmethod def _instantiate_accumulator(line: Union[bytes, str]) -> Union[bytes, str]: diff --git a/airbyte_cdk/sources/file_based/file_types/parquet_parser.py b/airbyte_cdk/sources/file_based/file_types/parquet_parser.py index ed25ceb4..99b6373d 100644 --- a/airbyte_cdk/sources/file_based/file_types/parquet_parser.py +++ b/airbyte_cdk/sources/file_based/file_types/parquet_parser.py @@ -10,9 +10,19 @@ import pyarrow as pa import pyarrow.parquet as pq -from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig, ParquetFormat -from airbyte_cdk.sources.file_based.exceptions import ConfigValidationError, FileBasedSourceError, RecordParseError -from airbyte_cdk.sources.file_based.file_based_stream_reader import AbstractFileBasedStreamReader, FileReadMode +from airbyte_cdk.sources.file_based.config.file_based_stream_config import ( + FileBasedStreamConfig, + ParquetFormat, +) +from airbyte_cdk.sources.file_based.exceptions import ( + ConfigValidationError, + FileBasedSourceError, + RecordParseError, +) +from airbyte_cdk.sources.file_based.file_based_stream_reader import ( + AbstractFileBasedStreamReader, + FileReadMode, +) from airbyte_cdk.sources.file_based.file_types.file_type_parser import FileTypeParser from airbyte_cdk.sources.file_based.remote_file import RemoteFile from airbyte_cdk.sources.file_based.schema_helpers import SchemaType @@ -44,9 +54,15 @@ async def infer_schema( parquet_schema = parquet_file.schema_arrow # Inferred non-partition schema - schema = {field.name: ParquetParser.parquet_type_to_schema_type(field.type, parquet_format) for field in parquet_schema} + schema = { + field.name: ParquetParser.parquet_type_to_schema_type(field.type, parquet_format) + for field in parquet_schema + } # Inferred partition schema - partition_columns = {partition.split("=")[0]: {"type": "string"} for partition in self._extract_partitions(file.uri)} + partition_columns = { + partition.split("=")[0]: {"type": "string"} + for partition in self._extract_partitions(file.uri) + } schema.update(partition_columns) return schema @@ -68,21 +84,27 @@ def parse_records( try: with stream_reader.open_file(file, self.file_read_mode, self.ENCODING, logger) as fp: reader = pq.ParquetFile(fp) - partition_columns = {x.split("=")[0]: x.split("=")[1] for x in self._extract_partitions(file.uri)} + partition_columns = { + x.split("=")[0]: x.split("=")[1] for x in self._extract_partitions(file.uri) + } for row_group in range(reader.num_row_groups): batch = reader.read_row_group(row_group) for row in range(batch.num_rows): line_no += 1 yield { **{ - column: ParquetParser._to_output_value(batch.column(column)[row], parquet_format) + column: ParquetParser._to_output_value( + batch.column(column)[row], parquet_format + ) for column in batch.column_names }, **partition_columns, } except Exception as exc: raise RecordParseError( - FileBasedSourceError.ERROR_PARSING_RECORD, filename=file.uri, lineno=f"{row_group=}, {line_no=}" + FileBasedSourceError.ERROR_PARSING_RECORD, + filename=file.uri, + lineno=f"{row_group=}, {line_no=}", ) from exc @staticmethod @@ -94,7 +116,9 @@ def file_read_mode(self) -> FileReadMode: return FileReadMode.READ_BINARY @staticmethod - def _to_output_value(parquet_value: Union[Scalar, DictionaryArray], parquet_format: ParquetFormat) -> Any: + def _to_output_value( + parquet_value: Union[Scalar, DictionaryArray], parquet_format: ParquetFormat + ) -> Any: """ Convert an entry in a pyarrow table to a value that can be output by the source. """ @@ -112,7 +136,11 @@ def _scalar_to_python_value(parquet_value: Scalar, parquet_format: ParquetFormat return None # Convert date and datetime objects to isoformat strings - if pa.types.is_time(parquet_value.type) or pa.types.is_timestamp(parquet_value.type) or pa.types.is_date(parquet_value.type): + if ( + pa.types.is_time(parquet_value.type) + or pa.types.is_timestamp(parquet_value.type) + or pa.types.is_date(parquet_value.type) + ): return parquet_value.as_py().isoformat() # Convert month_day_nano_interval to array @@ -167,7 +195,9 @@ def _dictionary_array_to_python_value(parquet_value: DictionaryArray) -> Dict[st } @staticmethod - def parquet_type_to_schema_type(parquet_type: pa.DataType, parquet_format: ParquetFormat) -> Mapping[str, str]: + def parquet_type_to_schema_type( + parquet_type: pa.DataType, parquet_format: ParquetFormat + ) -> Mapping[str, str]: """ Convert a pyarrow data type to an Airbyte schema type. Parquet data types are defined at https://arrow.apache.org/docs/python/api/datatypes.html @@ -197,7 +227,9 @@ def parquet_type_to_schema_type(parquet_type: pa.DataType, parquet_format: Parqu @staticmethod def _is_binary(parquet_type: pa.DataType) -> bool: return bool( - pa.types.is_binary(parquet_type) or pa.types.is_large_binary(parquet_type) or pa.types.is_fixed_size_binary(parquet_type) + pa.types.is_binary(parquet_type) + or pa.types.is_large_binary(parquet_type) + or pa.types.is_fixed_size_binary(parquet_type) ) @staticmethod @@ -220,13 +252,23 @@ def _is_string(parquet_type: pa.DataType, parquet_format: ParquetFormat) -> bool pa.types.is_time(parquet_type) or pa.types.is_string(parquet_type) or pa.types.is_large_string(parquet_type) - or ParquetParser._is_binary(parquet_type) # Best we can do is return as a string since we do not support binary + or ParquetParser._is_binary( + parquet_type + ) # Best we can do is return as a string since we do not support binary ) @staticmethod def _is_object(parquet_type: pa.DataType) -> bool: - return bool(pa.types.is_dictionary(parquet_type) or pa.types.is_struct(parquet_type) or pa.types.is_map(parquet_type)) + return bool( + pa.types.is_dictionary(parquet_type) + or pa.types.is_struct(parquet_type) + or pa.types.is_map(parquet_type) + ) @staticmethod def _is_list(parquet_type: pa.DataType) -> bool: - return bool(pa.types.is_list(parquet_type) or pa.types.is_large_list(parquet_type) or parquet_type == pa.month_day_nano_interval()) + return bool( + pa.types.is_list(parquet_type) + or pa.types.is_large_list(parquet_type) + or parquet_type == pa.month_day_nano_interval() + ) diff --git a/airbyte_cdk/sources/file_based/file_types/unstructured_parser.py b/airbyte_cdk/sources/file_based/file_types/unstructured_parser.py index 659fbd2c..e397ceae 100644 --- a/airbyte_cdk/sources/file_based/file_types/unstructured_parser.py +++ b/airbyte_cdk/sources/file_based/file_types/unstructured_parser.py @@ -19,13 +19,21 @@ UnstructuredFormat, ) from airbyte_cdk.sources.file_based.exceptions import FileBasedSourceError, RecordParseError -from airbyte_cdk.sources.file_based.file_based_stream_reader import AbstractFileBasedStreamReader, FileReadMode +from airbyte_cdk.sources.file_based.file_based_stream_reader import ( + AbstractFileBasedStreamReader, + FileReadMode, +) from airbyte_cdk.sources.file_based.file_types.file_type_parser import FileTypeParser from airbyte_cdk.sources.file_based.remote_file import RemoteFile from airbyte_cdk.sources.file_based.schema_helpers import SchemaType from airbyte_cdk.utils import is_cloud_environment from airbyte_cdk.utils.traced_exception import AirbyteTracedException -from unstructured.file_utils.filetype import FILETYPE_TO_MIMETYPE, STR_TO_FILETYPE, FileType, detect_filetype +from unstructured.file_utils.filetype import ( + FILETYPE_TO_MIMETYPE, + STR_TO_FILETYPE, + FileType, + detect_filetype, +) unstructured_partition_pdf = None unstructured_partition_docx = None @@ -109,7 +117,10 @@ async def infer_schema( "type": "string", "description": "Content of the file as markdown. Might be null if the file could not be parsed", }, - "document_key": {"type": "string", "description": "Unique identifier of the document, e.g. the file path"}, + "document_key": { + "type": "string", + "description": "Unique identifier of the document, e.g. the file path", + }, "_ab_source_file_parse_error": { "type": "string", "description": "Error message if the file could not be parsed even though the file is supported", @@ -149,9 +160,19 @@ def parse_records( else: raise e - def _read_file(self, file_handle: IOBase, remote_file: RemoteFile, format: UnstructuredFormat, logger: logging.Logger) -> str: + def _read_file( + self, + file_handle: IOBase, + remote_file: RemoteFile, + format: UnstructuredFormat, + logger: logging.Logger, + ) -> str: _import_unstructured() - if (not unstructured_partition_pdf) or (not unstructured_partition_docx) or (not unstructured_partition_pptx): + if ( + (not unstructured_partition_pdf) + or (not unstructured_partition_docx) + or (not unstructured_partition_pptx) + ): # check whether unstructured library is actually available for better error message and to ensure proper typing (can't be None after this point) raise Exception("unstructured library is not available") @@ -167,7 +188,9 @@ def _read_file(self, file_handle: IOBase, remote_file: RemoteFile, format: Unstr return self._read_file_locally(file_handle, filetype, format.strategy, remote_file) elif format.processing.mode == "api": try: - result: str = self._read_file_remotely_with_retries(file_handle, format.processing, filetype, format.strategy, remote_file) + result: str = self._read_file_remotely_with_retries( + file_handle, format.processing, filetype, format.strategy, remote_file + ) except Exception as e: # If a parser error happens during remotely processing the file, this means the file is corrupted. This case is handled by the parse_records method, so just rethrow. # @@ -175,11 +198,15 @@ def _read_file(self, file_handle: IOBase, remote_file: RemoteFile, format: Unstr # Once this parser leaves experimental stage, we should consider making this a system error instead for issues that might be transient. if isinstance(e, RecordParseError): raise e - raise AirbyteTracedException.from_exception(e, failure_type=FailureType.config_error) + raise AirbyteTracedException.from_exception( + e, failure_type=FailureType.config_error + ) return result - def _params_to_dict(self, params: Optional[List[APIParameterConfigModel]], strategy: str) -> Dict[str, Union[str, List[str]]]: + def _params_to_dict( + self, params: Optional[List[APIParameterConfigModel]], strategy: str + ) -> Dict[str, Union[str, List[str]]]: result_dict: Dict[str, Union[str, List[str]]] = {"strategy": strategy} if params is None: return result_dict @@ -229,9 +256,16 @@ def check_config(self, config: FileBasedStreamConfig) -> Tuple[bool, Optional[st return True, None - @backoff.on_exception(backoff.expo, requests.exceptions.RequestException, max_tries=5, giveup=user_error) + @backoff.on_exception( + backoff.expo, requests.exceptions.RequestException, max_tries=5, giveup=user_error + ) def _read_file_remotely_with_retries( - self, file_handle: IOBase, format: APIProcessingConfigModel, filetype: FileType, strategy: str, remote_file: RemoteFile + self, + file_handle: IOBase, + format: APIProcessingConfigModel, + filetype: FileType, + strategy: str, + remote_file: RemoteFile, ) -> str: """ Read a file remotely, retrying up to 5 times if the error is not caused by user error. This is useful for transient network errors or the API server being overloaded temporarily. @@ -239,7 +273,12 @@ def _read_file_remotely_with_retries( return self._read_file_remotely(file_handle, format, filetype, strategy, remote_file) def _read_file_remotely( - self, file_handle: IOBase, format: APIProcessingConfigModel, filetype: FileType, strategy: str, remote_file: RemoteFile + self, + file_handle: IOBase, + format: APIProcessingConfigModel, + filetype: FileType, + strategy: str, + remote_file: RemoteFile, ) -> str: headers = {"accept": "application/json", "unstructured-api-key": format.api_key} @@ -247,7 +286,9 @@ def _read_file_remotely( file_data = {"files": ("filename", file_handle, FILETYPE_TO_MIMETYPE[filetype])} - response = requests.post(f"{format.api_url}/general/v0/general", headers=headers, data=data, files=file_data) + response = requests.post( + f"{format.api_url}/general/v0/general", headers=headers, data=data, files=file_data + ) if response.status_code == 422: # 422 means the file couldn't be processed, but the API is working. Treat this as a parsing error (passing an error record to the destination). @@ -260,9 +301,15 @@ def _read_file_remotely( return self._render_markdown(json_response) - def _read_file_locally(self, file_handle: IOBase, filetype: FileType, strategy: str, remote_file: RemoteFile) -> str: + def _read_file_locally( + self, file_handle: IOBase, filetype: FileType, strategy: str, remote_file: RemoteFile + ) -> str: _import_unstructured() - if (not unstructured_partition_pdf) or (not unstructured_partition_docx) or (not unstructured_partition_pptx): + if ( + (not unstructured_partition_pdf) + or (not unstructured_partition_docx) + or (not unstructured_partition_pptx) + ): # check whether unstructured library is actually available for better error message and to ensure proper typing (can't be None after this point) raise Exception("unstructured library is not available") @@ -290,7 +337,9 @@ def _read_file_locally(self, file_handle: IOBase, filetype: FileType, strategy: return self._render_markdown([element.to_dict() for element in elements]) def _create_parse_error(self, remote_file: RemoteFile, message: str) -> RecordParseError: - return RecordParseError(FileBasedSourceError.ERROR_PARSING_RECORD, filename=remote_file.uri, message=message) + return RecordParseError( + FileBasedSourceError.ERROR_PARSING_RECORD, filename=remote_file.uri, message=message + ) def _get_filetype(self, file: IOBase, remote_file: RemoteFile) -> Optional[FileType]: """ diff --git a/airbyte_cdk/sources/file_based/schema_helpers.py b/airbyte_cdk/sources/file_based/schema_helpers.py index fb714120..1b653db6 100644 --- a/airbyte_cdk/sources/file_based/schema_helpers.py +++ b/airbyte_cdk/sources/file_based/schema_helpers.py @@ -8,13 +8,20 @@ from functools import total_ordering from typing import Any, Dict, List, Literal, Mapping, Optional, Tuple, Type, Union -from airbyte_cdk.sources.file_based.exceptions import ConfigValidationError, FileBasedSourceError, SchemaInferenceError +from airbyte_cdk.sources.file_based.exceptions import ( + ConfigValidationError, + FileBasedSourceError, + SchemaInferenceError, +) JsonSchemaSupportedType = Union[List[str], Literal["string"], str] SchemaType = Mapping[str, Mapping[str, JsonSchemaSupportedType]] schemaless_schema = {"type": "object", "properties": {"data": {"type": "object"}}} -file_transfer_schema = {"type": "object", "properties": {"data": {"type": "object"}, "file": {"type": "object"}}} +file_transfer_schema = { + "type": "object", + "properties": {"data": {"type": "object"}, "file": {"type": "object"}}, +} @total_ordering @@ -129,7 +136,12 @@ def _choose_wider_type(key: str, t1: Mapping[str, Any], t2: Mapping[str, Any]) - detected_types=f"{t1},{t2}", ) # Schemas can still be merged if a key contains a null value in either t1 or t2, but it is still an object - elif (t1_type == "object" or t2_type == "object") and t1_type != "null" and t2_type != "null" and t1 != t2: + elif ( + (t1_type == "object" or t2_type == "object") + and t1_type != "null" + and t2_type != "null" + and t1 != t2 + ): raise SchemaInferenceError( FileBasedSourceError.SCHEMA_INFERENCE_ERROR, details="Cannot merge schema for unequal object types.", @@ -137,12 +149,19 @@ def _choose_wider_type(key: str, t1: Mapping[str, Any], t2: Mapping[str, Any]) - detected_types=f"{t1},{t2}", ) else: - comparable_t1 = get_comparable_type(TYPE_PYTHON_MAPPING[t1_type][0]) # accessing the type_mapping value - comparable_t2 = get_comparable_type(TYPE_PYTHON_MAPPING[t2_type][0]) # accessing the type_mapping value + comparable_t1 = get_comparable_type( + TYPE_PYTHON_MAPPING[t1_type][0] + ) # accessing the type_mapping value + comparable_t2 = get_comparable_type( + TYPE_PYTHON_MAPPING[t2_type][0] + ) # accessing the type_mapping value if not comparable_t1 and comparable_t2: - raise SchemaInferenceError(FileBasedSourceError.UNRECOGNIZED_TYPE, key=key, detected_types=f"{t1},{t2}") + raise SchemaInferenceError( + FileBasedSourceError.UNRECOGNIZED_TYPE, key=key, detected_types=f"{t1},{t2}" + ) return max( - [t1, t2], key=lambda x: ComparableType(get_comparable_type(TYPE_PYTHON_MAPPING[x["type"]][0])) + [t1, t2], + key=lambda x: ComparableType(get_comparable_type(TYPE_PYTHON_MAPPING[x["type"]][0])), ) # accessing the type_mapping value @@ -205,7 +224,8 @@ def _parse_json_input(input_schema: Union[str, Mapping[str, str]]) -> Optional[M schema = input_schema if not all(isinstance(s, str) for s in schema.values()): raise ConfigValidationError( - FileBasedSourceError.ERROR_PARSING_USER_PROVIDED_SCHEMA, details="Invalid input schema; nested schemas are not supported." + FileBasedSourceError.ERROR_PARSING_USER_PROVIDED_SCHEMA, + details="Invalid input schema; nested schemas are not supported.", ) except json.decoder.JSONDecodeError: @@ -214,7 +234,9 @@ def _parse_json_input(input_schema: Union[str, Mapping[str, str]]) -> Optional[M return schema -def type_mapping_to_jsonschema(input_schema: Optional[Union[str, Mapping[str, str]]]) -> Optional[Mapping[str, Any]]: +def type_mapping_to_jsonschema( + input_schema: Optional[Union[str, Mapping[str, str]]], +) -> Optional[Mapping[str, Any]]: """ Return the user input schema (type mapping), transformed to JSON Schema format. @@ -241,7 +263,8 @@ def type_mapping_to_jsonschema(input_schema: Optional[Union[str, Mapping[str, st if not _json_schema_type: raise ConfigValidationError( - FileBasedSourceError.ERROR_PARSING_USER_PROVIDED_SCHEMA, details=f"Invalid type '{type_name}' for property '{col_name}'." + FileBasedSourceError.ERROR_PARSING_USER_PROVIDED_SCHEMA, + details=f"Invalid type '{type_name}' for property '{col_name}'.", ) json_schema_type = _json_schema_type[0] diff --git a/airbyte_cdk/sources/file_based/schema_validation_policies/abstract_schema_validation_policy.py b/airbyte_cdk/sources/file_based/schema_validation_policies/abstract_schema_validation_policy.py index 004139b7..139511a9 100644 --- a/airbyte_cdk/sources/file_based/schema_validation_policies/abstract_schema_validation_policy.py +++ b/airbyte_cdk/sources/file_based/schema_validation_policies/abstract_schema_validation_policy.py @@ -11,7 +11,9 @@ class AbstractSchemaValidationPolicy(ABC): validate_schema_before_sync = False # Whether to verify that records conform to the schema during the stream's availabilty check @abstractmethod - def record_passes_validation_policy(self, record: Mapping[str, Any], schema: Optional[Mapping[str, Any]]) -> bool: + def record_passes_validation_policy( + self, record: Mapping[str, Any], schema: Optional[Mapping[str, Any]] + ) -> bool: """ Return True if the record passes the user's validation policy. """ diff --git a/airbyte_cdk/sources/file_based/schema_validation_policies/default_schema_validation_policies.py b/airbyte_cdk/sources/file_based/schema_validation_policies/default_schema_validation_policies.py index 02134d1b..261b0fab 100644 --- a/airbyte_cdk/sources/file_based/schema_validation_policies/default_schema_validation_policies.py +++ b/airbyte_cdk/sources/file_based/schema_validation_policies/default_schema_validation_policies.py @@ -5,7 +5,10 @@ from typing import Any, Mapping, Optional from airbyte_cdk.sources.file_based.config.file_based_stream_config import ValidationPolicy -from airbyte_cdk.sources.file_based.exceptions import FileBasedSourceError, StopSyncPerValidationPolicy +from airbyte_cdk.sources.file_based.exceptions import ( + FileBasedSourceError, + StopSyncPerValidationPolicy, +) from airbyte_cdk.sources.file_based.schema_helpers import conforms_to_schema from airbyte_cdk.sources.file_based.schema_validation_policies import AbstractSchemaValidationPolicy @@ -13,14 +16,18 @@ class EmitRecordPolicy(AbstractSchemaValidationPolicy): name = "emit_record" - def record_passes_validation_policy(self, record: Mapping[str, Any], schema: Optional[Mapping[str, Any]]) -> bool: + def record_passes_validation_policy( + self, record: Mapping[str, Any], schema: Optional[Mapping[str, Any]] + ) -> bool: return True class SkipRecordPolicy(AbstractSchemaValidationPolicy): name = "skip_record" - def record_passes_validation_policy(self, record: Mapping[str, Any], schema: Optional[Mapping[str, Any]]) -> bool: + def record_passes_validation_policy( + self, record: Mapping[str, Any], schema: Optional[Mapping[str, Any]] + ) -> bool: return schema is not None and conforms_to_schema(record, schema) @@ -28,9 +35,13 @@ class WaitForDiscoverPolicy(AbstractSchemaValidationPolicy): name = "wait_for_discover" validate_schema_before_sync = True - def record_passes_validation_policy(self, record: Mapping[str, Any], schema: Optional[Mapping[str, Any]]) -> bool: + def record_passes_validation_policy( + self, record: Mapping[str, Any], schema: Optional[Mapping[str, Any]] + ) -> bool: if schema is None or not conforms_to_schema(record, schema): - raise StopSyncPerValidationPolicy(FileBasedSourceError.STOP_SYNC_PER_SCHEMA_VALIDATION_POLICY) + raise StopSyncPerValidationPolicy( + FileBasedSourceError.STOP_SYNC_PER_SCHEMA_VALIDATION_POLICY + ) return True diff --git a/airbyte_cdk/sources/file_based/stream/abstract_file_based_stream.py b/airbyte_cdk/sources/file_based/stream/abstract_file_based_stream.py index 5c2393e9..8c0e1ebf 100644 --- a/airbyte_cdk/sources/file_based/stream/abstract_file_based_stream.py +++ b/airbyte_cdk/sources/file_based/stream/abstract_file_based_stream.py @@ -8,10 +8,20 @@ from airbyte_cdk import AirbyteMessage from airbyte_cdk.models import SyncMode -from airbyte_cdk.sources.file_based.availability_strategy import AbstractFileBasedAvailabilityStrategy -from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig, PrimaryKeyType +from airbyte_cdk.sources.file_based.availability_strategy import ( + AbstractFileBasedAvailabilityStrategy, +) +from airbyte_cdk.sources.file_based.config.file_based_stream_config import ( + FileBasedStreamConfig, + PrimaryKeyType, +) from airbyte_cdk.sources.file_based.discovery_policy import AbstractDiscoveryPolicy -from airbyte_cdk.sources.file_based.exceptions import FileBasedErrorsCollector, FileBasedSourceError, RecordParseError, UndefinedParserError +from airbyte_cdk.sources.file_based.exceptions import ( + FileBasedErrorsCollector, + FileBasedSourceError, + RecordParseError, + UndefinedParserError, +) from airbyte_cdk.sources.file_based.file_based_stream_reader import AbstractFileBasedStreamReader from airbyte_cdk.sources.file_based.file_types.file_type_parser import FileTypeParser from airbyte_cdk.sources.file_based.remote_file import RemoteFile @@ -101,14 +111,20 @@ def read_records( return self.read_records_from_slice(stream_slice) @abstractmethod - def read_records_from_slice(self, stream_slice: StreamSlice) -> Iterable[Mapping[str, Any] | AirbyteMessage]: + def read_records_from_slice( + self, stream_slice: StreamSlice + ) -> Iterable[Mapping[str, Any] | AirbyteMessage]: """ Yield all records from all remote files in `list_files_for_this_sync`. """ ... def stream_slices( - self, *, sync_mode: SyncMode, cursor_field: Optional[List[str]] = None, stream_state: Optional[Mapping[str, Any]] = None + self, + *, + sync_mode: SyncMode, + cursor_field: Optional[List[str]] = None, + stream_state: Optional[Mapping[str, Any]] = None, ) -> Iterable[Optional[Mapping[str, Any]]]: """ This method acts as an adapter between the generic Stream interface and the file-based's @@ -143,14 +159,22 @@ def get_parser(self) -> FileTypeParser: try: return self._parsers[type(self.config.format)] except KeyError: - raise UndefinedParserError(FileBasedSourceError.UNDEFINED_PARSER, stream=self.name, format=type(self.config.format)) + raise UndefinedParserError( + FileBasedSourceError.UNDEFINED_PARSER, + stream=self.name, + format=type(self.config.format), + ) def record_passes_validation_policy(self, record: Mapping[str, Any]) -> bool: if self.validation_policy: - return self.validation_policy.record_passes_validation_policy(record=record, schema=self.catalog_schema) + return self.validation_policy.record_passes_validation_policy( + record=record, schema=self.catalog_schema + ) else: raise RecordParseError( - FileBasedSourceError.UNDEFINED_VALIDATION_POLICY, stream=self.name, validation_policy=self.config.validation_policy + FileBasedSourceError.UNDEFINED_VALIDATION_POLICY, + stream=self.name, + validation_policy=self.config.validation_policy, ) @cached_property diff --git a/airbyte_cdk/sources/file_based/stream/concurrent/adapters.py b/airbyte_cdk/sources/file_based/stream/concurrent/adapters.py index d335819d..fda609ae 100644 --- a/airbyte_cdk/sources/file_based/stream/concurrent/adapters.py +++ b/airbyte_cdk/sources/file_based/stream/concurrent/adapters.py @@ -7,7 +7,14 @@ from functools import cache, lru_cache from typing import TYPE_CHECKING, Any, Iterable, List, Mapping, MutableMapping, Optional, Union -from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, ConfiguredAirbyteStream, Level, SyncMode, Type +from airbyte_cdk.models import ( + AirbyteLogMessage, + AirbyteMessage, + ConfiguredAirbyteStream, + Level, + SyncMode, + Type, +) from airbyte_cdk.sources import AbstractSource from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager from airbyte_cdk.sources.file_based.availability_strategy import ( @@ -26,7 +33,10 @@ from airbyte_cdk.sources.streams.concurrent.abstract_stream_facade import AbstractStreamFacade from airbyte_cdk.sources.streams.concurrent.default_stream import DefaultStream from airbyte_cdk.sources.streams.concurrent.exceptions import ExceptionWithDisplayMessage -from airbyte_cdk.sources.streams.concurrent.helpers import get_cursor_field_from_stream, get_primary_key_from_stream +from airbyte_cdk.sources.streams.concurrent.helpers import ( + get_cursor_field_from_stream, + get_primary_key_from_stream, +) from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition from airbyte_cdk.sources.streams.concurrent.partitions.partition_generator import PartitionGenerator from airbyte_cdk.sources.streams.concurrent.partitions.record import Record @@ -36,7 +46,9 @@ from deprecated.classic import deprecated if TYPE_CHECKING: - from airbyte_cdk.sources.file_based.stream.concurrent.cursor import AbstractConcurrentFileBasedCursor + from airbyte_cdk.sources.file_based.stream.concurrent.cursor import ( + AbstractConcurrentFileBasedCursor, + ) """ This module contains adapters to help enabling concurrency on File-based Stream objects without needing to migrate to AbstractStream @@ -72,7 +84,9 @@ def create_from_stream( partition_generator=FileBasedStreamPartitionGenerator( stream, message_repository, - SyncMode.full_refresh if isinstance(cursor, FileBasedFinalStateCursor) else SyncMode.incremental, + SyncMode.full_refresh + if isinstance(cursor, FileBasedFinalStateCursor) + else SyncMode.incremental, [cursor_field] if cursor_field is not None else None, state, cursor, @@ -138,7 +152,10 @@ def get_json_schema(self) -> Mapping[str, Any]: @property def primary_key(self) -> PrimaryKeyType: - return self._legacy_stream.config.primary_key or self.get_parser().get_parser_defined_primary_key(self._legacy_stream.config) + return ( + self._legacy_stream.config.primary_key + or self.get_parser().get_parser_defined_primary_key(self._legacy_stream.config) + ) def get_parser(self) -> FileTypeParser: return self._legacy_stream.get_parser() @@ -185,7 +202,10 @@ def read_records( # This shouldn't happen if the ConcurrentCursor was used state = "unknown; no state attribute was available on the cursor" yield AirbyteMessage( - type=Type.LOG, log=AirbyteLogMessage(level=Level.ERROR, message=f"Cursor State at time of exception: {state}") + type=Type.LOG, + log=AirbyteLogMessage( + level=Level.ERROR, message=f"Cursor State at time of exception: {state}" + ), ) raise exc @@ -227,16 +247,30 @@ def read(self) -> Iterable[Record]: ): if isinstance(record_data, Mapping): data_to_return = dict(record_data) - self._stream.transformer.transform(data_to_return, self._stream.get_json_schema()) + self._stream.transformer.transform( + data_to_return, self._stream.get_json_schema() + ) yield Record(data_to_return, self) - elif isinstance(record_data, AirbyteMessage) and record_data.type == Type.RECORD and record_data.record is not None: + elif ( + isinstance(record_data, AirbyteMessage) + and record_data.type == Type.RECORD + and record_data.record is not None + ): # `AirbyteMessage`s of type `Record` should also be yielded so they are enqueued # If stream is flagged for file_transfer the record should data in file key - record_message_data = record_data.record.file if self._use_file_transfer() else record_data.record.data + record_message_data = ( + record_data.record.file + if self._use_file_transfer() + else record_data.record.data + ) if not record_message_data: raise ExceptionWithDisplayMessage("A record without data was found") else: - yield Record(data=record_message_data, partition=self, is_file_transfer_message=self._use_file_transfer()) + yield Record( + data=record_message_data, + partition=self, + is_file_transfer_message=self._use_file_transfer(), + ) else: self._message_repository.emit_message(record_data) except Exception as e: @@ -305,7 +339,9 @@ def __init__( def generate(self) -> Iterable[FileBasedStreamPartition]: pending_partitions = [] - for _slice in self._stream.stream_slices(sync_mode=self._sync_mode, cursor_field=self._cursor_field, stream_state=self._state): + for _slice in self._stream.stream_slices( + sync_mode=self._sync_mode, cursor_field=self._cursor_field, stream_state=self._state + ): if _slice is not None: for file in _slice.get("files", []): pending_partitions.append( diff --git a/airbyte_cdk/sources/file_based/stream/concurrent/cursor/abstract_concurrent_file_based_cursor.py b/airbyte_cdk/sources/file_based/stream/concurrent/cursor/abstract_concurrent_file_based_cursor.py index 9cb3541c..ef8b290d 100644 --- a/airbyte_cdk/sources/file_based/stream/concurrent/cursor/abstract_concurrent_file_based_cursor.py +++ b/airbyte_cdk/sources/file_based/stream/concurrent/cursor/abstract_concurrent_file_based_cursor.py @@ -39,7 +39,9 @@ def set_pending_partitions(self, partitions: List["FileBasedStreamPartition"]) - def add_file(self, file: RemoteFile) -> None: ... @abstractmethod - def get_files_to_sync(self, all_files: Iterable[RemoteFile], logger: logging.Logger) -> Iterable[RemoteFile]: ... + def get_files_to_sync( + self, all_files: Iterable[RemoteFile], logger: logging.Logger + ) -> Iterable[RemoteFile]: ... @abstractmethod def get_state(self) -> MutableMapping[str, Any]: ... diff --git a/airbyte_cdk/sources/file_based/stream/concurrent/cursor/file_based_concurrent_cursor.py b/airbyte_cdk/sources/file_based/stream/concurrent/cursor/file_based_concurrent_cursor.py index 0e3acaf8..e7bb2796 100644 --- a/airbyte_cdk/sources/file_based/stream/concurrent/cursor/file_based_concurrent_cursor.py +++ b/airbyte_cdk/sources/file_based/stream/concurrent/cursor/file_based_concurrent_cursor.py @@ -11,7 +11,9 @@ from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig from airbyte_cdk.sources.file_based.remote_file import RemoteFile -from airbyte_cdk.sources.file_based.stream.concurrent.cursor.abstract_concurrent_file_based_cursor import AbstractConcurrentFileBasedCursor +from airbyte_cdk.sources.file_based.stream.concurrent.cursor.abstract_concurrent_file_based_cursor import ( + AbstractConcurrentFileBasedCursor, +) from airbyte_cdk.sources.file_based.stream.cursor import DefaultFileBasedCursor from airbyte_cdk.sources.file_based.types import StreamState from airbyte_cdk.sources.message.repository import MessageRepository @@ -27,7 +29,9 @@ class FileBasedConcurrentCursor(AbstractConcurrentFileBasedCursor): CURSOR_FIELD = "_ab_source_file_last_modified" - DEFAULT_DAYS_TO_SYNC_IF_HISTORY_IS_FULL = DefaultFileBasedCursor.DEFAULT_DAYS_TO_SYNC_IF_HISTORY_IS_FULL + DEFAULT_DAYS_TO_SYNC_IF_HISTORY_IS_FULL = ( + DefaultFileBasedCursor.DEFAULT_DAYS_TO_SYNC_IF_HISTORY_IS_FULL + ) DEFAULT_MAX_HISTORY_SIZE = 10_000 DATE_TIME_FORMAT = DefaultFileBasedCursor.DATE_TIME_FORMAT zero_value = datetime.min @@ -51,7 +55,8 @@ def __init__( self._connector_state_manager = connector_state_manager self._cursor_field = cursor_field self._time_window_if_history_is_full = timedelta( - days=stream_config.days_to_sync_if_history_is_full or self.DEFAULT_DAYS_TO_SYNC_IF_HISTORY_IS_FULL + days=stream_config.days_to_sync_if_history_is_full + or self.DEFAULT_DAYS_TO_SYNC_IF_HISTORY_IS_FULL ) self._state_lock = RLock() self._pending_files_lock = RLock() @@ -70,7 +75,9 @@ def observe(self, record: Record) -> None: def close_partition(self, partition: Partition) -> None: with self._pending_files_lock: if self._pending_files is None: - raise RuntimeError("Expected pending partitions to be set but it was not. This is unexpected. Please contact Support.") + raise RuntimeError( + "Expected pending partitions to be set but it was not. This is unexpected. Please contact Support." + ) def set_pending_partitions(self, partitions: List["FileBasedStreamPartition"]) -> None: with self._pending_files_lock: @@ -81,7 +88,9 @@ def set_pending_partitions(self, partitions: List["FileBasedStreamPartition"]) - continue for file in _slice["files"]: if file.uri in self._pending_files.keys(): - raise RuntimeError(f"Already found file {_slice} in pending files. This is unexpected. Please contact Support.") + raise RuntimeError( + f"Already found file {_slice} in pending files. This is unexpected. Please contact Support." + ) self._pending_files.update({file.uri: file}) def _compute_prev_sync_cursor(self, value: Optional[StreamState]) -> Tuple[datetime, str]: @@ -96,7 +105,9 @@ def _compute_prev_sync_cursor(self, value: Optional[StreamState]) -> Tuple[datet # represents the start time that the file was uploaded, we can usually expect that all previous # files have already been uploaded. If that's the case, they'll be in history and we'll skip # re-uploading them. - earliest_file_cursor_value = self._get_cursor_key_from_file(self._compute_earliest_file_in_history()) + earliest_file_cursor_value = self._get_cursor_key_from_file( + self._compute_earliest_file_in_history() + ) cursor_str = min(prev_cursor_str, earliest_file_cursor_value) cursor_dt, cursor_uri = cursor_str.split("_", 1) return datetime.strptime(cursor_dt, self.DATE_TIME_FORMAT), cursor_uri @@ -109,8 +120,13 @@ def _get_cursor_key_from_file(self, file: Optional[RemoteFile]) -> str: def _compute_earliest_file_in_history(self) -> Optional[RemoteFile]: with self._state_lock: if self._file_to_datetime_history: - filename, last_modified = min(self._file_to_datetime_history.items(), key=lambda f: (f[1], f[0])) - return RemoteFile(uri=filename, last_modified=datetime.strptime(last_modified, self.DATE_TIME_FORMAT)) + filename, last_modified = min( + self._file_to_datetime_history.items(), key=lambda f: (f[1], f[0]) + ) + return RemoteFile( + uri=filename, + last_modified=datetime.strptime(last_modified, self.DATE_TIME_FORMAT), + ) else: return None @@ -120,7 +136,9 @@ def add_file(self, file: RemoteFile) -> None: :param file: The file to add """ if self._pending_files is None: - raise RuntimeError("Expected pending partitions to be set but it was not. This is unexpected. Please contact Support.") + raise RuntimeError( + "Expected pending partitions to be set but it was not. This is unexpected. Please contact Support." + ) with self._pending_files_lock: with self._state_lock: if file.uri not in self._pending_files: @@ -135,7 +153,9 @@ def add_file(self, file: RemoteFile) -> None: ) else: self._pending_files.pop(file.uri) - self._file_to_datetime_history[file.uri] = file.last_modified.strftime(self.DATE_TIME_FORMAT) + self._file_to_datetime_history[file.uri] = file.last_modified.strftime( + self.DATE_TIME_FORMAT + ) if len(self._file_to_datetime_history) > self.DEFAULT_MAX_HISTORY_SIZE: # Get the earliest file based on its last modified date and its uri oldest_file = self._compute_earliest_file_in_history() @@ -155,7 +175,9 @@ def emit_state_message(self) -> None: self._stream_namespace, new_state, ) - state_message = self._connector_state_manager.create_state_message(self._stream_name, self._stream_namespace) + state_message = self._connector_state_manager.create_state_message( + self._stream_name, self._stream_namespace + ) self._message_repository.emit_message(state_message) def _get_new_cursor_value(self) -> str: @@ -183,12 +205,19 @@ def _compute_earliest_pending_file(self) -> Optional[RemoteFile]: def _compute_latest_file_in_history(self) -> Optional[RemoteFile]: with self._state_lock: if self._file_to_datetime_history: - filename, last_modified = max(self._file_to_datetime_history.items(), key=lambda f: (f[1], f[0])) - return RemoteFile(uri=filename, last_modified=datetime.strptime(last_modified, self.DATE_TIME_FORMAT)) + filename, last_modified = max( + self._file_to_datetime_history.items(), key=lambda f: (f[1], f[0]) + ) + return RemoteFile( + uri=filename, + last_modified=datetime.strptime(last_modified, self.DATE_TIME_FORMAT), + ) else: return None - def get_files_to_sync(self, all_files: Iterable[RemoteFile], logger: logging.Logger) -> Iterable[RemoteFile]: + def get_files_to_sync( + self, all_files: Iterable[RemoteFile], logger: logging.Logger + ) -> Iterable[RemoteFile]: """ Given the list of files in the source, return the files that should be synced. :param all_files: All files in the source @@ -210,7 +239,9 @@ def _should_sync_file(self, file: RemoteFile, logger: logging.Logger) -> bool: with self._state_lock: if file.uri in self._file_to_datetime_history: # If the file's uri is in the history, we should sync the file if it has been modified since it was synced - updated_at_from_history = datetime.strptime(self._file_to_datetime_history[file.uri], self.DATE_TIME_FORMAT) + updated_at_from_history = datetime.strptime( + self._file_to_datetime_history[file.uri], self.DATE_TIME_FORMAT + ) if file.last_modified < updated_at_from_history: self._message_repository.emit_message( AirbyteMessage( @@ -246,7 +277,9 @@ def _is_history_full(self) -> bool: """ with self._state_lock: if self._file_to_datetime_history is None: - raise RuntimeError("The history object has not been set. This is unexpected. Please contact Support.") + raise RuntimeError( + "The history object has not been set. This is unexpected. Please contact Support." + ) return len(self._file_to_datetime_history) >= self.DEFAULT_MAX_HISTORY_SIZE def _compute_start_time(self) -> datetime: @@ -268,7 +301,10 @@ def get_state(self) -> MutableMapping[str, Any]: Get the state of the cursor. """ with self._state_lock: - return {"history": self._file_to_datetime_history, self._cursor_field.cursor_field_key: self._get_new_cursor_value()} + return { + "history": self._file_to_datetime_history, + self._cursor_field.cursor_field_key: self._get_new_cursor_value(), + } def set_initial_state(self, value: StreamState) -> None: pass diff --git a/airbyte_cdk/sources/file_based/stream/concurrent/cursor/file_based_final_state_cursor.py b/airbyte_cdk/sources/file_based/stream/concurrent/cursor/file_based_final_state_cursor.py index 7181ecd1..b8926451 100644 --- a/airbyte_cdk/sources/file_based/stream/concurrent/cursor/file_based_final_state_cursor.py +++ b/airbyte_cdk/sources/file_based/stream/concurrent/cursor/file_based_final_state_cursor.py @@ -9,7 +9,9 @@ from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig from airbyte_cdk.sources.file_based.remote_file import RemoteFile -from airbyte_cdk.sources.file_based.stream.concurrent.cursor.abstract_concurrent_file_based_cursor import AbstractConcurrentFileBasedCursor +from airbyte_cdk.sources.file_based.stream.concurrent.cursor.abstract_concurrent_file_based_cursor import ( + AbstractConcurrentFileBasedCursor, +) from airbyte_cdk.sources.file_based.types import StreamState from airbyte_cdk.sources.message import MessageRepository from airbyte_cdk.sources.streams import NO_CURSOR_STATE_KEY @@ -24,7 +26,11 @@ class FileBasedFinalStateCursor(AbstractConcurrentFileBasedCursor): """Cursor that is used to guarantee at least one state message is emitted for a concurrent file-based stream.""" def __init__( - self, stream_config: FileBasedStreamConfig, message_repository: MessageRepository, stream_namespace: Optional[str], **kwargs: Any + self, + stream_config: FileBasedStreamConfig, + message_repository: MessageRepository, + stream_namespace: Optional[str], + **kwargs: Any, ): self._stream_name = stream_config.name self._stream_namespace = stream_namespace @@ -50,7 +56,9 @@ def set_pending_partitions(self, partitions: List["FileBasedStreamPartition"]) - def add_file(self, file: RemoteFile) -> None: pass - def get_files_to_sync(self, all_files: Iterable[RemoteFile], logger: logging.Logger) -> Iterable[RemoteFile]: + def get_files_to_sync( + self, all_files: Iterable[RemoteFile], logger: logging.Logger + ) -> Iterable[RemoteFile]: return all_files def get_state(self) -> MutableMapping[str, Any]: @@ -66,6 +74,10 @@ def emit_state_message(self) -> None: pass def ensure_at_least_one_state_emitted(self) -> None: - self._connector_state_manager.update_state_for_stream(self._stream_name, self._stream_namespace, self.state) - state_message = self._connector_state_manager.create_state_message(self._stream_name, self._stream_namespace) + self._connector_state_manager.update_state_for_stream( + self._stream_name, self._stream_namespace, self.state + ) + state_message = self._connector_state_manager.create_state_message( + self._stream_name, self._stream_namespace + ) self._message_repository.emit_message(state_message) diff --git a/airbyte_cdk/sources/file_based/stream/cursor/abstract_file_based_cursor.py b/airbyte_cdk/sources/file_based/stream/cursor/abstract_file_based_cursor.py index f38a5364..4a5eadb4 100644 --- a/airbyte_cdk/sources/file_based/stream/cursor/abstract_file_based_cursor.py +++ b/airbyte_cdk/sources/file_based/stream/cursor/abstract_file_based_cursor.py @@ -54,7 +54,9 @@ def get_start_time(self) -> datetime: ... @abstractmethod - def get_files_to_sync(self, all_files: Iterable[RemoteFile], logger: logging.Logger) -> Iterable[RemoteFile]: + def get_files_to_sync( + self, all_files: Iterable[RemoteFile], logger: logging.Logger + ) -> Iterable[RemoteFile]: """ Given the list of files in the source, return the files that should be synced. :param all_files: All files in the source diff --git a/airbyte_cdk/sources/file_based/stream/cursor/default_file_based_cursor.py b/airbyte_cdk/sources/file_based/stream/cursor/default_file_based_cursor.py index 58d64acb..814bc1a1 100644 --- a/airbyte_cdk/sources/file_based/stream/cursor/default_file_based_cursor.py +++ b/airbyte_cdk/sources/file_based/stream/cursor/default_file_based_cursor.py @@ -8,7 +8,9 @@ from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig from airbyte_cdk.sources.file_based.remote_file import RemoteFile -from airbyte_cdk.sources.file_based.stream.cursor.abstract_file_based_cursor import AbstractFileBasedCursor +from airbyte_cdk.sources.file_based.stream.cursor.abstract_file_based_cursor import ( + AbstractFileBasedCursor, +) from airbyte_cdk.sources.file_based.types import StreamState @@ -22,11 +24,14 @@ def __init__(self, stream_config: FileBasedStreamConfig, **_: Any): super().__init__(stream_config) self._file_to_datetime_history: MutableMapping[str, str] = {} self._time_window_if_history_is_full = timedelta( - days=stream_config.days_to_sync_if_history_is_full or self.DEFAULT_DAYS_TO_SYNC_IF_HISTORY_IS_FULL + days=stream_config.days_to_sync_if_history_is_full + or self.DEFAULT_DAYS_TO_SYNC_IF_HISTORY_IS_FULL ) if self._time_window_if_history_is_full <= timedelta(): - raise ValueError(f"days_to_sync_if_history_is_full must be a positive timedelta, got {self._time_window_if_history_is_full}") + raise ValueError( + f"days_to_sync_if_history_is_full must be a positive timedelta, got {self._time_window_if_history_is_full}" + ) self._start_time = self._compute_start_time() self._initial_earliest_file_in_history: Optional[RemoteFile] = None @@ -37,7 +42,9 @@ def set_initial_state(self, value: StreamState) -> None: self._initial_earliest_file_in_history = self._compute_earliest_file_in_history() def add_file(self, file: RemoteFile) -> None: - self._file_to_datetime_history[file.uri] = file.last_modified.strftime(self.DATE_TIME_FORMAT) + self._file_to_datetime_history[file.uri] = file.last_modified.strftime( + self.DATE_TIME_FORMAT + ) if len(self._file_to_datetime_history) > self.DEFAULT_MAX_HISTORY_SIZE: # Get the earliest file based on its last modified date and its uri oldest_file = self._compute_earliest_file_in_history() @@ -60,7 +67,9 @@ def _get_cursor(self) -> Optional[str]: a string joining the last-modified timestamp of the last synced file and the name of the file. """ if self._file_to_datetime_history.items(): - filename, timestamp = max(self._file_to_datetime_history.items(), key=lambda x: (x[1], x[0])) + filename, timestamp = max( + self._file_to_datetime_history.items(), key=lambda x: (x[1], x[0]) + ) return f"{timestamp}_{filename}" return None @@ -73,7 +82,9 @@ def _is_history_full(self) -> bool: def _should_sync_file(self, file: RemoteFile, logger: logging.Logger) -> bool: if file.uri in self._file_to_datetime_history: # If the file's uri is in the history, we should sync the file if it has been modified since it was synced - updated_at_from_history = datetime.strptime(self._file_to_datetime_history[file.uri], self.DATE_TIME_FORMAT) + updated_at_from_history = datetime.strptime( + self._file_to_datetime_history[file.uri], self.DATE_TIME_FORMAT + ) if file.last_modified < updated_at_from_history: logger.warning( f"The file {file.uri}'s last modified date is older than the last time it was synced. This is unexpected. Skipping the file." @@ -99,7 +110,9 @@ def _should_sync_file(self, file: RemoteFile, logger: logging.Logger) -> bool: # The file is not in the history and the history is complete. We know we need to sync the file return True - def get_files_to_sync(self, all_files: Iterable[RemoteFile], logger: logging.Logger) -> Iterable[RemoteFile]: + def get_files_to_sync( + self, all_files: Iterable[RemoteFile], logger: logging.Logger + ) -> Iterable[RemoteFile]: if self._is_history_full(): logger.warning( f"The state history is full. " @@ -115,8 +128,12 @@ def get_start_time(self) -> datetime: def _compute_earliest_file_in_history(self) -> Optional[RemoteFile]: if self._file_to_datetime_history: - filename, last_modified = min(self._file_to_datetime_history.items(), key=lambda f: (f[1], f[0])) - return RemoteFile(uri=filename, last_modified=datetime.strptime(last_modified, self.DATE_TIME_FORMAT)) + filename, last_modified = min( + self._file_to_datetime_history.items(), key=lambda f: (f[1], f[0]) + ) + return RemoteFile( + uri=filename, last_modified=datetime.strptime(last_modified, self.DATE_TIME_FORMAT) + ) else: return None diff --git a/airbyte_cdk/sources/file_based/stream/default_file_based_stream.py b/airbyte_cdk/sources/file_based/stream/default_file_based_stream.py index a69712ef..a5cae2e6 100644 --- a/airbyte_cdk/sources/file_based/stream/default_file_based_stream.py +++ b/airbyte_cdk/sources/file_based/stream/default_file_based_stream.py @@ -22,7 +22,12 @@ ) from airbyte_cdk.sources.file_based.file_types import FileTransfer from airbyte_cdk.sources.file_based.remote_file import RemoteFile -from airbyte_cdk.sources.file_based.schema_helpers import SchemaType, file_transfer_schema, merge_schemas, schemaless_schema +from airbyte_cdk.sources.file_based.schema_helpers import ( + SchemaType, + file_transfer_schema, + merge_schemas, + schemaless_schema, +) from airbyte_cdk.sources.file_based.stream import AbstractFileBasedStream from airbyte_cdk.sources.file_based.stream.cursor import AbstractFileBasedCursor from airbyte_cdk.sources.file_based.types import StreamSlice @@ -67,18 +72,28 @@ def cursor(self) -> Optional[AbstractFileBasedCursor]: @cursor.setter def cursor(self, value: AbstractFileBasedCursor) -> None: if self._cursor is not None: - raise RuntimeError(f"Cursor for stream {self.name} is already set. This is unexpected. Please contact Support.") + raise RuntimeError( + f"Cursor for stream {self.name} is already set. This is unexpected. Please contact Support." + ) self._cursor = value @property def primary_key(self) -> PrimaryKeyType: - return self.config.primary_key or self.get_parser().get_parser_defined_primary_key(self.config) + return self.config.primary_key or self.get_parser().get_parser_defined_primary_key( + self.config + ) - def _filter_schema_invalid_properties(self, configured_catalog_json_schema: Dict[str, Any]) -> Dict[str, Any]: + def _filter_schema_invalid_properties( + self, configured_catalog_json_schema: Dict[str, Any] + ) -> Dict[str, Any]: if self.use_file_transfer: return { "type": "object", - "properties": {"file_path": {"type": "string"}, "file_size": {"type": "string"}, self.ab_file_name_col: {"type": "string"}}, + "properties": { + "file_path": {"type": "string"}, + "file_size": {"type": "string"}, + self.ab_file_name_col: {"type": "string"}, + }, } else: return super()._filter_schema_invalid_properties(configured_catalog_json_schema) @@ -88,16 +103,23 @@ def compute_slices(self) -> Iterable[Optional[Mapping[str, Any]]]: all_files = self.list_files() files_to_read = self._cursor.get_files_to_sync(all_files, self.logger) sorted_files_to_read = sorted(files_to_read, key=lambda f: (f.last_modified, f.uri)) - slices = [{"files": list(group[1])} for group in itertools.groupby(sorted_files_to_read, lambda f: f.last_modified)] + slices = [ + {"files": list(group[1])} + for group in itertools.groupby(sorted_files_to_read, lambda f: f.last_modified) + ] return slices - def transform_record(self, record: dict[str, Any], file: RemoteFile, last_updated: str) -> dict[str, Any]: + def transform_record( + self, record: dict[str, Any], file: RemoteFile, last_updated: str + ) -> dict[str, Any]: # adds _ab_source_file_last_modified and _ab_source_file_url to the record record[self.ab_last_mod_col] = last_updated record[self.ab_file_name_col] = file.uri return record - def transform_record_for_file_transfer(self, record: dict[str, Any], file: RemoteFile) -> dict[str, Any]: + def transform_record_for_file_transfer( + self, record: dict[str, Any], file: RemoteFile + ) -> dict[str, Any]: # timstamp() returns a float representing the number of seconds since the unix epoch record[self.modified] = int(file.last_modified.timestamp()) * 1000 record[self.source_file_url] = file.uri @@ -126,15 +148,21 @@ def read_records_from_slice(self, stream_slice: StreamSlice) -> Iterable[Airbyte self.logger.info(f"{self.name}: {file} file-based syncing") # todo: complete here the code to not rely on local parser file_transfer = FileTransfer() - for record in file_transfer.get_file(self.config, file, self.stream_reader, self.logger): + for record in file_transfer.get_file( + self.config, file, self.stream_reader, self.logger + ): line_no += 1 if not self.record_passes_validation_policy(record): n_skipped += 1 continue record = self.transform_record_for_file_transfer(record, file) - yield stream_data_to_airbyte_message(self.name, record, is_file_transfer_message=True) + yield stream_data_to_airbyte_message( + self.name, record, is_file_transfer_message=True + ) else: - for record in parser.parse_records(self.config, file, self.stream_reader, self.logger, schema): + for record in parser.parse_records( + self.config, file, self.stream_reader, self.logger, schema + ): line_no += 1 if self.config.schemaless: record = {"data": record} @@ -219,7 +247,9 @@ def get_json_schema(self) -> JsonSchema: except AirbyteTracedException as ate: raise ate except Exception as exc: - raise SchemaInferenceError(FileBasedSourceError.SCHEMA_INFERENCE_ERROR, stream=self.name) from exc + raise SchemaInferenceError( + FileBasedSourceError.SCHEMA_INFERENCE_ERROR, stream=self.name + ) from exc else: return {"type": "object", "properties": {**extra_fields, **schema["properties"]}} @@ -244,14 +274,20 @@ def _get_raw_json_schema(self) -> JsonSchema: first_n_files = self.config.recent_n_files_to_read_for_schema_discovery if first_n_files == 0: - self.logger.warning(msg=f"No files were identified in the stream {self.name}. Setting default schema for the stream.") + self.logger.warning( + msg=f"No files were identified in the stream {self.name}. Setting default schema for the stream." + ) return schemaless_schema - max_n_files_for_schema_inference = self._discovery_policy.get_max_n_files_for_schema_inference(self.get_parser()) + max_n_files_for_schema_inference = ( + self._discovery_policy.get_max_n_files_for_schema_inference(self.get_parser()) + ) if first_n_files > max_n_files_for_schema_inference: # Use the most recent files for schema inference, so we pick up schema changes during discovery. - self.logger.warning(msg=f"Refusing to infer schema for {first_n_files} files; using {max_n_files_for_schema_inference} files.") + self.logger.warning( + msg=f"Refusing to infer schema for {first_n_files} files; using {max_n_files_for_schema_inference} files." + ) first_n_files = max_n_files_for_schema_inference files = sorted(files, key=lambda x: x.last_modified, reverse=True)[:first_n_files] @@ -273,7 +309,9 @@ def get_files(self) -> Iterable[RemoteFile]: """ Return all files that belong to the stream as defined by the stream's globs. """ - return self.stream_reader.get_matching_files(self.config.globs or [], self.config.legacy_prefix, self.logger) + return self.stream_reader.get_matching_files( + self.config.globs or [], self.config.legacy_prefix, self.logger + ) def infer_schema(self, files: List[RemoteFile]) -> Mapping[str, Any]: loop = asyncio.get_event_loop() @@ -311,25 +349,34 @@ async def _infer_schema(self, files: List[RemoteFile]) -> Mapping[str, Any]: n_started, n_files = 0, len(files) files_iterator = iter(files) while pending_tasks or n_started < n_files: - while len(pending_tasks) <= self._discovery_policy.n_concurrent_requests and (file := next(files_iterator, None)): + while len(pending_tasks) <= self._discovery_policy.n_concurrent_requests and ( + file := next(files_iterator, None) + ): pending_tasks.add(asyncio.create_task(self._infer_file_schema(file))) n_started += 1 # Return when the first task is completed so that we can enqueue a new task as soon as the # number of concurrent tasks drops below the number allowed. - done, pending_tasks = await asyncio.wait(pending_tasks, return_when=asyncio.FIRST_COMPLETED) + done, pending_tasks = await asyncio.wait( + pending_tasks, return_when=asyncio.FIRST_COMPLETED + ) for task in done: try: base_schema = merge_schemas(base_schema, task.result()) except AirbyteTracedException as ate: raise ate except Exception as exc: - self.logger.error(f"An error occurred inferring the schema. \n {traceback.format_exc()}", exc_info=exc) + self.logger.error( + f"An error occurred inferring the schema. \n {traceback.format_exc()}", + exc_info=exc, + ) return base_schema async def _infer_file_schema(self, file: RemoteFile) -> SchemaType: try: - return await self.get_parser().infer_schema(self.config, file, self.stream_reader, self.logger) + return await self.get_parser().infer_schema( + self.config, file, self.stream_reader, self.logger + ) except AirbyteTracedException as ate: raise ate except Exception as exc: diff --git a/airbyte_cdk/sources/http_logger.py b/airbyte_cdk/sources/http_logger.py index 7158c800..cbdc3c68 100644 --- a/airbyte_cdk/sources/http_logger.py +++ b/airbyte_cdk/sources/http_logger.py @@ -9,7 +9,11 @@ def format_http_message( - response: requests.Response, title: str, description: str, stream_name: Optional[str], is_auxiliary: bool = None + response: requests.Response, + title: str, + description: str, + stream_name: Optional[str], + is_auxiliary: bool = None, ) -> LogMessage: request = response.request log_message = { diff --git a/airbyte_cdk/sources/message/repository.py b/airbyte_cdk/sources/message/repository.py index bf908309..2fc156e8 100644 --- a/airbyte_cdk/sources/message/repository.py +++ b/airbyte_cdk/sources/message/repository.py @@ -28,7 +28,9 @@ def _is_severe_enough(threshold: Level, level: Level) -> bool: if threshold not in _SEVERITY_BY_LOG_LEVEL: - _LOGGER.warning(f"Log level {threshold} for threshold is not supported. This is probably a CDK bug. Please contact Airbyte.") + _LOGGER.warning( + f"Log level {threshold} for threshold is not supported. This is probably a CDK bug. Please contact Airbyte." + ) return True if level not in _SEVERITY_BY_LOG_LEVEL: @@ -80,7 +82,12 @@ def emit_message(self, message: AirbyteMessage) -> None: def log_message(self, level: Level, message_provider: Callable[[], LogMessage]) -> None: if _is_severe_enough(self._log_level, level): self.emit_message( - AirbyteMessage(type=Type.LOG, log=AirbyteLogMessage(level=level, message=filter_secrets(json.dumps(message_provider())))) + AirbyteMessage( + type=Type.LOG, + log=AirbyteLogMessage( + level=level, message=filter_secrets(json.dumps(message_provider())) + ), + ) ) def consume_queue(self) -> Iterable[AirbyteMessage]: @@ -89,7 +96,12 @@ def consume_queue(self) -> Iterable[AirbyteMessage]: class LogAppenderMessageRepositoryDecorator(MessageRepository): - def __init__(self, dict_to_append: LogMessage, decorated: MessageRepository, log_level: Level = Level.INFO): + def __init__( + self, + dict_to_append: LogMessage, + decorated: MessageRepository, + log_level: Level = Level.INFO, + ): self._dict_to_append = dict_to_append self._decorated = decorated self._log_level = log_level @@ -106,7 +118,9 @@ def log_message(self, level: Level, message_provider: Callable[[], LogMessage]) def consume_queue(self) -> Iterable[AirbyteMessage]: return self._decorated.consume_queue() - def _append_second_to_first(self, first: LogMessage, second: LogMessage, path: Optional[List[str]] = None) -> LogMessage: + def _append_second_to_first( + self, first: LogMessage, second: LogMessage, path: Optional[List[str]] = None + ) -> LogMessage: if path is None: path = [] diff --git a/airbyte_cdk/sources/source.py b/airbyte_cdk/sources/source.py index c1d8ec66..2958d82c 100644 --- a/airbyte_cdk/sources/source.py +++ b/airbyte_cdk/sources/source.py @@ -33,7 +33,13 @@ def read_state(self, state_path: str) -> TState: ... def read_catalog(self, catalog_path: str) -> TCatalog: ... @abstractmethod - def read(self, logger: logging.Logger, config: TConfig, catalog: TCatalog, state: Optional[TState] = None) -> Iterable[AirbyteMessage]: + def read( + self, + logger: logging.Logger, + config: TConfig, + catalog: TCatalog, + state: Optional[TState] = None, + ) -> Iterable[AirbyteMessage]: """ Returns a generator of the AirbyteMessages generated by reading the source with the given configuration, catalog, and state. """ @@ -67,8 +73,14 @@ def read_state(cls, state_path: str) -> List[AirbyteStateMessage]: if state_obj: for state in state_obj: # type: ignore # `isinstance(state_obj, List)` ensures that this is a list parsed_message = AirbyteStateMessageSerializer.load(state) - if not parsed_message.stream and not parsed_message.data and not parsed_message.global_: - raise ValueError("AirbyteStateMessage should contain either a stream, global, or state field") + if ( + not parsed_message.stream + and not parsed_message.data + and not parsed_message.global_ + ): + raise ValueError( + "AirbyteStateMessage should contain either a stream, global, or state field" + ) parsed_state_messages.append(parsed_message) return parsed_state_messages diff --git a/airbyte_cdk/sources/streams/availability_strategy.py b/airbyte_cdk/sources/streams/availability_strategy.py index f2042bc1..312ddae1 100644 --- a/airbyte_cdk/sources/streams/availability_strategy.py +++ b/airbyte_cdk/sources/streams/availability_strategy.py @@ -20,7 +20,9 @@ class AvailabilityStrategy(ABC): """ @abstractmethod - def check_availability(self, stream: Stream, logger: logging.Logger, source: Optional["Source"] = None) -> Tuple[bool, Optional[str]]: + def check_availability( + self, stream: Stream, logger: logging.Logger, source: Optional["Source"] = None + ) -> Tuple[bool, Optional[str]]: """ Checks stream availability. @@ -52,7 +54,9 @@ def get_first_stream_slice(stream: Stream) -> Optional[Mapping[str, Any]]: return next(slices) @staticmethod - def get_first_record_for_slice(stream: Stream, stream_slice: Optional[Mapping[str, Any]]) -> StreamData: + def get_first_record_for_slice( + stream: Stream, stream_slice: Optional[Mapping[str, Any]] + ) -> StreamData: """ Gets the first record for a stream_slice of a stream. @@ -70,7 +74,9 @@ def get_first_record_for_slice(stream: Stream, stream_slice: Optional[Mapping[st # We wrap the return output of read_records() because some implementations return types that are iterable, # but not iterators such as lists or tuples - records_for_slice = iter(stream.read_records(sync_mode=SyncMode.full_refresh, stream_slice=stream_slice)) + records_for_slice = iter( + stream.read_records(sync_mode=SyncMode.full_refresh, stream_slice=stream_slice) + ) return next(records_for_slice) finally: diff --git a/airbyte_cdk/sources/streams/call_rate.py b/airbyte_cdk/sources/streams/call_rate.py index eb337545..19ae603c 100644 --- a/airbyte_cdk/sources/streams/call_rate.py +++ b/airbyte_cdk/sources/streams/call_rate.py @@ -76,7 +76,9 @@ def try_acquire(self, request: Any, weight: int) -> None: """ @abc.abstractmethod - def update(self, available_calls: Optional[int], call_reset_ts: Optional[datetime.datetime]) -> None: + def update( + self, available_calls: Optional[int], call_reset_ts: Optional[datetime.datetime] + ) -> None: """Update call rate counting with current values :param available_calls: @@ -202,12 +204,20 @@ class UnlimitedCallRatePolicy(BaseCallRatePolicy): def try_acquire(self, request: Any, weight: int) -> None: """Do nothing""" - def update(self, available_calls: Optional[int], call_reset_ts: Optional[datetime.datetime]) -> None: + def update( + self, available_calls: Optional[int], call_reset_ts: Optional[datetime.datetime] + ) -> None: """Do nothing""" class FixedWindowCallRatePolicy(BaseCallRatePolicy): - def __init__(self, next_reset_ts: datetime.datetime, period: timedelta, call_limit: int, matchers: list[RequestMatcher]): + def __init__( + self, + next_reset_ts: datetime.datetime, + period: timedelta, + call_limit: int, + matchers: list[RequestMatcher], + ): """A policy that allows {call_limit} calls within a {period} time interval :param next_reset_ts: next call rate reset time point @@ -235,7 +245,8 @@ def try_acquire(self, request: Any, weight: int) -> None: if self._calls_num + weight > self._call_limit: reset_in = self._next_reset_ts - datetime.datetime.now() error_message = ( - f"reached maximum number of allowed calls {self._call_limit} " f"per {self._offset} interval, next reset in {reset_in}." + f"reached maximum number of allowed calls {self._call_limit} " + f"per {self._offset} interval, next reset in {reset_in}." ) raise CallRateLimitHit( error=error_message, @@ -247,7 +258,9 @@ def try_acquire(self, request: Any, weight: int) -> None: self._calls_num += weight - def update(self, available_calls: Optional[int], call_reset_ts: Optional[datetime.datetime]) -> None: + def update( + self, available_calls: Optional[int], call_reset_ts: Optional[datetime.datetime] + ) -> None: """Update call rate counters, by default, only reacts to decreasing updates of available_calls and changes to call_reset_ts. We ignore updates with available_calls > current_available_calls to support call rate limits that are lower than API limits. @@ -260,12 +273,18 @@ def update(self, available_calls: Optional[int], call_reset_ts: Optional[datetim if available_calls is not None and current_available_calls > available_calls: logger.debug( - "got rate limit update from api, adjusting available calls from %s to %s", current_available_calls, available_calls + "got rate limit update from api, adjusting available calls from %s to %s", + current_available_calls, + available_calls, ) self._calls_num = self._call_limit - available_calls if call_reset_ts is not None and call_reset_ts != self._next_reset_ts: - logger.debug("got rate limit update from api, adjusting reset time from %s to %s", self._next_reset_ts, call_reset_ts) + logger.debug( + "got rate limit update from api, adjusting reset time from %s to %s", + self._next_reset_ts, + call_reset_ts, + ) self._next_reset_ts = call_reset_ts def _update_current_window(self) -> None: @@ -292,7 +311,10 @@ def __init__(self, rates: list[Rate], matchers: list[RequestMatcher]): """ if not rates: raise ValueError("The list of rates can not be empty") - pyrate_rates = [PyRateRate(limit=rate.limit, interval=int(rate.interval.total_seconds() * 1000)) for rate in rates] + pyrate_rates = [ + PyRateRate(limit=rate.limit, interval=int(rate.interval.total_seconds() * 1000)) + for rate in rates + ] self._bucket = InMemoryBucket(pyrate_rates) # Limiter will create the background task that clears old requests in the bucket self._limiter = Limiter(self._bucket) @@ -320,14 +342,18 @@ def try_acquire(self, request: Any, weight: int) -> None: time_to_wait=timedelta(milliseconds=time_to_wait), ) - def update(self, available_calls: Optional[int], call_reset_ts: Optional[datetime.datetime]) -> None: + def update( + self, available_calls: Optional[int], call_reset_ts: Optional[datetime.datetime] + ) -> None: """Adjust call bucket to reflect the state of the API server :param available_calls: :param call_reset_ts: :return: """ - if available_calls is not None and call_reset_ts is None: # we do our best to sync buckets with API + if ( + available_calls is not None and call_reset_ts is None + ): # we do our best to sync buckets with API if available_calls == 0: with self._limiter.lock: items_to_add = self._bucket.count() < self._bucket.rates[0].limit @@ -350,7 +376,9 @@ class AbstractAPIBudget(abc.ABC): """ @abc.abstractmethod - def acquire_call(self, request: Any, block: bool = True, timeout: Optional[float] = None) -> None: + def acquire_call( + self, request: Any, block: bool = True, timeout: Optional[float] = None + ) -> None: """Try to get a call from budget, will block by default :param request: @@ -375,7 +403,9 @@ def update_from_response(self, request: Any, response: Any) -> None: class APIBudget(AbstractAPIBudget): """Default APIBudget implementation""" - def __init__(self, policies: list[AbstractCallRatePolicy], maximum_attempts_to_acquire: int = 100000) -> None: + def __init__( + self, policies: list[AbstractCallRatePolicy], maximum_attempts_to_acquire: int = 100000 + ) -> None: """Constructor :param policies: list of policies in this budget @@ -392,7 +422,9 @@ def get_matching_policy(self, request: Any) -> Optional[AbstractCallRatePolicy]: return policy return None - def acquire_call(self, request: Any, block: bool = True, timeout: Optional[float] = None) -> None: + def acquire_call( + self, request: Any, block: bool = True, timeout: Optional[float] = None + ) -> None: """Try to get a call from budget, will block by default. Matchers will be called sequentially in the same order they were added. The first matcher that returns True will @@ -417,7 +449,9 @@ def update_from_response(self, request: Any, response: Any) -> None: """ pass - def _do_acquire(self, request: Any, policy: AbstractCallRatePolicy, block: bool, timeout: Optional[float]) -> None: + def _do_acquire( + self, request: Any, policy: AbstractCallRatePolicy, block: bool, timeout: Optional[float] + ) -> None: """Internal method to try to acquire a call credit :param request: @@ -439,14 +473,20 @@ def _do_acquire(self, request: Any, policy: AbstractCallRatePolicy, block: bool, else: time_to_wait = exc.time_to_wait - time_to_wait = max(timedelta(0), time_to_wait) # sometimes we get negative duration - logger.info("reached call limit %s. going to sleep for %s", exc.rate, time_to_wait) + time_to_wait = max( + timedelta(0), time_to_wait + ) # sometimes we get negative duration + logger.info( + "reached call limit %s. going to sleep for %s", exc.rate, time_to_wait + ) time.sleep(time_to_wait.total_seconds()) else: raise if last_exception: - logger.info("we used all %s attempts to acquire and failed", self._maximum_attempts_to_acquire) + logger.info( + "we used all %s attempts to acquire and failed", self._maximum_attempts_to_acquire + ) raise last_exception @@ -481,9 +521,13 @@ def update_from_response(self, request: Any, response: Any) -> None: reset_ts = self.get_reset_ts_from_response(response) policy.update(available_calls=available_calls, call_reset_ts=reset_ts) - def get_reset_ts_from_response(self, response: requests.Response) -> Optional[datetime.datetime]: + def get_reset_ts_from_response( + self, response: requests.Response + ) -> Optional[datetime.datetime]: if response.headers.get(self._ratelimit_reset_header): - return datetime.datetime.fromtimestamp(int(response.headers[self._ratelimit_reset_header])) + return datetime.datetime.fromtimestamp( + int(response.headers[self._ratelimit_reset_header]) + ) return None def get_calls_left_from_response(self, response: requests.Response) -> Optional[int]: diff --git a/airbyte_cdk/sources/streams/checkpoint/checkpoint_reader.py b/airbyte_cdk/sources/streams/checkpoint/checkpoint_reader.py index 1b6d6324..6e4ef98d 100644 --- a/airbyte_cdk/sources/streams/checkpoint/checkpoint_reader.py +++ b/airbyte_cdk/sources/streams/checkpoint/checkpoint_reader.py @@ -53,7 +53,9 @@ class IncrementalCheckpointReader(CheckpointReader): before syncing data. """ - def __init__(self, stream_state: Mapping[str, Any], stream_slices: Iterable[Optional[Mapping[str, Any]]]): + def __init__( + self, stream_state: Mapping[str, Any], stream_slices: Iterable[Optional[Mapping[str, Any]]] + ): self._state: Optional[Mapping[str, Any]] = stream_state self._stream_slices = iter(stream_slices) self._has_slices = False @@ -87,7 +89,12 @@ class CursorBasedCheckpointReader(CheckpointReader): that belongs to the Concurrent CDK. """ - def __init__(self, cursor: Cursor, stream_slices: Iterable[Optional[Mapping[str, Any]]], read_state_from_cursor: bool = False): + def __init__( + self, + cursor: Cursor, + stream_slices: Iterable[Optional[Mapping[str, Any]]], + read_state_from_cursor: bool = False, + ): self._cursor = cursor self._stream_slices = iter(stream_slices) # read_state_from_cursor is used to delineate that partitions should determine when to stop syncing dynamically according @@ -153,7 +160,11 @@ def _find_next_slice(self) -> StreamSlice: next_slice = self.read_and_convert_slice() state_for_slice = self._cursor.select_state(next_slice) has_more = state_for_slice == FULL_REFRESH_COMPLETE_STATE - return StreamSlice(cursor_slice=state_for_slice or {}, partition=next_slice.partition, extra_fields=next_slice.extra_fields) + return StreamSlice( + cursor_slice=state_for_slice or {}, + partition=next_slice.partition, + extra_fields=next_slice.extra_fields, + ) else: state_for_slice = self._cursor.select_state(self.current_slice) if state_for_slice == FULL_REFRESH_COMPLETE_STATE: @@ -173,7 +184,9 @@ def _find_next_slice(self) -> StreamSlice: ) # The reader continues to process the current partition if it's state is still in progress return StreamSlice( - cursor_slice=state_for_slice or {}, partition=self.current_slice.partition, extra_fields=self.current_slice.extra_fields + cursor_slice=state_for_slice or {}, + partition=self.current_slice.partition, + extra_fields=self.current_slice.extra_fields, ) else: # Unlike RFR cursors that iterate dynamically according to how stream state is updated, most cursors operate @@ -218,8 +231,17 @@ class LegacyCursorBasedCheckpointReader(CursorBasedCheckpointReader): } """ - def __init__(self, cursor: Cursor, stream_slices: Iterable[Optional[Mapping[str, Any]]], read_state_from_cursor: bool = False): - super().__init__(cursor=cursor, stream_slices=stream_slices, read_state_from_cursor=read_state_from_cursor) + def __init__( + self, + cursor: Cursor, + stream_slices: Iterable[Optional[Mapping[str, Any]]], + read_state_from_cursor: bool = False, + ): + super().__init__( + cursor=cursor, + stream_slices=stream_slices, + read_state_from_cursor=read_state_from_cursor, + ) def next(self) -> Optional[Mapping[str, Any]]: try: @@ -228,7 +250,9 @@ def next(self) -> Optional[Mapping[str, Any]]: if "partition" in dict(self.current_slice): raise ValueError("Stream is configured to use invalid stream slice key 'partition'") elif "cursor_slice" in dict(self.current_slice): - raise ValueError("Stream is configured to use invalid stream slice key 'cursor_slice'") + raise ValueError( + "Stream is configured to use invalid stream slice key 'cursor_slice'" + ) # We convert StreamSlice to a regular mapping because legacy connectors operate on the basic Mapping object. We # also duplicate all fields at the top level for backwards compatibility for existing Python sources diff --git a/airbyte_cdk/sources/streams/checkpoint/substream_resumable_full_refresh_cursor.py b/airbyte_cdk/sources/streams/checkpoint/substream_resumable_full_refresh_cursor.py index 8ebadcaf..9966959f 100644 --- a/airbyte_cdk/sources/streams/checkpoint/substream_resumable_full_refresh_cursor.py +++ b/airbyte_cdk/sources/streams/checkpoint/substream_resumable_full_refresh_cursor.py @@ -5,7 +5,9 @@ from airbyte_cdk.models import FailureType from airbyte_cdk.sources.streams.checkpoint import Cursor -from airbyte_cdk.sources.streams.checkpoint.per_partition_key_serializer import PerPartitionKeySerializer +from airbyte_cdk.sources.streams.checkpoint.per_partition_key_serializer import ( + PerPartitionKeySerializer, +) from airbyte_cdk.sources.types import Record, StreamSlice, StreamState from airbyte_cdk.utils import AirbyteTracedException @@ -97,7 +99,9 @@ def select_state(self, stream_slice: Optional[StreamSlice] = None) -> Optional[S if not stream_slice: raise ValueError("A partition needs to be provided in order to extract a state") - return self._per_partition_state.get(self._to_partition_key(stream_slice.partition), {}).get("cursor") + return self._per_partition_state.get( + self._to_partition_key(stream_slice.partition), {} + ).get("cursor") def _to_partition_key(self, partition: Mapping[str, Any]) -> str: return self._partition_serializer.to_partition_key(partition) diff --git a/airbyte_cdk/sources/streams/concurrent/adapters.py b/airbyte_cdk/sources/streams/concurrent/adapters.py index 5fc775a1..d4b539a5 100644 --- a/airbyte_cdk/sources/streams/concurrent/adapters.py +++ b/airbyte_cdk/sources/streams/concurrent/adapters.py @@ -8,7 +8,15 @@ from functools import lru_cache from typing import Any, Iterable, List, Mapping, MutableMapping, Optional, Tuple, Union -from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, AirbyteStream, ConfiguredAirbyteStream, Level, SyncMode, Type +from airbyte_cdk.models import ( + AirbyteLogMessage, + AirbyteMessage, + AirbyteStream, + ConfiguredAirbyteStream, + Level, + SyncMode, + Type, +) from airbyte_cdk.sources import AbstractSource, Source from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager from airbyte_cdk.sources.message import MessageRepository @@ -16,15 +24,23 @@ from airbyte_cdk.sources.streams import Stream from airbyte_cdk.sources.streams.availability_strategy import AvailabilityStrategy from airbyte_cdk.sources.streams.concurrent.abstract_stream_facade import AbstractStreamFacade -from airbyte_cdk.sources.streams.concurrent.availability_strategy import AbstractAvailabilityStrategy, AlwaysAvailableAvailabilityStrategy +from airbyte_cdk.sources.streams.concurrent.availability_strategy import ( + AbstractAvailabilityStrategy, + AlwaysAvailableAvailabilityStrategy, +) from airbyte_cdk.sources.streams.concurrent.cursor import Cursor, FinalStateCursor from airbyte_cdk.sources.streams.concurrent.default_stream import DefaultStream from airbyte_cdk.sources.streams.concurrent.exceptions import ExceptionWithDisplayMessage -from airbyte_cdk.sources.streams.concurrent.helpers import get_cursor_field_from_stream, get_primary_key_from_stream +from airbyte_cdk.sources.streams.concurrent.helpers import ( + get_cursor_field_from_stream, + get_primary_key_from_stream, +) from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition from airbyte_cdk.sources.streams.concurrent.partitions.partition_generator import PartitionGenerator from airbyte_cdk.sources.streams.concurrent.partitions.record import Record -from airbyte_cdk.sources.streams.concurrent.state_converters.datetime_stream_state_converter import DateTimeStreamStateConverter +from airbyte_cdk.sources.streams.concurrent.state_converters.datetime_stream_state_converter import ( + DateTimeStreamStateConverter, +) from airbyte_cdk.sources.streams.core import StreamData from airbyte_cdk.sources.types import StreamSlice from airbyte_cdk.sources.utils.schema_helpers import InternalConfig @@ -75,7 +91,9 @@ def create_from_stream( partition_generator=StreamPartitionGenerator( stream, message_repository, - SyncMode.full_refresh if isinstance(cursor, FinalStateCursor) else SyncMode.incremental, + SyncMode.full_refresh + if isinstance(cursor, FinalStateCursor) + else SyncMode.incremental, [cursor_field] if cursor_field is not None else None, state, cursor, @@ -97,14 +115,23 @@ def create_from_stream( @property def state(self) -> MutableMapping[str, Any]: - raise NotImplementedError("This should not be called as part of the Concurrent CDK code. Please report the problem to Airbyte") + raise NotImplementedError( + "This should not be called as part of the Concurrent CDK code. Please report the problem to Airbyte" + ) @state.setter def state(self, value: Mapping[str, Any]) -> None: if "state" in dir(self._legacy_stream): self._legacy_stream.state = value # type: ignore # validating `state` is attribute of stream using `if` above - def __init__(self, stream: DefaultStream, legacy_stream: Stream, cursor: Cursor, slice_logger: SliceLogger, logger: logging.Logger): + def __init__( + self, + stream: DefaultStream, + legacy_stream: Stream, + cursor: Cursor, + slice_logger: SliceLogger, + logger: logging.Logger, + ): """ :param stream: The underlying AbstractStream """ @@ -141,7 +168,10 @@ def read_records( # This shouldn't happen if the ConcurrentCursor was used state = "unknown; no state attribute was available on the cursor" yield AirbyteMessage( - type=Type.LOG, log=AirbyteLogMessage(level=Level.ERROR, message=f"Cursor State at time of exception: {state}") + type=Type.LOG, + log=AirbyteLogMessage( + level=Level.ERROR, message=f"Cursor State at time of exception: {state}" + ), ) raise exc @@ -180,7 +210,9 @@ def get_json_schema(self) -> Mapping[str, Any]: def supports_incremental(self) -> bool: return self._legacy_stream.supports_incremental - def check_availability(self, logger: logging.Logger, source: Optional["Source"] = None) -> Tuple[bool, Optional[str]]: + def check_availability( + self, logger: logging.Logger, source: Optional["Source"] = None + ) -> Tuple[bool, Optional[str]]: """ Verifies the stream is available. Delegates to the underlying AbstractStream and ignores the parameters :param logger: (ignored) @@ -264,7 +296,9 @@ def read(self) -> Iterable[Record]: ): if isinstance(record_data, Mapping): data_to_return = dict(record_data) - self._stream.transformer.transform(data_to_return, self._stream.get_json_schema()) + self._stream.transformer.transform( + data_to_return, self._stream.get_json_schema() + ) yield Record(data_to_return, self) else: self._message_repository.emit_message(record_data) @@ -329,9 +363,17 @@ def __init__( self._cursor = cursor def generate(self) -> Iterable[Partition]: - for s in self._stream.stream_slices(sync_mode=self._sync_mode, cursor_field=self._cursor_field, stream_state=self._state): + for s in self._stream.stream_slices( + sync_mode=self._sync_mode, cursor_field=self._cursor_field, stream_state=self._state + ): yield StreamPartition( - self._stream, copy.deepcopy(s), self.message_repository, self._sync_mode, self._cursor_field, self._state, self._cursor + self._stream, + copy.deepcopy(s), + self.message_repository, + self._sync_mode, + self._cursor_field, + self._state, + self._cursor, ) @@ -382,8 +424,16 @@ def generate(self) -> Iterable[Partition]: :return: An iterable of StreamPartition objects. """ - start_boundary = self._slice_boundary_fields[self._START_BOUNDARY] if self._slice_boundary_fields else "start" - end_boundary = self._slice_boundary_fields[self._END_BOUNDARY] if self._slice_boundary_fields else "end" + start_boundary = ( + self._slice_boundary_fields[self._START_BOUNDARY] + if self._slice_boundary_fields + else "start" + ) + end_boundary = ( + self._slice_boundary_fields[self._END_BOUNDARY] + if self._slice_boundary_fields + else "end" + ) for slice_start, slice_end in self._cursor.generate_slices(): stream_slice = StreamSlice( @@ -405,12 +455,17 @@ def generate(self) -> Iterable[Partition]: ) -@deprecated("Availability strategy has been soft deprecated. Do not use. Class is subject to removal", category=ExperimentalClassWarning) +@deprecated( + "Availability strategy has been soft deprecated. Do not use. Class is subject to removal", + category=ExperimentalClassWarning, +) class AvailabilityStrategyFacade(AvailabilityStrategy): def __init__(self, abstract_availability_strategy: AbstractAvailabilityStrategy): self._abstract_availability_strategy = abstract_availability_strategy - def check_availability(self, stream: Stream, logger: logging.Logger, source: Optional["Source"] = None) -> Tuple[bool, Optional[str]]: + def check_availability( + self, stream: Stream, logger: logging.Logger, source: Optional["Source"] = None + ) -> Tuple[bool, Optional[str]]: """ Checks stream availability. diff --git a/airbyte_cdk/sources/streams/concurrent/cursor.py b/airbyte_cdk/sources/streams/concurrent/cursor.py index d5b8fbca..15e9b59a 100644 --- a/airbyte_cdk/sources/streams/concurrent/cursor.py +++ b/airbyte_cdk/sources/streams/concurrent/cursor.py @@ -11,7 +11,9 @@ from airbyte_cdk.sources.streams import NO_CURSOR_STATE_KEY from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition from airbyte_cdk.sources.streams.concurrent.partitions.record import Record -from airbyte_cdk.sources.streams.concurrent.state_converters.abstract_stream_state_converter import AbstractStreamStateConverter +from airbyte_cdk.sources.streams.concurrent.state_converters.abstract_stream_state_converter import ( + AbstractStreamStateConverter, +) def _extract_value(mapping: Mapping[str, Any], path: List[str]) -> Any: @@ -127,8 +129,12 @@ def ensure_at_least_one_state_emitted(self) -> None: Used primarily for full refresh syncs that do not have a valid cursor value to emit at the end of a sync """ - self._connector_state_manager.update_state_for_stream(self._stream_name, self._stream_namespace, self.state) - state_message = self._connector_state_manager.create_state_message(self._stream_name, self._stream_namespace) + self._connector_state_manager.update_state_for_stream( + self._stream_name, self._stream_namespace, self.state + ) + state_message = self._connector_state_manager.create_state_message( + self._stream_name, self._stream_namespace + ) self._message_repository.emit_message(state_message) @@ -181,13 +187,22 @@ def cursor_field(self) -> CursorField: def slice_boundary_fields(self) -> Optional[Tuple[str, str]]: return self._slice_boundary_fields - def _get_concurrent_state(self, state: MutableMapping[str, Any]) -> Tuple[CursorValueType, MutableMapping[str, Any]]: + def _get_concurrent_state( + self, state: MutableMapping[str, Any] + ) -> Tuple[CursorValueType, MutableMapping[str, Any]]: if self._connector_state_converter.is_state_message_compatible(state): - return self._start or self._connector_state_converter.zero_value, self._connector_state_converter.deserialize(state) - return self._connector_state_converter.convert_from_sequential_state(self._cursor_field, state, self._start) + return ( + self._start or self._connector_state_converter.zero_value, + self._connector_state_converter.deserialize(state), + ) + return self._connector_state_converter.convert_from_sequential_state( + self._cursor_field, state, self._start + ) def observe(self, record: Record) -> None: - most_recent_cursor_value = self._most_recent_cursor_value_per_partition.get(record.partition) + most_recent_cursor_value = self._most_recent_cursor_value_per_partition.get( + record.partition + ) cursor_value = self._extract_cursor_value(record) if most_recent_cursor_value is None or most_recent_cursor_value < cursor_value: @@ -199,7 +214,9 @@ def _extract_cursor_value(self, record: Record) -> Any: def close_partition(self, partition: Partition) -> None: slice_count_before = len(self.state.get("slices", [])) self._add_slice_to_state(partition) - if slice_count_before < len(self.state["slices"]): # only emit if at least one slice has been processed + if slice_count_before < len( + self.state["slices"] + ): # only emit if at least one slice has been processed self._merge_partitions() self._emit_state_message() self._has_closed_at_least_one_slice = True @@ -252,9 +269,13 @@ def _emit_state_message(self) -> None: self._connector_state_manager.update_state_for_stream( self._stream_name, self._stream_namespace, - self._connector_state_converter.convert_to_state_message(self._cursor_field, self.state), + self._connector_state_converter.convert_to_state_message( + self._cursor_field, self.state + ), + ) + state_message = self._connector_state_manager.create_state_message( + self._stream_name, self._stream_namespace ) - state_message = self._connector_state_manager.create_state_message(self._stream_name, self._stream_namespace) self._message_repository.emit_message(state_message) def _merge_partitions(self) -> None: @@ -267,7 +288,9 @@ def _extract_from_slice(self, partition: Partition, key: str) -> CursorValueType raise KeyError(f"Could not find key `{key}` in empty slice") return self._connector_state_converter.parse_value(_slice[key]) # type: ignore # we expect the devs to specify a key that would return a CursorValueType except KeyError as exception: - raise KeyError(f"Partition is expected to have key `{key}` but could not be found") from exception + raise KeyError( + f"Partition is expected to have key `{key}` but could not be found" + ) from exception def ensure_at_least_one_state_emitted(self) -> None: """ @@ -299,7 +322,9 @@ def generate_slices(self) -> Iterable[Tuple[CursorValueType, CursorValueType]]: if len(self.state["slices"]) == 1: yield from self._split_per_slice_range( - self._calculate_lower_boundary_of_last_slice(self.state["slices"][0][self._connector_state_converter.END_KEY]), + self._calculate_lower_boundary_of_last_slice( + self.state["slices"][0][self._connector_state_converter.END_KEY] + ), self._end_provider(), True, ) @@ -307,7 +332,8 @@ def generate_slices(self) -> Iterable[Tuple[CursorValueType, CursorValueType]]: for i in range(len(self.state["slices"]) - 1): if self._cursor_granularity: yield from self._split_per_slice_range( - self.state["slices"][i][self._connector_state_converter.END_KEY] + self._cursor_granularity, + self.state["slices"][i][self._connector_state_converter.END_KEY] + + self._cursor_granularity, self.state["slices"][i + 1][self._connector_state_converter.START_KEY], False, ) @@ -318,7 +344,9 @@ def generate_slices(self) -> Iterable[Tuple[CursorValueType, CursorValueType]]: False, ) yield from self._split_per_slice_range( - self._calculate_lower_boundary_of_last_slice(self.state["slices"][-1][self._connector_state_converter.END_KEY]), + self._calculate_lower_boundary_of_last_slice( + self.state["slices"][-1][self._connector_state_converter.END_KEY] + ), self._end_provider(), True, ) @@ -326,9 +354,14 @@ def generate_slices(self) -> Iterable[Tuple[CursorValueType, CursorValueType]]: raise ValueError("Expected at least one slice") def _is_start_before_first_slice(self) -> bool: - return self._start is not None and self._start < self.state["slices"][0][self._connector_state_converter.START_KEY] + return ( + self._start is not None + and self._start < self.state["slices"][0][self._connector_state_converter.START_KEY] + ) - def _calculate_lower_boundary_of_last_slice(self, lower_boundary: CursorValueType) -> CursorValueType: + def _calculate_lower_boundary_of_last_slice( + self, lower_boundary: CursorValueType + ) -> CursorValueType: if self._lookback_window: return lower_boundary - self._lookback_window return lower_boundary @@ -352,9 +385,13 @@ def _split_per_slice_range( stop_processing = False current_lower_boundary = lower while not stop_processing: - current_upper_boundary = min(self._evaluate_upper_safely(current_lower_boundary, self._slice_range), upper) + current_upper_boundary = min( + self._evaluate_upper_safely(current_lower_boundary, self._slice_range), upper + ) has_reached_upper_boundary = current_upper_boundary >= upper - if self._cursor_granularity and (not upper_is_end or not has_reached_upper_boundary): + if self._cursor_granularity and ( + not upper_is_end or not has_reached_upper_boundary + ): yield current_lower_boundary, current_upper_boundary - self._cursor_granularity else: yield current_lower_boundary, current_upper_boundary diff --git a/airbyte_cdk/sources/streams/concurrent/default_stream.py b/airbyte_cdk/sources/streams/concurrent/default_stream.py index a48d897e..eb94ebba 100644 --- a/airbyte_cdk/sources/streams/concurrent/default_stream.py +++ b/airbyte_cdk/sources/streams/concurrent/default_stream.py @@ -8,7 +8,10 @@ from airbyte_cdk.models import AirbyteStream, SyncMode from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream -from airbyte_cdk.sources.streams.concurrent.availability_strategy import AbstractAvailabilityStrategy, StreamAvailability +from airbyte_cdk.sources.streams.concurrent.availability_strategy import ( + AbstractAvailabilityStrategy, + StreamAvailability, +) from airbyte_cdk.sources.streams.concurrent.cursor import Cursor from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition from airbyte_cdk.sources.streams.concurrent.partitions.partition_generator import PartitionGenerator @@ -60,7 +63,11 @@ def get_json_schema(self) -> Mapping[str, Any]: return self._json_schema def as_airbyte_stream(self) -> AirbyteStream: - stream = AirbyteStream(name=self.name, json_schema=dict(self._json_schema), supported_sync_modes=[SyncMode.full_refresh]) + stream = AirbyteStream( + name=self.name, + json_schema=dict(self._json_schema), + supported_sync_modes=[SyncMode.full_refresh], + ) if self._namespace: stream.namespace = self._namespace diff --git a/airbyte_cdk/sources/streams/concurrent/helpers.py b/airbyte_cdk/sources/streams/concurrent/helpers.py index ad772272..d839068a 100644 --- a/airbyte_cdk/sources/streams/concurrent/helpers.py +++ b/airbyte_cdk/sources/streams/concurrent/helpers.py @@ -5,7 +5,9 @@ from airbyte_cdk.sources.streams import Stream -def get_primary_key_from_stream(stream_primary_key: Optional[Union[str, List[str], List[List[str]]]]) -> List[str]: +def get_primary_key_from_stream( + stream_primary_key: Optional[Union[str, List[str], List[List[str]]]], +) -> List[str]: if stream_primary_key is None: return [] elif isinstance(stream_primary_key, str): @@ -22,7 +24,9 @@ def get_primary_key_from_stream(stream_primary_key: Optional[Union[str, List[str def get_cursor_field_from_stream(stream: Stream) -> Optional[str]: if isinstance(stream.cursor_field, list): if len(stream.cursor_field) > 1: - raise ValueError(f"Nested cursor fields are not supported. Got {stream.cursor_field} for {stream.name}") + raise ValueError( + f"Nested cursor fields are not supported. Got {stream.cursor_field} for {stream.name}" + ) elif len(stream.cursor_field) == 0: return None else: diff --git a/airbyte_cdk/sources/streams/concurrent/partition_enqueuer.py b/airbyte_cdk/sources/streams/concurrent/partition_enqueuer.py index 8e63c16a..a4dd81f2 100644 --- a/airbyte_cdk/sources/streams/concurrent/partition_enqueuer.py +++ b/airbyte_cdk/sources/streams/concurrent/partition_enqueuer.py @@ -4,7 +4,9 @@ import time from queue import Queue -from airbyte_cdk.sources.concurrent_source.partition_generation_completed_sentinel import PartitionGenerationCompletedSentinel +from airbyte_cdk.sources.concurrent_source.partition_generation_completed_sentinel import ( + PartitionGenerationCompletedSentinel, +) from airbyte_cdk.sources.concurrent_source.stream_thread_exception import StreamThreadException from airbyte_cdk.sources.concurrent_source.thread_pool_manager import ThreadPoolManager from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream @@ -16,7 +18,12 @@ class PartitionEnqueuer: Generates partitions from a partition generator and puts them in a queue. """ - def __init__(self, queue: Queue[QueueItem], thread_pool_manager: ThreadPoolManager, sleep_time_in_seconds: float = 0.1) -> None: + def __init__( + self, + queue: Queue[QueueItem], + thread_pool_manager: ThreadPoolManager, + sleep_time_in_seconds: float = 0.1, + ) -> None: """ :param queue: The queue to put the partitions in. :param throttler: The throttler to use to throttle the partition generation. diff --git a/airbyte_cdk/sources/streams/concurrent/partition_reader.py b/airbyte_cdk/sources/streams/concurrent/partition_reader.py index eec69d56..3d23fd9c 100644 --- a/airbyte_cdk/sources/streams/concurrent/partition_reader.py +++ b/airbyte_cdk/sources/streams/concurrent/partition_reader.py @@ -5,7 +5,10 @@ from airbyte_cdk.sources.concurrent_source.stream_thread_exception import StreamThreadException from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition -from airbyte_cdk.sources.streams.concurrent.partitions.types import PartitionCompleteSentinel, QueueItem +from airbyte_cdk.sources.streams.concurrent.partitions.types import ( + PartitionCompleteSentinel, + QueueItem, +) class PartitionReader: diff --git a/airbyte_cdk/sources/streams/concurrent/partitions/record.py b/airbyte_cdk/sources/streams/concurrent/partitions/record.py index 0b34ae13..e67dc656 100644 --- a/airbyte_cdk/sources/streams/concurrent/partitions/record.py +++ b/airbyte_cdk/sources/streams/concurrent/partitions/record.py @@ -13,7 +13,12 @@ class Record: Represents a record read from a stream. """ - def __init__(self, data: Mapping[str, Any], partition: "Partition", is_file_transfer_message: bool = False): + def __init__( + self, + data: Mapping[str, Any], + partition: "Partition", + is_file_transfer_message: bool = False, + ): self.data = data self.partition = partition self.is_file_transfer_message = is_file_transfer_message @@ -21,7 +26,10 @@ def __init__(self, data: Mapping[str, Any], partition: "Partition", is_file_tran def __eq__(self, other: Any) -> bool: if not isinstance(other, Record): return False - return self.data == other.data and self.partition.stream_name() == other.partition.stream_name() + return ( + self.data == other.data + and self.partition.stream_name() == other.partition.stream_name() + ) def __repr__(self) -> str: return f"Record(data={self.data}, stream_name={self.partition.stream_name()})" diff --git a/airbyte_cdk/sources/streams/concurrent/partitions/types.py b/airbyte_cdk/sources/streams/concurrent/partitions/types.py index c36d9d94..7abebe07 100644 --- a/airbyte_cdk/sources/streams/concurrent/partitions/types.py +++ b/airbyte_cdk/sources/streams/concurrent/partitions/types.py @@ -4,7 +4,9 @@ from typing import Any, Union -from airbyte_cdk.sources.concurrent_source.partition_generation_completed_sentinel import PartitionGenerationCompletedSentinel +from airbyte_cdk.sources.concurrent_source.partition_generation_completed_sentinel import ( + PartitionGenerationCompletedSentinel, +) from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition from airbyte_cdk.sources.streams.concurrent.partitions.record import Record @@ -31,4 +33,6 @@ def __eq__(self, other: Any) -> bool: """ Typedef representing the items that can be added to the ThreadBasedConcurrentStream """ -QueueItem = Union[Record, Partition, PartitionCompleteSentinel, PartitionGenerationCompletedSentinel, Exception] +QueueItem = Union[ + Record, Partition, PartitionCompleteSentinel, PartitionGenerationCompletedSentinel, Exception +] diff --git a/airbyte_cdk/sources/streams/concurrent/state_converters/abstract_stream_state_converter.py b/airbyte_cdk/sources/streams/concurrent/state_converters/abstract_stream_state_converter.py index 60d8f17f..1b477976 100644 --- a/airbyte_cdk/sources/streams/concurrent/state_converters/abstract_stream_state_converter.py +++ b/airbyte_cdk/sources/streams/concurrent/state_converters/abstract_stream_state_converter.py @@ -30,7 +30,9 @@ def _to_state_message(self, value: Any) -> Any: def __init__(self, is_sequential_state: bool = True): self._is_sequential_state = is_sequential_state - def convert_to_state_message(self, cursor_field: "CursorField", stream_state: MutableMapping[str, Any]) -> MutableMapping[str, Any]: + def convert_to_state_message( + self, cursor_field: "CursorField", stream_state: MutableMapping[str, Any] + ) -> MutableMapping[str, Any]: """ Convert the state message from the concurrency-compatible format to the stream's original format. @@ -41,7 +43,9 @@ def convert_to_state_message(self, cursor_field: "CursorField", stream_state: Mu legacy_state = stream_state.get("legacy", {}) latest_complete_time = self._get_latest_complete_time(stream_state.get("slices", [])) if latest_complete_time is not None: - legacy_state.update({cursor_field.cursor_field_key: self._to_state_message(latest_complete_time)}) + legacy_state.update( + {cursor_field.cursor_field_key: self._to_state_message(latest_complete_time)} + ) return legacy_state or {} else: return self.serialize(stream_state, ConcurrencyCompatibleStateType.date_range) @@ -51,7 +55,9 @@ def _get_latest_complete_time(self, slices: List[MutableMapping[str, Any]]) -> A Get the latest time before which all records have been processed. """ if not slices: - raise RuntimeError("Expected at least one slice but there were none. This is unexpected; please contact Support.") + raise RuntimeError( + "Expected at least one slice but there were none. This is unexpected; please contact Support." + ) merged_intervals = self.merge_intervals(slices) first_interval = merged_intervals[0] @@ -66,7 +72,9 @@ def deserialize(self, state: MutableMapping[str, Any]) -> MutableMapping[str, An stream_slice[self.END_KEY] = self._from_state_message(stream_slice[self.END_KEY]) return state - def serialize(self, state: MutableMapping[str, Any], state_type: ConcurrencyCompatibleStateType) -> MutableMapping[str, Any]: + def serialize( + self, state: MutableMapping[str, Any], state_type: ConcurrencyCompatibleStateType + ) -> MutableMapping[str, Any]: """ Perform any transformations needed for compatibility with the converter. """ @@ -77,13 +85,17 @@ def serialize(self, state: MutableMapping[str, Any], state_type: ConcurrencyComp self.END_KEY: self._to_state_message(stream_slice[self.END_KEY]), } if stream_slice.get(self.MOST_RECENT_RECORD_KEY): - serialized_slice[self.MOST_RECENT_RECORD_KEY] = self._to_state_message(stream_slice[self.MOST_RECENT_RECORD_KEY]) + serialized_slice[self.MOST_RECENT_RECORD_KEY] = self._to_state_message( + stream_slice[self.MOST_RECENT_RECORD_KEY] + ) serialized_slices.append(serialized_slice) return {"slices": serialized_slices, "state_type": state_type.value} @staticmethod def is_state_message_compatible(state: MutableMapping[str, Any]) -> bool: - return bool(state) and state.get("state_type") in [t.value for t in ConcurrencyCompatibleStateType] + return bool(state) and state.get("state_type") in [ + t.value for t in ConcurrencyCompatibleStateType + ] @abstractmethod def convert_from_sequential_state( @@ -112,7 +124,9 @@ def increment(self, value: Any) -> Any: """ ... - def merge_intervals(self, intervals: List[MutableMapping[str, Any]]) -> List[MutableMapping[str, Any]]: + def merge_intervals( + self, intervals: List[MutableMapping[str, Any]] + ) -> List[MutableMapping[str, Any]]: """ Compute and return a list of merged intervals. @@ -122,7 +136,9 @@ def merge_intervals(self, intervals: List[MutableMapping[str, Any]]) -> List[Mut if not intervals: return [] - sorted_intervals = sorted(intervals, key=lambda interval: (interval[self.START_KEY], interval[self.END_KEY])) + sorted_intervals = sorted( + intervals, key=lambda interval: (interval[self.START_KEY], interval[self.END_KEY]) + ) merged_intervals = [sorted_intervals[0]] for current_interval in sorted_intervals[1:]: diff --git a/airbyte_cdk/sources/streams/concurrent/state_converters/datetime_stream_state_converter.py b/airbyte_cdk/sources/streams/concurrent/state_converters/datetime_stream_state_converter.py index a6a33fac..3ff22c09 100644 --- a/airbyte_cdk/sources/streams/concurrent/state_converters/datetime_stream_state_converter.py +++ b/airbyte_cdk/sources/streams/concurrent/state_converters/datetime_stream_state_converter.py @@ -57,7 +57,10 @@ def _compare_intervals(self, end_time: Any, start_time: Any) -> bool: return bool(self.increment(end_time) >= start_time) def convert_from_sequential_state( - self, cursor_field: CursorField, stream_state: MutableMapping[str, Any], start: Optional[datetime] + self, + cursor_field: CursorField, + stream_state: MutableMapping[str, Any], + start: Optional[datetime], ) -> Tuple[datetime, MutableMapping[str, Any]]: """ Convert the state message to the format required by the ConcurrentCursor. @@ -78,7 +81,9 @@ def convert_from_sequential_state( # Create a slice to represent the records synced during prior syncs. # The start and end are the same to avoid confusion as to whether the records for this slice # were actually synced - slices = [{self.START_KEY: start if start is not None else sync_start, self.END_KEY: sync_start}] + slices = [ + {self.START_KEY: start if start is not None else sync_start, self.END_KEY: sync_start} + ] return sync_start, { "state_type": ConcurrencyCompatibleStateType.date_range.value, @@ -86,10 +91,17 @@ def convert_from_sequential_state( "legacy": stream_state, } - def _get_sync_start(self, cursor_field: CursorField, stream_state: MutableMapping[str, Any], start: Optional[datetime]) -> datetime: + def _get_sync_start( + self, + cursor_field: CursorField, + stream_state: MutableMapping[str, Any], + start: Optional[datetime], + ) -> datetime: sync_start = start if start is not None else self.zero_value prev_sync_low_water_mark = ( - self.parse_timestamp(stream_state[cursor_field.cursor_field_key]) if cursor_field.cursor_field_key in stream_state else None + self.parse_timestamp(stream_state[cursor_field.cursor_field_key]) + if cursor_field.cursor_field_key in stream_state + else None ) if prev_sync_low_water_mark and prev_sync_low_water_mark >= sync_start: return prev_sync_low_water_mark @@ -122,7 +134,9 @@ def output_format(self, timestamp: datetime) -> int: def parse_timestamp(self, timestamp: int) -> datetime: dt_object = pendulum.from_timestamp(timestamp) if not isinstance(dt_object, DateTime): - raise ValueError(f"DateTime object was expected but got {type(dt_object)} from pendulum.parse({timestamp})") + raise ValueError( + f"DateTime object was expected but got {type(dt_object)} from pendulum.parse({timestamp})" + ) return dt_object # type: ignore # we are manually type checking because pendulum.parse may return different types @@ -142,7 +156,9 @@ class IsoMillisConcurrentStreamStateConverter(DateTimeStreamStateConverter): _zero_value = "0001-01-01T00:00:00.000Z" - def __init__(self, is_sequential_state: bool = True, cursor_granularity: Optional[timedelta] = None): + def __init__( + self, is_sequential_state: bool = True, cursor_granularity: Optional[timedelta] = None + ): super().__init__(is_sequential_state=is_sequential_state) self._cursor_granularity = cursor_granularity or timedelta(milliseconds=1) @@ -155,7 +171,9 @@ def output_format(self, timestamp: datetime) -> Any: def parse_timestamp(self, timestamp: str) -> datetime: dt_object = pendulum.parse(timestamp) if not isinstance(dt_object, DateTime): - raise ValueError(f"DateTime object was expected but got {type(dt_object)} from pendulum.parse({timestamp})") + raise ValueError( + f"DateTime object was expected but got {type(dt_object)} from pendulum.parse({timestamp})" + ) return dt_object # type: ignore # we are manually type checking because pendulum.parse may return different types @@ -172,7 +190,9 @@ def __init__( is_sequential_state: bool = True, cursor_granularity: Optional[timedelta] = None, ): - super().__init__(is_sequential_state=is_sequential_state, cursor_granularity=cursor_granularity) + super().__init__( + is_sequential_state=is_sequential_state, cursor_granularity=cursor_granularity + ) self._datetime_format = datetime_format self._input_datetime_formats = input_datetime_formats if input_datetime_formats else [] self._input_datetime_formats += [self._datetime_format] diff --git a/airbyte_cdk/sources/streams/core.py b/airbyte_cdk/sources/streams/core.py index c7a0cf02..90925c4c 100644 --- a/airbyte_cdk/sources/streams/core.py +++ b/airbyte_cdk/sources/streams/core.py @@ -11,7 +11,13 @@ from typing import Any, Dict, Iterable, Iterator, List, Mapping, MutableMapping, Optional, Union import airbyte_cdk.sources.utils.casing as casing -from airbyte_cdk.models import AirbyteMessage, AirbyteStream, ConfiguredAirbyteStream, DestinationSyncMode, SyncMode +from airbyte_cdk.models import ( + AirbyteMessage, + AirbyteStream, + ConfiguredAirbyteStream, + DestinationSyncMode, + SyncMode, +) from airbyte_cdk.models import Type as MessageType from airbyte_cdk.sources.streams.checkpoint import ( CheckpointMode, @@ -84,7 +90,10 @@ def state(self, value: MutableMapping[str, Any]) -> None: """State setter, accept state serialized by state getter.""" -@deprecated(version="0.87.0", reason="Deprecated in favor of the CheckpointMixin which offers similar functionality") +@deprecated( + version="0.87.0", + reason="Deprecated in favor of the CheckpointMixin which offers similar functionality", +) class IncrementalMixin(CheckpointMixin, ABC): """Mixin to make stream incremental. @@ -192,9 +201,14 @@ def read( # type: ignore # ignoring typing for ConnectorStateManager because o for record_data_or_message in records: yield record_data_or_message if isinstance(record_data_or_message, Mapping) or ( - hasattr(record_data_or_message, "type") and record_data_or_message.type == MessageType.RECORD + hasattr(record_data_or_message, "type") + and record_data_or_message.type == MessageType.RECORD ): - record_data = record_data_or_message if isinstance(record_data_or_message, Mapping) else record_data_or_message.record + record_data = ( + record_data_or_message + if isinstance(record_data_or_message, Mapping) + else record_data_or_message.record + ) # Thanks I hate it. RFR fundamentally doesn't fit with the concept of the legacy Stream.get_updated_state() # method because RFR streams rely on pagination as a cursor. Stream.get_updated_state() was designed to make @@ -206,14 +220,23 @@ def read( # type: ignore # ignoring typing for ConnectorStateManager because o if self.cursor_field: # Some connectors have streams that implement get_updated_state(), but do not define a cursor_field. This # should be fixed on the stream implementation, but we should also protect against this in the CDK as well - stream_state_tracker = self.get_updated_state(stream_state_tracker, record_data) + stream_state_tracker = self.get_updated_state( + stream_state_tracker, record_data + ) self._observe_state(checkpoint_reader, stream_state_tracker) record_counter += 1 checkpoint_interval = self.state_checkpoint_interval checkpoint = checkpoint_reader.get_checkpoint() - if should_checkpoint and checkpoint_interval and record_counter % checkpoint_interval == 0 and checkpoint is not None: - airbyte_state_message = self._checkpoint_state(checkpoint, state_manager=state_manager) + if ( + should_checkpoint + and checkpoint_interval + and record_counter % checkpoint_interval == 0 + and checkpoint is not None + ): + airbyte_state_message = self._checkpoint_state( + checkpoint, state_manager=state_manager + ) yield airbyte_state_message if internal_config.is_limit_reached(record_counter): @@ -221,7 +244,9 @@ def read( # type: ignore # ignoring typing for ConnectorStateManager because o self._observe_state(checkpoint_reader) checkpoint_state = checkpoint_reader.get_checkpoint() if should_checkpoint and checkpoint_state is not None: - airbyte_state_message = self._checkpoint_state(checkpoint_state, state_manager=state_manager) + airbyte_state_message = self._checkpoint_state( + checkpoint_state, state_manager=state_manager + ) yield airbyte_state_message next_slice = checkpoint_reader.next() @@ -252,7 +277,9 @@ def read_only_records(self, state: Optional[Mapping[str, Any]] = None) -> Iterab configured_stream=configured_stream, logger=self.logger, slice_logger=DebugSliceLogger(), - stream_state=dict(state) if state else {}, # read() expects MutableMapping instead of Mapping which is used more often + stream_state=dict(state) + if state + else {}, # read() expects MutableMapping instead of Mapping which is used more often state_manager=None, internal_config=InternalConfig(), ) @@ -378,7 +405,11 @@ def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]: """ def stream_slices( - self, *, sync_mode: SyncMode, cursor_field: Optional[List[str]] = None, stream_state: Optional[Mapping[str, Any]] = None + self, + *, + sync_mode: SyncMode, + cursor_field: Optional[List[str]] = None, + stream_state: Optional[Mapping[str, Any]] = None, ) -> Iterable[Optional[Mapping[str, Any]]]: """ Override to define the slices for this stream. See the stream slicing section of the docs for more information. @@ -449,12 +480,16 @@ def _get_checkpoint_reader( mappings_or_slices = [{}] slices_iterable_copy, iterable_for_detecting_format = itertools.tee(mappings_or_slices, 2) - stream_classification = self._classify_stream(mappings_or_slices=iterable_for_detecting_format) + stream_classification = self._classify_stream( + mappings_or_slices=iterable_for_detecting_format + ) # Streams that override has_multiple_slices are explicitly indicating that they will iterate over # multiple partitions. Inspecting slices to automatically apply the correct cursor is only needed as # a backup. So if this value was already assigned to True by the stream, we don't need to reassign it - self.has_multiple_slices = self.has_multiple_slices or stream_classification.has_multiple_slices + self.has_multiple_slices = ( + self.has_multiple_slices or stream_classification.has_multiple_slices + ) cursor = self.get_cursor() if cursor: @@ -463,7 +498,9 @@ def _get_checkpoint_reader( checkpoint_mode = self._checkpoint_mode if cursor and stream_classification.is_legacy_format: - return LegacyCursorBasedCheckpointReader(stream_slices=slices_iterable_copy, cursor=cursor, read_state_from_cursor=True) + return LegacyCursorBasedCheckpointReader( + stream_slices=slices_iterable_copy, cursor=cursor, read_state_from_cursor=True + ) elif cursor: return CursorBasedCheckpointReader( stream_slices=slices_iterable_copy, @@ -475,7 +512,9 @@ def _get_checkpoint_reader( # not iterate over a static set of slices. return ResumableFullRefreshCheckpointReader(stream_state=stream_state) elif checkpoint_mode == CheckpointMode.INCREMENTAL: - return IncrementalCheckpointReader(stream_slices=slices_iterable_copy, stream_state=stream_state) + return IncrementalCheckpointReader( + stream_slices=slices_iterable_copy, stream_state=stream_state + ) else: return FullRefreshCheckpointReader(stream_slices=slices_iterable_copy) @@ -489,7 +528,9 @@ def _checkpoint_mode(self) -> CheckpointMode: return CheckpointMode.FULL_REFRESH @staticmethod - def _classify_stream(mappings_or_slices: Iterator[Optional[Union[Mapping[str, Any], StreamSlice]]]) -> StreamClassification: + def _classify_stream( + mappings_or_slices: Iterator[Optional[Union[Mapping[str, Any], StreamSlice]]], + ) -> StreamClassification: """ This is a bit of a crazy solution, but also the only way we can detect certain attributes about the stream since Python streams do not follow consistent implementation patterns. We care about the following two attributes: @@ -506,7 +547,9 @@ def _classify_stream(mappings_or_slices: Iterator[Optional[Union[Mapping[str, An raise ValueError("A stream should always have at least one slice") try: next_slice = next(mappings_or_slices) - if isinstance(next_slice, StreamSlice) and next_slice == StreamSlice(partition={}, cursor_slice={}): + if isinstance(next_slice, StreamSlice) and next_slice == StreamSlice( + partition={}, cursor_slice={} + ): is_legacy_format = False slice_has_value = False elif next_slice == {}: @@ -526,7 +569,9 @@ def _classify_stream(mappings_or_slices: Iterator[Optional[Union[Mapping[str, An if slice_has_value: # If the first slice contained a partition value from the result of stream_slices(), this is a substream that might # have multiple parent records to iterate over - return StreamClassification(is_legacy_format=is_legacy_format, has_multiple_slices=slice_has_value) + return StreamClassification( + is_legacy_format=is_legacy_format, has_multiple_slices=slice_has_value + ) try: # If stream_slices() returns multiple slices, this is also a substream that can potentially generate empty slices @@ -534,7 +579,9 @@ def _classify_stream(mappings_or_slices: Iterator[Optional[Union[Mapping[str, An return StreamClassification(is_legacy_format=is_legacy_format, has_multiple_slices=True) except StopIteration: # If the result of stream_slices() only returns a single empty stream slice, then we know this is a regular stream - return StreamClassification(is_legacy_format=is_legacy_format, has_multiple_slices=False) + return StreamClassification( + is_legacy_format=is_legacy_format, has_multiple_slices=False + ) def log_stream_sync_configuration(self) -> None: """ @@ -549,7 +596,9 @@ def log_stream_sync_configuration(self) -> None: ) @staticmethod - def _wrapped_primary_key(keys: Optional[Union[str, List[str], List[List[str]]]]) -> Optional[List[List[str]]]: + def _wrapped_primary_key( + keys: Optional[Union[str, List[str], List[List[str]]]], + ) -> Optional[List[List[str]]]: """ :return: wrap the primary_key property in a list of list of strings required by the Airbyte Stream object. """ @@ -571,7 +620,9 @@ def _wrapped_primary_key(keys: Optional[Union[str, List[str], List[List[str]]]]) else: raise ValueError(f"Element must be either list or str. Got: {type(keys)}") - def _observe_state(self, checkpoint_reader: CheckpointReader, stream_state: Optional[Mapping[str, Any]] = None) -> None: + def _observe_state( + self, checkpoint_reader: CheckpointReader, stream_state: Optional[Mapping[str, Any]] = None + ) -> None: """ Convenience method that attempts to read the Stream's state using the recommended way of connector's managing their own state via state setter/getter. But if we get back an AttributeError, then the legacy Stream.get_updated_state() @@ -617,7 +668,9 @@ def configured_json_schema(self) -> Optional[Dict[str, Any]]: def configured_json_schema(self, json_schema: Dict[str, Any]) -> None: self._configured_json_schema = self._filter_schema_invalid_properties(json_schema) - def _filter_schema_invalid_properties(self, configured_catalog_json_schema: Dict[str, Any]) -> Dict[str, Any]: + def _filter_schema_invalid_properties( + self, configured_catalog_json_schema: Dict[str, Any] + ) -> Dict[str, Any]: """ Filters the properties in json_schema that are not present in the stream schema. Configured Schemas can have very old fields, so we need to housekeeping ourselves. @@ -639,6 +692,8 @@ def _filter_schema_invalid_properties(self, configured_catalog_json_schema: Dict valid_configured_schema_properties = {} for configured_schema_property in valid_configured_schema_properties_keys: - valid_configured_schema_properties[configured_schema_property] = stream_schema_properties[configured_schema_property] + valid_configured_schema_properties[configured_schema_property] = ( + stream_schema_properties[configured_schema_property] + ) return {**configured_catalog_json_schema, "properties": valid_configured_schema_properties} diff --git a/airbyte_cdk/sources/streams/http/availability_strategy.py b/airbyte_cdk/sources/streams/http/availability_strategy.py index 4b3dba10..494fcf15 100644 --- a/airbyte_cdk/sources/streams/http/availability_strategy.py +++ b/airbyte_cdk/sources/streams/http/availability_strategy.py @@ -15,7 +15,9 @@ class HttpAvailabilityStrategy(AvailabilityStrategy): - def check_availability(self, stream: Stream, logger: logging.Logger, source: Optional["Source"] = None) -> Tuple[bool, Optional[str]]: + def check_availability( + self, stream: Stream, logger: logging.Logger, source: Optional["Source"] = None + ) -> Tuple[bool, Optional[str]]: """ Check stream availability by attempting to read the first record of the stream. diff --git a/airbyte_cdk/sources/streams/http/error_handlers/default_error_mapping.py b/airbyte_cdk/sources/streams/http/error_handlers/default_error_mapping.py index 546a910f..fa8864db 100644 --- a/airbyte_cdk/sources/streams/http/error_handlers/default_error_mapping.py +++ b/airbyte_cdk/sources/streams/http/error_handlers/default_error_mapping.py @@ -5,7 +5,10 @@ from typing import Mapping, Type, Union from airbyte_cdk.models import FailureType -from airbyte_cdk.sources.streams.http.error_handlers.response_models import ErrorResolution, ResponseAction +from airbyte_cdk.sources.streams.http.error_handlers.response_models import ( + ErrorResolution, + ResponseAction, +) from requests.exceptions import InvalidSchema, InvalidURL, RequestException DEFAULT_ERROR_MAPPING: Mapping[Union[int, str, Type[Exception]], ErrorResolution] = { diff --git a/airbyte_cdk/sources/streams/http/error_handlers/error_handler.py b/airbyte_cdk/sources/streams/http/error_handlers/error_handler.py index f1789cc6..b231e72e 100644 --- a/airbyte_cdk/sources/streams/http/error_handlers/error_handler.py +++ b/airbyte_cdk/sources/streams/http/error_handlers/error_handler.py @@ -30,7 +30,9 @@ def max_time(self) -> Optional[int]: pass @abstractmethod - def interpret_response(self, response: Optional[Union[requests.Response, Exception]]) -> ErrorResolution: + def interpret_response( + self, response: Optional[Union[requests.Response, Exception]] + ) -> ErrorResolution: """ Interpret the response or exception and return the corresponding response action, failure type, and error message. diff --git a/airbyte_cdk/sources/streams/http/error_handlers/http_status_error_handler.py b/airbyte_cdk/sources/streams/http/error_handlers/http_status_error_handler.py index 69adab30..f18e3db2 100644 --- a/airbyte_cdk/sources/streams/http/error_handlers/http_status_error_handler.py +++ b/airbyte_cdk/sources/streams/http/error_handlers/http_status_error_handler.py @@ -8,9 +8,14 @@ import requests from airbyte_cdk.models import FailureType -from airbyte_cdk.sources.streams.http.error_handlers.default_error_mapping import DEFAULT_ERROR_MAPPING +from airbyte_cdk.sources.streams.http.error_handlers.default_error_mapping import ( + DEFAULT_ERROR_MAPPING, +) from airbyte_cdk.sources.streams.http.error_handlers.error_handler import ErrorHandler -from airbyte_cdk.sources.streams.http.error_handlers.response_models import ErrorResolution, ResponseAction +from airbyte_cdk.sources.streams.http.error_handlers.response_models import ( + ErrorResolution, + ResponseAction, +) class HttpStatusErrorHandler(ErrorHandler): @@ -39,7 +44,9 @@ def max_retries(self) -> Optional[int]: def max_time(self) -> Optional[int]: return self._max_time - def interpret_response(self, response_or_exception: Optional[Union[requests.Response, Exception]] = None) -> ErrorResolution: + def interpret_response( + self, response_or_exception: Optional[Union[requests.Response, Exception]] = None + ) -> ErrorResolution: """ Interpret the response and return the corresponding response action, failure type, and error message. @@ -48,12 +55,16 @@ def interpret_response(self, response_or_exception: Optional[Union[requests.Resp """ if isinstance(response_or_exception, Exception): - mapped_error: Optional[ErrorResolution] = self._error_mapping.get(response_or_exception.__class__) + mapped_error: Optional[ErrorResolution] = self._error_mapping.get( + response_or_exception.__class__ + ) if mapped_error is not None: return mapped_error else: - self._logger.error(f"Unexpected exception in error handler: {response_or_exception}") + self._logger.error( + f"Unexpected exception in error handler: {response_or_exception}" + ) return ErrorResolution( response_action=ResponseAction.RETRY, failure_type=FailureType.system_error, diff --git a/airbyte_cdk/sources/streams/http/error_handlers/response_models.py b/airbyte_cdk/sources/streams/http/error_handlers/response_models.py index 21e20049..aca13a8c 100644 --- a/airbyte_cdk/sources/streams/http/error_handlers/response_models.py +++ b/airbyte_cdk/sources/streams/http/error_handlers/response_models.py @@ -33,13 +33,17 @@ def _format_response_error_message(response: requests.Response) -> str: try: response.raise_for_status() except HTTPError as exception: - return filter_secrets(f"Response was not ok: `{str(exception)}`. Response content is: {response.text}") + return filter_secrets( + f"Response was not ok: `{str(exception)}`. Response content is: {response.text}" + ) # We purposefully do not add the response.content because the response is "ok" so there might be sensitive information in the payload. # Feel free the return f"Unexpected response with HTTP status {response.status_code}" -def create_fallback_error_resolution(response_or_exception: Optional[Union[requests.Response, Exception]]) -> ErrorResolution: +def create_fallback_error_resolution( + response_or_exception: Optional[Union[requests.Response, Exception]], +) -> ErrorResolution: if response_or_exception is None: # We do not expect this case to happen but if it does, it would be good to understand the cause and improve the error message error_message = "Error handler did not receive a valid response or exception. This is unexpected please contact Airbyte Support" @@ -55,4 +59,6 @@ def create_fallback_error_resolution(response_or_exception: Optional[Union[reque ) -SUCCESS_RESOLUTION = ErrorResolution(response_action=ResponseAction.SUCCESS, failure_type=None, error_message=None) +SUCCESS_RESOLUTION = ErrorResolution( + response_action=ResponseAction.SUCCESS, failure_type=None, error_message=None +) diff --git a/airbyte_cdk/sources/streams/http/exceptions.py b/airbyte_cdk/sources/streams/http/exceptions.py index 3db57ffe..ee468762 100644 --- a/airbyte_cdk/sources/streams/http/exceptions.py +++ b/airbyte_cdk/sources/streams/http/exceptions.py @@ -17,7 +17,8 @@ def __init__( ): if isinstance(response, requests.Response): error_message = ( - error_message or f"Request URL: {request.url}, Response Code: {response.status_code}, Response Text: {response.text}" + error_message + or f"Request URL: {request.url}, Response Code: {response.status_code}, Response Text: {response.text}" ) super().__init__(error_message, request=request, response=response) else: diff --git a/airbyte_cdk/sources/streams/http/http.py b/airbyte_cdk/sources/streams/http/http.py index a132702e..f9731517 100644 --- a/airbyte_cdk/sources/streams/http/http.py +++ b/airbyte_cdk/sources/streams/http/http.py @@ -14,11 +14,22 @@ from airbyte_cdk.sources.message.repository import InMemoryMessageRepository from airbyte_cdk.sources.streams.call_rate import APIBudget from airbyte_cdk.sources.streams.checkpoint.cursor import Cursor -from airbyte_cdk.sources.streams.checkpoint.resumable_full_refresh_cursor import ResumableFullRefreshCursor -from airbyte_cdk.sources.streams.checkpoint.substream_resumable_full_refresh_cursor import SubstreamResumableFullRefreshCursor +from airbyte_cdk.sources.streams.checkpoint.resumable_full_refresh_cursor import ( + ResumableFullRefreshCursor, +) +from airbyte_cdk.sources.streams.checkpoint.substream_resumable_full_refresh_cursor import ( + SubstreamResumableFullRefreshCursor, +) from airbyte_cdk.sources.streams.core import CheckpointMixin, Stream, StreamData -from airbyte_cdk.sources.streams.http.error_handlers import BackoffStrategy, ErrorHandler, HttpStatusErrorHandler -from airbyte_cdk.sources.streams.http.error_handlers.response_models import ErrorResolution, ResponseAction +from airbyte_cdk.sources.streams.http.error_handlers import ( + BackoffStrategy, + ErrorHandler, + HttpStatusErrorHandler, +) +from airbyte_cdk.sources.streams.http.error_handlers.response_models import ( + ErrorResolution, + ResponseAction, +) from airbyte_cdk.sources.streams.http.http_client import HttpClient from airbyte_cdk.sources.types import Record, StreamSlice from airbyte_cdk.sources.utils.types import JsonType @@ -35,9 +46,13 @@ class HttpStream(Stream, CheckpointMixin, ABC): """ source_defined_cursor = True # Most HTTP streams use a source defined cursor (i.e: the user can't configure it like on a SQL table) - page_size: Optional[int] = None # Use this variable to define page size for API http requests with pagination support + page_size: Optional[int] = ( + None # Use this variable to define page size for API http requests with pagination support + ) - def __init__(self, authenticator: Optional[AuthBase] = None, api_budget: Optional[APIBudget] = None): + def __init__( + self, authenticator: Optional[AuthBase] = None, api_budget: Optional[APIBudget] = None + ): self._exit_on_rate_limit: bool = False self._http_client = HttpClient( name=self.name, @@ -55,7 +70,11 @@ def __init__(self, authenticator: Optional[AuthBase] = None, api_budget: Optiona # 2. Streams with at least one cursor_field are incremental and thus a superior sync to RFR. # 3. Streams overriding read_records() do not guarantee that they will call the parent implementation which can perform # per-page checkpointing so RFR is only supported if a stream use the default `HttpStream.read_records()` method - if not self.cursor and len(self.cursor_field) == 0 and type(self).read_records is HttpStream.read_records: + if ( + not self.cursor + and len(self.cursor_field) == 0 + and type(self).read_records is HttpStream.read_records + ): self.cursor = ResumableFullRefreshCursor() @property @@ -100,7 +119,10 @@ def http_method(self) -> str: return "GET" @property - @deprecated(version="3.0.0", reason="You should set error_handler explicitly in HttpStream.get_error_handler() instead.") + @deprecated( + version="3.0.0", + reason="You should set error_handler explicitly in HttpStream.get_error_handler() instead.", + ) def raise_on_http_errors(self) -> bool: """ Override if needed. If set to False, allows opting-out of raising HTTP code exception. @@ -108,7 +130,10 @@ def raise_on_http_errors(self) -> bool: return True @property - @deprecated(version="3.0.0", reason="You should set backoff_strategies explicitly in HttpStream.get_backoff_strategy() instead.") + @deprecated( + version="3.0.0", + reason="You should set backoff_strategies explicitly in HttpStream.get_backoff_strategy() instead.", + ) def max_retries(self) -> Union[int, None]: """ Override if needed. Specifies maximum amount of retries for backoff policy. Return None for no limit. @@ -116,7 +141,10 @@ def max_retries(self) -> Union[int, None]: return 5 @property - @deprecated(version="3.0.0", reason="You should set backoff_strategies explicitly in HttpStream.get_backoff_strategy() instead.") + @deprecated( + version="3.0.0", + reason="You should set backoff_strategies explicitly in HttpStream.get_backoff_strategy() instead.", + ) def max_time(self) -> Union[int, None]: """ Override if needed. Specifies maximum total waiting time (in seconds) for backoff policy. Return None for no limit. @@ -124,7 +152,10 @@ def max_time(self) -> Union[int, None]: return 60 * 10 @property - @deprecated(version="3.0.0", reason="You should set backoff_strategies explicitly in HttpStream.get_backoff_strategy() instead.") + @deprecated( + version="3.0.0", + reason="You should set backoff_strategies explicitly in HttpStream.get_backoff_strategy() instead.", + ) def retry_factor(self) -> float: """ Override if needed. Specifies factor for backoff policy. @@ -262,7 +293,10 @@ def get_error_handler(self) -> Optional[ErrorHandler]: """ if hasattr(self, "should_retry"): error_handler = HttpStreamAdapterHttpStatusErrorHandler( - stream=self, logger=logging.getLogger(), max_retries=self.max_retries, max_time=timedelta(seconds=self.max_time or 0) + stream=self, + logger=logging.getLogger(), + max_retries=self.max_retries, + max_time=timedelta(seconds=self.max_time or 0), ) return error_handler else: @@ -333,13 +367,17 @@ def read_records( # A cursor_field indicates this is an incremental stream which offers better checkpointing than RFR enabled via the cursor if self.cursor_field or not isinstance(self.get_cursor(), ResumableFullRefreshCursor): yield from self._read_pages( - lambda req, res, state, _slice: self.parse_response(res, stream_slice=_slice, stream_state=state), + lambda req, res, state, _slice: self.parse_response( + res, stream_slice=_slice, stream_state=state + ), stream_slice, stream_state, ) else: yield from self._read_single_page( - lambda req, res, state, _slice: self.parse_response(res, stream_slice=_slice, stream_state=state), + lambda req, res, state, _slice: self.parse_response( + res, stream_slice=_slice, stream_state=state + ), stream_slice, stream_state, ) @@ -373,7 +411,13 @@ def get_cursor(self) -> Optional[Cursor]: def _read_pages( self, records_generator_fn: Callable[ - [requests.PreparedRequest, requests.Response, Mapping[str, Any], Optional[Mapping[str, Any]]], Iterable[StreamData] + [ + requests.PreparedRequest, + requests.Response, + Mapping[str, Any], + Optional[Mapping[str, Any]], + ], + Iterable[StreamData], ], stream_slice: Optional[Mapping[str, Any]] = None, stream_state: Optional[Mapping[str, Any]] = None, @@ -403,19 +447,29 @@ def _read_pages( def _read_single_page( self, records_generator_fn: Callable[ - [requests.PreparedRequest, requests.Response, Mapping[str, Any], Optional[Mapping[str, Any]]], Iterable[StreamData] + [ + requests.PreparedRequest, + requests.Response, + Mapping[str, Any], + Optional[Mapping[str, Any]], + ], + Iterable[StreamData], ], stream_slice: Optional[Mapping[str, Any]] = None, stream_state: Optional[Mapping[str, Any]] = None, ) -> Iterable[StreamData]: - partition, cursor_slice, remaining_slice = self._extract_slice_fields(stream_slice=stream_slice) + partition, cursor_slice, remaining_slice = self._extract_slice_fields( + stream_slice=stream_slice + ) stream_state = stream_state or {} next_page_token = cursor_slice or None request, response = self._fetch_next_page(remaining_slice, stream_state, next_page_token) yield from records_generator_fn(request, response, stream_state, remaining_slice) - next_page_token = self.next_page_token(response) or {"__ab_full_refresh_sync_complete": True} + next_page_token = self.next_page_token(response) or { + "__ab_full_refresh_sync_complete": True + } cursor = self.get_cursor() if cursor: @@ -425,7 +479,9 @@ def _read_single_page( yield from [] @staticmethod - def _extract_slice_fields(stream_slice: Optional[Mapping[str, Any]]) -> tuple[Mapping[str, Any], Mapping[str, Any], Mapping[str, Any]]: + def _extract_slice_fields( + stream_slice: Optional[Mapping[str, Any]], + ) -> tuple[Mapping[str, Any], Mapping[str, Any], Mapping[str, Any]]: if not stream_slice: return {}, {}, {} @@ -439,7 +495,11 @@ def _extract_slice_fields(stream_slice: Optional[Mapping[str, Any]]) -> tuple[Ma # fields for the partition and cursor_slice value partition = stream_slice.get("partition", {}) cursor_slice = stream_slice.get("cursor_slice", {}) - remaining = {key: val for key, val in stream_slice.items() if key != "partition" and key != "cursor_slice"} + remaining = { + key: val + for key, val in stream_slice.items() + if key != "partition" and key != "cursor_slice" + } return partition, cursor_slice, remaining def _fetch_next_page( @@ -452,13 +512,37 @@ def _fetch_next_page( http_method=self.http_method, url=self._join_url( self.url_base, - self.path(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token), + self.path( + stream_state=stream_state, + stream_slice=stream_slice, + next_page_token=next_page_token, + ), + ), + request_kwargs=self.request_kwargs( + stream_state=stream_state, + stream_slice=stream_slice, + next_page_token=next_page_token, + ), + headers=self.request_headers( + stream_state=stream_state, + stream_slice=stream_slice, + next_page_token=next_page_token, + ), + params=self.request_params( + stream_state=stream_state, + stream_slice=stream_slice, + next_page_token=next_page_token, + ), + json=self.request_body_json( + stream_state=stream_state, + stream_slice=stream_slice, + next_page_token=next_page_token, + ), + data=self.request_body_data( + stream_state=stream_state, + stream_slice=stream_slice, + next_page_token=next_page_token, ), - request_kwargs=self.request_kwargs(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token), - headers=self.request_headers(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token), - params=self.request_params(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token), - json=self.request_body_json(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token), - data=self.request_body_data(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token), dedupe_query_params=True, log_formatter=self.get_log_formatter(), exit_on_rate_limit=self.exit_on_rate_limit, @@ -481,18 +565,27 @@ def __init__(self, parent: HttpStream, **kwargs: Any): """ super().__init__(**kwargs) self.parent = parent - self.has_multiple_slices = True # Substreams are based on parent records which implies there are multiple slices + self.has_multiple_slices = ( + True # Substreams are based on parent records which implies there are multiple slices + ) # There are three conditions that dictate if RFR should automatically be applied to a stream # 1. Streams that explicitly initialize their own cursor should defer to it and not automatically apply RFR # 2. Streams with at least one cursor_field are incremental and thus a superior sync to RFR. # 3. Streams overriding read_records() do not guarantee that they will call the parent implementation which can perform # per-page checkpointing so RFR is only supported if a stream use the default `HttpStream.read_records()` method - if not self.cursor and len(self.cursor_field) == 0 and type(self).read_records is HttpStream.read_records: + if ( + not self.cursor + and len(self.cursor_field) == 0 + and type(self).read_records is HttpStream.read_records + ): self.cursor = SubstreamResumableFullRefreshCursor() def stream_slices( - self, sync_mode: SyncMode, cursor_field: Optional[List[str]] = None, stream_state: Optional[Mapping[str, Any]] = None + self, + sync_mode: SyncMode, + cursor_field: Optional[List[str]] = None, + stream_state: Optional[Mapping[str, Any]] = None, ) -> Iterable[Optional[Mapping[str, Any]]]: # read_stateless() assumes the parent is not concurrent. This is currently okay since the concurrent CDK does # not support either substreams or RFR, but something that needs to be considered once we do @@ -508,7 +601,10 @@ def stream_slices( yield {"parent": parent_record} -@deprecated(version="3.0.0", reason="You should set backoff_strategies explicitly in HttpStream.get_backoff_strategy() instead.") +@deprecated( + version="3.0.0", + reason="You should set backoff_strategies explicitly in HttpStream.get_backoff_strategy() instead.", +) class HttpStreamAdapterBackoffStrategy(BackoffStrategy): def __init__(self, stream: HttpStream): self.stream = stream @@ -521,13 +617,18 @@ def backoff_time( return self.stream.backoff_time(response_or_exception) # type: ignore # noqa # HttpStream.backoff_time has been deprecated -@deprecated(version="3.0.0", reason="You should set error_handler explicitly in HttpStream.get_error_handler() instead.") +@deprecated( + version="3.0.0", + reason="You should set error_handler explicitly in HttpStream.get_error_handler() instead.", +) class HttpStreamAdapterHttpStatusErrorHandler(HttpStatusErrorHandler): def __init__(self, stream: HttpStream, **kwargs): # type: ignore # noqa self.stream = stream super().__init__(**kwargs) - def interpret_response(self, response_or_exception: Optional[Union[requests.Response, Exception]] = None) -> ErrorResolution: + def interpret_response( + self, response_or_exception: Optional[Union[requests.Response, Exception]] = None + ) -> ErrorResolution: if isinstance(response_or_exception, Exception): return super().interpret_response(response_or_exception) elif isinstance(response_or_exception, requests.Response): diff --git a/airbyte_cdk/sources/streams/http/http_client.py b/airbyte_cdk/sources/streams/http/http_client.py index 0b57b4ed..704b715c 100644 --- a/airbyte_cdk/sources/streams/http/http_client.py +++ b/airbyte_cdk/sources/streams/http/http_client.py @@ -44,7 +44,9 @@ user_defined_backoff_handler, ) from airbyte_cdk.utils.constants import ENV_REQUEST_CACHE_PATH -from airbyte_cdk.utils.stream_status_utils import as_airbyte_message as stream_status_as_airbyte_message +from airbyte_cdk.utils.stream_status_utils import ( + as_airbyte_message as stream_status_as_airbyte_message, +) from airbyte_cdk.utils.traced_exception import AirbyteTracedException from requests.auth import AuthBase @@ -94,7 +96,10 @@ def __init__( self._use_cache = use_cache self._session = self._request_session() self._session.mount( - "https://", requests.adapters.HTTPAdapter(pool_connections=MAX_CONNECTION_POOL_SIZE, pool_maxsize=MAX_CONNECTION_POOL_SIZE) + "https://", + requests.adapters.HTTPAdapter( + pool_connections=MAX_CONNECTION_POOL_SIZE, pool_maxsize=MAX_CONNECTION_POOL_SIZE + ), ) if isinstance(authenticator, AuthBase): self._session.auth = authenticator @@ -133,7 +138,9 @@ def _request_session(self) -> requests.Session: sqlite_path = str(Path(cache_dir) / self.cache_filename) else: sqlite_path = "file::memory:?cache=shared" - return CachedLimiterSession(sqlite_path, backend="sqlite", api_budget=self._api_budget, match_headers=True) # type: ignore # there are no typeshed stubs for requests_cache + return CachedLimiterSession( + sqlite_path, backend="sqlite", api_budget=self._api_budget, match_headers=True + ) # type: ignore # there are no typeshed stubs for requests_cache else: return LimiterSession(api_budget=self._api_budget) @@ -144,7 +151,9 @@ def clear_cache(self) -> None: if isinstance(self._session, requests_cache.CachedSession): self._session.cache.clear() # type: ignore # cache.clear is not typed - def _dedupe_query_params(self, url: str, params: Optional[Mapping[str, str]]) -> Mapping[str, str]: + def _dedupe_query_params( + self, url: str, params: Optional[Mapping[str, str]] + ) -> Mapping[str, str]: """ Remove query parameters from params mapping if they are already encoded in the URL. :param url: URL with @@ -156,7 +165,9 @@ def _dedupe_query_params(self, url: str, params: Optional[Mapping[str, str]]) -> query_string = urllib.parse.urlparse(url).query query_dict = {k: v[0] for k, v in urllib.parse.parse_qs(query_string).items()} - duplicate_keys_with_same_value = {k for k in query_dict.keys() if str(params.get(k)) == str(query_dict[k])} + duplicate_keys_with_same_value = { + k for k in query_dict.keys() if str(params.get(k)) == str(query_dict[k]) + } return {k: v for k, v in params.items() if k not in duplicate_keys_with_same_value} def _create_prepared_request( @@ -183,7 +194,9 @@ def _create_prepared_request( args["json"] = json elif data: args["data"] = data - prepared_request: requests.PreparedRequest = self._session.prepare_request(requests.Request(**args)) + prepared_request: requests.PreparedRequest = self._session.prepare_request( + requests.Request(**args) + ) return prepared_request @@ -204,7 +217,11 @@ def _max_time(self) -> int: """ Determines the max time based on the provided error handler. """ - return self._error_handler.max_time if self._error_handler.max_time is not None else self._DEFAULT_MAX_TIME + return ( + self._error_handler.max_time + if self._error_handler.max_time is not None + else self._DEFAULT_MAX_TIME + ) def _send_with_retry( self, @@ -228,12 +245,19 @@ def _send_with_retry( max_tries = max(0, max_retries) + 1 max_time = self._max_time - user_backoff_handler = user_defined_backoff_handler(max_tries=max_tries, max_time=max_time)(self._send) + user_backoff_handler = user_defined_backoff_handler(max_tries=max_tries, max_time=max_time)( + self._send + ) rate_limit_backoff_handler = rate_limit_default_backoff_handler() - backoff_handler = http_client_default_backoff_handler(max_tries=max_tries, max_time=max_time) + backoff_handler = http_client_default_backoff_handler( + max_tries=max_tries, max_time=max_time + ) # backoff handlers wrap _send, so it will always return a response response = backoff_handler(rate_limit_backoff_handler(user_backoff_handler))( - request, request_kwargs, log_formatter=log_formatter, exit_on_rate_limit=exit_on_rate_limit + request, + request_kwargs, + log_formatter=log_formatter, + exit_on_rate_limit=exit_on_rate_limit, ) # type: ignore # mypy can't infer that backoff_handler wraps _send return response @@ -253,7 +277,8 @@ def _send( self._session.auth(request) self._logger.debug( - "Making outbound API request", extra={"headers": request.headers, "url": request.url, "request_body": request.body} + "Making outbound API request", + extra={"headers": request.headers, "url": request.url, "request_body": request.body}, ) response: Optional[requests.Response] = None @@ -264,7 +289,9 @@ def _send( except requests.RequestException as e: exc = e - error_resolution: ErrorResolution = self._error_handler.interpret_response(response if response is not None else exc) + error_resolution: ErrorResolution = self._error_handler.interpret_response( + response if response is not None else exc + ) # Evaluation of response.text can be heavy, for example, if streaming a large response # Do it only in debug mode @@ -276,11 +303,20 @@ def _send( ) else: self._logger.debug( - "Receiving response", extra={"headers": response.headers, "status": response.status_code, "body": response.text} + "Receiving response", + extra={ + "headers": response.headers, + "status": response.status_code, + "body": response.text, + }, ) # Request/response logging for declarative cdk - if log_formatter is not None and response is not None and self._message_repository is not None: + if ( + log_formatter is not None + and response is not None + and self._message_repository is not None + ): formatter = log_formatter self._message_repository.log_message( Level.DEBUG, @@ -288,7 +324,11 @@ def _send( ) self._handle_error_resolution( - response=response, exc=exc, request=request, error_resolution=error_resolution, exit_on_rate_limit=exit_on_rate_limit + response=response, + exc=exc, + request=request, + error_resolution=error_resolution, + exit_on_rate_limit=exit_on_rate_limit, ) return response # type: ignore # will either return a valid response of type requests.Response or raise an exception @@ -307,7 +347,9 @@ def _handle_error_resolution( reasons = [AirbyteStreamStatusReason(type=AirbyteStreamStatusReasonType.RATE_LIMITED)] message = orjson.dumps( AirbyteMessageSerializer.dump( - stream_status_as_airbyte_message(StreamDescriptor(name=self._name), AirbyteStreamStatus.RUNNING, reasons) + stream_status_as_airbyte_message( + StreamDescriptor(name=self._name), AirbyteStreamStatus.RUNNING, reasons + ) ) ).decode() @@ -321,7 +363,9 @@ def _handle_error_resolution( if response is not None: error_message = f"'{request.method}' request to '{request.url}' failed with status code '{response.status_code}' and error message '{self._error_message_parser.parse_response_error_message(response)}'" else: - error_message = f"'{request.method}' request to '{request.url}' failed with exception: '{exc}'" + error_message = ( + f"'{request.method}' request to '{request.url}' failed with exception: '{exc}'" + ) raise MessageRepresentationAirbyteTracedErrors( internal_message=error_message, @@ -331,20 +375,22 @@ def _handle_error_resolution( elif error_resolution.response_action == ResponseAction.IGNORE: if response is not None: - log_message = ( - f"Ignoring response for '{request.method}' request to '{request.url}' with response code '{response.status_code}'" - ) + log_message = f"Ignoring response for '{request.method}' request to '{request.url}' with response code '{response.status_code}'" else: log_message = f"Ignoring response for '{request.method}' request to '{request.url}' with error '{exc}'" self._logger.info(error_resolution.error_message or log_message) # TODO: Consider dynamic retry count depending on subsequent error codes - elif error_resolution.response_action == ResponseAction.RETRY or error_resolution.response_action == ResponseAction.RATE_LIMITED: + elif ( + error_resolution.response_action == ResponseAction.RETRY + or error_resolution.response_action == ResponseAction.RATE_LIMITED + ): user_defined_backoff_time = None for backoff_strategy in self._backoff_strategies: backoff_time = backoff_strategy.backoff_time( - response_or_exception=response if response is not None else exc, attempt_count=self._request_attempt_count[request] + response_or_exception=response if response is not None else exc, + attempt_count=self._request_attempt_count[request], ) if backoff_time: user_defined_backoff_time = backoff_time @@ -354,7 +400,10 @@ def _handle_error_resolution( or f"Request to {request.url} failed with failure type {error_resolution.failure_type}, response action {error_resolution.response_action}." ) - retry_endlessly = error_resolution.response_action == ResponseAction.RATE_LIMITED and not exit_on_rate_limit + retry_endlessly = ( + error_resolution.response_action == ResponseAction.RATE_LIMITED + and not exit_on_rate_limit + ) if user_defined_backoff_time: raise UserDefinedBackoffException( @@ -365,10 +414,14 @@ def _handle_error_resolution( ) elif retry_endlessly: - raise RateLimitBackoffException(request=request, response=response or exc, error_message=error_message) + raise RateLimitBackoffException( + request=request, response=response or exc, error_message=error_message + ) raise DefaultBackoffException( - request=request, response=(response if response is not None else exc), error_message=error_message + request=request, + response=(response if response is not None else exc), + error_message=error_message, ) elif response: @@ -400,11 +453,20 @@ def send_request( """ request: requests.PreparedRequest = self._create_prepared_request( - http_method=http_method, url=url, dedupe_query_params=dedupe_query_params, headers=headers, params=params, json=json, data=data + http_method=http_method, + url=url, + dedupe_query_params=dedupe_query_params, + headers=headers, + params=params, + json=json, + data=data, ) response: requests.Response = self._send_with_retry( - request=request, request_kwargs=request_kwargs, log_formatter=log_formatter, exit_on_rate_limit=exit_on_rate_limit + request=request, + request_kwargs=request_kwargs, + log_formatter=log_formatter, + exit_on_rate_limit=exit_on_rate_limit, ) return request, response diff --git a/airbyte_cdk/sources/streams/http/rate_limiting.py b/airbyte_cdk/sources/streams/http/rate_limiting.py index cae3907d..926a7ad5 100644 --- a/airbyte_cdk/sources/streams/http/rate_limiting.py +++ b/airbyte_cdk/sources/streams/http/rate_limiting.py @@ -10,7 +10,11 @@ import backoff from requests import PreparedRequest, RequestException, Response, codes, exceptions -from .exceptions import DefaultBackoffException, RateLimitBackoffException, UserDefinedBackoffException +from .exceptions import ( + DefaultBackoffException, + RateLimitBackoffException, + UserDefinedBackoffException, +) TRANSIENT_EXCEPTIONS = ( DefaultBackoffException, @@ -32,7 +36,9 @@ def default_backoff_handler( def log_retry_attempt(details: Mapping[str, Any]) -> None: _, exc, _ = sys.exc_info() if isinstance(exc, RequestException) and exc.response: - logger.info(f"Status code: {exc.response.status_code!r}, Response Content: {exc.response.content!r}") + logger.info( + f"Status code: {exc.response.status_code!r}, Response Content: {exc.response.content!r}" + ) logger.info( f"Caught retryable error '{str(exc)}' after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..." ) @@ -71,7 +77,9 @@ def http_client_default_backoff_handler( def log_retry_attempt(details: Mapping[str, Any]) -> None: _, exc, _ = sys.exc_info() if isinstance(exc, RequestException) and exc.response: - logger.info(f"Status code: {exc.response.status_code!r}, Response Content: {exc.response.content!r}") + logger.info( + f"Status code: {exc.response.status_code!r}, Response Content: {exc.response.content!r}" + ) logger.info( f"Caught retryable error '{str(exc)}' after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..." ) @@ -99,7 +107,9 @@ def sleep_on_ratelimit(details: Mapping[str, Any]) -> None: _, exc, _ = sys.exc_info() if isinstance(exc, UserDefinedBackoffException): if exc.response: - logger.info(f"Status code: {exc.response.status_code!r}, Response Content: {exc.response.content!r}") + logger.info( + f"Status code: {exc.response.status_code!r}, Response Content: {exc.response.content!r}" + ) retry_after = exc.backoff logger.info(f"Retrying. Sleeping for {retry_after} seconds") time.sleep(retry_after + 1) # extra second to cover any fractions of second @@ -107,7 +117,9 @@ def sleep_on_ratelimit(details: Mapping[str, Any]) -> None: def log_give_up(details: Mapping[str, Any]) -> None: _, exc, _ = sys.exc_info() if isinstance(exc, RequestException): - logger.error(f"Max retry limit reached in {details['elapsed']}s. Request: {exc.request}, Response: {exc.response}") + logger.error( + f"Max retry limit reached in {details['elapsed']}s. Request: {exc.request}, Response: {exc.response}" + ) else: logger.error("Max retry limit reached for unknown request and response") @@ -124,11 +136,15 @@ def log_give_up(details: Mapping[str, Any]) -> None: ) -def rate_limit_default_backoff_handler(**kwargs: Any) -> Callable[[SendRequestCallableType], SendRequestCallableType]: +def rate_limit_default_backoff_handler( + **kwargs: Any, +) -> Callable[[SendRequestCallableType], SendRequestCallableType]: def log_retry_attempt(details: Mapping[str, Any]) -> None: _, exc, _ = sys.exc_info() if isinstance(exc, RequestException) and exc.response: - logger.info(f"Status code: {exc.response.status_code!r}, Response Content: {exc.response.content!r}") + logger.info( + f"Status code: {exc.response.status_code!r}, Response Content: {exc.response.content!r}" + ) logger.info( f"Caught retryable error '{str(exc)}' after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..." ) diff --git a/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py b/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py index 63915f71..7942aa36 100644 --- a/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py +++ b/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py @@ -92,7 +92,9 @@ def build_refresh_request_body(self) -> Mapping[str, Any]: return payload - def _wrap_refresh_token_exception(self, exception: requests.exceptions.RequestException) -> bool: + def _wrap_refresh_token_exception( + self, exception: requests.exceptions.RequestException + ) -> bool: try: if exception.response is not None: exception_content = exception.response.json() @@ -102,7 +104,8 @@ def _wrap_refresh_token_exception(self, exception: requests.exceptions.RequestEx return False return ( exception.response.status_code in self._refresh_token_error_status_codes - and exception_content.get(self._refresh_token_error_key) in self._refresh_token_error_values + and exception_content.get(self._refresh_token_error_key) + in self._refresh_token_error_values ) @backoff.on_exception( @@ -115,14 +118,20 @@ def _wrap_refresh_token_exception(self, exception: requests.exceptions.RequestEx ) def _get_refresh_access_token_response(self) -> Any: try: - response = requests.request(method="POST", url=self.get_token_refresh_endpoint(), data=self.build_refresh_request_body()) + response = requests.request( + method="POST", + url=self.get_token_refresh_endpoint(), + data=self.build_refresh_request_body(), + ) if response.ok: response_json = response.json() # Add the access token to the list of secrets so it is replaced before logging the response # An argument could be made to remove the prevous access key from the list of secrets, but unmasking values seems like a security incident waiting to happen... access_key = response_json.get(self.get_access_token_name()) if not access_key: - raise Exception("Token refresh API response was missing access token {self.get_access_token_name()}") + raise Exception( + "Token refresh API response was missing access token {self.get_access_token_name()}" + ) add_to_secrets(access_key) self._log_response(response) return response_json @@ -136,7 +145,9 @@ def _get_refresh_access_token_response(self) -> Any: raise DefaultBackoffException(request=e.response.request, response=e.response) if self._wrap_refresh_token_exception(e): message = "Refresh token is invalid or expired. Please re-authenticate from Sources//Settings." - raise AirbyteTracedException(internal_message=message, message=message, failure_type=FailureType.config_error) + raise AirbyteTracedException( + internal_message=message, message=message, failure_type=FailureType.config_error + ) raise except Exception as e: raise Exception(f"Error while refreshing access token: {e}") from e @@ -149,7 +160,9 @@ def refresh_access_token(self) -> Tuple[str, Union[str, int]]: """ response_json = self._get_refresh_access_token_response() - return response_json[self.get_access_token_name()], response_json[self.get_expires_in_name()] + return response_json[self.get_access_token_name()], response_json[ + self.get_expires_in_name() + ] def _parse_token_expiration_date(self, value: Union[str, int]) -> pendulum.DateTime: """ diff --git a/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py b/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py index 1728f409..4ae84048 100644 --- a/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py +++ b/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py @@ -6,9 +6,14 @@ import dpath import pendulum -from airbyte_cdk.config_observation import create_connector_config_control_message, emit_configuration_as_airbyte_control_message +from airbyte_cdk.config_observation import ( + create_connector_config_control_message, + emit_configuration_as_airbyte_control_message, +) from airbyte_cdk.sources.message import MessageRepository, NoopMessageRepository -from airbyte_cdk.sources.streams.http.requests_native_auth.abstract_oauth import AbstractOauth2Authenticator +from airbyte_cdk.sources.streams.http.requests_native_auth.abstract_oauth import ( + AbstractOauth2Authenticator, +) class Oauth2Authenticator(AbstractOauth2Authenticator): @@ -50,7 +55,9 @@ def __init__( self._token_expiry_date_format = token_expiry_date_format self._token_expiry_is_time_of_expiration = token_expiry_is_time_of_expiration self._access_token = None - super().__init__(refresh_token_error_status_codes, refresh_token_error_key, refresh_token_error_values) + super().__init__( + refresh_token_error_status_codes, refresh_token_error_key, refresh_token_error_values + ) def get_token_refresh_endpoint(self) -> str: return self._token_refresh_endpoint @@ -153,8 +160,16 @@ def __init__( token_expiry_is_time_of_expiration bool: set True it if expires_in is returned as time of expiration instead of the number seconds until expiration message_repository (MessageRepository): the message repository used to emit logs on HTTP requests and control message on config update """ - self._client_id = client_id if client_id is not None else dpath.get(connector_config, ("credentials", "client_id")) - self._client_secret = client_secret if client_secret is not None else dpath.get(connector_config, ("credentials", "client_secret")) + self._client_id = ( + client_id + if client_id is not None + else dpath.get(connector_config, ("credentials", "client_id")) + ) + self._client_secret = ( + client_secret + if client_secret is not None + else dpath.get(connector_config, ("credentials", "client_secret")) + ) self._access_token_config_path = access_token_config_path self._refresh_token_config_path = refresh_token_config_path self._token_expiry_date_config_path = token_expiry_date_config_path @@ -204,18 +219,24 @@ def set_refresh_token(self, new_refresh_token: str): dpath.new(self._connector_config, self._refresh_token_config_path, new_refresh_token) def get_token_expiry_date(self) -> pendulum.DateTime: - expiry_date = dpath.get(self._connector_config, self._token_expiry_date_config_path, default="") + expiry_date = dpath.get( + self._connector_config, self._token_expiry_date_config_path, default="" + ) return pendulum.now().subtract(days=1) if expiry_date == "" else pendulum.parse(expiry_date) def set_token_expiry_date(self, new_token_expiry_date): - dpath.new(self._connector_config, self._token_expiry_date_config_path, str(new_token_expiry_date)) + dpath.new( + self._connector_config, self._token_expiry_date_config_path, str(new_token_expiry_date) + ) def token_has_expired(self) -> bool: """Returns True if the token is expired""" return pendulum.now("UTC") > self.get_token_expiry_date() @staticmethod - def get_new_token_expiry_date(access_token_expires_in: str, token_expiry_date_format: str = None) -> pendulum.DateTime: + def get_new_token_expiry_date( + access_token_expires_in: str, token_expiry_date_format: str = None + ) -> pendulum.DateTime: if token_expiry_date_format: return pendulum.from_format(access_token_expires_in, token_expiry_date_format) else: @@ -228,8 +249,12 @@ def get_access_token(self) -> str: str: The current access_token, updated if it was previously expired. """ if self.token_has_expired(): - new_access_token, access_token_expires_in, new_refresh_token = self.refresh_access_token() - new_token_expiry_date = self.get_new_token_expiry_date(access_token_expires_in, self._token_expiry_date_format) + new_access_token, access_token_expires_in, new_refresh_token = ( + self.refresh_access_token() + ) + new_token_expiry_date = self.get_new_token_expiry_date( + access_token_expires_in, self._token_expiry_date_format + ) self.access_token = new_access_token self.set_refresh_token(new_refresh_token) self.set_token_expiry_date(new_token_expiry_date) @@ -237,7 +262,9 @@ def get_access_token(self) -> str: # Usually, a class shouldn't care about the implementation details but to keep backward compatibility where we print the # message directly in the console, this is needed if not isinstance(self._message_repository, NoopMessageRepository): - self._message_repository.emit_message(create_connector_config_control_message(self._connector_config)) + self._message_repository.emit_message( + create_connector_config_control_message(self._connector_config) + ) else: emit_configuration_as_airbyte_control_message(self._connector_config) return self.access_token diff --git a/airbyte_cdk/sources/streams/http/requests_native_auth/token.py b/airbyte_cdk/sources/streams/http/requests_native_auth/token.py index becfe810..eec7fd0c 100644 --- a/airbyte_cdk/sources/streams/http/requests_native_auth/token.py +++ b/airbyte_cdk/sources/streams/http/requests_native_auth/token.py @@ -6,7 +6,9 @@ from itertools import cycle from typing import List -from airbyte_cdk.sources.streams.http.requests_native_auth.abstract_token import AbstractHeaderAuthenticator +from airbyte_cdk.sources.streams.http.requests_native_auth.abstract_token import ( + AbstractHeaderAuthenticator, +) class MultipleTokenAuthenticator(AbstractHeaderAuthenticator): @@ -24,7 +26,9 @@ def auth_header(self) -> str: def token(self) -> str: return f"{self._auth_method} {next(self._tokens_iter)}" - def __init__(self, tokens: List[str], auth_method: str = "Bearer", auth_header: str = "Authorization"): + def __init__( + self, tokens: List[str], auth_method: str = "Bearer", auth_header: str = "Authorization" + ): self._auth_method = auth_method self._auth_header = auth_header self._tokens = tokens @@ -65,7 +69,13 @@ def auth_header(self) -> str: def token(self) -> str: return f"{self._auth_method} {self._token}" - def __init__(self, username: str, password: str = "", auth_method: str = "Basic", auth_header: str = "Authorization"): + def __init__( + self, + username: str, + password: str = "", + auth_method: str = "Basic", + auth_header: str = "Authorization", + ): auth_string = f"{username}:{password}".encode("utf8") b64_encoded = base64.b64encode(auth_string).decode("utf8") self._auth_header = auth_header diff --git a/airbyte_cdk/sources/types.py b/airbyte_cdk/sources/types.py index 6659c8dd..eb13cd08 100644 --- a/airbyte_cdk/sources/types.py +++ b/airbyte_cdk/sources/types.py @@ -54,7 +54,11 @@ def __ne__(self, other: object) -> bool: class StreamSlice(Mapping[str, Any]): def __init__( - self, *, partition: Mapping[str, Any], cursor_slice: Mapping[str, Any], extra_fields: Optional[Mapping[str, Any]] = None + self, + *, + partition: Mapping[str, Any], + cursor_slice: Mapping[str, Any], + extra_fields: Optional[Mapping[str, Any]] = None, ) -> None: """ :param partition: The partition keys representing a unique partition in the stream. diff --git a/airbyte_cdk/sources/utils/record_helper.py b/airbyte_cdk/sources/utils/record_helper.py index 98cefd1a..e45601c2 100644 --- a/airbyte_cdk/sources/utils/record_helper.py +++ b/airbyte_cdk/sources/utils/record_helper.py @@ -5,7 +5,12 @@ from collections.abc import Mapping as ABCMapping from typing import Any, Mapping, Optional -from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, AirbyteRecordMessage, AirbyteTraceMessage +from airbyte_cdk.models import ( + AirbyteLogMessage, + AirbyteMessage, + AirbyteRecordMessage, + AirbyteTraceMessage, +) from airbyte_cdk.models import Type as MessageType from airbyte_cdk.models.file_transfer_record_message import AirbyteFileTransferRecordMessage from airbyte_cdk.sources.streams.core import StreamData @@ -32,7 +37,9 @@ def stream_data_to_airbyte_message( # docs/connector-development/cdk-python/schemas.md for details. transformer.transform(data, schema) # type: ignore if is_file_transfer_message: - message = AirbyteFileTransferRecordMessage(stream=stream_name, file=data, emitted_at=now_millis, data={}) + message = AirbyteFileTransferRecordMessage( + stream=stream_name, file=data, emitted_at=now_millis, data={} + ) else: message = AirbyteRecordMessage(stream=stream_name, data=data, emitted_at=now_millis) return AirbyteMessage(type=MessageType.RECORD, record=message) @@ -41,4 +48,6 @@ def stream_data_to_airbyte_message( case AirbyteLogMessage(): return AirbyteMessage(type=MessageType.LOG, log=data_or_message) case _: - raise ValueError(f"Unexpected type for data_or_message: {type(data_or_message)}: {data_or_message}") + raise ValueError( + f"Unexpected type for data_or_message: {type(data_or_message)}: {data_or_message}" + ) diff --git a/airbyte_cdk/sources/utils/schema_helpers.py b/airbyte_cdk/sources/utils/schema_helpers.py index 7eef091a..5b1287ac 100644 --- a/airbyte_cdk/sources/utils/schema_helpers.py +++ b/airbyte_cdk/sources/utils/schema_helpers.py @@ -74,7 +74,9 @@ def _expand_refs(schema: Any, ref_resolver: Optional[RefResolver] = None) -> Non if "$ref" in schema: ref_url = schema.pop("$ref") _, definition = ref_resolver.resolve(ref_url) - _expand_refs(definition, ref_resolver=ref_resolver) # expand refs in definitions as well + _expand_refs( + definition, ref_resolver=ref_resolver + ) # expand refs in definitions as well schema.update(definition) else: for key, value in schema.items(): @@ -152,7 +154,9 @@ def _resolve_schema_references(self, raw_schema: dict[str, Any]) -> dict[str, An base = os.path.dirname(package.__file__) + "/" else: raise ValueError(f"Package {package} does not have a valid __file__ field") - resolved = jsonref.JsonRef.replace_refs(raw_schema, loader=JsonFileLoader(base, "schemas/shared"), base_uri=base) + resolved = jsonref.JsonRef.replace_refs( + raw_schema, loader=JsonFileLoader(base, "schemas/shared"), base_uri=base + ) resolved = resolve_ref_links(resolved) if isinstance(resolved, dict): return resolved @@ -160,7 +164,9 @@ def _resolve_schema_references(self, raw_schema: dict[str, Any]) -> dict[str, An raise ValueError(f"Expected resolved to be a dict. Got {resolved}") -def check_config_against_spec_or_exit(config: Mapping[str, Any], spec: ConnectorSpecification) -> None: +def check_config_against_spec_or_exit( + config: Mapping[str, Any], spec: ConnectorSpecification +) -> None: """ Check config object against spec. In case of spec is invalid, throws an exception with validation error description. diff --git a/airbyte_cdk/sources/utils/slice_logger.py b/airbyte_cdk/sources/utils/slice_logger.py index 6981cdde..ee802a7a 100644 --- a/airbyte_cdk/sources/utils/slice_logger.py +++ b/airbyte_cdk/sources/utils/slice_logger.py @@ -27,7 +27,10 @@ def create_slice_log_message(self, _slice: Optional[Mapping[str, Any]]) -> Airby printable_slice = dict(_slice) if _slice else _slice return AirbyteMessage( type=MessageType.LOG, - log=AirbyteLogMessage(level=Level.INFO, message=f"{SliceLogger.SLICE_LOG_PREFIX}{json.dumps(printable_slice, default=str)}"), + log=AirbyteLogMessage( + level=Level.INFO, + message=f"{SliceLogger.SLICE_LOG_PREFIX}{json.dumps(printable_slice, default=str)}", + ), ) @abstractmethod diff --git a/airbyte_cdk/sources/utils/transform.py b/airbyte_cdk/sources/utils/transform.py index b15ff11d..ef52c5fd 100644 --- a/airbyte_cdk/sources/utils/transform.py +++ b/airbyte_cdk/sources/utils/transform.py @@ -9,7 +9,13 @@ from jsonschema import Draft7Validator, ValidationError, validators -json_to_python_simple = {"string": str, "number": float, "integer": int, "boolean": bool, "null": type(None)} +json_to_python_simple = { + "string": str, + "number": float, + "integer": int, + "boolean": bool, + "null": type(None), +} json_to_python = {**json_to_python_simple, **{"object": dict, "array": list}} python_to_json = {v: k for k, v in json_to_python.items()} @@ -56,9 +62,13 @@ def __init__(self, config: TransformConfig): # Do not validate field we do not transform for maximum performance. if key in ["type", "array", "$ref", "properties", "items"] } - self._normalizer = validators.create(meta_schema=Draft7Validator.META_SCHEMA, validators=all_validators) + self._normalizer = validators.create( + meta_schema=Draft7Validator.META_SCHEMA, validators=all_validators + ) - def registerCustomTransform(self, normalization_callback: Callable[[Any, Dict[str, Any]], Any]) -> Callable: + def registerCustomTransform( + self, normalization_callback: Callable[[Any, Dict[str, Any]], Any] + ) -> Callable: """ Register custom normalization callback. :param normalization_callback function to be used for value @@ -68,7 +78,9 @@ def registerCustomTransform(self, normalization_callback: Callable[[Any, Dict[st :return Same callbeck, this is usefull for using registerCustomTransform function as decorator. """ if TransformConfig.CustomSchemaNormalization not in self._config: - raise Exception("Please set TransformConfig.CustomSchemaNormalization config before registering custom normalizer") + raise Exception( + "Please set TransformConfig.CustomSchemaNormalization config before registering custom normalizer" + ) self._custom_normalizer = normalization_callback return normalization_callback @@ -120,7 +132,10 @@ def default_convert(original_item: Any, subschema: Dict[str, Any]) -> Any: return bool(original_item) elif target_type == "array": item_types = set(subschema.get("items", {}).get("type", set())) - if item_types.issubset(json_to_python_simple) and type(original_item) in json_to_python_simple.values(): + if ( + item_types.issubset(json_to_python_simple) + and type(original_item) in json_to_python_simple.values() + ): return [original_item] except (ValueError, TypeError): return original_item @@ -133,7 +148,9 @@ def __get_normalizer(self, schema_key: str, original_validator: Callable): :original_validator: native jsonschema validator callback. """ - def normalizator(validator_instance: Callable, property_value: Any, instance: Any, schema: Dict[str, Any]): + def normalizator( + validator_instance: Callable, property_value: Any, instance: Any, schema: Dict[str, Any] + ): """ Jsonschema validator callable it uses for validating instance. We override default Draft7Validator to perform value transformation @@ -191,6 +208,4 @@ def transform(self, record: Dict[str, Any], schema: Mapping[str, Any]): def get_error_message(self, e: ValidationError) -> str: instance_json_type = python_to_json[type(e.instance)] key_path = "." + ".".join(map(str, e.path)) - return ( - f"Failed to transform value {repr(e.instance)} of type '{instance_json_type}' to '{e.validator_value}', key path: '{key_path}'" - ) + return f"Failed to transform value {repr(e.instance)} of type '{instance_json_type}' to '{e.validator_value}', key path: '{key_path}'" diff --git a/airbyte_cdk/sql/exceptions.py b/airbyte_cdk/sql/exceptions.py index 0192d829..963dc469 100644 --- a/airbyte_cdk/sql/exceptions.py +++ b/airbyte_cdk/sql/exceptions.py @@ -90,12 +90,18 @@ def __str__(self) -> str: "original_exception", ] display_properties = { - k: v for k, v in self.__dict__.items() if k not in special_properties and not k.startswith("_") and v is not None + k: v + for k, v in self.__dict__.items() + if k not in special_properties and not k.startswith("_") and v is not None } display_properties.update(self.context or {}) - context_str = "\n ".join(f"{str(k).replace('_', ' ').title()}: {v!r}" for k, v in display_properties.items()) + context_str = "\n ".join( + f"{str(k).replace('_', ' ').title()}: {v!r}" for k, v in display_properties.items() + ) exception_str = ( - f"{self.get_message()} ({self.__class__.__name__})" + VERTICAL_SEPARATOR + f"\n{self.__class__.__name__}: {self.get_message()}" + f"{self.get_message()} ({self.__class__.__name__})" + + VERTICAL_SEPARATOR + + f"\n{self.__class__.__name__}: {self.get_message()}" ) if self.guidance: @@ -124,7 +130,9 @@ def __str__(self) -> str: def __repr__(self) -> str: """Return a string representation of the exception.""" class_name = self.__class__.__name__ - properties_str = ", ".join(f"{k}={v!r}" for k, v in self.__dict__.items() if not k.startswith("_")) + properties_str = ", ".join( + f"{k}={v!r}" for k, v in self.__dict__.items() if not k.startswith("_") + ) return f"{class_name}({properties_str})" def safe_logging_dict(self) -> dict[str, Any]: @@ -180,7 +188,10 @@ class AirbyteInputError(AirbyteError, ValueError): class AirbyteNameNormalizationError(AirbyteError, ValueError): """Error occurred while normalizing a table or column name.""" - guidance = "Please consider renaming the source object if possible, or " "raise an issue in GitHub if not." + guidance = ( + "Please consider renaming the source object if possible, or " + "raise an issue in GitHub if not." + ) help_url = NEW_ISSUE_URL raw_name: str | None = None @@ -205,7 +216,9 @@ def _get_log_file(self) -> Path | None: logger = logging.getLogger(f"airbyte.{self.connector_name}") log_paths: list[Path] = [ - Path(handler.baseFilename).absolute() for handler in logger.handlers if isinstance(handler, logging.FileHandler) + Path(handler.baseFilename).absolute() + for handler in logger.handlers + if isinstance(handler, logging.FileHandler) ] if log_paths: diff --git a/airbyte_cdk/sql/secrets.py b/airbyte_cdk/sql/secrets.py index aaf7641a..bff9e810 100644 --- a/airbyte_cdk/sql/secrets.py +++ b/airbyte_cdk/sql/secrets.py @@ -101,7 +101,9 @@ def __get_pydantic_core_schema__( # noqa: PLW3201 # Pydantic dunder handler: GetCoreSchemaHandler, ) -> CoreSchema: """Return a modified core schema for the secret string.""" - return core_schema.with_info_after_validator_function(function=cls.validate, schema=handler(str), field_name=handler.field_name) + return core_schema.with_info_after_validator_function( + function=cls.validate, schema=handler(str), field_name=handler.field_name + ) @classmethod def __get_pydantic_json_schema__( # noqa: PLW3201 # Pydantic dunder method diff --git a/airbyte_cdk/sql/shared/catalog_providers.py b/airbyte_cdk/sql/shared/catalog_providers.py index 8d139c9c..80713a35 100644 --- a/airbyte_cdk/sql/shared/catalog_providers.py +++ b/airbyte_cdk/sql/shared/catalog_providers.py @@ -77,13 +77,17 @@ def get_configured_stream_info( ) matching_streams: list[ConfiguredAirbyteStream] = [ - stream for stream in self.configured_catalog.streams if stream.stream.name == stream_name + stream + for stream in self.configured_catalog.streams + if stream.stream.name == stream_name ] if not matching_streams: raise exc.AirbyteStreamNotFoundError( stream_name=stream_name, context={ - "available_streams": [stream.stream.name for stream in self.configured_catalog.streams], + "available_streams": [ + stream.stream.name for stream in self.configured_catalog.streams + ], }, ) @@ -121,12 +125,17 @@ def get_primary_keys( if not pks: return [] - normalized_pks: list[list[str]] = [[LowerCaseNormalizer.normalize(c) for c in pk] for pk in pks] + normalized_pks: list[list[str]] = [ + [LowerCaseNormalizer.normalize(c) for c in pk] for pk in pks + ] for pk_nodes in normalized_pks: if len(pk_nodes) != 1: raise exc.AirbyteError( - message=("Nested primary keys are not supported. " "Each PK column should have exactly one node. "), + message=( + "Nested primary keys are not supported. " + "Each PK column should have exactly one node. " + ), context={ "stream_name": stream_name, "primary_key_nodes": pk_nodes, diff --git a/airbyte_cdk/sql/shared/sql_processor.py b/airbyte_cdk/sql/shared/sql_processor.py index 52a8e52d..dd8cb3e5 100644 --- a/airbyte_cdk/sql/shared/sql_processor.py +++ b/airbyte_cdk/sql/shared/sql_processor.py @@ -16,7 +16,12 @@ from airbyte_cdk.sql import exceptions as exc from airbyte_cdk.sql._util.hashing import one_way_hash from airbyte_cdk.sql._util.name_normalizers import LowerCaseNormalizer -from airbyte_cdk.sql.constants import AB_EXTRACTED_AT_COLUMN, AB_META_COLUMN, AB_RAW_ID_COLUMN, DEBUG_MODE +from airbyte_cdk.sql.constants import ( + AB_EXTRACTED_AT_COLUMN, + AB_META_COLUMN, + AB_RAW_ID_COLUMN, + DEBUG_MODE, +) from airbyte_cdk.sql.secrets import SecretString from airbyte_cdk.sql.types import SQLTypeConverter from airbyte_protocol_dataclasses.models import AirbyteStateMessage @@ -100,7 +105,9 @@ def get_vendor_client(self) -> object: Raises `NotImplementedError` if a custom vendor client is not defined. """ - raise NotImplementedError(f"The type '{type(self).__name__}' does not define a custom client.") + raise NotImplementedError( + f"The type '{type(self).__name__}' does not define a custom client." + ) class SqlProcessorBase(abc.ABC): @@ -270,7 +277,9 @@ def _get_table_by_name( query. To ignore the cache and force a refresh, set 'force_refresh' to True. """ if force_refresh and shallow_okay: - raise exc.AirbyteInternalError(message="Cannot force refresh and use shallow query at the same time.") + raise exc.AirbyteInternalError( + message="Cannot force refresh and use shallow query at the same time." + ) if force_refresh and table_name in self._cached_table_definitions: self._invalidate_table_cache(table_name) @@ -315,7 +324,9 @@ def _ensure_schema_exists( if DEBUG_MODE: found_schemas = schemas_list - assert schema_name in found_schemas, f"Schema {schema_name} was not created. Found: {found_schemas}" + assert ( + schema_name in found_schemas + ), f"Schema {schema_name} was not created. Found: {found_schemas}" def _quote_identifier(self, identifier: str) -> str: """Return the given identifier, quoted.""" @@ -387,7 +398,8 @@ def _get_schemas_list( self._known_schemas_list = [ found_schema.split(".")[-1].strip('"') for found_schema in found_schemas - if "." not in found_schema or (found_schema.split(".")[0].lower().strip('"') == database_name.lower()) + if "." not in found_schema + or (found_schema.split(".")[0].lower().strip('"') == database_name.lower()) ] return self._known_schemas_list @@ -511,7 +523,9 @@ def _write_files_to_new_table( for file_path in files: dataframe = pd.read_json(file_path, lines=True) - sql_column_definitions: dict[str, TypeEngine[Any]] = self._get_sql_column_definitions(stream_name) + sql_column_definitions: dict[str, TypeEngine[Any]] = self._get_sql_column_definitions( + stream_name + ) # Remove fields that are not in the schema for col_name in dataframe.columns: @@ -549,7 +563,10 @@ def _add_column_to_table( ) -> None: """Add a column to the given table.""" self._execute_sql( - text(f"ALTER TABLE {self._fully_qualified(table.name)} " f"ADD COLUMN {column_name} {column_type}"), + text( + f"ALTER TABLE {self._fully_qualified(table.name)} " + f"ADD COLUMN {column_name} {column_type}" + ), ) def _add_missing_columns_to_table( @@ -626,8 +643,10 @@ def _swap_temp_table_with_final_table( deletion_name = f"{final_table_name}_deleteme" commands = "\n".join( [ - f"ALTER TABLE {self._fully_qualified(final_table_name)} RENAME " f"TO {deletion_name};", - f"ALTER TABLE {self._fully_qualified(temp_table_name)} RENAME " f"TO {final_table_name};", + f"ALTER TABLE {self._fully_qualified(final_table_name)} RENAME " + f"TO {deletion_name};", + f"ALTER TABLE {self._fully_qualified(temp_table_name)} RENAME " + f"TO {final_table_name};", f"DROP TABLE {self._fully_qualified(deletion_name)};", ] ) @@ -646,7 +665,9 @@ def _merge_temp_table_to_final_table( """ nl = "\n" columns = {self._quote_identifier(c) for c in self._get_sql_column_definitions(stream_name)} - pk_columns = {self._quote_identifier(c) for c in self.catalog_provider.get_primary_keys(stream_name)} + pk_columns = { + self._quote_identifier(c) for c in self.catalog_provider.get_primary_keys(stream_name) + } non_pk_columns = columns - pk_columns join_clause = f"{nl} AND ".join(f"tmp.{pk_col} = final.{pk_col}" for pk_col in pk_columns) set_clause = f"{nl} , ".join(f"{col} = tmp.{col}" for col in non_pk_columns) @@ -704,16 +725,23 @@ def _emulated_merge_temp_table_to_final_table( temp_table = self._get_table_by_name(temp_table_name) pk_columns = self.catalog_provider.get_primary_keys(stream_name) - columns_to_update: set[str] = self._get_sql_column_definitions(stream_name=stream_name).keys() - set(pk_columns) + columns_to_update: set[str] = self._get_sql_column_definitions( + stream_name=stream_name + ).keys() - set(pk_columns) # Create a dictionary mapping columns in users_final to users_stage for updating update_values = { - self._get_column_by_name(final_table, column): (self._get_column_by_name(temp_table, column)) for column in columns_to_update + self._get_column_by_name(final_table, column): ( + self._get_column_by_name(temp_table, column) + ) + for column in columns_to_update } # Craft the WHERE clause for composite primary keys join_conditions = [ - self._get_column_by_name(final_table, pk_column) == self._get_column_by_name(temp_table, pk_column) for pk_column in pk_columns + self._get_column_by_name(final_table, pk_column) + == self._get_column_by_name(temp_table, pk_column) + for pk_column in pk_columns ] join_clause = and_(*join_conditions) @@ -728,7 +756,9 @@ def _emulated_merge_temp_table_to_final_table( where_not_exists_clause = self._get_column_by_name(final_table, pk_columns[0]) == null() # Select records from temp_table that are not in final_table - select_new_records_stmt = select(temp_table).select_from(joined_table).where(where_not_exists_clause) + select_new_records_stmt = ( + select(temp_table).select_from(joined_table).where(where_not_exists_clause) + ) # Craft the INSERT statement using the select statement insert_new_records_stmt = insert(final_table).from_select( diff --git a/airbyte_cdk/test/catalog_builder.py b/airbyte_cdk/test/catalog_builder.py index ac02a561..b1bf4341 100644 --- a/airbyte_cdk/test/catalog_builder.py +++ b/airbyte_cdk/test/catalog_builder.py @@ -2,7 +2,12 @@ from typing import Any, Dict, List, Union, overload -from airbyte_cdk.models import ConfiguredAirbyteCatalog, ConfiguredAirbyteStream, ConfiguredAirbyteStreamSerializer, SyncMode +from airbyte_cdk.models import ( + ConfiguredAirbyteCatalog, + ConfiguredAirbyteStream, + ConfiguredAirbyteStreamSerializer, + SyncMode, +) class ConfiguredAirbyteStreamBuilder: @@ -50,7 +55,11 @@ def with_stream(self, name: ConfiguredAirbyteStreamBuilder) -> "CatalogBuilder": @overload def with_stream(self, name: str, sync_mode: SyncMode) -> "CatalogBuilder": ... - def with_stream(self, name: Union[str, ConfiguredAirbyteStreamBuilder], sync_mode: Union[SyncMode, None] = None) -> "CatalogBuilder": + def with_stream( + self, + name: Union[str, ConfiguredAirbyteStreamBuilder], + sync_mode: Union[SyncMode, None] = None, + ) -> "CatalogBuilder": # As we are introducing a fully fledge ConfiguredAirbyteStreamBuilder, we would like to deprecate the previous interface # with_stream(str, SyncMode) @@ -59,10 +68,14 @@ def with_stream(self, name: Union[str, ConfiguredAirbyteStreamBuilder], sync_mod builder = ( name_or_builder if isinstance(name_or_builder, ConfiguredAirbyteStreamBuilder) - else ConfiguredAirbyteStreamBuilder().with_name(name_or_builder).with_sync_mode(sync_mode) + else ConfiguredAirbyteStreamBuilder() + .with_name(name_or_builder) + .with_sync_mode(sync_mode) ) self._streams.append(builder) return self def build(self) -> ConfiguredAirbyteCatalog: - return ConfiguredAirbyteCatalog(streams=list(map(lambda builder: builder.build(), self._streams))) + return ConfiguredAirbyteCatalog( + streams=list(map(lambda builder: builder.build(), self._streams)) + ) diff --git a/airbyte_cdk/test/entrypoint_wrapper.py b/airbyte_cdk/test/entrypoint_wrapper.py index 9cc74ec2..5e7a80da 100644 --- a/airbyte_cdk/test/entrypoint_wrapper.py +++ b/airbyte_cdk/test/entrypoint_wrapper.py @@ -53,7 +53,11 @@ def __init__(self, messages: List[str], uncaught_exception: Optional[BaseExcepti raise ValueError("All messages are expected to be AirbyteMessage") from exception if uncaught_exception: - self._messages.append(assemble_uncaught_exception(type(uncaught_exception), uncaught_exception).as_airbyte_message()) + self._messages.append( + assemble_uncaught_exception( + type(uncaught_exception), uncaught_exception + ).as_airbyte_message() + ) @staticmethod def _parse_message(message: str) -> AirbyteMessage: @@ -61,7 +65,9 @@ def _parse_message(message: str) -> AirbyteMessage: return AirbyteMessageSerializer.load(orjson.loads(message)) # type: ignore[no-any-return] # Serializer.load() always returns AirbyteMessage except (orjson.JSONDecodeError, SchemaValidationError): # The platform assumes that logs that are not of AirbyteMessage format are log messages - return AirbyteMessage(type=Type.LOG, log=AirbyteLogMessage(level=Level.INFO, message=message)) + return AirbyteMessage( + type=Type.LOG, log=AirbyteLogMessage(level=Level.INFO, message=message) + ) @property def records_and_state_messages(self) -> List[AirbyteMessage]: @@ -119,18 +125,26 @@ def _get_message_by_types(self, message_types: List[Type]) -> List[AirbyteMessag return [message for message in self._messages if message.type in message_types] def _get_trace_message_by_trace_type(self, trace_type: TraceType) -> List[AirbyteMessage]: - return [message for message in self._get_message_by_types([Type.TRACE]) if message.trace.type == trace_type] # type: ignore[union-attr] # trace has `type` + return [ + message + for message in self._get_message_by_types([Type.TRACE]) + if message.trace.type == trace_type + ] # type: ignore[union-attr] # trace has `type` def is_in_logs(self, pattern: str) -> bool: """Check if any log message case-insensitive matches the pattern.""" - return any(re.search(pattern, entry.log.message, flags=re.IGNORECASE) for entry in self.logs) # type: ignore[union-attr] # log has `message` + return any( + re.search(pattern, entry.log.message, flags=re.IGNORECASE) for entry in self.logs + ) # type: ignore[union-attr] # log has `message` def is_not_in_logs(self, pattern: str) -> bool: """Check if no log message matches the case-insensitive pattern.""" return not self.is_in_logs(pattern) -def _run_command(source: Source, args: List[str], expecting_exception: bool = False) -> EntrypointOutput: +def _run_command( + source: Source, args: List[str], expecting_exception: bool = False +) -> EntrypointOutput: log_capture_buffer = StringIO() stream_handler = logging.StreamHandler(log_capture_buffer) stream_handler.setLevel(logging.INFO) @@ -174,7 +188,9 @@ def discover( tmp_directory_path = Path(tmp_directory) config_file = make_file(tmp_directory_path / "config.json", config) - return _run_command(source, ["discover", "--config", config_file, "--debug"], expecting_exception) + return _run_command( + source, ["discover", "--config", config_file, "--debug"], expecting_exception + ) def read( @@ -194,7 +210,8 @@ def read( tmp_directory_path = Path(tmp_directory) config_file = make_file(tmp_directory_path / "config.json", config) catalog_file = make_file( - tmp_directory_path / "catalog.json", orjson.dumps(ConfiguredAirbyteCatalogSerializer.dump(catalog)).decode() + tmp_directory_path / "catalog.json", + orjson.dumps(ConfiguredAirbyteCatalogSerializer.dump(catalog)).decode(), ) args = [ "read", @@ -217,7 +234,9 @@ def read( return _run_command(source, args, expecting_exception) -def make_file(path: Path, file_contents: Optional[Union[str, Mapping[str, Any], List[Mapping[str, Any]]]]) -> str: +def make_file( + path: Path, file_contents: Optional[Union[str, Mapping[str, Any], List[Mapping[str, Any]]]] +) -> str: if isinstance(file_contents, str): path.write_text(file_contents) else: diff --git a/airbyte_cdk/test/mock_http/mocker.py b/airbyte_cdk/test/mock_http/mocker.py index 4ac690dc..a62c46a5 100644 --- a/airbyte_cdk/test/mock_http/mocker.py +++ b/airbyte_cdk/test/mock_http/mocker.py @@ -42,7 +42,12 @@ def __enter__(self) -> "HttpMocker": self._mocker.__enter__() return self - def __exit__(self, exc_type: Optional[BaseException], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]) -> None: + def __exit__( + self, + exc_type: Optional[BaseException], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: self._mocker.__exit__(exc_type, exc_val, exc_tb) def _validate_all_matchers_called(self) -> None: @@ -51,7 +56,10 @@ def _validate_all_matchers_called(self) -> None: raise ValueError(f"Invalid number of matches for `{matcher}`") def _mock_request_method( - self, method: SupportedHttpMethods, request: HttpRequest, responses: Union[HttpResponse, List[HttpResponse]] + self, + method: SupportedHttpMethods, + request: HttpRequest, + responses: Union[HttpResponse, List[HttpResponse]], ) -> None: if isinstance(responses, HttpResponse): responses = [responses] @@ -65,37 +73,57 @@ def _mock_request_method( requests_mock.ANY, additional_matcher=self._matches_wrapper(matcher), response_list=[ - {"text": response.body, "status_code": response.status_code, "headers": response.headers} for response in responses + { + "text": response.body, + "status_code": response.status_code, + "headers": response.headers, + } + for response in responses ], ) def get(self, request: HttpRequest, responses: Union[HttpResponse, List[HttpResponse]]) -> None: self._mock_request_method(SupportedHttpMethods.GET, request, responses) - def patch(self, request: HttpRequest, responses: Union[HttpResponse, List[HttpResponse]]) -> None: + def patch( + self, request: HttpRequest, responses: Union[HttpResponse, List[HttpResponse]] + ) -> None: self._mock_request_method(SupportedHttpMethods.PATCH, request, responses) - def post(self, request: HttpRequest, responses: Union[HttpResponse, List[HttpResponse]]) -> None: + def post( + self, request: HttpRequest, responses: Union[HttpResponse, List[HttpResponse]] + ) -> None: self._mock_request_method(SupportedHttpMethods.POST, request, responses) - def delete(self, request: HttpRequest, responses: Union[HttpResponse, List[HttpResponse]]) -> None: + def delete( + self, request: HttpRequest, responses: Union[HttpResponse, List[HttpResponse]] + ) -> None: self._mock_request_method(SupportedHttpMethods.DELETE, request, responses) @staticmethod - def _matches_wrapper(matcher: HttpRequestMatcher) -> Callable[[requests_mock.request._RequestObjectProxy], bool]: + def _matches_wrapper( + matcher: HttpRequestMatcher, + ) -> Callable[[requests_mock.request._RequestObjectProxy], bool]: def matches(requests_mock_request: requests_mock.request._RequestObjectProxy) -> bool: # query_params are provided as part of `requests_mock_request.url` http_request = HttpRequest( - requests_mock_request.url, query_params={}, headers=requests_mock_request.headers, body=requests_mock_request.body + requests_mock_request.url, + query_params={}, + headers=requests_mock_request.headers, + body=requests_mock_request.body, ) return matcher.matches(http_request) return matches def assert_number_of_calls(self, request: HttpRequest, number_of_calls: int) -> None: - corresponding_matchers = list(filter(lambda matcher: matcher.request == request, self._matchers)) + corresponding_matchers = list( + filter(lambda matcher: matcher.request == request, self._matchers) + ) if len(corresponding_matchers) != 1: - raise ValueError(f"Was expecting only one matcher to match the request but got `{corresponding_matchers}`") + raise ValueError( + f"Was expecting only one matcher to match the request but got `{corresponding_matchers}`" + ) assert corresponding_matchers[0].actual_number_of_matches == number_of_calls @@ -110,7 +138,9 @@ def wrapper(*args, **kwargs): # type: ignore # this is a very generic wrapper try: result = f(*args, **kwargs) except requests_mock.NoMockAddress as no_mock_exception: - matchers_as_string = "\n\t".join(map(lambda matcher: str(matcher.request), self._matchers)) + matchers_as_string = "\n\t".join( + map(lambda matcher: str(matcher.request), self._matchers) + ) raise ValueError( f"No matcher matches {no_mock_exception.args[0]} with headers `{no_mock_exception.request.headers}` " f"and body `{no_mock_exception.request.body}`. " diff --git a/airbyte_cdk/test/mock_http/request.py b/airbyte_cdk/test/mock_http/request.py index 756be23e..7209513d 100644 --- a/airbyte_cdk/test/mock_http/request.py +++ b/airbyte_cdk/test/mock_http/request.py @@ -24,7 +24,9 @@ def __init__( if not self._parsed_url.query and query_params: self._parsed_url = urlparse(f"{url}?{self._encode_qs(query_params)}") elif self._parsed_url.query and query_params: - raise ValueError("If query params are provided as part of the url, `query_params` should be empty") + raise ValueError( + "If query params are provided as part of the url, `query_params` should be empty" + ) self._headers = headers or {} self._body = body @@ -62,7 +64,9 @@ def matches(self, other: Any) -> bool: return False @staticmethod - def _to_mapping(body: Optional[Union[str, bytes, Mapping[str, Any]]]) -> Optional[Mapping[str, Any]]: + def _to_mapping( + body: Optional[Union[str, bytes, Mapping[str, Any]]], + ) -> Optional[Mapping[str, Any]]: if isinstance(body, Mapping): return body elif isinstance(body, bytes): @@ -84,7 +88,9 @@ def __str__(self) -> str: return f"{self._parsed_url} with headers {self._headers} and body {self._body!r})" def __repr__(self) -> str: - return f"HttpRequest(request={self._parsed_url}, headers={self._headers}, body={self._body!r})" + return ( + f"HttpRequest(request={self._parsed_url}, headers={self._headers}, body={self._body!r})" + ) def __eq__(self, other: Any) -> bool: if isinstance(other, HttpRequest): diff --git a/airbyte_cdk/test/mock_http/response.py b/airbyte_cdk/test/mock_http/response.py index 8d5dc4c3..848be55a 100644 --- a/airbyte_cdk/test/mock_http/response.py +++ b/airbyte_cdk/test/mock_http/response.py @@ -5,7 +5,9 @@ class HttpResponse: - def __init__(self, body: str, status_code: int = 200, headers: Mapping[str, str] = MappingProxyType({})): + def __init__( + self, body: str, status_code: int = 200, headers: Mapping[str, str] = MappingProxyType({}) + ): self._body = body self._status_code = status_code self._headers = headers diff --git a/airbyte_cdk/test/mock_http/response_builder.py b/airbyte_cdk/test/mock_http/response_builder.py index 27bb5125..b517343e 100644 --- a/airbyte_cdk/test/mock_http/response_builder.py +++ b/airbyte_cdk/test/mock_http/response_builder.py @@ -91,7 +91,12 @@ def update(self, response: Dict[str, Any]) -> None: class RecordBuilder: - def __init__(self, template: Dict[str, Any], id_path: Optional[Path], cursor_path: Optional[Union[FieldPath, NestedPath]]): + def __init__( + self, + template: Dict[str, Any], + id_path: Optional[Path], + cursor_path: Optional[Union[FieldPath, NestedPath]], + ): self._record = template self._id_path = id_path self._cursor_path = cursor_path @@ -109,9 +114,13 @@ def _validate_template(self) -> None: def _validate_field(self, field_name: str, path: Optional[Path]) -> None: try: if path and not path.extract(self._record): - raise ValueError(f"{field_name} `{path}` was provided but it is not part of the template `{self._record}`") + raise ValueError( + f"{field_name} `{path}` was provided but it is not part of the template `{self._record}`" + ) except (IndexError, KeyError) as exception: - raise ValueError(f"{field_name} `{path}` was provided but it is not part of the template `{self._record}`") from exception + raise ValueError( + f"{field_name} `{path}` was provided but it is not part of the template `{self._record}`" + ) from exception def with_id(self, identifier: Any) -> "RecordBuilder": self._set_field("id", self._id_path, identifier) @@ -139,7 +148,10 @@ def build(self) -> Dict[str, Any]: class HttpResponseBuilder: def __init__( - self, template: Dict[str, Any], records_path: Union[FieldPath, NestedPath], pagination_strategy: Optional[PaginationStrategy] + self, + template: Dict[str, Any], + records_path: Union[FieldPath, NestedPath], + pagination_strategy: Optional[PaginationStrategy], ): self._response = template self._records: List[RecordBuilder] = [] @@ -175,7 +187,13 @@ def _get_unit_test_folder(execution_folder: str) -> FilePath: def find_template(resource: str, execution_folder: str) -> Dict[str, Any]: - response_template_filepath = str(get_unit_test_folder(execution_folder) / "resource" / "http" / "response" / f"{resource}.json") + response_template_filepath = str( + get_unit_test_folder(execution_folder) + / "resource" + / "http" + / "response" + / f"{resource}.json" + ) with open(response_template_filepath, "r") as template_file: return json.load(template_file) # type: ignore # we assume the dev correctly set up the resource file @@ -198,10 +216,14 @@ def create_record_builder( ) return RecordBuilder(record_template, record_id_path, record_cursor_path) except (IndexError, KeyError): - raise ValueError(f"Error while extracting records at path `{records_path}` from response template `{response_template}`") + raise ValueError( + f"Error while extracting records at path `{records_path}` from response template `{response_template}`" + ) def create_response_builder( - response_template: Dict[str, Any], records_path: Union[FieldPath, NestedPath], pagination_strategy: Optional[PaginationStrategy] = None + response_template: Dict[str, Any], + records_path: Union[FieldPath, NestedPath], + pagination_strategy: Optional[PaginationStrategy] = None, ) -> HttpResponseBuilder: return HttpResponseBuilder(response_template, records_path, pagination_strategy) diff --git a/airbyte_cdk/test/state_builder.py b/airbyte_cdk/test/state_builder.py index 50b5dbe5..a1315cf4 100644 --- a/airbyte_cdk/test/state_builder.py +++ b/airbyte_cdk/test/state_builder.py @@ -2,7 +2,13 @@ from typing import Any, List -from airbyte_cdk.models import AirbyteStateBlob, AirbyteStateMessage, AirbyteStateType, AirbyteStreamState, StreamDescriptor +from airbyte_cdk.models import ( + AirbyteStateBlob, + AirbyteStateMessage, + AirbyteStateType, + AirbyteStreamState, + StreamDescriptor, +) class StateBuilder: @@ -14,7 +20,9 @@ def with_stream_state(self, stream_name: str, state: Any) -> "StateBuilder": AirbyteStateMessage( type=AirbyteStateType.STREAM, stream=AirbyteStreamState( - stream_state=state if isinstance(state, AirbyteStateBlob) else AirbyteStateBlob(state), + stream_state=state + if isinstance(state, AirbyteStateBlob) + else AirbyteStateBlob(state), stream_descriptor=StreamDescriptor(**{"name": stream_name}), ), ) diff --git a/airbyte_cdk/test/utils/data.py b/airbyte_cdk/test/utils/data.py index a4d4fef6..6aaeb839 100644 --- a/airbyte_cdk/test/utils/data.py +++ b/airbyte_cdk/test/utils/data.py @@ -7,14 +7,18 @@ def get_unit_test_folder(execution_folder: str) -> FilePath: path = FilePath(execution_folder) while path.name != "unit_tests": if path.name == path.root or path.name == path.drive: - raise ValueError(f"Could not find `unit_tests` folder as a parent of {execution_folder}") + raise ValueError( + f"Could not find `unit_tests` folder as a parent of {execution_folder}" + ) path = path.parent return path def read_resource_file_contents(resource: str, test_location: str) -> str: """Read the contents of a test data file from the test resource folder.""" - file_path = str(get_unit_test_folder(test_location) / "resource" / "http" / "response" / f"{resource}") + file_path = str( + get_unit_test_folder(test_location) / "resource" / "http" / "response" / f"{resource}" + ) with open(file_path) as f: response = f.read() return response diff --git a/airbyte_cdk/test/utils/http_mocking.py b/airbyte_cdk/test/utils/http_mocking.py index 0cdd8f4c..7fd1419f 100644 --- a/airbyte_cdk/test/utils/http_mocking.py +++ b/airbyte_cdk/test/utils/http_mocking.py @@ -6,7 +6,9 @@ from requests_mock import Mocker -def register_mock_responses(mocker: Mocker, http_calls: list[Mapping[str, Mapping[str, Any]]]) -> None: +def register_mock_responses( + mocker: Mocker, http_calls: list[Mapping[str, Mapping[str, Any]]] +) -> None: """Register a list of HTTP request-response pairs.""" for call in http_calls: request, response = call["request"], call["response"] diff --git a/airbyte_cdk/utils/airbyte_secrets_utils.py b/airbyte_cdk/utils/airbyte_secrets_utils.py index 5afd305f..45279e57 100644 --- a/airbyte_cdk/utils/airbyte_secrets_utils.py +++ b/airbyte_cdk/utils/airbyte_secrets_utils.py @@ -36,7 +36,9 @@ def traverse_schema(schema_item: Any, path: List[str]) -> None: return paths -def get_secrets(connection_specification: Mapping[str, Any], config: Mapping[str, Any]) -> List[Any]: +def get_secrets( + connection_specification: Mapping[str, Any], config: Mapping[str, Any] +) -> List[Any]: """ Get a list of secret values from the source config based on the source specification :type connection_specification: the connection_specification field of an AirbyteSpecification i.e the JSONSchema definition diff --git a/airbyte_cdk/utils/analytics_message.py b/airbyte_cdk/utils/analytics_message.py index 54c3e984..82a07491 100644 --- a/airbyte_cdk/utils/analytics_message.py +++ b/airbyte_cdk/utils/analytics_message.py @@ -3,7 +3,13 @@ import time from typing import Any, Optional -from airbyte_cdk.models import AirbyteAnalyticsTraceMessage, AirbyteMessage, AirbyteTraceMessage, TraceType, Type +from airbyte_cdk.models import ( + AirbyteAnalyticsTraceMessage, + AirbyteMessage, + AirbyteTraceMessage, + TraceType, + Type, +) def create_analytics_message(type: str, value: Optional[Any]) -> AirbyteMessage: @@ -12,6 +18,8 @@ def create_analytics_message(type: str, value: Optional[Any]) -> AirbyteMessage: trace=AirbyteTraceMessage( type=TraceType.ANALYTICS, emitted_at=time.time() * 1000, - analytics=AirbyteAnalyticsTraceMessage(type=type, value=str(value) if value is not None else None), + analytics=AirbyteAnalyticsTraceMessage( + type=type, value=str(value) if value is not None else None + ), ), ) diff --git a/airbyte_cdk/utils/datetime_format_inferrer.py b/airbyte_cdk/utils/datetime_format_inferrer.py index cd423db9..28eaefa3 100644 --- a/airbyte_cdk/utils/datetime_format_inferrer.py +++ b/airbyte_cdk/utils/datetime_format_inferrer.py @@ -29,7 +29,10 @@ def __init__(self) -> None: "%Y-%m", "%d-%m-%Y", ] - self._timestamp_heuristic_ranges = [range(1_000_000_000, 2_000_000_000), range(1_000_000_000_000, 2_000_000_000_000)] + self._timestamp_heuristic_ranges = [ + range(1_000_000_000, 2_000_000_000), + range(1_000_000_000_000, 2_000_000_000_000), + ] def _can_be_datetime(self, value: Any) -> bool: """Checks if the value can be a datetime. diff --git a/airbyte_cdk/utils/mapping_helpers.py b/airbyte_cdk/utils/mapping_helpers.py index ae5e898f..469fb5e0 100644 --- a/airbyte_cdk/utils/mapping_helpers.py +++ b/airbyte_cdk/utils/mapping_helpers.py @@ -6,7 +6,9 @@ from typing import Any, List, Mapping, Optional, Set, Union -def combine_mappings(mappings: List[Optional[Union[Mapping[str, Any], str]]]) -> Union[Mapping[str, Any], str]: +def combine_mappings( + mappings: List[Optional[Union[Mapping[str, Any], str]]], +) -> Union[Mapping[str, Any], str]: """ Combine multiple mappings into a single mapping. If any of the mappings are a string, return that string. Raise errors in the following cases: diff --git a/airbyte_cdk/utils/message_utils.py b/airbyte_cdk/utils/message_utils.py index f9c7b65d..7e740b78 100644 --- a/airbyte_cdk/utils/message_utils.py +++ b/airbyte_cdk/utils/message_utils.py @@ -7,13 +7,19 @@ def get_stream_descriptor(message: AirbyteMessage) -> HashableStreamDescriptor: match message.type: case Type.RECORD: - return HashableStreamDescriptor(name=message.record.stream, namespace=message.record.namespace) # type: ignore[union-attr] # record has `stream` and `namespace` + return HashableStreamDescriptor( + name=message.record.stream, namespace=message.record.namespace + ) # type: ignore[union-attr] # record has `stream` and `namespace` case Type.STATE: if not message.state.stream or not message.state.stream.stream_descriptor: # type: ignore[union-attr] # state has `stream` - raise ValueError("State message was not in per-stream state format, which is required for record counts.") + raise ValueError( + "State message was not in per-stream state format, which is required for record counts." + ) return HashableStreamDescriptor( name=message.state.stream.stream_descriptor.name, namespace=message.state.stream.stream_descriptor.namespace, # type: ignore[union-attr] # state has `stream` ) case _: - raise NotImplementedError(f"get_stream_descriptor is not implemented for message type '{message.type}'.") + raise NotImplementedError( + f"get_stream_descriptor is not implemented for message type '{message.type}'." + ) diff --git a/airbyte_cdk/utils/print_buffer.py b/airbyte_cdk/utils/print_buffer.py index 51ca2a84..ae5a2020 100644 --- a/airbyte_cdk/utils/print_buffer.py +++ b/airbyte_cdk/utils/print_buffer.py @@ -65,6 +65,11 @@ def __enter__(self) -> "PrintBuffer": sys.stderr = self return self - def __exit__(self, exc_type: Optional[BaseException], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]) -> None: + def __exit__( + self, + exc_type: Optional[BaseException], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: self.flush() sys.stdout, sys.stderr = self.old_stdout, self.old_stderr diff --git a/airbyte_cdk/utils/schema_inferrer.py b/airbyte_cdk/utils/schema_inferrer.py index fd749850..65d44369 100644 --- a/airbyte_cdk/utils/schema_inferrer.py +++ b/airbyte_cdk/utils/schema_inferrer.py @@ -55,9 +55,14 @@ class NoRequiredSchemaBuilder(SchemaBuilder): class SchemaValidationException(Exception): @classmethod - def merge_exceptions(cls, exceptions: List["SchemaValidationException"]) -> "SchemaValidationException": + def merge_exceptions( + cls, exceptions: List["SchemaValidationException"] + ) -> "SchemaValidationException": # We assume the schema is the same for all SchemaValidationException - return SchemaValidationException(exceptions[0].schema, [x for exception in exceptions for x in exception._validation_errors]) + return SchemaValidationException( + exceptions[0].schema, + [x for exception in exceptions for x in exception._validation_errors], + ) def __init__(self, schema: InferredSchema, validation_errors: List[Exception]): self._schema = schema @@ -84,7 +89,9 @@ class SchemaInferrer: stream_to_builder: Dict[str, SchemaBuilder] - def __init__(self, pk: Optional[List[List[str]]] = None, cursor_field: Optional[List[List[str]]] = None) -> None: + def __init__( + self, pk: Optional[List[List[str]]] = None, cursor_field: Optional[List[List[str]]] = None + ) -> None: self.stream_to_builder = defaultdict(NoRequiredSchemaBuilder) self._pk = [] if pk is None else pk self._cursor_field = [] if cursor_field is None else cursor_field @@ -105,7 +112,9 @@ def _remove_type_from_any_of(self, node: InferredSchema) -> None: def _clean_any_of(self, node: InferredSchema) -> None: if len(node[_ANY_OF]) == 2 and self._null_type_in_any_of(node): - real_type = node[_ANY_OF][1] if node[_ANY_OF][0][_TYPE] == _NULL_TYPE else node[_ANY_OF][0] + real_type = ( + node[_ANY_OF][1] if node[_ANY_OF][0][_TYPE] == _NULL_TYPE else node[_ANY_OF][0] + ) node.update(real_type) node[_TYPE] = [node[_TYPE], _NULL_TYPE] node.pop(_ANY_OF) @@ -189,7 +198,9 @@ def _add_fields_as_required(self, node: InferredSchema, composite_key: List[List if errors: raise SchemaValidationException(node, errors) - def _add_field_as_required(self, node: InferredSchema, path: List[str], traveled_path: Optional[List[str]] = None) -> None: + def _add_field_as_required( + self, node: InferredSchema, path: List[str], traveled_path: Optional[List[str]] = None + ) -> None: """ Take a nested key and travel the schema to mark every node as required. """ @@ -208,7 +219,9 @@ def _add_field_as_required(self, node: InferredSchema, path: List[str], traveled next_node = path[0] if next_node not in node[_PROPERTIES]: - raise ValueError(f"Path {traveled_path} does not have field `{next_node}` in the schema and hence can't be marked as required.") + raise ValueError( + f"Path {traveled_path} does not have field `{next_node}` in the schema and hence can't be marked as required." + ) if _TYPE not in node: # We do not expect this case to happen but we added a specific error message just in case @@ -216,8 +229,14 @@ def _add_field_as_required(self, node: InferredSchema, path: List[str], traveled f"Unknown schema error: {traveled_path} is expected to have a type but did not. Schema inferrence is probably broken" ) - if node[_TYPE] not in [_OBJECT_TYPE, [_NULL_TYPE, _OBJECT_TYPE], [_OBJECT_TYPE, _NULL_TYPE]]: - raise ValueError(f"Path {traveled_path} is expected to be an object but was of type `{node['properties'][next_node]['type']}`") + if node[_TYPE] not in [ + _OBJECT_TYPE, + [_NULL_TYPE, _OBJECT_TYPE], + [_OBJECT_TYPE, _NULL_TYPE], + ]: + raise ValueError( + f"Path {traveled_path} is expected to be an object but was of type `{node['properties'][next_node]['type']}`" + ) if _REQUIRED not in node or not node[_REQUIRED]: node[_REQUIRED] = [next_node] @@ -242,7 +261,9 @@ def get_stream_schema(self, stream_name: str) -> Optional[InferredSchema]: Returns the inferred JSON schema for the specified stream. Might be `None` if there were no records for the given stream name. """ return ( - self._add_required_properties(self._clean(self.stream_to_builder[stream_name].to_schema())) + self._add_required_properties( + self._clean(self.stream_to_builder[stream_name].to_schema()) + ) if stream_name in self.stream_to_builder else None ) diff --git a/airbyte_cdk/utils/spec_schema_transformations.py b/airbyte_cdk/utils/spec_schema_transformations.py index 2a772d50..8d47f83e 100644 --- a/airbyte_cdk/utils/spec_schema_transformations.py +++ b/airbyte_cdk/utils/spec_schema_transformations.py @@ -17,7 +17,9 @@ def resolve_refs(schema: dict) -> dict: str_schema = json.dumps(schema) for ref_block in re.findall(r'{"\$ref": "#\/definitions\/.+?(?="})"}', str_schema): ref = json.loads(ref_block)["$ref"] - str_schema = str_schema.replace(ref_block, json.dumps(json_schema_ref_resolver.resolve(ref)[1])) + str_schema = str_schema.replace( + ref_block, json.dumps(json_schema_ref_resolver.resolve(ref)[1]) + ) pyschema: dict = json.loads(str_schema) del pyschema["definitions"] return pyschema diff --git a/airbyte_cdk/utils/traced_exception.py b/airbyte_cdk/utils/traced_exception.py index bdc975e9..11f60032 100644 --- a/airbyte_cdk/utils/traced_exception.py +++ b/airbyte_cdk/utils/traced_exception.py @@ -48,7 +48,9 @@ def __init__( self._stream_descriptor = stream_descriptor super().__init__(internal_message) - def as_airbyte_message(self, stream_descriptor: Optional[StreamDescriptor] = None) -> AirbyteMessage: + def as_airbyte_message( + self, stream_descriptor: Optional[StreamDescriptor] = None + ) -> AirbyteMessage: """ Builds an AirbyteTraceMessage from the exception @@ -64,11 +66,14 @@ def as_airbyte_message(self, stream_descriptor: Optional[StreamDescriptor] = Non type=TraceType.ERROR, emitted_at=now_millis, error=AirbyteErrorTraceMessage( - message=self.message or "Something went wrong in the connector. See the logs for more details.", + message=self.message + or "Something went wrong in the connector. See the logs for more details.", internal_message=self.internal_message, failure_type=self.failure_type, stack_trace=stack_trace_str, - stream_descriptor=self._stream_descriptor if self._stream_descriptor is not None else stream_descriptor, + stream_descriptor=self._stream_descriptor + if self._stream_descriptor is not None + else stream_descriptor, ), ) @@ -77,7 +82,10 @@ def as_airbyte_message(self, stream_descriptor: Optional[StreamDescriptor] = Non def as_connection_status_message(self) -> Optional[AirbyteMessage]: if self.failure_type == FailureType.config_error: return AirbyteMessage( - type=MessageType.CONNECTION_STATUS, connectionStatus=AirbyteConnectionStatus(status=Status.FAILED, message=self.message) + type=MessageType.CONNECTION_STATUS, + connectionStatus=AirbyteConnectionStatus( + status=Status.FAILED, message=self.message + ), ) return None @@ -92,16 +100,28 @@ def emit_message(self) -> None: @classmethod def from_exception( - cls, exc: BaseException, stream_descriptor: Optional[StreamDescriptor] = None, *args, **kwargs + cls, + exc: BaseException, + stream_descriptor: Optional[StreamDescriptor] = None, + *args, + **kwargs, ) -> "AirbyteTracedException": # type: ignore # ignoring because of args and kwargs """ Helper to create an AirbyteTracedException from an existing exception :param exc: the exception that caused the error :param stream_descriptor: describe the stream from which the exception comes from """ - return cls(internal_message=str(exc), exception=exc, stream_descriptor=stream_descriptor, *args, **kwargs) # type: ignore # ignoring because of args and kwargs + return cls( + internal_message=str(exc), + exception=exc, + stream_descriptor=stream_descriptor, + *args, + **kwargs, + ) # type: ignore # ignoring because of args and kwargs - def as_sanitized_airbyte_message(self, stream_descriptor: Optional[StreamDescriptor] = None) -> AirbyteMessage: + def as_sanitized_airbyte_message( + self, stream_descriptor: Optional[StreamDescriptor] = None + ) -> AirbyteMessage: """ Builds an AirbyteTraceMessage from the exception and sanitizes any secrets from the message body @@ -112,7 +132,11 @@ def as_sanitized_airbyte_message(self, stream_descriptor: Optional[StreamDescrip if error_message.trace.error.message: # type: ignore[union-attr] # AirbyteMessage with MessageType.TRACE has AirbyteTraceMessage error_message.trace.error.message = filter_secrets(error_message.trace.error.message) # type: ignore[union-attr] # AirbyteMessage with MessageType.TRACE has AirbyteTraceMessage if error_message.trace.error.internal_message: # type: ignore[union-attr] # AirbyteMessage with MessageType.TRACE has AirbyteTraceMessage - error_message.trace.error.internal_message = filter_secrets(error_message.trace.error.internal_message) # type: ignore[union-attr] # AirbyteMessage with MessageType.TRACE has AirbyteTraceMessage + error_message.trace.error.internal_message = filter_secrets( + error_message.trace.error.internal_message + ) # type: ignore[union-attr] # AirbyteMessage with MessageType.TRACE has AirbyteTraceMessage if error_message.trace.error.stack_trace: # type: ignore[union-attr] # AirbyteMessage with MessageType.TRACE has AirbyteTraceMessage - error_message.trace.error.stack_trace = filter_secrets(error_message.trace.error.stack_trace) # type: ignore[union-attr] # AirbyteMessage with MessageType.TRACE has AirbyteTraceMessage + error_message.trace.error.stack_trace = filter_secrets( + error_message.trace.error.stack_trace + ) # type: ignore[union-attr] # AirbyteMessage with MessageType.TRACE has AirbyteTraceMessage return error_message diff --git a/bin/generate_component_manifest_files.py b/bin/generate_component_manifest_files.py index 6a595934..7e9c6835 100755 --- a/bin/generate_component_manifest_files.py +++ b/bin/generate_component_manifest_files.py @@ -29,12 +29,18 @@ def generate_init_module_content() -> str: async def post_process_codegen(codegen_container: dagger.Container): - codegen_container = codegen_container.with_exec(["mkdir", "/generated_post_processed"], use_entrypoint=True) + codegen_container = codegen_container.with_exec( + ["mkdir", "/generated_post_processed"], use_entrypoint=True + ) for generated_file in await codegen_container.directory("/generated").entries(): if generated_file.endswith(".py"): - original_content = await codegen_container.file(f"/generated/{generated_file}").contents() + original_content = await codegen_container.file( + f"/generated/{generated_file}" + ).contents() # the space before _parameters is intentional to avoid replacing things like `request_parameters:` with `requestparameters:` - post_processed_content = original_content.replace(" _parameters:", " parameters:").replace("from pydantic", "from pydantic.v1") + post_processed_content = original_content.replace( + " _parameters:", " parameters:" + ).replace("from pydantic", "from pydantic.v1") codegen_container = codegen_container.with_new_file( f"/generated_post_processed/{generated_file}", contents=post_processed_content ) @@ -50,7 +56,9 @@ async def main(): .from_(PYTHON_IMAGE) .with_exec(["mkdir", "/generated"], use_entrypoint=True) .with_exec(["pip", "install", " ".join(PIP_DEPENDENCIES)], use_entrypoint=True) - .with_mounted_directory("/yaml", dagger_client.host().directory(LOCAL_YAML_DIR_PATH, include=["*.yaml"])) + .with_mounted_directory( + "/yaml", dagger_client.host().directory(LOCAL_YAML_DIR_PATH, include=["*.yaml"]) + ) .with_new_file("/generated/__init__.py", contents=init_module_content) ) for yaml_file in get_all_yaml_files_without_ext(): @@ -69,7 +77,11 @@ async def main(): use_entrypoint=True, ) - await (await post_process_codegen(codegen_container)).directory("/generated_post_processed").export(LOCAL_OUTPUT_DIR_PATH) + await ( + (await post_process_codegen(codegen_container)) + .directory("/generated_post_processed") + .export(LOCAL_OUTPUT_DIR_PATH) + ) anyio.run(main) diff --git a/docs/generate.py b/docs/generate.py index f5467f67..58589771 100644 --- a/docs/generate.py +++ b/docs/generate.py @@ -55,7 +55,9 @@ def run() -> None: continue print(f"Found module file: {'|'.join([parent_dir, file_name])}") - module = cast(str, ".".join([parent_dir, file_name])).replace("/", ".").removesuffix(".py") + module = ( + cast(str, ".".join([parent_dir, file_name])).replace("/", ".").removesuffix(".py") + ) public_modules.append(module) # recursively delete the docs/generated folder if it exists diff --git a/reference_docs/generate_rst_schema.py b/reference_docs/generate_rst_schema.py index b401d2e4..1a2268ee 100755 --- a/reference_docs/generate_rst_schema.py +++ b/reference_docs/generate_rst_schema.py @@ -20,7 +20,9 @@ def write_master_file(templatedir: str, master_name: str, values: Dict, opts: An if __name__ == "__main__": parser = get_parser() - parser.add_argument("--master", metavar="MASTER", default="index", help=__("master document name")) + parser.add_argument( + "--master", metavar="MASTER", default="index", help=__("master document name") + ) args = parser.parse_args(sys.argv[1:]) rootpath = path.abspath(args.module_path) @@ -39,8 +41,14 @@ def write_master_file(templatedir: str, master_name: str, values: Dict, opts: An modules = recurse_tree(rootpath, excludes, args, args.templatedir) template_values = { - "top_modules": [{"path": f"api/{module}", "caption": module.split(".")[1].title()} for module in modules if module.count(".") == 1], + "top_modules": [ + {"path": f"api/{module}", "caption": module.split(".")[1].title()} + for module in modules + if module.count(".") == 1 + ], "maxdepth": args.maxdepth, } - write_master_file(templatedir=args.templatedir, master_name=args.master, values=template_values, opts=args) + write_master_file( + templatedir=args.templatedir, master_name=args.master, values=template_values, opts=args + ) main() diff --git a/unit_tests/conftest.py b/unit_tests/conftest.py index e40ddd21..3a21552b 100644 --- a/unit_tests/conftest.py +++ b/unit_tests/conftest.py @@ -10,7 +10,9 @@ @pytest.fixture() def mock_sleep(monkeypatch): - with freezegun.freeze_time(datetime.datetime.now(), ignore=["_pytest.runner", "_pytest.terminal"]) as frozen_datetime: + with freezegun.freeze_time( + datetime.datetime.now(), ignore=["_pytest.runner", "_pytest.terminal"] + ) as frozen_datetime: monkeypatch.setattr("time.sleep", lambda x: frozen_datetime.tick(x)) yield @@ -25,7 +27,9 @@ def pytest_configure(config): def pytest_collection_modifyitems(config, items): if config.getoption("--skipslow"): - skip_slow = pytest.mark.skip(reason="--skipslow option has been provided and this test is marked as slow") + skip_slow = pytest.mark.skip( + reason="--skipslow option has been provided and this test is marked as slow" + ) for item in items: if "slow" in item.keywords: item.add_marker(skip_slow) diff --git a/unit_tests/connector_builder/test_connector_builder_handler.py b/unit_tests/connector_builder/test_connector_builder_handler.py index 0212466d..10bd4513 100644 --- a/unit_tests/connector_builder/test_connector_builder_handler.py +++ b/unit_tests/connector_builder/test_connector_builder_handler.py @@ -22,8 +22,17 @@ get_limits, resolve_manifest, ) -from airbyte_cdk.connector_builder.main import handle_connector_builder_request, handle_request, read_stream -from airbyte_cdk.connector_builder.models import LogMessage, StreamRead, StreamReadPages, StreamReadSlices +from airbyte_cdk.connector_builder.main import ( + handle_connector_builder_request, + handle_request, + read_stream, +) +from airbyte_cdk.connector_builder.models import ( + LogMessage, + StreamRead, + StreamReadPages, + StreamReadSlices, +) from airbyte_cdk.models import ( AirbyteLogMessage, AirbyteMessage, @@ -54,12 +63,19 @@ _stream_name = "stream_with_custom_requester" _stream_primary_key = "id" _stream_url_base = "https://api.sendgrid.com" -_stream_options = {"name": _stream_name, "primary_key": _stream_primary_key, "url_base": _stream_url_base} +_stream_options = { + "name": _stream_name, + "primary_key": _stream_primary_key, + "url_base": _stream_url_base, +} _page_size = 2 _A_STATE = [ AirbyteStateMessage( - type="STREAM", stream=AirbyteStreamState(stream_descriptor=StreamDescriptor(name=_stream_name), stream_state={"key": "value"}) + type="STREAM", + stream=AirbyteStreamState( + stream_descriptor=StreamDescriptor(name=_stream_name), stream_state={"key": "value"} + ), ) ] @@ -103,7 +119,10 @@ }, "" "requester": { "path": "/v3/marketing/lists", - "authenticator": {"type": "BearerAuthenticator", "api_token": "{{ config.apikey }}"}, + "authenticator": { + "type": "BearerAuthenticator", + "api_token": "{{ config.apikey }}", + }, "request_parameters": {"a_param": "10"}, }, "record_selector": {"extractor": {"field_path": ["result"]}}, @@ -193,7 +212,11 @@ { "stream": { "name": "dummy_stream", - "json_schema": {"$schema": "http://json-schema.org/draft-07/schema#", "type": "object", "properties": {}}, + "json_schema": { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {}, + }, "supported_sync_modes": ["full_refresh"], "source_defined_cursor": False, }, @@ -208,7 +231,11 @@ { "stream": { "name": _stream_name, - "json_schema": {"$schema": "http://json-schema.org/draft-07/schema#", "type": "object", "properties": {}}, + "json_schema": { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {}, + }, "supported_sync_modes": ["full_refresh"], "source_defined_cursor": False, }, @@ -279,17 +306,31 @@ def _mocked_send(self, request, **kwargs) -> requests.Response: def test_handle_resolve_manifest(valid_resolve_manifest_config_file, dummy_catalog): with mock.patch.object( - connector_builder.main, "handle_connector_builder_request", return_value=AirbyteMessage(type=MessageType.RECORD) + connector_builder.main, + "handle_connector_builder_request", + return_value=AirbyteMessage(type=MessageType.RECORD), ) as patched_handle: - handle_request(["read", "--config", str(valid_resolve_manifest_config_file), "--catalog", str(dummy_catalog)]) + handle_request( + [ + "read", + "--config", + str(valid_resolve_manifest_config_file), + "--catalog", + str(dummy_catalog), + ] + ) assert patched_handle.call_count == 1 def test_handle_test_read(valid_read_config_file, configured_catalog): with mock.patch.object( - connector_builder.main, "handle_connector_builder_request", return_value=AirbyteMessage(type=MessageType.RECORD) + connector_builder.main, + "handle_connector_builder_request", + return_value=AirbyteMessage(type=MessageType.RECORD), ) as patch: - handle_request(["read", "--config", str(valid_read_config_file), "--catalog", str(configured_catalog)]) + handle_request( + ["read", "--config", str(valid_read_config_file), "--catalog", str(configured_catalog)] + ) assert patch.call_count == 1 @@ -311,7 +352,10 @@ def test_resolve_manifest(valid_resolve_manifest_config_file): "paginator": { "type": "DefaultPaginator", "page_size": _page_size, - "page_size_option": {"inject_into": "request_parameter", "field_name": "page_size"}, + "page_size_option": { + "inject_into": "request_parameter", + "field_name": "page_size", + }, "page_token_option": {"inject_into": "path", "type": "RequestPath"}, "pagination_strategy": { "type": "CursorPagination", @@ -326,7 +370,10 @@ def test_resolve_manifest(valid_resolve_manifest_config_file): }, "requester": { "path": "/v3/marketing/lists", - "authenticator": {"type": "BearerAuthenticator", "api_token": "{{ config.apikey }}"}, + "authenticator": { + "type": "BearerAuthenticator", + "api_token": "{{ config.apikey }}", + }, "request_parameters": {"a_param": "10"}, }, "record_selector": {"extractor": {"field_path": ["result"]}}, @@ -454,7 +501,9 @@ def test_read(): config = TEST_READ_CONFIG source = ManifestDeclarativeSource(MANIFEST) - real_record = AirbyteRecordMessage(data={"id": "1234", "key": "value"}, emitted_at=1, stream=_stream_name) + real_record = AirbyteRecordMessage( + data={"id": "1234", "key": "value"}, emitted_at=1, stream=_stream_name + ) stream_read = StreamRead( logs=[{"message": "here be a log message"}], slices=[ @@ -478,7 +527,11 @@ def test_read(): data={ "logs": [{"message": "here be a log message"}], "slices": [ - {"pages": [{"records": [real_record], "request": None, "response": None}], "slice_descriptor": None, "state": None} + { + "pages": [{"records": [real_record], "request": None, "response": None}], + "slice_descriptor": None, + "state": None, + } ], "test_read_limit_reached": False, "auxiliary_requests": [], @@ -490,11 +543,25 @@ def test_read(): ), ) limits = TestReadLimits() - with patch("airbyte_cdk.connector_builder.message_grouper.MessageGrouper.get_message_groups", return_value=stream_read) as mock: + with patch( + "airbyte_cdk.connector_builder.message_grouper.MessageGrouper.get_message_groups", + return_value=stream_read, + ) as mock: output_record = handle_connector_builder_request( - source, "test_read", config, ConfiguredAirbyteCatalogSerializer.load(CONFIGURED_CATALOG), _A_STATE, limits + source, + "test_read", + config, + ConfiguredAirbyteCatalogSerializer.load(CONFIGURED_CATALOG), + _A_STATE, + limits, + ) + mock.assert_called_with( + source, + config, + ConfiguredAirbyteCatalogSerializer.load(CONFIGURED_CATALOG), + _A_STATE, + limits.max_records, ) - mock.assert_called_with(source, config, ConfiguredAirbyteCatalogSerializer.load(CONFIGURED_CATALOG), _A_STATE, limits.max_records) output_record.record.emitted_at = 1 assert ( orjson.dumps(AirbyteMessageSerializer.dump(output_record)).decode() @@ -573,7 +640,13 @@ def check_config_against_spec(self): source = MockManifestDeclarativeSource() limits = TestReadLimits() - response = read_stream(source, TEST_READ_CONFIG, ConfiguredAirbyteCatalogSerializer.load(CONFIGURED_CATALOG), _A_STATE, limits) + response = read_stream( + source, + TEST_READ_CONFIG, + ConfiguredAirbyteCatalogSerializer.load(CONFIGURED_CATALOG), + _A_STATE, + limits, + ) expected_stream_read = StreamRead( logs=[LogMessage("error_message", "ERROR", "error_message", "a stack trace")], @@ -587,19 +660,23 @@ def check_config_against_spec(self): expected_message = AirbyteMessage( type=MessageType.RECORD, - record=AirbyteRecordMessage(stream=_stream_name, data=dataclasses.asdict(expected_stream_read), emitted_at=1), + record=AirbyteRecordMessage( + stream=_stream_name, data=dataclasses.asdict(expected_stream_read), emitted_at=1 + ), ) response.record.emitted_at = 1 assert response == expected_message def test_handle_429_response(): - response = _create_429_page_response({"result": [{"error": "too many requests"}], "_metadata": {"next": "next"}}) + response = _create_429_page_response( + {"result": [{"error": "too many requests"}], "_metadata": {"next": "next"}} + ) # Add backoff strategy to avoid default endless backoff loop - TEST_READ_CONFIG["__injected_declarative_manifest"]["definitions"]["retriever"]["requester"]["error_handler"] = { - "backoff_strategies": [{"type": "ConstantBackoffStrategy", "backoff_time_in_seconds": 5}] - } + TEST_READ_CONFIG["__injected_declarative_manifest"]["definitions"]["retriever"]["requester"][ + "error_handler" + ] = {"backoff_strategies": [{"type": "ConstantBackoffStrategy", "backoff_time_in_seconds": 5}]} config = TEST_READ_CONFIG limits = TestReadLimits() @@ -607,7 +684,12 @@ def test_handle_429_response(): with patch("requests.Session.send", return_value=response) as mock_send: response = handle_connector_builder_request( - source, "test_read", config, ConfiguredAirbyteCatalogSerializer.load(CONFIGURED_CATALOG), _A_PER_PARTITION_STATE, limits + source, + "test_read", + config, + ConfiguredAirbyteCatalogSerializer.load(CONFIGURED_CATALOG), + _A_PER_PARTITION_STATE, + limits, ) mock_send.assert_called_once() @@ -627,7 +709,9 @@ def test_invalid_protocol_command(command, valid_resolve_manifest_config_file): config = copy.deepcopy(RESOLVE_MANIFEST_CONFIG) config["__command"] = "resolve_manifest" with pytest.raises(SystemExit): - handle_request([command, "--config", str(valid_resolve_manifest_config_file), "--catalog", ""]) + handle_request( + [command, "--config", str(valid_resolve_manifest_config_file), "--catalog", ""] + ) def test_missing_command(valid_resolve_manifest_config_file): @@ -647,7 +731,9 @@ def test_missing_config(valid_resolve_manifest_config_file): def test_invalid_config_command(invalid_config_file, dummy_catalog): with pytest.raises(ValueError): - handle_request(["read", "--config", str(invalid_config_file), "--catalog", str(dummy_catalog)]) + handle_request( + ["read", "--config", str(invalid_config_file), "--catalog", str(dummy_catalog)] + ) @pytest.fixture @@ -688,10 +774,18 @@ def create_mock_declarative_stream(http_stream): DEFAULT_MAXIMUM_NUMBER_OF_SLICES, DEFAULT_MAXIMUM_NUMBER_OF_PAGES_PER_SLICE, ), - ("test_values_are_set", {"__test_read_config": {"max_slices": 1, "max_pages_per_slice": 2, "max_records": 3}}, 3, 1, 2), + ( + "test_values_are_set", + {"__test_read_config": {"max_slices": 1, "max_pages_per_slice": 2, "max_records": 3}}, + 3, + 1, + 2, + ), ], ) -def test_get_limits(test_name, config, expected_max_records, expected_max_slices, expected_max_pages_per_slice): +def test_get_limits( + test_name, config, expected_max_records, expected_max_slices, expected_max_pages_per_slice +): limits = get_limits(config) assert limits.max_records == expected_max_records assert limits.max_pages_per_slice == expected_max_pages_per_slice @@ -715,11 +809,17 @@ def test_create_source(): def request_log_message(request: dict) -> AirbyteMessage: - return AirbyteMessage(type=Type.LOG, log=AirbyteLogMessage(level=Level.INFO, message=f"request:{json.dumps(request)}")) + return AirbyteMessage( + type=Type.LOG, + log=AirbyteLogMessage(level=Level.INFO, message=f"request:{json.dumps(request)}"), + ) def response_log_message(response: dict) -> AirbyteMessage: - return AirbyteMessage(type=Type.LOG, log=AirbyteLogMessage(level=Level.INFO, message=f"response:{json.dumps(response)}")) + return AirbyteMessage( + type=Type.LOG, + log=AirbyteLogMessage(level=Level.INFO, message=f"response:{json.dumps(response)}"), + ) def _create_request(): @@ -783,7 +883,9 @@ def test_read_source(mock_http_stream): catalog = ConfiguredAirbyteCatalog( streams=[ ConfiguredAirbyteStream( - stream=AirbyteStream(name=_stream_name, json_schema={}, supported_sync_modes=[SyncMode.full_refresh]), + stream=AirbyteStream( + name=_stream_name, json_schema={}, supported_sync_modes=[SyncMode.full_refresh] + ), sync_mode=SyncMode.full_refresh, destination_sync_mode=DestinationSyncMode.append, ) @@ -828,7 +930,9 @@ def test_read_source_single_page_single_slice(mock_http_stream): catalog = ConfiguredAirbyteCatalog( streams=[ ConfiguredAirbyteStream( - stream=AirbyteStream(name=_stream_name, json_schema={}, supported_sync_modes=[SyncMode.full_refresh]), + stream=AirbyteStream( + name=_stream_name, json_schema={}, supported_sync_modes=[SyncMode.full_refresh] + ), sync_mode=SyncMode.full_refresh, destination_sync_mode=DestinationSyncMode.append, ) @@ -858,13 +962,42 @@ def test_read_source_single_page_single_slice(mock_http_stream): @pytest.mark.parametrize( "deployment_mode, url_base, expected_error", [ - pytest.param("CLOUD", "https://airbyte.com/api/v1/characters", None, id="test_cloud_read_with_public_endpoint"), - pytest.param("CLOUD", "https://10.0.27.27", "AirbyteTracedException", id="test_cloud_read_with_private_endpoint"), - pytest.param("CLOUD", "https://localhost:80/api/v1/cast", "AirbyteTracedException", id="test_cloud_read_with_localhost"), - pytest.param("CLOUD", "http://unsecured.protocol/api/v1", "InvalidSchema", id="test_cloud_read_with_unsecured_endpoint"), - pytest.param("CLOUD", "https://domainwithoutextension", "Invalid URL", id="test_cloud_read_with_invalid_url_endpoint"), - pytest.param("OSS", "https://airbyte.com/api/v1/", None, id="test_oss_read_with_public_endpoint"), - pytest.param("OSS", "https://10.0.27.27/api/v1/", None, id="test_oss_read_with_private_endpoint"), + pytest.param( + "CLOUD", + "https://airbyte.com/api/v1/characters", + None, + id="test_cloud_read_with_public_endpoint", + ), + pytest.param( + "CLOUD", + "https://10.0.27.27", + "AirbyteTracedException", + id="test_cloud_read_with_private_endpoint", + ), + pytest.param( + "CLOUD", + "https://localhost:80/api/v1/cast", + "AirbyteTracedException", + id="test_cloud_read_with_localhost", + ), + pytest.param( + "CLOUD", + "http://unsecured.protocol/api/v1", + "InvalidSchema", + id="test_cloud_read_with_unsecured_endpoint", + ), + pytest.param( + "CLOUD", + "https://domainwithoutextension", + "Invalid URL", + id="test_cloud_read_with_invalid_url_endpoint", + ), + pytest.param( + "OSS", "https://airbyte.com/api/v1/", None, id="test_oss_read_with_public_endpoint" + ), + pytest.param( + "OSS", "https://10.0.27.27/api/v1/", None, id="test_oss_read_with_private_endpoint" + ), ], ) @patch.object(requests.Session, "send", _mocked_send) @@ -881,7 +1014,9 @@ def test_handle_read_external_requests(deployment_mode, url_base, expected_error catalog = ConfiguredAirbyteCatalog( streams=[ ConfiguredAirbyteStream( - stream=AirbyteStream(name=_stream_name, json_schema={}, supported_sync_modes=[SyncMode.full_refresh]), + stream=AirbyteStream( + name=_stream_name, json_schema={}, supported_sync_modes=[SyncMode.full_refresh] + ), sync_mode=SyncMode.full_refresh, destination_sync_mode=DestinationSyncMode.append, ) @@ -895,9 +1030,13 @@ def test_handle_read_external_requests(deployment_mode, url_base, expected_error source = create_source(config, limits) with mock.patch.dict(os.environ, {"DEPLOYMENT_MODE": deployment_mode}, clear=False): - output_data = read_stream(source, config, catalog, _A_PER_PARTITION_STATE, limits).record.data + output_data = read_stream( + source, config, catalog, _A_PER_PARTITION_STATE, limits + ).record.data if expected_error: - assert len(output_data["logs"]) > 0, "Expected at least one log message with the expected error" + assert ( + len(output_data["logs"]) > 0 + ), "Expected at least one log message with the expected error" error_message = output_data["logs"][0] assert error_message["level"] == "ERROR" assert expected_error in error_message["stacktrace"] @@ -909,12 +1048,42 @@ def test_handle_read_external_requests(deployment_mode, url_base, expected_error @pytest.mark.parametrize( "deployment_mode, token_url, expected_error", [ - pytest.param("CLOUD", "https://airbyte.com/tokens/bearer", None, id="test_cloud_read_with_public_endpoint"), - pytest.param("CLOUD", "https://10.0.27.27/tokens/bearer", "AirbyteTracedException", id="test_cloud_read_with_private_endpoint"), - pytest.param("CLOUD", "http://unsecured.protocol/tokens/bearer", "InvalidSchema", id="test_cloud_read_with_unsecured_endpoint"), - pytest.param("CLOUD", "https://domainwithoutextension", "Invalid URL", id="test_cloud_read_with_invalid_url_endpoint"), - pytest.param("OSS", "https://airbyte.com/tokens/bearer", None, id="test_oss_read_with_public_endpoint"), - pytest.param("OSS", "https://10.0.27.27/tokens/bearer", None, id="test_oss_read_with_private_endpoint"), + pytest.param( + "CLOUD", + "https://airbyte.com/tokens/bearer", + None, + id="test_cloud_read_with_public_endpoint", + ), + pytest.param( + "CLOUD", + "https://10.0.27.27/tokens/bearer", + "AirbyteTracedException", + id="test_cloud_read_with_private_endpoint", + ), + pytest.param( + "CLOUD", + "http://unsecured.protocol/tokens/bearer", + "InvalidSchema", + id="test_cloud_read_with_unsecured_endpoint", + ), + pytest.param( + "CLOUD", + "https://domainwithoutextension", + "Invalid URL", + id="test_cloud_read_with_invalid_url_endpoint", + ), + pytest.param( + "OSS", + "https://airbyte.com/tokens/bearer", + None, + id="test_oss_read_with_public_endpoint", + ), + pytest.param( + "OSS", + "https://10.0.27.27/tokens/bearer", + None, + id="test_oss_read_with_private_endpoint", + ), ], ) @patch.object(requests.Session, "send", _mocked_send) @@ -931,7 +1100,9 @@ def test_handle_read_external_oauth_request(deployment_mode, token_url, expected catalog = ConfiguredAirbyteCatalog( streams=[ ConfiguredAirbyteStream( - stream=AirbyteStream(name=_stream_name, json_schema={}, supported_sync_modes=[SyncMode.full_refresh]), + stream=AirbyteStream( + name=_stream_name, json_schema={}, supported_sync_modes=[SyncMode.full_refresh] + ), sync_mode=SyncMode.full_refresh, destination_sync_mode=DestinationSyncMode.append, ) @@ -947,15 +1118,21 @@ def test_handle_read_external_oauth_request(deployment_mode, token_url, expected } test_manifest = MANIFEST - test_manifest["definitions"]["retriever"]["requester"]["authenticator"] = oauth_authenticator_config + test_manifest["definitions"]["retriever"]["requester"]["authenticator"] = ( + oauth_authenticator_config + ) config = {"__injected_declarative_manifest": test_manifest} source = create_source(config, limits) with mock.patch.dict(os.environ, {"DEPLOYMENT_MODE": deployment_mode}, clear=False): - output_data = read_stream(source, config, catalog, _A_PER_PARTITION_STATE, limits).record.data + output_data = read_stream( + source, config, catalog, _A_PER_PARTITION_STATE, limits + ).record.data if expected_error: - assert len(output_data["logs"]) > 0, "Expected at least one log message with the expected error" + assert ( + len(output_data["logs"]) > 0 + ), "Expected at least one log message with the expected error" error_message = output_data["logs"][0] assert error_message["level"] == "ERROR" assert expected_error in error_message["stacktrace"] @@ -967,7 +1144,9 @@ def test_read_stream_exception_with_secrets(): catalog = ConfiguredAirbyteCatalog( streams=[ ConfiguredAirbyteStream( - stream=AirbyteStream(name=_stream_name, json_schema={}, supported_sync_modes=[SyncMode.full_refresh]), + stream=AirbyteStream( + name=_stream_name, json_schema={}, supported_sync_modes=[SyncMode.full_refresh] + ), sync_mode=SyncMode.full_refresh, destination_sync_mode=DestinationSyncMode.append, ) @@ -983,7 +1162,9 @@ def test_read_stream_exception_with_secrets(): mock_source = MagicMock() # Patch the handler to raise an exception - with patch("airbyte_cdk.connector_builder.message_grouper.MessageGrouper.get_message_groups") as mock_handler: + with patch( + "airbyte_cdk.connector_builder.message_grouper.MessageGrouper.get_message_groups" + ) as mock_handler: mock_handler.side_effect = Exception("Test exception with secret key: super_secret_key") # Call the read_stream function and check for the correct error message diff --git a/unit_tests/connector_builder/test_message_grouper.py b/unit_tests/connector_builder/test_message_grouper.py index 6a4c7a99..e95f7fcc 100644 --- a/unit_tests/connector_builder/test_message_grouper.py +++ b/unit_tests/connector_builder/test_message_grouper.py @@ -8,7 +8,13 @@ import pytest from airbyte_cdk.connector_builder.message_grouper import MessageGrouper -from airbyte_cdk.connector_builder.models import HttpRequest, HttpResponse, LogMessage, StreamRead, StreamReadPages +from airbyte_cdk.connector_builder.models import ( + HttpRequest, + HttpResponse, + LogMessage, + StreamRead, + StreamReadPages, +) from airbyte_cdk.models import ( AirbyteControlConnectorConfigMessage, AirbyteControlMessage, @@ -38,29 +44,57 @@ "version": "0.30.0", "type": "DeclarativeSource", "definitions": { - "selector": {"extractor": {"field_path": ["items"], "type": "DpathExtractor"}, "type": "RecordSelector"}, - "requester": {"url_base": "https://demonslayers.com/api/v1/", "http_method": "GET", "type": "DeclarativeSource"}, + "selector": { + "extractor": {"field_path": ["items"], "type": "DpathExtractor"}, + "type": "RecordSelector", + }, + "requester": { + "url_base": "https://demonslayers.com/api/v1/", + "http_method": "GET", + "type": "DeclarativeSource", + }, "retriever": { "type": "DeclarativeSource", - "record_selector": {"extractor": {"field_path": ["items"], "type": "DpathExtractor"}, "type": "RecordSelector"}, + "record_selector": { + "extractor": {"field_path": ["items"], "type": "DpathExtractor"}, + "type": "RecordSelector", + }, "paginator": {"type": "NoPagination"}, - "requester": {"url_base": "https://demonslayers.com/api/v1/", "http_method": "GET", "type": "HttpRequester"}, + "requester": { + "url_base": "https://demonslayers.com/api/v1/", + "http_method": "GET", + "type": "HttpRequester", + }, }, "hashiras_stream": { "retriever": { "type": "DeclarativeSource", - "record_selector": {"extractor": {"field_path": ["items"], "type": "DpathExtractor"}, "type": "RecordSelector"}, + "record_selector": { + "extractor": {"field_path": ["items"], "type": "DpathExtractor"}, + "type": "RecordSelector", + }, "paginator": {"type": "NoPagination"}, - "requester": {"url_base": "https://demonslayers.com/api/v1/", "http_method": "GET", "type": "HttpRequester"}, + "requester": { + "url_base": "https://demonslayers.com/api/v1/", + "http_method": "GET", + "type": "HttpRequester", + }, }, "$parameters": {"name": "hashiras", "path": "/hashiras"}, }, "breathing_techniques_stream": { "retriever": { "type": "DeclarativeSource", - "record_selector": {"extractor": {"field_path": ["items"], "type": "DpathExtractor"}, "type": "RecordSelector"}, + "record_selector": { + "extractor": {"field_path": ["items"], "type": "DpathExtractor"}, + "type": "RecordSelector", + }, "paginator": {"type": "NoPagination"}, - "requester": {"url_base": "https://demonslayers.com/api/v1/", "http_method": "GET", "type": "HttpRequester"}, + "requester": { + "url_base": "https://demonslayers.com/api/v1/", + "http_method": "GET", + "type": "HttpRequester", + }, }, "$parameters": {"name": "breathing-techniques", "path": "/breathing_techniques"}, }, @@ -70,9 +104,16 @@ "type": "DeclarativeStream", "retriever": { "type": "SimpleRetriever", - "record_selector": {"extractor": {"field_path": ["items"], "type": "DpathExtractor"}, "type": "RecordSelector"}, + "record_selector": { + "extractor": {"field_path": ["items"], "type": "DpathExtractor"}, + "type": "RecordSelector", + }, "paginator": {"type": "NoPagination"}, - "requester": {"url_base": "https://demonslayers.com/api/v1/", "http_method": "GET", "type": "HttpRequester"}, + "requester": { + "url_base": "https://demonslayers.com/api/v1/", + "http_method": "GET", + "type": "HttpRequester", + }, }, "$parameters": {"name": "hashiras", "path": "/hashiras"}, }, @@ -80,9 +121,16 @@ "type": "DeclarativeStream", "retriever": { "type": "SimpleRetriever", - "record_selector": {"extractor": {"field_path": ["items"], "type": "DpathExtractor"}, "type": "RecordSelector"}, + "record_selector": { + "extractor": {"field_path": ["items"], "type": "DpathExtractor"}, + "type": "RecordSelector", + }, "paginator": {"type": "NoPagination"}, - "requester": {"url_base": "https://demonslayers.com/api/v1/", "http_method": "GET", "type": "HttpRequester"}, + "requester": { + "url_base": "https://demonslayers.com/api/v1/", + "http_method": "GET", + "type": "HttpRequester", + }, }, "$parameters": {"name": "breathing-techniques", "path": "/breathing_techniques"}, }, @@ -103,7 +151,11 @@ def test_get_grouped_messages(mock_entrypoint_read: Mock) -> None: "method": "GET", "body": {"content": '{"custom": "field"}'}, } - response = {"status_code": 200, "headers": {"field": "value"}, "body": {"content": '{"name": "field"}'}} + response = { + "status_code": 200, + "headers": {"field": "value"}, + "body": {"content": '{"name": "field"}'}, + } expected_schema = { "$schema": "http://json-schema.org/schema#", "properties": {"name": {"type": ["string", "null"]}, "date": {"type": ["string", "null"]}}, @@ -119,7 +171,10 @@ def test_get_grouped_messages(mock_entrypoint_read: Mock) -> None: http_method="GET", ), response=HttpResponse(status=200, headers={"field": "value"}, body='{"name": "field"}'), - records=[{"name": "Shinobu Kocho", "date": "2023-03-03"}, {"name": "Muichiro Tokito", "date": "2023-03-04"}], + records=[ + {"name": "Shinobu Kocho", "date": "2023-03-03"}, + {"name": "Muichiro Tokito", "date": "2023-03-04"}, + ], ), StreamReadPages( request=HttpRequest( @@ -170,7 +225,11 @@ def test_get_grouped_messages_with_logs(mock_entrypoint_read: Mock) -> None: "method": "GET", "body": {"content": '{"custom": "field"}'}, } - response = {"status_code": 200, "headers": {"field": "value"}, "body": {"content": '{"name": "field"}'}} + response = { + "status_code": 200, + "headers": {"field": "value"}, + "body": {"content": '{"name": "field"}'}, + } expected_pages = [ StreamReadPages( request=HttpRequest( @@ -203,12 +262,25 @@ def test_get_grouped_messages_with_logs(mock_entrypoint_read: Mock) -> None: mock_entrypoint_read, iter( [ - AirbyteMessage(type=MessageType.LOG, log=AirbyteLogMessage(level=Level.INFO, message="log message before the request")), + AirbyteMessage( + type=MessageType.LOG, + log=AirbyteLogMessage( + level=Level.INFO, message="log message before the request" + ), + ), request_response_log_message(request, response, url), record_message("hashiras", {"name": "Shinobu Kocho"}), - AirbyteMessage(type=MessageType.LOG, log=AirbyteLogMessage(level=Level.INFO, message="log message during the page")), + AirbyteMessage( + type=MessageType.LOG, + log=AirbyteLogMessage(level=Level.INFO, message="log message during the page"), + ), record_message("hashiras", {"name": "Muichiro Tokito"}), - AirbyteMessage(type=MessageType.LOG, log=AirbyteLogMessage(level=Level.INFO, message="log message after the response")), + AirbyteMessage( + type=MessageType.LOG, + log=AirbyteLogMessage( + level=Level.INFO, message="log message after the response" + ), + ), ] ), ) @@ -246,7 +318,11 @@ def test_get_grouped_messages_record_limit( "method": "GET", "body": {"content": '{"custom": "field"}'}, } - response = {"status_code": 200, "headers": {"field": "value"}, "body": {"content": '{"name": "field"}'}} + response = { + "status_code": 200, + "headers": {"field": "value"}, + "body": {"content": '{"name": "field"}'}, + } mock_source = make_mock_source( mock_entrypoint_read, iter( @@ -298,14 +374,20 @@ def test_get_grouped_messages_record_limit( ], ) @patch("airbyte_cdk.connector_builder.message_grouper.AirbyteEntrypoint.read") -def test_get_grouped_messages_default_record_limit(mock_entrypoint_read: Mock, max_record_limit: int) -> None: +def test_get_grouped_messages_default_record_limit( + mock_entrypoint_read: Mock, max_record_limit: int +) -> None: url = "https://demonslayers.com/api/v1/hashiras?era=taisho" request = { "headers": {"Content-Type": "application/json"}, "method": "GET", "body": {"content": '{"custom": "field"}'}, } - response = {"status_code": 200, "headers": {"field": "value"}, "body": {"content": '{"name": "field"}'}} + response = { + "status_code": 200, + "headers": {"field": "value"}, + "body": {"content": '{"name": "field"}'}, + } mock_source = make_mock_source( mock_entrypoint_read, iter( @@ -322,7 +404,10 @@ def test_get_grouped_messages_default_record_limit(mock_entrypoint_read: Mock, m api = MessageGrouper(MAX_PAGES_PER_SLICE, MAX_SLICES, max_record_limit=max_record_limit) actual_response: StreamRead = api.get_message_groups( - source=mock_source, config=CONFIG, configured_catalog=create_configured_catalog("hashiras"), state=_NO_STATE + source=mock_source, + config=CONFIG, + configured_catalog=create_configured_catalog("hashiras"), + state=_NO_STATE, ) single_slice = actual_response.slices[0] total_records = 0 @@ -339,7 +424,11 @@ def test_get_grouped_messages_limit_0(mock_entrypoint_read: Mock) -> None: "method": "GET", "body": {"content": '{"custom": "field"}'}, } - response = {"status_code": 200, "headers": {"field": "value"}, "body": {"content": '{"name": "field"}'}} + response = { + "status_code": 200, + "headers": {"field": "value"}, + "body": {"content": '{"name": "field"}'}, + } mock_source = make_mock_source( mock_entrypoint_read, iter( @@ -356,7 +445,11 @@ def test_get_grouped_messages_limit_0(mock_entrypoint_read: Mock) -> None: with pytest.raises(ValueError): api.get_message_groups( - source=mock_source, config=CONFIG, configured_catalog=create_configured_catalog("hashiras"), state=_NO_STATE, record_limit=0 + source=mock_source, + config=CONFIG, + configured_catalog=create_configured_catalog("hashiras"), + state=_NO_STATE, + record_limit=0, ) @@ -368,7 +461,11 @@ def test_get_grouped_messages_no_records(mock_entrypoint_read: Mock) -> None: "method": "GET", "body": {"content": '{"custom": "field"}'}, } - response = {"status_code": 200, "headers": {"field": "value"}, "body": {"content": '{"name": "field"}'}} + response = { + "status_code": 200, + "headers": {"field": "value"}, + "body": {"content": '{"name": "field"}'}, + } expected_pages = [ StreamReadPages( request=HttpRequest( @@ -429,7 +526,11 @@ def test_get_grouped_messages_no_records(mock_entrypoint_read: Mock) -> None: } } }, - HttpResponse(status=200, headers={"field": "name"}, body='{"id": "fire", "owner": "kyojuro_rengoku"}'), + HttpResponse( + status=200, + headers={"field": "name"}, + body='{"id": "fire", "owner": "kyojuro_rengoku"}', + ), id="test_create_response_with_all_fields", ), pytest.param( @@ -438,7 +539,14 @@ def test_get_grouped_messages_no_records(mock_entrypoint_read: Mock) -> None: id="test_create_response_with_no_body", ), pytest.param( - {"http": {"response": {"status_code": 200, "body": {"content": '{"id": "fire", "owner": "kyojuro_rengoku"}'}}}}, + { + "http": { + "response": { + "status_code": 200, + "body": {"content": '{"id": "fire", "owner": "kyojuro_rengoku"}'}, + } + } + }, HttpResponse(status=200, body='{"id": "fire", "owner": "kyojuro_rengoku"}'), id="test_create_response_with_no_headers", ), @@ -448,7 +556,9 @@ def test_get_grouped_messages_no_records(mock_entrypoint_read: Mock) -> None: "response": { "status_code": 200, "headers": {"field": "name"}, - "body": {"content": '[{"id": "fire", "owner": "kyojuro_rengoku"}, {"id": "mist", "owner": "muichiro_tokito"}]'}, + "body": { + "content": '[{"id": "fire", "owner": "kyojuro_rengoku"}, {"id": "mist", "owner": "muichiro_tokito"}]' + }, } } }, @@ -466,7 +576,9 @@ def test_get_grouped_messages_no_records(mock_entrypoint_read: Mock) -> None: ), ], ) -def test_create_response_from_log_message(log_message: str, expected_response: HttpResponse) -> None: +def test_create_response_from_log_message( + log_message: str, expected_response: HttpResponse +) -> None: if isinstance(log_message, str): response_message = json.loads(log_message) else: @@ -533,12 +645,18 @@ def test_get_grouped_messages_with_many_slices(mock_entrypoint_read: Mock) -> No @patch("airbyte_cdk.connector_builder.message_grouper.AirbyteEntrypoint.read") -def test_get_grouped_messages_given_maximum_number_of_slices_then_test_read_limit_reached(mock_entrypoint_read: Mock) -> None: +def test_get_grouped_messages_given_maximum_number_of_slices_then_test_read_limit_reached( + mock_entrypoint_read: Mock, +) -> None: maximum_number_of_slices = 5 request: Mapping[str, Any] = {} response = {"status_code": 200} mock_source = make_mock_source( - mock_entrypoint_read, iter([slice_message(), request_response_log_message(request, response, "a_url")] * maximum_number_of_slices) + mock_entrypoint_read, + iter( + [slice_message(), request_response_log_message(request, response, "a_url")] + * maximum_number_of_slices + ), ) api = MessageGrouper(MAX_PAGES_PER_SLICE, MAX_SLICES) @@ -554,13 +672,19 @@ def test_get_grouped_messages_given_maximum_number_of_slices_then_test_read_limi @patch("airbyte_cdk.connector_builder.message_grouper.AirbyteEntrypoint.read") -def test_get_grouped_messages_given_maximum_number_of_pages_then_test_read_limit_reached(mock_entrypoint_read: Mock) -> None: +def test_get_grouped_messages_given_maximum_number_of_pages_then_test_read_limit_reached( + mock_entrypoint_read: Mock, +) -> None: maximum_number_of_pages_per_slice = 5 request: Mapping[str, Any] = {} response = {"status_code": 200} mock_source = make_mock_source( mock_entrypoint_read, - iter([slice_message()] + [request_response_log_message(request, response, "a_url")] * maximum_number_of_pages_per_slice), + iter( + [slice_message()] + + [request_response_log_message(request, response, "a_url")] + * maximum_number_of_pages_per_slice + ), ) api = MessageGrouper(MAX_PAGES_PER_SLICE, MAX_SLICES) @@ -596,10 +720,16 @@ def test_read_stream_returns_error_if_stream_does_not_exist() -> None: @patch("airbyte_cdk.connector_builder.message_grouper.AirbyteEntrypoint.read") -def test_given_control_message_then_stream_read_has_config_update(mock_entrypoint_read: Mock) -> None: +def test_given_control_message_then_stream_read_has_config_update( + mock_entrypoint_read: Mock, +) -> None: updated_config = {"x": 1} mock_source = make_mock_source( - mock_entrypoint_read, iter(any_request_and_response_with_a_record() + [connector_configuration_control_message(1, updated_config)]) + mock_entrypoint_read, + iter( + any_request_and_response_with_a_record() + + [connector_configuration_control_message(1, updated_config)] + ), ) connector_builder_handler = MessageGrouper(MAX_PAGES_PER_SLICE, MAX_SLICES) stream_read: StreamRead = connector_builder_handler.get_message_groups( @@ -613,7 +743,9 @@ def test_given_control_message_then_stream_read_has_config_update(mock_entrypoin @patch("airbyte_cdk.connector_builder.message_grouper.AirbyteEntrypoint.read") -def test_given_multiple_control_messages_then_stream_read_has_latest_based_on_emitted_at(mock_entrypoint_read: Mock) -> None: +def test_given_multiple_control_messages_then_stream_read_has_latest_based_on_emitted_at( + mock_entrypoint_read: Mock, +) -> None: earliest = 0 earliest_config = {"earliest": 0} latest = 1 @@ -670,10 +802,16 @@ def test_given_multiple_control_messages_with_same_timestamp_then_stream_read_ha @patch("airbyte_cdk.connector_builder.message_grouper.AirbyteEntrypoint.read") def test_given_auxiliary_requests_then_return_auxiliary_request(mock_entrypoint_read: Mock) -> None: - mock_source = make_mock_source(mock_entrypoint_read, iter(any_request_and_response_with_a_record() + [auxiliary_request_log_message()])) + mock_source = make_mock_source( + mock_entrypoint_read, + iter(any_request_and_response_with_a_record() + [auxiliary_request_log_message()]), + ) connector_builder_handler = MessageGrouper(MAX_PAGES_PER_SLICE, MAX_SLICES) stream_read: StreamRead = connector_builder_handler.get_message_groups( - source=mock_source, config=CONFIG, configured_catalog=create_configured_catalog("hashiras"), state=_NO_STATE + source=mock_source, + config=CONFIG, + configured_catalog=create_configured_catalog("hashiras"), + state=_NO_STATE, ) assert len(stream_read.auxiliary_requests) == 1 @@ -684,7 +822,10 @@ def test_given_no_slices_then_return_empty_slices(mock_entrypoint_read: Mock) -> mock_source = make_mock_source(mock_entrypoint_read, iter([auxiliary_request_log_message()])) connector_builder_handler = MessageGrouper(MAX_PAGES_PER_SLICE, MAX_SLICES) stream_read: StreamRead = connector_builder_handler.get_message_groups( - source=mock_source, config=CONFIG, configured_catalog=create_configured_catalog("hashiras"), state=_NO_STATE + source=mock_source, + config=CONFIG, + configured_catalog=create_configured_catalog("hashiras"), + state=_NO_STATE, ) assert len(stream_read.slices) == 0 @@ -708,14 +849,19 @@ def test_given_pk_then_ensure_pk_is_pass_to_schema_inferrence(mock_entrypoint_re connector_builder_handler = MessageGrouper(MAX_PAGES_PER_SLICE, MAX_SLICES) stream_read: StreamRead = connector_builder_handler.get_message_groups( - source=mock_source, config=CONFIG, configured_catalog=create_configured_catalog("hashiras"), state=_NO_STATE + source=mock_source, + config=CONFIG, + configured_catalog=create_configured_catalog("hashiras"), + state=_NO_STATE, ) assert stream_read.inferred_schema["required"] == ["id"] @patch("airbyte_cdk.connector_builder.message_grouper.AirbyteEntrypoint.read") -def test_given_cursor_field_then_ensure_cursor_field_is_pass_to_schema_inferrence(mock_entrypoint_read: Mock) -> None: +def test_given_cursor_field_then_ensure_cursor_field_is_pass_to_schema_inferrence( + mock_entrypoint_read: Mock, +) -> None: mock_source = make_mock_source( mock_entrypoint_read, iter( @@ -732,13 +878,18 @@ def test_given_cursor_field_then_ensure_cursor_field_is_pass_to_schema_inferrenc connector_builder_handler = MessageGrouper(MAX_PAGES_PER_SLICE, MAX_SLICES) stream_read: StreamRead = connector_builder_handler.get_message_groups( - source=mock_source, config=CONFIG, configured_catalog=create_configured_catalog("hashiras"), state=_NO_STATE + source=mock_source, + config=CONFIG, + configured_catalog=create_configured_catalog("hashiras"), + state=_NO_STATE, ) assert stream_read.inferred_schema["required"] == ["date"] -def make_mock_source(mock_entrypoint_read: Mock, return_value: Iterator[AirbyteMessage]) -> MagicMock: +def make_mock_source( + mock_entrypoint_read: Mock, return_value: Iterator[AirbyteMessage] +) -> MagicMock: mock_source = MagicMock() mock_entrypoint_read.return_value = return_value mock_source.streams.return_value = [make_mock_stream()] @@ -753,29 +904,47 @@ def make_mock_stream(): def request_log_message(request: Mapping[str, Any]) -> AirbyteMessage: - return AirbyteMessage(type=MessageType.LOG, log=AirbyteLogMessage(level=Level.INFO, message=f"request:{json.dumps(request)}")) + return AirbyteMessage( + type=MessageType.LOG, + log=AirbyteLogMessage(level=Level.INFO, message=f"request:{json.dumps(request)}"), + ) def response_log_message(response: Mapping[str, Any]) -> AirbyteMessage: - return AirbyteMessage(type=MessageType.LOG, log=AirbyteLogMessage(level=Level.INFO, message=f"response:{json.dumps(response)}")) + return AirbyteMessage( + type=MessageType.LOG, + log=AirbyteLogMessage(level=Level.INFO, message=f"response:{json.dumps(response)}"), + ) def record_message(stream: str, data: Mapping[str, Any]) -> AirbyteMessage: - return AirbyteMessage(type=MessageType.RECORD, record=AirbyteRecordMessage(stream=stream, data=data, emitted_at=1234)) + return AirbyteMessage( + type=MessageType.RECORD, + record=AirbyteRecordMessage(stream=stream, data=data, emitted_at=1234), + ) def state_message(stream: str, data: Mapping[str, Any]) -> AirbyteMessage: return AirbyteMessage( type=MessageType.STATE, - state=AirbyteStateMessage(stream=AirbyteStreamState(stream_descriptor=StreamDescriptor(name=stream), stream_state=data)), + state=AirbyteStateMessage( + stream=AirbyteStreamState( + stream_descriptor=StreamDescriptor(name=stream), stream_state=data + ) + ), ) def slice_message(slice_descriptor: str = '{"key": "value"}') -> AirbyteMessage: - return AirbyteMessage(type=MessageType.LOG, log=AirbyteLogMessage(level=Level.INFO, message="slice:" + slice_descriptor)) + return AirbyteMessage( + type=MessageType.LOG, + log=AirbyteLogMessage(level=Level.INFO, message="slice:" + slice_descriptor), + ) -def connector_configuration_control_message(emitted_at: float, config: Mapping[str, Any]) -> AirbyteMessage: +def connector_configuration_control_message( + emitted_at: float, config: Mapping[str, Any] +) -> AirbyteMessage: return AirbyteMessage( type=MessageType.CONTROL, control=AirbyteControlMessage( @@ -807,7 +976,9 @@ def auxiliary_request_log_message() -> AirbyteMessage: ) -def request_response_log_message(request: Mapping[str, Any], response: Mapping[str, Any], url: str) -> AirbyteMessage: +def request_response_log_message( + request: Mapping[str, Any], response: Mapping[str, Any], url: str +) -> AirbyteMessage: return AirbyteMessage( type=MessageType.LOG, log=AirbyteLogMessage( @@ -815,7 +986,12 @@ def request_response_log_message(request: Mapping[str, Any], response: Mapping[s message=json.dumps( { "airbyte_cdk": {"stream": {"name": "a stream name"}}, - "http": {"title": "a title", "description": "a description", "request": request, "response": response}, + "http": { + "title": "a title", + "description": "a description", + "request": request, + "response": response, + }, "url": {"full": url}, } ), diff --git a/unit_tests/destinations/test_destination.py b/unit_tests/destinations/test_destination.py index 3620b671..ffe1fd37 100644 --- a/unit_tests/destinations/test_destination.py +++ b/unit_tests/destinations/test_destination.py @@ -52,7 +52,9 @@ class TestArgParsing: ), ], ) - def test_successful_parse(self, arg_list: List[str], expected_output: Mapping[str, Any], destination: Destination): + def test_successful_parse( + self, arg_list: List[str], expected_output: Mapping[str, Any], destination: Destination + ): parsed_args = vars(destination.parse_args(arg_list)) assert ( parsed_args == expected_output @@ -98,7 +100,13 @@ def write_file(path: PathLike, content: Union[str, Mapping]): def _wrapped( - msg: Union[AirbyteRecordMessage, AirbyteStateMessage, AirbyteCatalog, ConnectorSpecification, AirbyteConnectionStatus], + msg: Union[ + AirbyteRecordMessage, + AirbyteStateMessage, + AirbyteCatalog, + ConnectorSpecification, + AirbyteConnectionStatus, + ], ) -> AirbyteMessage: if isinstance(msg, AirbyteRecordMessage): return AirbyteMessage(type=Type.RECORD, record=msg) @@ -145,13 +153,17 @@ def test_run_initializes_exception_handler(self, mocker, destination: Destinatio mocker.patch.object(destination, "parse_args") mocker.patch.object(destination, "run_cmd") destination.run(["dummy"]) - destination_module.init_uncaught_exception_handler.assert_called_once_with(destination_module.logger) + destination_module.init_uncaught_exception_handler.assert_called_once_with( + destination_module.logger + ) def test_run_spec(self, mocker, destination: Destination): args = {"command": "spec"} parsed_args = argparse.Namespace(**args) - expected_spec = ConnectorSpecification(connectionSpecification={"json_schema": {"prop": "value"}}) + expected_spec = ConnectorSpecification( + connectionSpecification={"json_schema": {"prop": "value"}} + ) mocker.patch.object(destination, "spec", return_value=expected_spec, autospec=True) spec_message = next(iter(destination.run_cmd(parsed_args))) @@ -172,7 +184,9 @@ def test_run_check(self, mocker, destination: Destination, tmp_path): destination.run_cmd(parsed_args) spec_msg = ConnectorSpecification(connectionSpecification={}) mocker.patch.object(destination, "spec", return_value=spec_msg) - validate_mock = mocker.patch("airbyte_cdk.destinations.destination.check_config_against_spec_or_exit") + validate_mock = mocker.patch( + "airbyte_cdk.destinations.destination.check_config_against_spec_or_exit" + ) expected_check_result = AirbyteConnectionStatus(status=Status.SUCCEEDED) mocker.patch.object(destination, "check", return_value=expected_check_result, autospec=True) @@ -226,7 +240,11 @@ def test_run_write(self, mocker, destination: Destination, tmp_path, monkeypatch dummy_catalog = ConfiguredAirbyteCatalog( streams=[ ConfiguredAirbyteStream( - stream=AirbyteStream(name="mystream", json_schema={"type": "object"}, supported_sync_modes=[SyncMode.full_refresh]), + stream=AirbyteStream( + name="mystream", + json_schema={"type": "object"}, + supported_sync_modes=[SyncMode.full_refresh], + ), sync_mode=SyncMode.full_refresh, destination_sync_mode=DestinationSyncMode.overwrite, ) @@ -247,10 +265,20 @@ def test_run_write(self, mocker, destination: Destination, tmp_path, monkeypatch ) spec_msg = ConnectorSpecification(connectionSpecification={}) mocker.patch.object(destination, "spec", return_value=spec_msg) - validate_mock = mocker.patch("airbyte_cdk.destinations.destination.check_config_against_spec_or_exit") + validate_mock = mocker.patch( + "airbyte_cdk.destinations.destination.check_config_against_spec_or_exit" + ) # mock input is a record followed by some state messages - mocked_input: List[AirbyteMessage] = [_wrapped(_record("s1", {"k1": "v1"})), *expected_write_result] - mocked_stdin_string = "\n".join([orjson.dumps(AirbyteMessageSerializer.dump(record)).decode() for record in mocked_input]) + mocked_input: List[AirbyteMessage] = [ + _wrapped(_record("s1", {"k1": "v1"})), + *expected_write_result, + ] + mocked_stdin_string = "\n".join( + [ + orjson.dumps(AirbyteMessageSerializer.dump(record)).decode() + for record in mocked_input + ] + ) mocked_stdin_string += "\n add this non-serializable string to verify the destination does not break on malformed input" mocked_stdin = io.TextIOWrapper(io.BytesIO(bytes(mocked_stdin_string, "utf-8"))) diff --git a/unit_tests/destinations/vector_db_based/document_processor_test.py b/unit_tests/destinations/vector_db_based/document_processor_test.py index db3ce730..f427f42d 100644 --- a/unit_tests/destinations/vector_db_based/document_processor_test.py +++ b/unit_tests/destinations/vector_db_based/document_processor_test.py @@ -25,7 +25,11 @@ from airbyte_cdk.utils.traced_exception import AirbyteTracedException -def initialize_processor(config=ProcessingConfigModel(chunk_size=48, chunk_overlap=0, text_fields=None, metadata_fields=None)): +def initialize_processor( + config=ProcessingConfigModel( + chunk_size=48, chunk_overlap=0, text_fields=None, metadata_fields=None + ), +): catalog = ConfiguredAirbyteCatalog( streams=[ ConfiguredAirbyteStream( @@ -129,7 +133,9 @@ def test_process_single_chunk_limited_metadata(): def test_process_single_chunk_without_namespace(): - config = ProcessingConfigModel(chunk_size=48, chunk_overlap=0, text_fields=None, metadata_fields=None) + config = ProcessingConfigModel( + chunk_size=48, chunk_overlap=0, text_fields=None, metadata_fields=None + ) catalog = ConfiguredAirbyteCatalog( streams=[ ConfiguredAirbyteStream( diff --git a/unit_tests/destinations/vector_db_based/embedder_test.py b/unit_tests/destinations/vector_db_based/embedder_test.py index 600a4c08..dc0f5378 100644 --- a/unit_tests/destinations/vector_db_based/embedder_test.py +++ b/unit_tests/destinations/vector_db_based/embedder_test.py @@ -31,8 +31,16 @@ @pytest.mark.parametrize( "embedder_class, args, dimensions", ( - (OpenAIEmbedder, [OpenAIEmbeddingConfigModel(**{"mode": "openai", "openai_key": "abc"}), 1000], OPEN_AI_VECTOR_SIZE), - (CohereEmbedder, [CohereEmbeddingConfigModel(**{"mode": "cohere", "cohere_key": "abc"})], COHERE_VECTOR_SIZE), + ( + OpenAIEmbedder, + [OpenAIEmbeddingConfigModel(**{"mode": "openai", "openai_key": "abc"}), 1000], + OPEN_AI_VECTOR_SIZE, + ), + ( + CohereEmbedder, + [CohereEmbeddingConfigModel(**{"mode": "cohere", "cohere_key": "abc"})], + COHERE_VECTOR_SIZE, + ), (FakeEmbedder, [FakeEmbeddingConfigModel(**{"mode": "fake"})], OPEN_AI_VECTOR_SIZE), ( AzureOpenAIEmbedder, @@ -82,8 +90,12 @@ def test_embedder(embedder_class, args, dimensions): mock_embedding_instance.embed_documents.return_value = [[0] * dimensions] * 2 chunks = [ - Document(page_content="a", record=AirbyteRecordMessage(stream="mystream", data={}, emitted_at=0)), - Document(page_content="b", record=AirbyteRecordMessage(stream="mystream", data={}, emitted_at=0)), + Document( + page_content="a", record=AirbyteRecordMessage(stream="mystream", data={}, emitted_at=0) + ), + Document( + page_content="b", record=AirbyteRecordMessage(stream="mystream", data={}, emitted_at=0) + ), ] assert embedder.embed_documents(chunks) == mock_embedding_instance.embed_documents.return_value mock_embedding_instance.embed_documents.assert_called_with(["a", "b"]) @@ -101,8 +113,17 @@ def test_embedder(embedder_class, args, dimensions): ), ) def test_from_field_embedder(field_name, dimensions, metadata, expected_embedding, expected_error): - embedder = FromFieldEmbedder(FromFieldEmbeddingConfigModel(mode="from_field", dimensions=dimensions, field_name=field_name)) - chunks = [Document(page_content="a", record=AirbyteRecordMessage(stream="mystream", data=metadata, emitted_at=0))] + embedder = FromFieldEmbedder( + FromFieldEmbeddingConfigModel( + mode="from_field", dimensions=dimensions, field_name=field_name + ) + ) + chunks = [ + Document( + page_content="a", + record=AirbyteRecordMessage(stream="mystream", data=metadata, emitted_at=0), + ) + ] if expected_error: with pytest.raises(AirbyteTracedException): embedder.embed_documents(chunks) @@ -116,8 +137,15 @@ def test_openai_chunking(): mock_embedding_instance = MagicMock() embedder.embeddings = mock_embedding_instance - mock_embedding_instance.embed_documents.side_effect = lambda texts: [[0] * OPEN_AI_VECTOR_SIZE] * len(texts) + mock_embedding_instance.embed_documents.side_effect = lambda texts: [ + [0] * OPEN_AI_VECTOR_SIZE + ] * len(texts) - chunks = [Document(page_content="a", record=AirbyteRecordMessage(stream="mystream", data={}, emitted_at=0)) for _ in range(1005)] + chunks = [ + Document( + page_content="a", record=AirbyteRecordMessage(stream="mystream", data={}, emitted_at=0) + ) + for _ in range(1005) + ] assert embedder.embed_documents(chunks) == [[0] * OPEN_AI_VECTOR_SIZE] * 1005 mock_embedding_instance.embed_documents.assert_has_calls([call(["a"] * 1000), call(["a"] * 5)]) diff --git a/unit_tests/destinations/vector_db_based/writer_test.py b/unit_tests/destinations/vector_db_based/writer_test.py index ac831694..b39ce8d3 100644 --- a/unit_tests/destinations/vector_db_based/writer_test.py +++ b/unit_tests/destinations/vector_db_based/writer_test.py @@ -19,11 +19,16 @@ ) -def _generate_record_message(index: int, stream: str = "example_stream", namespace: Optional[str] = None): +def _generate_record_message( + index: int, stream: str = "example_stream", namespace: Optional[str] = None +): return AirbyteMessage( type=Type.RECORD, record=AirbyteRecordMessage( - stream=stream, namespace=namespace, emitted_at=1234, data={"column_name": f"value {index}", "id": index} + stream=stream, + namespace=namespace, + emitted_at=1234, + data={"column_name": f"value {index}", "id": index}, ), ) @@ -36,7 +41,11 @@ def generate_stream(name: str = "example_stream", namespace: Optional[str] = Non "stream": { "name": name, "namespace": namespace, - "json_schema": {"$schema": "http://json-schema.org/draft-07/schema#", "type": "object", "properties": {}}, + "json_schema": { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {}, + }, "supported_sync_modes": ["full_refresh", "incremental"], "source_defined_cursor": False, "default_cursor_field": ["column_name"], @@ -60,9 +69,13 @@ def test_write(omit_raw_text: bool): """ Basic test for the write method, batcher and document processor. """ - config_model = ProcessingConfigModel(chunk_overlap=0, chunk_size=1000, metadata_fields=None, text_fields=["column_name"]) + config_model = ProcessingConfigModel( + chunk_overlap=0, chunk_size=1000, metadata_fields=None, text_fields=["column_name"] + ) - configured_catalog: ConfiguredAirbyteCatalog = ConfiguredAirbyteCatalogSerializer.load({"streams": [generate_stream()]}) + configured_catalog: ConfiguredAirbyteCatalog = ConfiguredAirbyteCatalogSerializer.load( + {"streams": [generate_stream()]} + ) # messages are flushed after 32 records or after a state message, so this will trigger two batches to be processed input_messages = [_generate_record_message(i) for i in range(BATCH_SIZE + 5)] state_message = AirbyteMessage(type=Type.STATE, state=AirbyteStateMessage()) @@ -73,7 +86,9 @@ def test_write(omit_raw_text: bool): mock_embedder = generate_mock_embedder() mock_indexer = MagicMock() - post_sync_log_message = AirbyteMessage(type=Type.LOG, log=AirbyteLogMessage(level=Level.INFO, message="post sync")) + post_sync_log_message = AirbyteMessage( + type=Type.LOG, log=AirbyteLogMessage(level=Level.INFO, message="post sync") + ) mock_indexer.post_sync.return_value = [post_sync_log_message] # Create the DestinationLangchain instance @@ -125,7 +140,9 @@ def test_write_stream_namespace_split(): * out of the first batch of 32, example_stream, example stream with namespace abd and the first 5 records for example_stream2 * in the second batch, the remaining 5 records for example_stream2 """ - config_model = ProcessingConfigModel(chunk_overlap=0, chunk_size=1000, metadata_fields=None, text_fields=["column_name"]) + config_model = ProcessingConfigModel( + chunk_overlap=0, chunk_size=1000, metadata_fields=None, text_fields=["column_name"] + ) configured_catalog: ConfiguredAirbyteCatalog = ConfiguredAirbyteCatalogSerializer.load( { @@ -137,7 +154,9 @@ def test_write_stream_namespace_split(): } ) - input_messages = [_generate_record_message(i, "example_stream", None) for i in range(BATCH_SIZE - 10)] + input_messages = [ + _generate_record_message(i, "example_stream", None) for i in range(BATCH_SIZE - 10) + ] input_messages.extend([_generate_record_message(i, "example_stream", "abc") for i in range(5)]) input_messages.extend([_generate_record_message(i, "example_stream2", None) for i in range(10)]) state_message = AirbyteMessage(type=Type.STATE, state=AirbyteStateMessage()) diff --git a/unit_tests/sources/concurrent_source/test_concurrent_source_adapter.py b/unit_tests/sources/concurrent_source/test_concurrent_source_adapter.py index 38e87d67..6593416a 100644 --- a/unit_tests/sources/concurrent_source/test_concurrent_source_adapter.py +++ b/unit_tests/sources/concurrent_source/test_concurrent_source_adapter.py @@ -28,7 +28,13 @@ class _MockSource(ConcurrentSourceAdapter): - def __init__(self, concurrent_source, _streams_to_is_concurrent, logger, raise_exception_on_missing_stream=True): + def __init__( + self, + concurrent_source, + _streams_to_is_concurrent, + logger, + raise_exception_on_missing_stream=True, + ): super().__init__(concurrent_source) self._streams_to_is_concurrent = _streams_to_is_concurrent self._logger = logger @@ -36,7 +42,9 @@ def __init__(self, concurrent_source, _streams_to_is_concurrent, logger, raise_e message_repository = InMemoryMessageRepository() - def check_connection(self, logger: logging.Logger, config: Mapping[str, Any]) -> Tuple[bool, Optional[Any]]: + def check_connection( + self, logger: logging.Logger, config: Mapping[str, Any] + ) -> Tuple[bool, Optional[Any]]: raise NotImplementedError def streams(self, config: Mapping[str, Any]) -> List[Stream]: @@ -72,10 +80,14 @@ def test_concurrent_source_adapter(as_stream_status, remove_stack_trace): unavailable_stream = _mock_stream("s3", [{"data": 3}], False) concurrent_stream.name = "s2" logger = Mock() - adapter = _MockSource(concurrent_source, {regular_stream: False, concurrent_stream: True}, logger) + adapter = _MockSource( + concurrent_source, {regular_stream: False, concurrent_stream: True}, logger + ) with pytest.raises(AirbyteTracedException): messages = [] - for message in adapter.read(logger, {}, _configured_catalog([regular_stream, concurrent_stream, unavailable_stream])): + for message in adapter.read( + logger, {}, _configured_catalog([regular_stream, concurrent_stream, unavailable_stream]) + ): messages.append(message) records = [m for m in messages if m.type == MessageType.RECORD] @@ -104,7 +116,10 @@ def test_concurrent_source_adapter(as_stream_status, remove_stack_trace): expected_status = [as_stream_status("s3", AirbyteStreamStatus.INCOMPLETE)] assert len(unavailable_stream_trace_messages) == 1 - assert unavailable_stream_trace_messages[0].trace.stream_status == expected_status[0].trace.stream_status + assert ( + unavailable_stream_trace_messages[0].trace.stream_status + == expected_status[0].trace.stream_status + ) def _mock_stream(name: str, data=[], available: bool = True): @@ -152,16 +167,24 @@ def test_read_nonexistent_concurrent_stream_emit_incomplete_stream_status( concurrent_source.read.return_value = [] adapter = _MockSource(concurrent_source, {s1: True}, logger) - expected_status = [as_stream_status("this_stream_doesnt_exist_in_the_source", AirbyteStreamStatus.INCOMPLETE)] + expected_status = [ + as_stream_status("this_stream_doesnt_exist_in_the_source", AirbyteStreamStatus.INCOMPLETE) + ] adapter.raise_exception_on_missing_stream = raise_exception_on_missing_stream if not raise_exception_on_missing_stream: - messages = [remove_stack_trace(message) for message in adapter.read(logger, {}, _configured_catalog([s2]))] + messages = [ + remove_stack_trace(message) + for message in adapter.read(logger, {}, _configured_catalog([s2])) + ] assert messages[0].trace.stream_status == expected_status[0].trace.stream_status else: with pytest.raises(AirbyteTracedException) as exc_info: - messages = [remove_stack_trace(message) for message in adapter.read(logger, {}, _configured_catalog([s2]))] + messages = [ + remove_stack_trace(message) + for message in adapter.read(logger, {}, _configured_catalog([s2])) + ] assert messages == expected_status assert exc_info.value.failure_type == FailureType.config_error assert "not found in the source" in exc_info.value.message diff --git a/unit_tests/sources/declarative/async_job/test_integration.py b/unit_tests/sources/declarative/async_job/test_integration.py index b3d4f095..be078488 100644 --- a/unit_tests/sources/declarative/async_job/test_integration.py +++ b/unit_tests/sources/declarative/async_job/test_integration.py @@ -5,7 +5,13 @@ from typing import Any, Iterable, List, Mapping, Optional, Set, Tuple from unittest import TestCase, mock -from airbyte_cdk import AbstractSource, DeclarativeStream, SinglePartitionRouter, Stream, StreamSlice +from airbyte_cdk import ( + AbstractSource, + DeclarativeStream, + SinglePartitionRouter, + Stream, + StreamSlice, +) from airbyte_cdk.models import ConnectorSpecification from airbyte_cdk.sources.declarative.async_job.job import AsyncJob from airbyte_cdk.sources.declarative.async_job.job_orchestrator import AsyncJobOrchestrator @@ -50,7 +56,9 @@ def __init__(self, stream_slicer: Optional[StreamSlicer] = None) -> None: self._stream_slicer = SinglePartitionRouter({}) if stream_slicer is None else stream_slicer self._message_repository = NoopMessageRepository() - def check_connection(self, logger: logging.Logger, config: Mapping[str, Any]) -> Tuple[bool, Optional[Any]]: + def check_connection( + self, logger: logging.Logger, config: Mapping[str, Any] + ) -> Tuple[bool, Optional[Any]]: return True, None def spec(self, logger: logging.Logger) -> ConnectorSpecification: @@ -101,7 +109,11 @@ def setUp(self) -> None: def test_when_read_then_return_records_from_repository(self) -> None: output = read( - self._source, self._CONFIG, CatalogBuilder().with_stream(ConfiguredAirbyteStreamBuilder().with_name(_A_STREAM_NAME)).build() + self._source, + self._CONFIG, + CatalogBuilder() + .with_stream(ConfiguredAirbyteStreamBuilder().with_name(_A_STREAM_NAME)) + .build(), ) assert len(output.records) == 1 @@ -111,7 +123,11 @@ def test_when_read_then_call_stream_slices_only_once(self) -> None: As generating stream slices is very expensive, we want to ensure that during a read, it is only called once. """ output = read( - self._source, self._CONFIG, CatalogBuilder().with_stream(ConfiguredAirbyteStreamBuilder().with_name(_A_STREAM_NAME)).build() + self._source, + self._CONFIG, + CatalogBuilder() + .with_stream(ConfiguredAirbyteStreamBuilder().with_name(_A_STREAM_NAME)) + .build(), ) assert not output.errors diff --git a/unit_tests/sources/declarative/async_job/test_job_orchestrator.py b/unit_tests/sources/declarative/async_job/test_job_orchestrator.py index 5eb8d569..af8e84e7 100644 --- a/unit_tests/sources/declarative/async_job/test_job_orchestrator.py +++ b/unit_tests/sources/declarative/async_job/test_job_orchestrator.py @@ -12,7 +12,10 @@ from airbyte_cdk import AirbyteTracedException, StreamSlice from airbyte_cdk.models import FailureType from airbyte_cdk.sources.declarative.async_job.job import AsyncJob, AsyncJobStatus -from airbyte_cdk.sources.declarative.async_job.job_orchestrator import AsyncJobOrchestrator, AsyncPartition +from airbyte_cdk.sources.declarative.async_job.job_orchestrator import ( + AsyncJobOrchestrator, + AsyncPartition, +) from airbyte_cdk.sources.declarative.async_job.job_tracker import JobTracker from airbyte_cdk.sources.declarative.async_job.repository import AsyncJobRepository from airbyte_cdk.sources.message import MessageRepository @@ -34,7 +37,9 @@ def _create_job(status: AsyncJobStatus = AsyncJobStatus.FAILED) -> AsyncJob: class AsyncPartitionTest(TestCase): def test_given_one_failed_job_when_status_then_return_failed(self) -> None: - partition = AsyncPartition([_create_job(status) for status in AsyncJobStatus], _ANY_STREAM_SLICE) + partition = AsyncPartition( + [_create_job(status) for status in AsyncJobStatus], _ANY_STREAM_SLICE + ) assert partition.status == AsyncJobStatus.FAILED def test_given_all_status_except_failed_when_status_then_return_timed_out(self) -> None: @@ -43,15 +48,22 @@ def test_given_all_status_except_failed_when_status_then_return_timed_out(self) assert partition.status == AsyncJobStatus.TIMED_OUT def test_given_running_and_completed_jobs_when_status_then_return_running(self) -> None: - partition = AsyncPartition([_create_job(AsyncJobStatus.RUNNING), _create_job(AsyncJobStatus.COMPLETED)], _ANY_STREAM_SLICE) + partition = AsyncPartition( + [_create_job(AsyncJobStatus.RUNNING), _create_job(AsyncJobStatus.COMPLETED)], + _ANY_STREAM_SLICE, + ) assert partition.status == AsyncJobStatus.RUNNING def test_given_only_completed_jobs_when_status_then_return_running(self) -> None: - partition = AsyncPartition([_create_job(AsyncJobStatus.COMPLETED) for _ in range(10)], _ANY_STREAM_SLICE) + partition = AsyncPartition( + [_create_job(AsyncJobStatus.COMPLETED) for _ in range(10)], _ANY_STREAM_SLICE + ) assert partition.status == AsyncJobStatus.COMPLETED -def _status_update_per_jobs(status_update_per_jobs: Mapping[AsyncJob, List[AsyncJobStatus]]) -> Callable[[set[AsyncJob]], None]: +def _status_update_per_jobs( + status_update_per_jobs: Mapping[AsyncJob, List[AsyncJobStatus]], +) -> Callable[[set[AsyncJob]], None]: status_index_by_job = {job: 0 for job in status_update_per_jobs.keys()} def _update_status(jobs: Set[AsyncJob]) -> None: @@ -74,7 +86,9 @@ def setUp(self) -> None: self._logger = Mock(spec=logging.Logger) self._job_for_a_slice = self._an_async_job("an api job id", _A_STREAM_SLICE) - self._job_for_another_slice = self._an_async_job("another api job id", _ANOTHER_STREAM_SLICE) + self._job_for_another_slice = self._an_async_job( + "another api job id", _ANOTHER_STREAM_SLICE + ) @mock.patch(sleep_mock_target) def test_when_create_and_get_completed_partitions_then_create_job_and_update_status_until_completed( @@ -82,20 +96,27 @@ def test_when_create_and_get_completed_partitions_then_create_job_and_update_sta ) -> None: self._job_repository.start.return_value = self._job_for_a_slice status_updates = [AsyncJobStatus.RUNNING, AsyncJobStatus.RUNNING, AsyncJobStatus.COMPLETED] - self._job_repository.update_jobs_status.side_effect = _status_update_per_jobs({self._job_for_a_slice: status_updates}) + self._job_repository.update_jobs_status.side_effect = _status_update_per_jobs( + {self._job_for_a_slice: status_updates} + ) orchestrator = self._orchestrator([_A_STREAM_SLICE]) partitions = list(orchestrator.create_and_get_completed_partitions()) assert len(partitions) == 1 assert partitions[0].status == AsyncJobStatus.COMPLETED - assert self._job_for_a_slice.update_status.mock_calls == [call(status) for status in status_updates] + assert self._job_for_a_slice.update_status.mock_calls == [ + call(status) for status in status_updates + ] @mock.patch(sleep_mock_target) def test_given_one_job_still_running_when_create_and_get_completed_partitions_then_only_update_running_job_status( self, mock_sleep: MagicMock ) -> None: - self._job_repository.start.side_effect = [self._job_for_a_slice, self._job_for_another_slice] + self._job_repository.start.side_effect = [ + self._job_for_a_slice, + self._job_for_another_slice, + ] self._job_repository.update_jobs_status.side_effect = _status_update_per_jobs( { self._job_for_a_slice: [AsyncJobStatus.COMPLETED], @@ -117,23 +138,35 @@ def test_given_timeout_when_create_and_get_completed_partitions_then_free_budget ) -> None: job_tracker = JobTracker(1) self._job_repository.start.return_value = self._job_for_a_slice - self._job_repository.update_jobs_status.side_effect = _status_update_per_jobs({self._job_for_a_slice: [AsyncJobStatus.TIMED_OUT]}) + self._job_repository.update_jobs_status.side_effect = _status_update_per_jobs( + {self._job_for_a_slice: [AsyncJobStatus.TIMED_OUT]} + ) orchestrator = self._orchestrator([_A_STREAM_SLICE], job_tracker) with pytest.raises(AirbyteTracedException): list(orchestrator.create_and_get_completed_partitions()) assert job_tracker.try_to_get_intent() - assert self._job_repository.start.call_args_list == [call(_A_STREAM_SLICE)] * _MAX_NUMBER_OF_ATTEMPTS + assert ( + self._job_repository.start.call_args_list + == [call(_A_STREAM_SLICE)] * _MAX_NUMBER_OF_ATTEMPTS + ) @mock.patch(sleep_mock_target) - def test_given_failure_when_create_and_get_completed_partitions_then_raise_exception(self, mock_sleep: MagicMock) -> None: + def test_given_failure_when_create_and_get_completed_partitions_then_raise_exception( + self, mock_sleep: MagicMock + ) -> None: self._job_repository.start.return_value = self._job_for_a_slice - self._job_repository.update_jobs_status.side_effect = _status_update_per_jobs({self._job_for_a_slice: [AsyncJobStatus.FAILED]}) + self._job_repository.update_jobs_status.side_effect = _status_update_per_jobs( + {self._job_for_a_slice: [AsyncJobStatus.FAILED]} + ) orchestrator = self._orchestrator([_A_STREAM_SLICE]) with pytest.raises(AirbyteTracedException): list(orchestrator.create_and_get_completed_partitions()) - assert self._job_repository.start.call_args_list == [call(_A_STREAM_SLICE)] * _MAX_NUMBER_OF_ATTEMPTS + assert ( + self._job_repository.start.call_args_list + == [call(_A_STREAM_SLICE)] * _MAX_NUMBER_OF_ATTEMPTS + ) def test_when_fetch_records_then_yield_records_from_each_job(self) -> None: self._job_repository.fetch_records.return_value = [_ANY_RECORD] @@ -148,15 +181,22 @@ def test_when_fetch_records_then_yield_records_from_each_job(self) -> None: assert self._job_repository.fetch_records.mock_calls == [call(first_job), call(second_job)] assert self._job_repository.delete.mock_calls == [call(first_job), call(second_job)] - def _orchestrator(self, slices: List[StreamSlice], job_tracker: Optional[JobTracker] = None) -> AsyncJobOrchestrator: + def _orchestrator( + self, slices: List[StreamSlice], job_tracker: Optional[JobTracker] = None + ) -> AsyncJobOrchestrator: job_tracker = job_tracker if job_tracker else JobTracker(_NO_JOB_LIMIT) - return AsyncJobOrchestrator(self._job_repository, slices, job_tracker, self._message_repository) + return AsyncJobOrchestrator( + self._job_repository, slices, job_tracker, self._message_repository + ) def test_given_more_jobs_than_limit_when_create_and_get_completed_partitions_then_still_return_all_slices_and_free_job_budget( self, ) -> None: job_tracker = JobTracker(1) - self._job_repository.start.side_effect = [self._job_for_a_slice, self._job_for_another_slice] + self._job_repository.start.side_effect = [ + self._job_for_a_slice, + self._job_for_another_slice, + ] self._job_repository.update_jobs_status.side_effect = _status_update_per_jobs( { self._job_for_a_slice: [AsyncJobStatus.COMPLETED], @@ -164,7 +204,8 @@ def test_given_more_jobs_than_limit_when_create_and_get_completed_partitions_the } ) orchestrator = self._orchestrator( - [self._job_for_a_slice.job_parameters(), self._job_for_another_slice.job_parameters()], job_tracker + [self._job_for_a_slice.job_parameters(), self._job_for_another_slice.job_parameters()], + job_tracker, ) partitions = list(orchestrator.create_and_get_completed_partitions()) @@ -173,7 +214,9 @@ def test_given_more_jobs_than_limit_when_create_and_get_completed_partitions_the assert job_tracker.try_to_get_intent() @mock.patch(sleep_mock_target) - def test_given_exception_to_break_when_start_job_and_raise_this_exception_and_abort_jobs(self, mock_sleep: MagicMock) -> None: + def test_given_exception_to_break_when_start_job_and_raise_this_exception_and_abort_jobs( + self, mock_sleep: MagicMock + ) -> None: orchestrator = AsyncJobOrchestrator( self._job_repository, [_A_STREAM_SLICE, _ANOTHER_STREAM_SLICE], @@ -181,7 +224,10 @@ def test_given_exception_to_break_when_start_job_and_raise_this_exception_and_ab self._message_repository, exceptions_to_break_on=[ValueError], ) - self._job_repository.start.side_effect = [self._job_for_a_slice, ValueError("Something went wrong")] + self._job_repository.start.side_effect = [ + self._job_for_a_slice, + ValueError("Something went wrong"), + ] with pytest.raises(ValueError): # assert that orchestrator exits on expected error @@ -189,7 +235,9 @@ def test_given_exception_to_break_when_start_job_and_raise_this_exception_and_ab assert len(orchestrator._job_tracker._jobs) == 0 self._job_repository.abort.assert_called_once_with(self._job_for_a_slice) - def test_given_traced_config_error_when_start_job_and_raise_this_exception_and_abort_jobs(self) -> None: + def test_given_traced_config_error_when_start_job_and_raise_this_exception_and_abort_jobs( + self, + ) -> None: """ Since this is a config error, we assume the other jobs will fail for the same reasons. """ @@ -198,7 +246,13 @@ def test_given_traced_config_error_when_start_job_and_raise_this_exception_and_a "Can't create job", failure_type=FailureType.config_error ) - orchestrator = AsyncJobOrchestrator(self._job_repository, [_A_STREAM_SLICE], job_tracker, self._message_repository, [ValueError]) + orchestrator = AsyncJobOrchestrator( + self._job_repository, + [_A_STREAM_SLICE], + job_tracker, + self._message_repository, + [ValueError], + ) with pytest.raises(AirbyteTracedException): list(orchestrator.create_and_get_completed_partitions()) @@ -206,7 +260,9 @@ def test_given_traced_config_error_when_start_job_and_raise_this_exception_and_a assert job_tracker.try_to_get_intent() @mock.patch(sleep_mock_target) - def test_given_exception_on_single_job_when_create_and_get_completed_partitions_then_return(self, mock_sleep: MagicMock) -> None: + def test_given_exception_on_single_job_when_create_and_get_completed_partitions_then_return( + self, mock_sleep: MagicMock + ) -> None: """ We added this test because the initial logic of breaking the main loop we implemented (when `self._has_started_a_job and self._running_partitions`) was not enough in the case where there was only one slice and it would fail to start. """ @@ -218,7 +274,9 @@ def test_given_exception_on_single_job_when_create_and_get_completed_partitions_ list(orchestrator.create_and_get_completed_partitions()) @mock.patch(sleep_mock_target) - def test_given_exception_when_start_job_and_skip_this_exception(self, mock_sleep: MagicMock) -> None: + def test_given_exception_when_start_job_and_skip_this_exception( + self, mock_sleep: MagicMock + ) -> None: self._job_repository.start.side_effect = [ AirbyteTracedException("Something went wrong. Expected error #1"), self._job_for_another_slice, @@ -246,7 +304,9 @@ def test_given_jobs_failed_more_than_max_attempts_when_create_and_get_completed_ job_tracker = JobTracker(1) jobs = [self._an_async_job(str(i), _A_STREAM_SLICE) for i in range(_MAX_NUMBER_OF_ATTEMPTS)] self._job_repository.start.side_effect = jobs - self._job_repository.update_jobs_status.side_effect = _status_update_per_jobs({job: [AsyncJobStatus.FAILED] for job in jobs}) + self._job_repository.update_jobs_status.side_effect = _status_update_per_jobs( + {job: [AsyncJobStatus.FAILED] for job in jobs} + ) orchestrator = self._orchestrator([_A_STREAM_SLICE], job_tracker) @@ -255,7 +315,9 @@ def test_given_jobs_failed_more_than_max_attempts_when_create_and_get_completed_ assert job_tracker.try_to_get_intent() - def given_budget_already_taken_before_start_when_create_and_get_completed_partitions_then_wait_for_budget_to_be_freed(self) -> None: + def given_budget_already_taken_before_start_when_create_and_get_completed_partitions_then_wait_for_budget_to_be_freed( + self, + ) -> None: job_tracker = JobTracker(1) intent_to_free = job_tracker.try_to_get_intent() @@ -276,11 +338,19 @@ def wait_and_free_intent(_job_tracker: JobTracker, _intent_to_free: str) -> None assert len(partitions) == 1 - def test_given_start_job_raise_when_create_and_get_completed_partitions_then_free_budget(self) -> None: + def test_given_start_job_raise_when_create_and_get_completed_partitions_then_free_budget( + self, + ) -> None: job_tracker = JobTracker(1) self._job_repository.start.side_effect = ValueError("Can't create job") - orchestrator = AsyncJobOrchestrator(self._job_repository, [_A_STREAM_SLICE], job_tracker, self._message_repository, [ValueError]) + orchestrator = AsyncJobOrchestrator( + self._job_repository, + [_A_STREAM_SLICE], + job_tracker, + self._message_repository, + [ValueError], + ) with pytest.raises(Exception): list(orchestrator.create_and_get_completed_partitions()) diff --git a/unit_tests/sources/declarative/async_job/test_job_tracker.py b/unit_tests/sources/declarative/async_job/test_job_tracker.py index f3c2744d..6d09df16 100644 --- a/unit_tests/sources/declarative/async_job/test_job_tracker.py +++ b/unit_tests/sources/declarative/async_job/test_job_tracker.py @@ -4,7 +4,10 @@ from unittest import TestCase import pytest -from airbyte_cdk.sources.declarative.async_job.job_tracker import ConcurrentJobLimitReached, JobTracker +from airbyte_cdk.sources.declarative.async_job.job_tracker import ( + ConcurrentJobLimitReached, + JobTracker, +) _LIMIT = 3 diff --git a/unit_tests/sources/declarative/auth/test_jwt.py b/unit_tests/sources/declarative/auth/test_jwt.py index 51bef482..a26042f7 100644 --- a/unit_tests/sources/declarative/auth/test_jwt.py +++ b/unit_tests/sources/declarative/auth/test_jwt.py @@ -27,7 +27,13 @@ class TestJwtAuthenticator: "test_typ", "test_cty", {"test": "test"}, - {"kid": "test_kid", "typ": "test_typ", "cty": "test_cty", "test": "test", "alg": "ALGORITHM"}, + { + "kid": "test_kid", + "typ": "test_typ", + "cty": "test_cty", + "test": "test", + "alg": "ALGORITHM", + }, ), ("ALGORITHM", None, None, None, None, {"alg": "ALGORITHM"}), ], diff --git a/unit_tests/sources/declarative/auth/test_oauth.py b/unit_tests/sources/declarative/auth/test_oauth.py index 78dd0b59..4cdfad2f 100644 --- a/unit_tests/sources/declarative/auth/test_oauth.py +++ b/unit_tests/sources/declarative/auth/test_oauth.py @@ -160,7 +160,9 @@ def test_refresh_access_token(self, mocker): ) resp.status_code = 200 - mocker.patch.object(resp, "json", return_value={"access_token": "access_token", "expires_in": 1000}) + mocker.patch.object( + resp, "json", return_value={"access_token": "access_token", "expires_in": 1000} + ) mocker.patch.object(requests, "request", side_effect=mock_request, autospec=True) token = oauth.refresh_access_token() @@ -200,7 +202,9 @@ def test_refresh_access_token_missing_access_token(self, mocker): ], ids=["timestamp_as_integer", "timestamp_as_integer_inside_string"], ) - def test_initialize_declarative_oauth_with_token_expiry_date_as_timestamp(self, timestamp, expected_date): + def test_initialize_declarative_oauth_with_token_expiry_date_as_timestamp( + self, timestamp, expected_date + ): # TODO: should be fixed inside DeclarativeOauth2Authenticator, remove next line after fixing with pytest.raises(TypeError): oauth = DeclarativeOauth2Authenticator( @@ -231,9 +235,13 @@ def test_initialize_declarative_oauth_with_token_expiry_date_as_timestamp(self, ids=["rfc3339", "iso8601", "simple_date"], ) @freezegun.freeze_time("2020-01-01") - def test_refresh_access_token_expire_format(self, mocker, expires_in_response, token_expiry_date_format): + def test_refresh_access_token_expire_format( + self, mocker, expires_in_response, token_expiry_date_format + ): next_day = "2020-01-02T00:00:00Z" - config.update({"token_expiry_date": pendulum.parse(next_day).subtract(days=2).to_rfc3339_string()}) + config.update( + {"token_expiry_date": pendulum.parse(next_day).subtract(days=2).to_rfc3339_string()} + ) message_repository = Mock() oauth = DeclarativeOauth2Authenticator( token_refresh_endpoint="{{ config['refresh_endpoint'] }}", @@ -255,7 +263,11 @@ def test_refresh_access_token_expire_format(self, mocker, expires_in_response, t ) resp.status_code = 200 - mocker.patch.object(resp, "json", return_value={"access_token": "access_token", "expires_in": expires_in_response}) + mocker.patch.object( + resp, + "json", + return_value={"access_token": "access_token", "expires_in": expires_in_response}, + ) mocker.patch.object(requests, "request", side_effect=mock_request, autospec=True) token = oauth.get_access_token() assert "access_token" == token @@ -271,11 +283,19 @@ def test_refresh_access_token_expire_format(self, mocker, expires_in_response, t ("86400.1", "2020-01-02T00:00:00Z", False), ("2020-01-02T00:00:00Z", "2020-01-02T00:00:00Z", True), ], - ids=["time_in_seconds", "time_in_seconds_float", "time_in_seconds_str", "time_in_seconds_str_float", "invalid"], + ids=[ + "time_in_seconds", + "time_in_seconds_float", + "time_in_seconds_str", + "time_in_seconds_str_float", + "invalid", + ], ) @freezegun.freeze_time("2020-01-01") def test_set_token_expiry_date_no_format(self, mocker, expires_in_response, next_day, raises): - config.update({"token_expiry_date": pendulum.parse(next_day).subtract(days=2).to_rfc3339_string()}) + config.update( + {"token_expiry_date": pendulum.parse(next_day).subtract(days=2).to_rfc3339_string()} + ) oauth = DeclarativeOauth2Authenticator( token_refresh_endpoint="{{ config['refresh_endpoint'] }}", client_id="{{ config['client_id'] }}", @@ -292,7 +312,11 @@ def test_set_token_expiry_date_no_format(self, mocker, expires_in_response, next ) resp.status_code = 200 - mocker.patch.object(resp, "json", return_value={"access_token": "access_token", "expires_in": expires_in_response}) + mocker.patch.object( + resp, + "json", + return_value={"access_token": "access_token", "expires_in": expires_in_response}, + ) mocker.patch.object(requests, "request", side_effect=mock_request, autospec=True) if raises: with pytest.raises(ValueError): @@ -318,7 +342,9 @@ def test_error_handling(self, mocker): parameters={}, ) resp.status_code = 400 - mocker.patch.object(resp, "json", return_value={"access_token": "access_token", "expires_in": 123}) + mocker.patch.object( + resp, "json", return_value={"access_token": "access_token", "expires_in": 123} + ) mocker.patch.object(requests, "request", side_effect=mock_request, autospec=True) with pytest.raises(requests.exceptions.HTTPError) as e: oauth.refresh_access_token() diff --git a/unit_tests/sources/declarative/auth/test_selective_authenticator.py b/unit_tests/sources/declarative/auth/test_selective_authenticator.py index 346b284c..55b0a1ed 100644 --- a/unit_tests/sources/declarative/auth/test_selective_authenticator.py +++ b/unit_tests/sources/declarative/auth/test_selective_authenticator.py @@ -20,7 +20,9 @@ def test_authenticator_selected(mocker): def test_selection_path_not_found(mocker): authenticators = {"one": mocker.Mock(), "two": mocker.Mock()} - with pytest.raises(ValueError, match="The path from `authenticator_selection_path` is not found in the config"): + with pytest.raises( + ValueError, match="The path from `authenticator_selection_path` is not found in the config" + ): _ = SelectiveAuthenticator( config={"auth": {"type": "one"}}, authenticators=authenticators, diff --git a/unit_tests/sources/declarative/auth/test_session_token_auth.py b/unit_tests/sources/declarative/auth/test_session_token_auth.py index 5f99fabf..eda2f36b 100644 --- a/unit_tests/sources/declarative/auth/test_session_token_auth.py +++ b/unit_tests/sources/declarative/auth/test_session_token_auth.py @@ -3,7 +3,10 @@ # import pytest -from airbyte_cdk.sources.declarative.auth.token import LegacySessionTokenAuthenticator, get_new_session_token +from airbyte_cdk.sources.declarative.auth.token import ( + LegacySessionTokenAuthenticator, + get_new_session_token, +) from requests.exceptions import HTTPError parameters = {"hello": "world"} @@ -73,7 +76,8 @@ def test_auth_header(): def test_get_token_valid_session(requests_mock): requests_mock.get( - f"{config_session_token['instance_api_url']}user/current", json={"common_name": "common_name", "last_login": "last_login"} + f"{config_session_token['instance_api_url']}user/current", + json={"common_name": "common_name", "last_login": "last_login"}, ) token = LegacySessionTokenAuthenticator( @@ -142,7 +146,10 @@ def test_get_token_username_password(requests_mock): def test_check_is_valid_session_token(requests_mock): - requests_mock.get(f"{config['instance_api_url']}user/current", json={"common_name": "common_name", "last_login": "last_login"}) + requests_mock.get( + f"{config['instance_api_url']}user/current", + json={"common_name": "common_name", "last_login": "last_login"}, + ) assert LegacySessionTokenAuthenticator( config=config, @@ -174,9 +181,16 @@ def test_check_is_valid_session_token_unauthorized(): def test_get_new_session_token(requests_mock): - requests_mock.post(f"{config['instance_api_url']}session", headers={"Content-Type": "application/json"}, json={"id": "some session id"}) + requests_mock.post( + f"{config['instance_api_url']}session", + headers={"Content-Type": "application/json"}, + json={"id": "some session id"}, + ) session_token = get_new_session_token( - f'{config["instance_api_url"]}session', config["username"], config["password"], config["session_token_response_key"] + f'{config["instance_api_url"]}session', + config["username"], + config["password"], + config["session_token_response_key"], ) assert session_token == "some session id" diff --git a/unit_tests/sources/declarative/auth/test_token_auth.py b/unit_tests/sources/declarative/auth/test_token_auth.py index 599667c4..64b181c4 100644 --- a/unit_tests/sources/declarative/auth/test_token_auth.py +++ b/unit_tests/sources/declarative/auth/test_token_auth.py @@ -6,9 +6,16 @@ import pytest import requests -from airbyte_cdk.sources.declarative.auth.token import ApiKeyAuthenticator, BasicHttpAuthenticator, BearerAuthenticator +from airbyte_cdk.sources.declarative.auth.token import ( + ApiKeyAuthenticator, + BasicHttpAuthenticator, + BearerAuthenticator, +) from airbyte_cdk.sources.declarative.auth.token_provider import InterpolatedStringTokenProvider -from airbyte_cdk.sources.declarative.requesters.request_option import RequestOption, RequestOptionType +from airbyte_cdk.sources.declarative.requesters.request_option import ( + RequestOption, + RequestOptionType, +) from requests import Response LOGGER = logging.getLogger(__name__) @@ -30,7 +37,9 @@ def test_bearer_token_authenticator(test_name, token, expected_header_value): """ Should match passed in token, no matter how many times token is retrieved. """ - token_provider = InterpolatedStringTokenProvider(config=config, api_token=token, parameters=parameters) + token_provider = InterpolatedStringTokenProvider( + config=config, api_token=token, parameters=parameters + ) token_auth = BearerAuthenticator(token_provider, config, parameters=parameters) header1 = token_auth.get_auth_header() header2 = token_auth.get_auth_header() @@ -48,15 +57,27 @@ def test_bearer_token_authenticator(test_name, token, expected_header_value): "test_name, username, password, expected_header_value", [ ("test_static_creds", "user", "password", "Basic dXNlcjpwYXNzd29yZA=="), - ("test_creds_from_config", "{{ config.username }}", "{{ config.password }}", "Basic dXNlcjpwYXNzd29yZA=="), - ("test_creds_from_parameters", "{{ parameters.username }}", "{{ parameters.password }}", "Basic dXNlcjpwYXNzd29yZA=="), + ( + "test_creds_from_config", + "{{ config.username }}", + "{{ config.password }}", + "Basic dXNlcjpwYXNzd29yZA==", + ), + ( + "test_creds_from_parameters", + "{{ parameters.username }}", + "{{ parameters.password }}", + "Basic dXNlcjpwYXNzd29yZA==", + ), ], ) def test_basic_authenticator(test_name, username, password, expected_header_value): """ Should match passed in token, no matter how many times token is retrieved. """ - token_auth = BasicHttpAuthenticator(username=username, password=password, config=config, parameters=parameters) + token_auth = BasicHttpAuthenticator( + username=username, password=password, config=config, parameters=parameters + ) header1 = token_auth.get_auth_header() header2 = token_auth.get_auth_header() @@ -73,17 +94,33 @@ def test_basic_authenticator(test_name, username, password, expected_header_valu "test_name, header, token, expected_header, expected_header_value", [ ("test_static_token", "Authorization", "test-token", "Authorization", "test-token"), - ("test_token_from_config", "{{ config.header }}", "{{ config.username }}", "header", "user"), - ("test_token_from_parameters", "{{ parameters.header }}", "{{ parameters.username }}", "header", "user"), + ( + "test_token_from_config", + "{{ config.header }}", + "{{ config.username }}", + "header", + "user", + ), + ( + "test_token_from_parameters", + "{{ parameters.header }}", + "{{ parameters.username }}", + "header", + "user", + ), ], ) def test_api_key_authenticator(test_name, header, token, expected_header, expected_header_value): """ Should match passed in token, no matter how many times token is retrieved. """ - token_provider = InterpolatedStringTokenProvider(config=config, api_token=token, parameters=parameters) + token_provider = InterpolatedStringTokenProvider( + config=config, api_token=token, parameters=parameters + ) token_auth = ApiKeyAuthenticator( - request_option=RequestOption(inject_into=RequestOptionType.header, field_name=header, parameters=parameters), + request_option=RequestOption( + inject_into=RequestOptionType.header, field_name=header, parameters=parameters + ), token_provider=token_provider, config=config, parameters=parameters, @@ -186,13 +223,25 @@ def test_api_key_authenticator(test_name, header, token, expected_header, expect ), ], ) -def test_api_key_authenticator_inject(test_name, field_name, token, expected_field_name, expected_field_value, inject_type, validation_fn): +def test_api_key_authenticator_inject( + test_name, + field_name, + token, + expected_field_name, + expected_field_value, + inject_type, + validation_fn, +): """ Should match passed in token, no matter how many times token is retrieved. """ - token_provider = InterpolatedStringTokenProvider(config=config, api_token=token, parameters=parameters) + token_provider = InterpolatedStringTokenProvider( + config=config, api_token=token, parameters=parameters + ) token_auth = ApiKeyAuthenticator( - request_option=RequestOption(inject_into=inject_type, field_name=field_name, parameters=parameters), + request_option=RequestOption( + inject_into=inject_type, field_name=field_name, parameters=parameters + ), token_provider=token_provider, config=config, parameters=parameters, diff --git a/unit_tests/sources/declarative/auth/test_token_provider.py b/unit_tests/sources/declarative/auth/test_token_provider.py index e73e5eef..684dfbf7 100644 --- a/unit_tests/sources/declarative/auth/test_token_provider.py +++ b/unit_tests/sources/declarative/auth/test_token_provider.py @@ -6,7 +6,10 @@ import pendulum import pytest -from airbyte_cdk.sources.declarative.auth.token_provider import InterpolatedStringTokenProvider, SessionTokenProvider +from airbyte_cdk.sources.declarative.auth.token_provider import ( + InterpolatedStringTokenProvider, + SessionTokenProvider, +) from airbyte_cdk.sources.declarative.exceptions import ReadException from isodate import parse_duration @@ -27,7 +30,9 @@ def create_session_token_provider(): def test_interpolated_string_token_provider(): provider = InterpolatedStringTokenProvider( - config={"config_key": "val"}, api_token="{{ config.config_key }}-{{ parameters.test }}", parameters={"test": "test"} + config={"config_key": "val"}, + api_token="{{ config.config_key }}-{{ parameters.test }}", + parameters={"test": "test"}, ) assert provider.get_token() == "val-test" @@ -49,7 +54,9 @@ def test_session_token_provider_cache_expiration(): provider = create_session_token_provider() provider.get_token() - provider.login_requester.send_request.return_value.json.return_value = {"nested": {"token": "updated_token"}} + provider.login_requester.send_request.return_value.json.return_value = { + "nested": {"token": "updated_token"} + } with pendulum.test(pendulum.datetime(2001, 5, 21, 14)): assert provider.get_token() == "updated_token" diff --git a/unit_tests/sources/declarative/checks/test_check_stream.py b/unit_tests/sources/declarative/checks/test_check_stream.py index 4ebe449d..aee429c8 100644 --- a/unit_tests/sources/declarative/checks/test_check_stream.py +++ b/unit_tests/sources/declarative/checks/test_check_stream.py @@ -22,13 +22,21 @@ "test_name, record, streams_to_check, stream_slice, expectation", [ ("test_success_check", record, stream_names, {}, (True, None)), - ("test_success_check_stream_slice", record, stream_names, {"slice": "slice_value"}, (True, None)), + ( + "test_success_check_stream_slice", + record, + stream_names, + {"slice": "slice_value"}, + (True, None), + ), ("test_fail_check", None, stream_names, {}, (True, None)), ("test_try_to_check_invalid stream", record, ["invalid_stream_name"], {}, None), ], ) @pytest.mark.parametrize("slices_as_list", [True, False]) -def test_check_stream_with_slices_as_list(test_name, record, streams_to_check, stream_slice, expectation, slices_as_list): +def test_check_stream_with_slices_as_list( + test_name, record, streams_to_check, stream_slice, expectation, slices_as_list +): stream = MagicMock() stream.name = "s1" stream.availability_strategy = None @@ -53,7 +61,11 @@ def test_check_stream_with_slices_as_list(test_name, record, streams_to_check, s def mock_read_records(responses, default_response=None, **kwargs): - return lambda stream_slice, sync_mode: responses[frozenset(stream_slice)] if frozenset(stream_slice) in responses else default_response + return ( + lambda stream_slice, sync_mode: responses[frozenset(stream_slice)] + if frozenset(stream_slice) in responses + else default_response + ) def test_check_empty_stream(): @@ -87,7 +99,12 @@ def test_check_stream_with_no_stream_slices_aborts(): @pytest.mark.parametrize( "test_name, response_code, available_expectation, expected_messages", [ - ("test_stream_unavailable_unhandled_error", 404, False, ["Not found. The requested resource was not found on the server."]), + ( + "test_stream_unavailable_unhandled_error", + 404, + False, + ["Not found. The requested resource was not found on the server."], + ), ( "test_stream_unavailable_handled_error", 403, @@ -97,7 +114,9 @@ def test_check_stream_with_no_stream_slices_aborts(): ("test_stream_available", 200, True, []), ], ) -def test_check_http_stream_via_availability_strategy(mocker, test_name, response_code, available_expectation, expected_messages): +def test_check_http_stream_via_availability_strategy( + mocker, test_name, response_code, available_expectation, expected_messages +): class MockHttpStream(HttpStream): url_base = "https://test_base_url.com" primary_key = "" diff --git a/unit_tests/sources/declarative/concurrency_level/test_concurrency_level.py b/unit_tests/sources/declarative/concurrency_level/test_concurrency_level.py index b806edba..5858b680 100644 --- a/unit_tests/sources/declarative/concurrency_level/test_concurrency_level.py +++ b/unit_tests/sources/declarative/concurrency_level/test_concurrency_level.py @@ -11,15 +11,35 @@ [ pytest.param(20, 75, 20, id="test_default_concurrency_as_int"), pytest.param(20, 75, 20, id="test_default_concurrency_as_int_ignores_max_concurrency"), - pytest.param("{{ config['num_workers'] or 40 }}", 75, 50, id="test_default_concurrency_using_interpolation"), - pytest.param("{{ config['missing'] or 40 }}", 75, 40, id="test_default_concurrency_using_interpolation_no_value"), - pytest.param("{{ config['num_workers'] or 40 }}", 10, 10, id="test_use_max_concurrency_if_default_is_too_high"), + pytest.param( + "{{ config['num_workers'] or 40 }}", + 75, + 50, + id="test_default_concurrency_using_interpolation", + ), + pytest.param( + "{{ config['missing'] or 40 }}", + 75, + 40, + id="test_default_concurrency_using_interpolation_no_value", + ), + pytest.param( + "{{ config['num_workers'] or 40 }}", + 10, + 10, + id="test_use_max_concurrency_if_default_is_too_high", + ), ], ) -def test_stream_slices(default_concurrency: Union[int, str], max_concurrency: int, expected_concurrency: int) -> None: +def test_stream_slices( + default_concurrency: Union[int, str], max_concurrency: int, expected_concurrency: int +) -> None: config = {"num_workers": 50} concurrency_level = ConcurrencyLevel( - default_concurrency=default_concurrency, max_concurrency=max_concurrency, config=config, parameters={} + default_concurrency=default_concurrency, + max_concurrency=max_concurrency, + config=config, + parameters={}, ) actual_concurrency = concurrency_level.get_concurrency_level() @@ -30,7 +50,12 @@ def test_stream_slices(default_concurrency: Union[int, str], max_concurrency: in @pytest.mark.parametrize( "config, expected_concurrency, expected_error", [ - pytest.param({"num_workers": "fifty five"}, None, ValueError, id="test_invalid_default_concurrency_as_string"), + pytest.param( + {"num_workers": "fifty five"}, + None, + ValueError, + id="test_invalid_default_concurrency_as_string", + ), pytest.param({"num_workers": "55"}, 55, None, id="test_default_concurrency_as_string_int"), pytest.param({"num_workers": 60}, 60, None, id="test_default_concurrency_as_int"), ], @@ -41,7 +66,10 @@ def test_default_concurrency_input_types_and_errors( expected_error: Optional[Type[Exception]], ) -> None: concurrency_level = ConcurrencyLevel( - default_concurrency="{{ config['num_workers'] or 30 }}", max_concurrency=65, config=config, parameters={} + default_concurrency="{{ config['num_workers'] or 30 }}", + max_concurrency=65, + config=config, + parameters={}, ) if expected_error: @@ -57,4 +85,9 @@ def test_max_concurrency_is_required_for_default_concurrency_using_config() -> N config = {"num_workers": "50"} with pytest.raises(ValueError): - ConcurrencyLevel(default_concurrency="{{ config['num_workers'] or 40 }}", max_concurrency=None, config=config, parameters={}) + ConcurrencyLevel( + default_concurrency="{{ config['num_workers'] or 40 }}", + max_concurrency=None, + config=config, + parameters={}, + ) diff --git a/unit_tests/sources/declarative/datetime/test_datetime_parser.py b/unit_tests/sources/declarative/datetime/test_datetime_parser.py index 1a7d45f7..6cbe59c7 100644 --- a/unit_tests/sources/declarative/datetime/test_datetime_parser.py +++ b/unit_tests/sources/declarative/datetime/test_datetime_parser.py @@ -21,7 +21,9 @@ "test_parse_date_iso_with_timezone_not_utc", "2021-01-01T00:00:00.000000+0400", "%Y-%m-%dT%H:%M:%S.%f%z", - datetime.datetime(2021, 1, 1, 0, 0, tzinfo=datetime.timezone(datetime.timedelta(seconds=14400))), + datetime.datetime( + 2021, 1, 1, 0, 0, tzinfo=datetime.timezone(datetime.timedelta(seconds=14400)) + ), ), ( "test_parse_timestamp", @@ -41,7 +43,12 @@ "%ms", datetime.datetime(2021, 1, 1, 0, 0, 0, 1000, tzinfo=datetime.timezone.utc), ), - ("test_parse_date_ms", "20210101", "%Y%m%d", datetime.datetime(2021, 1, 1, 0, 0, tzinfo=datetime.timezone.utc)), + ( + "test_parse_date_ms", + "20210101", + "%Y%m%d", + datetime.datetime(2021, 1, 1, 0, 0, tzinfo=datetime.timezone.utc), + ), ], ) def test_parse_date(test_name, input_date, date_format, expected_output_date): @@ -53,16 +60,36 @@ def test_parse_date(test_name, input_date, date_format, expected_output_date): @pytest.mark.parametrize( "test_name, input_dt, datetimeformat, expected_output", [ - ("test_format_timestamp", datetime.datetime(2021, 1, 1, 0, 0, tzinfo=datetime.timezone.utc), "%s", "1609459200"), - ("test_format_timestamp_ms", datetime.datetime(2021, 1, 1, 0, 0, 0, 1000, tzinfo=datetime.timezone.utc), "%ms", "1609459200001"), + ( + "test_format_timestamp", + datetime.datetime(2021, 1, 1, 0, 0, tzinfo=datetime.timezone.utc), + "%s", + "1609459200", + ), + ( + "test_format_timestamp_ms", + datetime.datetime(2021, 1, 1, 0, 0, 0, 1000, tzinfo=datetime.timezone.utc), + "%ms", + "1609459200001", + ), ( "test_format_timestamp_as_float", datetime.datetime(2023, 1, 30, 15, 28, 28, 873709, tzinfo=datetime.timezone.utc), "%s_as_float", "1675092508.873709", ), - ("test_format_string", datetime.datetime(2021, 1, 1, 0, 0, tzinfo=datetime.timezone.utc), "%Y-%m-%d", "2021-01-01"), - ("test_format_to_number", datetime.datetime(2021, 1, 1, 0, 0, tzinfo=datetime.timezone.utc), "%Y%m%d", "20210101"), + ( + "test_format_string", + datetime.datetime(2021, 1, 1, 0, 0, tzinfo=datetime.timezone.utc), + "%Y-%m-%d", + "2021-01-01", + ), + ( + "test_format_to_number", + datetime.datetime(2021, 1, 1, 0, 0, tzinfo=datetime.timezone.utc), + "%Y%m%d", + "20210101", + ), ], ) def test_format_datetime(test_name, input_dt, datetimeformat, expected_output): diff --git a/unit_tests/sources/declarative/datetime/test_min_max_datetime.py b/unit_tests/sources/declarative/datetime/test_min_max_datetime.py index ff9aedf0..848d673b 100644 --- a/unit_tests/sources/declarative/datetime/test_min_max_datetime.py +++ b/unit_tests/sources/declarative/datetime/test_min_max_datetime.py @@ -18,12 +18,48 @@ @pytest.mark.parametrize( "test_name, date, min_date, max_date, expected_date", [ - ("test_time_is_greater_than_min", "{{ config['older'] }}", "{{ stream_state['newer'] }}", "", new_date), - ("test_time_is_less_than_min", "{{ stream_state['newer'] }}", "{{ config['older'] }}", "", new_date), - ("test_time_is_equal_to_min", "{{ config['older'] }}", "{{ config['older'] }}", "", old_date), - ("test_time_is_greater_than_max", "{{ stream_state['newer'] }}", "", "{{ config['older'] }}", old_date), - ("test_time_is_less_than_max", "{{ config['older'] }}", "", "{{ stream_state['newer'] }}", old_date), - ("test_time_is_equal_to_min", "{{ stream_state['newer'] }}", "{{ stream_state['newer'] }}", "", new_date), + ( + "test_time_is_greater_than_min", + "{{ config['older'] }}", + "{{ stream_state['newer'] }}", + "", + new_date, + ), + ( + "test_time_is_less_than_min", + "{{ stream_state['newer'] }}", + "{{ config['older'] }}", + "", + new_date, + ), + ( + "test_time_is_equal_to_min", + "{{ config['older'] }}", + "{{ config['older'] }}", + "", + old_date, + ), + ( + "test_time_is_greater_than_max", + "{{ stream_state['newer'] }}", + "", + "{{ config['older'] }}", + old_date, + ), + ( + "test_time_is_less_than_max", + "{{ config['older'] }}", + "", + "{{ stream_state['newer'] }}", + old_date, + ), + ( + "test_time_is_equal_to_min", + "{{ stream_state['newer'] }}", + "{{ stream_state['newer'] }}", + "", + new_date, + ), ( "test_time_is_between_min_and_max", "{{ config['middle'] }}", @@ -31,8 +67,20 @@ "{{ stream_state['newer'] }}", middle_date, ), - ("test_min_newer_time_from_parameters", "{{ config['older'] }}", "{{ parameters['newer'] }}", "", new_date), - ("test_max_newer_time_from_parameters", "{{ stream_state['newer'] }}", "", "{{ parameters['older'] }}", old_date), + ( + "test_min_newer_time_from_parameters", + "{{ config['older'] }}", + "{{ parameters['newer'] }}", + "", + new_date, + ), + ( + "test_max_newer_time_from_parameters", + "{{ stream_state['newer'] }}", + "", + "{{ parameters['older'] }}", + old_date, + ), ], ) def test_min_max_datetime(test_name, date, min_date, max_date, expected_date): @@ -40,7 +88,9 @@ def test_min_max_datetime(test_name, date, min_date, max_date, expected_date): stream_state = {"newer": new_date} parameters = {"newer": new_date, "older": old_date} - min_max_date = MinMaxDatetime(datetime=date, min_datetime=min_date, max_datetime=max_date, parameters=parameters) + min_max_date = MinMaxDatetime( + datetime=date, min_datetime=min_date, max_datetime=max_date, parameters=parameters + ) actual_date = min_max_date.get_datetime(config, **{"stream_state": stream_state}) assert actual_date == datetime.datetime.strptime(expected_date, date_format) @@ -59,7 +109,9 @@ def test_custom_datetime_format(): ) actual_date = min_max_date.get_datetime(config, **{"stream_state": stream_state}) - assert actual_date == datetime.datetime.strptime("2022-01-01T20:12:19", "%Y-%m-%dT%H:%M:%S").replace(tzinfo=datetime.timezone.utc) + assert actual_date == datetime.datetime.strptime( + "2022-01-01T20:12:19", "%Y-%m-%dT%H:%M:%S" + ).replace(tzinfo=datetime.timezone.utc) def test_format_is_a_number(): @@ -75,17 +127,26 @@ def test_format_is_a_number(): ) actual_date = min_max_date.get_datetime(config, **{"stream_state": stream_state}) - assert actual_date == datetime.datetime.strptime("20220101", "%Y%m%d").replace(tzinfo=datetime.timezone.utc) + assert actual_date == datetime.datetime.strptime("20220101", "%Y%m%d").replace( + tzinfo=datetime.timezone.utc + ) def test_set_datetime_format(): - min_max_date = MinMaxDatetime(datetime="{{ config['middle'] }}", min_datetime="{{ config['older'] }}", parameters={}) + min_max_date = MinMaxDatetime( + datetime="{{ config['middle'] }}", min_datetime="{{ config['older'] }}", parameters={} + ) # Retrieve datetime using the default datetime formatting - default_fmt_config = {"older": "2021-01-01T20:12:19.597854Z", "middle": "2022-01-01T20:12:19.597854Z"} + default_fmt_config = { + "older": "2021-01-01T20:12:19.597854Z", + "middle": "2022-01-01T20:12:19.597854Z", + } actual_date = min_max_date.get_datetime(default_fmt_config) - assert actual_date == datetime.datetime.strptime("2022-01-01T20:12:19.597854Z", "%Y-%m-%dT%H:%M:%S.%f%z") + assert actual_date == datetime.datetime.strptime( + "2022-01-01T20:12:19.597854Z", "%Y-%m-%dT%H:%M:%S.%f%z" + ) # Set a different datetime format and attempt to retrieve datetime using an updated format min_max_date.datetime_format = "%Y-%m-%dT%H:%M:%S" @@ -93,7 +154,9 @@ def test_set_datetime_format(): custom_fmt_config = {"older": "2021-01-01T20:12:19", "middle": "2022-01-01T20:12:19"} actual_date = min_max_date.get_datetime(custom_fmt_config) - assert actual_date == datetime.datetime.strptime("2022-01-01T20:12:19", "%Y-%m-%dT%H:%M:%S").replace(tzinfo=datetime.timezone.utc) + assert actual_date == datetime.datetime.strptime( + "2022-01-01T20:12:19", "%Y-%m-%dT%H:%M:%S" + ).replace(tzinfo=datetime.timezone.utc) def test_min_max_datetime_lazy_eval(): @@ -104,7 +167,9 @@ def test_min_max_datetime_lazy_eval(): "max_datetime": "{{ parameters.max_datetime }}", } - assert datetime.datetime(2022, 1, 10, 0, 0, tzinfo=datetime.timezone.utc) == MinMaxDatetime(**kwargs, parameters={}).get_datetime({}) + assert datetime.datetime(2022, 1, 10, 0, 0, tzinfo=datetime.timezone.utc) == MinMaxDatetime( + **kwargs, parameters={} + ).get_datetime({}) assert datetime.datetime(2022, 1, 20, 0, 0, tzinfo=datetime.timezone.utc) == MinMaxDatetime( **kwargs, parameters={"min_datetime": "2022-01-20T00:00:00"} ).get_datetime({}) @@ -117,8 +182,14 @@ def test_min_max_datetime_lazy_eval(): "input_datetime", [ pytest.param("2022-01-01T00:00:00", id="test_create_min_max_datetime_from_string"), - pytest.param(InterpolatedString.create("2022-01-01T00:00:00", parameters={}), id="test_create_min_max_datetime_from_string"), - pytest.param(MinMaxDatetime("2022-01-01T00:00:00", parameters={}), id="test_create_min_max_datetime_from_minmaxdatetime"), + pytest.param( + InterpolatedString.create("2022-01-01T00:00:00", parameters={}), + id="test_create_min_max_datetime_from_string", + ), + pytest.param( + MinMaxDatetime("2022-01-01T00:00:00", parameters={}), + id="test_create_min_max_datetime_from_minmaxdatetime", + ), ], ) def test_create_min_max_datetime(input_datetime): diff --git a/unit_tests/sources/declarative/decoders/test_json_decoder.py b/unit_tests/sources/declarative/decoders/test_json_decoder.py index 1b9a552d..861b6e27 100644 --- a/unit_tests/sources/declarative/decoders/test_json_decoder.py +++ b/unit_tests/sources/declarative/decoders/test_json_decoder.py @@ -10,12 +10,18 @@ from airbyte_cdk.models import SyncMode from airbyte_cdk.sources.declarative.decoders.json_decoder import JsonDecoder, JsonlDecoder from airbyte_cdk.sources.declarative.models import DeclarativeStream as DeclarativeStreamModel -from airbyte_cdk.sources.declarative.parsers.model_to_component_factory import ModelToComponentFactory +from airbyte_cdk.sources.declarative.parsers.model_to_component_factory import ( + ModelToComponentFactory, +) @pytest.mark.parametrize( "response_body, first_element", - [("", {}), ("[]", {}), ('{"healthcheck": {"status": "ok"}}', {"healthcheck": {"status": "ok"}})], + [ + ("", {}), + ("[]", {}), + ('{"healthcheck": {"status": "ok"}}', {"healthcheck": {"status": "ok"}}), + ], ) def test_json_decoder(requests_mock, response_body, first_element): requests_mock.register_uri("GET", "https://airbyte.io/", text=response_body) @@ -28,7 +34,10 @@ def test_json_decoder(requests_mock, response_body, first_element): [ ("", []), ('{"id": 1, "name": "test1"}', [{"id": 1, "name": "test1"}]), - ('{"id": 1, "name": "test1"}\n{"id": 2, "name": "test2"}', [{"id": 1, "name": "test1"}, {"id": 2, "name": "test2"}]), + ( + '{"id": 1, "name": "test1"}\n{"id": 2, "name": "test2"}', + [{"id": 1, "name": "test1"}, {"id": 2, "name": "test2"}], + ), ], ids=["empty_response", "one_line_json", "multi_line_json"], ) @@ -92,7 +101,9 @@ def test_jsonl_decoder_memory_usage(requests_mock, large_events_response): factory = ModelToComponentFactory() stream_manifest = YamlDeclarativeSource._parse(content) - stream = factory.create_component(model_type=DeclarativeStreamModel, component_definition=stream_manifest, config={}) + stream = factory.create_component( + model_type=DeclarativeStreamModel, component_definition=stream_manifest, config={} + ) def get_body(): return open(file_path, "rb", buffering=30) diff --git a/unit_tests/sources/declarative/decoders/test_pagination_decoder_decorator.py b/unit_tests/sources/declarative/decoders/test_pagination_decoder_decorator.py index 70fc26d1..022482e7 100644 --- a/unit_tests/sources/declarative/decoders/test_pagination_decoder_decorator.py +++ b/unit_tests/sources/declarative/decoders/test_pagination_decoder_decorator.py @@ -11,7 +11,10 @@ def is_stream_response(self) -> bool: return True -@pytest.mark.parametrize("decoder_class, expected", [(StreamingJsonDecoder, {}), (JsonDecoder, {"data": [{"id": 1}, {"id": 2}]})]) +@pytest.mark.parametrize( + "decoder_class, expected", + [(StreamingJsonDecoder, {}), (JsonDecoder, {"data": [{"id": 1}, {"id": 2}]})], +) def test_pagination_decoder_decorator(requests_mock, decoder_class, expected): decoder = PaginationDecoderDecorator(decoder=decoder_class(parameters={})) response_body = '{"data": [{"id": 1}, {"id": 2}]}' diff --git a/unit_tests/sources/declarative/decoders/test_xml_decoder.py b/unit_tests/sources/declarative/decoders/test_xml_decoder.py index 43e970ad..c6295cff 100644 --- a/unit_tests/sources/declarative/decoders/test_xml_decoder.py +++ b/unit_tests/sources/declarative/decoders/test_xml_decoder.py @@ -12,7 +12,14 @@ ('', {"item": {"@name": "item_1"}}), ( 'Item 1Item 2', - {"data": {"item": [{"@name": "item_1", "#text": "Item 1"}, {"@name": "item_2", "#text": "Item 2"}]}}, + { + "data": { + "item": [ + {"@name": "item_1", "#text": "Item 1"}, + {"@name": "item_2", "#text": "Item 2"}, + ] + } + }, ), (None, {}), ('', {}), @@ -21,7 +28,13 @@ {"item": {"@xmlns:ns": "https://airbyte.io", "ns:id": "1", "ns:name": "Item 1"}}, ), ], - ids=["one_element_response", "multi_element_response", "empty_response", "malformed_xml_response", "xml_with_namespace_response"], + ids=[ + "one_element_response", + "multi_element_response", + "empty_response", + "malformed_xml_response", + "xml_with_namespace_response", + ], ) def test_xml_decoder(requests_mock, response_body, expected): requests_mock.register_uri("GET", "https://airbyte.io/", text=response_body) diff --git a/unit_tests/sources/declarative/extractors/test_dpath_extractor.py b/unit_tests/sources/declarative/extractors/test_dpath_extractor.py index 92b4ffbb..c5c40dd2 100644 --- a/unit_tests/sources/declarative/extractors/test_dpath_extractor.py +++ b/unit_tests/sources/declarative/extractors/test_dpath_extractor.py @@ -8,7 +8,11 @@ import pytest import requests from airbyte_cdk import Decoder -from airbyte_cdk.sources.declarative.decoders.json_decoder import IterableDecoder, JsonDecoder, JsonlDecoder +from airbyte_cdk.sources.declarative.decoders.json_decoder import ( + IterableDecoder, + JsonDecoder, + JsonlDecoder, +) from airbyte_cdk.sources.declarative.extractors.dpath_extractor import DpathExtractor config = {"field": "record_array"} @@ -32,15 +36,35 @@ def create_response(body: Union[Dict, bytes]): (["data"], decoder_json, {"data": {"id": 1}}, [{"id": 1}]), ([], decoder_json, {"id": 1}, [{"id": 1}]), ([], decoder_json, [{"id": 1}, {"id": 2}], [{"id": 1}, {"id": 2}]), - (["data", "records"], decoder_json, {"data": {"records": [{"id": 1}, {"id": 2}]}}, [{"id": 1}, {"id": 2}]), - (["{{ config['field'] }}"], decoder_json, {"record_array": [{"id": 1}, {"id": 2}]}, [{"id": 1}, {"id": 2}]), - (["{{ parameters['parameters_field'] }}"], decoder_json, {"record_array": [{"id": 1}, {"id": 2}]}, [{"id": 1}, {"id": 2}]), + ( + ["data", "records"], + decoder_json, + {"data": {"records": [{"id": 1}, {"id": 2}]}}, + [{"id": 1}, {"id": 2}], + ), + ( + ["{{ config['field'] }}"], + decoder_json, + {"record_array": [{"id": 1}, {"id": 2}]}, + [{"id": 1}, {"id": 2}], + ), + ( + ["{{ parameters['parameters_field'] }}"], + decoder_json, + {"record_array": [{"id": 1}, {"id": 2}]}, + [{"id": 1}, {"id": 2}], + ), (["record"], decoder_json, {"id": 1}, []), (["list", "*", "item"], decoder_json, {"list": [{"item": {"id": "1"}}]}, [{"id": "1"}]), ( ["data", "*", "list", "data2", "*"], decoder_json, - {"data": [{"list": {"data2": [{"id": 1}, {"id": 2}]}}, {"list": {"data2": [{"id": 3}, {"id": 4}]}}]}, + { + "data": [ + {"list": {"data2": [{"id": 1}, {"id": 2}]}}, + {"list": {"data2": [{"id": 3}, {"id": 4}]}}, + ] + }, [{"id": 1}, {"id": 2}, {"id": 3}, {"id": 4}], ), ([], decoder_jsonl, {"id": 1}, [{"id": 1}]), @@ -88,7 +112,9 @@ def create_response(body: Union[Dict, bytes]): ], ) def test_dpath_extractor(field_path: List, decoder: Decoder, body, expected_records: List): - extractor = DpathExtractor(field_path=field_path, config=config, decoder=decoder, parameters=parameters) + extractor = DpathExtractor( + field_path=field_path, config=config, decoder=decoder, parameters=parameters + ) response = create_response(body) actual_records = list(extractor.extract_records(response)) diff --git a/unit_tests/sources/declarative/extractors/test_record_filter.py b/unit_tests/sources/declarative/extractors/test_record_filter.py index 5e73d78e..c4824c64 100644 --- a/unit_tests/sources/declarative/extractors/test_record_filter.py +++ b/unit_tests/sources/declarative/extractors/test_record_filter.py @@ -5,7 +5,10 @@ import pytest from airbyte_cdk.sources.declarative.datetime import MinMaxDatetime -from airbyte_cdk.sources.declarative.extractors.record_filter import ClientSideIncrementalRecordFilterDecorator, RecordFilter +from airbyte_cdk.sources.declarative.extractors.record_filter import ( + ClientSideIncrementalRecordFilterDecorator, + RecordFilter, +) from airbyte_cdk.sources.declarative.incremental import ( CursorFactory, DatetimeBasedCursor, @@ -13,7 +16,11 @@ PerPartitionWithGlobalCursor, ) from airbyte_cdk.sources.declarative.interpolation import InterpolatedString -from airbyte_cdk.sources.declarative.models import CustomRetriever, DeclarativeStream, ParentStreamConfig +from airbyte_cdk.sources.declarative.models import ( + CustomRetriever, + DeclarativeStream, + ParentStreamConfig, +) from airbyte_cdk.sources.declarative.partition_routers import SubstreamPartitionRouter from airbyte_cdk.sources.declarative.types import StreamSlice @@ -48,12 +55,20 @@ [ ( "{{ record['created_at'] > stream_state['created_at'] }}", - [{"id": 1, "created_at": "06-06-21"}, {"id": 2, "created_at": "06-07-21"}, {"id": 3, "created_at": "06-08-21"}], + [ + {"id": 1, "created_at": "06-06-21"}, + {"id": 2, "created_at": "06-07-21"}, + {"id": 3, "created_at": "06-08-21"}, + ], [{"id": 2, "created_at": "06-07-21"}, {"id": 3, "created_at": "06-08-21"}], ), ( "{{ record['last_seen'] >= stream_slice['last_seen'] }}", - [{"id": 1, "last_seen": "06-06-21"}, {"id": 2, "last_seen": "06-07-21"}, {"id": 3, "last_seen": "06-10-21"}], + [ + {"id": 1, "last_seen": "06-06-21"}, + {"id": 2, "last_seen": "06-07-21"}, + {"id": 3, "last_seen": "06-10-21"}, + ], [{"id": 3, "last_seen": "06-10-21"}], ), ( @@ -68,12 +83,20 @@ ), ( "{{ record['created_at'] > parameters['created_at'] }}", - [{"id": 1, "created_at": "06-06-21"}, {"id": 2, "created_at": "06-07-21"}, {"id": 3, "created_at": "06-08-21"}], + [ + {"id": 1, "created_at": "06-06-21"}, + {"id": 2, "created_at": "06-07-21"}, + {"id": 3, "created_at": "06-08-21"}, + ], [{"id": 3, "created_at": "06-08-21"}], ), ( "{{ record['created_at'] > stream_slice.extra_fields['created_at'] }}", - [{"id": 1, "created_at": "06-06-21"}, {"id": 2, "created_at": "06-07-21"}, {"id": 3, "created_at": "06-08-21"}], + [ + {"id": 1, "created_at": "06-06-21"}, + {"id": 2, "created_at": "06-07-21"}, + {"id": 3, "created_at": "06-08-21"}, + ], [{"id": 3, "created_at": "06-08-21"}], ), ], @@ -86,16 +109,27 @@ "test_using_extra_fields_filter", ], ) -def test_record_filter(filter_template: str, records: List[Mapping], expected_records: List[Mapping]): +def test_record_filter( + filter_template: str, records: List[Mapping], expected_records: List[Mapping] +): config = {"response_override": "stop_if_you_see_me"} parameters = {"created_at": "06-07-21"} stream_state = {"created_at": "06-06-21"} - stream_slice = StreamSlice(partition={}, cursor_slice={"last_seen": "06-10-21"}, extra_fields={"created_at": "06-07-21"}) + stream_slice = StreamSlice( + partition={}, + cursor_slice={"last_seen": "06-10-21"}, + extra_fields={"created_at": "06-07-21"}, + ) next_page_token = {"last_seen_id": 14} record_filter = RecordFilter(config=config, condition=filter_template, parameters=parameters) actual_records = list( - record_filter.filter_records(records, stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token) + record_filter.filter_records( + records, + stream_state=stream_state, + stream_slice=stream_slice, + next_page_token=next_page_token, + ) ) assert actual_records == expected_records @@ -105,11 +139,46 @@ def test_record_filter(filter_template: str, records: List[Mapping], expected_re [ (DATE_FORMAT, {}, None, "2021-01-05", RECORDS_TO_FILTER_DATE_FORMAT, [2, 3, 5]), (DATE_FORMAT, {}, None, None, RECORDS_TO_FILTER_DATE_FORMAT, [2, 3, 4, 5]), - (DATE_FORMAT, {"created_at": "2021-01-04"}, None, "2021-01-05", RECORDS_TO_FILTER_DATE_FORMAT, [3]), - (DATE_FORMAT, {"created_at": "2021-01-04"}, None, None, RECORDS_TO_FILTER_DATE_FORMAT, [3, 4]), - (DATE_FORMAT, {}, "{{ record['id'] % 2 == 1 }}", "2021-01-05", RECORDS_TO_FILTER_DATE_FORMAT, [3, 5]), - (DATE_TIME_WITH_TZ_FORMAT, {}, None, "2021-01-05T00:00:00+00:00", RECORDS_TO_FILTER_DATE_TIME_WITH_TZ_FORMAT, [2, 3]), - (DATE_TIME_WITH_TZ_FORMAT, {}, None, None, RECORDS_TO_FILTER_DATE_TIME_WITH_TZ_FORMAT, [2, 3, 4]), + ( + DATE_FORMAT, + {"created_at": "2021-01-04"}, + None, + "2021-01-05", + RECORDS_TO_FILTER_DATE_FORMAT, + [3], + ), + ( + DATE_FORMAT, + {"created_at": "2021-01-04"}, + None, + None, + RECORDS_TO_FILTER_DATE_FORMAT, + [3, 4], + ), + ( + DATE_FORMAT, + {}, + "{{ record['id'] % 2 == 1 }}", + "2021-01-05", + RECORDS_TO_FILTER_DATE_FORMAT, + [3, 5], + ), + ( + DATE_TIME_WITH_TZ_FORMAT, + {}, + None, + "2021-01-05T00:00:00+00:00", + RECORDS_TO_FILTER_DATE_TIME_WITH_TZ_FORMAT, + [2, 3], + ), + ( + DATE_TIME_WITH_TZ_FORMAT, + {}, + None, + None, + RECORDS_TO_FILTER_DATE_TIME_WITH_TZ_FORMAT, + [2, 3, 4], + ), ( DATE_TIME_WITH_TZ_FORMAT, {"created_at": "2021-01-04T00:00:00+00:00"}, @@ -134,8 +203,22 @@ def test_record_filter(filter_template: str, records: List[Mapping], expected_re RECORDS_TO_FILTER_DATE_TIME_WITH_TZ_FORMAT, [3], ), - (DATE_TIME_WITHOUT_TZ_FORMAT, {}, None, "2021-01-05T00:00:00", RECORDS_TO_FILTER_DATE_TIME_WITHOUT_TZ_FORMAT, [2, 3]), - (DATE_TIME_WITHOUT_TZ_FORMAT, {}, None, None, RECORDS_TO_FILTER_DATE_TIME_WITHOUT_TZ_FORMAT, [2, 3, 4]), + ( + DATE_TIME_WITHOUT_TZ_FORMAT, + {}, + None, + "2021-01-05T00:00:00", + RECORDS_TO_FILTER_DATE_TIME_WITHOUT_TZ_FORMAT, + [2, 3], + ), + ( + DATE_TIME_WITHOUT_TZ_FORMAT, + {}, + None, + None, + RECORDS_TO_FILTER_DATE_TIME_WITHOUT_TZ_FORMAT, + [2, 3, 4], + ), ( DATE_TIME_WITHOUT_TZ_FORMAT, {"created_at": "2021-01-04T00:00:00"}, @@ -188,7 +271,9 @@ def test_client_side_record_filter_decorator_no_parent_stream( expected_record_ids: List[int], ): date_time_based_cursor = DatetimeBasedCursor( - start_datetime=MinMaxDatetime(datetime="2021-01-01", datetime_format=DATE_FORMAT, parameters={}), + start_datetime=MinMaxDatetime( + datetime="2021-01-01", datetime_format=DATE_FORMAT, parameters={} + ), end_datetime=MinMaxDatetime(datetime=end_datetime, parameters={}) if end_datetime else None, step="P10Y", cursor_field=InterpolatedString.create("created_at", parameters={}), @@ -208,7 +293,12 @@ def test_client_side_record_filter_decorator_no_parent_stream( ) filtered_records = list( - record_filter_decorator.filter_records(records=records_to_filter, stream_state=stream_state, stream_slice={}, next_page_token=None) + record_filter_decorator.filter_records( + records=records_to_filter, + stream_state=stream_state, + stream_slice={}, + next_page_token=None, + ) ) assert [x.get("id") for x in filtered_records] == expected_record_ids @@ -228,7 +318,12 @@ def test_client_side_record_filter_decorator_no_parent_stream( { "use_global_cursor": False, "state": {"created_at": "2021-01-10"}, - "states": [{"partition": {"id": "some_parent_id", "parent_slice": {}}, "cursor": {"created_at": "2021-01-03"}}], + "states": [ + { + "partition": {"id": "some_parent_id", "parent_slice": {}}, + "cursor": {"created_at": "2021-01-03"}, + } + ], }, "per_partition_with_global", [2, 3], @@ -238,13 +333,22 @@ def test_client_side_record_filter_decorator_no_parent_stream( { "use_global_cursor": True, "state": {"created_at": "2021-01-03"}, - "states": [{"partition": {"id": "some_parent_id", "parent_slice": {}}, "cursor": {"created_at": "2021-01-13"}}], + "states": [ + { + "partition": {"id": "some_parent_id", "parent_slice": {}}, + "cursor": {"created_at": "2021-01-13"}, + } + ], }, "per_partition_with_global", [2, 3], ), # Use PerPartitionWithGlobalCursor with partition state missing, global cursor used - ({"use_global_cursor": True, "state": {"created_at": "2021-01-03"}}, "per_partition_with_global", [2, 3]), + ( + {"use_global_cursor": True, "state": {"created_at": "2021-01-03"}}, + "per_partition_with_global", + [2, 3], + ), # Use PerPartitionWithGlobalCursor with partition state missing, global cursor not used ( {"use_global_cursor": False, "state": {"created_at": "2021-01-03"}}, @@ -267,8 +371,12 @@ def test_client_side_record_filter_decorator_with_cursor_types( ): def date_time_based_cursor_factory() -> DatetimeBasedCursor: return DatetimeBasedCursor( - start_datetime=MinMaxDatetime(datetime="2021-01-01", datetime_format=DATE_FORMAT, parameters={}), - end_datetime=MinMaxDatetime(datetime="2021-01-05", datetime_format=DATE_FORMAT, parameters={}), + start_datetime=MinMaxDatetime( + datetime="2021-01-01", datetime_format=DATE_FORMAT, parameters={} + ), + end_datetime=MinMaxDatetime( + datetime="2021-01-05", datetime_format=DATE_FORMAT, parameters={} + ), step="P10Y", cursor_field=InterpolatedString.create("created_at", parameters={}), datetime_format=DATE_FORMAT, @@ -289,7 +397,8 @@ def date_time_based_cursor_factory() -> DatetimeBasedCursor: parent_key="id", partition_field="id", stream=DeclarativeStream( - type="DeclarativeStream", retriever=CustomRetriever(type="CustomRetriever", class_name="a_class_name") + type="DeclarativeStream", + retriever=CustomRetriever(type="CustomRetriever", class_name="a_class_name"), ), ) ], @@ -330,7 +439,9 @@ def date_time_based_cursor_factory() -> DatetimeBasedCursor: ) # The partition we're testing - stream_slice = StreamSlice(partition={"id": "some_parent_id", "parent_slice": {}}, cursor_slice={}) + stream_slice = StreamSlice( + partition={"id": "some_parent_id", "parent_slice": {}}, cursor_slice={} + ) filtered_records = list( record_filter_decorator.filter_records( diff --git a/unit_tests/sources/declarative/extractors/test_record_selector.py b/unit_tests/sources/declarative/extractors/test_record_selector.py index fc2bcd6d..a83586f7 100644 --- a/unit_tests/sources/declarative/extractors/test_record_selector.py +++ b/unit_tests/sources/declarative/extractors/test_record_selector.py @@ -23,7 +23,13 @@ "test_with_extractor_and_filter", ["data"], "{{ record['created_at'] > stream_state['created_at'] }}", - {"data": [{"id": 1, "created_at": "06-06-21"}, {"id": 2, "created_at": "06-07-21"}, {"id": 3, "created_at": "06-08-21"}]}, + { + "data": [ + {"id": 1, "created_at": "06-06-21"}, + {"id": 2, "created_at": "06-07-21"}, + {"id": 3, "created_at": "06-08-21"}, + ] + }, [{"id": 2, "created_at": "06-07-21"}, {"id": 3, "created_at": "06-08-21"}], ), ( @@ -37,7 +43,13 @@ "test_with_extractor_and_filter_with_parameters", ["{{ parameters['parameters_field'] }}"], "{{ record['created_at'] > parameters['created_at'] }}", - {"data": [{"id": 1, "created_at": "06-06-21"}, {"id": 2, "created_at": "06-07-21"}, {"id": 3, "created_at": "06-08-21"}]}, + { + "data": [ + {"id": 1, "created_at": "06-06-21"}, + {"id": 2, "created_at": "06-07-21"}, + {"id": 3, "created_at": "06-08-21"}, + ] + }, [{"id": 3, "created_at": "06-08-21"}], ), ( @@ -76,11 +88,15 @@ def test_record_filter(test_name, field_path, filter_template, body, expected_da response = create_response(body) decoder = JsonDecoder(parameters={}) - extractor = DpathExtractor(field_path=field_path, decoder=decoder, config=config, parameters=parameters) + extractor = DpathExtractor( + field_path=field_path, decoder=decoder, config=config, parameters=parameters + ) if filter_template is None: record_filter = None else: - record_filter = RecordFilter(config=config, condition=filter_template, parameters=parameters) + record_filter = RecordFilter( + config=config, condition=filter_template, parameters=parameters + ) record_selector = RecordSelector( extractor=extractor, record_filter=record_filter, @@ -92,14 +108,20 @@ def test_record_filter(test_name, field_path, filter_template, body, expected_da actual_records = list( record_selector.select_records( - response=response, records_schema=schema, stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token + response=response, + records_schema=schema, + stream_state=stream_state, + stream_slice=stream_slice, + next_page_token=next_page_token, ) ) assert actual_records == [Record(data, stream_slice) for data in expected_data] calls = [] for record in expected_data: - calls.append(call(record, config=config, stream_state=stream_state, stream_slice=stream_slice)) + calls.append( + call(record, config=config, stream_state=stream_state, stream_slice=stream_slice) + ) for transformation in transformations: assert transformation.transform.call_count == len(expected_data) transformation.transform.assert_has_calls(calls) @@ -112,21 +134,33 @@ def test_record_filter(test_name, field_path, filter_template, body, expected_da "test_with_empty_schema", {}, TransformConfig.NoTransform, - {"data": [{"id": 1, "created_at": "06-06-21", "field_int": "100", "field_float": "123.3"}]}, + { + "data": [ + {"id": 1, "created_at": "06-06-21", "field_int": "100", "field_float": "123.3"} + ] + }, [{"id": 1, "created_at": "06-06-21", "field_int": "100", "field_float": "123.3"}], ), ( "test_with_schema_none_normalizer", {}, TransformConfig.NoTransform, - {"data": [{"id": 1, "created_at": "06-06-21", "field_int": "100", "field_float": "123.3"}]}, + { + "data": [ + {"id": 1, "created_at": "06-06-21", "field_int": "100", "field_float": "123.3"} + ] + }, [{"id": 1, "created_at": "06-06-21", "field_int": "100", "field_float": "123.3"}], ), ( "test_with_schema_and_default_normalizer", {}, TransformConfig.DefaultSchemaNormalization, - {"data": [{"id": 1, "created_at": "06-06-21", "field_int": "100", "field_float": "123.3"}]}, + { + "data": [ + {"id": 1, "created_at": "06-06-21", "field_int": "100", "field_float": "123.3"} + ] + }, [{"id": "1", "created_at": "06-06-21", "field_int": 100, "field_float": 123.3}], ), ], @@ -141,7 +175,9 @@ def test_schema_normalization(test_name, schema, schema_transformation, body, ex response = create_response(body) schema = create_schema() decoder = JsonDecoder(parameters={}) - extractor = DpathExtractor(field_path=["data"], decoder=decoder, config=config, parameters=parameters) + extractor = DpathExtractor( + field_path=["data"], decoder=decoder, config=config, parameters=parameters + ) record_selector = RecordSelector( extractor=extractor, record_filter=None, diff --git a/unit_tests/sources/declarative/incremental/test_datetime_based_cursor.py b/unit_tests/sources/declarative/incremental/test_datetime_based_cursor.py index 5b89e04f..7b651e25 100644 --- a/unit_tests/sources/declarative/incremental/test_datetime_based_cursor.py +++ b/unit_tests/sources/declarative/incremental/test_datetime_based_cursor.py @@ -9,7 +9,10 @@ from airbyte_cdk.sources.declarative.datetime.min_max_datetime import MinMaxDatetime from airbyte_cdk.sources.declarative.incremental import DatetimeBasedCursor from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString -from airbyte_cdk.sources.declarative.requesters.request_option import RequestOption, RequestOptionType +from airbyte_cdk.sources.declarative.requesters.request_option import ( + RequestOption, + RequestOptionType, +) from airbyte_cdk.sources.types import Record, StreamSlice datetime_format = "%Y-%m-%dT%H:%M:%S.%f%z" @@ -50,16 +53,46 @@ def mock_datetime_now(monkeypatch): cursor_granularity, None, [ - {"start_time": "2021-01-01T00:00:00.000000+0000", "end_time": "2021-01-01T23:59:59.999999+0000"}, - {"start_time": "2021-01-02T00:00:00.000000+0000", "end_time": "2021-01-02T23:59:59.999999+0000"}, - {"start_time": "2021-01-03T00:00:00.000000+0000", "end_time": "2021-01-03T23:59:59.999999+0000"}, - {"start_time": "2021-01-04T00:00:00.000000+0000", "end_time": "2021-01-04T23:59:59.999999+0000"}, - {"start_time": "2021-01-05T00:00:00.000000+0000", "end_time": "2021-01-05T23:59:59.999999+0000"}, - {"start_time": "2021-01-06T00:00:00.000000+0000", "end_time": "2021-01-06T23:59:59.999999+0000"}, - {"start_time": "2021-01-07T00:00:00.000000+0000", "end_time": "2021-01-07T23:59:59.999999+0000"}, - {"start_time": "2021-01-08T00:00:00.000000+0000", "end_time": "2021-01-08T23:59:59.999999+0000"}, - {"start_time": "2021-01-09T00:00:00.000000+0000", "end_time": "2021-01-09T23:59:59.999999+0000"}, - {"start_time": "2021-01-10T00:00:00.000000+0000", "end_time": "2021-01-10T00:00:00.000000+0000"}, + { + "start_time": "2021-01-01T00:00:00.000000+0000", + "end_time": "2021-01-01T23:59:59.999999+0000", + }, + { + "start_time": "2021-01-02T00:00:00.000000+0000", + "end_time": "2021-01-02T23:59:59.999999+0000", + }, + { + "start_time": "2021-01-03T00:00:00.000000+0000", + "end_time": "2021-01-03T23:59:59.999999+0000", + }, + { + "start_time": "2021-01-04T00:00:00.000000+0000", + "end_time": "2021-01-04T23:59:59.999999+0000", + }, + { + "start_time": "2021-01-05T00:00:00.000000+0000", + "end_time": "2021-01-05T23:59:59.999999+0000", + }, + { + "start_time": "2021-01-06T00:00:00.000000+0000", + "end_time": "2021-01-06T23:59:59.999999+0000", + }, + { + "start_time": "2021-01-07T00:00:00.000000+0000", + "end_time": "2021-01-07T23:59:59.999999+0000", + }, + { + "start_time": "2021-01-08T00:00:00.000000+0000", + "end_time": "2021-01-08T23:59:59.999999+0000", + }, + { + "start_time": "2021-01-09T00:00:00.000000+0000", + "end_time": "2021-01-09T23:59:59.999999+0000", + }, + { + "start_time": "2021-01-10T00:00:00.000000+0000", + "end_time": "2021-01-10T00:00:00.000000+0000", + }, ], ), ( @@ -74,11 +107,26 @@ def mock_datetime_now(monkeypatch): cursor_granularity, None, [ - {"start_time": "2021-01-01T00:00:00.000000+0000", "end_time": "2021-01-02T23:59:59.999999+0000"}, - {"start_time": "2021-01-03T00:00:00.000000+0000", "end_time": "2021-01-04T23:59:59.999999+0000"}, - {"start_time": "2021-01-05T00:00:00.000000+0000", "end_time": "2021-01-06T23:59:59.999999+0000"}, - {"start_time": "2021-01-07T00:00:00.000000+0000", "end_time": "2021-01-08T23:59:59.999999+0000"}, - {"start_time": "2021-01-09T00:00:00.000000+0000", "end_time": "2021-01-10T00:00:00.000000+0000"}, + { + "start_time": "2021-01-01T00:00:00.000000+0000", + "end_time": "2021-01-02T23:59:59.999999+0000", + }, + { + "start_time": "2021-01-03T00:00:00.000000+0000", + "end_time": "2021-01-04T23:59:59.999999+0000", + }, + { + "start_time": "2021-01-05T00:00:00.000000+0000", + "end_time": "2021-01-06T23:59:59.999999+0000", + }, + { + "start_time": "2021-01-07T00:00:00.000000+0000", + "end_time": "2021-01-08T23:59:59.999999+0000", + }, + { + "start_time": "2021-01-09T00:00:00.000000+0000", + "end_time": "2021-01-10T00:00:00.000000+0000", + }, ], ), ( @@ -93,12 +141,30 @@ def mock_datetime_now(monkeypatch): cursor_granularity, None, [ - {"start_time": "2021-01-01T00:00:00.000000+0000", "end_time": "2021-01-07T23:59:59.999999+0000"}, - {"start_time": "2021-01-08T00:00:00.000000+0000", "end_time": "2021-01-14T23:59:59.999999+0000"}, - {"start_time": "2021-01-15T00:00:00.000000+0000", "end_time": "2021-01-21T23:59:59.999999+0000"}, - {"start_time": "2021-01-22T00:00:00.000000+0000", "end_time": "2021-01-28T23:59:59.999999+0000"}, - {"start_time": "2021-01-29T00:00:00.000000+0000", "end_time": "2021-02-04T23:59:59.999999+0000"}, - {"start_time": "2021-02-05T00:00:00.000000+0000", "end_time": "2021-02-10T00:00:00.000000+0000"}, + { + "start_time": "2021-01-01T00:00:00.000000+0000", + "end_time": "2021-01-07T23:59:59.999999+0000", + }, + { + "start_time": "2021-01-08T00:00:00.000000+0000", + "end_time": "2021-01-14T23:59:59.999999+0000", + }, + { + "start_time": "2021-01-15T00:00:00.000000+0000", + "end_time": "2021-01-21T23:59:59.999999+0000", + }, + { + "start_time": "2021-01-22T00:00:00.000000+0000", + "end_time": "2021-01-28T23:59:59.999999+0000", + }, + { + "start_time": "2021-01-29T00:00:00.000000+0000", + "end_time": "2021-02-04T23:59:59.999999+0000", + }, + { + "start_time": "2021-02-05T00:00:00.000000+0000", + "end_time": "2021-02-10T00:00:00.000000+0000", + }, ], ), ( @@ -113,12 +179,30 @@ def mock_datetime_now(monkeypatch): cursor_granularity, None, [ - {"start_time": "2021-01-01T00:00:00.000000+0000", "end_time": "2021-01-31T23:59:59.999999+0000"}, - {"start_time": "2021-02-01T00:00:00.000000+0000", "end_time": "2021-02-28T23:59:59.999999+0000"}, - {"start_time": "2021-03-01T00:00:00.000000+0000", "end_time": "2021-03-31T23:59:59.999999+0000"}, - {"start_time": "2021-04-01T00:00:00.000000+0000", "end_time": "2021-04-30T23:59:59.999999+0000"}, - {"start_time": "2021-05-01T00:00:00.000000+0000", "end_time": "2021-05-31T23:59:59.999999+0000"}, - {"start_time": "2021-06-01T00:00:00.000000+0000", "end_time": "2021-06-10T00:00:00.000000+0000"}, + { + "start_time": "2021-01-01T00:00:00.000000+0000", + "end_time": "2021-01-31T23:59:59.999999+0000", + }, + { + "start_time": "2021-02-01T00:00:00.000000+0000", + "end_time": "2021-02-28T23:59:59.999999+0000", + }, + { + "start_time": "2021-03-01T00:00:00.000000+0000", + "end_time": "2021-03-31T23:59:59.999999+0000", + }, + { + "start_time": "2021-04-01T00:00:00.000000+0000", + "end_time": "2021-04-30T23:59:59.999999+0000", + }, + { + "start_time": "2021-05-01T00:00:00.000000+0000", + "end_time": "2021-05-31T23:59:59.999999+0000", + }, + { + "start_time": "2021-06-01T00:00:00.000000+0000", + "end_time": "2021-06-10T00:00:00.000000+0000", + }, ], ), ( @@ -133,8 +217,14 @@ def mock_datetime_now(monkeypatch): cursor_granularity, None, [ - {"start_time": "2021-01-01T00:00:00.000000+0000", "end_time": "2021-12-31T23:59:59.999999+0000"}, - {"start_time": "2022-01-01T00:00:00.000000+0000", "end_time": "2022-01-01T00:00:00.000000+0000"}, + { + "start_time": "2021-01-01T00:00:00.000000+0000", + "end_time": "2021-12-31T23:59:59.999999+0000", + }, + { + "start_time": "2022-01-01T00:00:00.000000+0000", + "end_time": "2022-01-01T00:00:00.000000+0000", + }, ], ), ( @@ -149,12 +239,30 @@ def mock_datetime_now(monkeypatch): cursor_granularity, None, [ - {"start_time": "2021-01-05T00:00:00.000000+0000", "end_time": "2021-01-05T23:59:59.999999+0000"}, - {"start_time": "2021-01-06T00:00:00.000000+0000", "end_time": "2021-01-06T23:59:59.999999+0000"}, - {"start_time": "2021-01-07T00:00:00.000000+0000", "end_time": "2021-01-07T23:59:59.999999+0000"}, - {"start_time": "2021-01-08T00:00:00.000000+0000", "end_time": "2021-01-08T23:59:59.999999+0000"}, - {"start_time": "2021-01-09T00:00:00.000000+0000", "end_time": "2021-01-09T23:59:59.999999+0000"}, - {"start_time": "2021-01-10T00:00:00.000000+0000", "end_time": "2021-01-10T00:00:00.000000+0000"}, + { + "start_time": "2021-01-05T00:00:00.000000+0000", + "end_time": "2021-01-05T23:59:59.999999+0000", + }, + { + "start_time": "2021-01-06T00:00:00.000000+0000", + "end_time": "2021-01-06T23:59:59.999999+0000", + }, + { + "start_time": "2021-01-07T00:00:00.000000+0000", + "end_time": "2021-01-07T23:59:59.999999+0000", + }, + { + "start_time": "2021-01-08T00:00:00.000000+0000", + "end_time": "2021-01-08T23:59:59.999999+0000", + }, + { + "start_time": "2021-01-09T00:00:00.000000+0000", + "end_time": "2021-01-09T23:59:59.999999+0000", + }, + { + "start_time": "2021-01-10T00:00:00.000000+0000", + "end_time": "2021-01-10T00:00:00.000000+0000", + }, ], ), ( @@ -169,14 +277,20 @@ def mock_datetime_now(monkeypatch): cursor_granularity, None, [ - {"start_time": "2021-01-01T00:00:00.000000+0000", "end_time": "2021-01-10T00:00:00.000000+0000"}, + { + "start_time": "2021-01-01T00:00:00.000000+0000", + "end_time": "2021-01-10T00:00:00.000000+0000", + }, ], ), ( "test_end_time_greater_than_now", NO_STATE, MinMaxDatetime(datetime="2021-12-28T00:00:00.000000+0000", parameters={}), - MinMaxDatetime(datetime=f"{(FAKE_NOW + datetime.timedelta(days=1)).strftime(datetime_format)}", parameters={}), + MinMaxDatetime( + datetime=f"{(FAKE_NOW + datetime.timedelta(days=1)).strftime(datetime_format)}", + parameters={}, + ), "P1D", cursor_field, None, @@ -184,11 +298,26 @@ def mock_datetime_now(monkeypatch): cursor_granularity, None, [ - {"start_time": "2021-12-28T00:00:00.000000+0000", "end_time": "2021-12-28T23:59:59.999999+0000"}, - {"start_time": "2021-12-29T00:00:00.000000+0000", "end_time": "2021-12-29T23:59:59.999999+0000"}, - {"start_time": "2021-12-30T00:00:00.000000+0000", "end_time": "2021-12-30T23:59:59.999999+0000"}, - {"start_time": "2021-12-31T00:00:00.000000+0000", "end_time": "2021-12-31T23:59:59.999999+0000"}, - {"start_time": "2022-01-01T00:00:00.000000+0000", "end_time": "2022-01-01T00:00:00.000000+0000"}, + { + "start_time": "2021-12-28T00:00:00.000000+0000", + "end_time": "2021-12-28T23:59:59.999999+0000", + }, + { + "start_time": "2021-12-29T00:00:00.000000+0000", + "end_time": "2021-12-29T23:59:59.999999+0000", + }, + { + "start_time": "2021-12-30T00:00:00.000000+0000", + "end_time": "2021-12-30T23:59:59.999999+0000", + }, + { + "start_time": "2021-12-31T00:00:00.000000+0000", + "end_time": "2021-12-31T23:59:59.999999+0000", + }, + { + "start_time": "2022-01-01T00:00:00.000000+0000", + "end_time": "2022-01-01T00:00:00.000000+0000", + }, ], ), ( @@ -203,7 +332,10 @@ def mock_datetime_now(monkeypatch): cursor_granularity, None, [ - {"start_time": "2021-01-05T00:00:00.000000+0000", "end_time": "2021-01-05T00:00:00.000000+0000"}, + { + "start_time": "2021-01-05T00:00:00.000000+0000", + "end_time": "2021-01-05T00:00:00.000000+0000", + }, ], ), ( @@ -218,12 +350,30 @@ def mock_datetime_now(monkeypatch): cursor_granularity, None, [ - {"start_time": "2021-01-05T00:00:00.000000+0000", "end_time": "2021-01-05T23:59:59.999999+0000"}, - {"start_time": "2021-01-06T00:00:00.000000+0000", "end_time": "2021-01-06T23:59:59.999999+0000"}, - {"start_time": "2021-01-07T00:00:00.000000+0000", "end_time": "2021-01-07T23:59:59.999999+0000"}, - {"start_time": "2021-01-08T00:00:00.000000+0000", "end_time": "2021-01-08T23:59:59.999999+0000"}, - {"start_time": "2021-01-09T00:00:00.000000+0000", "end_time": "2021-01-09T23:59:59.999999+0000"}, - {"start_time": "2021-01-10T00:00:00.000000+0000", "end_time": "2021-01-10T00:00:00.000000+0000"}, + { + "start_time": "2021-01-05T00:00:00.000000+0000", + "end_time": "2021-01-05T23:59:59.999999+0000", + }, + { + "start_time": "2021-01-06T00:00:00.000000+0000", + "end_time": "2021-01-06T23:59:59.999999+0000", + }, + { + "start_time": "2021-01-07T00:00:00.000000+0000", + "end_time": "2021-01-07T23:59:59.999999+0000", + }, + { + "start_time": "2021-01-08T00:00:00.000000+0000", + "end_time": "2021-01-08T23:59:59.999999+0000", + }, + { + "start_time": "2021-01-09T00:00:00.000000+0000", + "end_time": "2021-01-09T23:59:59.999999+0000", + }, + { + "start_time": "2021-01-10T00:00:00.000000+0000", + "end_time": "2021-01-10T00:00:00.000000+0000", + }, ], ), ( @@ -238,9 +388,18 @@ def mock_datetime_now(monkeypatch): cursor_granularity, None, [ - {"start_time": "2021-01-05T00:00:00.000000+0000", "end_time": "2021-01-06T23:59:59.999999+0000"}, - {"start_time": "2021-01-07T00:00:00.000000+0000", "end_time": "2021-01-08T23:59:59.999999+0000"}, - {"start_time": "2021-01-09T00:00:00.000000+0000", "end_time": "2021-01-10T00:00:00.000000+0000"}, + { + "start_time": "2021-01-05T00:00:00.000000+0000", + "end_time": "2021-01-06T23:59:59.999999+0000", + }, + { + "start_time": "2021-01-07T00:00:00.000000+0000", + "end_time": "2021-01-08T23:59:59.999999+0000", + }, + { + "start_time": "2021-01-09T00:00:00.000000+0000", + "end_time": "2021-01-10T00:00:00.000000+0000", + }, ], ), ( @@ -255,10 +414,22 @@ def mock_datetime_now(monkeypatch): cursor_granularity, None, [ - {"start_time": "2021-01-05T00:00:00.000000+0000", "end_time": "2021-01-05T23:59:59.999999+0000"}, - {"start_time": "2021-01-06T00:00:00.000000+0000", "end_time": "2021-01-06T23:59:59.999999+0000"}, - {"start_time": "2021-01-07T00:00:00.000000+0000", "end_time": "2021-01-07T23:59:59.999999+0000"}, - {"start_time": "2021-01-08T00:00:00.000000+0000", "end_time": "2021-01-08T00:00:00.000000+0000"}, + { + "start_time": "2021-01-05T00:00:00.000000+0000", + "end_time": "2021-01-05T23:59:59.999999+0000", + }, + { + "start_time": "2021-01-06T00:00:00.000000+0000", + "end_time": "2021-01-06T23:59:59.999999+0000", + }, + { + "start_time": "2021-01-07T00:00:00.000000+0000", + "end_time": "2021-01-07T23:59:59.999999+0000", + }, + { + "start_time": "2021-01-08T00:00:00.000000+0000", + "end_time": "2021-01-08T00:00:00.000000+0000", + }, ], ), ( @@ -273,11 +444,26 @@ def mock_datetime_now(monkeypatch): cursor_granularity, None, [ - {"start_time": "2021-01-02T00:00:00.000000+0000", "end_time": "2021-01-02T23:59:59.999999+0000"}, - {"start_time": "2021-01-03T00:00:00.000000+0000", "end_time": "2021-01-03T23:59:59.999999+0000"}, - {"start_time": "2021-01-04T00:00:00.000000+0000", "end_time": "2021-01-04T23:59:59.999999+0000"}, - {"start_time": "2021-01-05T00:00:00.000000+0000", "end_time": "2021-01-05T23:59:59.999999+0000"}, - {"start_time": "2021-01-06T00:00:00.000000+0000", "end_time": "2021-01-06T00:00:00.000000+0000"}, + { + "start_time": "2021-01-02T00:00:00.000000+0000", + "end_time": "2021-01-02T23:59:59.999999+0000", + }, + { + "start_time": "2021-01-03T00:00:00.000000+0000", + "end_time": "2021-01-03T23:59:59.999999+0000", + }, + { + "start_time": "2021-01-04T00:00:00.000000+0000", + "end_time": "2021-01-04T23:59:59.999999+0000", + }, + { + "start_time": "2021-01-05T00:00:00.000000+0000", + "end_time": "2021-01-05T23:59:59.999999+0000", + }, + { + "start_time": "2021-01-06T00:00:00.000000+0000", + "end_time": "2021-01-06T00:00:00.000000+0000", + }, ], ), ( @@ -292,11 +478,26 @@ def mock_datetime_now(monkeypatch): cursor_granularity, None, [ - {"start_time": "2021-01-01T00:00:00.000000+0000", "end_time": "2021-01-01T23:59:59.999999+0000"}, - {"start_time": "2021-01-02T00:00:00.000000+0000", "end_time": "2021-01-02T23:59:59.999999+0000"}, - {"start_time": "2021-01-03T00:00:00.000000+0000", "end_time": "2021-01-03T23:59:59.999999+0000"}, - {"start_time": "2021-01-04T00:00:00.000000+0000", "end_time": "2021-01-04T23:59:59.999999+0000"}, - {"start_time": "2021-01-05T00:00:00.000000+0000", "end_time": "2021-01-05T00:00:00.000000+0000"}, + { + "start_time": "2021-01-01T00:00:00.000000+0000", + "end_time": "2021-01-01T23:59:59.999999+0000", + }, + { + "start_time": "2021-01-02T00:00:00.000000+0000", + "end_time": "2021-01-02T23:59:59.999999+0000", + }, + { + "start_time": "2021-01-03T00:00:00.000000+0000", + "end_time": "2021-01-03T23:59:59.999999+0000", + }, + { + "start_time": "2021-01-04T00:00:00.000000+0000", + "end_time": "2021-01-04T23:59:59.999999+0000", + }, + { + "start_time": "2021-01-05T00:00:00.000000+0000", + "end_time": "2021-01-05T00:00:00.000000+0000", + }, ], ), ( @@ -311,12 +512,30 @@ def mock_datetime_now(monkeypatch): cursor_granularity, None, [ - {"start_time": "2021-01-05T00:00:00.000000+0000", "end_time": "2021-01-05T23:59:59.999999+0000"}, - {"start_time": "2021-01-06T00:00:00.000000+0000", "end_time": "2021-01-06T23:59:59.999999+0000"}, - {"start_time": "2021-01-07T00:00:00.000000+0000", "end_time": "2021-01-07T23:59:59.999999+0000"}, - {"start_time": "2021-01-08T00:00:00.000000+0000", "end_time": "2021-01-08T23:59:59.999999+0000"}, - {"start_time": "2021-01-09T00:00:00.000000+0000", "end_time": "2021-01-09T23:59:59.999999+0000"}, - {"start_time": "2021-01-10T00:00:00.000000+0000", "end_time": "2021-01-10T00:00:00.000000+0000"}, + { + "start_time": "2021-01-05T00:00:00.000000+0000", + "end_time": "2021-01-05T23:59:59.999999+0000", + }, + { + "start_time": "2021-01-06T00:00:00.000000+0000", + "end_time": "2021-01-06T23:59:59.999999+0000", + }, + { + "start_time": "2021-01-07T00:00:00.000000+0000", + "end_time": "2021-01-07T23:59:59.999999+0000", + }, + { + "start_time": "2021-01-08T00:00:00.000000+0000", + "end_time": "2021-01-08T23:59:59.999999+0000", + }, + { + "start_time": "2021-01-09T00:00:00.000000+0000", + "end_time": "2021-01-09T23:59:59.999999+0000", + }, + { + "start_time": "2021-01-10T00:00:00.000000+0000", + "end_time": "2021-01-10T00:00:00.000000+0000", + }, ], ), ( @@ -331,7 +550,10 @@ def mock_datetime_now(monkeypatch): cursor_granularity, True, [ - {"start_time": "2021-01-01T00:00:00.000000+0000", "end_time": "2021-01-31T23:59:59.999999+0000"}, + { + "start_time": "2021-01-01T00:00:00.000000+0000", + "end_time": "2021-01-31T23:59:59.999999+0000", + }, ], ), ], @@ -350,7 +572,9 @@ def test_stream_slices( is_compare_strictly, expected_slices, ): - lookback_window = InterpolatedString(string=lookback_window, parameters={}) if lookback_window else None + lookback_window = ( + InterpolatedString(string=lookback_window, parameters={}) if lookback_window else None + ) cursor = DatetimeBasedCursor( start_datetime=start, end_datetime=end, @@ -375,70 +599,98 @@ def test_stream_slices( ( "test_close_slice_previous_cursor_is_highest", "2023-01-01", - StreamSlice(partition={}, cursor_slice={"start_time": "2021-01-01", "end_time": "2022-01-01"}), + StreamSlice( + partition={}, cursor_slice={"start_time": "2021-01-01", "end_time": "2022-01-01"} + ), [{cursor_field: "2021-01-01"}], {cursor_field: "2023-01-01"}, ), ( "test_close_slice_stream_slice_partition_end_is_highest", "2020-01-01", - StreamSlice(partition={}, cursor_slice={"start_time": "2021-01-01", "end_time": "2023-01-01"}), + StreamSlice( + partition={}, cursor_slice={"start_time": "2021-01-01", "end_time": "2023-01-01"} + ), [{cursor_field: "2021-01-01"}], {cursor_field: "2021-01-01"}, ), ( "test_close_slice_latest_record_cursor_value_is_higher_than_slice_end", "2021-01-01", - StreamSlice(partition={}, cursor_slice={"start_time": "2021-01-01", "end_time": "2022-01-01"}), + StreamSlice( + partition={}, cursor_slice={"start_time": "2021-01-01", "end_time": "2022-01-01"} + ), [{cursor_field: "2023-01-01"}], {cursor_field: "2021-01-01"}, ), ( "test_close_slice_with_no_records_observed", "2021-01-01", - StreamSlice(partition={}, cursor_slice={"start_time": "2021-01-01", "end_time": "2022-01-01"}), + StreamSlice( + partition={}, cursor_slice={"start_time": "2021-01-01", "end_time": "2022-01-01"} + ), [], {cursor_field: "2021-01-01"}, ), ( "test_close_slice_with_no_records_observed_and_no_previous_state", None, - StreamSlice(partition={}, cursor_slice={"start_time": "2021-01-01", "end_time": "2022-01-01"}), + StreamSlice( + partition={}, cursor_slice={"start_time": "2021-01-01", "end_time": "2022-01-01"} + ), [], {}, ), ( "test_close_slice_without_previous_cursor", None, - StreamSlice(partition={}, cursor_slice={"start_time": "2021-01-01", "end_time": "2023-01-01"}), + StreamSlice( + partition={}, cursor_slice={"start_time": "2021-01-01", "end_time": "2023-01-01"} + ), [{cursor_field: "2022-01-01"}], {cursor_field: "2022-01-01"}, ), ( "test_close_slice_with_out_of_order_records", "2021-01-01", - StreamSlice(partition={}, cursor_slice={"start_time": "2021-01-01", "end_time": "2022-01-01"}), - [{cursor_field: "2021-04-01"}, {cursor_field: "2021-02-01"}, {cursor_field: "2021-03-01"}], + StreamSlice( + partition={}, cursor_slice={"start_time": "2021-01-01", "end_time": "2022-01-01"} + ), + [ + {cursor_field: "2021-04-01"}, + {cursor_field: "2021-02-01"}, + {cursor_field: "2021-03-01"}, + ], {cursor_field: "2021-04-01"}, ), ( "test_close_slice_with_some_records_out_of_slice_boundaries", "2021-01-01", - StreamSlice(partition={}, cursor_slice={"start_time": "2021-01-01", "end_time": "2022-01-01"}), - [{cursor_field: "2021-02-01"}, {cursor_field: "2021-03-01"}, {cursor_field: "2023-01-01"}], + StreamSlice( + partition={}, cursor_slice={"start_time": "2021-01-01", "end_time": "2022-01-01"} + ), + [ + {cursor_field: "2021-02-01"}, + {cursor_field: "2021-03-01"}, + {cursor_field: "2023-01-01"}, + ], {cursor_field: "2021-03-01"}, ), ( "test_close_slice_with_all_records_out_of_slice_boundaries", "2021-01-01", - StreamSlice(partition={}, cursor_slice={"start_time": "2021-01-01", "end_time": "2022-01-01"}), + StreamSlice( + partition={}, cursor_slice={"start_time": "2021-01-01", "end_time": "2022-01-01"} + ), [{cursor_field: "2023-01-01"}], {cursor_field: "2021-01-01"}, ), ( "test_close_slice_with_all_records_out_of_slice_and_no_previous_cursor", None, - StreamSlice(partition={}, cursor_slice={"start_time": "2021-01-01", "end_time": "2022-01-01"}), + StreamSlice( + partition={}, cursor_slice={"start_time": "2021-01-01", "end_time": "2022-01-01"} + ), [{cursor_field: "2023-01-01"}], {}, ), @@ -485,7 +737,9 @@ def test_compares_cursor_values_by_chronological_order(): parameters={}, ) - _slice = StreamSlice(partition={}, cursor_slice={"start_time": "01-01-2023", "end_time": "01-04-2023"}) + _slice = StreamSlice( + partition={}, cursor_slice={"start_time": "01-01-2023", "end_time": "01-04-2023"} + ) first_record = Record({cursor_field: "21-02-2023"}, _slice) cursor.observe(_slice, first_record) second_record = Record({cursor_field: "01-03-2023"}, _slice) @@ -505,7 +759,13 @@ def test_given_different_format_and_slice_is_highest_when_close_slice_then_state parameters={}, ) - _slice = StreamSlice(partition={}, cursor_slice={"start_time": "2023-01-01T17:30:19.000Z", "end_time": "2023-01-04T17:30:19.000Z"}) + _slice = StreamSlice( + partition={}, + cursor_slice={ + "start_time": "2023-01-01T17:30:19.000Z", + "end_time": "2023-01-04T17:30:19.000Z", + }, + ) record_cursor_value = "2023-01-03" record = Record({cursor_field: record_cursor_value}, _slice) cursor.observe(_slice, record) @@ -522,7 +782,10 @@ def test_given_different_format_and_slice_is_highest_when_close_slice_then_state "test_start_time_passed_by_req_param", RequestOptionType.request_parameter, "start_time", - {"start_time": "2021-01-01T00:00:00.000000+0000", "endtime": "2021-01-04T00:00:00.000000+0000"}, + { + "start_time": "2021-01-01T00:00:00.000000+0000", + "endtime": "2021-01-04T00:00:00.000000+0000", + }, {}, {}, {}, @@ -532,7 +795,10 @@ def test_given_different_format_and_slice_is_highest_when_close_slice_then_state RequestOptionType.header, "start_time", {}, - {"start_time": "2021-01-01T00:00:00.000000+0000", "endtime": "2021-01-04T00:00:00.000000+0000"}, + { + "start_time": "2021-01-01T00:00:00.000000+0000", + "endtime": "2021-01-04T00:00:00.000000+0000", + }, {}, {}, ), @@ -542,7 +808,10 @@ def test_given_different_format_and_slice_is_highest_when_close_slice_then_state "start_time", {}, {}, - {"start_time": "2021-01-01T00:00:00.000000+0000", "endtime": "2021-01-04T00:00:00.000000+0000"}, + { + "start_time": "2021-01-01T00:00:00.000000+0000", + "endtime": "2021-01-04T00:00:00.000000+0000", + }, {}, ), ( @@ -552,13 +821,32 @@ def test_given_different_format_and_slice_is_highest_when_close_slice_then_state {}, {}, {}, - {"start_time": "2021-01-01T00:00:00.000000+0000", "endtime": "2021-01-04T00:00:00.000000+0000"}, + { + "start_time": "2021-01-01T00:00:00.000000+0000", + "endtime": "2021-01-04T00:00:00.000000+0000", + }, ), ], ) -def test_request_option(test_name, inject_into, field_name, expected_req_params, expected_headers, expected_body_json, expected_body_data): - start_request_option = RequestOption(inject_into=inject_into, parameters={}, field_name=field_name) if inject_into else None - end_request_option = RequestOption(inject_into=inject_into, parameters={}, field_name="endtime") if inject_into else None +def test_request_option( + test_name, + inject_into, + field_name, + expected_req_params, + expected_headers, + expected_body_json, + expected_body_data, +): + start_request_option = ( + RequestOption(inject_into=inject_into, parameters={}, field_name=field_name) + if inject_into + else None + ) + end_request_option = ( + RequestOption(inject_into=inject_into, parameters={}, field_name="endtime") + if inject_into + else None + ) slicer = DatetimeBasedCursor( start_datetime=MinMaxDatetime(datetime="2021-01-01T00:00:00.000000+0000", parameters={}), end_datetime=MinMaxDatetime(datetime="2021-01-10T00:00:00.000000+0000", parameters={}), @@ -572,7 +860,10 @@ def test_request_option(test_name, inject_into, field_name, expected_req_params, config=config, parameters={}, ) - stream_slice = {"start_time": "2021-01-01T00:00:00.000000+0000", "end_time": "2021-01-04T00:00:00.000000+0000"} + stream_slice = { + "start_time": "2021-01-01T00:00:00.000000+0000", + "end_time": "2021-01-04T00:00:00.000000+0000", + } assert slicer.get_request_params(stream_slice=stream_slice) == expected_req_params assert slicer.get_request_headers(stream_slice=stream_slice) == expected_headers assert slicer.get_request_body_json(stream_slice=stream_slice) == expected_body_json @@ -587,8 +878,12 @@ def test_request_option(test_name, inject_into, field_name, expected_req_params, ], ) def test_request_option_with_empty_stream_slice(stream_slice): - start_request_option = RequestOption(inject_into=RequestOptionType.request_parameter, parameters={}, field_name="starttime") - end_request_option = RequestOption(inject_into=RequestOptionType.request_parameter, parameters={}, field_name="endtime") + start_request_option = RequestOption( + inject_into=RequestOptionType.request_parameter, parameters={}, field_name="starttime" + ) + end_request_option = RequestOption( + inject_into=RequestOptionType.request_parameter, parameters={}, field_name="endtime" + ) slicer = DatetimeBasedCursor( start_datetime=MinMaxDatetime(datetime="2021-01-01T00:00:00.000000+0000", parameters={}), end_datetime=MinMaxDatetime(datetime="2021-01-10T00:00:00.000000+0000", parameters={}), @@ -622,7 +917,13 @@ def test_request_option_with_empty_stream_slice(stream_slice): "PT1S", datetime.datetime(2021, 1, 1, 0, 0, tzinfo=datetime.timezone.utc), ), - ("test_parse_date_number", "20210101", "%Y%m%d", "P1D", datetime.datetime(2021, 1, 1, 0, 0, tzinfo=datetime.timezone.utc)), + ( + "test_parse_date_number", + "20210101", + "%Y%m%d", + "P1D", + datetime.datetime(2021, 1, 1, 0, 0, tzinfo=datetime.timezone.utc), + ), ], ) def test_parse_date_legacy_merge_datetime_format_in_cursor_datetime_format( @@ -688,12 +989,32 @@ def test_given_unknown_format_when_parse_date_then_raise_error(): @pytest.mark.parametrize( "test_name, input_dt, datetimeformat, datetimeformat_granularity, expected_output", [ - ("test_format_timestamp", datetime.datetime(2021, 1, 1, 0, 0, tzinfo=datetime.timezone.utc), "%s", "PT1S", "1609459200"), - ("test_format_string", datetime.datetime(2021, 1, 1, 0, 0, tzinfo=datetime.timezone.utc), "%Y-%m-%d", "P1D", "2021-01-01"), - ("test_format_to_number", datetime.datetime(2021, 1, 1, 0, 0, tzinfo=datetime.timezone.utc), "%Y%m%d", "P1D", "20210101"), + ( + "test_format_timestamp", + datetime.datetime(2021, 1, 1, 0, 0, tzinfo=datetime.timezone.utc), + "%s", + "PT1S", + "1609459200", + ), + ( + "test_format_string", + datetime.datetime(2021, 1, 1, 0, 0, tzinfo=datetime.timezone.utc), + "%Y-%m-%d", + "P1D", + "2021-01-01", + ), + ( + "test_format_to_number", + datetime.datetime(2021, 1, 1, 0, 0, tzinfo=datetime.timezone.utc), + "%Y%m%d", + "P1D", + "20210101", + ), ], ) -def test_format_datetime(test_name, input_dt, datetimeformat, datetimeformat_granularity, expected_output): +def test_format_datetime( + test_name, input_dt, datetimeformat, datetimeformat_granularity, expected_output +): slicer = DatetimeBasedCursor( start_datetime=MinMaxDatetime("2021-01-01T00:00:00.000000+0000", parameters={}), end_datetime=MinMaxDatetime("2021-01-10T00:00:00.000000+0000", parameters={}), @@ -772,7 +1093,9 @@ def test_no_end_datetime(mock_datetime_now): parameters={}, ) stream_slices = cursor.stream_slices() - assert stream_slices == [{"start_time": "2021-01-01", "end_time": FAKE_NOW.strftime("%Y-%m-%d")}] + assert stream_slices == [ + {"start_time": "2021-01-01", "end_time": FAKE_NOW.strftime("%Y-%m-%d")} + ] def test_given_no_state_and_start_before_cursor_value_when_should_be_synced_then_return_true(): @@ -852,7 +1175,9 @@ def test_given_first_greater_than_second_then_return_true(): config=config, parameters={}, ) - assert cursor.is_greater_than_or_equal(Record({"cursor_field": "2023-01-01"}, {}), Record({"cursor_field": "2021-01-01"}, {})) + assert cursor.is_greater_than_or_equal( + Record({"cursor_field": "2023-01-01"}, {}), Record({"cursor_field": "2021-01-01"}, {}) + ) def test_given_first_lesser_than_second_then_return_false(): @@ -863,7 +1188,9 @@ def test_given_first_lesser_than_second_then_return_false(): config=config, parameters={}, ) - assert not cursor.is_greater_than_or_equal(Record({"cursor_field": "2021-01-01"}, {}), Record({"cursor_field": "2023-01-01"}, {})) + assert not cursor.is_greater_than_or_equal( + Record({"cursor_field": "2021-01-01"}, {}), Record({"cursor_field": "2023-01-01"}, {}) + ) def test_given_no_cursor_value_for_second_than_second_then_return_true(): @@ -874,7 +1201,9 @@ def test_given_no_cursor_value_for_second_than_second_then_return_true(): config=config, parameters={}, ) - assert cursor.is_greater_than_or_equal(Record({"cursor_field": "2021-01-01"}, {}), Record({}, {})) + assert cursor.is_greater_than_or_equal( + Record({"cursor_field": "2021-01-01"}, {}), Record({}, {}) + ) def test_given_no_cursor_value_for_first_than_second_then_return_false(): @@ -885,7 +1214,9 @@ def test_given_no_cursor_value_for_first_than_second_then_return_false(): config=config, parameters={}, ) - assert not cursor.is_greater_than_or_equal(Record({}, {}), Record({"cursor_field": "2021-01-01"}, {})) + assert not cursor.is_greater_than_or_equal( + Record({}, {}), Record({"cursor_field": "2021-01-01"}, {}) + ) if __name__ == "__main__": diff --git a/unit_tests/sources/declarative/incremental/test_per_partition_cursor.py b/unit_tests/sources/declarative/incremental/test_per_partition_cursor.py index 823405cb..e1cd6d19 100644 --- a/unit_tests/sources/declarative/incremental/test_per_partition_cursor.py +++ b/unit_tests/sources/declarative/incremental/test_per_partition_cursor.py @@ -7,7 +7,11 @@ import pytest from airbyte_cdk.sources.declarative.incremental.declarative_cursor import DeclarativeCursor -from airbyte_cdk.sources.declarative.incremental.per_partition_cursor import PerPartitionCursor, PerPartitionKeySerializer, StreamSlice +from airbyte_cdk.sources.declarative.incremental.per_partition_cursor import ( + PerPartitionCursor, + PerPartitionKeySerializer, + StreamSlice, +) from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter from airbyte_cdk.sources.types import Record @@ -16,7 +20,10 @@ "partition_key int": 1, "partition_key list str": ["list item 1", "list item 2"], "partition_key list dict": [ - {"dict within list key 1-1": "dict within list value 1-1", "dict within list key 1-2": "dict within list value 1-2"}, + { + "dict within list key 1-1": "dict within list value 1-1", + "dict within list key 1-2": "dict within list value 1-2", + }, {"dict within list key 2": "dict within list value 2"}, ], "partition_key nested dict": { @@ -59,7 +66,9 @@ def test_partition_with_different_key_orders(): same_dict_with_different_order = OrderedDict({"2": 2, "1": 1}) serializer = PerPartitionKeySerializer() - assert serializer.to_partition_key(ordered_dict) == serializer.to_partition_key(same_dict_with_different_order) + assert serializer.to_partition_key(ordered_dict) == serializer.to_partition_key( + same_dict_with_different_order + ) def test_given_tuples_in_json_then_deserialization_convert_to_list(): @@ -70,17 +79,24 @@ def test_given_tuples_in_json_then_deserialization_convert_to_list(): serializer = PerPartitionKeySerializer() partition_with_tuple = {"key": (1, 2, 3)} - assert partition_with_tuple != serializer.to_partition(serializer.to_partition_key(partition_with_tuple)) + assert partition_with_tuple != serializer.to_partition( + serializer.to_partition_key(partition_with_tuple) + ) def test_stream_slice_merge_dictionaries(): - stream_slice = StreamSlice(partition={"partition key": "partition value"}, cursor_slice={"cursor key": "cursor value"}) + stream_slice = StreamSlice( + partition={"partition key": "partition value"}, cursor_slice={"cursor key": "cursor value"} + ) assert stream_slice == {"partition key": "partition value", "cursor key": "cursor value"} def test_overlapping_slice_keys_raise_error(): with pytest.raises(ValueError): - StreamSlice(partition={"overlapping key": "partition value"}, cursor_slice={"overlapping key": "cursor value"}) + StreamSlice( + partition={"overlapping key": "partition value"}, + cursor_slice={"overlapping key": "cursor value"}, + ) class MockedCursorBuilder: @@ -115,7 +131,9 @@ def mocked_cursor_factory(): return cursor_factory -def test_given_no_partition_when_stream_slices_then_no_slices(mocked_cursor_factory, mocked_partition_router): +def test_given_no_partition_when_stream_slices_then_no_slices( + mocked_cursor_factory, mocked_partition_router +): mocked_partition_router.stream_slices.return_value = [] cursor = PerPartitionCursor(mocked_cursor_factory, mocked_partition_router) @@ -127,27 +145,42 @@ def test_given_no_partition_when_stream_slices_then_no_slices(mocked_cursor_fact def test_given_partition_router_without_state_has_one_partition_then_return_one_slice_per_cursor_slice( mocked_cursor_factory, mocked_partition_router ): - partition = StreamSlice(partition={"partition_field_1": "a value", "partition_field_2": "another value"}, cursor_slice={}) + partition = StreamSlice( + partition={"partition_field_1": "a value", "partition_field_2": "another value"}, + cursor_slice={}, + ) mocked_partition_router.stream_slices.return_value = [partition] cursor_slices = [{"start_datetime": 1}, {"start_datetime": 2}] - mocked_cursor_factory.create.return_value = MockedCursorBuilder().with_stream_slices(cursor_slices).build() + mocked_cursor_factory.create.return_value = ( + MockedCursorBuilder().with_stream_slices(cursor_slices).build() + ) cursor = PerPartitionCursor(mocked_cursor_factory, mocked_partition_router) slices = cursor.stream_slices() - assert list(slices) == [StreamSlice(partition=partition, cursor_slice=cursor_slice) for cursor_slice in cursor_slices] + assert list(slices) == [ + StreamSlice(partition=partition, cursor_slice=cursor_slice) + for cursor_slice in cursor_slices + ] def test_given_partition_associated_with_state_when_stream_slices_then_do_not_recreate_cursor( mocked_cursor_factory, mocked_partition_router ): - partition = StreamSlice(partition={"partition_field_1": "a value", "partition_field_2": "another value"}, cursor_slice={}) + partition = StreamSlice( + partition={"partition_field_1": "a value", "partition_field_2": "another value"}, + cursor_slice={}, + ) mocked_partition_router.stream_slices.return_value = [partition] cursor_slices = [{"start_datetime": 1}] - mocked_cursor_factory.create.return_value = MockedCursorBuilder().with_stream_slices(cursor_slices).build() + mocked_cursor_factory.create.return_value = ( + MockedCursorBuilder().with_stream_slices(cursor_slices).build() + ) cursor = PerPartitionCursor(mocked_cursor_factory, mocked_partition_router) - cursor.set_initial_state({"states": [{"partition": partition.partition, "cursor": CURSOR_STATE}]}) + cursor.set_initial_state( + {"states": [{"partition": partition.partition, "cursor": CURSOR_STATE}]} + ) mocked_cursor_factory.create.assert_called_once() slices = list(cursor.stream_slices()) @@ -155,14 +188,24 @@ def test_given_partition_associated_with_state_when_stream_slices_then_do_not_re assert len(slices) == 1 -def test_given_multiple_partitions_then_each_have_their_state(mocked_cursor_factory, mocked_partition_router): +def test_given_multiple_partitions_then_each_have_their_state( + mocked_cursor_factory, mocked_partition_router +): first_partition = {"first_partition_key": "first_partition_value"} mocked_partition_router.stream_slices.return_value = [ StreamSlice(partition=first_partition, cursor_slice={}), StreamSlice(partition={"second_partition_key": "second_partition_value"}, cursor_slice={}), ] - first_cursor = MockedCursorBuilder().with_stream_slices([{CURSOR_SLICE_FIELD: "first slice cursor value"}]).build() - second_cursor = MockedCursorBuilder().with_stream_slices([{CURSOR_SLICE_FIELD: "second slice cursor value"}]).build() + first_cursor = ( + MockedCursorBuilder() + .with_stream_slices([{CURSOR_SLICE_FIELD: "first slice cursor value"}]) + .build() + ) + second_cursor = ( + MockedCursorBuilder() + .with_stream_slices([{CURSOR_SLICE_FIELD: "second slice cursor value"}]) + .build() + ) mocked_cursor_factory.create.side_effect = [first_cursor, second_cursor] cursor = PerPartitionCursor(mocked_cursor_factory, mocked_partition_router) @@ -173,18 +216,26 @@ def test_given_multiple_partitions_then_each_have_their_state(mocked_cursor_fact second_cursor.stream_slices.assert_called_once() assert slices == [ StreamSlice( - partition={"first_partition_key": "first_partition_value"}, cursor_slice={CURSOR_SLICE_FIELD: "first slice cursor value"} + partition={"first_partition_key": "first_partition_value"}, + cursor_slice={CURSOR_SLICE_FIELD: "first slice cursor value"}, ), StreamSlice( - partition={"second_partition_key": "second_partition_value"}, cursor_slice={CURSOR_SLICE_FIELD: "second slice cursor value"} + partition={"second_partition_key": "second_partition_value"}, + cursor_slice={CURSOR_SLICE_FIELD: "second slice cursor value"}, ), ] -def test_given_stream_slices_when_get_stream_state_then_return_updated_state(mocked_cursor_factory, mocked_partition_router): +def test_given_stream_slices_when_get_stream_state_then_return_updated_state( + mocked_cursor_factory, mocked_partition_router +): mocked_cursor_factory.create.side_effect = [ - MockedCursorBuilder().with_stream_state({CURSOR_STATE_KEY: "first slice cursor value"}).build(), - MockedCursorBuilder().with_stream_state({CURSOR_STATE_KEY: "second slice cursor value"}).build(), + MockedCursorBuilder() + .with_stream_state({CURSOR_STATE_KEY: "first slice cursor value"}) + .build(), + MockedCursorBuilder() + .with_stream_state({CURSOR_STATE_KEY: "second slice cursor value"}) + .build(), ] mocked_partition_router.stream_slices.return_value = [ StreamSlice(partition={"partition key": "first partition"}, cursor_slice={}), @@ -198,16 +249,30 @@ def test_given_stream_slices_when_get_stream_state_then_return_updated_state(moc list(cursor.stream_slices()) assert cursor.get_stream_state() == { "states": [ - {"partition": {"partition key": "first partition"}, "cursor": {CURSOR_STATE_KEY: "first slice cursor value"}}, - {"partition": {"partition key": "second partition"}, "cursor": {CURSOR_STATE_KEY: "second slice cursor value"}}, + { + "partition": {"partition key": "first partition"}, + "cursor": {CURSOR_STATE_KEY: "first slice cursor value"}, + }, + { + "partition": {"partition key": "second partition"}, + "cursor": {CURSOR_STATE_KEY: "second slice cursor value"}, + }, ] } -def test_when_get_stream_state_then_delegate_to_underlying_cursor(mocked_cursor_factory, mocked_partition_router): - underlying_cursor = MockedCursorBuilder().with_stream_slices([{CURSOR_SLICE_FIELD: "first slice cursor value"}]).build() +def test_when_get_stream_state_then_delegate_to_underlying_cursor( + mocked_cursor_factory, mocked_partition_router +): + underlying_cursor = ( + MockedCursorBuilder() + .with_stream_slices([{CURSOR_SLICE_FIELD: "first slice cursor value"}]) + .build() + ) mocked_cursor_factory.create.side_effect = [underlying_cursor] - mocked_partition_router.stream_slices.return_value = [StreamSlice(partition={"partition key": "first partition"}, cursor_slice={})] + mocked_partition_router.stream_slices.return_value = [ + StreamSlice(partition={"partition key": "first partition"}, cursor_slice={}) + ] cursor = PerPartitionCursor(mocked_cursor_factory, mocked_partition_router) first_slice = list(cursor.stream_slices())[0] @@ -217,7 +282,11 @@ def test_when_get_stream_state_then_delegate_to_underlying_cursor(mocked_cursor_ def test_close_slice(mocked_cursor_factory, mocked_partition_router): - underlying_cursor = MockedCursorBuilder().with_stream_slices([{CURSOR_SLICE_FIELD: "first slice cursor value"}]).build() + underlying_cursor = ( + MockedCursorBuilder() + .with_stream_slices([{CURSOR_SLICE_FIELD: "first slice cursor value"}]) + .build() + ) mocked_cursor_factory.create.side_effect = [underlying_cursor] stream_slice = StreamSlice(partition={"partition key": "first partition"}, cursor_slice={}) mocked_partition_router.stream_slices.return_value = [stream_slice] @@ -229,8 +298,14 @@ def test_close_slice(mocked_cursor_factory, mocked_partition_router): underlying_cursor.close_slice.assert_called_once_with(stream_slice.cursor_slice) -def test_given_no_last_record_when_close_slice_then_do_not_raise_error(mocked_cursor_factory, mocked_partition_router): - underlying_cursor = MockedCursorBuilder().with_stream_slices([{CURSOR_SLICE_FIELD: "first slice cursor value"}]).build() +def test_given_no_last_record_when_close_slice_then_do_not_raise_error( + mocked_cursor_factory, mocked_partition_router +): + underlying_cursor = ( + MockedCursorBuilder() + .with_stream_slices([{CURSOR_SLICE_FIELD: "first slice cursor value"}]) + .build() + ) mocked_cursor_factory.create.side_effect = [underlying_cursor] stream_slice = StreamSlice(partition={"partition key": "first partition"}, cursor_slice={}) mocked_partition_router.stream_slices.return_value = [stream_slice] @@ -256,7 +331,9 @@ def test_given_unknown_partition_when_should_be_synced_then_raise_error(): any_partition_router = Mock() cursor = PerPartitionCursor(any_cursor_factory, any_partition_router) with pytest.raises(ValueError): - cursor.should_be_synced(Record({}, StreamSlice(partition={"unknown_partition": "unknown"}, cursor_slice={}))) + cursor.should_be_synced( + Record({}, StreamSlice(partition={"unknown_partition": "unknown"}, cursor_slice={})) + ) def test_given_records_with_different_slice_when_is_greater_than_or_equal_then_raise_error(): @@ -273,16 +350,28 @@ def test_given_records_with_different_slice_when_is_greater_than_or_equal_then_r @pytest.mark.parametrize( "first_record_slice, second_record_slice", [ - pytest.param(StreamSlice(partition={"a slice": "value"}, cursor_slice={}), None, id="second record does not have a slice"), - pytest.param(None, StreamSlice(partition={"a slice": "value"}, cursor_slice={}), id="first record does not have a slice"), + pytest.param( + StreamSlice(partition={"a slice": "value"}, cursor_slice={}), + None, + id="second record does not have a slice", + ), + pytest.param( + None, + StreamSlice(partition={"a slice": "value"}, cursor_slice={}), + id="first record does not have a slice", + ), ], ) -def test_given_records_without_a_slice_when_is_greater_than_or_equal_then_raise_error(first_record_slice, second_record_slice): +def test_given_records_without_a_slice_when_is_greater_than_or_equal_then_raise_error( + first_record_slice, second_record_slice +): any_cursor_factory = Mock() any_partition_router = Mock() cursor = PerPartitionCursor(any_cursor_factory, any_partition_router) with pytest.raises(ValueError): - cursor.is_greater_than_or_equal(Record({}, first_record_slice), Record({}, second_record_slice)) + cursor.is_greater_than_or_equal( + Record({}, first_record_slice), Record({}, second_record_slice) + ) def test_given_slice_is_unknown_when_is_greater_than_or_equal_then_raise_error(): @@ -296,8 +385,14 @@ def test_given_slice_is_unknown_when_is_greater_than_or_equal_then_raise_error() ) -def test_when_is_greater_than_or_equal_then_return_underlying_cursor_response(mocked_cursor_factory, mocked_partition_router): - underlying_cursor = MockedCursorBuilder().with_stream_slices([{CURSOR_SLICE_FIELD: "first slice cursor value"}]).build() +def test_when_is_greater_than_or_equal_then_return_underlying_cursor_response( + mocked_cursor_factory, mocked_partition_router +): + underlying_cursor = ( + MockedCursorBuilder() + .with_stream_slices([{CURSOR_SLICE_FIELD: "first slice cursor value"}]) + .build() + ) mocked_cursor_factory.create.side_effect = [underlying_cursor] stream_slice = StreamSlice(partition={"partition key": "first partition"}, cursor_slice={}) mocked_partition_router.stream_slices.return_value = [stream_slice] @@ -323,21 +418,31 @@ def test_when_is_greater_than_or_equal_then_return_underlying_cursor_response(mo pytest.param(None, None, id="first partition"), ], ) -def test_get_request_params(mocked_cursor_factory, mocked_partition_router, stream_slice, expected_output): - underlying_cursor = MockedCursorBuilder().with_stream_slices([{CURSOR_SLICE_FIELD: "first slice cursor value"}]).build() +def test_get_request_params( + mocked_cursor_factory, mocked_partition_router, stream_slice, expected_output +): + underlying_cursor = ( + MockedCursorBuilder() + .with_stream_slices([{CURSOR_SLICE_FIELD: "first slice cursor value"}]) + .build() + ) underlying_cursor.get_request_params.return_value = {"cursor": "params"} mocked_cursor_factory.create.side_effect = [underlying_cursor] mocked_partition_router.stream_slices.return_value = [stream_slice] mocked_partition_router.get_request_params.return_value = {"router": "params"} cursor = PerPartitionCursor(mocked_cursor_factory, mocked_partition_router) if stream_slice: - cursor.set_initial_state({"states": [{"partition": stream_slice.partition, "cursor": CURSOR_STATE}]}) + cursor.set_initial_state( + {"states": [{"partition": stream_slice.partition, "cursor": CURSOR_STATE}]} + ) params = cursor.get_request_params(stream_slice=stream_slice) assert params == expected_output mocked_partition_router.get_request_params.assert_called_once_with( stream_state=None, stream_slice=stream_slice, next_page_token=None ) - underlying_cursor.get_request_params.assert_called_once_with(stream_state=None, stream_slice={}, next_page_token=None) + underlying_cursor.get_request_params.assert_called_once_with( + stream_state=None, stream_slice={}, next_page_token=None + ) else: with pytest.raises(ValueError): cursor.get_request_params(stream_slice=stream_slice) @@ -354,21 +459,31 @@ def test_get_request_params(mocked_cursor_factory, mocked_partition_router, stre pytest.param(None, None, id="first partition"), ], ) -def test_get_request_headers(mocked_cursor_factory, mocked_partition_router, stream_slice, expected_output): - underlying_cursor = MockedCursorBuilder().with_stream_slices([{CURSOR_SLICE_FIELD: "first slice cursor value"}]).build() +def test_get_request_headers( + mocked_cursor_factory, mocked_partition_router, stream_slice, expected_output +): + underlying_cursor = ( + MockedCursorBuilder() + .with_stream_slices([{CURSOR_SLICE_FIELD: "first slice cursor value"}]) + .build() + ) underlying_cursor.get_request_headers.return_value = {"cursor": "params"} mocked_cursor_factory.create.side_effect = [underlying_cursor] mocked_partition_router.stream_slices.return_value = [stream_slice] mocked_partition_router.get_request_headers.return_value = {"router": "params"} cursor = PerPartitionCursor(mocked_cursor_factory, mocked_partition_router) if stream_slice: - cursor.set_initial_state({"states": [{"partition": stream_slice.partition, "cursor": CURSOR_STATE}]}) + cursor.set_initial_state( + {"states": [{"partition": stream_slice.partition, "cursor": CURSOR_STATE}]} + ) params = cursor.get_request_headers(stream_slice=stream_slice) assert params == expected_output mocked_partition_router.get_request_headers.assert_called_once_with( stream_state=None, stream_slice=stream_slice, next_page_token=None ) - underlying_cursor.get_request_headers.assert_called_once_with(stream_state=None, stream_slice={}, next_page_token=None) + underlying_cursor.get_request_headers.assert_called_once_with( + stream_state=None, stream_slice={}, next_page_token=None + ) else: with pytest.raises(ValueError): cursor.get_request_headers(stream_slice=stream_slice) @@ -385,21 +500,31 @@ def test_get_request_headers(mocked_cursor_factory, mocked_partition_router, str pytest.param(None, None, id="first partition"), ], ) -def test_get_request_body_data(mocked_cursor_factory, mocked_partition_router, stream_slice, expected_output): - underlying_cursor = MockedCursorBuilder().with_stream_slices([{CURSOR_SLICE_FIELD: "first slice cursor value"}]).build() +def test_get_request_body_data( + mocked_cursor_factory, mocked_partition_router, stream_slice, expected_output +): + underlying_cursor = ( + MockedCursorBuilder() + .with_stream_slices([{CURSOR_SLICE_FIELD: "first slice cursor value"}]) + .build() + ) underlying_cursor.get_request_body_data.return_value = {"cursor": "params"} mocked_cursor_factory.create.side_effect = [underlying_cursor] mocked_partition_router.stream_slices.return_value = [stream_slice] mocked_partition_router.get_request_body_data.return_value = {"router": "params"} cursor = PerPartitionCursor(mocked_cursor_factory, mocked_partition_router) if stream_slice: - cursor.set_initial_state({"states": [{"partition": stream_slice.partition, "cursor": CURSOR_STATE}]}) + cursor.set_initial_state( + {"states": [{"partition": stream_slice.partition, "cursor": CURSOR_STATE}]} + ) params = cursor.get_request_body_data(stream_slice=stream_slice) assert params == expected_output mocked_partition_router.get_request_body_data.assert_called_once_with( stream_state=None, stream_slice=stream_slice, next_page_token=None ) - underlying_cursor.get_request_body_data.assert_called_once_with(stream_state=None, stream_slice={}, next_page_token=None) + underlying_cursor.get_request_body_data.assert_called_once_with( + stream_state=None, stream_slice={}, next_page_token=None + ) else: with pytest.raises(ValueError): cursor.get_request_body_data(stream_slice=stream_slice) @@ -416,32 +541,47 @@ def test_get_request_body_data(mocked_cursor_factory, mocked_partition_router, s pytest.param(None, None, id="first partition"), ], ) -def test_get_request_body_json(mocked_cursor_factory, mocked_partition_router, stream_slice, expected_output): - underlying_cursor = MockedCursorBuilder().with_stream_slices([{CURSOR_SLICE_FIELD: "first slice cursor value"}]).build() +def test_get_request_body_json( + mocked_cursor_factory, mocked_partition_router, stream_slice, expected_output +): + underlying_cursor = ( + MockedCursorBuilder() + .with_stream_slices([{CURSOR_SLICE_FIELD: "first slice cursor value"}]) + .build() + ) underlying_cursor.get_request_body_json.return_value = {"cursor": "params"} mocked_cursor_factory.create.side_effect = [underlying_cursor] mocked_partition_router.stream_slices.return_value = [stream_slice] mocked_partition_router.get_request_body_json.return_value = {"router": "params"} cursor = PerPartitionCursor(mocked_cursor_factory, mocked_partition_router) if stream_slice: - cursor.set_initial_state({"states": [{"partition": stream_slice.partition, "cursor": CURSOR_STATE}]}) + cursor.set_initial_state( + {"states": [{"partition": stream_slice.partition, "cursor": CURSOR_STATE}]} + ) params = cursor.get_request_body_json(stream_slice=stream_slice) assert params == expected_output mocked_partition_router.get_request_body_json.assert_called_once_with( stream_state=None, stream_slice=stream_slice, next_page_token=None ) - underlying_cursor.get_request_body_json.assert_called_once_with(stream_state=None, stream_slice={}, next_page_token=None) + underlying_cursor.get_request_body_json.assert_called_once_with( + stream_state=None, stream_slice={}, next_page_token=None + ) else: with pytest.raises(ValueError): cursor.get_request_body_json(stream_slice=stream_slice) -def test_parent_state_is_set_for_per_partition_cursor(mocked_cursor_factory, mocked_partition_router): +def test_parent_state_is_set_for_per_partition_cursor( + mocked_cursor_factory, mocked_partition_router +): # Define the parent state to be used in the test parent_state = {"parent_cursor": "parent_state_value"} # Mock the partition router to return a stream slice - partition = StreamSlice(partition={"partition_field_1": "a value", "partition_field_2": "another value"}, cursor_slice={}) + partition = StreamSlice( + partition={"partition_field_1": "a value", "partition_field_2": "another value"}, + cursor_slice={}, + ) mocked_partition_router.stream_slices.return_value = [partition] # Mock the cursor factory to create cursors with specific states @@ -517,7 +657,9 @@ def test_get_stream_state_includes_parent_state(mocked_cursor_factory, mocked_pa assert stream_state == expected_state -def test_per_partition_state_when_set_initial_global_state(mocked_cursor_factory, mocked_partition_router) -> None: +def test_per_partition_state_when_set_initial_global_state( + mocked_cursor_factory, mocked_partition_router +) -> None: first_partition = {"first_partition_key": "first_partition_value"} second_partition = {"second_partition_key": "second_partition_value"} global_state = {"global_state_format_key": "global_state_format_value"} @@ -535,16 +677,29 @@ def test_per_partition_state_when_set_initial_global_state(mocked_cursor_factory cursor.set_initial_state(global_state) assert cursor._state_to_migrate_from == global_state list(cursor.stream_slices()) - assert cursor._cursor_per_partition['{"first_partition_key":"first_partition_value"}'].set_initial_state.call_count == 1 - assert cursor._cursor_per_partition['{"first_partition_key":"first_partition_value"}'].set_initial_state.call_args[0] == ( - {"global_state_format_key": "global_state_format_value"}, + assert ( + cursor._cursor_per_partition[ + '{"first_partition_key":"first_partition_value"}' + ].set_initial_state.call_count + == 1 ) - assert cursor._cursor_per_partition['{"second_partition_key":"second_partition_value"}'].set_initial_state.call_count == 1 - assert cursor._cursor_per_partition['{"second_partition_key":"second_partition_value"}'].set_initial_state.call_args[0] == ( - {"global_state_format_key": "global_state_format_value"}, + assert cursor._cursor_per_partition[ + '{"first_partition_key":"first_partition_value"}' + ].set_initial_state.call_args[0] == ({"global_state_format_key": "global_state_format_value"},) + assert ( + cursor._cursor_per_partition[ + '{"second_partition_key":"second_partition_value"}' + ].set_initial_state.call_count + == 1 ) + assert cursor._cursor_per_partition[ + '{"second_partition_key":"second_partition_value"}' + ].set_initial_state.call_args[0] == ({"global_state_format_key": "global_state_format_value"},) expected_state = [ - {"cursor": {"global_state_format_key": "global_state_format_value"}, "partition": {"first_partition_key": "first_partition_value"}}, + { + "cursor": {"global_state_format_key": "global_state_format_value"}, + "partition": {"first_partition_key": "first_partition_value"}, + }, { "cursor": {"global_state_format_key": "global_state_format_value"}, "partition": {"second_partition_key": "second_partition_value"}, diff --git a/unit_tests/sources/declarative/incremental/test_per_partition_cursor_integration.py b/unit_tests/sources/declarative/incremental/test_per_partition_cursor_integration.py index 4fff298b..9d0216ff 100644 --- a/unit_tests/sources/declarative/incremental/test_per_partition_cursor_integration.py +++ b/unit_tests/sources/declarative/incremental/test_per_partition_cursor_integration.py @@ -17,7 +17,10 @@ StreamDescriptor, SyncMode, ) -from airbyte_cdk.sources.declarative.incremental.per_partition_cursor import PerPartitionCursor, StreamSlice +from airbyte_cdk.sources.declarative.incremental.per_partition_cursor import ( + PerPartitionCursor, + StreamSlice, +) from airbyte_cdk.sources.declarative.manifest_declarative_source import ManifestDeclarativeSource from airbyte_cdk.sources.declarative.retrievers.simple_retriever import SimpleRetriever from airbyte_cdk.sources.types import Record @@ -55,7 +58,16 @@ def with_substream_partition_router(self, stream_name): } return self - def with_incremental_sync(self, stream_name, start_datetime, end_datetime, datetime_format, cursor_field, step, cursor_granularity): + def with_incremental_sync( + self, + stream_name, + start_datetime, + end_datetime, + datetime_format, + cursor_field, + step, + cursor_granularity, + ): self._incremental_sync[stream_name] = { "type": "DatetimeBasedCursor", "start_datetime": start_datetime, @@ -79,7 +91,11 @@ def build(self): "primary_key": [], "schema_loader": { "type": "InlineSchemaLoader", - "schema": {"$schema": "http://json-schema.org/schema#", "properties": {"id": {"type": "string"}}, "type": "object"}, + "schema": { + "$schema": "http://json-schema.org/schema#", + "properties": {"id": {"type": "string"}}, + "type": "object", + }, }, "retriever": { "type": "SimpleRetriever", @@ -89,7 +105,10 @@ def build(self): "path": "/exchangerates_data/latest", "http_method": "GET", }, - "record_selector": {"type": "RecordSelector", "extractor": {"type": "DpathExtractor", "field_path": []}}, + "record_selector": { + "type": "RecordSelector", + "extractor": {"type": "DpathExtractor", "field_path": []}, + }, }, }, "Rates": { @@ -98,7 +117,11 @@ def build(self): "primary_key": [], "schema_loader": { "type": "InlineSchemaLoader", - "schema": {"$schema": "http://json-schema.org/schema#", "properties": {}, "type": "object"}, + "schema": { + "$schema": "http://json-schema.org/schema#", + "properties": {}, + "type": "object", + }, }, "retriever": { "type": "SimpleRetriever", @@ -108,7 +131,10 @@ def build(self): "path": "/exchangerates_data/latest", "http_method": "GET", }, - "record_selector": {"type": "RecordSelector", "extractor": {"type": "DpathExtractor", "field_path": []}}, + "record_selector": { + "type": "RecordSelector", + "extractor": {"type": "DpathExtractor", "field_path": []}, + }, }, }, }, @@ -128,9 +154,13 @@ def build(self): for stream_name, incremental_sync_definition in self._incremental_sync.items(): manifest["definitions"][stream_name]["incremental_sync"] = incremental_sync_definition for stream_name, partition_router_definition in self._partition_router.items(): - manifest["definitions"][stream_name]["retriever"]["partition_router"] = partition_router_definition + manifest["definitions"][stream_name]["retriever"]["partition_router"] = ( + partition_router_definition + ) for stream_name, partition_router_definition in self._substream_partition_router.items(): - manifest["definitions"][stream_name]["retriever"]["partition_router"] = partition_router_definition + manifest["definitions"][stream_name]["retriever"]["partition_router"] = ( + partition_router_definition + ) return manifest @@ -189,9 +219,16 @@ def test_given_record_for_partition_when_read_then_update_state(): stream_instance = source.streams({})[0] list(stream_instance.stream_slices(sync_mode=SYNC_MODE)) - stream_slice = StreamSlice(partition={"partition_field": "1"}, cursor_slice={"start_time": "2022-01-01", "end_time": "2022-01-31"}) + stream_slice = StreamSlice( + partition={"partition_field": "1"}, + cursor_slice={"start_time": "2022-01-01", "end_time": "2022-01-31"}, + ) with patch.object( - SimpleRetriever, "_read_pages", side_effect=[[Record({"a record key": "a record value", CURSOR_FIELD: "2022-01-15"}, stream_slice)]] + SimpleRetriever, + "_read_pages", + side_effect=[ + [Record({"a record key": "a record value", CURSOR_FIELD: "2022-01-15"}, stream_slice)] + ], ): list( stream_instance.read_records( @@ -241,7 +278,9 @@ def test_substream_without_input_state(): stream_instance = test_source.streams({})[1] - parent_stream_slice = StreamSlice(partition={}, cursor_slice={"start_time": "2022-01-01", "end_time": "2022-01-31"}) + parent_stream_slice = StreamSlice( + partition={}, cursor_slice={"start_time": "2022-01-01", "end_time": "2022-01-31"} + ) # This mocks the resulting records of the Rates stream which acts as the parent stream of the SubstreamPartitionRouter being tested with patch.object( @@ -316,18 +355,38 @@ def test_partition_limitation(caplog): records_list = [ [ - Record({"a record key": "a record value", CURSOR_FIELD: "2022-01-15"}, partition_slices[0]), - Record({"a record key": "a record value", CURSOR_FIELD: "2022-01-16"}, partition_slices[0]), + Record( + {"a record key": "a record value", CURSOR_FIELD: "2022-01-15"}, partition_slices[0] + ), + Record( + {"a record key": "a record value", CURSOR_FIELD: "2022-01-16"}, partition_slices[0] + ), + ], + [ + Record( + {"a record key": "a record value", CURSOR_FIELD: "2022-02-15"}, partition_slices[0] + ) + ], + [ + Record( + {"a record key": "a record value", CURSOR_FIELD: "2022-01-16"}, partition_slices[1] + ) ], - [Record({"a record key": "a record value", CURSOR_FIELD: "2022-02-15"}, partition_slices[0])], - [Record({"a record key": "a record value", CURSOR_FIELD: "2022-01-16"}, partition_slices[1])], [], [], - [Record({"a record key": "a record value", CURSOR_FIELD: "2022-02-17"}, partition_slices[2])], + [ + Record( + {"a record key": "a record value", CURSOR_FIELD: "2022-02-17"}, partition_slices[2] + ) + ], ] configured_stream = ConfiguredAirbyteStream( - stream=AirbyteStream(name="Rates", json_schema={}, supported_sync_modes=[SyncMode.full_refresh, SyncMode.incremental]), + stream=AirbyteStream( + name="Rates", + json_schema={}, + supported_sync_modes=[SyncMode.full_refresh, SyncMode.incremental], + ), sync_mode=SyncMode.incremental, destination_sync_mode=DestinationSyncMode.append, ) @@ -369,12 +428,14 @@ def test_partition_limitation(caplog): # Check if the warning was logged logged_messages = [record.message for record in caplog.records if record.levelname == "WARNING"] - warning_message = ( - 'The maximum number of partitions has been reached. Dropping the oldest partition: {"partition_field":"1"}. Over limit: 1.' - ) + warning_message = 'The maximum number of partitions has been reached. Dropping the oldest partition: {"partition_field":"1"}. Over limit: 1.' assert warning_message in logged_messages - final_state = [orjson.loads(orjson.dumps(message.state.stream.stream_state)) for message in output if message.state] + final_state = [ + orjson.loads(orjson.dumps(message.state.stream.stream_state)) + for message in output + if message.state + ] assert final_state[-1] == { "lookback_window": 1, "state": {"cursor_field": "2022-02-17"}, @@ -414,28 +475,70 @@ def test_perpartition_with_fallback(caplog): .build() ) - partition_slices = [StreamSlice(partition={"partition_field": str(i)}, cursor_slice={}) for i in range(1, 7)] + partition_slices = [ + StreamSlice(partition={"partition_field": str(i)}, cursor_slice={}) for i in range(1, 7) + ] records_list = [ [ - Record({"a record key": "a record value", CURSOR_FIELD: "2022-01-15"}, partition_slices[0]), - Record({"a record key": "a record value", CURSOR_FIELD: "2022-01-16"}, partition_slices[0]), + Record( + {"a record key": "a record value", CURSOR_FIELD: "2022-01-15"}, partition_slices[0] + ), + Record( + {"a record key": "a record value", CURSOR_FIELD: "2022-01-16"}, partition_slices[0] + ), + ], + [ + Record( + {"a record key": "a record value", CURSOR_FIELD: "2022-02-15"}, partition_slices[0] + ) + ], + [ + Record( + {"a record key": "a record value", CURSOR_FIELD: "2022-01-16"}, partition_slices[1] + ) ], - [Record({"a record key": "a record value", CURSOR_FIELD: "2022-02-15"}, partition_slices[0])], - [Record({"a record key": "a record value", CURSOR_FIELD: "2022-01-16"}, partition_slices[1])], [], [], - [Record({"a record key": "a record value", CURSOR_FIELD: "2022-02-17"}, partition_slices[2])], - [Record({"a record key": "a record value", CURSOR_FIELD: "2022-01-17"}, partition_slices[3])], - [Record({"a record key": "a record value", CURSOR_FIELD: "2022-02-19"}, partition_slices[3])], + [ + Record( + {"a record key": "a record value", CURSOR_FIELD: "2022-02-17"}, partition_slices[2] + ) + ], + [ + Record( + {"a record key": "a record value", CURSOR_FIELD: "2022-01-17"}, partition_slices[3] + ) + ], + [ + Record( + {"a record key": "a record value", CURSOR_FIELD: "2022-02-19"}, partition_slices[3] + ) + ], [], - [Record({"a record key": "a record value", CURSOR_FIELD: "2022-02-18"}, partition_slices[4])], - [Record({"a record key": "a record value", CURSOR_FIELD: "2022-01-13"}, partition_slices[3])], - [Record({"a record key": "a record value", CURSOR_FIELD: "2022-02-18"}, partition_slices[3])], + [ + Record( + {"a record key": "a record value", CURSOR_FIELD: "2022-02-18"}, partition_slices[4] + ) + ], + [ + Record( + {"a record key": "a record value", CURSOR_FIELD: "2022-01-13"}, partition_slices[3] + ) + ], + [ + Record( + {"a record key": "a record value", CURSOR_FIELD: "2022-02-18"}, partition_slices[3] + ) + ], ] configured_stream = ConfiguredAirbyteStream( - stream=AirbyteStream(name="Rates", json_schema={}, supported_sync_modes=[SyncMode.full_refresh, SyncMode.incremental]), + stream=AirbyteStream( + name="Rates", + json_schema={}, + supported_sync_modes=[SyncMode.full_refresh, SyncMode.incremental], + ), sync_mode=SyncMode.incremental, destination_sync_mode=DestinationSyncMode.append, ) @@ -488,8 +591,16 @@ def test_perpartition_with_fallback(caplog): assert expected_message in logged_messages # Proceed with existing assertions - final_state = [orjson.loads(orjson.dumps(message.state.stream.stream_state)) for message in output if message.state] - assert final_state[-1] == {"use_global_cursor": True, "state": {"cursor_field": "2022-02-19"}, "lookback_window": 1} + final_state = [ + orjson.loads(orjson.dumps(message.state.stream.stream_state)) + for message in output + if message.state + ] + assert final_state[-1] == { + "use_global_cursor": True, + "state": {"cursor_field": "2022-02-19"}, + "lookback_window": 1, + } def test_per_partition_cursor_within_limit(caplog): @@ -514,22 +625,64 @@ def test_per_partition_cursor_within_limit(caplog): .build() ) - partition_slices = [StreamSlice(partition={"partition_field": str(i)}, cursor_slice={}) for i in range(1, 4)] + partition_slices = [ + StreamSlice(partition={"partition_field": str(i)}, cursor_slice={}) for i in range(1, 4) + ] records_list = [ - [Record({"a record key": "a record value", CURSOR_FIELD: "2022-01-15"}, partition_slices[0])], - [Record({"a record key": "a record value", CURSOR_FIELD: "2022-02-20"}, partition_slices[0])], - [Record({"a record key": "a record value", CURSOR_FIELD: "2022-03-25"}, partition_slices[0])], - [Record({"a record key": "a record value", CURSOR_FIELD: "2022-01-16"}, partition_slices[1])], - [Record({"a record key": "a record value", CURSOR_FIELD: "2022-02-18"}, partition_slices[1])], - [Record({"a record key": "a record value", CURSOR_FIELD: "2022-03-28"}, partition_slices[1])], - [Record({"a record key": "a record value", CURSOR_FIELD: "2022-01-17"}, partition_slices[2])], - [Record({"a record key": "a record value", CURSOR_FIELD: "2022-02-19"}, partition_slices[2])], - [Record({"a record key": "a record value", CURSOR_FIELD: "2022-03-29"}, partition_slices[2])], + [ + Record( + {"a record key": "a record value", CURSOR_FIELD: "2022-01-15"}, partition_slices[0] + ) + ], + [ + Record( + {"a record key": "a record value", CURSOR_FIELD: "2022-02-20"}, partition_slices[0] + ) + ], + [ + Record( + {"a record key": "a record value", CURSOR_FIELD: "2022-03-25"}, partition_slices[0] + ) + ], + [ + Record( + {"a record key": "a record value", CURSOR_FIELD: "2022-01-16"}, partition_slices[1] + ) + ], + [ + Record( + {"a record key": "a record value", CURSOR_FIELD: "2022-02-18"}, partition_slices[1] + ) + ], + [ + Record( + {"a record key": "a record value", CURSOR_FIELD: "2022-03-28"}, partition_slices[1] + ) + ], + [ + Record( + {"a record key": "a record value", CURSOR_FIELD: "2022-01-17"}, partition_slices[2] + ) + ], + [ + Record( + {"a record key": "a record value", CURSOR_FIELD: "2022-02-19"}, partition_slices[2] + ) + ], + [ + Record( + {"a record key": "a record value", CURSOR_FIELD: "2022-03-29"}, partition_slices[2] + ) + ], ] configured_stream = ConfiguredAirbyteStream( - stream=AirbyteStream(name="Rates", json_schema={}, supported_sync_modes=[SyncMode.full_refresh, SyncMode.incremental]), + stream=AirbyteStream( + name="Rates", + json_schema={}, + supported_sync_modes=[SyncMode.full_refresh, SyncMode.incremental], + ), sync_mode=SyncMode.incremental, destination_sync_mode=DestinationSyncMode.append, ) @@ -549,7 +702,11 @@ def test_per_partition_cursor_within_limit(caplog): assert len(logged_warnings) == 0 # Proceed with existing assertions - final_state = [orjson.loads(orjson.dumps(message.state.stream.stream_state)) for message in output if message.state] + final_state = [ + orjson.loads(orjson.dumps(message.state.stream.stream_state)) + for message in output + if message.state + ] assert final_state[-1] == { "lookback_window": 1, "state": {"cursor_field": "2022-03-29"}, diff --git a/unit_tests/sources/declarative/incremental/test_resumable_full_refresh_cursor.py b/unit_tests/sources/declarative/incremental/test_resumable_full_refresh_cursor.py index b4597328..90321449 100644 --- a/unit_tests/sources/declarative/incremental/test_resumable_full_refresh_cursor.py +++ b/unit_tests/sources/declarative/incremental/test_resumable_full_refresh_cursor.py @@ -1,7 +1,10 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. import pytest -from airbyte_cdk.sources.declarative.incremental import ChildPartitionResumableFullRefreshCursor, ResumableFullRefreshCursor +from airbyte_cdk.sources.declarative.incremental import ( + ChildPartitionResumableFullRefreshCursor, + ResumableFullRefreshCursor, +) from airbyte_cdk.sources.types import StreamSlice diff --git a/unit_tests/sources/declarative/interpolation/test_filters.py b/unit_tests/sources/declarative/interpolation/test_filters.py index 912abf3d..82dd2bf1 100644 --- a/unit_tests/sources/declarative/interpolation/test_filters.py +++ b/unit_tests/sources/declarative/interpolation/test_filters.py @@ -97,7 +97,9 @@ def test_regex_search_no_match_group() -> None: def test_regex_search_no_match() -> None: # If no group is set in the regular expression, the result will be an empty string - expression_with_regex = "{{ '; rel=\"next\"' | regex_search('WATWATWAT') }}" + expression_with_regex = ( + "{{ '; rel=\"next\"' | regex_search('WATWATWAT') }}" + ) val = interpolation.eval(expression_with_regex, {}) diff --git a/unit_tests/sources/declarative/interpolation/test_interpolated_boolean.py b/unit_tests/sources/declarative/interpolation/test_interpolated_boolean.py index 6b4ba534..015d45aa 100644 --- a/unit_tests/sources/declarative/interpolation/test_interpolated_boolean.py +++ b/unit_tests/sources/declarative/interpolation/test_interpolated_boolean.py @@ -21,7 +21,11 @@ [ ("test_interpolated_true_value", "{{ config['parent']['key_with_true'] }}", True), ("test_interpolated_true_comparison", "{{ config['string_key'] == \"compare_me\" }}", True), - ("test_interpolated_false_condition", "{{ config['string_key'] == \"witness_me\" }}", False), + ( + "test_interpolated_false_condition", + "{{ config['string_key'] == \"witness_me\" }}", + False, + ), ("test_path_has_value_returns_true", "{{ config['string_key'] }}", True), ("test_zero_is_false", "{{ config['zero_value'] }}", False), ("test_empty_array_is_false", "{{ config['empty_array'] }}", False), @@ -32,9 +36,15 @@ ("test_True", "{{ True }}", True), ("test_value_in_array", "{{ 1 in config['non_empty_array'] }}", True), ("test_value_not_in_array", "{{ 2 in config['non_empty_array'] }}", False), - ("test_interpolation_using_parameters", "{{ parameters['from_parameters'] == \"come_find_me\" }}", True), + ( + "test_interpolation_using_parameters", + "{{ parameters['from_parameters'] == \"come_find_me\" }}", + True, + ), ], ) def test_interpolated_boolean(test_name, template, expected_result): - interpolated_bool = InterpolatedBoolean(condition=template, parameters={"from_parameters": "come_find_me"}) + interpolated_bool = InterpolatedBoolean( + condition=template, parameters={"from_parameters": "come_find_me"} + ) assert interpolated_bool.eval(config) == expected_result diff --git a/unit_tests/sources/declarative/interpolation/test_interpolated_mapping.py b/unit_tests/sources/declarative/interpolation/test_interpolated_mapping.py index f4c77c3c..87843915 100644 --- a/unit_tests/sources/declarative/interpolation/test_interpolated_mapping.py +++ b/unit_tests/sources/declarative/interpolation/test_interpolated_mapping.py @@ -11,9 +11,21 @@ [ ("test_field_value", "field", "value"), ("test_number", "number", 100), - ("test_field_to_interpolate_from_config", "field_to_interpolate_from_config", "VALUE_FROM_CONFIG"), - ("test_field_to_interpolate_from_kwargs", "field_to_interpolate_from_kwargs", "VALUE_FROM_KWARGS"), - ("test_field_to_interpolate_from_parameters", "field_to_interpolate_from_parameters", "VALUE_FROM_PARAMETERS"), + ( + "test_field_to_interpolate_from_config", + "field_to_interpolate_from_config", + "VALUE_FROM_CONFIG", + ), + ( + "test_field_to_interpolate_from_kwargs", + "field_to_interpolate_from_kwargs", + "VALUE_FROM_KWARGS", + ), + ( + "test_field_to_interpolate_from_parameters", + "field_to_interpolate_from_parameters", + "VALUE_FROM_PARAMETERS", + ), ("test_key_is_interpolated", "key", "VALUE"), ], ) diff --git a/unit_tests/sources/declarative/interpolation/test_interpolated_nested_mapping.py b/unit_tests/sources/declarative/interpolation/test_interpolated_nested_mapping.py index c1dea1d6..809f368c 100644 --- a/unit_tests/sources/declarative/interpolation/test_interpolated_nested_mapping.py +++ b/unit_tests/sources/declarative/interpolation/test_interpolated_nested_mapping.py @@ -4,7 +4,9 @@ import dpath import pytest -from airbyte_cdk.sources.declarative.interpolation.interpolated_nested_mapping import InterpolatedNestedMapping +from airbyte_cdk.sources.declarative.interpolation.interpolated_nested_mapping import ( + InterpolatedNestedMapping, +) @pytest.mark.parametrize( @@ -16,7 +18,11 @@ ("test_interpolated_boolean", "nested/nested_array/2/value", True), ("test_field_to_interpolate_from_config", "nested/config_value", "VALUE_FROM_CONFIG"), ("test_field_to_interpolate_from_kwargs", "nested/kwargs_value", "VALUE_FROM_KWARGS"), - ("test_field_to_interpolate_from_parameters", "nested/parameters_value", "VALUE_FROM_PARAMETERS"), + ( + "test_field_to_interpolate_from_parameters", + "nested/parameters_value", + "VALUE_FROM_PARAMETERS", + ), ("test_key_is_interpolated", "nested/nested_array/0/key", "VALUE"), ], ) @@ -38,7 +44,9 @@ def test(test_name, path, expected_value): config = {"c": "VALUE_FROM_CONFIG", "num_value": 3} kwargs = {"a": "VALUE_FROM_KWARGS"} - mapping = InterpolatedNestedMapping(mapping=d, parameters={"b": "VALUE_FROM_PARAMETERS", "k": "key"}) + mapping = InterpolatedNestedMapping( + mapping=d, parameters={"b": "VALUE_FROM_PARAMETERS", "k": "key"} + ) interpolated = mapping.eval(config, **{"kwargs": kwargs}) diff --git a/unit_tests/sources/declarative/interpolation/test_jinja.py b/unit_tests/sources/declarative/interpolation/test_jinja.py index 207e6fae..7534e929 100644 --- a/unit_tests/sources/declarative/interpolation/test_jinja.py +++ b/unit_tests/sources/declarative/interpolation/test_jinja.py @@ -129,7 +129,9 @@ def test_positive_day_delta(): val = interpolation.eval(delta_template, {}) # We need to assert against an earlier delta since the interpolation function runs datetime.now() a few milliseconds earlier - assert val > (datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=24, hours=23)).strftime("%Y-%m-%dT%H:%M:%S.%f%z") + assert val > ( + datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=24, hours=23) + ).strftime("%Y-%m-%dT%H:%M:%S.%f%z") def test_positive_day_delta_with_format(): @@ -144,7 +146,9 @@ def test_negative_day_delta(): delta_template = "{{ day_delta(-25) }}" val = interpolation.eval(delta_template, {}) - assert val <= (datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(days=25)).strftime("%Y-%m-%dT%H:%M:%S.%f%z") + assert val <= ( + datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(days=25) + ).strftime("%Y-%m-%dT%H:%M:%S.%f%z") @pytest.mark.parametrize( @@ -173,7 +177,11 @@ def test_to_string(test_name, input_value, expected_output): [ pytest.param("{{ timestamp(1621439283) }}", 1621439283, id="test_timestamp_from_timestamp"), pytest.param("{{ timestamp('2021-05-19') }}", 1621382400, id="test_timestamp_from_string"), - pytest.param("{{ timestamp('2017-01-01T00:00:00.0Z') }}", 1483228800, id="test_timestamp_from_rfc3339"), + pytest.param( + "{{ timestamp('2017-01-01T00:00:00.0Z') }}", + 1483228800, + id="test_timestamp_from_rfc3339", + ), pytest.param("{{ max(1,2) }}", 2, id="test_max"), ], ) @@ -187,7 +195,9 @@ def test_macros(s, expected_value): "template_string", [ pytest.param("{{ import os) }}", id="test_jinja_with_import"), - pytest.param("{{ [a for a in range(1000000000)] }}", id="test_jinja_with_list_comprehension"), + pytest.param( + "{{ [a for a in range(1000000000)] }}", id="test_jinja_with_list_comprehension" + ), ], ) def test_invalid_jinja_statements(template_string): @@ -229,7 +239,9 @@ def test_restricted_builtin_functions_are_not_executed(template_string): [ pytest.param("{{ to_be }}", "that_is_the_question", None, id="valid_template_variable"), pytest.param("{{ missingno }}", None, ValueError, id="undeclared_template_variable"), - pytest.param("{{ to_be and or_not_to_be }}", None, ValueError, id="one_undeclared_template_variable"), + pytest.param( + "{{ to_be and or_not_to_be }}", None, ValueError, id="one_undeclared_template_variable" + ), ], ) def test_undeclared_variables(template_string, expected_error, expected_value): @@ -239,7 +251,9 @@ def test_undeclared_variables(template_string, expected_error, expected_value): with pytest.raises(expected_error): interpolation.eval(template_string, config=config, **{"to_be": "that_is_the_question"}) else: - actual_value = interpolation.eval(template_string, config=config, **{"to_be": "that_is_the_question"}) + actual_value = interpolation.eval( + template_string, config=config, **{"to_be": "that_is_the_question"} + ) assert actual_value == expected_value @@ -248,28 +262,56 @@ def test_undeclared_variables(template_string, expected_error, expected_value): "template_string, expected_value", [ pytest.param("{{ now_utc() }}", "2021-09-01 00:00:00+00:00", id="test_now_utc"), - pytest.param("{{ now_utc().strftime('%Y-%m-%d') }}", "2021-09-01", id="test_now_utc_strftime"), + pytest.param( + "{{ now_utc().strftime('%Y-%m-%d') }}", "2021-09-01", id="test_now_utc_strftime" + ), pytest.param("{{ today_utc() }}", "2021-09-01", id="test_today_utc"), - pytest.param("{{ today_utc().strftime('%Y/%m/%d') }}", "2021/09/01", id="test_todat_utc_stftime"), + pytest.param( + "{{ today_utc().strftime('%Y/%m/%d') }}", "2021/09/01", id="test_todat_utc_stftime" + ), pytest.param("{{ timestamp(1646006400) }}", 1646006400, id="test_timestamp_from_timestamp"), - pytest.param("{{ timestamp('2022-02-28') }}", 1646006400, id="test_timestamp_from_timestamp"), - pytest.param("{{ timestamp('2022-02-28T00:00:00Z') }}", 1646006400, id="test_timestamp_from_timestamp"), - pytest.param("{{ timestamp('2022-02-28 00:00:00Z') }}", 1646006400, id="test_timestamp_from_timestamp"), - pytest.param("{{ timestamp('2022-02-28T00:00:00-08:00') }}", 1646035200, id="test_timestamp_from_date_with_tz"), + pytest.param( + "{{ timestamp('2022-02-28') }}", 1646006400, id="test_timestamp_from_timestamp" + ), + pytest.param( + "{{ timestamp('2022-02-28T00:00:00Z') }}", + 1646006400, + id="test_timestamp_from_timestamp", + ), + pytest.param( + "{{ timestamp('2022-02-28 00:00:00Z') }}", + 1646006400, + id="test_timestamp_from_timestamp", + ), + pytest.param( + "{{ timestamp('2022-02-28T00:00:00-08:00') }}", + 1646035200, + id="test_timestamp_from_date_with_tz", + ), pytest.param("{{ max(2, 3) }}", 3, id="test_max_with_arguments"), pytest.param("{{ max([2, 3]) }}", 3, id="test_max_with_list"), pytest.param("{{ day_delta(1) }}", "2021-09-02T00:00:00.000000+0000", id="test_day_delta"), - pytest.param("{{ day_delta(-1) }}", "2021-08-31T00:00:00.000000+0000", id="test_day_delta_negative"), - pytest.param("{{ day_delta(1, format='%Y-%m-%d') }}", "2021-09-02", id="test_day_delta_with_format"), + pytest.param( + "{{ day_delta(-1) }}", "2021-08-31T00:00:00.000000+0000", id="test_day_delta_negative" + ), + pytest.param( + "{{ day_delta(1, format='%Y-%m-%d') }}", "2021-09-02", id="test_day_delta_with_format" + ), pytest.param("{{ duration('P1D') }}", "1 day, 0:00:00", id="test_duration_one_day"), - pytest.param("{{ duration('P6DT23H') }}", "6 days, 23:00:00", id="test_duration_six_days_and_23_hours"), + pytest.param( + "{{ duration('P6DT23H') }}", + "6 days, 23:00:00", + id="test_duration_six_days_and_23_hours", + ), pytest.param( "{{ (now_utc() - duration('P1D')).strftime('%Y-%m-%dT%H:%M:%SZ') }}", "2021-08-31T00:00:00Z", id="test_now_utc_with_duration_and_format", ), pytest.param("{{ 1 | string }}", "1", id="test_int_to_string"), - pytest.param('{{ ["hello", "world"] | string }}', '["hello", "world"]', id="test_array_to_string"), + pytest.param( + '{{ ["hello", "world"] | string }}', '["hello", "world"]', id="test_array_to_string" + ), ], ) def test_macros_examples(template_string, expected_value): @@ -283,7 +325,11 @@ def test_macros_examples(template_string, expected_value): @pytest.mark.parametrize( "template_string, expected_value", [ - pytest.param("{{ today_with_timezone('Pacific/Kiritimati') }}", "2021-09-02", id="test_today_timezone_pacific"), + pytest.param( + "{{ today_with_timezone('Pacific/Kiritimati') }}", + "2021-09-02", + id="test_today_timezone_pacific", + ), ], ) def test_macros_timezone(template_string: str, expected_value: str): diff --git a/unit_tests/sources/declarative/interpolation/test_macros.py b/unit_tests/sources/declarative/interpolation/test_macros.py index d2a2a291..cd16bd9f 100644 --- a/unit_tests/sources/declarative/interpolation/test_macros.py +++ b/unit_tests/sources/declarative/interpolation/test_macros.py @@ -33,12 +33,48 @@ def test_macros_export(test_name, fn_name, found_in_macros): ("test_datetime_string_to_date", "2022-01-01T01:01:01Z", "%Y-%m-%d", None, "2022-01-01"), ("test_date_string_to_date", "2022-01-01", "%Y-%m-%d", None, "2022-01-01"), ("test_datetime_string_to_date", "2022-01-01T00:00:00Z", "%Y-%m-%d", None, "2022-01-01"), - ("test_datetime_with_tz_string_to_date", "2022-01-01T00:00:00Z", "%Y-%m-%d", None, "2022-01-01"), - ("test_datetime_string_to_datetime", "2022-01-01T01:01:01Z", "%Y-%m-%dT%H:%M:%SZ", None, "2022-01-01T01:01:01Z"), - ("test_datetime_string_with_tz_to_datetime", "2022-01-01T01:01:01-0800", "%Y-%m-%dT%H:%M:%SZ", None, "2022-01-01T09:01:01Z"), - ("test_datetime_object_tz_to_date", datetime.datetime(2022, 1, 1, 1, 1, 1), "%Y-%m-%d", None, "2022-01-01"), - ("test_datetime_object_tz_to_datetime", datetime.datetime(2022, 1, 1, 1, 1, 1), "%Y-%m-%dT%H:%M:%SZ", None, "2022-01-01T01:01:01Z"), - ("test_datetime_string_to_rfc2822_date", "Sat, 01 Jan 2022 01:01:01 +0000", "%Y-%m-%d", "%a, %d %b %Y %H:%M:%S %z", "2022-01-01"), + ( + "test_datetime_with_tz_string_to_date", + "2022-01-01T00:00:00Z", + "%Y-%m-%d", + None, + "2022-01-01", + ), + ( + "test_datetime_string_to_datetime", + "2022-01-01T01:01:01Z", + "%Y-%m-%dT%H:%M:%SZ", + None, + "2022-01-01T01:01:01Z", + ), + ( + "test_datetime_string_with_tz_to_datetime", + "2022-01-01T01:01:01-0800", + "%Y-%m-%dT%H:%M:%SZ", + None, + "2022-01-01T09:01:01Z", + ), + ( + "test_datetime_object_tz_to_date", + datetime.datetime(2022, 1, 1, 1, 1, 1), + "%Y-%m-%d", + None, + "2022-01-01", + ), + ( + "test_datetime_object_tz_to_datetime", + datetime.datetime(2022, 1, 1, 1, 1, 1), + "%Y-%m-%dT%H:%M:%SZ", + None, + "2022-01-01T01:01:01Z", + ), + ( + "test_datetime_string_to_rfc2822_date", + "Sat, 01 Jan 2022 01:01:01 +0000", + "%Y-%m-%d", + "%a, %d %b %Y %H:%M:%S %z", + "2022-01-01", + ), ], ) def test_format_datetime(test_name, input_value, format, input_format, expected_output): @@ -48,7 +84,10 @@ def test_format_datetime(test_name, input_value, format, input_format, expected_ @pytest.mark.parametrize( "test_name, input_value, expected_output", - [("test_one_day", "P1D", datetime.timedelta(days=1)), ("test_6_days_23_hours", "P6DT23H", datetime.timedelta(days=6, hours=23))], + [ + ("test_one_day", "P1D", datetime.timedelta(days=1)), + ("test_6_days_23_hours", "P6DT23H", datetime.timedelta(days=6, hours=23)), + ], ) def test_duration(test_name, input_value, expected_output): duration_fn = macros["duration"] diff --git a/unit_tests/sources/declarative/migrations/test_legacy_to_per_partition_migration.py b/unit_tests/sources/declarative/migrations/test_legacy_to_per_partition_migration.py index 97e5efd6..442e444a 100644 --- a/unit_tests/sources/declarative/migrations/test_legacy_to_per_partition_migration.py +++ b/unit_tests/sources/declarative/migrations/test_legacy_to_per_partition_migration.py @@ -5,13 +5,32 @@ from unittest.mock import MagicMock import pytest -from airbyte_cdk.sources.declarative.migrations.legacy_to_per_partition_state_migration import LegacyToPerPartitionStateMigration -from airbyte_cdk.sources.declarative.models import CustomPartitionRouter, CustomRetriever, DatetimeBasedCursor, DeclarativeStream -from airbyte_cdk.sources.declarative.models import LegacyToPerPartitionStateMigration as LegacyToPerPartitionStateMigrationModel -from airbyte_cdk.sources.declarative.models import ParentStreamConfig, SimpleRetriever, SubstreamPartitionRouter -from airbyte_cdk.sources.declarative.parsers.manifest_component_transformer import ManifestComponentTransformer -from airbyte_cdk.sources.declarative.parsers.manifest_reference_resolver import ManifestReferenceResolver -from airbyte_cdk.sources.declarative.parsers.model_to_component_factory import ModelToComponentFactory +from airbyte_cdk.sources.declarative.migrations.legacy_to_per_partition_state_migration import ( + LegacyToPerPartitionStateMigration, +) +from airbyte_cdk.sources.declarative.models import ( + CustomPartitionRouter, + CustomRetriever, + DatetimeBasedCursor, + DeclarativeStream, +) +from airbyte_cdk.sources.declarative.models import ( + LegacyToPerPartitionStateMigration as LegacyToPerPartitionStateMigrationModel, +) +from airbyte_cdk.sources.declarative.models import ( + ParentStreamConfig, + SimpleRetriever, + SubstreamPartitionRouter, +) +from airbyte_cdk.sources.declarative.parsers.manifest_component_transformer import ( + ManifestComponentTransformer, +) +from airbyte_cdk.sources.declarative.parsers.manifest_reference_resolver import ( + ManifestReferenceResolver, +) +from airbyte_cdk.sources.declarative.parsers.model_to_component_factory import ( + ModelToComponentFactory, +) from airbyte_cdk.sources.declarative.yaml_declarative_source import YamlDeclarativeSource factory = ModelToComponentFactory() @@ -33,8 +52,14 @@ def test_migrate_a_valid_legacy_state_to_per_partition(): expected_state = { "states": [ - {"partition": {"parent_id": "13506132"}, "cursor": {"last_changed": "2022-12-27T08:34:39+00:00"}}, - {"partition": {"parent_id": "14351124"}, "cursor": {"last_changed": "2022-12-27T08:35:39+00:00"}}, + { + "partition": {"parent_id": "13506132"}, + "cursor": {"last_changed": "2022-12-27T08:34:39+00:00"}, + }, + { + "partition": {"parent_id": "14351124"}, + "cursor": {"last_changed": "2022-12-27T08:35:39+00:00"}, + }, ] } @@ -47,8 +72,14 @@ def test_migrate_a_valid_legacy_state_to_per_partition(): pytest.param( { "states": [ - {"partition": {"id": "13506132"}, "cursor": {"last_changed": "2022-12-27T08:34:39+00:00"}}, - {"partition": {"id": "14351124"}, "cursor": {"last_changed": "2022-12-27T08:35:39+00:00"}}, + { + "partition": {"id": "13506132"}, + "cursor": {"last_changed": "2022-12-27T08:34:39+00:00"}, + }, + { + "partition": {"id": "14351124"}, + "cursor": {"last_changed": "2022-12-27T08:35:39+00:00"}, + }, ] }, id="test_should_not_migrate_a_per_partition_state", @@ -56,7 +87,10 @@ def test_migrate_a_valid_legacy_state_to_per_partition(): pytest.param( { "states": [ - {"partition": {"id": "13506132"}, "cursor": {"last_changed": "2022-12-27T08:34:39+00:00"}}, + { + "partition": {"id": "13506132"}, + "cursor": {"last_changed": "2022-12-27T08:34:39+00:00"}, + }, { "partition": {"id": "14351124"}, }, @@ -67,8 +101,14 @@ def test_migrate_a_valid_legacy_state_to_per_partition(): pytest.param( { "states": [ - {"partition": {"id": "13506132"}, "cursor": {"updated_at": "2022-12-27T08:34:39+00:00"}}, - {"partition": {"id": "14351124"}, "cursor": {"updated_at": "2022-12-27T08:35:39+00:00"}}, + { + "partition": {"id": "13506132"}, + "cursor": {"updated_at": "2022-12-27T08:34:39+00:00"}, + }, + { + "partition": {"id": "14351124"}, + "cursor": {"updated_at": "2022-12-27T08:35:39+00:00"}, + }, ] }, id="test_should_not_migrate_a_per_partition_state_with_wrong_cursor_field", @@ -76,8 +116,17 @@ def test_migrate_a_valid_legacy_state_to_per_partition(): pytest.param( { "states": [ - {"partition": {"id": "13506132"}, "cursor": {"last_changed": "2022-12-27T08:34:39+00:00"}}, - {"partition": {"id": "14351124"}, "cursor": {"last_changed": "2022-12-27T08:35:39+00:00", "updated_at": "2021-01-01"}}, + { + "partition": {"id": "13506132"}, + "cursor": {"last_changed": "2022-12-27T08:34:39+00:00"}, + }, + { + "partition": {"id": "14351124"}, + "cursor": { + "last_changed": "2022-12-27T08:35:39+00:00", + "updated_at": "2021-01-01", + }, + }, ] }, id="test_should_not_migrate_a_per_partition_state_with_multiple_cursor_fields", @@ -85,7 +134,10 @@ def test_migrate_a_valid_legacy_state_to_per_partition(): pytest.param( { "states": [ - {"partition": {"id": "13506132"}, "cursor": {"last_changed": "2022-12-27T08:34:39+00:00"}}, + { + "partition": {"id": "13506132"}, + "cursor": {"last_changed": "2022-12-27T08:34:39+00:00"}, + }, {"cursor": {"last_changed": "2022-12-27T08:34:39+00:00"}}, ] }, @@ -94,8 +146,14 @@ def test_migrate_a_valid_legacy_state_to_per_partition(): pytest.param( { "states": [ - {"partition": {"id": "13506132", "another_id": "A"}, "cursor": {"last_changed": "2022-12-27T08:34:39+00:00"}}, - {"partition": {"id": "13506134"}, "cursor": {"last_changed": "2022-12-27T08:34:39+00:00"}}, + { + "partition": {"id": "13506132", "another_id": "A"}, + "cursor": {"last_changed": "2022-12-27T08:34:39+00:00"}, + }, + { + "partition": {"id": "13506134"}, + "cursor": {"last_changed": "2022-12-27T08:34:39+00:00"}, + }, ] }, id="test_should_not_migrate_state_if_multiple_partition_keys", @@ -103,8 +161,14 @@ def test_migrate_a_valid_legacy_state_to_per_partition(): pytest.param( { "states": [ - {"partition": {"identifier": "13506132"}, "cursor": {"last_changed": "2022-12-27T08:34:39+00:00"}}, - {"partition": {"id": "13506134"}, "cursor": {"last_changed": "2022-12-27T08:34:39+00:00"}}, + { + "partition": {"identifier": "13506132"}, + "cursor": {"last_changed": "2022-12-27T08:34:39+00:00"}, + }, + { + "partition": {"id": "13506134"}, + "cursor": {"last_changed": "2022-12-27T08:34:39+00:00"}, + }, ] }, id="test_should_not_migrate_state_if_invalid_partition_key", @@ -112,7 +176,10 @@ def test_migrate_a_valid_legacy_state_to_per_partition(): pytest.param( { "13506132": {"last_changed": "2022-12-27T08:34:39+00:00"}, - "14351124": {"last_changed": "2022-12-27T08:35:39+00:00", "another_key": "2022-12-27T08:35:39+00:00"}, + "14351124": { + "last_changed": "2022-12-27T08:35:39+00:00", + "another_key": "2022-12-27T08:35:39+00:00", + }, }, id="test_should_not_migrate_if_the_partitioned_state_has_more_than_one_key", ), @@ -150,7 +217,8 @@ def _migrator(): parent_key="{{ parameters['parent_key_id'] }}", partition_field="parent_id", stream=DeclarativeStream( - type="DeclarativeStream", retriever=CustomRetriever(type="CustomRetriever", class_name="a_class_name") + type="DeclarativeStream", + retriever=CustomRetriever(type="CustomRetriever", class_name="a_class_name"), ), ) ], @@ -175,7 +243,8 @@ def _migrator_with_multiple_parent_streams(): parent_key="id", partition_field="parent_id", stream=DeclarativeStream( - type="DeclarativeStream", retriever=CustomRetriever(type="CustomRetriever", class_name="a_class_name") + type="DeclarativeStream", + retriever=CustomRetriever(type="CustomRetriever", class_name="a_class_name"), ), ), ParentStreamConfig( @@ -183,7 +252,8 @@ def _migrator_with_multiple_parent_streams(): parent_key="id", partition_field="parent_id", stream=DeclarativeStream( - type="DeclarativeStream", retriever=CustomRetriever(type="CustomRetriever", class_name="a_class_name") + type="DeclarativeStream", + retriever=CustomRetriever(type="CustomRetriever", class_name="a_class_name"), ), ), ], @@ -233,7 +303,11 @@ def test_create_legacy_to_per_partition_state_migration( expected_exception, expected_error_message, ): - partition_router = partition_router_class(type="CustomPartitionRouter", class_name="a_class_namer") if partition_router_class else None + partition_router = ( + partition_router_class(type="CustomPartitionRouter", class_name="a_class_namer") + if partition_router_class + else None + ) stream = MagicMock() stream.retriever = MagicMock(spec=retriever_type) @@ -245,7 +319,9 @@ def test_create_legacy_to_per_partition_state_migration( """ resolved_manifest = resolver.preprocess_manifest(YamlDeclarativeSource._parse(content)) - state_migrations_manifest = transformer.propagate_types_and_parameters("", resolved_manifest["state_migrations"][0], {}) + state_migrations_manifest = transformer.propagate_types_and_parameters( + "", resolved_manifest["state_migrations"][0], {} + ) if is_parent_stream_config: parent_stream_config = ParentStreamConfig( @@ -253,7 +329,8 @@ def test_create_legacy_to_per_partition_state_migration( parent_key="id", partition_field="parent_id", stream=DeclarativeStream( - type="DeclarativeStream", retriever=CustomRetriever(type="CustomRetriever", class_name="a_class_name") + type="DeclarativeStream", + retriever=CustomRetriever(type="CustomRetriever", class_name="a_class_name"), ), ) partition_router.parent_stream_configs = [parent_stream_config] diff --git a/unit_tests/sources/declarative/parsers/test_manifest_component_transformer.py b/unit_tests/sources/declarative/parsers/test_manifest_component_transformer.py index 63efac66..3cd06273 100644 --- a/unit_tests/sources/declarative/parsers/test_manifest_component_transformer.py +++ b/unit_tests/sources/declarative/parsers/test_manifest_component_transformer.py @@ -3,14 +3,19 @@ # import pytest -from airbyte_cdk.sources.declarative.parsers.manifest_component_transformer import ManifestComponentTransformer +from airbyte_cdk.sources.declarative.parsers.manifest_component_transformer import ( + ManifestComponentTransformer, +) @pytest.mark.parametrize( "component, expected_component", [ pytest.param( - {"type": "DeclarativeSource", "streams": [{"type": "DeclarativeStream", "retriever": {}, "schema_loader": {}}]}, + { + "type": "DeclarativeSource", + "streams": [{"type": "DeclarativeStream", "retriever": {}, "schema_loader": {}}], + }, { "type": "DeclarativeSource", "streams": [ @@ -26,7 +31,12 @@ pytest.param( { "type": "DeclarativeStream", - "retriever": {"type": "SimpleRetriever", "paginator": {}, "record_selector": {}, "requester": {}}, + "retriever": { + "type": "SimpleRetriever", + "paginator": {}, + "record_selector": {}, + "requester": {}, + }, }, { "type": "DeclarativeStream", @@ -40,7 +50,10 @@ id="test_simple_retriever", ), pytest.param( - {"type": "DeclarativeStream", "requester": {"type": "HttpRequester", "error_handler": {}}}, + { + "type": "DeclarativeStream", + "requester": {"type": "HttpRequester", "error_handler": {}}, + }, { "type": "DeclarativeStream", "requester": { @@ -51,7 +64,14 @@ id="test_http_requester", ), pytest.param( - {"type": "SimpleRetriever", "paginator": {"type": "DefaultPaginator", "page_size_option": {}, "page_token_option": {}}}, + { + "type": "SimpleRetriever", + "paginator": { + "type": "DefaultPaginator", + "page_size_option": {}, + "page_token_option": {}, + }, + }, { "type": "SimpleRetriever", "paginator": { @@ -63,7 +83,13 @@ id="test_default_paginator", ), pytest.param( - {"type": "SimpleRetriever", "partition_router": {"type": "SubstreamPartitionRouter", "parent_stream_configs": [{}, {}, {}]}}, + { + "type": "SimpleRetriever", + "partition_router": { + "type": "SubstreamPartitionRouter", + "parent_stream_configs": [{}, {}, {}], + }, + }, { "type": "SimpleRetriever", "partition_router": { @@ -92,13 +118,21 @@ def test_find_default_types(component, expected_component): pytest.param( { "type": "SimpleRetriever", - "requester": {"type": "HttpRequester", "authenticator": {"class_name": "source_greenhouse.components.NewAuthenticator"}}, + "requester": { + "type": "HttpRequester", + "authenticator": { + "class_name": "source_greenhouse.components.NewAuthenticator" + }, + }, }, { "type": "SimpleRetriever", "requester": { "type": "HttpRequester", - "authenticator": {"type": "CustomAuthenticator", "class_name": "source_greenhouse.components.NewAuthenticator"}, + "authenticator": { + "type": "CustomAuthenticator", + "class_name": "source_greenhouse.components.NewAuthenticator", + }, }, }, id="test_custom_authenticator", @@ -115,7 +149,10 @@ def test_find_default_types(component, expected_component): "type": "SimpleRetriever", "record_selector": { "type": "RecordSelector", - "extractor": {"type": "CustomRecordExtractor", "class_name": "source_greenhouse.components.NewRecordExtractor"}, + "extractor": { + "type": "CustomRecordExtractor", + "class_name": "source_greenhouse.components.NewRecordExtractor", + }, }, }, id="test_custom_extractor", @@ -138,7 +175,10 @@ def test_propagate_parameters_to_all_components(): "$parameters": {"name": "roasters", "primary_key": "id"}, "retriever": { "type": "SimpleRetriever", - "record_selector": {"type": "RecordSelector", "extractor": {"type": "DpathExtractor", "field_path": []}}, + "record_selector": { + "type": "RecordSelector", + "extractor": {"type": "DpathExtractor", "field_path": []}, + }, "requester": { "type": "HttpRequester", "name": '{{ parameters["name"] }}', @@ -252,7 +292,10 @@ def test_do_not_propagate_parameters_that_have_the_same_field_name(): "$parameters": { "name": "roasters", "primary_key": "id", - "schema_loader": {"type": "JsonFileSchemaLoader", "file_path": './source_coffee/schemas/{{ parameters["name"] }}.json'}, + "schema_loader": { + "type": "JsonFileSchemaLoader", + "file_path": './source_coffee/schemas/{{ parameters["name"] }}.json', + }, }, } ], @@ -278,7 +321,10 @@ def test_do_not_propagate_parameters_that_have_the_same_field_name(): "$parameters": { "name": "roasters", "primary_key": "id", - "schema_loader": {"type": "JsonFileSchemaLoader", "file_path": './source_coffee/schemas/{{ parameters["name"] }}.json'}, + "schema_loader": { + "type": "JsonFileSchemaLoader", + "file_path": './source_coffee/schemas/{{ parameters["name"] }}.json', + }, }, } ], @@ -295,7 +341,10 @@ def test_ignore_empty_parameters(): "type": "DeclarativeStream", "retriever": { "type": "SimpleRetriever", - "record_selector": {"type": "RecordSelector", "extractor": {"type": "DpathExtractor", "field_path": []}}, + "record_selector": { + "type": "RecordSelector", + "extractor": {"type": "DpathExtractor", "field_path": []}, + }, }, } @@ -317,7 +366,10 @@ def test_only_propagate_parameters_to_components(): "some_option": "already", }, }, - "dictionary_field": {"details": "should_not_contain_parameters", "other": "no_parameters_as_fields"}, + "dictionary_field": { + "details": "should_not_contain_parameters", + "other": "no_parameters_as_fields", + }, "$parameters": { "included": "not!", }, @@ -335,7 +387,10 @@ def test_only_propagate_parameters_to_components(): "included": "not!", "$parameters": {"some_option": "already", "included": "not!"}, }, - "dictionary_field": {"details": "should_not_contain_parameters", "other": "no_parameters_as_fields"}, + "dictionary_field": { + "details": "should_not_contain_parameters", + "other": "no_parameters_as_fields", + }, "included": "not!", "$parameters": { "included": "not!", diff --git a/unit_tests/sources/declarative/parsers/test_manifest_reference_resolver.py b/unit_tests/sources/declarative/parsers/test_manifest_reference_resolver.py index 75ee51c8..36ae03cc 100644 --- a/unit_tests/sources/declarative/parsers/test_manifest_reference_resolver.py +++ b/unit_tests/sources/declarative/parsers/test_manifest_reference_resolver.py @@ -3,8 +3,14 @@ # import pytest -from airbyte_cdk.sources.declarative.parsers.custom_exceptions import CircularReferenceException, UndefinedReferenceException -from airbyte_cdk.sources.declarative.parsers.manifest_reference_resolver import ManifestReferenceResolver, _parse_path +from airbyte_cdk.sources.declarative.parsers.custom_exceptions import ( + CircularReferenceException, + UndefinedReferenceException, +) +from airbyte_cdk.sources.declarative.parsers.manifest_reference_resolver import ( + ManifestReferenceResolver, + _parse_path, +) resolver = ManifestReferenceResolver() @@ -29,7 +35,13 @@ def test_refer_to_non_existant_struct(): def test_refer_in_dict(): - content = {"limit": 50, "offset_request_parameters": {"offset": "{{ next_page_token['offset'] }}", "limit": "#/limit"}} + content = { + "limit": 50, + "offset_request_parameters": { + "offset": "{{ next_page_token['offset'] }}", + "limit": "#/limit", + }, + } config = resolver.preprocess_manifest(content) assert config["offset_request_parameters"]["offset"] == "{{ next_page_token['offset'] }}" assert config["offset_request_parameters"]["limit"] == 50 @@ -38,7 +50,10 @@ def test_refer_in_dict(): def test_refer_to_dict(): content = { "limit": 50, - "offset_request_parameters": {"offset": "{{ next_page_token['offset'] }}", "limit": "#/limit"}, + "offset_request_parameters": { + "offset": "{{ next_page_token['offset'] }}", + "limit": "#/limit", + }, "offset_pagination_request_parameters": { "class": "InterpolatedRequestParameterProvider", "request_parameters": "#/offset_request_parameters", @@ -49,15 +64,24 @@ def test_refer_to_dict(): assert config["offset_request_parameters"]["limit"] == 50 assert len(config["offset_pagination_request_parameters"]) == 2 assert config["offset_pagination_request_parameters"]["request_parameters"]["limit"] == 50 - assert config["offset_pagination_request_parameters"]["request_parameters"]["offset"] == "{{ next_page_token['offset'] }}" + assert ( + config["offset_pagination_request_parameters"]["request_parameters"]["offset"] + == "{{ next_page_token['offset'] }}" + ) def test_refer_and_overwrite(): content = { "limit": 50, "custom_limit": 25, - "offset_request_parameters": {"offset": "{{ next_page_token['offset'] }}", "limit": "#/limit"}, - "custom_request_parameters": {"$ref": "#/offset_request_parameters", "limit": "#/custom_limit"}, + "offset_request_parameters": { + "offset": "{{ next_page_token['offset'] }}", + "limit": "#/limit", + }, + "custom_request_parameters": { + "$ref": "#/offset_request_parameters", + "limit": "#/custom_limit", + }, } config = resolver.preprocess_manifest(content) assert config["offset_request_parameters"]["limit"] == 50 diff --git a/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py b/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py index e14c730d..c8d0781a 100644 --- a/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py +++ b/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py @@ -26,7 +26,9 @@ from airbyte_cdk.sources.declarative.declarative_stream import DeclarativeStream from airbyte_cdk.sources.declarative.decoders import JsonDecoder, PaginationDecoderDecorator from airbyte_cdk.sources.declarative.extractors import DpathExtractor, RecordFilter, RecordSelector -from airbyte_cdk.sources.declarative.extractors.record_filter import ClientSideIncrementalRecordFilterDecorator +from airbyte_cdk.sources.declarative.extractors.record_filter import ( + ClientSideIncrementalRecordFilterDecorator, +) from airbyte_cdk.sources.declarative.incremental import ( CursorFactory, DatetimeBasedCursor, @@ -36,10 +38,14 @@ ) from airbyte_cdk.sources.declarative.interpolation import InterpolatedString from airbyte_cdk.sources.declarative.models import CheckStream as CheckStreamModel -from airbyte_cdk.sources.declarative.models import CompositeErrorHandler as CompositeErrorHandlerModel +from airbyte_cdk.sources.declarative.models import ( + CompositeErrorHandler as CompositeErrorHandlerModel, +) from airbyte_cdk.sources.declarative.models import ConcurrencyLevel as ConcurrencyLevelModel from airbyte_cdk.sources.declarative.models import CustomErrorHandler as CustomErrorHandlerModel -from airbyte_cdk.sources.declarative.models import CustomPartitionRouter as CustomPartitionRouterModel +from airbyte_cdk.sources.declarative.models import ( + CustomPartitionRouter as CustomPartitionRouterModel, +) from airbyte_cdk.sources.declarative.models import CustomSchemaLoader as CustomSchemaLoaderModel from airbyte_cdk.sources.declarative.models import DatetimeBasedCursor as DatetimeBasedCursorModel from airbyte_cdk.sources.declarative.models import DeclarativeStream as DeclarativeStreamModel @@ -51,13 +57,27 @@ from airbyte_cdk.sources.declarative.models import RecordSelector as RecordSelectorModel from airbyte_cdk.sources.declarative.models import SimpleRetriever as SimpleRetrieverModel from airbyte_cdk.sources.declarative.models import Spec as SpecModel -from airbyte_cdk.sources.declarative.models import SubstreamPartitionRouter as SubstreamPartitionRouterModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import OffsetIncrement as OffsetIncrementModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import PageIncrement as PageIncrementModel -from airbyte_cdk.sources.declarative.models.declarative_component_schema import SelectiveAuthenticator -from airbyte_cdk.sources.declarative.parsers.manifest_component_transformer import ManifestComponentTransformer -from airbyte_cdk.sources.declarative.parsers.manifest_reference_resolver import ManifestReferenceResolver -from airbyte_cdk.sources.declarative.parsers.model_to_component_factory import ModelToComponentFactory +from airbyte_cdk.sources.declarative.models import ( + SubstreamPartitionRouter as SubstreamPartitionRouterModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + OffsetIncrement as OffsetIncrementModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + PageIncrement as PageIncrementModel, +) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + SelectiveAuthenticator, +) +from airbyte_cdk.sources.declarative.parsers.manifest_component_transformer import ( + ManifestComponentTransformer, +) +from airbyte_cdk.sources.declarative.parsers.manifest_reference_resolver import ( + ManifestReferenceResolver, +) +from airbyte_cdk.sources.declarative.parsers.model_to_component_factory import ( + ModelToComponentFactory, +) from airbyte_cdk.sources.declarative.partition_routers import ( CartesianProductStreamSlicer, ListPartitionRouter, @@ -65,7 +85,11 @@ SubstreamPartitionRouter, ) from airbyte_cdk.sources.declarative.requesters import HttpRequester -from airbyte_cdk.sources.declarative.requesters.error_handlers import CompositeErrorHandler, DefaultErrorHandler, HttpResponseFilter +from airbyte_cdk.sources.declarative.requesters.error_handlers import ( + CompositeErrorHandler, + DefaultErrorHandler, + HttpResponseFilter, +) from airbyte_cdk.sources.declarative.requesters.error_handlers.backoff_strategies import ( ConstantBackoffStrategy, ExponentialBackoffStrategy, @@ -79,7 +103,10 @@ PageIncrement, StopConditionPaginationStrategyDecorator, ) -from airbyte_cdk.sources.declarative.requesters.request_option import RequestOption, RequestOptionType +from airbyte_cdk.sources.declarative.requesters.request_option import ( + RequestOption, + RequestOptionType, +) from airbyte_cdk.sources.declarative.requesters.request_options import ( DatetimeBasedRequestOptionsProvider, DefaultRequestOptionsProvider, @@ -87,7 +114,10 @@ ) from airbyte_cdk.sources.declarative.requesters.request_path import RequestPath from airbyte_cdk.sources.declarative.requesters.requester import HttpMethod -from airbyte_cdk.sources.declarative.retrievers import SimpleRetriever, SimpleRetrieverTestReadDecorator +from airbyte_cdk.sources.declarative.retrievers import ( + SimpleRetriever, + SimpleRetrieverTestReadDecorator, +) from airbyte_cdk.sources.declarative.schema import JsonFileSchemaLoader from airbyte_cdk.sources.declarative.schema.schema_loader import SchemaLoader from airbyte_cdk.sources.declarative.spec import Spec @@ -99,8 +129,13 @@ CustomFormatConcurrentStreamStateConverter, ) from airbyte_cdk.sources.streams.http.error_handlers.response_models import ResponseAction -from airbyte_cdk.sources.streams.http.requests_native_auth.oauth import SingleUseRefreshTokenOauth2Authenticator -from unit_tests.sources.declarative.parsers.testing_components import TestingCustomSubstreamPartitionRouter, TestingSomeComponent +from airbyte_cdk.sources.streams.http.requests_native_auth.oauth import ( + SingleUseRefreshTokenOauth2Authenticator, +) +from unit_tests.sources.declarative.parsers.testing_components import ( + TestingCustomSubstreamPartitionRouter, + TestingSomeComponent, +) factory = ModelToComponentFactory() @@ -242,7 +277,9 @@ def test_full_config_stream(): stream_manifest = manifest["list_stream"] assert stream_manifest["type"] == "DeclarativeStream" - stream = factory.create_component(model_type=DeclarativeStreamModel, component_definition=stream_manifest, config=input_config) + stream = factory.create_component( + model_type=DeclarativeStreamModel, component_definition=stream_manifest, config=input_config + ) assert isinstance(stream, DeclarativeStream) assert stream.primary_key == "id" @@ -266,23 +303,39 @@ def test_full_config_stream(): assert isinstance(stream.retriever.record_selector.extractor, DpathExtractor) assert isinstance(stream.retriever.record_selector.extractor.decoder, JsonDecoder) - assert [fp.eval(input_config) for fp in stream.retriever.record_selector.extractor._field_path] == ["lists"] + assert [ + fp.eval(input_config) for fp in stream.retriever.record_selector.extractor._field_path + ] == ["lists"] assert isinstance(stream.retriever.record_selector.record_filter, RecordFilter) - assert stream.retriever.record_selector.record_filter._filter_interpolator.condition == "{{ record['id'] > stream_state['id'] }}" + assert ( + stream.retriever.record_selector.record_filter._filter_interpolator.condition + == "{{ record['id'] > stream_state['id'] }}" + ) assert isinstance(stream.retriever.paginator, DefaultPaginator) assert isinstance(stream.retriever.paginator.decoder, PaginationDecoderDecorator) assert stream.retriever.paginator.page_size_option.field_name.eval(input_config) == "page_size" - assert stream.retriever.paginator.page_size_option.inject_into == RequestOptionType.request_parameter + assert ( + stream.retriever.paginator.page_size_option.inject_into + == RequestOptionType.request_parameter + ) assert isinstance(stream.retriever.paginator.page_token_option, RequestPath) assert stream.retriever.paginator.url_base.string == "https://api.sendgrid.com/v3/" assert stream.retriever.paginator.url_base.default == "https://api.sendgrid.com/v3/" assert isinstance(stream.retriever.paginator.pagination_strategy, CursorPaginationStrategy) - assert isinstance(stream.retriever.paginator.pagination_strategy.decoder, PaginationDecoderDecorator) - assert stream.retriever.paginator.pagination_strategy._cursor_value.string == "{{ response._metadata.next }}" - assert stream.retriever.paginator.pagination_strategy._cursor_value.default == "{{ response._metadata.next }}" + assert isinstance( + stream.retriever.paginator.pagination_strategy.decoder, PaginationDecoderDecorator + ) + assert ( + stream.retriever.paginator.pagination_strategy._cursor_value.string + == "{{ response._metadata.next }}" + ) + assert ( + stream.retriever.paginator.pagination_strategy._cursor_value.default + == "{{ response._metadata.next }}" + ) assert stream.retriever.paginator.pagination_strategy.page_size == 10 assert isinstance(stream.retriever.requester, HttpRequester) @@ -292,27 +345,51 @@ def test_full_config_stream(): assert stream.retriever.requester._path.default == "{{ next_page_token['next_page_url'] }}" assert isinstance(stream.retriever.request_option_provider, DatetimeBasedRequestOptionsProvider) - assert stream.retriever.request_option_provider.start_time_option.inject_into == RequestOptionType.request_parameter - assert stream.retriever.request_option_provider.start_time_option.field_name.eval(config=input_config) == "after" - assert stream.retriever.request_option_provider.end_time_option.inject_into == RequestOptionType.request_parameter - assert stream.retriever.request_option_provider.end_time_option.field_name.eval(config=input_config) == "before" + assert ( + stream.retriever.request_option_provider.start_time_option.inject_into + == RequestOptionType.request_parameter + ) + assert ( + stream.retriever.request_option_provider.start_time_option.field_name.eval( + config=input_config + ) + == "after" + ) + assert ( + stream.retriever.request_option_provider.end_time_option.inject_into + == RequestOptionType.request_parameter + ) + assert ( + stream.retriever.request_option_provider.end_time_option.field_name.eval( + config=input_config + ) + == "before" + ) assert stream.retriever.request_option_provider._partition_field_start.string == "start_time" assert stream.retriever.request_option_provider._partition_field_end.string == "end_time" assert isinstance(stream.retriever.requester.authenticator, BearerAuthenticator) assert stream.retriever.requester.authenticator.token_provider.get_token() == "verysecrettoken" - assert isinstance(stream.retriever.requester.request_options_provider, InterpolatedRequestOptionsProvider) - assert stream.retriever.requester.request_options_provider.request_parameters.get("unit") == "day" + assert isinstance( + stream.retriever.requester.request_options_provider, InterpolatedRequestOptionsProvider + ) + assert ( + stream.retriever.requester.request_options_provider.request_parameters.get("unit") == "day" + ) - checker = factory.create_component(model_type=CheckStreamModel, component_definition=manifest["check"], config=input_config) + checker = factory.create_component( + model_type=CheckStreamModel, component_definition=manifest["check"], config=input_config + ) assert isinstance(checker, CheckStream) streams_to_check = checker.stream_names assert len(streams_to_check) == 1 assert list(streams_to_check)[0] == "list_stream" - spec = factory.create_component(model_type=SpecModel, component_definition=manifest["spec"], config=input_config) + spec = factory.create_component( + model_type=SpecModel, component_definition=manifest["spec"], config=input_config + ) assert isinstance(spec, Spec) documentation_url = spec.documentation_url @@ -331,7 +408,9 @@ def test_full_config_stream(): assert advanced_auth.auth_flow_type.value == "oauth2.0" concurrency_level = factory.create_component( - model_type=ConcurrencyLevelModel, component_definition=manifest["concurrency_level"], config=input_config + model_type=ConcurrencyLevelModel, + component_definition=manifest["concurrency_level"], + config=input_config, ) assert isinstance(concurrency_level, ConcurrencyLevel) assert isinstance(concurrency_level._default_concurrency, InterpolatedString) @@ -353,19 +432,32 @@ def test_interpolate_config(): """ parsed_manifest = YamlDeclarativeSource._parse(content) resolved_manifest = resolver.preprocess_manifest(parsed_manifest) - authenticator_manifest = transformer.propagate_types_and_parameters("", resolved_manifest["authenticator"], {}) + authenticator_manifest = transformer.propagate_types_and_parameters( + "", resolved_manifest["authenticator"], {} + ) authenticator = factory.create_component( - model_type=OAuthAuthenticatorModel, component_definition=authenticator_manifest, config=input_config + model_type=OAuthAuthenticatorModel, + component_definition=authenticator_manifest, + config=input_config, ) assert isinstance(authenticator, DeclarativeOauth2Authenticator) assert authenticator._client_id.eval(input_config) == "some_client_id" assert authenticator._client_secret.string == "some_client_secret" - assert authenticator._token_refresh_endpoint.eval(input_config) == "https://api.sendgrid.com/v3/auth" + assert ( + authenticator._token_refresh_endpoint.eval(input_config) + == "https://api.sendgrid.com/v3/auth" + ) assert authenticator._refresh_token.eval(input_config) == "verysecrettoken" - assert authenticator._refresh_request_body.mapping == {"body_field": "yoyoyo", "interpolated_body_field": "{{ config['apikey'] }}"} - assert authenticator.get_refresh_request_body() == {"body_field": "yoyoyo", "interpolated_body_field": "verysecrettoken"} + assert authenticator._refresh_request_body.mapping == { + "body_field": "yoyoyo", + "interpolated_body_field": "{{ config['apikey'] }}", + } + assert authenticator.get_refresh_request_body() == { + "body_field": "yoyoyo", + "interpolated_body_field": "verysecrettoken", + } def test_interpolate_config_with_token_expiry_date_format(): @@ -380,10 +472,14 @@ def test_interpolate_config_with_token_expiry_date_format(): """ parsed_manifest = YamlDeclarativeSource._parse(content) resolved_manifest = resolver.preprocess_manifest(parsed_manifest) - authenticator_manifest = transformer.propagate_types_and_parameters("", resolved_manifest["authenticator"], {}) + authenticator_manifest = transformer.propagate_types_and_parameters( + "", resolved_manifest["authenticator"], {} + ) authenticator = factory.create_component( - model_type=OAuthAuthenticatorModel, component_definition=authenticator_manifest, config=input_config + model_type=OAuthAuthenticatorModel, + component_definition=authenticator_manifest, + config=input_config, ) assert isinstance(authenticator, DeclarativeOauth2Authenticator) @@ -391,7 +487,10 @@ def test_interpolate_config_with_token_expiry_date_format(): assert authenticator.token_expiry_is_time_of_expiration assert authenticator._client_id.eval(input_config) == "some_client_id" assert authenticator._client_secret.string == "some_client_secret" - assert authenticator._token_refresh_endpoint.eval(input_config) == "https://api.sendgrid.com/v3/auth" + assert ( + authenticator._token_refresh_endpoint.eval(input_config) + == "https://api.sendgrid.com/v3/auth" + ) def test_single_use_oauth_branch(): @@ -421,10 +520,14 @@ def test_single_use_oauth_branch(): """ parsed_manifest = YamlDeclarativeSource._parse(content) resolved_manifest = resolver.preprocess_manifest(parsed_manifest) - authenticator_manifest = transformer.propagate_types_and_parameters("", resolved_manifest["authenticator"], {}) + authenticator_manifest = transformer.propagate_types_and_parameters( + "", resolved_manifest["authenticator"], {} + ) authenticator: SingleUseRefreshTokenOauth2Authenticator = factory.create_component( - model_type=OAuthAuthenticatorModel, component_definition=authenticator_manifest, config=single_use_input_config + model_type=OAuthAuthenticatorModel, + component_definition=authenticator_manifest, + config=single_use_input_config, ) assert isinstance(authenticator, SingleUseRefreshTokenOauth2Authenticator) @@ -432,7 +535,10 @@ def test_single_use_oauth_branch(): assert authenticator._client_secret == "some_client_secret" assert authenticator._token_refresh_endpoint == "https://api.sendgrid.com/v3/auth" assert authenticator._refresh_token == "verysecrettoken" - assert authenticator._refresh_request_body == {"body_field": "yoyoyo", "interpolated_body_field": "verysecrettoken"} + assert authenticator._refresh_request_body == { + "body_field": "yoyoyo", + "interpolated_body_field": "verysecrettoken", + } assert authenticator._refresh_token_name == "the_refresh_token" assert authenticator._refresh_token_config_path == ["apikey"] # default values @@ -453,10 +559,14 @@ def test_list_based_stream_slicer_with_values_refd(): """ parsed_manifest = YamlDeclarativeSource._parse(content) resolved_manifest = resolver.preprocess_manifest(parsed_manifest) - partition_router_manifest = transformer.propagate_types_and_parameters("", resolved_manifest["partition_router"], {}) + partition_router_manifest = transformer.propagate_types_and_parameters( + "", resolved_manifest["partition_router"], {} + ) partition_router = factory.create_component( - model_type=ListPartitionRouterModel, component_definition=partition_router_manifest, config=input_config + model_type=ListPartitionRouterModel, + component_definition=partition_router_manifest, + config=input_config, ) assert isinstance(partition_router, ListPartitionRouter) @@ -476,10 +586,14 @@ def test_list_based_stream_slicer_with_values_defined_in_config(): """ parsed_manifest = YamlDeclarativeSource._parse(content) resolved_manifest = resolver.preprocess_manifest(parsed_manifest) - partition_router_manifest = transformer.propagate_types_and_parameters("", resolved_manifest["partition_router"], {}) + partition_router_manifest = transformer.propagate_types_and_parameters( + "", resolved_manifest["partition_router"], {} + ) partition_router = factory.create_component( - model_type=ListPartitionRouterModel, component_definition=partition_router_manifest, config=input_config + model_type=ListPartitionRouterModel, + component_definition=partition_router_manifest, + config=input_config, ) assert isinstance(partition_router, ListPartitionRouter) @@ -532,10 +646,14 @@ def test_create_substream_partition_router(): """ parsed_manifest = YamlDeclarativeSource._parse(content) resolved_manifest = resolver.preprocess_manifest(parsed_manifest) - partition_router_manifest = transformer.propagate_types_and_parameters("", resolved_manifest["partition_router"], {}) + partition_router_manifest = transformer.propagate_types_and_parameters( + "", resolved_manifest["partition_router"], {} + ) partition_router = factory.create_component( - model_type=SubstreamPartitionRouterModel, component_definition=partition_router_manifest, config=input_config + model_type=SubstreamPartitionRouterModel, + component_definition=partition_router_manifest, + config=input_config, ) assert isinstance(partition_router, SubstreamPartitionRouter) @@ -546,8 +664,16 @@ def test_create_substream_partition_router(): assert partition_router.parent_stream_configs[0].parent_key.eval({}) == "id" assert partition_router.parent_stream_configs[0].partition_field.eval({}) == "repository_id" - assert partition_router.parent_stream_configs[0].request_option.inject_into == RequestOptionType.request_parameter - assert partition_router.parent_stream_configs[0].request_option.field_name.eval(config=input_config) == "repository_id" + assert ( + partition_router.parent_stream_configs[0].request_option.inject_into + == RequestOptionType.request_parameter + ) + assert ( + partition_router.parent_stream_configs[0].request_option.field_name.eval( + config=input_config + ) + == "repository_id" + ) assert partition_router.parent_stream_configs[1].parent_key.eval({}) == "someid" assert partition_router.parent_stream_configs[1].partition_field.eval({}) == "word_id" @@ -582,9 +708,15 @@ def test_datetime_based_cursor(): """ parsed_manifest = YamlDeclarativeSource._parse(content) resolved_manifest = resolver.preprocess_manifest(parsed_manifest) - slicer_manifest = transformer.propagate_types_and_parameters("", resolved_manifest["incremental"], {"cursor_field": "created_at"}) + slicer_manifest = transformer.propagate_types_and_parameters( + "", resolved_manifest["incremental"], {"cursor_field": "created_at"} + ) - stream_slicer = factory.create_component(model_type=DatetimeBasedCursorModel, component_definition=slicer_manifest, config=input_config) + stream_slicer = factory.create_component( + model_type=DatetimeBasedCursorModel, + component_definition=slicer_manifest, + config=input_config, + ) assert isinstance(stream_slicer, DatetimeBasedCursor) assert stream_slicer._step == datetime.timedelta(days=10) @@ -592,7 +724,12 @@ def test_datetime_based_cursor(): assert stream_slicer.cursor_granularity == "PT0.000001S" assert stream_slicer._lookback_window.string == "P5D" assert stream_slicer.start_time_option.inject_into == RequestOptionType.request_parameter - assert stream_slicer.start_time_option.field_name.eval(config=input_config | {"cursor_field": "updated_at"}) == "since_updated_at" + assert ( + stream_slicer.start_time_option.field_name.eval( + config=input_config | {"cursor_field": "updated_at"} + ) + == "since_updated_at" + ) assert stream_slicer.end_time_option.inject_into == RequestOptionType.body_json assert stream_slicer.end_time_option.field_name.eval({}) == "before_created_at" assert stream_slicer._partition_field_start.eval({}) == "star" @@ -601,7 +738,10 @@ def test_datetime_based_cursor(): assert isinstance(stream_slicer._start_datetime, MinMaxDatetime) assert stream_slicer.start_datetime._datetime_format == "%Y-%m-%dT%H:%M:%S.%f%z" assert stream_slicer.start_datetime.datetime.string == "{{ config['start_time'] }}" - assert stream_slicer.start_datetime.min_datetime.string == "{{ config['start_time'] + day_delta(2) }}" + assert ( + stream_slicer.start_datetime.min_datetime.string + == "{{ config['start_time'] + day_delta(2) }}" + ) assert isinstance(stream_slicer._end_datetime, MinMaxDatetime) assert stream_slicer._end_datetime.datetime.string == "{{ config['end_time'] }}" @@ -708,15 +848,21 @@ def test_stream_with_incremental_and_retriever_with_partition_router(): parsed_manifest = YamlDeclarativeSource._parse(content) resolved_manifest = resolver.preprocess_manifest(parsed_manifest) - stream_manifest = transformer.propagate_types_and_parameters("", resolved_manifest["list_stream"], {}) + stream_manifest = transformer.propagate_types_and_parameters( + "", resolved_manifest["list_stream"], {} + ) - stream = factory.create_component(model_type=DeclarativeStreamModel, component_definition=stream_manifest, config=input_config) + stream = factory.create_component( + model_type=DeclarativeStreamModel, component_definition=stream_manifest, config=input_config + ) assert isinstance(stream, DeclarativeStream) assert isinstance(stream.retriever, SimpleRetriever) assert isinstance(stream.retriever.stream_slicer, PerPartitionWithGlobalCursor) - datetime_stream_slicer = stream.retriever.stream_slicer._per_partition_cursor._cursor_factory.create() + datetime_stream_slicer = ( + stream.retriever.stream_slicer._per_partition_cursor._cursor_factory.create() + ) assert isinstance(datetime_stream_slicer, DatetimeBasedCursor) assert isinstance(datetime_stream_slicer._start_datetime, MinMaxDatetime) assert datetime_stream_slicer._start_datetime.datetime.string == "{{ config['start_time'] }}" @@ -825,7 +971,9 @@ def test_resumable_full_refresh_stream(): stream_manifest = manifest["list_stream"] assert stream_manifest["type"] == "DeclarativeStream" - stream = factory.create_component(model_type=DeclarativeStreamModel, component_definition=stream_manifest, config=input_config) + stream = factory.create_component( + model_type=DeclarativeStreamModel, component_definition=stream_manifest, config=input_config + ) assert isinstance(stream, DeclarativeStream) assert stream.primary_key == "id" @@ -844,18 +992,31 @@ def test_resumable_full_refresh_stream(): assert isinstance(stream.retriever.paginator, DefaultPaginator) assert isinstance(stream.retriever.paginator.decoder, PaginationDecoderDecorator) assert stream.retriever.paginator.page_size_option.field_name.eval(input_config) == "page_size" - assert stream.retriever.paginator.page_size_option.inject_into == RequestOptionType.request_parameter + assert ( + stream.retriever.paginator.page_size_option.inject_into + == RequestOptionType.request_parameter + ) assert isinstance(stream.retriever.paginator.page_token_option, RequestPath) assert stream.retriever.paginator.url_base.string == "https://api.sendgrid.com/v3/" assert stream.retriever.paginator.url_base.default == "https://api.sendgrid.com/v3/" assert isinstance(stream.retriever.paginator.pagination_strategy, CursorPaginationStrategy) - assert isinstance(stream.retriever.paginator.pagination_strategy.decoder, PaginationDecoderDecorator) - assert stream.retriever.paginator.pagination_strategy._cursor_value.string == "{{ response._metadata.next }}" - assert stream.retriever.paginator.pagination_strategy._cursor_value.default == "{{ response._metadata.next }}" + assert isinstance( + stream.retriever.paginator.pagination_strategy.decoder, PaginationDecoderDecorator + ) + assert ( + stream.retriever.paginator.pagination_strategy._cursor_value.string + == "{{ response._metadata.next }}" + ) + assert ( + stream.retriever.paginator.pagination_strategy._cursor_value.default + == "{{ response._metadata.next }}" + ) assert stream.retriever.paginator.pagination_strategy.page_size == 10 - checker = factory.create_component(model_type=CheckStreamModel, component_definition=manifest["check"], config=input_config) + checker = factory.create_component( + model_type=CheckStreamModel, component_definition=manifest["check"], config=input_config + ) assert isinstance(checker, CheckStream) streams_to_check = checker.stream_names @@ -907,11 +1068,17 @@ def test_incremental_data_feed(): parsed_manifest = YamlDeclarativeSource._parse(content) resolved_manifest = resolver.preprocess_manifest(parsed_manifest) - stream_manifest = transformer.propagate_types_and_parameters("", resolved_manifest["list_stream"], {}) + stream_manifest = transformer.propagate_types_and_parameters( + "", resolved_manifest["list_stream"], {} + ) - stream = factory.create_component(model_type=DeclarativeStreamModel, component_definition=stream_manifest, config=input_config) + stream = factory.create_component( + model_type=DeclarativeStreamModel, component_definition=stream_manifest, config=input_config + ) - assert isinstance(stream.retriever.paginator.pagination_strategy, StopConditionPaginationStrategyDecorator) + assert isinstance( + stream.retriever.paginator.pagination_strategy, StopConditionPaginationStrategyDecorator + ) def test_given_data_feed_and_incremental_then_raise_error(): @@ -927,11 +1094,15 @@ def test_given_data_feed_and_incremental_then_raise_error(): parsed_incremental_sync = YamlDeclarativeSource._parse(content) resolved_incremental_sync = resolver.preprocess_manifest(parsed_incremental_sync) - datetime_based_cursor_definition = transformer.propagate_types_and_parameters("", resolved_incremental_sync["incremental_sync"], {}) + datetime_based_cursor_definition = transformer.propagate_types_and_parameters( + "", resolved_incremental_sync["incremental_sync"], {} + ) with pytest.raises(ValueError): factory.create_component( - model_type=DatetimeBasedCursorModel, component_definition=datetime_based_cursor_definition, config=input_config + model_type=DatetimeBasedCursorModel, + component_definition=datetime_based_cursor_definition, + config=input_config, ) @@ -979,11 +1150,17 @@ def test_client_side_incremental(): parsed_manifest = YamlDeclarativeSource._parse(content) resolved_manifest = resolver.preprocess_manifest(parsed_manifest) - stream_manifest = transformer.propagate_types_and_parameters("", resolved_manifest["list_stream"], {}) + stream_manifest = transformer.propagate_types_and_parameters( + "", resolved_manifest["list_stream"], {} + ) - stream = factory.create_component(model_type=DeclarativeStreamModel, component_definition=stream_manifest, config=input_config) + stream = factory.create_component( + model_type=DeclarativeStreamModel, component_definition=stream_manifest, config=input_config + ) - assert isinstance(stream.retriever.record_selector.record_filter, ClientSideIncrementalRecordFilterDecorator) + assert isinstance( + stream.retriever.record_selector.record_filter, ClientSideIncrementalRecordFilterDecorator + ) def test_client_side_incremental_with_partition_router(): @@ -1054,12 +1231,21 @@ def test_client_side_incremental_with_partition_router(): parsed_manifest = YamlDeclarativeSource._parse(content) resolved_manifest = resolver.preprocess_manifest(parsed_manifest) - stream_manifest = transformer.propagate_types_and_parameters("", resolved_manifest["list_stream"], {}) + stream_manifest = transformer.propagate_types_and_parameters( + "", resolved_manifest["list_stream"], {} + ) - stream = factory.create_component(model_type=DeclarativeStreamModel, component_definition=stream_manifest, config=input_config) + stream = factory.create_component( + model_type=DeclarativeStreamModel, component_definition=stream_manifest, config=input_config + ) - assert isinstance(stream.retriever.record_selector.record_filter, ClientSideIncrementalRecordFilterDecorator) - assert isinstance(stream.retriever.record_selector.record_filter._substream_cursor, PerPartitionWithGlobalCursor) + assert isinstance( + stream.retriever.record_selector.record_filter, ClientSideIncrementalRecordFilterDecorator + ) + assert isinstance( + stream.retriever.record_selector.record_filter._substream_cursor, + PerPartitionWithGlobalCursor, + ) def test_given_data_feed_and_client_side_incremental_then_raise_error(): @@ -1076,18 +1262,28 @@ def test_given_data_feed_and_client_side_incremental_then_raise_error(): parsed_incremental_sync = YamlDeclarativeSource._parse(content) resolved_incremental_sync = resolver.preprocess_manifest(parsed_incremental_sync) - datetime_based_cursor_definition = transformer.propagate_types_and_parameters("", resolved_incremental_sync["incremental_sync"], {}) + datetime_based_cursor_definition = transformer.propagate_types_and_parameters( + "", resolved_incremental_sync["incremental_sync"], {} + ) with pytest.raises(ValueError) as e: factory.create_component( - model_type=DatetimeBasedCursorModel, component_definition=datetime_based_cursor_definition, config=input_config + model_type=DatetimeBasedCursorModel, + component_definition=datetime_based_cursor_definition, + config=input_config, ) - assert e.value.args[0] == "`Client side incremental` cannot be applied with `data feed`. Choose only 1 from them." + assert ( + e.value.args[0] + == "`Client side incremental` cannot be applied with `data feed`. Choose only 1 from them." + ) @pytest.mark.parametrize( "test_name, record_selector, expected_runtime_selector", - [("test_static_record_selector", "result", "result"), ("test_options_record_selector", "{{ parameters['name'] }}", "lists")], + [ + ("test_static_record_selector", "result", "result"), + ("test_options_record_selector", "{{ parameters['name'] }}", "lists"), + ], ) def test_create_record_selector(test_name, record_selector, expected_runtime_selector): content = f""" @@ -1106,15 +1302,23 @@ def test_create_record_selector(test_name, record_selector, expected_runtime_sel """ parsed_manifest = YamlDeclarativeSource._parse(content) resolved_manifest = resolver.preprocess_manifest(parsed_manifest) - selector_manifest = transformer.propagate_types_and_parameters("", resolved_manifest["selector"], {}) + selector_manifest = transformer.propagate_types_and_parameters( + "", resolved_manifest["selector"], {} + ) selector = factory.create_component( - model_type=RecordSelectorModel, component_definition=selector_manifest, decoder=None, transformations=[], config=input_config + model_type=RecordSelectorModel, + component_definition=selector_manifest, + decoder=None, + transformations=[], + config=input_config, ) assert isinstance(selector, RecordSelector) assert isinstance(selector.extractor, DpathExtractor) - assert [fp.eval(input_config) for fp in selector.extractor._field_path] == [expected_runtime_selector] + assert [fp.eval(input_config) for fp in selector.extractor._field_path] == [ + expected_runtime_selector + ] assert isinstance(selector.record_filter, RecordFilter) assert selector.record_filter.condition == "{{ record['id'] > stream_state['id'] }}" @@ -1186,7 +1390,9 @@ def test_create_requester(test_name, error_handler, expected_backoff_strategy_ty name = "name" parsed_manifest = YamlDeclarativeSource._parse(content) resolved_manifest = resolver.preprocess_manifest(parsed_manifest) - requester_manifest = transformer.propagate_types_and_parameters("", resolved_manifest["requester"], {}) + requester_manifest = transformer.propagate_types_and_parameters( + "", resolved_manifest["requester"], {} + ) selector = factory.create_component( model_type=HttpRequesterModel, @@ -1205,15 +1411,25 @@ def test_create_requester(test_name, error_handler, expected_backoff_strategy_ty assert isinstance(selector.error_handler, DefaultErrorHandler) if expected_backoff_strategy_type: assert len(selector.error_handler.backoff_strategies) == 1 - assert isinstance(selector.error_handler.backoff_strategies[0], expected_backoff_strategy_type) + assert isinstance( + selector.error_handler.backoff_strategies[0], expected_backoff_strategy_type + ) assert isinstance(selector.authenticator, BasicHttpAuthenticator) assert selector.authenticator._username.eval(input_config) == "lists" assert selector.authenticator._password.eval(input_config) == "verysecrettoken" assert isinstance(selector._request_options_provider, InterpolatedRequestOptionsProvider) - assert selector._request_options_provider._parameter_interpolator._interpolator.mapping["a_parameter"] == "something_here" - assert selector._request_options_provider._headers_interpolator._interpolator.mapping["header"] == "header_value" + assert ( + selector._request_options_provider._parameter_interpolator._interpolator.mapping[ + "a_parameter" + ] + == "something_here" + ) + assert ( + selector._request_options_provider._headers_interpolator._interpolator.mapping["header"] + == "header_value" + ) def test_create_request_with_legacy_session_authenticator(): @@ -1240,10 +1456,16 @@ def test_create_request_with_legacy_session_authenticator(): name = "name" parsed_manifest = YamlDeclarativeSource._parse(content) resolved_manifest = resolver.preprocess_manifest(parsed_manifest) - requester_manifest = transformer.propagate_types_and_parameters("", resolved_manifest["requester"], {}) + requester_manifest = transformer.propagate_types_and_parameters( + "", resolved_manifest["requester"], {} + ) selector = factory.create_component( - model_type=HttpRequesterModel, component_definition=requester_manifest, config=input_config, name=name, decoder=None + model_type=HttpRequesterModel, + component_definition=requester_manifest, + config=input_config, + name=name, + decoder=None, ) assert isinstance(selector, HttpRequester) @@ -1290,10 +1512,16 @@ def test_create_request_with_session_authenticator(): name = "name" parsed_manifest = YamlDeclarativeSource._parse(content) resolved_manifest = resolver.preprocess_manifest(parsed_manifest) - requester_manifest = transformer.propagate_types_and_parameters("", resolved_manifest["requester"], {}) + requester_manifest = transformer.propagate_types_and_parameters( + "", resolved_manifest["requester"], {} + ) selector = factory.create_component( - model_type=HttpRequesterModel, component_definition=requester_manifest, config=input_config, name=name, decoder=None + model_type=HttpRequesterModel, + component_definition=requester_manifest, + config=input_config, + name=name, + decoder=None, ) assert isinstance(selector.authenticator, ApiKeyAuthenticator) @@ -1301,14 +1529,19 @@ def test_create_request_with_session_authenticator(): assert selector.authenticator.token_provider.session_token_path == ["id"] assert isinstance(selector.authenticator.token_provider.login_requester, HttpRequester) assert selector.authenticator.token_provider.session_token_path == ["id"] - assert selector.authenticator.token_provider.login_requester._url_base.eval(input_config) == "https://api.sendgrid.com" + assert ( + selector.authenticator.token_provider.login_requester._url_base.eval(input_config) + == "https://api.sendgrid.com" + ) assert selector.authenticator.token_provider.login_requester.get_request_body_json() == { "username": "lists", "password": "verysecrettoken", } -def test_given_composite_error_handler_does_not_match_response_then_fallback_on_default_error_handler(requests_mock): +def test_given_composite_error_handler_does_not_match_response_then_fallback_on_default_error_handler( + requests_mock, +): content = """ requester: type: HttpRequester @@ -1328,7 +1561,9 @@ def test_given_composite_error_handler_does_not_match_response_then_fallback_on_ """ parsed_manifest = YamlDeclarativeSource._parse(content) resolved_manifest = resolver.preprocess_manifest(parsed_manifest) - requester_manifest = transformer.propagate_types_and_parameters("", resolved_manifest["requester"], {}) + requester_manifest = transformer.propagate_types_and_parameters( + "", resolved_manifest["requester"], {} + ) http_requester = factory.create_component( model_type=HttpRequesterModel, component_definition=requester_manifest, @@ -1383,10 +1618,15 @@ def test_create_requester_with_selective_authenticator(input_config, expected_au name = "name" parsed_manifest = YamlDeclarativeSource._parse(content) resolved_manifest = resolver.preprocess_manifest(parsed_manifest) - authenticator_manifest = transformer.propagate_types_and_parameters("", resolved_manifest["authenticator"], {}) + authenticator_manifest = transformer.propagate_types_and_parameters( + "", resolved_manifest["authenticator"], {} + ) authenticator = factory.create_component( - model_type=SelectiveAuthenticator, component_definition=authenticator_manifest, config=input_config, name=name + model_type=SelectiveAuthenticator, + component_definition=authenticator_manifest, + config=input_config, + name=name, ) assert isinstance(authenticator, expected_authenticator_class) @@ -1406,10 +1646,14 @@ def test_create_composite_error_handler(): """ parsed_manifest = YamlDeclarativeSource._parse(content) resolved_manifest = resolver.preprocess_manifest(parsed_manifest) - error_handler_manifest = transformer.propagate_types_and_parameters("", resolved_manifest["error_handler"], {}) + error_handler_manifest = transformer.propagate_types_and_parameters( + "", resolved_manifest["error_handler"], {} + ) error_handler = factory.create_component( - model_type=CompositeErrorHandlerModel, component_definition=error_handler_manifest, config=input_config + model_type=CompositeErrorHandlerModel, + component_definition=error_handler_manifest, + config=input_config, ) assert isinstance(error_handler, CompositeErrorHandler) @@ -1470,9 +1714,13 @@ def test_config_with_defaults(): parsed_manifest = YamlDeclarativeSource._parse(content) resolved_manifest = resolver.preprocess_manifest(parsed_manifest) resolved_manifest["type"] = "DeclarativeSource" - stream_manifest = transformer.propagate_types_and_parameters("", resolved_manifest["lists_stream"], {}) + stream_manifest = transformer.propagate_types_and_parameters( + "", resolved_manifest["lists_stream"], {} + ) - stream = factory.create_component(model_type=DeclarativeStreamModel, component_definition=stream_manifest, config=input_config) + stream = factory.create_component( + model_type=DeclarativeStreamModel, component_definition=stream_manifest, config=input_config + ) assert isinstance(stream, DeclarativeStream) assert stream.primary_key == "id" @@ -1482,8 +1730,14 @@ def test_config_with_defaults(): assert stream.retriever.primary_key == stream.primary_key assert isinstance(stream.schema_loader, JsonFileSchemaLoader) - assert stream.schema_loader.file_path.string == "./source_sendgrid/schemas/{{ parameters.name }}.yaml" - assert stream.schema_loader.file_path.default == "./source_sendgrid/schemas/{{ parameters.name }}.yaml" + assert ( + stream.schema_loader.file_path.string + == "./source_sendgrid/schemas/{{ parameters.name }}.yaml" + ) + assert ( + stream.schema_loader.file_path.default + == "./source_sendgrid/schemas/{{ parameters.name }}.yaml" + ) assert isinstance(stream.retriever.requester, HttpRequester) assert stream.retriever.requester.http_method == HttpMethod.GET @@ -1493,7 +1747,9 @@ def test_config_with_defaults(): assert isinstance(stream.retriever.record_selector, RecordSelector) assert isinstance(stream.retriever.record_selector.extractor, DpathExtractor) - assert [fp.eval(input_config) for fp in stream.retriever.record_selector.extractor._field_path] == ["result"] + assert [ + fp.eval(input_config) for fp in stream.retriever.record_selector.extractor._field_path + ] == ["result"] assert isinstance(stream.retriever.paginator, DefaultPaginator) assert stream.retriever.paginator.url_base.string == "https://api.sendgrid.com" @@ -1517,7 +1773,9 @@ def test_create_default_paginator(): """ parsed_manifest = YamlDeclarativeSource._parse(content) resolved_manifest = resolver.preprocess_manifest(parsed_manifest) - paginator_manifest = transformer.propagate_types_and_parameters("", resolved_manifest["paginator"], {}) + paginator_manifest = transformer.propagate_types_and_parameters( + "", resolved_manifest["paginator"], {} + ) paginator = factory.create_component( model_type=DefaultPaginatorModel, @@ -1548,7 +1806,11 @@ def test_create_default_paginator(): { "type": "CustomErrorHandler", "class_name": "unit_tests.sources.declarative.parsers.testing_components.TestingSomeComponent", - "subcomponent_field_with_hint": {"type": "DpathExtractor", "field_path": [], "decoder": {"type": "JsonDecoder"}}, + "subcomponent_field_with_hint": { + "type": "DpathExtractor", + "field_path": [], + "decoder": {"type": "JsonDecoder"}, + }, }, "subcomponent_field_with_hint", DpathExtractor( @@ -1567,7 +1829,11 @@ def test_create_default_paginator(): "subcomponent_field_with_hint": {"field_path": []}, }, "subcomponent_field_with_hint", - DpathExtractor(field_path=[], config={"apikey": "verysecrettoken", "repos": ["airbyte", "airbyte-cloud"]}, parameters={}), + DpathExtractor( + field_path=[], + config={"apikey": "verysecrettoken", "repos": ["airbyte", "airbyte-cloud"]}, + parameters={}, + ), None, id="test_create_custom_component_with_subcomponent_that_must_infer_type_from_explicit_hints", ), @@ -1586,10 +1852,18 @@ def test_create_default_paginator(): { "type": "CustomErrorHandler", "class_name": "unit_tests.sources.declarative.parsers.testing_components.TestingSomeComponent", - "optional_subcomponent_field": {"type": "RequestOption", "inject_into": "request_parameter", "field_name": "destination"}, + "optional_subcomponent_field": { + "type": "RequestOption", + "inject_into": "request_parameter", + "field_name": "destination", + }, }, "optional_subcomponent_field", - RequestOption(inject_into=RequestOptionType.request_parameter, field_name="destination", parameters={}), + RequestOption( + inject_into=RequestOptionType.request_parameter, + field_name="destination", + parameters={}, + ), None, id="test_create_custom_component_with_subcomponent_wrapped_in_optional", ), @@ -1599,13 +1873,23 @@ def test_create_default_paginator(): "class_name": "unit_tests.sources.declarative.parsers.testing_components.TestingSomeComponent", "list_of_subcomponents": [ {"inject_into": "header", "field_name": "store_me"}, - {"type": "RequestOption", "inject_into": "request_parameter", "field_name": "destination"}, + { + "type": "RequestOption", + "inject_into": "request_parameter", + "field_name": "destination", + }, ], }, "list_of_subcomponents", [ - RequestOption(inject_into=RequestOptionType.header, field_name="store_me", parameters={}), - RequestOption(inject_into=RequestOptionType.request_parameter, field_name="destination", parameters={}), + RequestOption( + inject_into=RequestOptionType.header, field_name="store_me", parameters={} + ), + RequestOption( + inject_into=RequestOptionType.request_parameter, + field_name="destination", + parameters={}, + ), ], None, id="test_create_custom_component_with_subcomponent_wrapped_in_list", @@ -1634,7 +1918,9 @@ def test_create_default_paginator(): "paginator", DefaultPaginator( pagination_strategy=OffsetIncrement( - page_size=10, config={"apikey": "verysecrettoken", "repos": ["airbyte", "airbyte-cloud"]}, parameters={} + page_size=10, + config={"apikey": "verysecrettoken", "repos": ["airbyte", "airbyte-cloud"]}, + parameters={}, ), url_base="https://physical_100.com", config={"apikey": "verysecrettoken", "repos": ["airbyte", "airbyte-cloud"]}, @@ -1705,7 +1991,11 @@ def test_custom_components_do_not_contain_extra_fields(): "type": "RecordSelector", "extractor": {"type": "DpathExtractor", "field_path": []}, }, - "requester": {"type": "HttpRequester", "url_base": "https://airbyte.io", "path": "some"}, + "requester": { + "type": "HttpRequester", + "url_base": "https://airbyte.io", + "path": "some", + }, }, "schema_loader": { "type": "JsonFileSchemaLoader", @@ -1714,7 +2004,11 @@ def test_custom_components_do_not_contain_extra_fields(): }, "parent_key": "id", "partition_field": "repository_id", - "request_option": {"type": "RequestOption", "inject_into": "request_parameter", "field_name": "repository_id"}, + "request_option": { + "type": "RequestOption", + "inject_into": "request_parameter", + "field_name": "repository_id", + }, } ], } @@ -1726,9 +2020,20 @@ def test_custom_components_do_not_contain_extra_fields(): assert len(custom_substream_partition_router.parent_stream_configs) == 1 assert custom_substream_partition_router.parent_stream_configs[0].parent_key.eval({}) == "id" - assert custom_substream_partition_router.parent_stream_configs[0].partition_field.eval({}) == "repository_id" - assert custom_substream_partition_router.parent_stream_configs[0].request_option.inject_into == RequestOptionType.request_parameter - assert custom_substream_partition_router.parent_stream_configs[0].request_option.field_name.eval(config=input_config) == "repository_id" + assert ( + custom_substream_partition_router.parent_stream_configs[0].partition_field.eval({}) + == "repository_id" + ) + assert ( + custom_substream_partition_router.parent_stream_configs[0].request_option.inject_into + == RequestOptionType.request_parameter + ) + assert ( + custom_substream_partition_router.parent_stream_configs[0].request_option.field_name.eval( + config=input_config + ) + == "repository_id" + ) assert isinstance(custom_substream_partition_router.custom_pagination_strategy, PageIncrement) assert custom_substream_partition_router.custom_pagination_strategy.page_size == 100 @@ -1753,7 +2058,11 @@ def test_parse_custom_component_fields_if_subcomponent(): "type": "RecordSelector", "extractor": {"type": "DpathExtractor", "field_path": []}, }, - "requester": {"type": "HttpRequester", "url_base": "https://airbyte.io", "path": "some"}, + "requester": { + "type": "HttpRequester", + "url_base": "https://airbyte.io", + "path": "some", + }, }, "schema_loader": { "type": "JsonFileSchemaLoader", @@ -1762,7 +2071,11 @@ def test_parse_custom_component_fields_if_subcomponent(): }, "parent_key": "id", "partition_field": "repository_id", - "request_option": {"type": "RequestOption", "inject_into": "request_parameter", "field_name": "repository_id"}, + "request_option": { + "type": "RequestOption", + "inject_into": "request_parameter", + "field_name": "repository_id", + }, } ], } @@ -1775,9 +2088,20 @@ def test_parse_custom_component_fields_if_subcomponent(): assert len(custom_substream_partition_router.parent_stream_configs) == 1 assert custom_substream_partition_router.parent_stream_configs[0].parent_key.eval({}) == "id" - assert custom_substream_partition_router.parent_stream_configs[0].partition_field.eval({}) == "repository_id" - assert custom_substream_partition_router.parent_stream_configs[0].request_option.inject_into == RequestOptionType.request_parameter - assert custom_substream_partition_router.parent_stream_configs[0].request_option.field_name.eval(config=input_config) == "repository_id" + assert ( + custom_substream_partition_router.parent_stream_configs[0].partition_field.eval({}) + == "repository_id" + ) + assert ( + custom_substream_partition_router.parent_stream_configs[0].request_option.inject_into + == RequestOptionType.request_parameter + ) + assert ( + custom_substream_partition_router.parent_stream_configs[0].request_option.field_name.eval( + config=input_config + ) + == "repository_id" + ) assert isinstance(custom_substream_partition_router.custom_pagination_strategy, PageIncrement) assert custom_substream_partition_router.custom_pagination_strategy.page_size == 100 @@ -1813,9 +2137,15 @@ def test_no_transformations(self): parsed_manifest = YamlDeclarativeSource._parse(content) resolved_manifest = resolver.preprocess_manifest(parsed_manifest) resolved_manifest["type"] = "DeclarativeSource" - stream_manifest = transformer.propagate_types_and_parameters("", resolved_manifest["the_stream"], {}) + stream_manifest = transformer.propagate_types_and_parameters( + "", resolved_manifest["the_stream"], {} + ) - stream = factory.create_component(model_type=DeclarativeStreamModel, component_definition=stream_manifest, config=input_config) + stream = factory.create_component( + model_type=DeclarativeStreamModel, + component_definition=stream_manifest, + config=input_config, + ) assert isinstance(stream, DeclarativeStream) assert [] == stream.retriever.record_selector.transformations @@ -1835,12 +2165,20 @@ def test_remove_fields(self): parsed_manifest = YamlDeclarativeSource._parse(content) resolved_manifest = resolver.preprocess_manifest(parsed_manifest) resolved_manifest["type"] = "DeclarativeSource" - stream_manifest = transformer.propagate_types_and_parameters("", resolved_manifest["the_stream"], {}) + stream_manifest = transformer.propagate_types_and_parameters( + "", resolved_manifest["the_stream"], {} + ) - stream = factory.create_component(model_type=DeclarativeStreamModel, component_definition=stream_manifest, config=input_config) + stream = factory.create_component( + model_type=DeclarativeStreamModel, + component_definition=stream_manifest, + config=input_config, + ) assert isinstance(stream, DeclarativeStream) - expected = [RemoveFields(field_pointers=[["path", "to", "field1"], ["path2"]], parameters={})] + expected = [ + RemoveFields(field_pointers=[["path", "to", "field1"], ["path2"]], parameters={}) + ] assert stream.retriever.record_selector.transformations == expected def test_add_fields_no_value_type(self): @@ -1860,7 +2198,9 @@ def test_add_fields_no_value_type(self): fields=[ AddedFieldDefinition( path=["field1"], - value=InterpolatedString(string="static_value", default="static_value", parameters={}), + value=InterpolatedString( + string="static_value", default="static_value", parameters={} + ), value_type=None, parameters={}, ) @@ -1888,7 +2228,9 @@ def test_add_fields_value_type_is_string(self): fields=[ AddedFieldDefinition( path=["field1"], - value=InterpolatedString(string="static_value", default="static_value", parameters={}), + value=InterpolatedString( + string="static_value", default="static_value", parameters={} + ), value_type=str, parameters={}, ) @@ -1986,9 +2328,15 @@ def _test_add_fields(self, content, expected): parsed_manifest = YamlDeclarativeSource._parse(content) resolved_manifest = resolver.preprocess_manifest(parsed_manifest) resolved_manifest["type"] = "DeclarativeSource" - stream_manifest = transformer.propagate_types_and_parameters("", resolved_manifest["the_stream"], {}) + stream_manifest = transformer.propagate_types_and_parameters( + "", resolved_manifest["the_stream"], {} + ) - stream = factory.create_component(model_type=DeclarativeStreamModel, component_definition=stream_manifest, config=input_config) + stream = factory.create_component( + model_type=DeclarativeStreamModel, + component_definition=stream_manifest, + config=input_config, + ) assert isinstance(stream, DeclarativeStream) assert stream.retriever.record_selector.transformations == expected @@ -2010,9 +2358,15 @@ def test_default_schema_loader(self): "request_body_json": {}, "type": "InterpolatedRequestOptionsProvider", }, - "authenticator": {"type": "BearerAuthenticator", "api_token": "{{ config['api_key'] }}"}, + "authenticator": { + "type": "BearerAuthenticator", + "api_token": "{{ config['api_key'] }}", + }, + }, + "record_selector": { + "type": "RecordSelector", + "extractor": {"type": "DpathExtractor", "field_path": ["items"]}, }, - "record_selector": {"type": "RecordSelector", "extractor": {"type": "DpathExtractor", "field_path": ["items"]}}, "paginator": {"type": "NoPagination"}, }, } @@ -2020,10 +2374,15 @@ def test_default_schema_loader(self): ws = ManifestComponentTransformer() propagated_source_config = ws.propagate_types_and_parameters("", resolved_manifest, {}) stream = factory.create_component( - model_type=DeclarativeStreamModel, component_definition=propagated_source_config, config=input_config + model_type=DeclarativeStreamModel, + component_definition=propagated_source_config, + config=input_config, ) schema_loader = stream.schema_loader - assert schema_loader.default_loader._get_json_filepath().split("/")[-1] == f"{stream.name}.json" + assert ( + schema_loader.default_loader._get_json_filepath().split("/")[-1] + == f"{stream.name}.json" + ) @pytest.mark.parametrize( @@ -2096,7 +2455,12 @@ def test_default_schema_loader(self): PerPartitionWithGlobalCursor, id="test_create_simple_retriever_with_partition_routers_multiple_components", ), - pytest.param(None, None, SinglePartitionRouter, id="test_create_simple_retriever_with_no_incremental_or_partition_router"), + pytest.param( + None, + None, + SinglePartitionRouter, + id="test_create_simple_retriever_with_no_incremental_or_partition_router", + ), ], ) def test_merge_incremental_and_partition_router(incremental, partition_router, expected_type): @@ -2126,7 +2490,9 @@ def test_merge_incremental_and_partition_router(incremental, partition_router, e if partition_router: stream_model["retriever"]["partition_router"] = partition_router - stream = factory.create_component(model_type=DeclarativeStreamModel, component_definition=stream_model, config=input_config) + stream = factory.create_component( + model_type=DeclarativeStreamModel, component_definition=stream_model, config=input_config + ) assert isinstance(stream, DeclarativeStream) assert isinstance(stream.retriever, SimpleRetriever) @@ -2136,8 +2502,12 @@ def test_merge_incremental_and_partition_router(incremental, partition_router, e if incremental and partition_router: assert isinstance(stream.retriever.stream_slicer, PerPartitionWithGlobalCursor) if isinstance(partition_router, list) and len(partition_router) > 1: - assert isinstance(stream.retriever.stream_slicer._partition_router, CartesianProductStreamSlicer) - assert len(stream.retriever.stream_slicer._partition_router.stream_slicers) == len(partition_router) + assert isinstance( + stream.retriever.stream_slicer._partition_router, CartesianProductStreamSlicer + ) + assert len(stream.retriever.stream_slicer._partition_router.stream_slicers) == len( + partition_router + ) elif partition_router and isinstance(partition_router, list) and len(partition_router) > 1: assert isinstance(stream.retriever.stream_slicer, PerPartitionWithGlobalCursor) assert len(stream.retriever.stream_slicer.stream_slicerS) == len(partition_router) @@ -2153,7 +2523,12 @@ def test_simple_retriever_emit_log_messages(): "field_path": [], }, }, - "requester": {"type": "HttpRequester", "name": "list", "url_base": "orange.com", "path": "/v1/api"}, + "requester": { + "type": "HttpRequester", + "name": "list", + "url_base": "orange.com", + "path": "/v1/api", + }, } connector_builder_factory = ModelToComponentFactory(emit_connector_builder_messages=True) @@ -2178,7 +2553,13 @@ def test_create_page_increment(): start_from_page=1, inject_on_first_request=True, ) - expected_strategy = PageIncrement(page_size=10, start_from_page=1, inject_on_first_request=True, parameters={}, config=input_config) + expected_strategy = PageIncrement( + page_size=10, + start_from_page=1, + inject_on_first_request=True, + parameters={}, + config=input_config, + ) strategy = factory.create_page_increment(model, input_config) @@ -2195,7 +2576,9 @@ def test_create_page_increment_with_interpolated_page_size(): inject_on_first_request=True, ) config = {**input_config, "page_size": 5} - expected_strategy = PageIncrement(page_size=5, start_from_page=1, inject_on_first_request=True, parameters={}, config=config) + expected_strategy = PageIncrement( + page_size=5, start_from_page=1, inject_on_first_request=True, parameters={}, config=config + ) strategy = factory.create_page_increment(model, config) @@ -2210,9 +2593,13 @@ def test_create_offset_increment(): page_size=10, inject_on_first_request=True, ) - expected_strategy = OffsetIncrement(page_size=10, inject_on_first_request=True, parameters={}, config=input_config) + expected_strategy = OffsetIncrement( + page_size=10, inject_on_first_request=True, parameters={}, config=input_config + ) - strategy = factory.create_offset_increment(model, input_config, decoder=JsonDecoder(parameters={})) + strategy = factory.create_offset_increment( + model, input_config, decoder=JsonDecoder(parameters={}) + ) assert strategy.page_size == expected_strategy.page_size assert strategy.inject_on_first_request == expected_strategy.inject_on_first_request @@ -2355,16 +2742,22 @@ def test_create_jwt_authenticator(config, manifest, expected): parsed_manifest = YamlDeclarativeSource._parse(manifest) resolved_manifest = resolver.preprocess_manifest(parsed_manifest) - authenticator_manifest = transformer.propagate_types_and_parameters("", resolved_manifest["authenticator"], {}) + authenticator_manifest = transformer.propagate_types_and_parameters( + "", resolved_manifest["authenticator"], {} + ) if expected.get("expect_error"): with pytest.raises(ValueError): authenticator = factory.create_component( - model_type=JwtAuthenticatorModel, component_definition=authenticator_manifest, config=config + model_type=JwtAuthenticatorModel, + component_definition=authenticator_manifest, + config=config, ) return - authenticator = factory.create_component(model_type=JwtAuthenticatorModel, component_definition=authenticator_manifest, config=config) + authenticator = factory.create_component( + model_type=JwtAuthenticatorModel, component_definition=authenticator_manifest, config=config + ) assert isinstance(authenticator, JwtAuthenticator) assert authenticator._secret_key.eval(config) == expected["secret_key"] @@ -2399,7 +2792,12 @@ def test_use_request_options_provider_for_datetime_based_cursor(): "field_path": [], }, }, - "requester": {"type": "HttpRequester", "name": "list", "url_base": "orange.com", "path": "/v1/api"}, + "requester": { + "type": "HttpRequester", + "name": "list", + "url_base": "orange.com", + "path": "/v1/api", + }, } datetime_based_cursor = DatetimeBasedCursor( @@ -2448,10 +2846,22 @@ def test_use_request_options_provider_for_datetime_based_cursor(): assert isinstance(retriever.stream_slicer, DatetimeBasedCursor) assert isinstance(retriever.request_option_provider, DatetimeBasedRequestOptionsProvider) - assert retriever.request_option_provider.start_time_option.inject_into == RequestOptionType.request_parameter - assert retriever.request_option_provider.start_time_option.field_name.eval(config=input_config) == "after" - assert retriever.request_option_provider.end_time_option.inject_into == RequestOptionType.request_parameter - assert retriever.request_option_provider.end_time_option.field_name.eval(config=input_config) == "before" + assert ( + retriever.request_option_provider.start_time_option.inject_into + == RequestOptionType.request_parameter + ) + assert ( + retriever.request_option_provider.start_time_option.field_name.eval(config=input_config) + == "after" + ) + assert ( + retriever.request_option_provider.end_time_option.inject_into + == RequestOptionType.request_parameter + ) + assert ( + retriever.request_option_provider.end_time_option.field_name.eval(config=input_config) + == "before" + ) assert retriever.request_option_provider._partition_field_start.string == "start_time" assert retriever.request_option_provider._partition_field_end.string == "end_time" @@ -2473,7 +2883,12 @@ def test_do_not_separate_request_options_provider_for_non_datetime_based_cursor( "field_path": [], }, }, - "requester": {"type": "HttpRequester", "name": "list", "url_base": "orange.com", "path": "/v1/api"}, + "requester": { + "type": "HttpRequester", + "name": "list", + "url_base": "orange.com", + "path": "/v1/api", + }, } datetime_based_cursor = DatetimeBasedCursor( @@ -2533,7 +2948,12 @@ def test_use_default_request_options_provider(): "field_path": [], }, }, - "requester": {"type": "HttpRequester", "name": "list", "url_base": "orange.com", "path": "/v1/api"}, + "requester": { + "type": "HttpRequester", + "name": "list", + "url_base": "orange.com", + "path": "/v1/api", + }, } connector_builder_factory = ModelToComponentFactory(emit_connector_builder_messages=True) @@ -2559,14 +2979,23 @@ def test_use_default_request_options_provider(): @pytest.mark.parametrize( "stream_state,expected_start", [ - pytest.param({}, "2024-08-01T00:00:00.000000Z", id="test_create_concurrent_cursor_without_state"), pytest.param( - {"updated_at": "2024-10-01T00:00:00.000000Z"}, "2024-10-01T00:00:00.000000Z", id="test_create_concurrent_cursor_with_state" + {}, "2024-08-01T00:00:00.000000Z", id="test_create_concurrent_cursor_without_state" + ), + pytest.param( + {"updated_at": "2024-10-01T00:00:00.000000Z"}, + "2024-10-01T00:00:00.000000Z", + id="test_create_concurrent_cursor_with_state", ), ], ) -def test_create_concurrent_cursor_from_datetime_based_cursor_all_fields(stream_state, expected_start): - config = {"start_time": "2024-08-01T00:00:00.000000Z", "end_time": "2024-10-15T00:00:00.000000Z"} +def test_create_concurrent_cursor_from_datetime_based_cursor_all_fields( + stream_state, expected_start +): + config = { + "start_time": "2024-08-01T00:00:00.000000Z", + "end_time": "2024-10-15T00:00:00.000000Z", + } expected_cursor_field = "updated_at" expected_start_boundary = "custom_start" @@ -2577,7 +3006,9 @@ def test_create_concurrent_cursor_from_datetime_based_cursor_all_fields(stream_s expected_cursor_granularity = datetime.timedelta(microseconds=1) expected_start = pendulum.parse(expected_start) - expected_end = datetime.datetime(year=2024, month=10, day=15, second=0, microsecond=0, tzinfo=datetime.timezone.utc) + expected_end = datetime.datetime( + year=2024, month=10, day=15, second=0, microsecond=0, tzinfo=datetime.timezone.utc + ) if stream_state: # Using incoming state, the resulting already completed partition is the start_time up to the last successful # partition indicated by the legacy sequential state @@ -2622,14 +3053,16 @@ def test_create_concurrent_cursor_from_datetime_based_cursor_all_fields(stream_s "lookback_window": "P3D", } - concurrent_cursor, stream_state_converter = connector_builder_factory.create_concurrent_cursor_from_datetime_based_cursor( - state_manager=connector_state_manager, - model_type=DatetimeBasedCursorModel, - component_definition=cursor_component_definition, - stream_name=stream_name, - stream_namespace=None, - config=config, - stream_state=stream_state, + concurrent_cursor, stream_state_converter = ( + connector_builder_factory.create_concurrent_cursor_from_datetime_based_cursor( + state_manager=connector_state_manager, + model_type=DatetimeBasedCursorModel, + component_definition=cursor_component_definition, + stream_name=stream_name, + stream_namespace=None, + config=config, + stream_state=stream_state, + ) ) assert concurrent_cursor._stream_name == stream_name @@ -2639,8 +3072,14 @@ def test_create_concurrent_cursor_from_datetime_based_cursor_all_fields(stream_s assert concurrent_cursor._slice_range == expected_step assert concurrent_cursor._lookback_window == expected_lookback_window - assert concurrent_cursor.slice_boundary_fields[ConcurrentCursor._START_BOUNDARY] == expected_start_boundary - assert concurrent_cursor.slice_boundary_fields[ConcurrentCursor._END_BOUNDARY] == expected_end_boundary + assert ( + concurrent_cursor.slice_boundary_fields[ConcurrentCursor._START_BOUNDARY] + == expected_start_boundary + ) + assert ( + concurrent_cursor.slice_boundary_fields[ConcurrentCursor._END_BOUNDARY] + == expected_end_boundary + ) assert concurrent_cursor.start == expected_start assert concurrent_cursor._end_provider() == expected_end @@ -2656,15 +3095,37 @@ def test_create_concurrent_cursor_from_datetime_based_cursor_all_fields(stream_s "cursor_fields_to_replace,assertion_field,expected_value,expected_error", [ pytest.param( - {"partition_field_start": None}, "slice_boundary_fields", ("start_time", "custom_end"), None, id="test_no_partition_field_start" + {"partition_field_start": None}, + "slice_boundary_fields", + ("start_time", "custom_end"), + None, + id="test_no_partition_field_start", + ), + pytest.param( + {"partition_field_end": None}, + "slice_boundary_fields", + ("custom_start", "end_time"), + None, + id="test_no_partition_field_end", ), pytest.param( - {"partition_field_end": None}, "slice_boundary_fields", ("custom_start", "end_time"), None, id="test_no_partition_field_end" + {"lookback_window": None}, "_lookback_window", None, None, id="test_no_lookback_window" + ), + pytest.param( + {"lookback_window": "{{ config.does_not_exist }}"}, + "_lookback_window", + None, + None, + id="test_no_lookback_window", ), - pytest.param({"lookback_window": None}, "_lookback_window", None, None, id="test_no_lookback_window"), - pytest.param({"lookback_window": "{{ config.does_not_exist }}"}, "_lookback_window", None, None, id="test_no_lookback_window"), pytest.param({"step": None}, None, None, ValueError, id="test_no_step_raises_exception"), - pytest.param({"cursor_granularity": None}, None, None, ValueError, id="test_no_cursor_granularity_exception"), + pytest.param( + {"cursor_granularity": None}, + None, + None, + ValueError, + id="test_no_cursor_granularity_exception", + ), pytest.param( { "end_time": None, @@ -2679,10 +3140,15 @@ def test_create_concurrent_cursor_from_datetime_based_cursor_all_fields(stream_s ], ) @freezegun.freeze_time("2024-10-01T00:00:00") -def test_create_concurrent_cursor_from_datetime_based_cursor(cursor_fields_to_replace, assertion_field, expected_value, expected_error): +def test_create_concurrent_cursor_from_datetime_based_cursor( + cursor_fields_to_replace, assertion_field, expected_value, expected_error +): connector_state_manager = ConnectorStateManager() - config = {"start_time": "2024-08-01T00:00:00.000000Z", "end_time": "2024-09-01T00:00:00.000000Z"} + config = { + "start_time": "2024-08-01T00:00:00.000000Z", + "end_time": "2024-09-01T00:00:00.000000Z", + } stream_name = "test" @@ -2719,14 +3185,16 @@ def test_create_concurrent_cursor_from_datetime_based_cursor(cursor_fields_to_re stream_state={}, ) else: - concurrent_cursor, stream_state_converter = connector_builder_factory.create_concurrent_cursor_from_datetime_based_cursor( - state_manager=connector_state_manager, - model_type=DatetimeBasedCursorModel, - component_definition=cursor_component_definition, - stream_name=stream_name, - stream_namespace=None, - config=config, - stream_state={}, + concurrent_cursor, stream_state_converter = ( + connector_builder_factory.create_concurrent_cursor_from_datetime_based_cursor( + state_manager=connector_state_manager, + model_type=DatetimeBasedCursorModel, + component_definition=cursor_component_definition, + stream_name=stream_name, + stream_namespace=None, + config=config, + stream_state={}, + ) ) assert getattr(concurrent_cursor, assertion_field) == expected_value @@ -2738,8 +3206,12 @@ def test_create_concurrent_cursor_uses_min_max_datetime_format_if_defined(): string parser should not inherit from the parent DatetimeBasedCursor.datetime_format. The parent which uses an incorrect precision would fail if it were used by the dependent children. """ - expected_start = datetime.datetime(year=2024, month=8, day=1, second=0, microsecond=0, tzinfo=datetime.timezone.utc) - expected_end = datetime.datetime(year=2024, month=9, day=1, second=0, microsecond=0, tzinfo=datetime.timezone.utc) + expected_start = datetime.datetime( + year=2024, month=8, day=1, second=0, microsecond=0, tzinfo=datetime.timezone.utc + ) + expected_end = datetime.datetime( + year=2024, month=9, day=1, second=0, microsecond=0, tzinfo=datetime.timezone.utc + ) connector_state_manager = ConnectorStateManager() @@ -2753,8 +3225,16 @@ def test_create_concurrent_cursor_uses_min_max_datetime_format_if_defined(): "type": "DatetimeBasedCursor", "cursor_field": "updated_at", "datetime_format": "%Y-%m-%dT%H:%MZ", - "start_datetime": {"type": "MinMaxDatetime", "datetime": "{{ config.start_time }}", "datetime_format": "%Y-%m-%dT%H:%M:%SZ"}, - "end_datetime": {"type": "MinMaxDatetime", "datetime": "{{ config.end_time }}", "datetime_format": "%Y-%m-%dT%H:%M:%SZ"}, + "start_datetime": { + "type": "MinMaxDatetime", + "datetime": "{{ config.start_time }}", + "datetime_format": "%Y-%m-%dT%H:%M:%SZ", + }, + "end_datetime": { + "type": "MinMaxDatetime", + "datetime": "{{ config.end_time }}", + "datetime_format": "%Y-%m-%dT%H:%M:%SZ", + }, "partition_field_start": "custom_start", "partition_field_end": "custom_end", "step": "P10D", @@ -2762,14 +3242,16 @@ def test_create_concurrent_cursor_uses_min_max_datetime_format_if_defined(): "lookback_window": "P3D", } - concurrent_cursor, stream_state_converter = connector_builder_factory.create_concurrent_cursor_from_datetime_based_cursor( - state_manager=connector_state_manager, - model_type=DatetimeBasedCursorModel, - component_definition=cursor_component_definition, - stream_name=stream_name, - stream_namespace=None, - config=config, - stream_state={}, + concurrent_cursor, stream_state_converter = ( + connector_builder_factory.create_concurrent_cursor_from_datetime_based_cursor( + state_manager=connector_state_manager, + model_type=DatetimeBasedCursorModel, + component_definition=cursor_component_definition, + stream_name=stream_name, + stream_namespace=None, + config=config, + stream_state={}, + ) ) assert concurrent_cursor.start == expected_start diff --git a/unit_tests/sources/declarative/parsers/testing_components.py b/unit_tests/sources/declarative/parsers/testing_components.py index db85283b..0d49e862 100644 --- a/unit_tests/sources/declarative/parsers/testing_components.py +++ b/unit_tests/sources/declarative/parsers/testing_components.py @@ -9,7 +9,10 @@ from airbyte_cdk.sources.declarative.partition_routers import SubstreamPartitionRouter from airbyte_cdk.sources.declarative.requesters import RequestOption from airbyte_cdk.sources.declarative.requesters.error_handlers import DefaultErrorHandler -from airbyte_cdk.sources.declarative.requesters.paginators import DefaultPaginator, PaginationStrategy +from airbyte_cdk.sources.declarative.requesters.paginators import ( + DefaultPaginator, + PaginationStrategy, +) @dataclass @@ -18,7 +21,9 @@ class TestingSomeComponent(DefaultErrorHandler): A basic test class with various field permutations used to test manifests with custom components """ - subcomponent_field_with_hint: DpathExtractor = field(default_factory=lambda: DpathExtractor(field_path=[], config={}, parameters={})) + subcomponent_field_with_hint: DpathExtractor = field( + default_factory=lambda: DpathExtractor(field_path=[], config={}, parameters={}) + ) basic_field: str = "" optional_subcomponent_field: Optional[RequestOption] = None list_of_subcomponents: List[RequestOption] = None diff --git a/unit_tests/sources/declarative/partition_routers/test_cartesian_product_partition_router.py b/unit_tests/sources/declarative/partition_routers/test_cartesian_product_partition_router.py index 2b9313b3..697a0605 100644 --- a/unit_tests/sources/declarative/partition_routers/test_cartesian_product_partition_router.py +++ b/unit_tests/sources/declarative/partition_routers/test_cartesian_product_partition_router.py @@ -6,8 +6,14 @@ from airbyte_cdk.sources.declarative.datetime.min_max_datetime import MinMaxDatetime from airbyte_cdk.sources.declarative.incremental.datetime_based_cursor import DatetimeBasedCursor from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString -from airbyte_cdk.sources.declarative.partition_routers import CartesianProductStreamSlicer, ListPartitionRouter -from airbyte_cdk.sources.declarative.requesters.request_option import RequestOption, RequestOptionType +from airbyte_cdk.sources.declarative.partition_routers import ( + CartesianProductStreamSlicer, + ListPartitionRouter, +) +from airbyte_cdk.sources.declarative.requesters.request_option import ( + RequestOption, + RequestOptionType, +) from airbyte_cdk.sources.types import StreamSlice @@ -16,7 +22,14 @@ [ ( "test_single_stream_slicer", - [ListPartitionRouter(values=["customer", "store", "subscription"], cursor_field="owner_resource", config={}, parameters={})], + [ + ListPartitionRouter( + values=["customer", "store", "subscription"], + cursor_field="owner_resource", + config={}, + parameters={}, + ) + ], [ StreamSlice(partition={"owner_resource": "customer"}, cursor_slice={}), StreamSlice(partition={"owner_resource": "store"}, cursor_slice={}), @@ -26,24 +39,43 @@ ( "test_two_stream_slicers", [ - ListPartitionRouter(values=["customer", "store", "subscription"], cursor_field="owner_resource", config={}, parameters={}), - ListPartitionRouter(values=["A", "B"], cursor_field="letter", config={}, parameters={}), + ListPartitionRouter( + values=["customer", "store", "subscription"], + cursor_field="owner_resource", + config={}, + parameters={}, + ), + ListPartitionRouter( + values=["A", "B"], cursor_field="letter", config={}, parameters={} + ), ], [ - StreamSlice(partition={"owner_resource": "customer", "letter": "A"}, cursor_slice={}), - StreamSlice(partition={"owner_resource": "customer", "letter": "B"}, cursor_slice={}), + StreamSlice( + partition={"owner_resource": "customer", "letter": "A"}, cursor_slice={} + ), + StreamSlice( + partition={"owner_resource": "customer", "letter": "B"}, cursor_slice={} + ), StreamSlice(partition={"owner_resource": "store", "letter": "A"}, cursor_slice={}), StreamSlice(partition={"owner_resource": "store", "letter": "B"}, cursor_slice={}), - StreamSlice(partition={"owner_resource": "subscription", "letter": "A"}, cursor_slice={}), - StreamSlice(partition={"owner_resource": "subscription", "letter": "B"}, cursor_slice={}), + StreamSlice( + partition={"owner_resource": "subscription", "letter": "A"}, cursor_slice={} + ), + StreamSlice( + partition={"owner_resource": "subscription", "letter": "B"}, cursor_slice={} + ), ], ), ( "test_singledatetime", [ DatetimeBasedCursor( - start_datetime=MinMaxDatetime(datetime="2021-01-01", datetime_format="%Y-%m-%d", parameters={}), - end_datetime=MinMaxDatetime(datetime="2021-01-03", datetime_format="%Y-%m-%d", parameters={}), + start_datetime=MinMaxDatetime( + datetime="2021-01-01", datetime_format="%Y-%m-%d", parameters={} + ), + end_datetime=MinMaxDatetime( + datetime="2021-01-03", datetime_format="%Y-%m-%d", parameters={} + ), step="P1D", cursor_field=InterpolatedString.create("", parameters={}), datetime_format="%Y-%m-%d", @@ -53,18 +85,36 @@ ), ], [ - StreamSlice(partition={}, cursor_slice={"start_time": "2021-01-01", "end_time": "2021-01-01"}), - StreamSlice(partition={}, cursor_slice={"start_time": "2021-01-02", "end_time": "2021-01-02"}), - StreamSlice(partition={}, cursor_slice={"start_time": "2021-01-03", "end_time": "2021-01-03"}), + StreamSlice( + partition={}, + cursor_slice={"start_time": "2021-01-01", "end_time": "2021-01-01"}, + ), + StreamSlice( + partition={}, + cursor_slice={"start_time": "2021-01-02", "end_time": "2021-01-02"}, + ), + StreamSlice( + partition={}, + cursor_slice={"start_time": "2021-01-03", "end_time": "2021-01-03"}, + ), ], ), ( "test_list_and_datetime", [ - ListPartitionRouter(values=["customer", "store", "subscription"], cursor_field="owner_resource", config={}, parameters={}), + ListPartitionRouter( + values=["customer", "store", "subscription"], + cursor_field="owner_resource", + config={}, + parameters={}, + ), DatetimeBasedCursor( - start_datetime=MinMaxDatetime(datetime="2021-01-01", datetime_format="%Y-%m-%d", parameters={}), - end_datetime=MinMaxDatetime(datetime="2021-01-03", datetime_format="%Y-%m-%d", parameters={}), + start_datetime=MinMaxDatetime( + datetime="2021-01-01", datetime_format="%Y-%m-%d", parameters={} + ), + end_datetime=MinMaxDatetime( + datetime="2021-01-03", datetime_format="%Y-%m-%d", parameters={} + ), step="P1D", cursor_field=InterpolatedString.create("", parameters={}), datetime_format="%Y-%m-%d", @@ -74,20 +124,41 @@ ), ], [ - StreamSlice(partition={"owner_resource": "customer"}, cursor_slice={"start_time": "2021-01-01", "end_time": "2021-01-01"}), - StreamSlice(partition={"owner_resource": "customer"}, cursor_slice={"start_time": "2021-01-02", "end_time": "2021-01-02"}), - StreamSlice(partition={"owner_resource": "customer"}, cursor_slice={"start_time": "2021-01-03", "end_time": "2021-01-03"}), - StreamSlice(partition={"owner_resource": "store"}, cursor_slice={"start_time": "2021-01-01", "end_time": "2021-01-01"}), - StreamSlice(partition={"owner_resource": "store"}, cursor_slice={"start_time": "2021-01-02", "end_time": "2021-01-02"}), - StreamSlice(partition={"owner_resource": "store"}, cursor_slice={"start_time": "2021-01-03", "end_time": "2021-01-03"}), StreamSlice( - partition={"owner_resource": "subscription"}, cursor_slice={"start_time": "2021-01-01", "end_time": "2021-01-01"} + partition={"owner_resource": "customer"}, + cursor_slice={"start_time": "2021-01-01", "end_time": "2021-01-01"}, + ), + StreamSlice( + partition={"owner_resource": "customer"}, + cursor_slice={"start_time": "2021-01-02", "end_time": "2021-01-02"}, ), StreamSlice( - partition={"owner_resource": "subscription"}, cursor_slice={"start_time": "2021-01-02", "end_time": "2021-01-02"} + partition={"owner_resource": "customer"}, + cursor_slice={"start_time": "2021-01-03", "end_time": "2021-01-03"}, ), StreamSlice( - partition={"owner_resource": "subscription"}, cursor_slice={"start_time": "2021-01-03", "end_time": "2021-01-03"} + partition={"owner_resource": "store"}, + cursor_slice={"start_time": "2021-01-01", "end_time": "2021-01-01"}, + ), + StreamSlice( + partition={"owner_resource": "store"}, + cursor_slice={"start_time": "2021-01-02", "end_time": "2021-01-02"}, + ), + StreamSlice( + partition={"owner_resource": "store"}, + cursor_slice={"start_time": "2021-01-03", "end_time": "2021-01-03"}, + ), + StreamSlice( + partition={"owner_resource": "subscription"}, + cursor_slice={"start_time": "2021-01-01", "end_time": "2021-01-01"}, + ), + StreamSlice( + partition={"owner_resource": "subscription"}, + cursor_slice={"start_time": "2021-01-02", "end_time": "2021-01-02"}, + ), + StreamSlice( + partition={"owner_resource": "subscription"}, + cursor_slice={"start_time": "2021-01-03", "end_time": "2021-01-03"}, ), ], ), @@ -102,8 +173,12 @@ def test_substream_slicer(test_name, stream_slicers, expected_slices): def test_stream_slices_raises_exception_if_multiple_cursor_slice_components(): stream_slicers = [ DatetimeBasedCursor( - start_datetime=MinMaxDatetime(datetime="2021-01-01", datetime_format="%Y-%m-%d", parameters={}), - end_datetime=MinMaxDatetime(datetime="2021-01-03", datetime_format="%Y-%m-%d", parameters={}), + start_datetime=MinMaxDatetime( + datetime="2021-01-01", datetime_format="%Y-%m-%d", parameters={} + ), + end_datetime=MinMaxDatetime( + datetime="2021-01-03", datetime_format="%Y-%m-%d", parameters={} + ), step="P1D", cursor_field=InterpolatedString.create("", parameters={}), datetime_format="%Y-%m-%d", @@ -112,8 +187,12 @@ def test_stream_slices_raises_exception_if_multiple_cursor_slice_components(): parameters={}, ), DatetimeBasedCursor( - start_datetime=MinMaxDatetime(datetime="2021-01-01", datetime_format="%Y-%m-%d", parameters={}), - end_datetime=MinMaxDatetime(datetime="2021-01-03", datetime_format="%Y-%m-%d", parameters={}), + start_datetime=MinMaxDatetime( + datetime="2021-01-01", datetime_format="%Y-%m-%d", parameters={} + ), + end_datetime=MinMaxDatetime( + datetime="2021-01-03", datetime_format="%Y-%m-%d", parameters={} + ), step="P1D", cursor_field=InterpolatedString.create("", parameters={}), datetime_format="%Y-%m-%d", @@ -132,7 +211,9 @@ def test_stream_slices_raises_exception_if_multiple_cursor_slice_components(): [ ( "test_param_header", - RequestOption(inject_into=RequestOptionType.request_parameter, parameters={}, field_name="owner"), + RequestOption( + inject_into=RequestOptionType.request_parameter, parameters={}, field_name="owner" + ), RequestOption(inject_into=RequestOptionType.header, parameters={}, field_name="repo"), {"owner": "customer"}, {"repo": "airbyte"}, @@ -150,8 +231,12 @@ def test_stream_slices_raises_exception_if_multiple_cursor_slice_components(): ), ( "test_body_data", - RequestOption(inject_into=RequestOptionType.body_data, parameters={}, field_name="owner"), - RequestOption(inject_into=RequestOptionType.body_data, parameters={}, field_name="repo"), + RequestOption( + inject_into=RequestOptionType.body_data, parameters={}, field_name="owner" + ), + RequestOption( + inject_into=RequestOptionType.body_data, parameters={}, field_name="repo" + ), {}, {}, {}, @@ -159,8 +244,12 @@ def test_stream_slices_raises_exception_if_multiple_cursor_slice_components(): ), ( "test_body_json", - RequestOption(inject_into=RequestOptionType.body_json, parameters={}, field_name="owner"), - RequestOption(inject_into=RequestOptionType.body_json, parameters={}, field_name="repo"), + RequestOption( + inject_into=RequestOptionType.body_json, parameters={}, field_name="owner" + ), + RequestOption( + inject_into=RequestOptionType.body_json, parameters={}, field_name="repo" + ), {}, {}, {"owner": "customer", "repo": "airbyte"}, @@ -205,8 +294,12 @@ def test_request_option( def test_request_option_before_updating_cursor(): - stream_1_request_option = RequestOption(inject_into=RequestOptionType.request_parameter, parameters={}, field_name="owner") - stream_2_request_option = RequestOption(inject_into=RequestOptionType.header, parameters={}, field_name="repo") + stream_1_request_option = RequestOption( + inject_into=RequestOptionType.request_parameter, parameters={}, field_name="owner" + ) + stream_2_request_option = RequestOption( + inject_into=RequestOptionType.header, parameters={}, field_name="repo" + ) slicer = CartesianProductStreamSlicer( stream_slicers=[ ListPartitionRouter( diff --git a/unit_tests/sources/declarative/partition_routers/test_list_partition_router.py b/unit_tests/sources/declarative/partition_routers/test_list_partition_router.py index 87aa18f5..baa3ad8d 100644 --- a/unit_tests/sources/declarative/partition_routers/test_list_partition_router.py +++ b/unit_tests/sources/declarative/partition_routers/test_list_partition_router.py @@ -3,8 +3,13 @@ # import pytest as pytest -from airbyte_cdk.sources.declarative.partition_routers.list_partition_router import ListPartitionRouter -from airbyte_cdk.sources.declarative.requesters.request_option import RequestOption, RequestOptionType +from airbyte_cdk.sources.declarative.partition_routers.list_partition_router import ( + ListPartitionRouter, +) +from airbyte_cdk.sources.declarative.requesters.request_option import ( + RequestOption, + RequestOptionType, +) from airbyte_cdk.sources.types import StreamSlice partition_values = ["customer", "store", "subscription"] @@ -50,7 +55,9 @@ ], ) def test_list_partition_router(partition_values, cursor_field, expected_slices): - slicer = ListPartitionRouter(values=partition_values, cursor_field=cursor_field, config={}, parameters=parameters) + slicer = ListPartitionRouter( + values=partition_values, cursor_field=cursor_field, config={}, parameters=parameters + ) slices = [s for s in slicer.stream_slices()] assert slices == expected_slices assert all(isinstance(s, StreamSlice) for s in slices) @@ -60,28 +67,38 @@ def test_list_partition_router(partition_values, cursor_field, expected_slices): "request_option, expected_req_params, expected_headers, expected_body_json, expected_body_data", [ ( - RequestOption(inject_into=RequestOptionType.request_parameter, parameters={}, field_name="owner_resource"), + RequestOption( + inject_into=RequestOptionType.request_parameter, + parameters={}, + field_name="owner_resource", + ), {"owner_resource": "customer"}, {}, {}, {}, ), ( - RequestOption(inject_into=RequestOptionType.header, parameters={}, field_name="owner_resource"), + RequestOption( + inject_into=RequestOptionType.header, parameters={}, field_name="owner_resource" + ), {}, {"owner_resource": "customer"}, {}, {}, ), ( - RequestOption(inject_into=RequestOptionType.body_json, parameters={}, field_name="owner_resource"), + RequestOption( + inject_into=RequestOptionType.body_json, parameters={}, field_name="owner_resource" + ), {}, {}, {"owner_resource": "customer"}, {}, ), ( - RequestOption(inject_into=RequestOptionType.body_data, parameters={}, field_name="owner_resource"), + RequestOption( + inject_into=RequestOptionType.body_data, parameters={}, field_name="owner_resource" + ), {}, {}, {}, @@ -95,9 +112,15 @@ def test_list_partition_router(partition_values, cursor_field, expected_slices): "test_inject_into_body_data", ], ) -def test_request_option(request_option, expected_req_params, expected_headers, expected_body_json, expected_body_data): +def test_request_option( + request_option, expected_req_params, expected_headers, expected_body_json, expected_body_data +): partition_router = ListPartitionRouter( - values=partition_values, cursor_field=cursor_field, config={}, request_option=request_option, parameters={} + values=partition_values, + cursor_field=cursor_field, + config={}, + request_option=request_option, + parameters={}, ) stream_slice = {cursor_field: "customer"} @@ -111,14 +134,23 @@ def test_request_option(request_option, expected_req_params, expected_headers, e "stream_slice", [ pytest.param({}, id="test_request_option_is_empty_if_empty_stream_slice"), - pytest.param({"not the cursor": "value"}, id="test_request_option_is_empty_if_the_stream_slice_does_not_have_cursor_field"), + pytest.param( + {"not the cursor": "value"}, + id="test_request_option_is_empty_if_the_stream_slice_does_not_have_cursor_field", + ), pytest.param(None, id="test_request_option_is_empty_if_no_stream_slice"), ], ) def test_request_option_is_empty_if_no_stream_slice(stream_slice): - request_option = RequestOption(inject_into=RequestOptionType.body_data, parameters={}, field_name="owner_resource") + request_option = RequestOption( + inject_into=RequestOptionType.body_data, parameters={}, field_name="owner_resource" + ) partition_router = ListPartitionRouter( - values=partition_values, cursor_field=cursor_field, config={}, request_option=request_option, parameters={} + values=partition_values, + cursor_field=cursor_field, + config={}, + request_option=request_option, + parameters={}, ) assert {} == partition_router.get_request_body_data(stream_slice=stream_slice) @@ -134,14 +166,22 @@ def test_request_option_is_empty_if_no_stream_slice(stream_slice): "config_interpolation", ], ) -def test_request_options_interpolation(field_name_interpolation: str, expected_request_params: dict): +def test_request_options_interpolation( + field_name_interpolation: str, expected_request_params: dict +): config = {"partition_name": "config_partition"} parameters = {"partition_name": "parameters_partition"} request_option = RequestOption( - inject_into=RequestOptionType.request_parameter, parameters=parameters, field_name=field_name_interpolation + inject_into=RequestOptionType.request_parameter, + parameters=parameters, + field_name=field_name_interpolation, ) partition_router = ListPartitionRouter( - values=partition_values, cursor_field=cursor_field, config=config, request_option=request_option, parameters=parameters + values=partition_values, + cursor_field=cursor_field, + config=config, + request_option=request_option, + parameters=parameters, ) stream_slice = {cursor_field: "customer"} @@ -149,9 +189,15 @@ def test_request_options_interpolation(field_name_interpolation: str, expected_r def test_request_option_before_updating_cursor(): - request_option = RequestOption(inject_into=RequestOptionType.request_parameter, parameters={}, field_name="owner_resource") + request_option = RequestOption( + inject_into=RequestOptionType.request_parameter, parameters={}, field_name="owner_resource" + ) partition_router = ListPartitionRouter( - values=partition_values, cursor_field=cursor_field, config={}, request_option=request_option, parameters={} + values=partition_values, + cursor_field=cursor_field, + config={}, + request_option=request_option, + parameters={}, ) stream_slice = {cursor_field: "customer"} diff --git a/unit_tests/sources/declarative/partition_routers/test_parent_state_stream.py b/unit_tests/sources/declarative/partition_routers/test_parent_state_stream.py index 76a8f082..81de362d 100644 --- a/unit_tests/sources/declarative/partition_routers/test_parent_state_stream.py +++ b/unit_tests/sources/declarative/partition_routers/test_parent_state_stream.py @@ -50,7 +50,11 @@ }, "paginator": { "type": "DefaultPaginator", - "page_size_option": {"type": "RequestOption", "field_name": "per_page", "inject_into": "request_parameter"}, + "page_size_option": { + "type": "RequestOption", + "field_name": "per_page", + "inject_into": "request_parameter", + }, "pagination_strategy": { "type": "CursorPagination", "page_size": 100, @@ -66,7 +70,11 @@ "datetime_format": "%Y-%m-%dT%H:%M:%SZ", "cursor_field": "{{ parameters.get('cursor_field', 'updated_at') }}", "start_datetime": {"datetime": "{{ config.get('start_date')}}"}, - "start_time_option": {"inject_into": "request_parameter", "field_name": "start_time", "type": "RequestOption"}, + "start_time_option": { + "inject_into": "request_parameter", + "field_name": "start_time", + "type": "RequestOption", + }, }, "posts_stream": { "type": "DeclarativeStream", @@ -235,7 +243,11 @@ def _run_read( catalog = ConfiguredAirbyteCatalog( streams=[ ConfiguredAirbyteStream( - stream=AirbyteStream(name=stream_name, json_schema={}, supported_sync_modes=[SyncMode.full_refresh, SyncMode.incremental]), + stream=AirbyteStream( + name=stream_name, + json_schema={}, + supported_sync_modes=[SyncMode.full_refresh, SyncMode.incremental], + ), sync_mode=SyncMode.incremental, destination_sync_mode=DestinationSyncMode.append, ) @@ -245,7 +257,9 @@ def _run_read( return list(source.read(logger, config, catalog, state)) -def run_incremental_parent_state_test(manifest, mock_requests, expected_records, initial_state, expected_states): +def run_incremental_parent_state_test( + manifest, mock_requests, expected_records, initial_state, expected_states +): """ Run an incremental parent state test for the specified stream. @@ -265,7 +279,10 @@ def run_incremental_parent_state_test(manifest, mock_requests, expected_records, expected_states (list): A list of expected final states after the read operation. """ _stream_name = "post_comment_votes" - config = {"start_date": "2024-01-01T00:00:01Z", "credentials": {"email": "email", "api_token": "api_token"}} + config = { + "start_date": "2024-01-01T00:00:01Z", + "credentials": {"email": "email", "api_token": "api_token"}, + } with requests_mock.Mocker() as m: for url, response in mock_requests: @@ -284,7 +301,11 @@ def run_incremental_parent_state_test(manifest, mock_requests, expected_records, final_states = [] # To store the final state after each read # Store the final state after the initial read - final_state_initial = [orjson.loads(orjson.dumps(message.state.stream.stream_state)) for message in output if message.state] + final_state_initial = [ + orjson.loads(orjson.dumps(message.state.stream.stream_state)) + for message in output + if message.state + ] final_states.append(final_state_initial[-1]) for message in output: @@ -300,29 +321,40 @@ def run_incremental_parent_state_test(manifest, mock_requests, expected_records, # For each intermediate state, perform another read starting from that state for state, records_before_state in intermediate_states[:-1]: output_intermediate = _run_read(manifest, config, _stream_name, [state]) - records_from_state = [message.record.data for message in output_intermediate if message.record] + records_from_state = [ + message.record.data for message in output_intermediate if message.record + ] # Combine records produced before the state with records from the new read cumulative_records_state = records_before_state + records_from_state # Duplicates may occur because the state matches the cursor of the last record, causing it to be re-emitted in the next sync. - cumulative_records_state_deduped = list({orjson.dumps(record): record for record in cumulative_records_state}.values()) + cumulative_records_state_deduped = list( + {orjson.dumps(record): record for record in cumulative_records_state}.values() + ) # Compare the cumulative records with the expected records - expected_records_set = list({orjson.dumps(record): record for record in expected_records}.values()) - assert sorted(cumulative_records_state_deduped, key=lambda x: orjson.dumps(x)) == sorted( - expected_records_set, key=lambda x: orjson.dumps(x) + expected_records_set = list( + {orjson.dumps(record): record for record in expected_records}.values() + ) + assert ( + sorted(cumulative_records_state_deduped, key=lambda x: orjson.dumps(x)) + == sorted(expected_records_set, key=lambda x: orjson.dumps(x)) ), f"Records mismatch with intermediate state {state}. Expected {expected_records}, got {cumulative_records_state_deduped}" # Store the final state after each intermediate read final_state_intermediate = [ - orjson.loads(orjson.dumps(message.state.stream.stream_state)) for message in output_intermediate if message.state + orjson.loads(orjson.dumps(message.state.stream.stream_state)) + for message in output_intermediate + if message.state ] final_states.append(final_state_intermediate[-1]) # Assert that the final state matches the expected state for all runs for i, final_state in enumerate(final_states): - assert final_state in expected_states, f"Final state mismatch at run {i + 1}. Expected {expected_states}, got {final_state}" + assert ( + final_state in expected_states + ), f"Final state mismatch at run {i + 1}. Expected {expected_states}, got {final_state}" @pytest.mark.parametrize( @@ -336,7 +368,10 @@ def run_incremental_parent_state_test(manifest, mock_requests, expected_records, ( "https://api.example.com/community/posts?per_page=100&start_time=2024-01-05T00:00:00Z", { - "posts": [{"id": 1, "updated_at": "2024-01-30T00:00:00Z"}, {"id": 2, "updated_at": "2024-01-29T00:00:00Z"}], + "posts": [ + {"id": 1, "updated_at": "2024-01-30T00:00:00Z"}, + {"id": 2, "updated_at": "2024-01-29T00:00:00Z"}, + ], "next_page": "https://api.example.com/community/posts?per_page=100&start_time=2024-01-05T00:00:00Z&page=2", }, ), @@ -366,27 +401,42 @@ def run_incremental_parent_state_test(manifest, mock_requests, expected_records, ( "https://api.example.com/community/posts/1/comments/10/votes?per_page=100&start_time=2024-01-02T00:00:00Z", { - "votes": [{"id": 100, "comment_id": 10, "created_at": "2024-01-15T00:00:00Z"}], + "votes": [ + {"id": 100, "comment_id": 10, "created_at": "2024-01-15T00:00:00Z"} + ], "next_page": "https://api.example.com/community/posts/1/comments/10/votes?per_page=100&page=2&start_time=2024-01-01T00:00:01Z", }, ), # Fetch the second page of votes for comment 10 of post 1 ( "https://api.example.com/community/posts/1/comments/10/votes?per_page=100&page=2&start_time=2024-01-01T00:00:01Z", - {"votes": [{"id": 101, "comment_id": 10, "created_at": "2024-01-14T00:00:00Z"}]}, + { + "votes": [ + {"id": 101, "comment_id": 10, "created_at": "2024-01-14T00:00:00Z"} + ] + }, ), # Fetch the first page of votes for comment 11 of post 1 ( "https://api.example.com/community/posts/1/comments/11/votes?per_page=100&start_time=2024-01-03T00:00:00Z", - {"votes": [{"id": 102, "comment_id": 11, "created_at": "2024-01-13T00:00:00Z"}]}, + { + "votes": [ + {"id": 102, "comment_id": 11, "created_at": "2024-01-13T00:00:00Z"} + ] + }, ), # Fetch the first page of votes for comment 12 of post 1 - ("https://api.example.com/community/posts/1/comments/12/votes?per_page=100&start_time=2024-01-01T00:00:01Z", {"votes": []}), + ( + "https://api.example.com/community/posts/1/comments/12/votes?per_page=100&start_time=2024-01-01T00:00:01Z", + {"votes": []}, + ), # Fetch the first page of comments for post 2 ( "https://api.example.com/community/posts/2/comments?per_page=100", { - "comments": [{"id": 20, "post_id": 2, "updated_at": "2024-01-22T00:00:00Z"}], + "comments": [ + {"id": 20, "post_id": 2, "updated_at": "2024-01-22T00:00:00Z"} + ], "next_page": "https://api.example.com/community/posts/2/comments?per_page=100&page=2", }, ), @@ -398,12 +448,20 @@ def run_incremental_parent_state_test(manifest, mock_requests, expected_records, # Fetch the first page of votes for comment 20 of post 2 ( "https://api.example.com/community/posts/2/comments/20/votes?per_page=100&start_time=2024-01-01T00:00:01Z", - {"votes": [{"id": 200, "comment_id": 20, "created_at": "2024-01-12T00:00:00Z"}]}, + { + "votes": [ + {"id": 200, "comment_id": 20, "created_at": "2024-01-12T00:00:00Z"} + ] + }, ), # Fetch the first page of votes for comment 21 of post 2 ( "https://api.example.com/community/posts/2/comments/21/votes?per_page=100&start_time=2024-01-01T00:00:01Z", - {"votes": [{"id": 201, "comment_id": 21, "created_at": "2024-01-12T00:00:15Z"}]}, + { + "votes": [ + {"id": 201, "comment_id": 21, "created_at": "2024-01-12T00:00:15Z"} + ] + }, ), # Fetch the first page of comments for post 3 ( @@ -413,21 +471,29 @@ def run_incremental_parent_state_test(manifest, mock_requests, expected_records, # Fetch the first page of votes for comment 30 of post 3 ( "https://api.example.com/community/posts/3/comments/30/votes?per_page=100", - {"votes": [{"id": 300, "comment_id": 30, "created_at": "2024-01-10T00:00:00Z"}]}, + { + "votes": [ + {"id": 300, "comment_id": 30, "created_at": "2024-01-10T00:00:00Z"} + ] + }, ), # Requests with intermediate states # Fetch votes for comment 10 of post 1 ( "https://api.example.com/community/posts/1/comments/10/votes?per_page=100&start_time=2024-01-15T00:00:00Z", { - "votes": [{"id": 100, "comment_id": 10, "created_at": "2024-01-15T00:00:00Z"}], + "votes": [ + {"id": 100, "comment_id": 10, "created_at": "2024-01-15T00:00:00Z"} + ], }, ), # Fetch votes for comment 11 of post 1 ( "https://api.example.com/community/posts/1/comments/11/votes?per_page=100&start_time=2024-01-13T00:00:00Z", { - "votes": [{"id": 102, "comment_id": 11, "created_at": "2024-01-13T00:00:00Z"}], + "votes": [ + {"id": 102, "comment_id": 11, "created_at": "2024-01-13T00:00:00Z"} + ], }, ), # Fetch votes for comment 12 of post 1 @@ -440,12 +506,20 @@ def run_incremental_parent_state_test(manifest, mock_requests, expected_records, # Fetch votes for comment 20 of post 2 ( "https://api.example.com/community/posts/2/comments/20/votes?per_page=100&start_time=2024-01-12T00:00:00Z", - {"votes": [{"id": 200, "comment_id": 20, "created_at": "2024-01-12T00:00:00Z"}]}, + { + "votes": [ + {"id": 200, "comment_id": 20, "created_at": "2024-01-12T00:00:00Z"} + ] + }, ), # Fetch votes for comment 21 of post 2 ( "https://api.example.com/community/posts/2/comments/21/votes?per_page=100&start_time=2024-01-12T00:00:15Z", - {"votes": [{"id": 201, "comment_id": 21, "created_at": "2024-01-12T00:00:15Z"}]}, + { + "votes": [ + {"id": 201, "comment_id": 21, "created_at": "2024-01-12T00:00:15Z"} + ] + }, ), ], # Expected records @@ -462,24 +536,37 @@ def run_incremental_parent_state_test(manifest, mock_requests, expected_records, AirbyteStateMessage( type=AirbyteStateType.STREAM, stream=AirbyteStreamState( - stream_descriptor=StreamDescriptor(name="post_comment_votes", namespace=None), + stream_descriptor=StreamDescriptor( + name="post_comment_votes", namespace=None + ), stream_state=AirbyteStateBlob( { "parent_state": { "post_comments": { "states": [ - {"partition": {"id": 1, "parent_slice": {}}, "cursor": {"updated_at": "2023-01-04T00:00:00Z"}} + { + "partition": {"id": 1, "parent_slice": {}}, + "cursor": {"updated_at": "2023-01-04T00:00:00Z"}, + } ], - "parent_state": {"posts": {"updated_at": "2024-01-05T00:00:00Z"}}, + "parent_state": { + "posts": {"updated_at": "2024-01-05T00:00:00Z"} + }, } }, "states": [ { - "partition": {"id": 10, "parent_slice": {"id": 1, "parent_slice": {}}}, + "partition": { + "id": 10, + "parent_slice": {"id": 1, "parent_slice": {}}, + }, "cursor": {"created_at": "2024-01-02T00:00:00Z"}, }, { - "partition": {"id": 11, "parent_slice": {"id": 1, "parent_slice": {}}}, + "partition": { + "id": 11, + "parent_slice": {"id": 1, "parent_slice": {}}, + }, "cursor": {"created_at": "2024-01-03T00:00:00Z"}, }, ], @@ -499,9 +586,18 @@ def run_incremental_parent_state_test(manifest, mock_requests, expected_records, "parent_state": {"posts": {"updated_at": "2024-01-30T00:00:00Z"}}, "lookback_window": 1, "states": [ - {"partition": {"id": 1, "parent_slice": {}}, "cursor": {"updated_at": "2024-01-25T00:00:00Z"}}, - {"partition": {"id": 2, "parent_slice": {}}, "cursor": {"updated_at": "2024-01-22T00:00:00Z"}}, - {"partition": {"id": 3, "parent_slice": {}}, "cursor": {"updated_at": "2024-01-09T00:00:00Z"}}, + { + "partition": {"id": 1, "parent_slice": {}}, + "cursor": {"updated_at": "2024-01-25T00:00:00Z"}, + }, + { + "partition": {"id": 2, "parent_slice": {}}, + "cursor": {"updated_at": "2024-01-22T00:00:00Z"}, + }, + { + "partition": {"id": 3, "parent_slice": {}}, + "cursor": {"updated_at": "2024-01-09T00:00:00Z"}, + }, ], } }, @@ -532,12 +628,23 @@ def run_incremental_parent_state_test(manifest, mock_requests, expected_records, ), ], ) -def test_incremental_parent_state(test_name, manifest, mock_requests, expected_records, initial_state, expected_state): +def test_incremental_parent_state( + test_name, manifest, mock_requests, expected_records, initial_state, expected_state +): additional_expected_state = copy.deepcopy(expected_state) # State for empty partition (comment 12), when the global cursor is used for intermediate states - empty_state = {"cursor": {"created_at": "2024-01-15T00:00:00Z"}, "partition": {"id": 12, "parent_slice": {"id": 1, "parent_slice": {}}}} + empty_state = { + "cursor": {"created_at": "2024-01-15T00:00:00Z"}, + "partition": {"id": 12, "parent_slice": {"id": 1, "parent_slice": {}}}, + } additional_expected_state["states"].append(empty_state) - run_incremental_parent_state_test(manifest, mock_requests, expected_records, initial_state, [expected_state, additional_expected_state]) + run_incremental_parent_state_test( + manifest, + mock_requests, + expected_records, + initial_state, + [expected_state, additional_expected_state], + ) @pytest.mark.parametrize( @@ -551,7 +658,10 @@ def test_incremental_parent_state(test_name, manifest, mock_requests, expected_r ( "https://api.example.com/community/posts?per_page=100&start_time=2024-01-02T00:00:00Z", { - "posts": [{"id": 1, "updated_at": "2024-01-30T00:00:00Z"}, {"id": 2, "updated_at": "2024-01-29T00:00:00Z"}], + "posts": [ + {"id": 1, "updated_at": "2024-01-30T00:00:00Z"}, + {"id": 2, "updated_at": "2024-01-29T00:00:00Z"}, + ], "next_page": "https://api.example.com/community/posts?per_page=100&start_time=2024-01-02T00:00:00Z&page=2", }, ), @@ -581,27 +691,42 @@ def test_incremental_parent_state(test_name, manifest, mock_requests, expected_r ( "https://api.example.com/community/posts/1/comments/10/votes?per_page=100&start_time=2024-01-02T00:00:00Z", { - "votes": [{"id": 100, "comment_id": 10, "created_at": "2024-01-15T00:00:00Z"}], + "votes": [ + {"id": 100, "comment_id": 10, "created_at": "2024-01-15T00:00:00Z"} + ], "next_page": "https://api.example.com/community/posts/1/comments/10/votes?per_page=100&page=2&start_time=2024-01-02T00:00:00Z", }, ), # Fetch the second page of votes for comment 10 of post 1 ( "https://api.example.com/community/posts/1/comments/10/votes?per_page=100&page=2&start_time=2024-01-02T00:00:00Z", - {"votes": [{"id": 101, "comment_id": 10, "created_at": "2024-01-14T00:00:00Z"}]}, + { + "votes": [ + {"id": 101, "comment_id": 10, "created_at": "2024-01-14T00:00:00Z"} + ] + }, ), # Fetch the first page of votes for comment 11 of post 1 ( "https://api.example.com/community/posts/1/comments/11/votes?per_page=100&start_time=2024-01-02T00:00:00Z", - {"votes": [{"id": 102, "comment_id": 11, "created_at": "2024-01-13T00:00:00Z"}]}, + { + "votes": [ + {"id": 102, "comment_id": 11, "created_at": "2024-01-13T00:00:00Z"} + ] + }, ), # Fetch the first page of votes for comment 12 of post 1 - ("https://api.example.com/community/posts/1/comments/12/votes?per_page=100&start_time=2024-01-02T00:00:00Z", {"votes": []}), + ( + "https://api.example.com/community/posts/1/comments/12/votes?per_page=100&start_time=2024-01-02T00:00:00Z", + {"votes": []}, + ), # Fetch the first page of comments for post 2 ( "https://api.example.com/community/posts/2/comments?per_page=100", { - "comments": [{"id": 20, "post_id": 2, "updated_at": "2024-01-22T00:00:00Z"}], + "comments": [ + {"id": 20, "post_id": 2, "updated_at": "2024-01-22T00:00:00Z"} + ], "next_page": "https://api.example.com/community/posts/2/comments?per_page=100&page=2", }, ), @@ -613,12 +738,20 @@ def test_incremental_parent_state(test_name, manifest, mock_requests, expected_r # Fetch the first page of votes for comment 20 of post 2 ( "https://api.example.com/community/posts/2/comments/20/votes?per_page=100&start_time=2024-01-02T00:00:00Z", - {"votes": [{"id": 200, "comment_id": 20, "created_at": "2024-01-12T00:00:00Z"}]}, + { + "votes": [ + {"id": 200, "comment_id": 20, "created_at": "2024-01-12T00:00:00Z"} + ] + }, ), # Fetch the first page of votes for comment 21 of post 2 ( "https://api.example.com/community/posts/2/comments/21/votes?per_page=100&start_time=2024-01-02T00:00:00Z", - {"votes": [{"id": 201, "comment_id": 21, "created_at": "2024-01-12T00:00:15Z"}]}, + { + "votes": [ + {"id": 201, "comment_id": 21, "created_at": "2024-01-12T00:00:15Z"} + ] + }, ), # Fetch the first page of comments for post 3 ( @@ -628,7 +761,11 @@ def test_incremental_parent_state(test_name, manifest, mock_requests, expected_r # Fetch the first page of votes for comment 30 of post 3 ( "https://api.example.com/community/posts/3/comments/30/votes?per_page=100&start_time=2024-01-02T00:00:00Z", - {"votes": [{"id": 300, "comment_id": 30, "created_at": "2024-01-10T00:00:00Z"}]}, + { + "votes": [ + {"id": 300, "comment_id": 30, "created_at": "2024-01-10T00:00:00Z"} + ] + }, ), ], # Expected records @@ -645,7 +782,9 @@ def test_incremental_parent_state(test_name, manifest, mock_requests, expected_r AirbyteStateMessage( type=AirbyteStateType.STREAM, stream=AirbyteStreamState( - stream_descriptor=StreamDescriptor(name="post_comment_votes", namespace=None), + stream_descriptor=StreamDescriptor( + name="post_comment_votes", namespace=None + ), stream_state=AirbyteStateBlob({"created_at": "2024-01-02T00:00:00Z"}), ), ) @@ -661,9 +800,18 @@ def test_incremental_parent_state(test_name, manifest, mock_requests, expected_r "parent_state": {"posts": {"updated_at": "2024-01-30T00:00:00Z"}}, "lookback_window": 1, "states": [ - {"partition": {"id": 1, "parent_slice": {}}, "cursor": {"updated_at": "2024-01-25T00:00:00Z"}}, - {"partition": {"id": 2, "parent_slice": {}}, "cursor": {"updated_at": "2024-01-22T00:00:00Z"}}, - {"partition": {"id": 3, "parent_slice": {}}, "cursor": {"updated_at": "2024-01-09T00:00:00Z"}}, + { + "partition": {"id": 1, "parent_slice": {}}, + "cursor": {"updated_at": "2024-01-25T00:00:00Z"}, + }, + { + "partition": {"id": 2, "parent_slice": {}}, + "cursor": {"updated_at": "2024-01-22T00:00:00Z"}, + }, + { + "partition": {"id": 3, "parent_slice": {}}, + "cursor": {"updated_at": "2024-01-09T00:00:00Z"}, + }, ], } }, @@ -698,12 +846,17 @@ def test_incremental_parent_state(test_name, manifest, mock_requests, expected_r ), ], ) -def test_incremental_parent_state_migration(test_name, manifest, mock_requests, expected_records, initial_state, expected_state): +def test_incremental_parent_state_migration( + test_name, manifest, mock_requests, expected_records, initial_state, expected_state +): """ Test incremental partition router with parent state migration """ _stream_name = "post_comment_votes" - config = {"start_date": "2024-01-01T00:00:01Z", "credentials": {"email": "email", "api_token": "api_token"}} + config = { + "start_date": "2024-01-01T00:00:01Z", + "credentials": {"email": "email", "api_token": "api_token"}, + } with requests_mock.Mocker() as m: for url, response in mock_requests: @@ -713,7 +866,11 @@ def test_incremental_parent_state_migration(test_name, manifest, mock_requests, output_data = [message.record.data for message in output if message.record] assert output_data == expected_records - final_state = [orjson.loads(orjson.dumps(message.state.stream.stream_state)) for message in output if message.state] + final_state = [ + orjson.loads(orjson.dumps(message.state.stream.stream_state)) + for message in output + if message.state + ] assert final_state[-1] == expected_state @@ -769,7 +926,10 @@ def test_incremental_parent_state_migration(test_name, manifest, mock_requests, {"votes": []}, ), # Fetch the first page of votes for comment 12 of post 1 - ("https://api.example.com/community/posts/1/comments/12/votes?per_page=100&start_time=2024-01-01T00:00:01Z", {"votes": []}), + ( + "https://api.example.com/community/posts/1/comments/12/votes?per_page=100&start_time=2024-01-01T00:00:01Z", + {"votes": []}, + ), # Fetch the first page of comments for post 2 ( "https://api.example.com/community/posts/2/comments?per_page=100", @@ -811,24 +971,37 @@ def test_incremental_parent_state_migration(test_name, manifest, mock_requests, AirbyteStateMessage( type=AirbyteStateType.STREAM, stream=AirbyteStreamState( - stream_descriptor=StreamDescriptor(name="post_comment_votes", namespace=None), + stream_descriptor=StreamDescriptor( + name="post_comment_votes", namespace=None + ), stream_state=AirbyteStateBlob( { "parent_state": { "post_comments": { "states": [ - {"partition": {"id": 1, "parent_slice": {}}, "cursor": {"updated_at": "2023-01-04T00:00:00Z"}} + { + "partition": {"id": 1, "parent_slice": {}}, + "cursor": {"updated_at": "2023-01-04T00:00:00Z"}, + } ], - "parent_state": {"posts": {"updated_at": "2024-01-05T00:00:00Z"}}, + "parent_state": { + "posts": {"updated_at": "2024-01-05T00:00:00Z"} + }, } }, "states": [ { - "partition": {"id": 10, "parent_slice": {"id": 1, "parent_slice": {}}}, + "partition": { + "id": 10, + "parent_slice": {"id": 1, "parent_slice": {}}, + }, "cursor": {"created_at": "2024-01-02T00:00:00Z"}, }, { - "partition": {"id": 11, "parent_slice": {"id": 1, "parent_slice": {}}}, + "partition": { + "id": 11, + "parent_slice": {"id": 1, "parent_slice": {}}, + }, "cursor": {"created_at": "2024-01-03T00:00:00Z"}, }, ], @@ -849,7 +1022,12 @@ def test_incremental_parent_state_migration(test_name, manifest, mock_requests, "use_global_cursor": False, "state": {}, "parent_state": {"posts": {"updated_at": "2024-01-05T00:00:00Z"}}, - "states": [{"partition": {"id": 1, "parent_slice": {}}, "cursor": {"updated_at": "2023-01-04T00:00:00Z"}}], + "states": [ + { + "partition": {"id": 1, "parent_slice": {}}, + "cursor": {"updated_at": "2023-01-04T00:00:00Z"}, + } + ], } }, "states": [ @@ -866,12 +1044,17 @@ def test_incremental_parent_state_migration(test_name, manifest, mock_requests, ), ], ) -def test_incremental_parent_state_no_slices(test_name, manifest, mock_requests, expected_records, initial_state, expected_state): +def test_incremental_parent_state_no_slices( + test_name, manifest, mock_requests, expected_records, initial_state, expected_state +): """ Test incremental partition router with no parent records """ _stream_name = "post_comment_votes" - config = {"start_date": "2024-01-01T00:00:01Z", "credentials": {"email": "email", "api_token": "api_token"}} + config = { + "start_date": "2024-01-01T00:00:01Z", + "credentials": {"email": "email", "api_token": "api_token"}, + } with requests_mock.Mocker() as m: for url, response in mock_requests: @@ -881,7 +1064,11 @@ def test_incremental_parent_state_no_slices(test_name, manifest, mock_requests, output_data = [message.record.data for message in output if message.record] assert output_data == expected_records - final_state = [orjson.loads(orjson.dumps(message.state.stream.stream_state)) for message in output if message.state] + final_state = [ + orjson.loads(orjson.dumps(message.state.stream.stream_state)) + for message in output + if message.state + ] assert final_state[-1] == expected_state @@ -896,7 +1083,10 @@ def test_incremental_parent_state_no_slices(test_name, manifest, mock_requests, ( "https://api.example.com/community/posts?per_page=100&start_time=2024-01-05T00:00:00Z", { - "posts": [{"id": 1, "updated_at": "2024-01-30T00:00:00Z"}, {"id": 2, "updated_at": "2024-01-29T00:00:00Z"}], + "posts": [ + {"id": 1, "updated_at": "2024-01-30T00:00:00Z"}, + {"id": 2, "updated_at": "2024-01-29T00:00:00Z"}, + ], "next_page": "https://api.example.com/community/posts?per_page=100&start_time=2024-01-05T00:00:00Z&page=2", }, ), @@ -941,12 +1131,17 @@ def test_incremental_parent_state_no_slices(test_name, manifest, mock_requests, {"votes": []}, ), # Fetch the first page of votes for comment 12 of post 1 - ("https://api.example.com/community/posts/1/comments/12/votes?per_page=100&start_time=2024-01-03T00:00:00Z", {"votes": []}), + ( + "https://api.example.com/community/posts/1/comments/12/votes?per_page=100&start_time=2024-01-03T00:00:00Z", + {"votes": []}, + ), # Fetch the first page of comments for post 2 ( "https://api.example.com/community/posts/2/comments?per_page=100", { - "comments": [{"id": 20, "post_id": 2, "updated_at": "2024-01-22T00:00:00Z"}], + "comments": [ + {"id": 20, "post_id": 2, "updated_at": "2024-01-22T00:00:00Z"} + ], "next_page": "https://api.example.com/community/posts/2/comments?per_page=100&page=2", }, ), @@ -983,24 +1178,37 @@ def test_incremental_parent_state_no_slices(test_name, manifest, mock_requests, AirbyteStateMessage( type=AirbyteStateType.STREAM, stream=AirbyteStreamState( - stream_descriptor=StreamDescriptor(name="post_comment_votes", namespace=None), + stream_descriptor=StreamDescriptor( + name="post_comment_votes", namespace=None + ), stream_state=AirbyteStateBlob( { "parent_state": { "post_comments": { "states": [ - {"partition": {"id": 1, "parent_slice": {}}, "cursor": {"updated_at": "2023-01-04T00:00:00Z"}} + { + "partition": {"id": 1, "parent_slice": {}}, + "cursor": {"updated_at": "2023-01-04T00:00:00Z"}, + } ], - "parent_state": {"posts": {"updated_at": "2024-01-05T00:00:00Z"}}, + "parent_state": { + "posts": {"updated_at": "2024-01-05T00:00:00Z"} + }, } }, "states": [ { - "partition": {"id": 10, "parent_slice": {"id": 1, "parent_slice": {}}}, + "partition": { + "id": 10, + "parent_slice": {"id": 1, "parent_slice": {}}, + }, "cursor": {"created_at": "2024-01-02T00:00:00Z"}, }, { - "partition": {"id": 11, "parent_slice": {"id": 1, "parent_slice": {}}}, + "partition": { + "id": 11, + "parent_slice": {"id": 1, "parent_slice": {}}, + }, "cursor": {"created_at": "2024-01-03T00:00:00Z"}, }, ], @@ -1024,9 +1232,18 @@ def test_incremental_parent_state_no_slices(test_name, manifest, mock_requests, "parent_state": {"posts": {"updated_at": "2024-01-30T00:00:00Z"}}, "lookback_window": 1, "states": [ - {"partition": {"id": 1, "parent_slice": {}}, "cursor": {"updated_at": "2024-01-25T00:00:00Z"}}, - {"partition": {"id": 2, "parent_slice": {}}, "cursor": {"updated_at": "2024-01-22T00:00:00Z"}}, - {"partition": {"id": 3, "parent_slice": {}}, "cursor": {"updated_at": "2024-01-09T00:00:00Z"}}, + { + "partition": {"id": 1, "parent_slice": {}}, + "cursor": {"updated_at": "2024-01-25T00:00:00Z"}, + }, + { + "partition": {"id": 2, "parent_slice": {}}, + "cursor": {"updated_at": "2024-01-22T00:00:00Z"}, + }, + { + "partition": {"id": 3, "parent_slice": {}}, + "cursor": {"updated_at": "2024-01-09T00:00:00Z"}, + }, ], } }, @@ -1034,12 +1251,17 @@ def test_incremental_parent_state_no_slices(test_name, manifest, mock_requests, ), ], ) -def test_incremental_parent_state_no_records(test_name, manifest, mock_requests, expected_records, initial_state, expected_state): +def test_incremental_parent_state_no_records( + test_name, manifest, mock_requests, expected_records, initial_state, expected_state +): """ Test incremental partition router with no child records """ _stream_name = "post_comment_votes" - config = {"start_date": "2024-01-01T00:00:01Z", "credentials": {"email": "email", "api_token": "api_token"}} + config = { + "start_date": "2024-01-01T00:00:01Z", + "credentials": {"email": "email", "api_token": "api_token"}, + } with requests_mock.Mocker() as m: for url, response in mock_requests: @@ -1049,7 +1271,11 @@ def test_incremental_parent_state_no_records(test_name, manifest, mock_requests, output_data = [message.record.data for message in output if message.record] assert output_data == expected_records - final_state = [orjson.loads(orjson.dumps(message.state.stream.stream_state)) for message in output if message.state] + final_state = [ + orjson.loads(orjson.dumps(message.state.stream.stream_state)) + for message in output + if message.state + ] assert final_state[-1] == expected_state @@ -1064,7 +1290,10 @@ def test_incremental_parent_state_no_records(test_name, manifest, mock_requests, ( "https://api.example.com/community/posts?per_page=100&start_time=2024-01-01T00:00:01Z", { - "posts": [{"id": 1, "updated_at": "2024-01-30T00:00:00Z"}, {"id": 2, "updated_at": "2024-01-29T00:00:00Z"}], + "posts": [ + {"id": 1, "updated_at": "2024-01-30T00:00:00Z"}, + {"id": 2, "updated_at": "2024-01-29T00:00:00Z"}, + ], "next_page": "https://api.example.com/community/posts?per_page=100&start_time=2024-01-05T00:00:00Z&page=2", }, ), @@ -1094,27 +1323,42 @@ def test_incremental_parent_state_no_records(test_name, manifest, mock_requests, ( "https://api.example.com/community/posts/1/comments/10/votes?per_page=100&start_time=2024-01-02T00:00:00Z", { - "votes": [{"id": 100, "comment_id": 10, "created_at": "2024-01-15T00:00:00Z"}], + "votes": [ + {"id": 100, "comment_id": 10, "created_at": "2024-01-15T00:00:00Z"} + ], "next_page": "https://api.example.com/community/posts/1/comments/10/votes?per_page=100&page=2&start_time=2024-01-01T00:00:01Z", }, ), # Fetch the second page of votes for comment 10 of post 1 ( "https://api.example.com/community/posts/1/comments/10/votes?per_page=100&page=2&start_time=2024-01-01T00:00:01Z", - {"votes": [{"id": 101, "comment_id": 10, "created_at": "2024-01-14T00:00:00Z"}]}, + { + "votes": [ + {"id": 101, "comment_id": 10, "created_at": "2024-01-14T00:00:00Z"} + ] + }, ), # Fetch the first page of votes for comment 11 of post 1 ( "https://api.example.com/community/posts/1/comments/11/votes?per_page=100&start_time=2024-01-03T00:00:00Z", - {"votes": [{"id": 102, "comment_id": 11, "created_at": "2024-01-13T00:00:00Z"}]}, + { + "votes": [ + {"id": 102, "comment_id": 11, "created_at": "2024-01-13T00:00:00Z"} + ] + }, ), # Fetch the first page of votes for comment 12 of post 1 - ("https://api.example.com/community/posts/1/comments/12/votes?per_page=100&start_time=2024-01-01T00:00:01Z", {"votes": []}), + ( + "https://api.example.com/community/posts/1/comments/12/votes?per_page=100&start_time=2024-01-01T00:00:01Z", + {"votes": []}, + ), # Fetch the first page of comments for post 2 ( "https://api.example.com/community/posts/2/comments?per_page=100", { - "comments": [{"id": 20, "post_id": 2, "updated_at": "2024-01-22T00:00:00Z"}], + "comments": [ + {"id": 20, "post_id": 2, "updated_at": "2024-01-22T00:00:00Z"} + ], "next_page": "https://api.example.com/community/posts/2/comments?per_page=100&page=2", }, ), @@ -1126,12 +1370,20 @@ def test_incremental_parent_state_no_records(test_name, manifest, mock_requests, # Fetch the first page of votes for comment 20 of post 2 ( "https://api.example.com/community/posts/2/comments/20/votes?per_page=100&start_time=2024-01-01T00:00:01Z", - {"votes": [{"id": 200, "comment_id": 20, "created_at": "2024-01-12T00:00:00Z"}]}, + { + "votes": [ + {"id": 200, "comment_id": 20, "created_at": "2024-01-12T00:00:00Z"} + ] + }, ), # Fetch the first page of votes for comment 21 of post 2 ( "https://api.example.com/community/posts/2/comments/21/votes?per_page=100&start_time=2024-01-01T00:00:01Z", - {"votes": [{"id": 201, "comment_id": 21, "created_at": "2024-01-12T00:00:15Z"}]}, + { + "votes": [ + {"id": 201, "comment_id": 21, "created_at": "2024-01-12T00:00:15Z"} + ] + }, ), # Fetch the first page of comments for post 3 ( @@ -1141,7 +1393,11 @@ def test_incremental_parent_state_no_records(test_name, manifest, mock_requests, # Fetch the first page of votes for comment 30 of post 3 ( "https://api.example.com/community/posts/3/comments/30/votes?per_page=100", - {"votes": [{"id": 300, "comment_id": 30, "created_at": "2024-01-10T00:00:00Z"}]}, + { + "votes": [ + {"id": 300, "comment_id": 30, "created_at": "2024-01-10T00:00:00Z"} + ] + }, ), ], # Expected records @@ -1158,7 +1414,9 @@ def test_incremental_parent_state_no_records(test_name, manifest, mock_requests, AirbyteStateMessage( type=AirbyteStateType.STREAM, stream=AirbyteStreamState( - stream_descriptor=StreamDescriptor(name="post_comment_votes", namespace=None), + stream_descriptor=StreamDescriptor( + name="post_comment_votes", namespace=None + ), stream_state=AirbyteStateBlob( { # This should not happen since parent state is disabled, but I've added this to validate that and @@ -1166,18 +1424,29 @@ def test_incremental_parent_state_no_records(test_name, manifest, mock_requests, "parent_state": { "post_comments": { "states": [ - {"partition": {"id": 1, "parent_slice": {}}, "cursor": {"updated_at": "2023-01-04T00:00:00Z"}} + { + "partition": {"id": 1, "parent_slice": {}}, + "cursor": {"updated_at": "2023-01-04T00:00:00Z"}, + } ], - "parent_state": {"posts": {"updated_at": "2024-01-05T00:00:00Z"}}, + "parent_state": { + "posts": {"updated_at": "2024-01-05T00:00:00Z"} + }, } }, "states": [ { - "partition": {"id": 10, "parent_slice": {"id": 1, "parent_slice": {}}}, + "partition": { + "id": 10, + "parent_slice": {"id": 1, "parent_slice": {}}, + }, "cursor": {"created_at": "2024-01-02T00:00:00Z"}, }, { - "partition": {"id": 11, "parent_slice": {"id": 1, "parent_slice": {}}}, + "partition": { + "id": 11, + "parent_slice": {"id": 1, "parent_slice": {}}, + }, "cursor": {"created_at": "2024-01-03T00:00:00Z"}, }, ], @@ -1234,15 +1503,18 @@ def test_incremental_parent_state_no_incremental_dependency( """ _stream_name = "post_comment_votes" - config = {"start_date": "2024-01-01T00:00:01Z", "credentials": {"email": "email", "api_token": "api_token"}} + config = { + "start_date": "2024-01-01T00:00:01Z", + "credentials": {"email": "email", "api_token": "api_token"}, + } # Disable incremental_dependency - manifest["definitions"]["post_comments_stream"]["retriever"]["partition_router"]["parent_stream_configs"][0][ - "incremental_dependency" - ] = False - manifest["definitions"]["post_comment_votes_stream"]["retriever"]["partition_router"]["parent_stream_configs"][0][ - "incremental_dependency" - ] = False + manifest["definitions"]["post_comments_stream"]["retriever"]["partition_router"][ + "parent_stream_configs" + ][0]["incremental_dependency"] = False + manifest["definitions"]["post_comment_votes_stream"]["retriever"]["partition_router"][ + "parent_stream_configs" + ][0]["incremental_dependency"] = False with requests_mock.Mocker() as m: for url, response in mock_requests: @@ -1252,7 +1524,11 @@ def test_incremental_parent_state_no_incremental_dependency( output_data = [message.record.data for message in output if message.record] assert output_data == expected_records - final_state = [orjson.loads(orjson.dumps(message.state.stream.stream_state)) for message in output if message.state] + final_state = [ + orjson.loads(orjson.dumps(message.state.stream.stream_state)) + for message in output + if message.state + ] assert final_state[-1] == expected_state @@ -1284,7 +1560,11 @@ def test_incremental_parent_state_no_incremental_dependency( }, "paginator": { "type": "DefaultPaginator", - "page_size_option": {"type": "RequestOption", "field_name": "per_page", "inject_into": "request_parameter"}, + "page_size_option": { + "type": "RequestOption", + "field_name": "per_page", + "inject_into": "request_parameter", + }, "pagination_strategy": { "type": "CursorPagination", "page_size": 100, @@ -1300,7 +1580,11 @@ def test_incremental_parent_state_no_incremental_dependency( "datetime_format": "%Y-%m-%dT%H:%M:%SZ", "cursor_field": "{{ parameters.get('cursor_field', 'updated_at') }}", "start_datetime": {"datetime": "{{ config.get('start_date')}}"}, - "start_time_option": {"inject_into": "request_parameter", "field_name": "start_time", "type": "RequestOption"}, + "start_time_option": { + "inject_into": "request_parameter", + "field_name": "start_time", + "type": "RequestOption", + }, }, "posts_stream": { "type": "DeclarativeStream", @@ -1447,7 +1731,11 @@ def test_incremental_parent_state_no_incremental_dependency( "datetime_format": "%Y-%m-%dT%H:%M:%SZ", "cursor_field": "{{ parameters.get('cursor_field', 'updated_at') }}", "start_datetime": {"datetime": "{{ config.get('start_date')}}"}, - "start_time_option": {"inject_into": "request_parameter", "field_name": "start_time", "type": "RequestOption"}, + "start_time_option": { + "inject_into": "request_parameter", + "field_name": "start_time", + "type": "RequestOption", + }, "global_substream_cursor": True, }, "$parameters": { @@ -1465,10 +1753,12 @@ def test_incremental_parent_state_no_incremental_dependency( {"$ref": "#/definitions/post_comment_votes_stream"}, ], } -SUBSTREAM_MANIFEST_GLOBAL_PARENT_CURSOR_NO_DEPENDENCY = copy.deepcopy(SUBSTREAM_MANIFEST_GLOBAL_PARENT_CURSOR) -SUBSTREAM_MANIFEST_GLOBAL_PARENT_CURSOR_NO_DEPENDENCY["definitions"]["post_comment_votes_stream"]["retriever"]["partition_router"][ - "parent_stream_configs" -][0]["incremental_dependency"] = False +SUBSTREAM_MANIFEST_GLOBAL_PARENT_CURSOR_NO_DEPENDENCY = copy.deepcopy( + SUBSTREAM_MANIFEST_GLOBAL_PARENT_CURSOR +) +SUBSTREAM_MANIFEST_GLOBAL_PARENT_CURSOR_NO_DEPENDENCY["definitions"]["post_comment_votes_stream"][ + "retriever" +]["partition_router"]["parent_stream_configs"][0]["incremental_dependency"] = False @pytest.mark.parametrize( @@ -1482,7 +1772,10 @@ def test_incremental_parent_state_no_incremental_dependency( ( "https://api.example.com/community/posts?per_page=100&start_time=2024-01-05T00:00:00Z", { - "posts": [{"id": 1, "updated_at": "2024-01-30T00:00:00Z"}, {"id": 2, "updated_at": "2024-01-29T00:00:00Z"}], + "posts": [ + {"id": 1, "updated_at": "2024-01-30T00:00:00Z"}, + {"id": 2, "updated_at": "2024-01-29T00:00:00Z"}, + ], "next_page": "https://api.example.com/community/posts?per_page=100&start_time=2024-01-05T00:00:00Z&page=2", }, ), @@ -1512,27 +1805,42 @@ def test_incremental_parent_state_no_incremental_dependency( ( "https://api.example.com/community/posts/1/comments/10/votes?per_page=100&start_time=2024-01-03T00:00:00Z", { - "votes": [{"id": 100, "comment_id": 10, "created_at": "2024-01-15T00:00:00Z"}], + "votes": [ + {"id": 100, "comment_id": 10, "created_at": "2024-01-15T00:00:00Z"} + ], "next_page": "https://api.example.com/community/posts/1/comments/10/votes?per_page=100&page=2&start_time=2024-01-03T00:00:01Z", }, ), # Fetch the second page of votes for comment 10 of post 1 ( "https://api.example.com/community/posts/1/comments/10/votes?per_page=100&page=2&start_time=2024-01-03T00:00:01Z", - {"votes": [{"id": 101, "comment_id": 10, "created_at": "2024-01-14T00:00:00Z"}]}, + { + "votes": [ + {"id": 101, "comment_id": 10, "created_at": "2024-01-14T00:00:00Z"} + ] + }, ), # Fetch the first page of votes for comment 11 of post 1 ( "https://api.example.com/community/posts/1/comments/11/votes?per_page=100&start_time=2024-01-03T00:00:00Z", - {"votes": [{"id": 102, "comment_id": 11, "created_at": "2024-01-13T00:00:00Z"}]}, + { + "votes": [ + {"id": 102, "comment_id": 11, "created_at": "2024-01-13T00:00:00Z"} + ] + }, ), # Fetch the first page of votes for comment 12 of post 1 - ("https://api.example.com/community/posts/1/comments/12/votes?per_page=100&start_time=2024-01-03T00:00:00Z", {"votes": []}), + ( + "https://api.example.com/community/posts/1/comments/12/votes?per_page=100&start_time=2024-01-03T00:00:00Z", + {"votes": []}, + ), # Fetch the first page of comments for post 2 ( "https://api.example.com/community/posts/2/comments?per_page=100", { - "comments": [{"id": 20, "post_id": 2, "updated_at": "2024-01-22T00:00:00Z"}], + "comments": [ + {"id": 20, "post_id": 2, "updated_at": "2024-01-22T00:00:00Z"} + ], "next_page": "https://api.example.com/community/posts/2/comments?per_page=100&page=2", }, ), @@ -1544,12 +1852,20 @@ def test_incremental_parent_state_no_incremental_dependency( # Fetch the first page of votes for comment 20 of post 2 ( "https://api.example.com/community/posts/2/comments/20/votes?per_page=100&start_time=2024-01-03T00:00:00Z", - {"votes": [{"id": 200, "comment_id": 20, "created_at": "2024-01-12T00:00:00Z"}]}, + { + "votes": [ + {"id": 200, "comment_id": 20, "created_at": "2024-01-12T00:00:00Z"} + ] + }, ), # Fetch the first page of votes for comment 21 of post 2 ( "https://api.example.com/community/posts/2/comments/21/votes?per_page=100&start_time=2024-01-03T00:00:00Z", - {"votes": [{"id": 201, "comment_id": 21, "created_at": "2024-01-12T00:00:15Z"}]}, + { + "votes": [ + {"id": 201, "comment_id": 21, "created_at": "2024-01-12T00:00:15Z"} + ] + }, ), # Fetch the first page of comments for post 3 ( @@ -1559,21 +1875,29 @@ def test_incremental_parent_state_no_incremental_dependency( # Fetch the first page of votes for comment 30 of post 3 ( "https://api.example.com/community/posts/3/comments/30/votes?per_page=100", - {"votes": [{"id": 300, "comment_id": 30, "created_at": "2024-01-10T00:00:00Z"}]}, + { + "votes": [ + {"id": 300, "comment_id": 30, "created_at": "2024-01-10T00:00:00Z"} + ] + }, ), # Requests with intermediate states # Fetch votes for comment 10 of post 1 ( "https://api.example.com/community/posts/1/comments/10/votes?per_page=100&start_time=2024-01-14T23:59:59Z", { - "votes": [{"id": 100, "comment_id": 10, "created_at": "2024-01-15T00:00:00Z"}], + "votes": [ + {"id": 100, "comment_id": 10, "created_at": "2024-01-15T00:00:00Z"} + ], }, ), # Fetch votes for comment 11 of post 1 ( "https://api.example.com/community/posts/1/comments/11/votes?per_page=100&start_time=2024-01-14T23:59:59Z", { - "votes": [{"id": 102, "comment_id": 11, "created_at": "2024-01-13T00:00:00Z"}], + "votes": [ + {"id": 102, "comment_id": 11, "created_at": "2024-01-13T00:00:00Z"} + ], }, ), # Fetch votes for comment 12 of post 1 @@ -1586,12 +1910,20 @@ def test_incremental_parent_state_no_incremental_dependency( # Fetch votes for comment 20 of post 2 ( "https://api.example.com/community/posts/2/comments/20/votes?per_page=100&start_time=2024-01-14T23:59:59Z", - {"votes": [{"id": 200, "comment_id": 20, "created_at": "2024-01-12T00:00:00Z"}]}, + { + "votes": [ + {"id": 200, "comment_id": 20, "created_at": "2024-01-12T00:00:00Z"} + ] + }, ), # Fetch votes for comment 21 of post 2 ( "https://api.example.com/community/posts/2/comments/21/votes?per_page=100&start_time=2024-01-14T23:59:59Z", - {"votes": [{"id": 201, "comment_id": 21, "created_at": "2024-01-12T00:00:15Z"}]}, + { + "votes": [ + {"id": 201, "comment_id": 21, "created_at": "2024-01-12T00:00:15Z"} + ] + }, ), ], # Expected records @@ -1608,15 +1940,22 @@ def test_incremental_parent_state_no_incremental_dependency( AirbyteStateMessage( type=AirbyteStateType.STREAM, stream=AirbyteStreamState( - stream_descriptor=StreamDescriptor(name="post_comment_votes", namespace=None), + stream_descriptor=StreamDescriptor( + name="post_comment_votes", namespace=None + ), stream_state=AirbyteStateBlob( { "parent_state": { "post_comments": { "states": [ - {"partition": {"id": 1, "parent_slice": {}}, "cursor": {"updated_at": "2023-01-04T00:00:00Z"}} + { + "partition": {"id": 1, "parent_slice": {}}, + "cursor": {"updated_at": "2023-01-04T00:00:00Z"}, + } ], - "parent_state": {"posts": {"updated_at": "2024-01-05T00:00:00Z"}}, + "parent_state": { + "posts": {"updated_at": "2024-01-05T00:00:00Z"} + }, } }, "state": {"created_at": "2024-01-04T02:03:04Z"}, @@ -1637,9 +1976,18 @@ def test_incremental_parent_state_no_incremental_dependency( "parent_state": {"posts": {"updated_at": "2024-01-30T00:00:00Z"}}, "lookback_window": 1, "states": [ - {"partition": {"id": 1, "parent_slice": {}}, "cursor": {"updated_at": "2024-01-25T00:00:00Z"}}, - {"partition": {"id": 2, "parent_slice": {}}, "cursor": {"updated_at": "2024-01-22T00:00:00Z"}}, - {"partition": {"id": 3, "parent_slice": {}}, "cursor": {"updated_at": "2024-01-09T00:00:00Z"}}, + { + "partition": {"id": 1, "parent_slice": {}}, + "cursor": {"updated_at": "2024-01-25T00:00:00Z"}, + }, + { + "partition": {"id": 2, "parent_slice": {}}, + "cursor": {"updated_at": "2024-01-22T00:00:00Z"}, + }, + { + "partition": {"id": 3, "parent_slice": {}}, + "cursor": {"updated_at": "2024-01-09T00:00:00Z"}, + }, ], } }, @@ -1653,7 +2001,10 @@ def test_incremental_parent_state_no_incremental_dependency( ( "https://api.example.com/community/posts?per_page=100&start_time=2024-01-01T00:00:01Z", { - "posts": [{"id": 1, "updated_at": "2024-01-30T00:00:00Z"}, {"id": 2, "updated_at": "2024-01-29T00:00:00Z"}], + "posts": [ + {"id": 1, "updated_at": "2024-01-30T00:00:00Z"}, + {"id": 2, "updated_at": "2024-01-29T00:00:00Z"}, + ], "next_page": "https://api.example.com/community/posts?per_page=100&start_time=2024-01-01T00:00:01Z&page=2", }, ), @@ -1683,27 +2034,42 @@ def test_incremental_parent_state_no_incremental_dependency( ( "https://api.example.com/community/posts/1/comments/10/votes?per_page=100&start_time=2024-01-03T00:00:00Z", { - "votes": [{"id": 100, "comment_id": 10, "created_at": "2024-01-15T00:00:00Z"}], + "votes": [ + {"id": 100, "comment_id": 10, "created_at": "2024-01-15T00:00:00Z"} + ], "next_page": "https://api.example.com/community/posts/1/comments/10/votes?per_page=100&page=2&start_time=2024-01-03T00:00:00Z", }, ), # Fetch the second page of votes for comment 10 of post 1 ( "https://api.example.com/community/posts/1/comments/10/votes?per_page=100&page=2&start_time=2024-01-03T00:00:00Z", - {"votes": [{"id": 101, "comment_id": 10, "created_at": "2024-01-14T00:00:00Z"}]}, + { + "votes": [ + {"id": 101, "comment_id": 10, "created_at": "2024-01-14T00:00:00Z"} + ] + }, ), # Fetch the first page of votes for comment 11 of post 1 ( "https://api.example.com/community/posts/1/comments/11/votes?per_page=100&start_time=2024-01-03T00:00:00Z", - {"votes": [{"id": 102, "comment_id": 11, "created_at": "2024-01-13T00:00:00Z"}]}, + { + "votes": [ + {"id": 102, "comment_id": 11, "created_at": "2024-01-13T00:00:00Z"} + ] + }, ), # Fetch the first page of votes for comment 12 of post 1 - ("https://api.example.com/community/posts/1/comments/12/votes?per_page=100&start_time=2024-01-03T00:00:00Z", {"votes": []}), + ( + "https://api.example.com/community/posts/1/comments/12/votes?per_page=100&start_time=2024-01-03T00:00:00Z", + {"votes": []}, + ), # Fetch the first page of comments for post 2 ( "https://api.example.com/community/posts/2/comments?per_page=100", { - "comments": [{"id": 20, "post_id": 2, "updated_at": "2024-01-22T00:00:00Z"}], + "comments": [ + {"id": 20, "post_id": 2, "updated_at": "2024-01-22T00:00:00Z"} + ], "next_page": "https://api.example.com/community/posts/2/comments?per_page=100&page=2", }, ), @@ -1715,12 +2081,20 @@ def test_incremental_parent_state_no_incremental_dependency( # Fetch the first page of votes for comment 20 of post 2 ( "https://api.example.com/community/posts/2/comments/20/votes?per_page=100&start_time=2024-01-03T00:00:00Z", - {"votes": [{"id": 200, "comment_id": 20, "created_at": "2024-01-12T00:00:00Z"}]}, + { + "votes": [ + {"id": 200, "comment_id": 20, "created_at": "2024-01-12T00:00:00Z"} + ] + }, ), # Fetch the first page of votes for comment 21 of post 2 ( "https://api.example.com/community/posts/2/comments/21/votes?per_page=100&start_time=2024-01-03T00:00:00Z", - {"votes": [{"id": 201, "comment_id": 21, "created_at": "2024-01-12T00:00:15Z"}]}, + { + "votes": [ + {"id": 201, "comment_id": 21, "created_at": "2024-01-12T00:00:15Z"} + ] + }, ), # Fetch the first page of comments for post 3 ( @@ -1730,7 +2104,11 @@ def test_incremental_parent_state_no_incremental_dependency( # Fetch the first page of votes for comment 30 of post 3 ( "https://api.example.com/community/posts/3/comments/30/votes?per_page=100", - {"votes": [{"id": 300, "comment_id": 30, "created_at": "2024-01-10T00:00:00Z"}]}, + { + "votes": [ + {"id": 300, "comment_id": 30, "created_at": "2024-01-10T00:00:00Z"} + ] + }, ), ], # Expected records @@ -1747,15 +2125,22 @@ def test_incremental_parent_state_no_incremental_dependency( AirbyteStateMessage( type=AirbyteStateType.STREAM, stream=AirbyteStreamState( - stream_descriptor=StreamDescriptor(name="post_comment_votes", namespace=None), + stream_descriptor=StreamDescriptor( + name="post_comment_votes", namespace=None + ), stream_state=AirbyteStateBlob( { "parent_state": { "post_comments": { "states": [ - {"partition": {"id": 1, "parent_slice": {}}, "cursor": {"updated_at": "2023-01-04T00:00:00Z"}} + { + "partition": {"id": 1, "parent_slice": {}}, + "cursor": {"updated_at": "2023-01-04T00:00:00Z"}, + } ], - "parent_state": {"posts": {"updated_at": "2024-01-05T00:00:00Z"}}, + "parent_state": { + "posts": {"updated_at": "2024-01-05T00:00:00Z"} + }, } }, "state": {"created_at": "2024-01-04T02:03:04Z"}, @@ -1770,5 +2155,9 @@ def test_incremental_parent_state_no_incremental_dependency( ), ], ) -def test_incremental_global_parent_state(test_name, manifest, mock_requests, expected_records, initial_state, expected_state): - run_incremental_parent_state_test(manifest, mock_requests, expected_records, initial_state, [expected_state]) +def test_incremental_global_parent_state( + test_name, manifest, mock_requests, expected_records, initial_state, expected_state +): + run_incremental_parent_state_test( + manifest, mock_requests, expected_records, initial_state, [expected_state] + ) diff --git a/unit_tests/sources/declarative/partition_routers/test_single_partition_router.py b/unit_tests/sources/declarative/partition_routers/test_single_partition_router.py index b1512dc5..82cc7ba3 100644 --- a/unit_tests/sources/declarative/partition_routers/test_single_partition_router.py +++ b/unit_tests/sources/declarative/partition_routers/test_single_partition_router.py @@ -2,7 +2,9 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # -from airbyte_cdk.sources.declarative.partition_routers.single_partition_router import SinglePartitionRouter +from airbyte_cdk.sources.declarative.partition_routers.single_partition_router import ( + SinglePartitionRouter, +) from airbyte_cdk.sources.types import StreamSlice diff --git a/unit_tests/sources/declarative/partition_routers/test_substream_partition_router.py b/unit_tests/sources/declarative/partition_routers/test_substream_partition_router.py index f29917ab..f42bd554 100644 --- a/unit_tests/sources/declarative/partition_routers/test_substream_partition_router.py +++ b/unit_tests/sources/declarative/partition_routers/test_substream_partition_router.py @@ -9,20 +9,42 @@ import pytest as pytest from airbyte_cdk.models import AirbyteMessage, AirbyteRecordMessage, SyncMode, Type from airbyte_cdk.sources.declarative.declarative_stream import DeclarativeStream -from airbyte_cdk.sources.declarative.incremental import ChildPartitionResumableFullRefreshCursor, ResumableFullRefreshCursor -from airbyte_cdk.sources.declarative.incremental.per_partition_cursor import CursorFactory, PerPartitionCursor, StreamSlice +from airbyte_cdk.sources.declarative.incremental import ( + ChildPartitionResumableFullRefreshCursor, + ResumableFullRefreshCursor, +) +from airbyte_cdk.sources.declarative.incremental.per_partition_cursor import ( + CursorFactory, + PerPartitionCursor, + StreamSlice, +) from airbyte_cdk.sources.declarative.interpolation import InterpolatedString -from airbyte_cdk.sources.declarative.partition_routers import CartesianProductStreamSlicer, ListPartitionRouter -from airbyte_cdk.sources.declarative.partition_routers.substream_partition_router import ParentStreamConfig, SubstreamPartitionRouter -from airbyte_cdk.sources.declarative.requesters.request_option import RequestOption, RequestOptionType +from airbyte_cdk.sources.declarative.partition_routers import ( + CartesianProductStreamSlicer, + ListPartitionRouter, +) +from airbyte_cdk.sources.declarative.partition_routers.substream_partition_router import ( + ParentStreamConfig, + SubstreamPartitionRouter, +) +from airbyte_cdk.sources.declarative.requesters.request_option import ( + RequestOption, + RequestOptionType, +) from airbyte_cdk.sources.streams.checkpoint import Cursor from airbyte_cdk.sources.types import Record from airbyte_cdk.utils import AirbyteTracedException parent_records = [{"id": 1, "data": "data1"}, {"id": 2, "data": "data2"}] -more_records = [{"id": 10, "data": "data10", "slice": "second_parent"}, {"id": 20, "data": "data20", "slice": "second_parent"}] +more_records = [ + {"id": 10, "data": "data10", "slice": "second_parent"}, + {"id": 20, "data": "data20", "slice": "second_parent"}, +] -data_first_parent_slice = [{"id": 0, "slice": "first", "data": "A"}, {"id": 1, "slice": "first", "data": "B"}] +data_first_parent_slice = [ + {"id": 0, "slice": "first", "data": "A"}, + {"id": 1, "slice": "first", "data": "B"}, +] data_second_parent_slice = [{"id": 2, "slice": "second", "data": "C"}] data_third_parent_slice = [] all_parent_data = data_first_parent_slice + data_second_parent_slice + data_third_parent_slice @@ -33,8 +55,12 @@ {"id": 0, "slice": "first", "data": "A", "cursor": "first_cursor_0"}, {"id": 1, "slice": "first", "data": "B", "cursor": "first_cursor_1"}, ] -data_second_parent_slice_with_cursor = [{"id": 2, "slice": "second", "data": "C", "cursor": "second_cursor_2"}] -all_parent_data_with_cursor = data_first_parent_slice_with_cursor + data_second_parent_slice_with_cursor +data_second_parent_slice_with_cursor = [ + {"id": 2, "slice": "second", "data": "C", "cursor": "second_cursor_2"} +] +all_parent_data_with_cursor = ( + data_first_parent_slice_with_cursor + data_second_parent_slice_with_cursor +) class MockStream(DeclarativeStream): @@ -43,7 +69,9 @@ def __init__(self, slices, records, name, cursor_field="", cursor=None): self._slices = slices self._records = records self._stream_cursor_field = ( - InterpolatedString.create(cursor_field, parameters={}) if isinstance(cursor_field, str) else cursor_field + InterpolatedString.create(cursor_field, parameters={}) + if isinstance(cursor_field, str) + else cursor_field ) self._name = name self._state = {"states": []} @@ -73,7 +101,11 @@ def get_cursor(self) -> Optional[Cursor]: return self._cursor def stream_slices( - self, *, sync_mode: SyncMode, cursor_field: List[str] = None, stream_state: Mapping[str, Any] = None + self, + *, + sync_mode: SyncMode, + cursor_field: List[str] = None, + stream_state: Mapping[str, Any] = None, ) -> Iterable[Optional[StreamSlice]]: for s in self._slices: if isinstance(s, StreamSlice): @@ -94,14 +126,20 @@ def read_records( if not stream_slice: result = self._records else: - result = [Record(data=r, associated_slice=stream_slice) for r in self._records if r["slice"] == stream_slice["slice"]] + result = [ + Record(data=r, associated_slice=stream_slice) + for r in self._records + if r["slice"] == stream_slice["slice"] + ] yield from result # Update the state only after reading the full slice cursor_field = self._stream_cursor_field.eval(config=self.config) if stream_slice and cursor_field and result: - self._state["states"].append({cursor_field: result[-1][cursor_field], "partition": stream_slice["slice"]}) + self._state["states"].append( + {cursor_field: result[-1][cursor_field], "partition": stream_slice["slice"]} + ) def get_json_schema(self) -> Mapping[str, Any]: return {} @@ -122,7 +160,11 @@ def read_records( stream_slice: Mapping[str, Any] = None, stream_state: Mapping[str, Any] = None, ) -> Iterable[Mapping[str, Any]]: - results = [record for record in self._records if stream_slice["start_time"] <= record["updated_at"] <= stream_slice["end_time"]] + results = [ + record + for record in self._records + if stream_slice["start_time"] <= record["updated_at"] <= stream_slice["end_time"] + ] print(f"about to emit {results}") yield from results print(f"setting state to {stream_slice}") @@ -130,7 +172,14 @@ def read_records( class MockResumableFullRefreshStream(MockStream): - def __init__(self, slices, name, cursor_field="", cursor=None, record_pages: Optional[List[List[Mapping[str, Any]]]] = None): + def __init__( + self, + slices, + name, + cursor_field="", + cursor=None, + record_pages: Optional[List[List[Mapping[str, Any]]]] = None, + ): super().__init__(slices, [], name, cursor_field, cursor) if record_pages: self._record_pages = record_pages @@ -150,9 +199,13 @@ def read_records( cursor = self.get_cursor() if page_number < len(self._record_pages): - cursor.close_slice(StreamSlice(cursor_slice={"next_page_token": page_number + 1}, partition={})) + cursor.close_slice( + StreamSlice(cursor_slice={"next_page_token": page_number + 1}, partition={}) + ) else: - cursor.close_slice(StreamSlice(cursor_slice={"__ab_full_refresh_sync_complete": True}, partition={})) + cursor.close_slice( + StreamSlice(cursor_slice={"__ab_full_refresh_sync_complete": True}, partition={}) + ) @property def state(self) -> Mapping[str, Any]: @@ -190,7 +243,10 @@ def state(self, value: Mapping[str, Any]) -> None: config={}, ) ], - [{"first_stream_id": 1, "parent_slice": {}}, {"first_stream_id": 2, "parent_slice": {}}], + [ + {"first_stream_id": 1, "parent_slice": {}}, + {"first_stream_id": 2, "parent_slice": {}}, + ], ), ( [ @@ -212,7 +268,10 @@ def state(self, value: Mapping[str, Any]) -> None: [ ParentStreamConfig( stream=MockStream( - [StreamSlice(partition=p, cursor_slice={"start": 0, "end": 1}) for p in parent_slices], + [ + StreamSlice(partition=p, cursor_slice={"start": 0, "end": 1}) + for p in parent_slices + ], all_parent_data, "first_stream", ), @@ -231,7 +290,11 @@ def state(self, value: Mapping[str, Any]) -> None: ( [ ParentStreamConfig( - stream=MockStream(parent_slices, data_first_parent_slice + data_second_parent_slice, "first_stream"), + stream=MockStream( + parent_slices, + data_first_parent_slice + data_second_parent_slice, + "first_stream", + ), parent_key="id", partition_field="first_stream_id", parameters={}, @@ -256,7 +319,9 @@ def state(self, value: Mapping[str, Any]) -> None: ( [ ParentStreamConfig( - stream=MockStream([{}], [{"id": 0}, {"id": 1}, {"_id": 2}, {"id": 3}], "first_stream"), + stream=MockStream( + [{}], [{"id": 0}, {"id": 1}, {"_id": 2}, {"id": 3}], "first_stream" + ), parent_key="id", partition_field="first_stream_id", parameters={}, @@ -272,7 +337,11 @@ def state(self, value: Mapping[str, Any]) -> None: ( [ ParentStreamConfig( - stream=MockStream([{}], [{"a": {"b": 0}}, {"a": {"b": 1}}, {"a": {"c": 2}}, {"a": {"b": 3}}], "first_stream"), + stream=MockStream( + [{}], + [{"a": {"b": 0}}, {"a": {"b": 1}}, {"a": {"c": 2}}, {"a": {"b": 3}}], + "first_stream", + ), parent_key="a/b", partition_field="first_stream_id", parameters={}, @@ -300,11 +369,15 @@ def state(self, value: Mapping[str, Any]) -> None: def test_substream_partition_router(parent_stream_configs, expected_slices): if expected_slices is None: try: - SubstreamPartitionRouter(parent_stream_configs=parent_stream_configs, parameters={}, config={}) + SubstreamPartitionRouter( + parent_stream_configs=parent_stream_configs, parameters={}, config={} + ) assert False except ValueError: return - partition_router = SubstreamPartitionRouter(parent_stream_configs=parent_stream_configs, parameters={}, config={}) + partition_router = SubstreamPartitionRouter( + parent_stream_configs=parent_stream_configs, parameters={}, config={} + ) slices = [s for s in partition_router.stream_slices()] assert slices == expected_slices @@ -333,8 +406,16 @@ def test_substream_partition_router_invalid_parent_record_type(): [ ( [ - RequestOption(inject_into=RequestOptionType.request_parameter, parameters={}, field_name="first_stream"), - RequestOption(inject_into=RequestOptionType.request_parameter, parameters={}, field_name="second_stream"), + RequestOption( + inject_into=RequestOptionType.request_parameter, + parameters={}, + field_name="first_stream", + ), + RequestOption( + inject_into=RequestOptionType.request_parameter, + parameters={}, + field_name="second_stream", + ), ], {"first_stream": "1234", "second_stream": "4567"}, {}, @@ -343,8 +424,12 @@ def test_substream_partition_router_invalid_parent_record_type(): ), ( [ - RequestOption(inject_into=RequestOptionType.header, parameters={}, field_name="first_stream"), - RequestOption(inject_into=RequestOptionType.header, parameters={}, field_name="second_stream"), + RequestOption( + inject_into=RequestOptionType.header, parameters={}, field_name="first_stream" + ), + RequestOption( + inject_into=RequestOptionType.header, parameters={}, field_name="second_stream" + ), ], {}, {"first_stream": "1234", "second_stream": "4567"}, @@ -353,8 +438,14 @@ def test_substream_partition_router_invalid_parent_record_type(): ), ( [ - RequestOption(inject_into=RequestOptionType.request_parameter, parameters={}, field_name="first_stream"), - RequestOption(inject_into=RequestOptionType.header, parameters={}, field_name="second_stream"), + RequestOption( + inject_into=RequestOptionType.request_parameter, + parameters={}, + field_name="first_stream", + ), + RequestOption( + inject_into=RequestOptionType.header, parameters={}, field_name="second_stream" + ), ], {"first_stream": "1234"}, {"second_stream": "4567"}, @@ -363,8 +454,16 @@ def test_substream_partition_router_invalid_parent_record_type(): ), ( [ - RequestOption(inject_into=RequestOptionType.body_json, parameters={}, field_name="first_stream"), - RequestOption(inject_into=RequestOptionType.body_json, parameters={}, field_name="second_stream"), + RequestOption( + inject_into=RequestOptionType.body_json, + parameters={}, + field_name="first_stream", + ), + RequestOption( + inject_into=RequestOptionType.body_json, + parameters={}, + field_name="second_stream", + ), ], {}, {}, @@ -373,8 +472,16 @@ def test_substream_partition_router_invalid_parent_record_type(): ), ( [ - RequestOption(inject_into=RequestOptionType.body_data, parameters={}, field_name="first_stream"), - RequestOption(inject_into=RequestOptionType.body_data, parameters={}, field_name="second_stream"), + RequestOption( + inject_into=RequestOptionType.body_data, + parameters={}, + field_name="first_stream", + ), + RequestOption( + inject_into=RequestOptionType.body_data, + parameters={}, + field_name="second_stream", + ), ], {}, {}, @@ -400,7 +507,11 @@ def test_request_option( partition_router = SubstreamPartitionRouter( parent_stream_configs=[ ParentStreamConfig( - stream=MockStream(parent_slices, data_first_parent_slice + data_second_parent_slice, "first_stream"), + stream=MockStream( + parent_slices, + data_first_parent_slice + data_second_parent_slice, + "first_stream", + ), parent_key="id", partition_field="first_stream_id", parameters={}, @@ -432,7 +543,12 @@ def test_request_option( [ ( ParentStreamConfig( - stream=MockStream(parent_slices, all_parent_data_with_cursor, "first_stream", cursor_field="cursor"), + stream=MockStream( + parent_slices, + all_parent_data_with_cursor, + "first_stream", + cursor_field="cursor", + ), parent_key="id", partition_field="first_stream_id", parameters={}, @@ -441,7 +557,10 @@ def test_request_option( ), { "first_stream": { - "states": [{"cursor": "first_cursor_1", "partition": "first"}, {"cursor": "second_cursor_2", "partition": "second"}] + "states": [ + {"cursor": "first_cursor_1", "partition": "first"}, + {"cursor": "second_cursor_2", "partition": "second"}, + ] } }, ), @@ -451,7 +570,9 @@ def test_request_option( ], ) def test_substream_slicer_parent_state_update_with_cursor(parent_stream_config, expected_state): - partition_router = SubstreamPartitionRouter(parent_stream_configs=[parent_stream_config], parameters={}, config={}) + partition_router = SubstreamPartitionRouter( + parent_stream_configs=[parent_stream_config], parameters={}, config={} + ) # Simulate reading the records and updating the state for _ in partition_router.stream_slices(): @@ -484,18 +605,30 @@ def test_substream_slicer_parent_state_update_with_cursor(parent_stream_config, def test_request_params_interpolation_for_parent_stream( field_name_first_stream: str, field_name_second_stream: str, expected_request_params: dict ): - config = {"field_name_first_stream": "config_first_stream_id", "field_name_second_stream": "config_second_stream_id"} - parameters = {"field_name_first_stream": "parameter_first_stream_id", "field_name_second_stream": "parameter_second_stream_id"} + config = { + "field_name_first_stream": "config_first_stream_id", + "field_name_second_stream": "config_second_stream_id", + } + parameters = { + "field_name_first_stream": "parameter_first_stream_id", + "field_name_second_stream": "parameter_second_stream_id", + } partition_router = SubstreamPartitionRouter( parent_stream_configs=[ ParentStreamConfig( - stream=MockStream(parent_slices, data_first_parent_slice + data_second_parent_slice, "first_stream"), + stream=MockStream( + parent_slices, + data_first_parent_slice + data_second_parent_slice, + "first_stream", + ), parent_key="id", partition_field="first_stream_id", parameters=parameters, config=config, request_option=RequestOption( - inject_into=RequestOptionType.request_parameter, parameters=parameters, field_name=field_name_first_stream + inject_into=RequestOptionType.request_parameter, + parameters=parameters, + field_name=field_name_first_stream, ), ), ParentStreamConfig( @@ -505,7 +638,9 @@ def test_request_params_interpolation_for_parent_stream( parameters=parameters, config=config, request_option=RequestOption( - inject_into=RequestOptionType.request_parameter, parameters=parameters, field_name=field_name_second_stream + inject_into=RequestOptionType.request_parameter, + parameters=parameters, + field_name=field_name_second_stream, ), ), ], @@ -526,7 +661,10 @@ def test_given_record_is_airbyte_message_when_stream_slices_then_use_record_data [parent_slice], [ AirbyteMessage( - type=Type.RECORD, record=AirbyteRecordMessage(data={"id": "record value"}, emitted_at=0, stream="stream") + type=Type.RECORD, + record=AirbyteRecordMessage( + data={"id": "record value"}, emitted_at=0, stream="stream" + ), ) ], "first_stream", @@ -550,7 +688,9 @@ def test_given_record_is_record_object_when_stream_slices_then_use_record_data() partition_router = SubstreamPartitionRouter( parent_stream_configs=[ ParentStreamConfig( - stream=MockStream([parent_slice], [Record({"id": "record value"}, {})], "first_stream"), + stream=MockStream( + [parent_slice], [Record({"id": "record value"}, {})], "first_stream" + ), parent_key="id", partition_field="partition_field", parameters={}, @@ -567,8 +707,12 @@ def test_given_record_is_record_object_when_stream_slices_then_use_record_data() def test_substream_using_incremental_parent_stream(): mock_slices = [ - StreamSlice(cursor_slice={"start_time": "2024-04-27", "end_time": "2024-05-27"}, partition={}), - StreamSlice(cursor_slice={"start_time": "2024-05-27", "end_time": "2024-06-27"}, partition={}), + StreamSlice( + cursor_slice={"start_time": "2024-04-27", "end_time": "2024-05-27"}, partition={} + ), + StreamSlice( + cursor_slice={"start_time": "2024-05-27", "end_time": "2024-06-27"}, partition={} + ), ] expected_slices = [ @@ -612,8 +756,12 @@ def test_substream_checkpoints_after_each_parent_partition(): parent records for the parent slice (not just the substream) """ mock_slices = [ - StreamSlice(cursor_slice={"start_time": "2024-04-27", "end_time": "2024-05-27"}, partition={}), - StreamSlice(cursor_slice={"start_time": "2024-05-27", "end_time": "2024-06-27"}, partition={}), + StreamSlice( + cursor_slice={"start_time": "2024-04-27", "end_time": "2024-05-27"}, partition={} + ), + StreamSlice( + cursor_slice={"start_time": "2024-05-27", "end_time": "2024-06-27"}, partition={} + ), ] expected_slices = [ @@ -667,7 +815,10 @@ def test_substream_checkpoints_after_each_parent_partition(): "use_incremental_dependency", [ pytest.param(False, id="test_resumable_full_refresh_stream_without_parent_checkpoint"), - pytest.param(True, id="test_resumable_full_refresh_stream_with_use_incremental_dependency_for_parent_checkpoint"), + pytest.param( + True, + id="test_resumable_full_refresh_stream_with_use_incremental_dependency_for_parent_checkpoint", + ), ], ) def test_substream_using_resumable_full_refresh_parent_stream(use_incremental_dependency): @@ -742,8 +893,13 @@ def test_substream_using_resumable_full_refresh_parent_stream(use_incremental_de @pytest.mark.parametrize( "use_incremental_dependency", [ - pytest.param(False, id="test_substream_resumable_full_refresh_stream_without_parent_checkpoint"), - pytest.param(True, id="test_substream_resumable_full_refresh_stream_with_use_incremental_dependency_for_parent_checkpoint"), + pytest.param( + False, id="test_substream_resumable_full_refresh_stream_without_parent_checkpoint" + ), + pytest.param( + True, + id="test_substream_resumable_full_refresh_stream_with_use_incremental_dependency_for_parent_checkpoint", + ), ], ) def test_substream_using_resumable_full_refresh_parent_stream_slices(use_incremental_dependency): @@ -774,12 +930,30 @@ def test_substream_using_resumable_full_refresh_parent_stream_slices(use_increme expected_substream_state = { "states": [ - {"partition": {"parent_slice": {}, "partition_field": "makoto_yuki"}, "cursor": {"__ab_full_refresh_sync_complete": True}}, - {"partition": {"parent_slice": {}, "partition_field": "yukari_takeba"}, "cursor": {"__ab_full_refresh_sync_complete": True}}, - {"partition": {"parent_slice": {}, "partition_field": "mitsuru_kirijo"}, "cursor": {"__ab_full_refresh_sync_complete": True}}, - {"partition": {"parent_slice": {}, "partition_field": "akihiko_sanada"}, "cursor": {"__ab_full_refresh_sync_complete": True}}, - {"partition": {"parent_slice": {}, "partition_field": "junpei_iori"}, "cursor": {"__ab_full_refresh_sync_complete": True}}, - {"partition": {"parent_slice": {}, "partition_field": "fuuka_yamagishi"}, "cursor": {"__ab_full_refresh_sync_complete": True}}, + { + "partition": {"parent_slice": {}, "partition_field": "makoto_yuki"}, + "cursor": {"__ab_full_refresh_sync_complete": True}, + }, + { + "partition": {"parent_slice": {}, "partition_field": "yukari_takeba"}, + "cursor": {"__ab_full_refresh_sync_complete": True}, + }, + { + "partition": {"parent_slice": {}, "partition_field": "mitsuru_kirijo"}, + "cursor": {"__ab_full_refresh_sync_complete": True}, + }, + { + "partition": {"parent_slice": {}, "partition_field": "akihiko_sanada"}, + "cursor": {"__ab_full_refresh_sync_complete": True}, + }, + { + "partition": {"parent_slice": {}, "partition_field": "junpei_iori"}, + "cursor": {"__ab_full_refresh_sync_complete": True}, + }, + { + "partition": {"parent_slice": {}, "partition_field": "fuuka_yamagishi"}, + "cursor": {"__ab_full_refresh_sync_complete": True}, + }, ], "parent_state": {"persona_3_characters": {"__ab_full_refresh_sync_complete": True}}, } @@ -792,16 +966,31 @@ def test_substream_using_resumable_full_refresh_parent_stream_slices(use_increme cursor=ResumableFullRefreshCursor(parameters={}), record_pages=[ [ - Record(data={"id": "makoto_yuki"}, associated_slice=mock_parent_slices[0]), - Record(data={"id": "yukari_takeba"}, associated_slice=mock_parent_slices[0]), + Record( + data={"id": "makoto_yuki"}, associated_slice=mock_parent_slices[0] + ), + Record( + data={"id": "yukari_takeba"}, associated_slice=mock_parent_slices[0] + ), ], [ - Record(data={"id": "mitsuru_kirijo"}, associated_slice=mock_parent_slices[1]), - Record(data={"id": "akihiko_sanada"}, associated_slice=mock_parent_slices[1]), + Record( + data={"id": "mitsuru_kirijo"}, + associated_slice=mock_parent_slices[1], + ), + Record( + data={"id": "akihiko_sanada"}, + associated_slice=mock_parent_slices[1], + ), ], [ - Record(data={"id": "junpei_iori"}, associated_slice=mock_parent_slices[2]), - Record(data={"id": "fuuka_yamagishi"}, associated_slice=mock_parent_slices[2]), + Record( + data={"id": "junpei_iori"}, associated_slice=mock_parent_slices[2] + ), + Record( + data={"id": "fuuka_yamagishi"}, + associated_slice=mock_parent_slices[2], + ), ], ], name="persona_3_characters", @@ -818,7 +1007,9 @@ def test_substream_using_resumable_full_refresh_parent_stream_slices(use_increme ) substream_cursor_slicer = PerPartitionCursor( - cursor_factory=CursorFactory(create_function=partial(ChildPartitionResumableFullRefreshCursor, {})), + cursor_factory=CursorFactory( + create_function=partial(ChildPartitionResumableFullRefreshCursor, {}) + ), partition_router=partition_router, ) @@ -830,17 +1021,27 @@ def test_substream_using_resumable_full_refresh_parent_stream_slices(use_increme assert actual_slice == expected_parent_slices[expected_counter] # check for parent state if use_incremental_dependency: - assert substream_cursor_slicer._partition_router.get_stream_state() == expected_parent_state[expected_counter] + assert ( + substream_cursor_slicer._partition_router.get_stream_state() + == expected_parent_state[expected_counter] + ) expected_counter += 1 if use_incremental_dependency: - assert substream_cursor_slicer._partition_router.get_stream_state() == expected_parent_state[expected_counter] + assert ( + substream_cursor_slicer._partition_router.get_stream_state() + == expected_parent_state[expected_counter] + ) # validate final state for closed substream slices final_state = substream_cursor_slicer.get_stream_state() if not use_incremental_dependency: - assert final_state["states"] == expected_substream_state["states"], "State for substreams is not valid!" + assert ( + final_state["states"] == expected_substream_state["states"] + ), "State for substreams is not valid!" else: - assert final_state == expected_substream_state, "State for substreams with incremental dependency is not valid!" + assert ( + final_state == expected_substream_state + ), "State for substreams with incremental dependency is not valid!" @pytest.mark.parametrize( @@ -852,8 +1053,16 @@ def test_substream_using_resumable_full_refresh_parent_stream_slices(use_increme stream=MockStream( [{}], [ - {"id": 1, "field_1": "value_1", "field_2": {"nested_field": "nested_value_1"}}, - {"id": 2, "field_1": "value_2", "field_2": {"nested_field": "nested_value_2"}}, + { + "id": 1, + "field_1": "value_1", + "field_2": {"nested_field": "nested_value_1"}, + }, + { + "id": 2, + "field_1": "value_2", + "field_2": {"nested_field": "nested_value_2"}, + }, ], "first_stream", ), @@ -872,7 +1081,11 @@ def test_substream_using_resumable_full_refresh_parent_stream_slices(use_increme ( [ ParentStreamConfig( - stream=MockStream([{}], [{"id": 1, "field_1": "value_1"}, {"id": 2, "field_1": "value_2"}], "first_stream"), + stream=MockStream( + [{}], + [{"id": 1, "field_1": "value_1"}, {"id": 2, "field_1": "value_2"}], + "first_stream", + ), parent_key="id", partition_field="first_stream_id", extra_fields=[["field_1"]], @@ -889,7 +1102,9 @@ def test_substream_using_resumable_full_refresh_parent_stream_slices(use_increme ], ) def test_substream_partition_router_with_extra_keys(parent_stream_configs, expected_slices): - partition_router = SubstreamPartitionRouter(parent_stream_configs=parent_stream_configs, parameters={}, config={}) + partition_router = SubstreamPartitionRouter( + parent_stream_configs=parent_stream_configs, parameters={}, config={} + ) slices = [s.extra_fields for s in partition_router.stream_slices()] assert slices == expected_slices @@ -900,19 +1115,34 @@ def test_substream_partition_router_with_extra_keys(parent_stream_configs, expec # Case with two ListPartitionRouters, no warning expected ( [ - ListPartitionRouter(values=["1", "2", "3"], cursor_field="partition_field", config={}, parameters={}), - ListPartitionRouter(values=["1", "2", "3"], cursor_field="partition_field", config={}, parameters={}), + ListPartitionRouter( + values=["1", "2", "3"], cursor_field="partition_field", config={}, parameters={} + ), + ListPartitionRouter( + values=["1", "2", "3"], cursor_field="partition_field", config={}, parameters={} + ), ], False, ), # Case with a SubstreamPartitionRouter, warning expected ( [ - ListPartitionRouter(values=["1", "2", "3"], cursor_field="partition_field", config={}, parameters={}), + ListPartitionRouter( + values=["1", "2", "3"], cursor_field="partition_field", config={}, parameters={} + ), SubstreamPartitionRouter( parent_stream_configs=[ ParentStreamConfig( - stream=MockStream([{}], [{"a": {"b": 0}}, {"a": {"b": 1}}, {"a": {"c": 2}}, {"a": {"b": 3}}], "first_stream"), + stream=MockStream( + [{}], + [ + {"a": {"b": 0}}, + {"a": {"b": 1}}, + {"a": {"c": 2}}, + {"a": {"b": 3}}, + ], + "first_stream", + ), parent_key="a/b", partition_field="first_stream_id", parameters={}, @@ -928,15 +1158,29 @@ def test_substream_partition_router_with_extra_keys(parent_stream_configs, expec # Case with nested CartesianProductStreamSlicer containing a SubstreamPartitionRouter, warning expected ( [ - ListPartitionRouter(values=["1", "2", "3"], cursor_field="partition_field", config={}, parameters={}), + ListPartitionRouter( + values=["1", "2", "3"], cursor_field="partition_field", config={}, parameters={} + ), CartesianProductStreamSlicer( stream_slicers=[ - ListPartitionRouter(values=["1", "2", "3"], cursor_field="partition_field", config={}, parameters={}), + ListPartitionRouter( + values=["1", "2", "3"], + cursor_field="partition_field", + config={}, + parameters={}, + ), SubstreamPartitionRouter( parent_stream_configs=[ ParentStreamConfig( stream=MockStream( - [{}], [{"a": {"b": 0}}, {"a": {"b": 1}}, {"a": {"c": 2}}, {"a": {"b": 3}}], "first_stream" + [{}], + [ + {"a": {"b": 0}}, + {"a": {"b": 1}}, + {"a": {"c": 2}}, + {"a": {"b": 3}}, + ], + "first_stream", ), parent_key="a/b", partition_field="first_stream_id", @@ -955,7 +1199,9 @@ def test_substream_partition_router_with_extra_keys(parent_stream_configs, expec ), ], ) -def test_cartesian_product_stream_slicer_warning_log_message(caplog, stream_slicers, expect_warning): +def test_cartesian_product_stream_slicer_warning_log_message( + caplog, stream_slicers, expect_warning +): """Test that a warning is logged when SubstreamPartitionRouter is used within a CartesianProductStreamSlicer.""" warning_message = "Parent state handling is not supported for CartesianProductStreamSlicer." diff --git a/unit_tests/sources/declarative/requesters/error_handlers/backoff_strategies/test_constant_backoff.py b/unit_tests/sources/declarative/requesters/error_handlers/backoff_strategies/test_constant_backoff.py index 3a931a0d..eb2ecc1d 100644 --- a/unit_tests/sources/declarative/requesters/error_handlers/backoff_strategies/test_constant_backoff.py +++ b/unit_tests/sources/declarative/requesters/error_handlers/backoff_strategies/test_constant_backoff.py @@ -5,7 +5,9 @@ from unittest.mock import MagicMock import pytest -from airbyte_cdk.sources.declarative.requesters.error_handlers.backoff_strategies.constant_backoff_strategy import ConstantBackoffStrategy +from airbyte_cdk.sources.declarative.requesters.error_handlers.backoff_strategies.constant_backoff_strategy import ( + ConstantBackoffStrategy, +) BACKOFF_TIME = 10 PARAMETERS_BACKOFF_TIME = 20 @@ -21,14 +23,21 @@ ("test_constant_backoff_attempt_round_float", 1.5, 6.7, 6.7), ("test_constant_backoff_first_attempt_round_float", 1, 10.0, BACKOFF_TIME), ("test_constant_backoff_second_attempt_round_float", 2, 10.0, BACKOFF_TIME), - ("test_constant_backoff_from_parameters", 1, "{{ parameters['backoff'] }}", PARAMETERS_BACKOFF_TIME), + ( + "test_constant_backoff_from_parameters", + 1, + "{{ parameters['backoff'] }}", + PARAMETERS_BACKOFF_TIME, + ), ("test_constant_backoff_from_config", 1, "{{ config['backoff'] }}", CONFIG_BACKOFF_TIME), ], ) def test_constant_backoff(test_name, attempt_count, backofftime, expected_backoff_time): response_mock = MagicMock() backoff_strategy = ConstantBackoffStrategy( - parameters={"backoff": PARAMETERS_BACKOFF_TIME}, backoff_time_in_seconds=backofftime, config={"backoff": CONFIG_BACKOFF_TIME} + parameters={"backoff": PARAMETERS_BACKOFF_TIME}, + backoff_time_in_seconds=backofftime, + config={"backoff": CONFIG_BACKOFF_TIME}, ) backoff = backoff_strategy.backoff_time(response_mock, attempt_count=attempt_count) assert backoff == expected_backoff_time diff --git a/unit_tests/sources/declarative/requesters/error_handlers/backoff_strategies/test_exponential_backoff.py b/unit_tests/sources/declarative/requesters/error_handlers/backoff_strategies/test_exponential_backoff.py index a99050a7..3e5b4c90 100644 --- a/unit_tests/sources/declarative/requesters/error_handlers/backoff_strategies/test_exponential_backoff.py +++ b/unit_tests/sources/declarative/requesters/error_handlers/backoff_strategies/test_exponential_backoff.py @@ -24,7 +24,9 @@ ) def test_exponential_backoff(test_name, attempt_count, factor, expected_backoff_time): response_mock = MagicMock() - backoff_strategy = ExponentialBackoffStrategy(factor=factor, parameters=parameters, config=config) + backoff_strategy = ExponentialBackoffStrategy( + factor=factor, parameters=parameters, config=config + ) backoff = backoff_strategy.backoff_time(response_mock, attempt_count=attempt_count) assert backoff == expected_backoff_time diff --git a/unit_tests/sources/declarative/requesters/error_handlers/backoff_strategies/test_header_helper.py b/unit_tests/sources/declarative/requesters/error_handlers/backoff_strategies/test_header_helper.py index 81c2d34e..99af3ca2 100644 --- a/unit_tests/sources/declarative/requesters/error_handlers/backoff_strategies/test_header_helper.py +++ b/unit_tests/sources/declarative/requesters/error_handlers/backoff_strategies/test_header_helper.py @@ -6,7 +6,9 @@ from unittest.mock import MagicMock import pytest -from airbyte_cdk.sources.declarative.requesters.error_handlers.backoff_strategies.header_helper import get_numeric_value_from_header +from airbyte_cdk.sources.declarative.requesters.error_handlers.backoff_strategies.header_helper import ( + get_numeric_value_from_header, +) @pytest.mark.parametrize( @@ -17,9 +19,27 @@ ("test_get_numeric_value_from_string_value", {"header": "10.9"}, "header", None, 10.9), ("test_get_numeric_value_from_non_numeric", {"header": "60,120"}, "header", None, None), ("test_get_numeric_value_from_missing_header", {"header": 1}, "notheader", None, None), - ("test_get_numeric_value_with_regex", {"header": "61,60"}, "header", re.compile("([-+]?\d+)"), 61), # noqa - ("test_get_numeric_value_with_regex_no_header", {"header": "61,60"}, "notheader", re.compile("([-+]?\d+)"), None), # noqa - ("test_get_numeric_value_with_regex_not_matching", {"header": "abc61,60"}, "header", re.compile("([-+]?\d+)"), None), # noqa + ( + "test_get_numeric_value_with_regex", + {"header": "61,60"}, + "header", + re.compile("([-+]?\d+)"), + 61, + ), # noqa + ( + "test_get_numeric_value_with_regex_no_header", + {"header": "61,60"}, + "notheader", + re.compile("([-+]?\d+)"), + None, + ), # noqa + ( + "test_get_numeric_value_with_regex_not_matching", + {"header": "abc61,60"}, + "header", + re.compile("([-+]?\d+)"), + None, + ), # noqa ], ) def test_get_numeric_value_from_header(test_name, headers, requested_header, regex, expected_value): diff --git a/unit_tests/sources/declarative/requesters/error_handlers/backoff_strategies/test_wait_time_from_header.py b/unit_tests/sources/declarative/requesters/error_handlers/backoff_strategies/test_wait_time_from_header.py index 59dbb6b4..6db2f9fd 100644 --- a/unit_tests/sources/declarative/requesters/error_handlers/backoff_strategies/test_wait_time_from_header.py +++ b/unit_tests/sources/declarative/requesters/error_handlers/backoff_strategies/test_wait_time_from_header.py @@ -22,8 +22,20 @@ [ ("test_wait_time_from_header", "wait_time", SOME_BACKOFF_TIME, None, SOME_BACKOFF_TIME), ("test_wait_time_from_header_string", "wait_time", "60", None, SOME_BACKOFF_TIME), - ("test_wait_time_from_header_parameters", "{{ parameters['wait_time'] }}", "60", None, SOME_BACKOFF_TIME), - ("test_wait_time_from_header_config", "{{ config['wait_time'] }}", "60", None, SOME_BACKOFF_TIME), + ( + "test_wait_time_from_header_parameters", + "{{ parameters['wait_time'] }}", + "60", + None, + SOME_BACKOFF_TIME, + ), + ( + "test_wait_time_from_header_config", + "{{ config['wait_time'] }}", + "60", + None, + SOME_BACKOFF_TIME, + ), ("test_wait_time_from_header_not_a_number", "wait_time", "61,60", None, None), ("test_wait_time_from_header_with_regex", "wait_time", "61,60", "([-+]?\d+)", 61), # noqa ("test_wait_time_fœrom_header_with_regex_no_match", "wait_time", "...", "[-+]?\d+", None), # noqa @@ -34,7 +46,10 @@ def test_wait_time_from_header(test_name, header, header_value, regex, expected_ response_mock = MagicMock(spec=Response) response_mock.headers = {"wait_time": header_value} backoff_strategy = WaitTimeFromHeaderBackoffStrategy( - header=header, regex=regex, parameters={"wait_time": "wait_time"}, config={"wait_time": "wait_time"} + header=header, + regex=regex, + parameters={"wait_time": "wait_time"}, + config={"wait_time": "wait_time"}, ) backoff = backoff_strategy.backoff_time(response_mock, 1) assert backoff == expected_backoff_time diff --git a/unit_tests/sources/declarative/requesters/error_handlers/backoff_strategies/test_wait_until_time_from_header.py b/unit_tests/sources/declarative/requesters/error_handlers/backoff_strategies/test_wait_until_time_from_header.py index 5f2bc02f..20dba620 100644 --- a/unit_tests/sources/declarative/requesters/error_handlers/backoff_strategies/test_wait_until_time_from_header.py +++ b/unit_tests/sources/declarative/requesters/error_handlers/backoff_strategies/test_wait_until_time_from_header.py @@ -19,19 +19,75 @@ "test_name, header, wait_until, min_wait, regex, expected_backoff_time", [ ("test_wait_until_time_from_header", "wait_until", 1600000060.0, None, None, 60), - ("test_wait_until_time_from_header_parameters", "{{parameters['wait_until']}}", 1600000060.0, None, None, 60), - ("test_wait_until_time_from_header_config", "{{config['wait_until']}}", 1600000060.0, None, None, 60), + ( + "test_wait_until_time_from_header_parameters", + "{{parameters['wait_until']}}", + 1600000060.0, + None, + None, + 60, + ), + ( + "test_wait_until_time_from_header_config", + "{{config['wait_until']}}", + 1600000060.0, + None, + None, + 60, + ), ("test_wait_until_negative_time", "wait_until", 1500000000.0, None, None, None), ("test_wait_until_time_less_than_min", "wait_until", 1600000060.0, 120, None, 120), ("test_wait_until_no_header", "absent_header", 1600000000.0, None, None, None), - ("test_wait_until_time_from_header_not_numeric", "wait_until", "1600000000,1600000000", None, None, None), + ( + "test_wait_until_time_from_header_not_numeric", + "wait_until", + "1600000000,1600000000", + None, + None, + None, + ), ("test_wait_until_time_from_header_is_numeric", "wait_until", "1600000060", None, None, 60), - ("test_wait_until_time_from_header_with_regex", "wait_until", "1600000060,60", None, "[-+]?\d+", 60), # noqa - ("test_wait_until_time_from_header_with_regex_from_parameters", "wait_until", "1600000060,60", None, "{{parameters['regex']}}", 60), + ( + "test_wait_until_time_from_header_with_regex", + "wait_until", + "1600000060,60", + None, + "[-+]?\d+", + 60, + ), # noqa + ( + "test_wait_until_time_from_header_with_regex_from_parameters", + "wait_until", + "1600000060,60", + None, + "{{parameters['regex']}}", + 60, + ), # noqa - ("test_wait_until_time_from_header_with_regex_from_config", "wait_until", "1600000060,60", None, "{{config['regex']}}", 60), # noqa - ("test_wait_until_time_from_header_with_regex_no_match", "wait_time", "...", None, "[-+]?\d+", None), # noqa - ("test_wait_until_no_header_with_min", "absent_header", "1600000000.0", SOME_BACKOFF_TIME, None, SOME_BACKOFF_TIME), + ( + "test_wait_until_time_from_header_with_regex_from_config", + "wait_until", + "1600000060,60", + None, + "{{config['regex']}}", + 60, + ), # noqa + ( + "test_wait_until_time_from_header_with_regex_no_match", + "wait_time", + "...", + None, + "[-+]?\d+", + None, + ), # noqa + ( + "test_wait_until_no_header_with_min", + "absent_header", + "1600000000.0", + SOME_BACKOFF_TIME, + None, + SOME_BACKOFF_TIME, + ), ( "test_wait_until_no_header_with_min_from_parameters", "absent_header", @@ -51,7 +107,9 @@ ], ) @patch("time.time", return_value=1600000000.0) -def test_wait_untiltime_from_header(time_mock, test_name, header, wait_until, min_wait, regex, expected_backoff_time): +def test_wait_untiltime_from_header( + time_mock, test_name, header, wait_until, min_wait, regex, expected_backoff_time +): response_mock = MagicMock(spec=requests.Response) response_mock.headers = {"wait_until": wait_until} backoff_strategy = WaitUntilTimeFromHeaderBackoffStrategy( diff --git a/unit_tests/sources/declarative/requesters/error_handlers/test_composite_error_handler.py b/unit_tests/sources/declarative/requesters/error_handlers/test_composite_error_handler.py index 574f3eec..3d2f551c 100644 --- a/unit_tests/sources/declarative/requesters/error_handlers/test_composite_error_handler.py +++ b/unit_tests/sources/declarative/requesters/error_handlers/test_composite_error_handler.py @@ -8,9 +8,16 @@ import requests from airbyte_cdk.models import FailureType from airbyte_cdk.sources.declarative.requesters.error_handlers import HttpResponseFilter -from airbyte_cdk.sources.declarative.requesters.error_handlers.composite_error_handler import CompositeErrorHandler -from airbyte_cdk.sources.declarative.requesters.error_handlers.default_error_handler import DefaultErrorHandler -from airbyte_cdk.sources.streams.http.error_handlers.response_models import ErrorResolution, ResponseAction +from airbyte_cdk.sources.declarative.requesters.error_handlers.composite_error_handler import ( + CompositeErrorHandler, +) +from airbyte_cdk.sources.declarative.requesters.error_handlers.default_error_handler import ( + DefaultErrorHandler, +) +from airbyte_cdk.sources.streams.http.error_handlers.response_models import ( + ErrorResolution, + ResponseAction, +) SOME_BACKOFF_TIME = 60 @@ -86,7 +93,9 @@ ), ], ) -def test_composite_error_handler(test_name, first_handler_behavior, second_handler_behavior, expected_behavior): +def test_composite_error_handler( + test_name, first_handler_behavior, second_handler_behavior, expected_behavior +): first_error_handler = MagicMock() first_error_handler.interpret_response.return_value = first_handler_behavior second_error_handler = MagicMock() @@ -94,7 +103,10 @@ def test_composite_error_handler(test_name, first_handler_behavior, second_handl retriers = [first_error_handler, second_error_handler] retrier = CompositeErrorHandler(error_handlers=retriers, parameters={}) response_mock = MagicMock() - response_mock.ok = first_handler_behavior.response_action == ResponseAction.SUCCESS or second_handler_behavior == ResponseAction.SUCCESS + response_mock.ok = ( + first_handler_behavior.response_action == ResponseAction.SUCCESS + or second_handler_behavior == ResponseAction.SUCCESS + ) assert retrier.interpret_response(response_mock) == expected_behavior @@ -131,25 +143,40 @@ def test_error_handler_compatibility_simple(): default_error_handler = DefaultErrorHandler( config={}, parameters={}, - response_filters=[HttpResponseFilter(action=ResponseAction.IGNORE, http_codes={403}, config={}, parameters={})], + response_filters=[ + HttpResponseFilter( + action=ResponseAction.IGNORE, http_codes={403}, config={}, parameters={} + ) + ], ) composite_error_handler = CompositeErrorHandler( error_handlers=[ DefaultErrorHandler( - response_filters=[HttpResponseFilter(action=ResponseAction.IGNORE, http_codes={403}, parameters={}, config={})], + response_filters=[ + HttpResponseFilter( + action=ResponseAction.IGNORE, http_codes={403}, parameters={}, config={} + ) + ], parameters={}, config={}, ) ], parameters={}, ) - assert default_error_handler.interpret_response(response_mock).response_action == expected_action - assert composite_error_handler.interpret_response(response_mock).response_action == expected_action + assert ( + default_error_handler.interpret_response(response_mock).response_action == expected_action + ) + assert ( + composite_error_handler.interpret_response(response_mock).response_action == expected_action + ) @pytest.mark.parametrize( "test_name, status_code, expected_action", - [("test_first_filter", 403, ResponseAction.IGNORE), ("test_second_filter", 404, ResponseAction.FAIL)], + [ + ("test_first_filter", 403, ResponseAction.IGNORE), + ("test_second_filter", 404, ResponseAction.FAIL), + ], ) def test_error_handler_compatibility_multiple_filters(test_name, status_code, expected_action): response_mock = create_response(status_code) @@ -157,8 +184,12 @@ def test_error_handler_compatibility_multiple_filters(test_name, status_code, ex error_handlers=[ DefaultErrorHandler( response_filters=[ - HttpResponseFilter(action=ResponseAction.IGNORE, http_codes={403}, parameters={}, config={}), - HttpResponseFilter(action=ResponseAction.FAIL, http_codes={404}, parameters={}, config={}), + HttpResponseFilter( + action=ResponseAction.IGNORE, http_codes={403}, parameters={}, config={} + ), + HttpResponseFilter( + action=ResponseAction.FAIL, http_codes={404}, parameters={}, config={} + ), ], parameters={}, config={}, @@ -169,22 +200,34 @@ def test_error_handler_compatibility_multiple_filters(test_name, status_code, ex composite_error_handler_with_single_filters = CompositeErrorHandler( error_handlers=[ DefaultErrorHandler( - response_filters=[HttpResponseFilter(action=ResponseAction.IGNORE, http_codes={403}, parameters={}, config={})], + response_filters=[ + HttpResponseFilter( + action=ResponseAction.IGNORE, http_codes={403}, parameters={}, config={} + ) + ], parameters={}, config={}, ), DefaultErrorHandler( - response_filters=[HttpResponseFilter(action=ResponseAction.FAIL, http_codes={404}, parameters={}, config={})], + response_filters=[ + HttpResponseFilter( + action=ResponseAction.FAIL, http_codes={404}, parameters={}, config={} + ) + ], parameters={}, config={}, ), ], parameters={}, ) - actual_action_multiple_filters = error_handler_with_multiple_filters.interpret_response(response_mock).response_action + actual_action_multiple_filters = error_handler_with_multiple_filters.interpret_response( + response_mock + ).response_action assert actual_action_multiple_filters == expected_action - actual_action_single_filters = composite_error_handler_with_single_filters.interpret_response(response_mock).response_action + actual_action_single_filters = composite_error_handler_with_single_filters.interpret_response( + response_mock + ).response_action assert actual_action_single_filters == expected_action @@ -212,7 +255,11 @@ def test_max_time_is_max_of_underlying_handlers(test_name, max_times, expected_m composite_error_handler = CompositeErrorHandler( error_handlers=[ DefaultErrorHandler( - response_filters=[HttpResponseFilter(action=ResponseAction.IGNORE, http_codes={403}, parameters={}, config={})], + response_filters=[ + HttpResponseFilter( + action=ResponseAction.IGNORE, http_codes={403}, parameters={}, config={} + ) + ], max_time=max_time, parameters={}, config={}, diff --git a/unit_tests/sources/declarative/requesters/error_handlers/test_default_error_handler.py b/unit_tests/sources/declarative/requesters/error_handlers/test_default_error_handler.py index 6fc99159..f97fa05f 100644 --- a/unit_tests/sources/declarative/requesters/error_handlers/test_default_error_handler.py +++ b/unit_tests/sources/declarative/requesters/error_handlers/test_default_error_handler.py @@ -6,13 +6,24 @@ import pytest import requests -from airbyte_cdk.sources.declarative.requesters.error_handlers.backoff_strategies.constant_backoff_strategy import ConstantBackoffStrategy +from airbyte_cdk.sources.declarative.requesters.error_handlers.backoff_strategies.constant_backoff_strategy import ( + ConstantBackoffStrategy, +) from airbyte_cdk.sources.declarative.requesters.error_handlers.backoff_strategies.exponential_backoff_strategy import ( ExponentialBackoffStrategy, ) -from airbyte_cdk.sources.declarative.requesters.error_handlers.default_error_handler import DefaultErrorHandler, HttpResponseFilter -from airbyte_cdk.sources.streams.http.error_handlers.default_error_mapping import DEFAULT_ERROR_MAPPING -from airbyte_cdk.sources.streams.http.error_handlers.response_models import ErrorResolution, FailureType, ResponseAction +from airbyte_cdk.sources.declarative.requesters.error_handlers.default_error_handler import ( + DefaultErrorHandler, + HttpResponseFilter, +) +from airbyte_cdk.sources.streams.http.error_handlers.default_error_mapping import ( + DEFAULT_ERROR_MAPPING, +) +from airbyte_cdk.sources.streams.http.error_handlers.response_models import ( + ErrorResolution, + FailureType, + ResponseAction, +) SOME_BACKOFF_TIME = 60 @@ -55,7 +66,9 @@ ), ], ) -def test_default_error_handler_with_default_response_filter(test_name, http_status_code: int, expected_error_resolution: ErrorResolution): +def test_default_error_handler_with_default_response_filter( + test_name, http_status_code: int, expected_error_resolution: ErrorResolution +): response_mock = create_response(http_status_code) error_handler = DefaultErrorHandler(config={}, parameters={}) actual_error_resolution = error_handler.interpret_response(response_mock) @@ -142,7 +155,9 @@ def test_default_error_handler_with_custom_response_filter( response_mock.json.return_value = {"error": "test"} response_filter = test_response_filter - error_handler = DefaultErrorHandler(config={}, parameters={}, response_filters=[response_filter]) + error_handler = DefaultErrorHandler( + config={}, parameters={}, response_filters=[response_filter] + ) actual_error_resolution = error_handler.interpret_response(response_mock) assert actual_error_resolution.response_action == response_action assert actual_error_resolution.failure_type == failure_type @@ -156,7 +171,9 @@ def test_default_error_handler_with_custom_response_filter( (402, ResponseAction.FAIL), ], ) -def test_default_error_handler_with_multiple_response_filters(http_status_code, expected_response_action): +def test_default_error_handler_with_multiple_response_filters( + http_status_code, expected_response_action +): response_filter_one = HttpResponseFilter( http_codes=[400], action=ResponseAction.RETRY, @@ -171,7 +188,9 @@ def test_default_error_handler_with_multiple_response_filters(http_status_code, ) response_mock = create_response(http_status_code) - error_handler = DefaultErrorHandler(config={}, parameters={}, response_filters=[response_filter_one, response_filter_two]) + error_handler = DefaultErrorHandler( + config={}, parameters={}, response_filters=[response_filter_one, response_filter_two] + ) actual_error_resolution = error_handler.interpret_response(response_mock) assert actual_error_resolution.response_action == expected_response_action @@ -202,7 +221,9 @@ def test_default_error_handler_with_conflicting_response_filters( ) response_mock = create_response(400) - error_handler = DefaultErrorHandler(config={}, parameters={}, response_filters=[response_filter_one, response_filter_two]) + error_handler = DefaultErrorHandler( + config={}, parameters={}, response_filters=[response_filter_one, response_filter_two] + ) actual_error_resolution = error_handler.interpret_response(response_mock) assert actual_error_resolution.response_action == expected_response_action @@ -210,9 +231,14 @@ def test_default_error_handler_with_conflicting_response_filters( def test_default_error_handler_with_constant_backoff_strategy(): response_mock = create_response(429) error_handler = DefaultErrorHandler( - config={}, parameters={}, backoff_strategies=[ConstantBackoffStrategy(SOME_BACKOFF_TIME, config={}, parameters={})] + config={}, + parameters={}, + backoff_strategies=[ConstantBackoffStrategy(SOME_BACKOFF_TIME, config={}, parameters={})], + ) + assert ( + error_handler.backoff_time(response_or_exception=response_mock, attempt_count=0) + == SOME_BACKOFF_TIME ) - assert error_handler.backoff_time(response_or_exception=response_mock, attempt_count=0) == SOME_BACKOFF_TIME @pytest.mark.parametrize( @@ -230,9 +256,13 @@ def test_default_error_handler_with_constant_backoff_strategy(): def test_default_error_handler_with_exponential_backoff_strategy(attempt_count): response_mock = create_response(429) error_handler = DefaultErrorHandler( - config={}, parameters={}, backoff_strategies=[ExponentialBackoffStrategy(factor=1, config={}, parameters={})] + config={}, + parameters={}, + backoff_strategies=[ExponentialBackoffStrategy(factor=1, config={}, parameters={})], ) - assert error_handler.backoff_time(response_or_exception=response_mock, attempt_count=attempt_count) == (1 * 2**attempt_count) + assert error_handler.backoff_time( + response_or_exception=response_mock, attempt_count=attempt_count + ) == (1 * 2**attempt_count) def create_response(status_code: int, headers=None, json_body=None): @@ -270,7 +300,9 @@ def test_predicate_takes_precedent_over_default_mapped_error(): parameters={}, ) - error_handler = DefaultErrorHandler(config={}, parameters={}, response_filters=[response_filter]) + error_handler = DefaultErrorHandler( + config={}, parameters={}, response_filters=[response_filter] + ) actual_error_resolution = error_handler.interpret_response(response_mock) assert actual_error_resolution.response_action == ResponseAction.FAIL assert actual_error_resolution.failure_type == FailureType.system_error diff --git a/unit_tests/sources/declarative/requesters/error_handlers/test_default_http_response_filter.py b/unit_tests/sources/declarative/requesters/error_handlers/test_default_http_response_filter.py index ada35a51..dc0c004a 100644 --- a/unit_tests/sources/declarative/requesters/error_handlers/test_default_http_response_filter.py +++ b/unit_tests/sources/declarative/requesters/error_handlers/test_default_http_response_filter.py @@ -6,8 +6,12 @@ import pytest from airbyte_cdk.models import FailureType -from airbyte_cdk.sources.declarative.requesters.error_handlers.default_http_response_filter import DefaultHttpResponseFilter -from airbyte_cdk.sources.streams.http.error_handlers.default_error_mapping import DEFAULT_ERROR_MAPPING +from airbyte_cdk.sources.declarative.requesters.error_handlers.default_http_response_filter import ( + DefaultHttpResponseFilter, +) +from airbyte_cdk.sources.streams.http.error_handlers.default_error_mapping import ( + DEFAULT_ERROR_MAPPING, +) from airbyte_cdk.sources.streams.http.error_handlers.response_models import ResponseAction from requests import RequestException, Response diff --git a/unit_tests/sources/declarative/requesters/error_handlers/test_http_response_filter.py b/unit_tests/sources/declarative/requesters/error_handlers/test_http_response_filter.py index 9c6817c2..1acc95f2 100644 --- a/unit_tests/sources/declarative/requesters/error_handlers/test_http_response_filter.py +++ b/unit_tests/sources/declarative/requesters/error_handlers/test_http_response_filter.py @@ -8,7 +8,10 @@ import requests from airbyte_cdk.models import FailureType from airbyte_cdk.sources.declarative.requesters.error_handlers import HttpResponseFilter -from airbyte_cdk.sources.streams.http.error_handlers.response_models import ErrorResolution, ResponseAction +from airbyte_cdk.sources.streams.http.error_handlers.response_models import ( + ErrorResolution, + ResponseAction, +) @pytest.mark.parametrize( @@ -23,7 +26,9 @@ "custom error message", {"status_code": 503}, ErrorResolution( - response_action=ResponseAction.FAIL, failure_type=FailureType.transient_error, error_message="custom error message" + response_action=ResponseAction.FAIL, + failure_type=FailureType.transient_error, + error_message="custom error message", ), id="test_http_code_matches", ), @@ -51,7 +56,9 @@ "", {"status_code": 429}, ErrorResolution( - response_action=ResponseAction.RETRY, failure_type=FailureType.transient_error, error_message="Too many requests." + response_action=ResponseAction.RETRY, + failure_type=FailureType.transient_error, + error_message="Too many requests.", ), id="test_http_code_matches_retry_action", ), @@ -64,7 +71,9 @@ "error message was: {{ response.failure }}", {"status_code": 404, "json": {"the_body": "do_i_match", "failure": "i failed you"}}, ErrorResolution( - response_action=ResponseAction.FAIL, failure_type=FailureType.system_error, error_message="error message was: i failed you" + response_action=ResponseAction.FAIL, + failure_type=FailureType.system_error, + error_message="error message was: i failed you", ), id="test_predicate_matches_json", ), @@ -77,7 +86,9 @@ "error from header: {{ headers.warning }}", {"status_code": 404, "headers": {"the_key": "header_match", "warning": "this failed"}}, ErrorResolution( - response_action=ResponseAction.FAIL, failure_type=FailureType.system_error, error_message="error from header: this failed" + response_action=ResponseAction.FAIL, + failure_type=FailureType.system_error, + error_message="error from header: this failed", ), id="test_predicate_matches_headers", ), @@ -103,7 +114,11 @@ '{{ headers.error == "invalid_input" or response.reason == "bad request"}}', "", "", - {"status_code": 403, "headers": {"error": "authentication_error"}, "json": {"reason": "permission denied"}}, + { + "status_code": 403, + "headers": {"error": "authentication_error"}, + "json": {"reason": "permission denied"}, + }, None, id="test_response_does_not_match_filter", ), @@ -115,7 +130,11 @@ "", "check permissions", {"status_code": 403}, - ErrorResolution(response_action=ResponseAction.FAIL, failure_type=FailureType.config_error, error_message="check permissions"), + ErrorResolution( + response_action=ResponseAction.FAIL, + failure_type=FailureType.config_error, + error_message="check permissions", + ), id="test_http_code_matches_failure_type_config_error", ), pytest.param( @@ -126,7 +145,11 @@ "", "check permissions", {"status_code": 403}, - ErrorResolution(response_action=ResponseAction.FAIL, failure_type=FailureType.system_error, error_message="check permissions"), + ErrorResolution( + response_action=ResponseAction.FAIL, + failure_type=FailureType.system_error, + error_message="check permissions", + ), id="test_http_code_matches_failure_type_system_error", ), pytest.param( @@ -137,7 +160,11 @@ "", "rate limits", {"status_code": 500}, - ErrorResolution(response_action=ResponseAction.FAIL, failure_type=FailureType.transient_error, error_message="rate limits"), + ErrorResolution( + response_action=ResponseAction.FAIL, + failure_type=FailureType.transient_error, + error_message="rate limits", + ), id="test_http_code_matches_failure_type_transient_error", ), pytest.param( @@ -148,7 +175,11 @@ "", "rate limits", {"status_code": 500}, - ErrorResolution(response_action=ResponseAction.RETRY, failure_type=FailureType.transient_error, error_message="rate limits"), + ErrorResolution( + response_action=ResponseAction.RETRY, + failure_type=FailureType.transient_error, + error_message="rate limits", + ), id="test_http_code_matches_failure_type_config_error_action_retry_uses_default_failure_type", ), pytest.param( @@ -160,14 +191,24 @@ "rate limits", {"status_code": 500}, ErrorResolution( - response_action=ResponseAction.RATE_LIMITED, failure_type=FailureType.transient_error, error_message="rate limits" + response_action=ResponseAction.RATE_LIMITED, + failure_type=FailureType.transient_error, + error_message="rate limits", ), id="test_http_code_matches_response_action_rate_limited", ), ], ) def test_matches( - requests_mock, action, failure_type, http_codes, predicate, error_contains, error_message, response, expected_error_resolution + requests_mock, + action, + failure_type, + http_codes, + predicate, + error_contains, + error_message, + response, + expected_error_resolution, ): requests_mock.register_uri( "GET", diff --git a/unit_tests/sources/declarative/requesters/paginators/test_cursor_pagination_strategy.py b/unit_tests/sources/declarative/requesters/paginators/test_cursor_pagination_strategy.py index 31d9ae5e..4d2920ea 100644 --- a/unit_tests/sources/declarative/requesters/paginators/test_cursor_pagination_strategy.py +++ b/unit_tests/sources/declarative/requesters/paginators/test_cursor_pagination_strategy.py @@ -8,7 +8,9 @@ import requests from airbyte_cdk.sources.declarative.decoders.json_decoder import JsonDecoder from airbyte_cdk.sources.declarative.interpolation.interpolated_boolean import InterpolatedBoolean -from airbyte_cdk.sources.declarative.requesters.paginators.strategies.cursor_pagination_strategy import CursorPaginationStrategy +from airbyte_cdk.sources.declarative.requesters.paginators.strategies.cursor_pagination_strategy import ( + CursorPaginationStrategy, +) @pytest.mark.parametrize( @@ -68,7 +70,13 @@ def test_cursor_pagination_strategy(template_string, stop_condition, expected_to response = requests.Response() link_str = '; rel="next"' response.headers = {"has_more": True, "next": "ready_to_go", "link": link_str} - response_body = {"_metadata": {"content": "content_value"}, "accounts": [], "end": 99, "total": 200, "characters": {}} + response_body = { + "_metadata": {"content": "content_value"}, + "accounts": [], + "end": 99, + "total": 200, + "characters": {}, + } response._content = json.dumps(response_body).encode("utf-8") last_record = {"id": 1, "more_records": True} diff --git a/unit_tests/sources/declarative/requesters/paginators/test_default_paginator.py b/unit_tests/sources/declarative/requesters/paginators/test_default_paginator.py index ff341fd8..d02562b0 100644 --- a/unit_tests/sources/declarative/requesters/paginators/test_default_paginator.py +++ b/unit_tests/sources/declarative/requesters/paginators/test_default_paginator.py @@ -15,8 +15,12 @@ RequestOption, RequestOptionType, ) -from airbyte_cdk.sources.declarative.requesters.paginators.strategies.cursor_pagination_strategy import CursorPaginationStrategy -from airbyte_cdk.sources.declarative.requesters.paginators.strategies.offset_increment import OffsetIncrement +from airbyte_cdk.sources.declarative.requesters.paginators.strategies.cursor_pagination_strategy import ( + CursorPaginationStrategy, +) +from airbyte_cdk.sources.declarative.requesters.paginators.strategies.offset_increment import ( + OffsetIncrement, +) from airbyte_cdk.sources.declarative.requesters.request_path import RequestPath @@ -38,7 +42,9 @@ {"next": "https://airbyte.io/next_url"}, ), ( - RequestOption(inject_into=RequestOptionType.request_parameter, field_name="from", parameters={}), + RequestOption( + inject_into=RequestOptionType.request_parameter, field_name="from", parameters={} + ), None, None, {"limit": 2, "from": "https://airbyte.io/next_url"}, @@ -52,7 +58,9 @@ {"next": "https://airbyte.io/next_url"}, ), ( - RequestOption(inject_into=RequestOptionType.request_parameter, field_name="from", parameters={}), + RequestOption( + inject_into=RequestOptionType.request_parameter, field_name="from", parameters={} + ), InterpolatedBoolean(condition="{{True}}", parameters={}), None, {"limit": 2}, @@ -80,7 +88,9 @@ {"next": "https://airbyte.io/next_url"}, ), ( - RequestOption(inject_into=RequestOptionType.body_data, field_name="from", parameters={}), + RequestOption( + inject_into=RequestOptionType.body_data, field_name="from", parameters={} + ), None, None, {"limit": 2}, @@ -94,7 +104,9 @@ {"next": "https://airbyte.io/next_url"}, ), ( - RequestOption(inject_into=RequestOptionType.body_json, field_name="from", parameters={}), + RequestOption( + inject_into=RequestOptionType.body_json, field_name="from", parameters={} + ), None, None, {"limit": 2}, @@ -122,7 +134,9 @@ b"https://airbyte.io/next_url", ), ( - RequestOption(inject_into=RequestOptionType.request_parameter, field_name="from", parameters={}), + RequestOption( + inject_into=RequestOptionType.request_parameter, field_name="from", parameters={} + ), None, None, {"limit": 2, "from": "https://airbyte.io/next_url"}, @@ -162,7 +176,9 @@ def test_default_paginator_with_cursor( response_body, ): page_size_request_option = RequestOption( - inject_into=RequestOptionType.request_parameter, field_name="{{parameters['page_limit']}}", parameters={"page_limit": "limit"} + inject_into=RequestOptionType.request_parameter, + field_name="{{parameters['page_limit']}}", + parameters={"page_limit": "limit"}, ) cursor_value = "{{ response.next }}" url_base = "https://airbyte.io" @@ -187,7 +203,9 @@ def test_default_paginator_with_cursor( response = requests.Response() response.headers = {"A_HEADER": "HEADER_VALUE"} - response._content = json.dumps(response_body).encode("utf-8") if decoder == JsonDecoder else response_body + response._content = ( + json.dumps(response_body).encode("utf-8") if decoder == JsonDecoder else response_body + ) actual_next_page_token = paginator.next_page_token(response, 2, last_record) actual_next_path = paginator.path() @@ -211,7 +229,11 @@ def test_default_paginator_with_cursor( "{{parameters['page_token']}}", {"parameters_limit": 50, "parameters_token": "https://airbyte.io/next_url"}, ), - ("{{config['page_limit']}}", "{{config['page_token']}}", {"config_limit": 50, "config_token": "https://airbyte.io/next_url"}), + ( + "{{config['page_limit']}}", + "{{config['page_token']}}", + {"config_limit": 50, "config_token": "https://airbyte.io/next_url"}, + ), ], ids=[ "parameters_interpolation", @@ -219,7 +241,9 @@ def test_default_paginator_with_cursor( ], ) def test_paginator_request_param_interpolation( - field_name_page_size_interpolation: str, field_name_page_token_interpolation: str, expected_request_params: dict + field_name_page_size_interpolation: str, + field_name_page_token_interpolation: str, + expected_request_params: dict, ): config = {"page_limit": "config_limit", "page_token": "config_token"} parameters = {"page_limit": "parameters_limit", "page_token": "parameters_token"} @@ -242,7 +266,9 @@ def test_paginator_request_param_interpolation( paginator = DefaultPaginator( page_size_option=page_size_request_option, page_token_option=RequestOption( - inject_into=RequestOptionType.request_parameter, field_name=field_name_page_token_interpolation, parameters=parameters + inject_into=RequestOptionType.request_parameter, + field_name=field_name_page_token_interpolation, + parameters=parameters, ), pagination_strategy=strategy, config=config, @@ -260,13 +286,19 @@ def test_paginator_request_param_interpolation( def test_page_size_option_cannot_be_set_if_strategy_has_no_limit(): - page_size_request_option = RequestOption(inject_into=RequestOptionType.request_parameter, field_name="page_size", parameters={}) - page_token_request_option = RequestOption(inject_into=RequestOptionType.request_parameter, field_name="offset", parameters={}) + page_size_request_option = RequestOption( + inject_into=RequestOptionType.request_parameter, field_name="page_size", parameters={} + ) + page_token_request_option = RequestOption( + inject_into=RequestOptionType.request_parameter, field_name="offset", parameters={} + ) cursor_value = "{{ response.next }}" url_base = "https://airbyte.io" config = {} parameters = {} - strategy = CursorPaginationStrategy(page_size=None, cursor_value=cursor_value, config=config, parameters=parameters) + strategy = CursorPaginationStrategy( + page_size=None, cursor_value=cursor_value, config=config, parameters=parameters + ) try: DefaultPaginator( page_size_option=page_size_request_option, @@ -293,13 +325,24 @@ def test_page_size_option_cannot_be_set_if_strategy_has_no_limit(): ], ) def test_reset(inject_on_first_request): - page_size_request_option = RequestOption(inject_into=RequestOptionType.request_parameter, field_name="limit", parameters={}) - page_token_request_option = RequestOption(inject_into=RequestOptionType.request_parameter, field_name="offset", parameters={}) + page_size_request_option = RequestOption( + inject_into=RequestOptionType.request_parameter, field_name="limit", parameters={} + ) + page_token_request_option = RequestOption( + inject_into=RequestOptionType.request_parameter, field_name="offset", parameters={} + ) url_base = "https://airbyte.io" config = {} - strategy = OffsetIncrement(config={}, page_size=2, inject_on_first_request=inject_on_first_request, parameters={}) + strategy = OffsetIncrement( + config={}, page_size=2, inject_on_first_request=inject_on_first_request, parameters={} + ) paginator = DefaultPaginator( - strategy, config, url_base, parameters={}, page_size_option=page_size_request_option, page_token_option=page_token_request_option + strategy, + config, + url_base, + parameters={}, + page_size_option=page_size_request_option, + page_token_option=page_token_request_option, ) initial_request_parameters = paginator.get_request_params() response = requests.Response() @@ -313,13 +356,22 @@ def test_reset(inject_on_first_request): def test_initial_token_with_offset_pagination(): - page_size_request_option = RequestOption(inject_into=RequestOptionType.request_parameter, field_name="limit", parameters={}) - page_token_request_option = RequestOption(inject_into=RequestOptionType.request_parameter, field_name="offset", parameters={}) + page_size_request_option = RequestOption( + inject_into=RequestOptionType.request_parameter, field_name="limit", parameters={} + ) + page_token_request_option = RequestOption( + inject_into=RequestOptionType.request_parameter, field_name="offset", parameters={} + ) url_base = "https://airbyte.io" config = {} strategy = OffsetIncrement(config={}, page_size=2, parameters={}, inject_on_first_request=True) paginator = DefaultPaginator( - strategy, config, url_base, parameters={}, page_size_option=page_size_request_option, page_token_option=page_token_request_option + strategy, + config, + url_base, + parameters={}, + page_size_option=page_size_request_option, + page_token_option=page_token_request_option, ) initial_request_parameters = paginator.get_request_params() @@ -355,7 +407,9 @@ def test_paginator_with_page_option_no_page_size(): ( DefaultPaginator( page_size_option=MagicMock(), - page_token_option=RequestOption("limit", RequestOptionType.request_parameter, parameters={}), + page_token_option=RequestOption( + "limit", RequestOptionType.request_parameter, parameters={} + ), pagination_strategy=pagination_strategy, config=MagicMock(), url_base=MagicMock(), diff --git a/unit_tests/sources/declarative/requesters/paginators/test_offset_increment.py b/unit_tests/sources/declarative/requesters/paginators/test_offset_increment.py index d655dc9e..8c357349 100644 --- a/unit_tests/sources/declarative/requesters/paginators/test_offset_increment.py +++ b/unit_tests/sources/declarative/requesters/paginators/test_offset_increment.py @@ -7,7 +7,9 @@ import pytest import requests -from airbyte_cdk.sources.declarative.requesters.paginators.strategies.offset_increment import OffsetIncrement +from airbyte_cdk.sources.declarative.requesters.paginators.strategies.offset_increment import ( + OffsetIncrement, +) @pytest.mark.parametrize( @@ -15,12 +17,30 @@ [ pytest.param("2", {}, 2, {"id": 1}, 2, 2, id="test_same_page_size"), pytest.param(2, {}, 2, {"id": 1}, 2, 2, id="test_same_page_size"), - pytest.param("{{ parameters['page_size'] }}", {"page_size": 3}, 2, {"id": 1}, None, 0, id="test_larger_page_size"), + pytest.param( + "{{ parameters['page_size'] }}", + {"page_size": 3}, + 2, + {"id": 1}, + None, + 0, + id="test_larger_page_size", + ), pytest.param(None, {}, 0, [], None, 0, id="test_stop_if_no_records"), - pytest.param("{{ response['page_metadata']['limit'] }}", {}, 2, {"id": 1}, None, 0, id="test_page_size_from_response"), + pytest.param( + "{{ response['page_metadata']['limit'] }}", + {}, + 2, + {"id": 1}, + None, + 0, + id="test_page_size_from_response", + ), ], ) -def test_offset_increment_paginator_strategy(page_size, parameters, last_page_size, last_record, expected_next_page_token, expected_offset): +def test_offset_increment_paginator_strategy( + page_size, parameters, last_page_size, last_record, expected_next_page_token, expected_offset +): paginator_strategy = OffsetIncrement(page_size=page_size, parameters=parameters, config={}) assert paginator_strategy._offset == 0 @@ -39,7 +59,11 @@ def test_offset_increment_paginator_strategy(page_size, parameters, last_page_si def test_offset_increment_paginator_strategy_rises(): - paginator_strategy = OffsetIncrement(page_size="{{ parameters['page_size'] }}", parameters={"page_size": "invalid value"}, config={}) + paginator_strategy = OffsetIncrement( + page_size="{{ parameters['page_size'] }}", + parameters={"page_size": "invalid value"}, + config={}, + ) with pytest.raises(Exception) as exc: paginator_strategy.get_page_size() assert str(exc.value) == "invalid value is of type . Expected " @@ -52,8 +76,12 @@ def test_offset_increment_paginator_strategy_rises(): pytest.param(False, None, id="test_without_inject_offset"), ], ) -def test_offset_increment_paginator_strategy_initial_token(inject_on_first_request: bool, expected_initial_token: Optional[Any]): - paginator_strategy = OffsetIncrement(page_size=20, parameters={}, config={}, inject_on_first_request=inject_on_first_request) +def test_offset_increment_paginator_strategy_initial_token( + inject_on_first_request: bool, expected_initial_token: Optional[Any] +): + paginator_strategy = OffsetIncrement( + page_size=20, parameters={}, config={}, inject_on_first_request=inject_on_first_request + ) assert paginator_strategy.initial_token == expected_initial_token @@ -67,7 +95,9 @@ def test_offset_increment_paginator_strategy_initial_token(inject_on_first_reque ], ) def test_offset_increment_reset(reset_value, expected_initial_token, expected_error): - paginator_strategy = OffsetIncrement(page_size=20, parameters={}, config={}, inject_on_first_request=True) + paginator_strategy = OffsetIncrement( + page_size=20, parameters={}, config={}, inject_on_first_request=True + ) if expected_error: with pytest.raises(expected_error): diff --git a/unit_tests/sources/declarative/requesters/paginators/test_page_increment.py b/unit_tests/sources/declarative/requesters/paginators/test_page_increment.py index da2bf6d9..9ec994e2 100644 --- a/unit_tests/sources/declarative/requesters/paginators/test_page_increment.py +++ b/unit_tests/sources/declarative/requesters/paginators/test_page_increment.py @@ -7,7 +7,9 @@ import pytest import requests -from airbyte_cdk.sources.declarative.requesters.paginators.strategies.page_increment import PageIncrement +from airbyte_cdk.sources.declarative.requesters.paginators.strategies.page_increment import ( + PageIncrement, +) @pytest.mark.parametrize( @@ -19,11 +21,17 @@ pytest.param(3, 0, 2, {"id": 1}, None, 0, id="test_larger_page_size_start_from_0"), pytest.param(None, 0, 0, None, None, 0, id="test_no_page_size"), pytest.param("2", 0, 2, {"id": 1}, 1, 1, id="test_page_size_from_string"), - pytest.param("{{ config['value'] }}", 0, 2, {"id": 1}, 1, 1, id="test_page_size_from_config"), + pytest.param( + "{{ config['value'] }}", 0, 2, {"id": 1}, 1, 1, id="test_page_size_from_config" + ), ], ) -def test_page_increment_paginator_strategy(page_size, start_from, last_page_size, last_record, expected_next_page_token, expected_offset): - paginator_strategy = PageIncrement(page_size=page_size, parameters={}, start_from_page=start_from, config={"value": 2}) +def test_page_increment_paginator_strategy( + page_size, start_from, last_page_size, last_record, expected_next_page_token, expected_offset +): + paginator_strategy = PageIncrement( + page_size=page_size, parameters={}, start_from_page=start_from, config={"value": 2} + ) assert paginator_strategy._page == start_from response = requests.Response() @@ -40,10 +48,17 @@ def test_page_increment_paginator_strategy(page_size, start_from, last_page_size assert start_from == paginator_strategy._page -@pytest.mark.parametrize("page_size", [pytest.param("{{ config['value'] }}"), pytest.param("not-an-integer")]) +@pytest.mark.parametrize( + "page_size", [pytest.param("{{ config['value'] }}"), pytest.param("not-an-integer")] +) def test_page_increment_paginator_strategy_malformed_page_size(page_size): with pytest.raises(Exception, match=".* is of type . Expected "): - PageIncrement(page_size=page_size, parameters={}, start_from_page=0, config={"value": "not-an-integer"}) + PageIncrement( + page_size=page_size, + parameters={}, + start_from_page=0, + config={"value": "not-an-integer"}, + ) @pytest.mark.parametrize( @@ -58,7 +73,11 @@ def test_page_increment_paginator_strategy_initial_token( inject_on_first_request: bool, start_from_page: int, expected_initial_token: Optional[Any] ): paginator_strategy = PageIncrement( - page_size=20, parameters={}, start_from_page=start_from_page, inject_on_first_request=inject_on_first_request, config={} + page_size=20, + parameters={}, + start_from_page=start_from_page, + inject_on_first_request=inject_on_first_request, + config={}, ) assert paginator_strategy.initial_token == expected_initial_token @@ -73,7 +92,9 @@ def test_page_increment_paginator_strategy_initial_token( ], ) def test_offset_increment_reset(reset_value, expected_initial_token, expected_error): - paginator_strategy = PageIncrement(page_size=100, parameters={}, config={}, inject_on_first_request=True) + paginator_strategy = PageIncrement( + page_size=100, parameters={}, config={}, inject_on_first_request=True + ) if expected_error: with pytest.raises(expected_error): diff --git a/unit_tests/sources/declarative/requesters/paginators/test_request_option.py b/unit_tests/sources/declarative/requesters/paginators/test_request_option.py index 5caa11f5..cef4fe87 100644 --- a/unit_tests/sources/declarative/requesters/paginators/test_request_option.py +++ b/unit_tests/sources/declarative/requesters/paginators/test_request_option.py @@ -3,7 +3,10 @@ # import pytest -from airbyte_cdk.sources.declarative.requesters.request_option import RequestOption, RequestOptionType +from airbyte_cdk.sources.declarative.requesters.request_option import ( + RequestOption, + RequestOptionType, +) @pytest.mark.parametrize( @@ -13,11 +16,19 @@ (RequestOptionType.header, "field", "field"), (RequestOptionType.body_data, "field", "field"), (RequestOptionType.body_json, "field", "field"), - (RequestOptionType.request_parameter, "since_{{ parameters['cursor_field'] }}", "since_updated_at"), + ( + RequestOptionType.request_parameter, + "since_{{ parameters['cursor_field'] }}", + "since_updated_at", + ), (RequestOptionType.header, "since_{{ parameters['cursor_field'] }}", "since_updated_at"), (RequestOptionType.body_data, "since_{{ parameters['cursor_field'] }}", "since_updated_at"), (RequestOptionType.body_json, "since_{{ parameters['cursor_field'] }}", "since_updated_at"), - (RequestOptionType.request_parameter, "since_{{ config['cursor_field'] }}", "since_created_at"), + ( + RequestOptionType.request_parameter, + "since_{{ config['cursor_field'] }}", + "since_created_at", + ), (RequestOptionType.header, "since_{{ config['cursor_field'] }}", "since_created_at"), (RequestOptionType.body_data, "since_{{ config['cursor_field'] }}", "since_created_at"), (RequestOptionType.body_json, "since_{{ config['cursor_field'] }}", "since_created_at"), @@ -38,6 +49,8 @@ ], ) def test_request_option(option_type: RequestOptionType, field_name: str, expected_field_name: str): - request_option = RequestOption(inject_into=option_type, field_name=field_name, parameters={"cursor_field": "updated_at"}) + request_option = RequestOption( + inject_into=option_type, field_name=field_name, parameters={"cursor_field": "updated_at"} + ) assert request_option.field_name.eval({"cursor_field": "created_at"}) == expected_field_name assert request_option.inject_into == option_type diff --git a/unit_tests/sources/declarative/requesters/paginators/test_stop_condition.py b/unit_tests/sources/declarative/requesters/paginators/test_stop_condition.py index 86c5e65f..201636f1 100644 --- a/unit_tests/sources/declarative/requesters/paginators/test_stop_condition.py +++ b/unit_tests/sources/declarative/requesters/paginators/test_stop_condition.py @@ -5,7 +5,9 @@ from unittest.mock import Mock, call from airbyte_cdk.sources.declarative.incremental.declarative_cursor import DeclarativeCursor -from airbyte_cdk.sources.declarative.requesters.paginators.strategies.pagination_strategy import PaginationStrategy +from airbyte_cdk.sources.declarative.requesters.paginators.strategies.pagination_strategy import ( + PaginationStrategy, +) from airbyte_cdk.sources.declarative.requesters.paginators.strategies.stop_condition import ( CursorStopCondition, PaginationStopCondition, @@ -44,11 +46,15 @@ def test_given_record_should_not_be_synced_when_is_met_return_true(mocked_cursor assert CursorStopCondition(mocked_cursor).is_met(ANY_RECORD) -def test_given_stop_condition_is_met_when_next_page_token_then_return_none(mocked_pagination_strategy, mocked_stop_condition): +def test_given_stop_condition_is_met_when_next_page_token_then_return_none( + mocked_pagination_strategy, mocked_stop_condition +): mocked_stop_condition.is_met.return_value = True last_record = Mock(spec=Record) - decorator = StopConditionPaginationStrategyDecorator(mocked_pagination_strategy, mocked_stop_condition) + decorator = StopConditionPaginationStrategyDecorator( + mocked_pagination_strategy, mocked_stop_condition + ) assert not decorator.next_page_token(ANY_RESPONSE, 2, last_record) mocked_stop_condition.is_met.assert_has_calls([call(last_record)]) @@ -60,17 +66,21 @@ def test_given_last_record_meets_condition_when_next_page_token_then_do_not_chec mocked_stop_condition.is_met.return_value = True last_record = Mock(spec=Record) - StopConditionPaginationStrategyDecorator(mocked_pagination_strategy, mocked_stop_condition).next_page_token( - ANY_RESPONSE, 2, last_record - ) + StopConditionPaginationStrategyDecorator( + mocked_pagination_strategy, mocked_stop_condition + ).next_page_token(ANY_RESPONSE, 2, last_record) mocked_stop_condition.is_met.assert_called_once_with(last_record) -def test_given_stop_condition_is_not_met_when_next_page_token_then_delegate(mocked_pagination_strategy, mocked_stop_condition): +def test_given_stop_condition_is_not_met_when_next_page_token_then_delegate( + mocked_pagination_strategy, mocked_stop_condition +): mocked_stop_condition.is_met.return_value = False last_record = Mock(spec=Record) - decorator = StopConditionPaginationStrategyDecorator(mocked_pagination_strategy, mocked_stop_condition) + decorator = StopConditionPaginationStrategyDecorator( + mocked_pagination_strategy, mocked_stop_condition + ) next_page_token = decorator.next_page_token(ANY_RESPONSE, 2, last_record) @@ -79,8 +89,12 @@ def test_given_stop_condition_is_not_met_when_next_page_token_then_delegate(mock mocked_stop_condition.is_met.assert_has_calls([call(last_record)]) -def test_given_no_records_when_next_page_token_then_delegate(mocked_pagination_strategy, mocked_stop_condition): - decorator = StopConditionPaginationStrategyDecorator(mocked_pagination_strategy, mocked_stop_condition) +def test_given_no_records_when_next_page_token_then_delegate( + mocked_pagination_strategy, mocked_stop_condition +): + decorator = StopConditionPaginationStrategyDecorator( + mocked_pagination_strategy, mocked_stop_condition + ) next_page_token = decorator.next_page_token(ANY_RESPONSE, 0, NO_RECORD) @@ -89,13 +103,17 @@ def test_given_no_records_when_next_page_token_then_delegate(mocked_pagination_s def test_when_reset_then_delegate(mocked_pagination_strategy, mocked_stop_condition): - decorator = StopConditionPaginationStrategyDecorator(mocked_pagination_strategy, mocked_stop_condition) + decorator = StopConditionPaginationStrategyDecorator( + mocked_pagination_strategy, mocked_stop_condition + ) decorator.reset() mocked_pagination_strategy.reset.assert_called_once_with() def test_when_get_page_size_then_delegate(mocked_pagination_strategy, mocked_stop_condition): - decorator = StopConditionPaginationStrategyDecorator(mocked_pagination_strategy, mocked_stop_condition) + decorator = StopConditionPaginationStrategyDecorator( + mocked_pagination_strategy, mocked_stop_condition + ) page_size = decorator.get_page_size() diff --git a/unit_tests/sources/declarative/requesters/request_options/test_datetime_based_request_options_provider.py b/unit_tests/sources/declarative/requesters/request_options/test_datetime_based_request_options_provider.py index 9fca23be..7cbfa78d 100644 --- a/unit_tests/sources/declarative/requesters/request_options/test_datetime_based_request_options_provider.py +++ b/unit_tests/sources/declarative/requesters/request_options/test_datetime_based_request_options_provider.py @@ -3,8 +3,13 @@ # import pytest -from airbyte_cdk.sources.declarative.requesters.request_option import RequestOption, RequestOptionType -from airbyte_cdk.sources.declarative.requesters.request_options import DatetimeBasedRequestOptionsProvider +from airbyte_cdk.sources.declarative.requesters.request_option import ( + RequestOption, + RequestOptionType, +) +from airbyte_cdk.sources.declarative.requesters.request_options import ( + DatetimeBasedRequestOptionsProvider, +) from airbyte_cdk.sources.declarative.types import StreamSlice @@ -12,44 +17,69 @@ "start_time_option, end_time_option, partition_field_start, partition_field_end, stream_slice, expected_request_options", [ pytest.param( - RequestOption(field_name="after", inject_into=RequestOptionType.request_parameter, parameters={}), - RequestOption(field_name="before", inject_into=RequestOptionType.request_parameter, parameters={}), + RequestOption( + field_name="after", inject_into=RequestOptionType.request_parameter, parameters={} + ), + RequestOption( + field_name="before", inject_into=RequestOptionType.request_parameter, parameters={} + ), "custom_start", "custom_end", - StreamSlice(cursor_slice={"custom_start": "2024-06-01", "custom_end": "2024-06-02"}, partition={}), + StreamSlice( + cursor_slice={"custom_start": "2024-06-01", "custom_end": "2024-06-02"}, + partition={}, + ), {"after": "2024-06-01", "before": "2024-06-02"}, id="test_request_params", ), pytest.param( - RequestOption(field_name="after", inject_into=RequestOptionType.request_parameter, parameters={}), - RequestOption(field_name="before", inject_into=RequestOptionType.request_parameter, parameters={}), + RequestOption( + field_name="after", inject_into=RequestOptionType.request_parameter, parameters={} + ), + RequestOption( + field_name="before", inject_into=RequestOptionType.request_parameter, parameters={} + ), None, None, - StreamSlice(cursor_slice={"start_time": "2024-06-01", "end_time": "2024-06-02"}, partition={}), + StreamSlice( + cursor_slice={"start_time": "2024-06-01", "end_time": "2024-06-02"}, partition={} + ), {"after": "2024-06-01", "before": "2024-06-02"}, id="test_request_params_with_default_partition_fields", ), pytest.param( None, - RequestOption(field_name="before", inject_into=RequestOptionType.request_parameter, parameters={}), + RequestOption( + field_name="before", inject_into=RequestOptionType.request_parameter, parameters={} + ), None, None, - StreamSlice(cursor_slice={"start_time": "2024-06-01", "end_time": "2024-06-02"}, partition={}), + StreamSlice( + cursor_slice={"start_time": "2024-06-01", "end_time": "2024-06-02"}, partition={} + ), {"before": "2024-06-02"}, id="test_request_params_no_start_time_option", ), pytest.param( - RequestOption(field_name="after", inject_into=RequestOptionType.request_parameter, parameters={}), + RequestOption( + field_name="after", inject_into=RequestOptionType.request_parameter, parameters={} + ), None, None, None, - StreamSlice(cursor_slice={"start_time": "2024-06-01", "end_time": "2024-06-02"}, partition={}), + StreamSlice( + cursor_slice={"start_time": "2024-06-01", "end_time": "2024-06-02"}, partition={} + ), {"after": "2024-06-01"}, id="test_request_params_no_end_time_option", ), pytest.param( - RequestOption(field_name="after", inject_into=RequestOptionType.request_parameter, parameters={}), - RequestOption(field_name="before", inject_into=RequestOptionType.request_parameter, parameters={}), + RequestOption( + field_name="after", inject_into=RequestOptionType.request_parameter, parameters={} + ), + RequestOption( + field_name="before", inject_into=RequestOptionType.request_parameter, parameters={} + ), None, None, None, @@ -61,32 +91,54 @@ RequestOption(field_name="before", inject_into=RequestOptionType.header, parameters={}), "custom_start", "custom_end", - StreamSlice(cursor_slice={"custom_start": "2024-06-01", "custom_end": "2024-06-02"}, partition={}), + StreamSlice( + cursor_slice={"custom_start": "2024-06-01", "custom_end": "2024-06-02"}, + partition={}, + ), {"after": "2024-06-01", "before": "2024-06-02"}, id="test_request_headers", ), pytest.param( - RequestOption(field_name="after", inject_into=RequestOptionType.body_data, parameters={}), - RequestOption(field_name="before", inject_into=RequestOptionType.body_data, parameters={}), + RequestOption( + field_name="after", inject_into=RequestOptionType.body_data, parameters={} + ), + RequestOption( + field_name="before", inject_into=RequestOptionType.body_data, parameters={} + ), "custom_start", "custom_end", - StreamSlice(cursor_slice={"custom_start": "2024-06-01", "custom_end": "2024-06-02"}, partition={}), + StreamSlice( + cursor_slice={"custom_start": "2024-06-01", "custom_end": "2024-06-02"}, + partition={}, + ), {"after": "2024-06-01", "before": "2024-06-02"}, id="test_request_request_body_data", ), pytest.param( - RequestOption(field_name="after", inject_into=RequestOptionType.body_json, parameters={}), - RequestOption(field_name="before", inject_into=RequestOptionType.body_json, parameters={}), + RequestOption( + field_name="after", inject_into=RequestOptionType.body_json, parameters={} + ), + RequestOption( + field_name="before", inject_into=RequestOptionType.body_json, parameters={} + ), "custom_start", "custom_end", - StreamSlice(cursor_slice={"custom_start": "2024-06-01", "custom_end": "2024-06-02"}, partition={}), + StreamSlice( + cursor_slice={"custom_start": "2024-06-01", "custom_end": "2024-06-02"}, + partition={}, + ), {"after": "2024-06-01", "before": "2024-06-02"}, id="test_request_request_body_json", ), ], ) def test_datetime_based_request_options_provider( - start_time_option, end_time_option, partition_field_start, partition_field_end, stream_slice, expected_request_options + start_time_option, + end_time_option, + partition_field_start, + partition_field_end, + stream_slice, + expected_request_options, ): config = {} request_options_provider = DatetimeBasedRequestOptionsProvider( @@ -98,18 +150,30 @@ def test_datetime_based_request_options_provider( parameters={}, ) - request_option_type = start_time_option.inject_into if isinstance(start_time_option, RequestOption) else None + request_option_type = ( + start_time_option.inject_into if isinstance(start_time_option, RequestOption) else None + ) match request_option_type: case RequestOptionType.request_parameter: - actual_request_options = request_options_provider.get_request_params(stream_slice=stream_slice) + actual_request_options = request_options_provider.get_request_params( + stream_slice=stream_slice + ) case RequestOptionType.header: - actual_request_options = request_options_provider.get_request_headers(stream_slice=stream_slice) + actual_request_options = request_options_provider.get_request_headers( + stream_slice=stream_slice + ) case RequestOptionType.body_data: - actual_request_options = request_options_provider.get_request_body_data(stream_slice=stream_slice) + actual_request_options = request_options_provider.get_request_body_data( + stream_slice=stream_slice + ) case RequestOptionType.body_json: - actual_request_options = request_options_provider.get_request_body_json(stream_slice=stream_slice) + actual_request_options = request_options_provider.get_request_body_json( + stream_slice=stream_slice + ) case _: # We defer to testing the default RequestOptions using get_request_params() - actual_request_options = request_options_provider.get_request_params(stream_slice=stream_slice) + actual_request_options = request_options_provider.get_request_params( + stream_slice=stream_slice + ) assert actual_request_options == expected_request_options diff --git a/unit_tests/sources/declarative/requesters/request_options/test_interpolated_request_options_provider.py b/unit_tests/sources/declarative/requesters/request_options/test_interpolated_request_options_provider.py index a233d371..3e11bfa5 100644 --- a/unit_tests/sources/declarative/requesters/request_options/test_interpolated_request_options_provider.py +++ b/unit_tests/sources/declarative/requesters/request_options/test_interpolated_request_options_provider.py @@ -16,14 +16,36 @@ @pytest.mark.parametrize( "test_name, input_request_params, expected_request_params", [ - ("test_static_param", {"a_static_request_param": "a_static_value"}, {"a_static_request_param": "a_static_value"}), - ("test_value_depends_on_state", {"read_from_state": "{{ stream_state['date'] }}"}, {"read_from_state": "2021-01-01"}), - ("test_value_depends_on_stream_slice", {"read_from_slice": "{{ stream_slice['start_date'] }}"}, {"read_from_slice": "2020-01-01"}), - ("test_value_depends_on_next_page_token", {"read_from_token": "{{ next_page_token['offset'] }}"}, {"read_from_token": "12345"}), - ("test_value_depends_on_config", {"read_from_config": "{{ config['option'] }}"}, {"read_from_config": "OPTION"}), + ( + "test_static_param", + {"a_static_request_param": "a_static_value"}, + {"a_static_request_param": "a_static_value"}, + ), + ( + "test_value_depends_on_state", + {"read_from_state": "{{ stream_state['date'] }}"}, + {"read_from_state": "2021-01-01"}, + ), + ( + "test_value_depends_on_stream_slice", + {"read_from_slice": "{{ stream_slice['start_date'] }}"}, + {"read_from_slice": "2020-01-01"}, + ), + ( + "test_value_depends_on_next_page_token", + {"read_from_token": "{{ next_page_token['offset'] }}"}, + {"read_from_token": "12345"}, + ), + ( + "test_value_depends_on_config", + {"read_from_config": "{{ config['option'] }}"}, + {"read_from_config": "OPTION"}, + ), ( "test_parameter_is_interpolated", - {"{{ stream_state['date'] }} - {{stream_slice['start_date']}} - {{next_page_token['offset']}} - {{config['option']}}": "ABC"}, + { + "{{ stream_state['date'] }} - {{stream_slice['start_date']}} - {{next_page_token['offset']}} - {{config['option']}}": "ABC" + }, {"2021-01-01 - 2020-01-01 - 12345 - OPTION": "ABC"}, ), ("test_boolean_false_value", {"boolean_false": "{{ False }}"}, {"boolean_false": "False"}), @@ -34,9 +56,13 @@ ], ) def test_interpolated_request_params(test_name, input_request_params, expected_request_params): - provider = InterpolatedRequestOptionsProvider(config=config, request_parameters=input_request_params, parameters={}) + provider = InterpolatedRequestOptionsProvider( + config=config, request_parameters=input_request_params, parameters={} + ) - actual_request_params = provider.get_request_params(stream_state=state, stream_slice=stream_slice, next_page_token=next_page_token) + actual_request_params = provider.get_request_params( + stream_state=state, stream_slice=stream_slice, next_page_token=next_page_token + ) assert actual_request_params == expected_request_params @@ -44,11 +70,31 @@ def test_interpolated_request_params(test_name, input_request_params, expected_r @pytest.mark.parametrize( "test_name, input_request_json, expected_request_json", [ - ("test_static_json", {"a_static_request_param": "a_static_value"}, {"a_static_request_param": "a_static_value"}), - ("test_value_depends_on_state", {"read_from_state": "{{ stream_state['date'] }}"}, {"read_from_state": "2021-01-01"}), - ("test_value_depends_on_stream_slice", {"read_from_slice": "{{ stream_slice['start_date'] }}"}, {"read_from_slice": "2020-01-01"}), - ("test_value_depends_on_next_page_token", {"read_from_token": "{{ next_page_token['offset'] }}"}, {"read_from_token": 12345}), - ("test_value_depends_on_config", {"read_from_config": "{{ config['option'] }}"}, {"read_from_config": "OPTION"}), + ( + "test_static_json", + {"a_static_request_param": "a_static_value"}, + {"a_static_request_param": "a_static_value"}, + ), + ( + "test_value_depends_on_state", + {"read_from_state": "{{ stream_state['date'] }}"}, + {"read_from_state": "2021-01-01"}, + ), + ( + "test_value_depends_on_stream_slice", + {"read_from_slice": "{{ stream_slice['start_date'] }}"}, + {"read_from_slice": "2020-01-01"}, + ), + ( + "test_value_depends_on_next_page_token", + {"read_from_token": "{{ next_page_token['offset'] }}"}, + {"read_from_token": 12345}, + ), + ( + "test_value_depends_on_config", + {"read_from_config": "{{ config['option'] }}"}, + {"read_from_config": "OPTION"}, + ), ( "test_interpolated_keys", {"{{ stream_state['date'] }}": 123, "{{ config['option'] }}": "ABC"}, @@ -59,8 +105,16 @@ def test_interpolated_request_params(test_name, input_request_params, expected_r ("test_number_falsy_value", {"number_falsy": "{{ 0.0 }}"}, {"number_falsy": 0.0}), ("test_string_falsy_value", {"string_falsy": "{{ '' }}"}, {}), ("test_none_value", {"none_value": "{{ None }}"}, {}), - ("test_string", """{"nested": { "key": "{{ config['option'] }}" }}""", {"nested": {"key": "OPTION"}}), - ("test_nested_objects", {"nested": {"key": "{{ config['option'] }}"}}, {"nested": {"key": "OPTION"}}), + ( + "test_string", + """{"nested": { "key": "{{ config['option'] }}" }}""", + {"nested": {"key": "OPTION"}}, + ), + ( + "test_nested_objects", + {"nested": {"key": "{{ config['option'] }}"}}, + {"nested": {"key": "OPTION"}}, + ), ( "test_nested_objects_interpolated keys", {"nested": {"{{ stream_state['date'] }}": "{{ config['option'] }}"}}, @@ -69,9 +123,13 @@ def test_interpolated_request_params(test_name, input_request_params, expected_r ], ) def test_interpolated_request_json(test_name, input_request_json, expected_request_json): - provider = InterpolatedRequestOptionsProvider(config=config, request_body_json=input_request_json, parameters={}) + provider = InterpolatedRequestOptionsProvider( + config=config, request_body_json=input_request_json, parameters={} + ) - actual_request_json = provider.get_request_body_json(stream_state=state, stream_slice=stream_slice, next_page_token=next_page_token) + actual_request_json = provider.get_request_body_json( + stream_state=state, stream_slice=stream_slice, next_page_token=next_page_token + ) assert actual_request_json == expected_request_json @@ -79,17 +137,37 @@ def test_interpolated_request_json(test_name, input_request_json, expected_reque @pytest.mark.parametrize( "test_name, input_request_data, expected_request_data", [ - ("test_static_map_data", {"a_static_request_param": "a_static_value"}, {"a_static_request_param": "a_static_value"}), - ("test_map_depends_on_stream_slice", {"read_from_slice": "{{ stream_slice['start_date'] }}"}, {"read_from_slice": "2020-01-01"}), - ("test_map_depends_on_config", {"read_from_config": "{{ config['option'] }}"}, {"read_from_config": "OPTION"}), + ( + "test_static_map_data", + {"a_static_request_param": "a_static_value"}, + {"a_static_request_param": "a_static_value"}, + ), + ( + "test_map_depends_on_stream_slice", + {"read_from_slice": "{{ stream_slice['start_date'] }}"}, + {"read_from_slice": "2020-01-01"}, + ), + ( + "test_map_depends_on_config", + {"read_from_config": "{{ config['option'] }}"}, + {"read_from_config": "OPTION"}, + ), ("test_defaults_to_empty_dict", None, {}), - ("test_interpolated_keys", {"{{ stream_state['date'] }} - {{ next_page_token['offset'] }}": "ABC"}, {"2021-01-01 - 12345": "ABC"}), + ( + "test_interpolated_keys", + {"{{ stream_state['date'] }} - {{ next_page_token['offset'] }}": "ABC"}, + {"2021-01-01 - 12345": "ABC"}, + ), ], ) def test_interpolated_request_data(test_name, input_request_data, expected_request_data): - provider = InterpolatedRequestOptionsProvider(config=config, request_body_data=input_request_data, parameters={}) + provider = InterpolatedRequestOptionsProvider( + config=config, request_body_data=input_request_data, parameters={} + ) - actual_request_data = provider.get_request_body_data(stream_state=state, stream_slice=stream_slice, next_page_token=next_page_token) + actual_request_data = provider.get_request_body_data( + stream_state=state, stream_slice=stream_slice, next_page_token=next_page_token + ) assert actual_request_data == expected_request_data @@ -98,16 +176,41 @@ def test_error_on_create_for_both_request_json_and_data(): request_json = {"body_key": "{{ stream_slice['start_date'] }}"} request_data = "interpolate_me=5&invalid={{ config['option'] }}" with pytest.raises(ValueError): - InterpolatedRequestOptionsProvider(config=config, request_body_json=request_json, request_body_data=request_data, parameters={}) + InterpolatedRequestOptionsProvider( + config=config, + request_body_json=request_json, + request_body_data=request_data, + parameters={}, + ) @pytest.mark.parametrize( "request_option_type,request_input,contains_state", [ - pytest.param("request_parameter", {"start": "{{ stream_state.get('start_date') }}"}, True, id="test_request_parameter_has_state"), - pytest.param("request_parameter", {"start": "{{ slice_interval.get('start_date') }}"}, False, id="test_request_parameter_no_state"), - pytest.param("request_header", {"start": "{{ stream_state.get('start_date') }}"}, True, id="test_request_header_has_state"), - pytest.param("request_header", {"start": "{{ slice_interval.get('start_date') }}"}, False, id="test_request_header_no_state"), + pytest.param( + "request_parameter", + {"start": "{{ stream_state.get('start_date') }}"}, + True, + id="test_request_parameter_has_state", + ), + pytest.param( + "request_parameter", + {"start": "{{ slice_interval.get('start_date') }}"}, + False, + id="test_request_parameter_no_state", + ), + pytest.param( + "request_header", + {"start": "{{ stream_state.get('start_date') }}"}, + True, + id="test_request_header_has_state", + ), + pytest.param( + "request_header", + {"start": "{{ slice_interval.get('start_date') }}"}, + False, + id="test_request_header_no_state", + ), pytest.param( "request_body_data", "[{'query': {'type': 'timestamp', 'value': stream_state.get('start_date')}}]", @@ -120,9 +223,17 @@ def test_error_on_create_for_both_request_json_and_data(): False, id="test_request_body_data_no_state", ), - pytest.param("request_body_json", {"start": "{{ stream_state.get('start_date') }}"}, True, id="test_request_body_json_has_state"), pytest.param( - "request_body_json", {"start": "{{ slice_interval.get('start_date') }}"}, False, id="test_request_request_body_json_no_state" + "request_body_json", + {"start": "{{ stream_state.get('start_date') }}"}, + True, + id="test_request_body_json_has_state", + ), + pytest.param( + "request_body_json", + {"start": "{{ slice_interval.get('start_date') }}"}, + False, + id="test_request_request_body_json_no_state", ), ], ) @@ -130,14 +241,24 @@ def test_request_options_contain_stream_state(request_option_type, request_input request_options_provider: InterpolatedRequestOptionsProvider match request_option_type: case "request_parameter": - request_options_provider = InterpolatedRequestOptionsProvider(config=config, request_parameters=request_input, parameters={}) + request_options_provider = InterpolatedRequestOptionsProvider( + config=config, request_parameters=request_input, parameters={} + ) case "request_header": - request_options_provider = InterpolatedRequestOptionsProvider(config=config, request_headers=request_input, parameters={}) + request_options_provider = InterpolatedRequestOptionsProvider( + config=config, request_headers=request_input, parameters={} + ) case "request_body_data": - request_options_provider = InterpolatedRequestOptionsProvider(config=config, request_body_data=request_input, parameters={}) + request_options_provider = InterpolatedRequestOptionsProvider( + config=config, request_body_data=request_input, parameters={} + ) case "request_body_json": - request_options_provider = InterpolatedRequestOptionsProvider(config=config, request_body_json=request_input, parameters={}) + request_options_provider = InterpolatedRequestOptionsProvider( + config=config, request_body_json=request_input, parameters={} + ) case _: - request_options_provider = InterpolatedRequestOptionsProvider(config=config, parameters={}) + request_options_provider = InterpolatedRequestOptionsProvider( + config=config, parameters={} + ) assert request_options_provider.request_options_contain_stream_state() == contains_state diff --git a/unit_tests/sources/declarative/requesters/test_http_job_repository.py b/unit_tests/sources/declarative/requesters/test_http_job_repository.py index 98ccd600..aa2a13f7 100644 --- a/unit_tests/sources/declarative/requesters/test_http_job_repository.py +++ b/unit_tests/sources/declarative/requesters/test_http_job_repository.py @@ -9,13 +9,22 @@ from airbyte_cdk.sources.declarative.async_job.status import AsyncJobStatus from airbyte_cdk.sources.declarative.decoders import NoopDecoder from airbyte_cdk.sources.declarative.decoders.json_decoder import JsonDecoder -from airbyte_cdk.sources.declarative.extractors import DpathExtractor, RecordSelector, ResponseToFileExtractor +from airbyte_cdk.sources.declarative.extractors import ( + DpathExtractor, + RecordSelector, + ResponseToFileExtractor, +) from airbyte_cdk.sources.declarative.requesters.error_handlers import DefaultErrorHandler from airbyte_cdk.sources.declarative.requesters.http_job_repository import AsyncHttpJobRepository from airbyte_cdk.sources.declarative.requesters.http_requester import HttpRequester from airbyte_cdk.sources.declarative.requesters.paginators import DefaultPaginator -from airbyte_cdk.sources.declarative.requesters.paginators.strategies.cursor_pagination_strategy import CursorPaginationStrategy -from airbyte_cdk.sources.declarative.requesters.request_option import RequestOption, RequestOptionType +from airbyte_cdk.sources.declarative.requesters.paginators.strategies.cursor_pagination_strategy import ( + CursorPaginationStrategy, +) +from airbyte_cdk.sources.declarative.requesters.request_option import ( + RequestOption, + RequestOptionType, +) from airbyte_cdk.sources.declarative.requesters.requester import HttpMethod from airbyte_cdk.sources.declarative.retrievers.simple_retriever import SimpleRetriever from airbyte_cdk.sources.types import StreamSlice @@ -122,13 +131,23 @@ def setUp(self) -> None: download_retriever=self._download_retriever, abort_requester=None, delete_requester=None, - status_extractor=DpathExtractor(decoder=JsonDecoder(parameters={}), field_path=["status"], config={}, parameters={} or {}), + status_extractor=DpathExtractor( + decoder=JsonDecoder(parameters={}), + field_path=["status"], + config={}, + parameters={} or {}, + ), status_mapping={ "ready": AsyncJobStatus.COMPLETED, "failure": AsyncJobStatus.FAILED, "pending": AsyncJobStatus.RUNNING, }, - urls_extractor=DpathExtractor(decoder=JsonDecoder(parameters={}), field_path=["urls"], config={}, parameters={} or {}), + urls_extractor=DpathExtractor( + decoder=JsonDecoder(parameters={}), + field_path=["urls"], + config={}, + parameters={} or {}, + ), ) self._http_mocker = HttpMocker() @@ -137,7 +156,9 @@ def setUp(self) -> None: def tearDown(self) -> None: self._http_mocker.__exit__(None, None, None) - def test_given_different_statuses_when_update_jobs_status_then_update_status_properly(self) -> None: + def test_given_different_statuses_when_update_jobs_status_then_update_status_properly( + self, + ) -> None: self._mock_create_response(_A_JOB_ID) self._http_mocker.get( HttpRequest(url=f"{_EXPORT_URL}/{_A_JOB_ID}"), @@ -167,7 +188,9 @@ def test_given_unknown_status_when_update_jobs_status_then_raise_error(self) -> with pytest.raises(ValueError): self._repository.update_jobs_status([job]) - def test_given_multiple_jobs_when_update_jobs_status_then_all_the_jobs_are_updated(self) -> None: + def test_given_multiple_jobs_when_update_jobs_status_then_all_the_jobs_are_updated( + self, + ) -> None: self._http_mocker.post( HttpRequest(url=_EXPORT_URL), [ @@ -195,11 +218,15 @@ def test_given_pagination_when_fetch_records_then_yield_records_from_all_pages(s self._mock_create_response(_A_JOB_ID) self._http_mocker.get( HttpRequest(url=f"{_EXPORT_URL}/{_A_JOB_ID}"), - HttpResponse(body=json.dumps({"id": _A_JOB_ID, "status": "ready", "urls": [_JOB_FIRST_URL]})), + HttpResponse( + body=json.dumps({"id": _A_JOB_ID, "status": "ready", "urls": [_JOB_FIRST_URL]}) + ), ) self._http_mocker.get( HttpRequest(url=_JOB_FIRST_URL), - HttpResponse(body=_A_CSV_WITH_ONE_RECORD, headers={"Sforce-Locator": _A_CURSOR_FOR_PAGINATION}), + HttpResponse( + body=_A_CSV_WITH_ONE_RECORD, headers={"Sforce-Locator": _A_CURSOR_FOR_PAGINATION} + ), ) self._http_mocker.get( HttpRequest(url=_JOB_FIRST_URL, query_params={"locator": _A_CURSOR_FOR_PAGINATION}), diff --git a/unit_tests/sources/declarative/requesters/test_http_requester.py b/unit_tests/sources/declarative/requesters/test_http_requester.py index 404bf9f5..1428319f 100644 --- a/unit_tests/sources/declarative/requesters/test_http_requester.py +++ b/unit_tests/sources/declarative/requesters/test_http_requester.py @@ -11,13 +11,23 @@ import requests from airbyte_cdk.sources.declarative.auth.declarative_authenticator import DeclarativeAuthenticator from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString -from airbyte_cdk.sources.declarative.requesters.error_handlers.backoff_strategies import ConstantBackoffStrategy, ExponentialBackoffStrategy -from airbyte_cdk.sources.declarative.requesters.error_handlers.default_error_handler import DefaultErrorHandler +from airbyte_cdk.sources.declarative.requesters.error_handlers.backoff_strategies import ( + ConstantBackoffStrategy, + ExponentialBackoffStrategy, +) +from airbyte_cdk.sources.declarative.requesters.error_handlers.default_error_handler import ( + DefaultErrorHandler, +) from airbyte_cdk.sources.declarative.requesters.error_handlers.error_handler import ErrorHandler from airbyte_cdk.sources.declarative.requesters.http_requester import HttpMethod, HttpRequester -from airbyte_cdk.sources.declarative.requesters.request_options import InterpolatedRequestOptionsProvider +from airbyte_cdk.sources.declarative.requesters.request_options import ( + InterpolatedRequestOptionsProvider, +) from airbyte_cdk.sources.message import MessageRepository -from airbyte_cdk.sources.streams.http.exceptions import RequestBodyException, UserDefinedBackoffException +from airbyte_cdk.sources.streams.http.exceptions import ( + RequestBodyException, + UserDefinedBackoffException, +) from airbyte_cdk.sources.types import Config from requests import PreparedRequest @@ -99,12 +109,24 @@ def test_http_requester(): ) assert requester.get_url_base() == "https://airbyte.io/" - assert requester.get_path(stream_state={}, stream_slice=stream_slice, next_page_token={}) == "v1/1234" + assert ( + requester.get_path(stream_state={}, stream_slice=stream_slice, next_page_token={}) + == "v1/1234" + ) assert requester.get_authenticator() == authenticator assert requester.get_method() == http_method - assert requester.get_request_params(stream_state={}, stream_slice=None, next_page_token=None) == request_params - assert requester.get_request_body_data(stream_state={}, stream_slice=None, next_page_token=None) == request_body_data - assert requester.get_request_body_json(stream_state={}, stream_slice=None, next_page_token=None) == request_body_json + assert ( + requester.get_request_params(stream_state={}, stream_slice=None, next_page_token=None) + == request_params + ) + assert ( + requester.get_request_body_data(stream_state={}, stream_slice=None, next_page_token=None) + == request_body_data + ) + assert ( + requester.get_request_body_json(stream_state={}, stream_slice=None, next_page_token=None) + == request_body_json + ) @pytest.mark.parametrize( @@ -200,14 +222,50 @@ def test_basic_send_request(): [ # merging data params from the three sources ({"field": "value"}, None, None, None, None, None, None, "field=value"), - ({"field": "value"}, None, {"field2": "value"}, None, None, None, None, "field=value&field2=value"), - ({"field": "value"}, None, {"field2": "value"}, None, {"authfield": "val"}, None, None, "field=value&field2=value&authfield=val"), + ( + {"field": "value"}, + None, + {"field2": "value"}, + None, + None, + None, + None, + "field=value&field2=value", + ), + ( + {"field": "value"}, + None, + {"field2": "value"}, + None, + {"authfield": "val"}, + None, + None, + "field=value&field2=value&authfield=val", + ), ({"field": "value"}, None, {"field": "value"}, None, None, None, ValueError, None), ({"field": "value"}, None, None, None, {"field": "value"}, None, ValueError, None), - ({"field": "value"}, None, {"field2": "value"}, None, {"field": "value"}, None, ValueError, None), + ( + {"field": "value"}, + None, + {"field2": "value"}, + None, + {"field": "value"}, + None, + ValueError, + None, + ), # merging json params from the three sources (None, {"field": "value"}, None, None, None, None, None, '{"field": "value"}'), - (None, {"field": "value"}, None, {"field2": "value"}, None, None, None, '{"field": "value", "field2": "value"}'), + ( + None, + {"field": "value"}, + None, + {"field2": "value"}, + None, + None, + None, + '{"field": "value", "field2": "value"}', + ), ( None, {"field": "value"}, @@ -221,15 +279,67 @@ def test_basic_send_request(): (None, {"field": "value"}, None, {"field": "value"}, None, None, ValueError, None), (None, {"field": "value"}, None, None, None, {"field": "value"}, ValueError, None), # raise on mixed data and json params - ({"field": "value"}, {"field": "value"}, None, None, None, None, RequestBodyException, None), - ({"field": "value"}, None, None, {"field": "value"}, None, None, RequestBodyException, None), - (None, None, {"field": "value"}, {"field": "value"}, None, None, RequestBodyException, None), - (None, None, None, None, {"field": "value"}, {"field": "value"}, RequestBodyException, None), - ({"field": "value"}, None, None, None, None, {"field": "value"}, RequestBodyException, None), + ( + {"field": "value"}, + {"field": "value"}, + None, + None, + None, + None, + RequestBodyException, + None, + ), + ( + {"field": "value"}, + None, + None, + {"field": "value"}, + None, + None, + RequestBodyException, + None, + ), + ( + None, + None, + {"field": "value"}, + {"field": "value"}, + None, + None, + RequestBodyException, + None, + ), + ( + None, + None, + None, + None, + {"field": "value"}, + {"field": "value"}, + RequestBodyException, + None, + ), + ( + {"field": "value"}, + None, + None, + None, + None, + {"field": "value"}, + RequestBodyException, + None, + ), ], ) def test_send_request_data_json( - provider_data, provider_json, param_data, param_json, authenticator_data, authenticator_json, expected_exception, expected_body + provider_data, + provider_json, + param_data, + param_json, + authenticator_data, + authenticator_json, + expected_exception, + expected_body, ): options_provider = MagicMock() options_provider.get_request_body_data.return_value = provider_data @@ -246,7 +356,11 @@ def test_send_request_data_json( requester.send_request(request_body_data=param_data, request_body_json=param_json) sent_request: PreparedRequest = requester._http_client._session.send.call_args_list[0][0][0] if expected_body is not None: - assert sent_request.body == expected_body.decode("UTF-8") if not isinstance(expected_body, str) else expected_body + assert ( + sent_request.body == expected_body.decode("UTF-8") + if not isinstance(expected_body, str) + else expected_body + ) @pytest.mark.parametrize( @@ -267,7 +381,9 @@ def test_send_request_data_json( ("field=value", None, {"abc": "def"}, ValueError, None), ], ) -def test_send_request_string_data(provider_data, param_data, authenticator_data, expected_exception, expected_body): +def test_send_request_string_data( + provider_data, param_data, authenticator_data, expected_exception, expected_body +): options_provider = MagicMock() options_provider.get_request_body_data.return_value = provider_data authenticator = MagicMock() @@ -289,7 +405,13 @@ def test_send_request_string_data(provider_data, param_data, authenticator_data, [ # merging headers from the three sources ({"header": "value"}, None, None, None, {"header": "value"}), - ({"header": "value"}, {"header2": "value"}, None, None, {"header": "value", "header2": "value"}), + ( + {"header": "value"}, + {"header2": "value"}, + None, + None, + {"header": "value", "header2": "value"}, + ), ( {"header": "value"}, {"header2": "value"}, @@ -303,9 +425,16 @@ def test_send_request_string_data(provider_data, param_data, authenticator_data, ({"header": "value"}, {"header2": "value"}, {"header": "value"}, ValueError, None), ], ) -def test_send_request_headers(provider_headers, param_headers, authenticator_headers, expected_exception, expected_headers): +def test_send_request_headers( + provider_headers, param_headers, authenticator_headers, expected_exception, expected_headers +): # headers set by the requests framework, do not validate - default_headers = {"User-Agent": mock.ANY, "Accept-Encoding": mock.ANY, "Accept": mock.ANY, "Connection": mock.ANY} + default_headers = { + "User-Agent": mock.ANY, + "Accept-Encoding": mock.ANY, + "Accept": mock.ANY, + "Connection": mock.ANY, + } options_provider = MagicMock() options_provider.get_request_headers.return_value = provider_headers authenticator = MagicMock() @@ -326,15 +455,29 @@ def test_send_request_headers(provider_headers, param_headers, authenticator_hea [ # merging params from the three sources ({"param": "value"}, None, None, None, {"param": "value"}), - ({"param": "value"}, {"param2": "value"}, None, None, {"param": "value", "param2": "value"}), - ({"param": "value"}, {"param2": "value"}, {"authparam": "val"}, None, {"param": "value", "param2": "value", "authparam": "val"}), + ( + {"param": "value"}, + {"param2": "value"}, + None, + None, + {"param": "value", "param2": "value"}, + ), + ( + {"param": "value"}, + {"param2": "value"}, + {"authparam": "val"}, + None, + {"param": "value", "param2": "value", "authparam": "val"}, + ), # raise on conflicting params ({"param": "value"}, {"param": "value"}, None, ValueError, None), ({"param": "value"}, None, {"param": "value"}, ValueError, None), ({"param": "value"}, {"param2": "value"}, {"param": "value"}, ValueError, None), ], ) -def test_send_request_params(provider_params, param_params, authenticator_params, expected_exception, expected_params): +def test_send_request_params( + provider_params, param_params, authenticator_params, expected_exception, expected_params +): options_provider = MagicMock() options_provider.get_request_params.return_value = provider_params authenticator = MagicMock() @@ -356,7 +499,9 @@ def test_send_request_params(provider_params, param_params, authenticator_params "request_parameters, config, expected_query_params", [ pytest.param( - {"k": '{"updatedDateFrom": "2023-08-20T00:00:00Z", "updatedDateTo": "2023-08-20T23:59:59Z"}'}, + { + "k": '{"updatedDateFrom": "2023-08-20T00:00:00Z", "updatedDateTo": "2023-08-20T23:59:59Z"}' + }, {}, "k=%7B%22updatedDateFrom%22%3A+%222023-08-20T00%3A00%3A00Z%22%2C+%22updatedDateTo%22%3A+%222023-08-20T23%3A59%3A59Z%22%7D", id="test-request-parameter-dictionary", @@ -375,7 +520,12 @@ def test_send_request_params(provider_params, param_params, authenticator_params ), pytest.param( {"k": '{{ config["k"] }}'}, - {"k": {"updatedDateFrom": "2023-08-20T00:00:00Z", "updatedDateTo": "2023-08-20T23:59:59Z"}}, + { + "k": { + "updatedDateFrom": "2023-08-20T00:00:00Z", + "updatedDateTo": "2023-08-20T23:59:59Z", + } + }, # {'updatedDateFrom': '2023-08-20T00:00:00Z', 'updatedDateTo': '2023-08-20T23:59:59Z'} "k=%7B%27updatedDateFrom%27%3A+%272023-08-20T00%3A00%3A00Z%27%2C+%27updatedDateTo%27%3A+%272023-08-20T23%3A59%3A59Z%27%7D", id="test-request-parameter-from-config-object", @@ -437,7 +587,12 @@ def test_request_param_interpolation(request_parameters, config, expected_query_ "request_parameters, config, invalid_value_for_key", [ pytest.param( - {"k": {"updatedDateFrom": "2023-08-20T00:00:00Z", "updatedDateTo": "2023-08-20T23:59:59Z"}}, + { + "k": { + "updatedDateFrom": "2023-08-20T00:00:00Z", + "updatedDateTo": "2023-08-20T23:59:59Z", + } + }, {}, "k", id="test-request-parameter-object-of-the-updated-info", @@ -450,7 +605,9 @@ def test_request_param_interpolation(request_parameters, config, expected_query_ ), ], ) -def test_request_param_interpolation_with_incorrect_values(request_parameters, config, invalid_value_for_key): +def test_request_param_interpolation_with_incorrect_values( + request_parameters, config, invalid_value_for_key +): options_provider = InterpolatedRequestOptionsProvider( config=config, request_parameters=request_parameters, @@ -464,7 +621,8 @@ def test_request_param_interpolation_with_incorrect_values(request_parameters, c requester.send_request() assert ( - error.value.args[0] == f"Invalid value for `{invalid_value_for_key}` parameter. The values of request params cannot be an object." + error.value.args[0] + == f"Invalid value for `{invalid_value_for_key}` parameter. The values of request params cannot be an object." ) @@ -472,7 +630,9 @@ def test_request_param_interpolation_with_incorrect_values(request_parameters, c "request_body_data, config, expected_request_body_data", [ pytest.param( - {"k": '{"updatedDateFrom": "2023-08-20T00:00:00Z", "updatedDateTo": "2023-08-20T23:59:59Z"}'}, + { + "k": '{"updatedDateFrom": "2023-08-20T00:00:00Z", "updatedDateTo": "2023-08-20T23:59:59Z"}' + }, {}, # k={"updatedDateFrom": "2023-08-20T00:00:00Z", "updatedDateTo": "2023-08-20T23:59:59Z"} "k=%7B%22updatedDateFrom%22%3A+%222023-08-20T00%3A00%3A00Z%22%2C+%22updatedDateTo%22%3A+%222023-08-20T23%3A59%3A59Z%22%7D", @@ -504,7 +664,12 @@ def test_request_param_interpolation_with_incorrect_values(request_parameters, c ), pytest.param( {"k": '{{ config["k"] }}'}, - {"k": {"updatedDateFrom": "2023-08-20T00:00:00Z", "updatedDateTo": "2023-08-20T23:59:59Z"}}, + { + "k": { + "updatedDateFrom": "2023-08-20T00:00:00Z", + "updatedDateTo": "2023-08-20T23:59:59Z", + } + }, # k={'updatedDateFrom': '2023-08-20T00:00:00Z', 'updatedDateTo': '2023-08-20T23:59:59Z'} "k=%7B%27updatedDateFrom%27%3A+%272023-08-20T00%3A00%3A00Z%27%2C+%27updatedDateTo%27%3A+%272023-08-20T23%3A59%3A59Z%27%7D", id="test-request-body-from-config-object", @@ -534,7 +699,9 @@ def test_request_param_interpolation_with_incorrect_values(request_parameters, c id="test-key-with-list-is-not-interpolated", ), pytest.param( - {"k": "{'updatedDateFrom': '2023-08-20T00:00:00Z', 'updatedDateTo': '2023-08-20T23:59:59Z'}"}, + { + "k": "{'updatedDateFrom': '2023-08-20T00:00:00Z', 'updatedDateTo': '2023-08-20T23:59:59Z'}" + }, {}, # k={'updatedDateFrom': '2023-08-20T00:00:00Z', 'updatedDateTo': '2023-08-20T23:59:59Z'} "k=%7B%27updatedDateFrom%27%3A+%272023-08-20T00%3A00%3A00Z%27%2C+%27updatedDateTo%27%3A+%272023-08-20T23%3A59%3A59Z%27%7D", @@ -571,8 +738,16 @@ def test_request_body_interpolation(request_body_data, config, expected_request_ ], ) def test_send_request_path(requester_path, param_path, expected_path): - requester = create_requester(config={"config_key": "config_value"}, path=requester_path, parameters={"param_key": "param_value"}) - requester.send_request(stream_slice={"start": "2012"}, next_page_token={"next_page_token": "pagetoken"}, path=param_path) + requester = create_requester( + config={"config_key": "config_value"}, + path=requester_path, + parameters={"param_key": "param_value"}, + ) + requester.send_request( + stream_slice={"start": "2012"}, + next_page_token={"next_page_token": "pagetoken"}, + path=param_path, + ) sent_request: PreparedRequest = requester._http_client._session.send.call_args_list[0][0][0] parsed_url = urlparse(sent_request.url) assert parsed_url.path == expected_path @@ -615,17 +790,42 @@ def test_send_request_stream_slice_next_page_token(): "test_name, base_url, path, expected_full_url", [ ("test_no_slashes", "https://airbyte.io", "my_endpoint", "https://airbyte.io/my_endpoint"), - ("test_trailing_slash_on_base_url", "https://airbyte.io/", "my_endpoint", "https://airbyte.io/my_endpoint"), + ( + "test_trailing_slash_on_base_url", + "https://airbyte.io/", + "my_endpoint", + "https://airbyte.io/my_endpoint", + ), ( "test_trailing_slash_on_base_url_and_leading_slash_on_path", "https://airbyte.io/", "/my_endpoint", "https://airbyte.io/my_endpoint", ), - ("test_leading_slash_on_path", "https://airbyte.io", "/my_endpoint", "https://airbyte.io/my_endpoint"), - ("test_trailing_slash_on_path", "https://airbyte.io", "/my_endpoint/", "https://airbyte.io/my_endpoint/"), - ("test_nested_path_no_leading_slash", "https://airbyte.io", "v1/my_endpoint", "https://airbyte.io/v1/my_endpoint"), - ("test_nested_path_with_leading_slash", "https://airbyte.io", "/v1/my_endpoint", "https://airbyte.io/v1/my_endpoint"), + ( + "test_leading_slash_on_path", + "https://airbyte.io", + "/my_endpoint", + "https://airbyte.io/my_endpoint", + ), + ( + "test_trailing_slash_on_path", + "https://airbyte.io", + "/my_endpoint/", + "https://airbyte.io/my_endpoint/", + ), + ( + "test_nested_path_no_leading_slash", + "https://airbyte.io", + "v1/my_endpoint", + "https://airbyte.io/v1/my_endpoint", + ), + ( + "test_nested_path_with_leading_slash", + "https://airbyte.io", + "/v1/my_endpoint", + "https://airbyte.io/v1/my_endpoint", + ), ], ) def test_join_url(test_name, base_url, path, expected_full_url): @@ -655,8 +855,12 @@ def test_request_attempt_count_is_tracked_across_retries(http_requester_factory) request_mock.url = "https://example.com/deals" request_mock.method = "GET" request_mock.body = {} - backoff_strategy = ConstantBackoffStrategy(parameters={}, config={}, backoff_time_in_seconds=0.1) - error_handler = DefaultErrorHandler(parameters={}, config={}, max_retries=1, backoff_strategies=[backoff_strategy]) + backoff_strategy = ConstantBackoffStrategy( + parameters={}, config={}, backoff_time_in_seconds=0.1 + ) + error_handler = DefaultErrorHandler( + parameters={}, config={}, max_retries=1, backoff_strategies=[backoff_strategy] + ) http_requester = http_requester_factory(error_handler=error_handler) http_requester._http_client._session.send = MagicMock() response = requests.Response() @@ -666,7 +870,10 @@ def test_request_attempt_count_is_tracked_across_retries(http_requester_factory) with pytest.raises(UserDefinedBackoffException): http_requester._http_client._send_with_retry(request=request_mock, request_kwargs={}) - assert http_requester._http_client._request_attempt_count.get(request_mock) == http_requester._http_client._max_retries + 1 + assert ( + http_requester._http_client._request_attempt_count.get(request_mock) + == http_requester._http_client._max_retries + 1 + ) @pytest.mark.usefixtures("mock_sleep") @@ -677,7 +884,9 @@ def test_request_attempt_count_with_exponential_backoff_strategy(http_requester_ request_mock.method = "GET" request_mock.body = {} backoff_strategy = ExponentialBackoffStrategy(parameters={}, config={}, factor=0.01) - error_handler = DefaultErrorHandler(parameters={}, config={}, max_retries=2, backoff_strategies=[backoff_strategy]) + error_handler = DefaultErrorHandler( + parameters={}, config={}, max_retries=2, backoff_strategies=[backoff_strategy] + ) http_requester = http_requester_factory(error_handler=error_handler) http_requester._http_client._session.send = MagicMock() response = requests.Response() @@ -687,4 +896,7 @@ def test_request_attempt_count_with_exponential_backoff_strategy(http_requester_ with pytest.raises(UserDefinedBackoffException): http_requester._http_client._send_with_retry(request=request_mock, request_kwargs={}) - assert http_requester._http_client._request_attempt_count.get(request_mock) == http_requester._http_client._max_retries + 1 + assert ( + http_requester._http_client._request_attempt_count.get(request_mock) + == http_requester._http_client._max_retries + 1 + ) diff --git a/unit_tests/sources/declarative/requesters/test_interpolated_request_input_provider.py b/unit_tests/sources/declarative/requesters/test_interpolated_request_input_provider.py index 3f80b7ee..8882e918 100644 --- a/unit_tests/sources/declarative/requesters/test_interpolated_request_input_provider.py +++ b/unit_tests/sources/declarative/requesters/test_interpolated_request_input_provider.py @@ -4,15 +4,29 @@ import pytest as pytest from airbyte_cdk.sources.declarative.interpolation.interpolated_mapping import InterpolatedMapping -from airbyte_cdk.sources.declarative.requesters.request_options.interpolated_request_input_provider import InterpolatedRequestInputProvider +from airbyte_cdk.sources.declarative.requesters.request_options.interpolated_request_input_provider import ( + InterpolatedRequestInputProvider, +) @pytest.mark.parametrize( "test_name, input_request_data, expected_request_data", [ - ("test_static_map_data", {"a_static_request_param": "a_static_value"}, {"a_static_request_param": "a_static_value"}), - ("test_map_depends_on_stream_slice", {"read_from_slice": "{{ stream_slice['slice_key'] }}"}, {"read_from_slice": "slice_value"}), - ("test_map_depends_on_config", {"read_from_config": "{{ config['config_key'] }}"}, {"read_from_config": "value_of_config"}), + ( + "test_static_map_data", + {"a_static_request_param": "a_static_value"}, + {"a_static_request_param": "a_static_value"}, + ), + ( + "test_map_depends_on_stream_slice", + {"read_from_slice": "{{ stream_slice['slice_key'] }}"}, + {"read_from_slice": "slice_value"}, + ), + ( + "test_map_depends_on_config", + {"read_from_config": "{{ config['config_key'] }}"}, + {"read_from_config": "value_of_config"}, + ), ( "test_map_depends_on_parameters", {"read_from_parameters": "{{ parameters['read_from_parameters'] }}"}, @@ -21,11 +35,15 @@ ("test_defaults_to_empty_dictionary", None, {}), ], ) -def test_initialize_interpolated_mapping_request_input_provider(test_name, input_request_data, expected_request_data): +def test_initialize_interpolated_mapping_request_input_provider( + test_name, input_request_data, expected_request_data +): config = {"config_key": "value_of_config"} stream_slice = {"slice_key": "slice_value"} parameters = {"read_from_parameters": "value_of_parameters"} - provider = InterpolatedRequestInputProvider(request_inputs=input_request_data, config=config, parameters=parameters) + provider = InterpolatedRequestInputProvider( + request_inputs=input_request_data, config=config, parameters=parameters + ) actual_request_data = provider.eval_request_inputs(stream_state={}, stream_slice=stream_slice) assert isinstance(provider._interpolator, InterpolatedMapping) diff --git a/unit_tests/sources/declarative/retrievers/test_simple_retriever.py b/unit_tests/sources/declarative/retrievers/test_simple_retriever.py index 2fd0594b..d2eb2d15 100644 --- a/unit_tests/sources/declarative/retrievers/test_simple_retriever.py +++ b/unit_tests/sources/declarative/retrievers/test_simple_retriever.py @@ -10,15 +10,24 @@ from airbyte_cdk import YamlDeclarativeSource from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, Level, SyncMode, Type from airbyte_cdk.sources.declarative.auth.declarative_authenticator import NoAuth -from airbyte_cdk.sources.declarative.incremental import DatetimeBasedCursor, DeclarativeCursor, ResumableFullRefreshCursor +from airbyte_cdk.sources.declarative.incremental import ( + DatetimeBasedCursor, + DeclarativeCursor, + ResumableFullRefreshCursor, +) from airbyte_cdk.sources.declarative.models import DeclarativeStream as DeclarativeStreamModel -from airbyte_cdk.sources.declarative.parsers.model_to_component_factory import ModelToComponentFactory +from airbyte_cdk.sources.declarative.parsers.model_to_component_factory import ( + ModelToComponentFactory, +) from airbyte_cdk.sources.declarative.partition_routers import SinglePartitionRouter from airbyte_cdk.sources.declarative.requesters.paginators import DefaultPaginator from airbyte_cdk.sources.declarative.requesters.paginators.strategies import PageIncrement from airbyte_cdk.sources.declarative.requesters.request_option import RequestOptionType from airbyte_cdk.sources.declarative.requesters.requester import HttpMethod -from airbyte_cdk.sources.declarative.retrievers.simple_retriever import SimpleRetriever, SimpleRetrieverTestReadDecorator +from airbyte_cdk.sources.declarative.retrievers.simple_retriever import ( + SimpleRetriever, + SimpleRetrieverTestReadDecorator, +) from airbyte_cdk.sources.types import Record, StreamSlice A_SLICE_STATE = {"slice_state": "slice state value"} @@ -149,7 +158,9 @@ def test_simple_retriever_with_request_response_logs(mock_http_stream): pytest.param({"next_page_token": 10}, 10, 11, id="test_reset_with_next_page_token"), ], ) -def test_simple_retriever_resumable_full_refresh_cursor_page_increment(initial_state, expected_reset_value, expected_next_page): +def test_simple_retriever_resumable_full_refresh_cursor_page_increment( + initial_state, expected_reset_value, expected_next_page +): expected_records = [ Record(data={"id": "abc"}, associated_slice=None), Record(data={"id": "def"}, associated_slice=None), @@ -163,7 +174,9 @@ def test_simple_retriever_resumable_full_refresh_cursor_page_increment(initial_s response = requests.Response() response.status_code = 200 - response._content = json.dumps({"data": [record.data for record in expected_records[:5]]}).encode("utf-8") + response._content = json.dumps( + {"data": [record.data for record in expected_records[:5]]} + ).encode("utf-8") requester = MagicMock() requester.send_request.side_effect = [ @@ -188,7 +201,12 @@ def test_simple_retriever_resumable_full_refresh_cursor_page_increment(initial_s ] page_increment_strategy = PageIncrement(config={}, page_size=5, parameters={}) - paginator = DefaultPaginator(config={}, pagination_strategy=page_increment_strategy, url_base="https://airbyte.io", parameters={}) + paginator = DefaultPaginator( + config={}, + pagination_strategy=page_increment_strategy, + url_base="https://airbyte.io", + parameters={}, + ) paginator.reset = Mock(wraps=paginator.reset) stream_slicer = ResumableFullRefreshCursor(parameters={}) @@ -208,13 +226,17 @@ def test_simple_retriever_resumable_full_refresh_cursor_page_increment(initial_s ) stream_slice = list(stream_slicer.stream_slices())[0] - actual_records = [r for r in retriever.read_records(records_schema={}, stream_slice=stream_slice)] + actual_records = [ + r for r in retriever.read_records(records_schema={}, stream_slice=stream_slice) + ] assert len(actual_records) == 5 assert actual_records == expected_records[:5] assert retriever.state == {"next_page_token": expected_next_page} - actual_records = [r for r in retriever.read_records(records_schema={}, stream_slice=stream_slice)] + actual_records = [ + r for r in retriever.read_records(records_schema={}, stream_slice=stream_slice) + ] assert len(actual_records) == 3 assert actual_records == expected_records[5:] assert retriever.state == {"__ab_full_refresh_sync_complete": True} @@ -227,7 +249,9 @@ def test_simple_retriever_resumable_full_refresh_cursor_page_increment(initial_s [ pytest.param(None, None, 1, id="test_initial_sync_no_state"), pytest.param( - {"next_page_token": "https://for-all-mankind.nasa.com/api/v1/astronauts?next_page=tracy_stevens"}, + { + "next_page_token": "https://for-all-mankind.nasa.com/api/v1/astronauts?next_page=tracy_stevens" + }, "https://for-all-mankind.nasa.com/api/v1/astronauts?next_page=tracy_stevens", "https://for-all-mankind.nasa.com/api/v1/astronauts?next_page=gordo_stevens", id="test_reset_with_next_page_token", @@ -287,17 +311,24 @@ def test_simple_retriever_resumable_full_refresh_cursor_reset_cursor_pagination( factory = ModelToComponentFactory() stream_manifest = YamlDeclarativeSource._parse(content) - stream = factory.create_component(model_type=DeclarativeStreamModel, component_definition=stream_manifest, config={}) + stream = factory.create_component( + model_type=DeclarativeStreamModel, component_definition=stream_manifest, config={} + ) response_body = { "data": [r.data for r in expected_records[:5]], "next_page": "https://for-all-mankind.nasa.com/api/v1/astronauts?next_page=gordo_stevens", } requests_mock.get("https://for-all-mankind.nasa.com/api/v1/astronauts", json=response_body) - requests_mock.get("https://for-all-mankind.nasa.com/astronauts?next_page=tracy_stevens", json=response_body) + requests_mock.get( + "https://for-all-mankind.nasa.com/astronauts?next_page=tracy_stevens", json=response_body + ) response_body_2 = { "data": [r.data for r in expected_records[5:]], } - requests_mock.get("https://for-all-mankind.nasa.com/api/v1/astronauts?next_page=gordo_stevens", json=response_body_2) + requests_mock.get( + "https://for-all-mankind.nasa.com/api/v1/astronauts?next_page=gordo_stevens", + json=response_body_2, + ) stream.retriever.paginator.reset = Mock(wraps=stream.retriever.paginator.reset) stream_slicer = ResumableFullRefreshCursor(parameters={}) if initial_state: @@ -305,14 +336,24 @@ def test_simple_retriever_resumable_full_refresh_cursor_reset_cursor_pagination( stream.retriever.stream_slices = stream_slicer stream.retriever.cursor = stream_slicer stream_slice = list(stream_slicer.stream_slices())[0] - actual_records = [r for r in stream.retriever.read_records(records_schema={}, stream_slice=stream_slice)] + actual_records = [ + r for r in stream.retriever.read_records(records_schema={}, stream_slice=stream_slice) + ] assert len(actual_records) == 5 assert actual_records == expected_records[:5] - assert stream.retriever.state == {"next_page_token": "https://for-all-mankind.nasa.com/api/v1/astronauts?next_page=gordo_stevens"} - requests_mock.get("https://for-all-mankind.nasa.com/astronauts?next_page=tracy_stevens", json=response_body) - requests_mock.get("https://for-all-mankind.nasa.com/astronauts?next_page=gordo_stevens", json=response_body_2) - actual_records = [r for r in stream.retriever.read_records(records_schema={}, stream_slice=stream_slice)] + assert stream.retriever.state == { + "next_page_token": "https://for-all-mankind.nasa.com/api/v1/astronauts?next_page=gordo_stevens" + } + requests_mock.get( + "https://for-all-mankind.nasa.com/astronauts?next_page=tracy_stevens", json=response_body + ) + requests_mock.get( + "https://for-all-mankind.nasa.com/astronauts?next_page=gordo_stevens", json=response_body_2 + ) + actual_records = [ + r for r in stream.retriever.read_records(records_schema={}, stream_slice=stream_slice) + ] assert len(actual_records) == 3 assert actual_records == expected_records[5:] assert stream.retriever.state == {"__ab_full_refresh_sync_complete": True} @@ -342,7 +383,12 @@ def test_simple_retriever_resumable_full_refresh_cursor_reset_skip_completed_str ] page_increment_strategy = PageIncrement(config={}, page_size=5, parameters={}) - paginator = DefaultPaginator(config={}, pagination_strategy=page_increment_strategy, url_base="https://airbyte.io", parameters={}) + paginator = DefaultPaginator( + config={}, + pagination_strategy=page_increment_strategy, + url_base="https://airbyte.io", + parameters={}, + ) paginator.reset = Mock(wraps=paginator.reset) stream_slicer = ResumableFullRefreshCursor(parameters={}) @@ -361,7 +407,9 @@ def test_simple_retriever_resumable_full_refresh_cursor_reset_skip_completed_str ) stream_slice = list(stream_slicer.stream_slices())[0] - actual_records = [r for r in retriever.read_records(records_schema={}, stream_slice=stream_slice)] + actual_records = [ + r for r in retriever.read_records(records_schema={}, stream_slice=stream_slice) + ] assert len(actual_records) == 0 assert retriever.state == {"__ab_full_refresh_sync_complete": True} @@ -373,12 +421,19 @@ def test_simple_retriever_resumable_full_refresh_cursor_reset_skip_completed_str "test_name, paginator_mapping, request_options_provider_mapping, expected_mapping", [ ("test_empty_headers", {}, {}, {}), - ("test_header_from_pagination_and_slicer", {"offset": 1000}, {"key": "value"}, {"key": "value", "offset": 1000}), + ( + "test_header_from_pagination_and_slicer", + {"offset": 1000}, + {"key": "value"}, + {"key": "value", "offset": 1000}, + ), ("test_header_from_stream_slicer", {}, {"slice": "slice_value"}, {"slice": "slice_value"}), ("test_duplicate_header_slicer_paginator", {"k": "v"}, {"k": "slice_value"}, None), ], ) -def test_get_request_options_from_pagination(test_name, paginator_mapping, request_options_provider_mapping, expected_mapping): +def test_get_request_options_from_pagination( + test_name, paginator_mapping, request_options_provider_mapping, expected_mapping +): # This test does not test request headers because they must be strings paginator = MagicMock() paginator.get_request_params.return_value = paginator_mapping @@ -499,7 +554,11 @@ def test_get_request_headers(test_name, paginator_mapping, expected_mapping): ], ) def test_ignore_stream_slicer_parameters_on_paginated_requests( - test_name, paginator_mapping, ignore_stream_slicer_parameters_on_paginated_requests, next_page_token, expected_mapping + test_name, + paginator_mapping, + ignore_stream_slicer_parameters_on_paginated_requests, + next_page_token, + expected_mapping, ): # This test is separate from the other request options because request headers must be strings paginator = MagicMock() @@ -536,12 +595,19 @@ def test_ignore_stream_slicer_parameters_on_paginated_requests( [ ("test_only_slicer_mapping", {"key": "value"}, {}, {"key": "value"}), ("test_only_slicer_string", "key=value", {}, "key=value"), - ("test_slicer_mapping_and_paginator_no_duplicate", {"key": "value"}, {"offset": 1000}, {"key": "value", "offset": 1000}), + ( + "test_slicer_mapping_and_paginator_no_duplicate", + {"key": "value"}, + {"offset": 1000}, + {"key": "value", "offset": 1000}, + ), ("test_slicer_mapping_and_paginator_with_duplicate", {"key": "value"}, {"key": 1000}, None), ("test_slicer_string_and_paginator", "key=value", {"offset": 1000}, None), ], ) -def test_request_body_data(test_name, request_options_provider_body_data, paginator_body_data, expected_body_data): +def test_request_body_data( + test_name, request_options_provider_body_data, paginator_body_data, expected_body_data +): paginator = MagicMock() paginator.get_request_body_data.return_value = paginator_body_data requester = MagicMock(use_cache=False) @@ -631,7 +697,9 @@ def test_limit_stream_slices(): ("test_second_greater_than_first", False), ], ) -def test_when_read_records_then_cursor_close_slice_with_greater_record(test_name, first_greater_than_second): +def test_when_read_records_then_cursor_close_slice_with_greater_record( + test_name, first_greater_than_second +): first_record = Record({"first": 1}, StreamSlice(cursor_slice={}, partition={})) second_record = Record({"second": 2}, StreamSlice(cursor_slice={}, partition={})) records = [first_record, second_record] @@ -656,7 +724,9 @@ def test_when_read_records_then_cursor_close_slice_with_greater_record(test_name stream_slice = StreamSlice(cursor_slice={}, partition={"repository": "airbyte"}) def retriever_read_pages(_, __, ___): - return retriever._parse_records(response=MagicMock(), stream_state={}, stream_slice=stream_slice, records_schema={}) + return retriever._parse_records( + response=MagicMock(), stream_state={}, stream_slice=stream_slice, records_schema={} + ) with patch.object( SimpleRetriever, @@ -665,11 +735,17 @@ def retriever_read_pages(_, __, ___): side_effect=retriever_read_pages, ): list(retriever.read_records(stream_slice=stream_slice, records_schema={})) - cursor.close_slice.assert_called_once_with(stream_slice, first_record if first_greater_than_second else second_record) + cursor.close_slice.assert_called_once_with( + stream_slice, first_record if first_greater_than_second else second_record + ) def test_given_stream_data_is_not_record_when_read_records_then_update_slice_with_optional_record(): - stream_data = [AirbyteMessage(type=Type.LOG, log=AirbyteLogMessage(level=Level.INFO, message="a log message"))] + stream_data = [ + AirbyteMessage( + type=Type.LOG, log=AirbyteLogMessage(level=Level.INFO, message="a log message") + ) + ] record_selector = MagicMock() record_selector.select_records.return_value = [] cursor = MagicMock(spec=DeclarativeCursor) @@ -688,7 +764,9 @@ def test_given_stream_data_is_not_record_when_read_records_then_update_slice_wit stream_slice = StreamSlice(cursor_slice={}, partition={"repository": "airbyte"}) def retriever_read_pages(_, __, ___): - return retriever._parse_records(response=MagicMock(), stream_state={}, stream_slice=stream_slice, records_schema={}) + return retriever._parse_records( + response=MagicMock(), stream_state={}, stream_slice=stream_slice, records_schema={} + ) with patch.object( SimpleRetriever, @@ -743,7 +821,9 @@ def test_emit_log_request_response_messages(mocker): response.request = request response.status_code = 200 - format_http_message_mock = mocker.patch("airbyte_cdk.sources.declarative.retrievers.simple_retriever.format_http_message") + format_http_message_mock = mocker.patch( + "airbyte_cdk.sources.declarative.retrievers.simple_retriever.format_http_message" + ) requester = MagicMock() retriever = SimpleRetrieverTestReadDecorator( name="stream_name", @@ -756,7 +836,12 @@ def test_emit_log_request_response_messages(mocker): config={}, ) - retriever._fetch_next_page(stream_state={}, stream_slice=StreamSlice(cursor_slice={}, partition={})) + retriever._fetch_next_page( + stream_state={}, stream_slice=StreamSlice(cursor_slice={}, partition={}) + ) assert requester.send_request.call_args_list[0][1]["log_formatter"] is not None - assert requester.send_request.call_args_list[0][1]["log_formatter"](response) == format_http_message_mock.return_value + assert ( + requester.send_request.call_args_list[0][1]["log_formatter"](response) + == format_http_message_mock.return_value + ) diff --git a/unit_tests/sources/declarative/schema/test_default_schema_loader.py b/unit_tests/sources/declarative/schema/test_default_schema_loader.py index c04c4fdc..38e617f9 100644 --- a/unit_tests/sources/declarative/schema/test_default_schema_loader.py +++ b/unit_tests/sources/declarative/schema/test_default_schema_loader.py @@ -12,7 +12,10 @@ "found_schema, found_error, expected_schema", [ pytest.param( - {"type": "object", "properties": {}}, None, {"type": "object", "properties": {}}, id="test_has_schema_in_default_location" + {"type": "object", "properties": {}}, + None, + {"type": "object", "properties": {}}, + id="test_has_schema_in_default_location", ), pytest.param(None, FileNotFoundError, {}, id="test_schema_file_does_not_exist"), ], diff --git a/unit_tests/sources/declarative/schema/test_json_file_schema_loader.py b/unit_tests/sources/declarative/schema/test_json_file_schema_loader.py index 8ef46761..a53a88a9 100644 --- a/unit_tests/sources/declarative/schema/test_json_file_schema_loader.py +++ b/unit_tests/sources/declarative/schema/test_json_file_schema_loader.py @@ -4,15 +4,33 @@ from unittest.mock import patch import pytest -from airbyte_cdk.sources.declarative.schema.json_file_schema_loader import JsonFileSchemaLoader, _default_file_path +from airbyte_cdk.sources.declarative.schema.json_file_schema_loader import ( + JsonFileSchemaLoader, + _default_file_path, +) @pytest.mark.parametrize( "test_name, input_path, expected_resource, expected_path", [ - ("path_prefixed_with_dot", "./source_example/schemas/lists.json", "source_example", "schemas/lists.json"), - ("path_prefixed_with_slash", "/source_example/schemas/lists.json", "source_example", "schemas/lists.json"), - ("path_starting_with_source", "source_example/schemas/lists.json", "source_example", "schemas/lists.json"), + ( + "path_prefixed_with_dot", + "./source_example/schemas/lists.json", + "source_example", + "schemas/lists.json", + ), + ( + "path_prefixed_with_slash", + "/source_example/schemas/lists.json", + "source_example", + "schemas/lists.json", + ), + ( + "path_starting_with_source", + "source_example/schemas/lists.json", + "source_example", + "schemas/lists.json", + ), ("path_starting_missing_source", "schemas/lists.json", "schemas", "lists.json"), ("path_with_file_only", "lists.json", "", "lists.json"), ("empty_path_does_not_crash", "", "", ""), @@ -29,7 +47,10 @@ def test_extract_resource_and_schema_path(test_name, input_path, expected_resour @patch("airbyte_cdk.sources.declarative.schema.json_file_schema_loader.sys") def test_exclude_cdk_packages(mocked_sys): - keys = ["airbyte_cdk.sources.concurrent_source.concurrent_source_adapter", "source_gitlab.utils"] + keys = [ + "airbyte_cdk.sources.concurrent_source.concurrent_source_adapter", + "source_gitlab.utils", + ] mocked_sys.modules = {key: "" for key in keys} default_file_path = _default_file_path() diff --git a/unit_tests/sources/declarative/spec/test_spec.py b/unit_tests/sources/declarative/spec/test_spec.py index 1e1ef498..8b924cb4 100644 --- a/unit_tests/sources/declarative/spec/test_spec.py +++ b/unit_tests/sources/declarative/spec/test_spec.py @@ -16,13 +16,25 @@ ConnectorSpecification(connectionSpecification={"client_id": "my_client_id"}), ), ( - Spec(connection_specification={"client_id": "my_client_id"}, parameters={}, documentation_url="https://airbyte.io"), - ConnectorSpecification(connectionSpecification={"client_id": "my_client_id"}, documentationUrl="https://airbyte.io"), + Spec( + connection_specification={"client_id": "my_client_id"}, + parameters={}, + documentation_url="https://airbyte.io", + ), + ConnectorSpecification( + connectionSpecification={"client_id": "my_client_id"}, + documentationUrl="https://airbyte.io", + ), ), ( - Spec(connection_specification={"client_id": "my_client_id"}, parameters={}, advanced_auth=AuthFlow(auth_flow_type="oauth2.0")), + Spec( + connection_specification={"client_id": "my_client_id"}, + parameters={}, + advanced_auth=AuthFlow(auth_flow_type="oauth2.0"), + ), ConnectorSpecification( - connectionSpecification={"client_id": "my_client_id"}, advanced_auth=AdvancedAuth(auth_flow_type=AuthFlowType.oauth2_0) + connectionSpecification={"client_id": "my_client_id"}, + advanced_auth=AdvancedAuth(auth_flow_type=AuthFlowType.oauth2_0), ), ), ], diff --git a/unit_tests/sources/declarative/test_concurrent_declarative_source.py b/unit_tests/sources/declarative/test_concurrent_declarative_source.py index 2081d6f7..bc97b2c5 100644 --- a/unit_tests/sources/declarative/test_concurrent_declarative_source.py +++ b/unit_tests/sources/declarative/test_concurrent_declarative_source.py @@ -26,7 +26,9 @@ StreamDescriptor, SyncMode, ) -from airbyte_cdk.sources.declarative.concurrent_declarative_source import ConcurrentDeclarativeSource +from airbyte_cdk.sources.declarative.concurrent_declarative_source import ( + ConcurrentDeclarativeSource, +) from airbyte_cdk.sources.declarative.declarative_stream import DeclarativeStream from airbyte_cdk.sources.streams import Stream from airbyte_cdk.sources.streams.checkpoint import Cursor @@ -43,22 +45,32 @@ _CATALOG = ConfiguredAirbyteCatalog( streams=[ ConfiguredAirbyteStream( - stream=AirbyteStream(name="party_members", json_schema={}, supported_sync_modes=[SyncMode.incremental]), + stream=AirbyteStream( + name="party_members", json_schema={}, supported_sync_modes=[SyncMode.incremental] + ), sync_mode=SyncMode.incremental, destination_sync_mode=DestinationSyncMode.append, ), ConfiguredAirbyteStream( - stream=AirbyteStream(name="palaces", json_schema={}, supported_sync_modes=[SyncMode.full_refresh]), + stream=AirbyteStream( + name="palaces", json_schema={}, supported_sync_modes=[SyncMode.full_refresh] + ), sync_mode=SyncMode.full_refresh, destination_sync_mode=DestinationSyncMode.append, ), ConfiguredAirbyteStream( - stream=AirbyteStream(name="locations", json_schema={}, supported_sync_modes=[SyncMode.incremental]), + stream=AirbyteStream( + name="locations", json_schema={}, supported_sync_modes=[SyncMode.incremental] + ), sync_mode=SyncMode.incremental, destination_sync_mode=DestinationSyncMode.append, ), ConfiguredAirbyteStream( - stream=AirbyteStream(name="party_members_skills", json_schema={}, supported_sync_modes=[SyncMode.full_refresh]), + stream=AirbyteStream( + name="party_members_skills", + json_schema={}, + supported_sync_modes=[SyncMode.full_refresh], + ), sync_mode=SyncMode.full_refresh, destination_sync_mode=DestinationSyncMode.append, ), @@ -101,7 +113,18 @@ _NO_STATE_PARTY_MEMBERS_SLICES_AND_RESPONSES = [ ( {"start": "2024-07-01", "end": "2024-07-15"}, - HttpResponse(json.dumps([{"id": "amamiya", "first_name": "ren", "last_name": "amamiya", "updated_at": "2024-07-10"}])), + HttpResponse( + json.dumps( + [ + { + "id": "amamiya", + "first_name": "ren", + "last_name": "amamiya", + "updated_at": "2024-07-10", + } + ] + ) + ), ), ({"start": "2024-07-16", "end": "2024-07-30"}, _EMPTY_RESPONSE), ( @@ -109,7 +132,12 @@ HttpResponse( json.dumps( [ - {"id": "nijima", "first_name": "makoto", "last_name": "nijima", "updated_at": "2024-08-10"}, + { + "id": "nijima", + "first_name": "makoto", + "last_name": "nijima", + "updated_at": "2024-08-10", + }, ] ) ), @@ -117,13 +145,27 @@ ({"start": "2024-08-15", "end": "2024-08-29"}, _EMPTY_RESPONSE), ( {"start": "2024-08-30", "end": "2024-09-10"}, - HttpResponse(json.dumps([{"id": "yoshizawa", "first_name": "sumire", "last_name": "yoshizawa", "updated_at": "2024-09-10"}])), + HttpResponse( + json.dumps( + [ + { + "id": "yoshizawa", + "first_name": "sumire", + "last_name": "yoshizawa", + "updated_at": "2024-09-10", + } + ] + ) + ), ), ] _MANIFEST = { "version": "5.0.0", "definitions": { - "selector": {"type": "RecordSelector", "extractor": {"type": "DpathExtractor", "field_path": []}}, + "selector": { + "type": "RecordSelector", + "extractor": {"type": "DpathExtractor", "field_path": []}, + }, "requester": { "type": "HttpRequester", "url_base": "https://persona.metaverse.com", @@ -142,7 +184,11 @@ "failure_type": "config_error", "error_message": "Access denied due to lack of permission or invalid API/Secret key or wrong data region.", }, - {"http_codes": [404], "action": "IGNORE", "error_message": "No data available for the time range requested."}, + { + "http_codes": [404], + "action": "IGNORE", + "error_message": "No data available for the time range requested.", + }, ], }, }, @@ -154,7 +200,9 @@ }, "incremental_cursor": { "type": "DatetimeBasedCursor", - "start_datetime": {"datetime": "{{ format_datetime(config['start_date'], '%Y-%m-%d') }}"}, + "start_datetime": { + "datetime": "{{ format_datetime(config['start_date'], '%Y-%m-%d') }}" + }, "end_datetime": {"datetime": "{{ now_utc().strftime('%Y-%m-%d') }}"}, "datetime_format": "%Y-%m-%d", "cursor_datetime_formats": ["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"], @@ -162,17 +210,31 @@ "step": "P15D", "cursor_field": "updated_at", "lookback_window": "P5D", - "start_time_option": {"type": "RequestOption", "field_name": "start", "inject_into": "request_parameter"}, - "end_time_option": {"type": "RequestOption", "field_name": "end", "inject_into": "request_parameter"}, + "start_time_option": { + "type": "RequestOption", + "field_name": "start", + "inject_into": "request_parameter", + }, + "end_time_option": { + "type": "RequestOption", + "field_name": "end", + "inject_into": "request_parameter", + }, }, "base_stream": {"retriever": {"$ref": "#/definitions/retriever"}}, "base_incremental_stream": { - "retriever": {"$ref": "#/definitions/retriever", "requester": {"$ref": "#/definitions/requester"}}, + "retriever": { + "$ref": "#/definitions/retriever", + "requester": {"$ref": "#/definitions/requester"}, + }, "incremental_sync": {"$ref": "#/definitions/incremental_cursor"}, }, "party_members_stream": { "$ref": "#/definitions/base_incremental_stream", - "retriever": {"$ref": "#/definitions/base_incremental_stream/retriever", "record_selector": {"$ref": "#/definitions/selector"}}, + "retriever": { + "$ref": "#/definitions/base_incremental_stream/retriever", + "record_selector": {"$ref": "#/definitions/selector"}, + }, "$parameters": {"name": "party_members", "primary_key": "id", "path": "/party_members"}, "schema_loader": { "type": "InlineSchemaLoader", @@ -184,7 +246,10 @@ "description": "The identifier", "type": ["null", "string"], }, - "name": {"description": "The name of the party member", "type": ["null", "string"]}, + "name": { + "description": "The name of the party member", + "type": ["null", "string"], + }, }, }, }, @@ -202,7 +267,10 @@ "description": "The identifier", "type": ["null", "string"], }, - "name": {"description": "The name of the metaverse palace", "type": ["null", "string"]}, + "name": { + "description": "The name of the metaverse palace", + "type": ["null", "string"], + }, }, }, }, @@ -217,7 +285,11 @@ }, "record_selector": {"$ref": "#/definitions/selector"}, }, - "incremental_sync": {"$ref": "#/definitions/incremental_cursor", "step": "P1M", "cursor_field": "updated_at"}, + "incremental_sync": { + "$ref": "#/definitions/incremental_cursor", + "step": "P1M", + "cursor_field": "updated_at", + }, "$parameters": {"name": "locations", "primary_key": "id", "path": "/locations"}, "schema_loader": { "type": "InlineSchemaLoader", @@ -229,7 +301,10 @@ "description": "The identifier", "type": ["null", "string"], }, - "name": {"description": "The name of the neighborhood location", "type": ["null", "string"]}, + "name": { + "description": "The name of the neighborhood location", + "type": ["null", "string"], + }, }, }, }, @@ -266,7 +341,10 @@ "description": "The identifier", "type": ["null", "string"], }, - "name": {"description": "The name of the party member", "type": ["null", "string"]}, + "name": { + "description": "The name of the party member", + "type": ["null", "string"], + }, }, }, }, @@ -301,7 +379,11 @@ class DeclarativeStreamDecorator(Stream): necessary. """ - def __init__(self, declarative_stream: DeclarativeStream, slice_to_records_mapping: Mapping[tuple[str, str], List[Mapping[str, Any]]]): + def __init__( + self, + declarative_stream: DeclarativeStream, + slice_to_records_mapping: Mapping[tuple[str, str], List[Mapping[str, Any]]], + ): self._declarative_stream = declarative_stream self._slice_to_records_mapping = slice_to_records_mapping @@ -335,7 +417,9 @@ def read_records( else: yield from [] else: - raise ValueError(f"stream_slice should be of type StreamSlice, but received {type(stream_slice)}") + raise ValueError( + f"stream_slice should be of type StreamSlice, but received {type(stream_slice)}" + ) def get_json_schema(self) -> Mapping[str, Any]: return self._declarative_stream.get_json_schema() @@ -352,22 +436,34 @@ def test_group_streams(): catalog = ConfiguredAirbyteCatalog( streams=[ ConfiguredAirbyteStream( - stream=AirbyteStream(name="party_members", json_schema={}, supported_sync_modes=[SyncMode.incremental]), + stream=AirbyteStream( + name="party_members", + json_schema={}, + supported_sync_modes=[SyncMode.incremental], + ), sync_mode=SyncMode.full_refresh, destination_sync_mode=DestinationSyncMode.append, ), ConfiguredAirbyteStream( - stream=AirbyteStream(name="palaces", json_schema={}, supported_sync_modes=[SyncMode.full_refresh]), + stream=AirbyteStream( + name="palaces", json_schema={}, supported_sync_modes=[SyncMode.full_refresh] + ), sync_mode=SyncMode.full_refresh, destination_sync_mode=DestinationSyncMode.append, ), ConfiguredAirbyteStream( - stream=AirbyteStream(name="locations", json_schema={}, supported_sync_modes=[SyncMode.incremental]), + stream=AirbyteStream( + name="locations", json_schema={}, supported_sync_modes=[SyncMode.incremental] + ), sync_mode=SyncMode.full_refresh, destination_sync_mode=DestinationSyncMode.append, ), ConfiguredAirbyteStream( - stream=AirbyteStream(name="party_members_skills", json_schema={}, supported_sync_modes=[SyncMode.full_refresh]), + stream=AirbyteStream( + name="party_members_skills", + json_schema={}, + supported_sync_modes=[SyncMode.full_refresh], + ), sync_mode=SyncMode.full_refresh, destination_sync_mode=DestinationSyncMode.append, ), @@ -376,7 +472,9 @@ def test_group_streams(): state = [] - source = ConcurrentDeclarativeSource(source_config=_MANIFEST, config=_CONFIG, catalog=catalog, state=state) + source = ConcurrentDeclarativeSource( + source_config=_MANIFEST, config=_CONFIG, catalog=catalog, state=state + ) concurrent_streams = source._concurrent_streams synchronous_streams = source._synchronous_streams @@ -421,7 +519,9 @@ def test_create_concurrent_cursor(): ), ] - source = ConcurrentDeclarativeSource(source_config=_MANIFEST, config=_CONFIG, catalog=_CATALOG, state=state) + source = ConcurrentDeclarativeSource( + source_config=_MANIFEST, config=_CONFIG, catalog=_CATALOG, state=state + ) party_members_stream = source._concurrent_streams[0] assert isinstance(party_members_stream, DefaultStream) @@ -431,7 +531,9 @@ def test_create_concurrent_cursor(): assert party_members_cursor._stream_name == "party_members" assert party_members_cursor._cursor_field.cursor_field_key == "updated_at" assert party_members_cursor._start == pendulum.parse(_CONFIG.get("start_date")) - assert party_members_cursor._end_provider() == datetime(year=2024, month=9, day=1, tzinfo=timezone.utc) + assert party_members_cursor._end_provider() == datetime( + year=2024, month=9, day=1, tzinfo=timezone.utc + ) assert party_members_cursor._slice_boundary_fields == ("start_time", "end_time") assert party_members_cursor._slice_range == timedelta(days=15) assert party_members_cursor._lookback_window == timedelta(days=5) @@ -445,7 +547,9 @@ def test_create_concurrent_cursor(): assert locations_cursor._stream_name == "locations" assert locations_cursor._cursor_field.cursor_field_key == "updated_at" assert locations_cursor._start == pendulum.parse(_CONFIG.get("start_date")) - assert locations_cursor._end_provider() == datetime(year=2024, month=9, day=1, tzinfo=timezone.utc) + assert locations_cursor._end_provider() == datetime( + year=2024, month=9, day=1, tzinfo=timezone.utc + ) assert locations_cursor._slice_boundary_fields == ("start_time", "end_time") assert locations_cursor._slice_range == isodate.Duration(months=1) assert locations_cursor._lookback_window == timedelta(days=5) @@ -467,18 +571,33 @@ def test_check(): """ with HttpMocker() as http_mocker: http_mocker.get( - HttpRequest("https://persona.metaverse.com/party_members?start=2024-07-01&end=2024-07-15"), - HttpResponse(json.dumps({"id": "amamiya", "first_name": "ren", "last_name": "amamiya", "updated_at": "2024-07-10"})), + HttpRequest( + "https://persona.metaverse.com/party_members?start=2024-07-01&end=2024-07-15" + ), + HttpResponse( + json.dumps( + { + "id": "amamiya", + "first_name": "ren", + "last_name": "amamiya", + "updated_at": "2024-07-10", + } + ) + ), ) http_mocker.get( HttpRequest("https://persona.metaverse.com/palaces"), HttpResponse(json.dumps({"id": "palace_1"})), ) http_mocker.get( - HttpRequest("https://persona.metaverse.com/locations?m=active&i=1&g=country&start=2024-07-01&end=2024-07-31"), + HttpRequest( + "https://persona.metaverse.com/locations?m=active&i=1&g=country&start=2024-07-01&end=2024-07-31" + ), HttpResponse(json.dumps({"id": "location_1"})), ) - source = ConcurrentDeclarativeSource(source_config=_MANIFEST, config=_CONFIG, catalog=None, state=None) + source = ConcurrentDeclarativeSource( + source_config=_MANIFEST, config=_CONFIG, catalog=None, state=None + ) connection_status = source.check(logger=source.logger, config=_CONFIG) @@ -491,7 +610,9 @@ def test_discover(): """ expected_stream_names = ["party_members", "palaces", "locations", "party_members_skills"] - source = ConcurrentDeclarativeSource(source_config=_MANIFEST, config=_CONFIG, catalog=None, state=None) + source = ConcurrentDeclarativeSource( + source_config=_MANIFEST, config=_CONFIG, catalog=None, state=None + ) actual_catalog = source.discover(logger=source.logger, config=_CONFIG) @@ -502,14 +623,21 @@ def test_discover(): assert actual_catalog.streams[3].name in expected_stream_names -def _mock_requests(http_mocker: HttpMocker, url: str, query_params: List[Dict[str, str]], responses: List[HttpResponse]) -> None: +def _mock_requests( + http_mocker: HttpMocker, + url: str, + query_params: List[Dict[str, str]], + responses: List[HttpResponse], +) -> None: assert len(query_params) == len(responses), "Expecting as many slices as response" for i in range(len(query_params)): http_mocker.get(HttpRequest(url, query_params=query_params[i]), responses[i]) -def _mock_party_members_requests(http_mocker: HttpMocker, slices_and_responses: List[Tuple[Dict[str, str], HttpResponse]]) -> None: +def _mock_party_members_requests( + http_mocker: HttpMocker, slices_and_responses: List[Tuple[Dict[str, str], HttpResponse]] +) -> None: slices = list(map(lambda slice_and_response: slice_and_response[0], slices_and_responses)) responses = list(map(lambda slice_and_response: slice_and_response[1], slices_and_responses)) @@ -522,7 +650,9 @@ def _mock_party_members_requests(http_mocker: HttpMocker, slices_and_responses: def _mock_locations_requests(http_mocker: HttpMocker, slices: List[Dict[str, str]]) -> None: - locations_query_params = list(map(lambda _slice: _slice | {"m": "active", "i": "1", "g": "country"}, slices)) + locations_query_params = list( + map(lambda _slice: _slice | {"m": "active", "i": "1", "g": "country"}, slices) + ) _mock_requests( http_mocker, "https://persona.metaverse.com/locations", @@ -535,9 +665,18 @@ def _mock_party_members_skills_requests(http_mocker: HttpMocker) -> None: """ This method assumes _mock_party_members_requests has been called before else the stream won't work. """ - http_mocker.get(HttpRequest("https://persona.metaverse.com/party_members/amamiya/skills"), _PARTY_MEMBERS_SKILLS_RESPONSE) - http_mocker.get(HttpRequest("https://persona.metaverse.com/party_members/nijima/skills"), _PARTY_MEMBERS_SKILLS_RESPONSE) - http_mocker.get(HttpRequest("https://persona.metaverse.com/party_members/yoshizawa/skills"), _PARTY_MEMBERS_SKILLS_RESPONSE) + http_mocker.get( + HttpRequest("https://persona.metaverse.com/party_members/amamiya/skills"), + _PARTY_MEMBERS_SKILLS_RESPONSE, + ) + http_mocker.get( + HttpRequest("https://persona.metaverse.com/party_members/nijima/skills"), + _PARTY_MEMBERS_SKILLS_RESPONSE, + ) + http_mocker.get( + HttpRequest("https://persona.metaverse.com/party_members/yoshizawa/skills"), + _PARTY_MEMBERS_SKILLS_RESPONSE, + ) @freezegun.freeze_time(_NOW) @@ -550,7 +689,9 @@ def test_read_with_concurrent_and_synchronous_streams(): {"start": "2024-08-01", "end": "2024-08-31"}, {"start": "2024-09-01", "end": "2024-09-10"}, ] - source = ConcurrentDeclarativeSource(source_config=_MANIFEST, config=_CONFIG, catalog=_CATALOG, state=None) + source = ConcurrentDeclarativeSource( + source_config=_MANIFEST, config=_CONFIG, catalog=_CATALOG, state=None + ) disable_emitting_sequential_state_messages(source=source) with HttpMocker() as http_mocker: @@ -559,7 +700,9 @@ def test_read_with_concurrent_and_synchronous_streams(): http_mocker.get(HttpRequest("https://persona.metaverse.com/palaces"), _PALACES_RESPONSE) _mock_party_members_skills_requests(http_mocker) - messages = list(source.read(logger=source.logger, config=_CONFIG, catalog=_CATALOG, state=[])) + messages = list( + source.read(logger=source.logger, config=_CONFIG, catalog=_CATALOG, state=[]) + ) # See _mock_party_members_requests party_members_records = get_records_for_stream("party_members", messages) @@ -570,7 +713,14 @@ def test_read_with_concurrent_and_synchronous_streams(): assert ( party_members_states[5].stream.stream_state.__dict__ == AirbyteStateBlob( - state_type="date-range", slices=[{"start": "2024-07-01", "end": "2024-09-10", "most_recent_cursor_value": "2024-09-10"}] + state_type="date-range", + slices=[ + { + "start": "2024-07-01", + "end": "2024-09-10", + "most_recent_cursor_value": "2024-09-10", + } + ], ).__dict__ ) @@ -585,7 +735,14 @@ def test_read_with_concurrent_and_synchronous_streams(): assert ( locations_states[3].stream.stream_state.__dict__ == AirbyteStateBlob( - state_type="date-range", slices=[{"start": "2024-07-01", "end": "2024-09-10", "most_recent_cursor_value": "2024-08-10"}] + state_type="date-range", + slices=[ + { + "start": "2024-07-01", + "end": "2024-09-10", + "most_recent_cursor_value": "2024-08-10", + } + ], ).__dict__ ) @@ -595,30 +752,53 @@ def test_read_with_concurrent_and_synchronous_streams(): palaces_states = get_states_for_stream(stream_name="palaces", messages=messages) assert len(palaces_states) == 1 - assert palaces_states[0].stream.stream_state.__dict__ == AirbyteStateBlob(__ab_full_refresh_sync_complete=True).__dict__ + assert ( + palaces_states[0].stream.stream_state.__dict__ + == AirbyteStateBlob(__ab_full_refresh_sync_complete=True).__dict__ + ) # Expects 3 records, 3 slices, 3 records in slice party_members_skills_records = get_records_for_stream("party_members_skills", messages) assert len(party_members_skills_records) == 9 - party_members_skills_states = get_states_for_stream(stream_name="party_members_skills", messages=messages) + party_members_skills_states = get_states_for_stream( + stream_name="party_members_skills", messages=messages + ) assert len(party_members_skills_states) == 3 assert party_members_skills_states[0].stream.stream_state.__dict__ == { "states": [ - {"partition": {"parent_slice": {}, "party_member_id": "amamiya"}, "cursor": {"__ab_full_refresh_sync_complete": True}}, + { + "partition": {"parent_slice": {}, "party_member_id": "amamiya"}, + "cursor": {"__ab_full_refresh_sync_complete": True}, + }, ] } assert party_members_skills_states[1].stream.stream_state.__dict__ == { "states": [ - {"partition": {"parent_slice": {}, "party_member_id": "amamiya"}, "cursor": {"__ab_full_refresh_sync_complete": True}}, - {"partition": {"parent_slice": {}, "party_member_id": "nijima"}, "cursor": {"__ab_full_refresh_sync_complete": True}}, + { + "partition": {"parent_slice": {}, "party_member_id": "amamiya"}, + "cursor": {"__ab_full_refresh_sync_complete": True}, + }, + { + "partition": {"parent_slice": {}, "party_member_id": "nijima"}, + "cursor": {"__ab_full_refresh_sync_complete": True}, + }, ] } assert party_members_skills_states[2].stream.stream_state.__dict__ == { "states": [ - {"partition": {"parent_slice": {}, "party_member_id": "amamiya"}, "cursor": {"__ab_full_refresh_sync_complete": True}}, - {"partition": {"parent_slice": {}, "party_member_id": "nijima"}, "cursor": {"__ab_full_refresh_sync_complete": True}}, - {"partition": {"parent_slice": {}, "party_member_id": "yoshizawa"}, "cursor": {"__ab_full_refresh_sync_complete": True}}, + { + "partition": {"parent_slice": {}, "party_member_id": "amamiya"}, + "cursor": {"__ab_full_refresh_sync_complete": True}, + }, + { + "partition": {"parent_slice": {}, "party_member_id": "nijima"}, + "cursor": {"__ab_full_refresh_sync_complete": True}, + }, + { + "partition": {"parent_slice": {}, "party_member_id": "yoshizawa"}, + "cursor": {"__ab_full_refresh_sync_complete": True}, + }, ] } @@ -659,7 +839,18 @@ def test_read_with_concurrent_and_synchronous_streams_with_concurrent_state(): party_members_slices_and_responses = _NO_STATE_PARTY_MEMBERS_SLICES_AND_RESPONSES + [ ( {"start": "2024-09-04", "end": "2024-09-10"}, # considering lookback window - HttpResponse(json.dumps([{"id": "yoshizawa", "first_name": "sumire", "last_name": "yoshizawa", "updated_at": "2024-09-10"}])), + HttpResponse( + json.dumps( + [ + { + "id": "yoshizawa", + "first_name": "sumire", + "last_name": "yoshizawa", + "updated_at": "2024-09-10", + } + ] + ) + ), ) ] location_slices = [ @@ -667,7 +858,9 @@ def test_read_with_concurrent_and_synchronous_streams_with_concurrent_state(): {"start": "2024-08-26", "end": "2024-09-10"}, ] - source = ConcurrentDeclarativeSource(source_config=_MANIFEST, config=_CONFIG, catalog=_CATALOG, state=state) + source = ConcurrentDeclarativeSource( + source_config=_MANIFEST, config=_CONFIG, catalog=_CATALOG, state=state + ) disable_emitting_sequential_state_messages(source=source) with HttpMocker() as http_mocker: @@ -676,7 +869,9 @@ def test_read_with_concurrent_and_synchronous_streams_with_concurrent_state(): http_mocker.get(HttpRequest("https://persona.metaverse.com/palaces"), _PALACES_RESPONSE) _mock_party_members_skills_requests(http_mocker) - messages = list(source.read(logger=source.logger, config=_CONFIG, catalog=_CATALOG, state=state)) + messages = list( + source.read(logger=source.logger, config=_CONFIG, catalog=_CATALOG, state=state) + ) # Expects 8 records, skip successful intervals and are left with 2 slices, 4 records each slice locations_records = get_records_for_stream("locations", messages) @@ -687,7 +882,14 @@ def test_read_with_concurrent_and_synchronous_streams_with_concurrent_state(): assert ( locations_states[2].stream.stream_state.__dict__ == AirbyteStateBlob( - state_type="date-range", slices=[{"start": "2024-07-01", "end": "2024-09-10", "most_recent_cursor_value": "2024-08-10"}] + state_type="date-range", + slices=[ + { + "start": "2024-07-01", + "end": "2024-09-10", + "most_recent_cursor_value": "2024-08-10", + } + ], ).__dict__ ) @@ -702,7 +904,14 @@ def test_read_with_concurrent_and_synchronous_streams_with_concurrent_state(): assert ( party_members_states[3].stream.stream_state.__dict__ == AirbyteStateBlob( - state_type="date-range", slices=[{"start": "2024-07-01", "end": "2024-09-10", "most_recent_cursor_value": "2024-09-10"}] + state_type="date-range", + slices=[ + { + "start": "2024-07-01", + "end": "2024-09-10", + "most_recent_cursor_value": "2024-09-10", + } + ], ).__dict__ ) @@ -738,17 +947,41 @@ def test_read_with_concurrent_and_synchronous_streams_with_sequential_state(): ), ] - source = ConcurrentDeclarativeSource(source_config=_MANIFEST, config=_CONFIG, catalog=_CATALOG, state=state) + source = ConcurrentDeclarativeSource( + source_config=_MANIFEST, config=_CONFIG, catalog=_CATALOG, state=state + ) disable_emitting_sequential_state_messages(source=source) party_members_slices_and_responses = _NO_STATE_PARTY_MEMBERS_SLICES_AND_RESPONSES + [ ( {"start": "2024-08-16", "end": "2024-08-30"}, - HttpResponse(json.dumps([{"id": "nijima", "first_name": "makoto", "last_name": "nijima", "updated_at": "2024-08-10"}])), + HttpResponse( + json.dumps( + [ + { + "id": "nijima", + "first_name": "makoto", + "last_name": "nijima", + "updated_at": "2024-08-10", + } + ] + ) + ), ), # considering lookback window ( {"start": "2024-08-31", "end": "2024-09-10"}, - HttpResponse(json.dumps([{"id": "yoshizawa", "first_name": "sumire", "last_name": "yoshizawa", "updated_at": "2024-09-10"}])), + HttpResponse( + json.dumps( + [ + { + "id": "yoshizawa", + "first_name": "sumire", + "last_name": "yoshizawa", + "updated_at": "2024-09-10", + } + ] + ) + ), ), ] location_slices = [ @@ -762,7 +995,9 @@ def test_read_with_concurrent_and_synchronous_streams_with_sequential_state(): http_mocker.get(HttpRequest("https://persona.metaverse.com/palaces"), _PALACES_RESPONSE) _mock_party_members_skills_requests(http_mocker) - messages = list(source.read(logger=source.logger, config=_CONFIG, catalog=_CATALOG, state=state)) + messages = list( + source.read(logger=source.logger, config=_CONFIG, catalog=_CATALOG, state=state) + ) # Expects 8 records, skip successful intervals and are left with 2 slices, 4 records each slice locations_records = get_records_for_stream("locations", messages) @@ -773,7 +1008,14 @@ def test_read_with_concurrent_and_synchronous_streams_with_sequential_state(): assert ( locations_states[2].stream.stream_state.__dict__ == AirbyteStateBlob( - state_type="date-range", slices=[{"start": "2024-07-01", "end": "2024-09-10", "most_recent_cursor_value": "2024-08-10"}] + state_type="date-range", + slices=[ + { + "start": "2024-07-01", + "end": "2024-09-10", + "most_recent_cursor_value": "2024-08-10", + } + ], ).__dict__ ) @@ -786,7 +1028,14 @@ def test_read_with_concurrent_and_synchronous_streams_with_sequential_state(): assert ( party_members_states[2].stream.stream_state.__dict__ == AirbyteStateBlob( - state_type="date-range", slices=[{"start": "2024-07-01", "end": "2024-09-10", "most_recent_cursor_value": "2024-09-10"}] + state_type="date-range", + slices=[ + { + "start": "2024-07-01", + "end": "2024-09-10", + "most_recent_cursor_value": "2024-09-10", + } + ], ).__dict__ ) @@ -811,20 +1060,27 @@ def test_read_concurrent_with_failing_partition_in_the_middle(): ] expected_stream_state = { "state_type": "date-range", - "slices": [location_slice | {"most_recent_cursor_value": "2024-08-10"} for location_slice in location_slices], + "slices": [ + location_slice | {"most_recent_cursor_value": "2024-08-10"} + for location_slice in location_slices + ], } catalog = ConfiguredAirbyteCatalog( streams=[ ConfiguredAirbyteStream( - stream=AirbyteStream(name="locations", json_schema={}, supported_sync_modes=[SyncMode.incremental]), + stream=AirbyteStream( + name="locations", json_schema={}, supported_sync_modes=[SyncMode.incremental] + ), sync_mode=SyncMode.incremental, destination_sync_mode=DestinationSyncMode.append, ), ] ) - source = ConcurrentDeclarativeSource(source_config=_MANIFEST, config=_CONFIG, catalog=catalog, state=[]) + source = ConcurrentDeclarativeSource( + source_config=_MANIFEST, config=_CONFIG, catalog=catalog, state=[] + ) disable_emitting_sequential_state_messages(source=source) location_slices = [ @@ -838,11 +1094,16 @@ def test_read_concurrent_with_failing_partition_in_the_middle(): messages = [] try: - for message in source.read(logger=source.logger, config=_CONFIG, catalog=catalog, state=[]): + for message in source.read( + logger=source.logger, config=_CONFIG, catalog=catalog, state=[] + ): messages.append(message) except AirbyteTracedException: assert ( - get_states_for_stream(stream_name="locations", messages=messages)[-1].stream.stream_state.__dict__ == expected_stream_state + get_states_for_stream(stream_name="locations", messages=messages)[ + -1 + ].stream.stream_state.__dict__ + == expected_stream_state ) @@ -855,26 +1116,36 @@ def test_read_concurrent_skip_streams_not_in_catalog(): catalog = ConfiguredAirbyteCatalog( streams=[ ConfiguredAirbyteStream( - stream=AirbyteStream(name="palaces", json_schema={}, supported_sync_modes=[SyncMode.full_refresh]), + stream=AirbyteStream( + name="palaces", json_schema={}, supported_sync_modes=[SyncMode.full_refresh] + ), sync_mode=SyncMode.full_refresh, destination_sync_mode=DestinationSyncMode.append, ), ConfiguredAirbyteStream( - stream=AirbyteStream(name="locations", json_schema={}, supported_sync_modes=[SyncMode.incremental]), + stream=AirbyteStream( + name="locations", + json_schema={}, + supported_sync_modes=[SyncMode.incremental], + ), sync_mode=SyncMode.incremental, destination_sync_mode=DestinationSyncMode.append, ), ] ) - source = ConcurrentDeclarativeSource(source_config=_MANIFEST, config=_CONFIG, catalog=catalog, state=None) + source = ConcurrentDeclarativeSource( + source_config=_MANIFEST, config=_CONFIG, catalog=catalog, state=None + ) # locations requests location_slices = [ {"start": "2024-07-01", "end": "2024-07-31"}, {"start": "2024-08-01", "end": "2024-08-31"}, {"start": "2024-09-01", "end": "2024-09-10"}, ] - locations_query_params = list(map(lambda _slice: _slice | {"m": "active", "i": "1", "g": "country"}, location_slices)) + locations_query_params = list( + map(lambda _slice: _slice | {"m": "active", "i": "1", "g": "country"}, location_slices) + ) _mock_requests( http_mocker, "https://persona.metaverse.com/locations", @@ -887,7 +1158,9 @@ def test_read_concurrent_skip_streams_not_in_catalog(): disable_emitting_sequential_state_messages(source=source) - messages = list(source.read(logger=source.logger, config=_CONFIG, catalog=catalog, state=[])) + messages = list( + source.read(logger=source.logger, config=_CONFIG, catalog=catalog, state=[]) + ) locations_records = get_records_for_stream(stream_name="locations", messages=messages) assert len(locations_records) == 12 @@ -911,22 +1184,30 @@ def test_default_perform_interpolation_on_concurrency_level(): catalog = ConfiguredAirbyteCatalog( streams=[ ConfiguredAirbyteStream( - stream=AirbyteStream(name="palaces", json_schema={}, supported_sync_modes=[SyncMode.full_refresh]), + stream=AirbyteStream( + name="palaces", json_schema={}, supported_sync_modes=[SyncMode.full_refresh] + ), sync_mode=SyncMode.full_refresh, destination_sync_mode=DestinationSyncMode.append, ), ] ) - source = ConcurrentDeclarativeSource(source_config=_MANIFEST, config=config, catalog=catalog, state=[]) - assert source._concurrent_source._initial_number_partitions_to_generate == 10 # We floor the number of initial partitions on creation + source = ConcurrentDeclarativeSource( + source_config=_MANIFEST, config=config, catalog=catalog, state=[] + ) + assert ( + source._concurrent_source._initial_number_partitions_to_generate == 10 + ) # We floor the number of initial partitions on creation def test_default_to_single_threaded_when_no_concurrency_level(): catalog = ConfiguredAirbyteCatalog( streams=[ ConfiguredAirbyteStream( - stream=AirbyteStream(name="palaces", json_schema={}, supported_sync_modes=[SyncMode.full_refresh]), + stream=AirbyteStream( + name="palaces", json_schema={}, supported_sync_modes=[SyncMode.full_refresh] + ), sync_mode=SyncMode.full_refresh, destination_sync_mode=DestinationSyncMode.append, ), @@ -936,7 +1217,9 @@ def test_default_to_single_threaded_when_no_concurrency_level(): manifest = copy.deepcopy(_MANIFEST) del manifest["concurrency_level"] - source = ConcurrentDeclarativeSource(source_config=manifest, config=_CONFIG, catalog=catalog, state=[]) + source = ConcurrentDeclarativeSource( + source_config=manifest, config=_CONFIG, catalog=catalog, state=[] + ) assert source._concurrent_source._initial_number_partitions_to_generate == 1 @@ -945,7 +1228,9 @@ def test_concurrency_level_initial_number_partitions_to_generate_is_always_one_o catalog = ConfiguredAirbyteCatalog( streams=[ ConfiguredAirbyteStream( - stream=AirbyteStream(name="palaces", json_schema={}, supported_sync_modes=[SyncMode.full_refresh]), + stream=AirbyteStream( + name="palaces", json_schema={}, supported_sync_modes=[SyncMode.full_refresh] + ), sync_mode=SyncMode.full_refresh, destination_sync_mode=DestinationSyncMode.append, ), @@ -959,7 +1244,9 @@ def test_concurrency_level_initial_number_partitions_to_generate_is_always_one_o "max_concurrency": 25, } - source = ConcurrentDeclarativeSource(source_config=_MANIFEST, config=config, catalog=catalog, state=[]) + source = ConcurrentDeclarativeSource( + source_config=_MANIFEST, config=config, catalog=catalog, state=[] + ) assert source._concurrent_source._initial_number_partitions_to_generate == 1 @@ -967,18 +1254,25 @@ def test_streams_with_stream_state_interpolation_should_be_synchronous(): manifest_with_stream_state_interpolation = copy.deepcopy(_MANIFEST) # Add stream_state interpolation to the location stream's HttpRequester - manifest_with_stream_state_interpolation["definitions"]["locations_stream"]["retriever"]["requester"]["request_parameters"] = { + manifest_with_stream_state_interpolation["definitions"]["locations_stream"]["retriever"][ + "requester" + ]["request_parameters"] = { "after": "{{ stream_state['updated_at'] }}", } # Add a RecordFilter component that uses stream_state interpolation to the party member stream - manifest_with_stream_state_interpolation["definitions"]["party_members_stream"]["retriever"]["record_selector"]["record_filter"] = { + manifest_with_stream_state_interpolation["definitions"]["party_members_stream"]["retriever"][ + "record_selector" + ]["record_filter"] = { "type": "RecordFilter", "condition": "{{ record.updated_at > stream_state['updated_at'] }}", } source = ConcurrentDeclarativeSource( - source_config=manifest_with_stream_state_interpolation, config=_CONFIG, catalog=_CATALOG, state=None + source_config=manifest_with_stream_state_interpolation, + config=_CONFIG, + catalog=_CATALOG, + state=None, ) assert len(source._concurrent_streams) == 0 @@ -989,7 +1283,10 @@ def test_given_partition_routing_and_incremental_sync_then_stream_is_not_concurr manifest = { "version": "5.0.0", "definitions": { - "selector": {"type": "RecordSelector", "extractor": {"type": "DpathExtractor", "field_path": []}}, + "selector": { + "type": "RecordSelector", + "extractor": {"type": "DpathExtractor", "field_path": []}, + }, "requester": { "type": "HttpRequester", "url_base": "https://persona.metaverse.com", @@ -1008,7 +1305,11 @@ def test_given_partition_routing_and_incremental_sync_then_stream_is_not_concurr "failure_type": "config_error", "error_message": "Access denied due to lack of permission or invalid API/Secret key or wrong data region.", }, - {"http_codes": [404], "action": "IGNORE", "error_message": "No data available for the time range requested."}, + { + "http_codes": [404], + "action": "IGNORE", + "error_message": "No data available for the time range requested.", + }, ], }, }, @@ -1020,7 +1321,9 @@ def test_given_partition_routing_and_incremental_sync_then_stream_is_not_concurr }, "incremental_cursor": { "type": "DatetimeBasedCursor", - "start_datetime": {"datetime": "{{ format_datetime(config['start_date'], '%Y-%m-%d') }}"}, + "start_datetime": { + "datetime": "{{ format_datetime(config['start_date'], '%Y-%m-%d') }}" + }, "end_datetime": {"datetime": "{{ now_utc().strftime('%Y-%m-%d') }}"}, "datetime_format": "%Y-%m-%d", "cursor_datetime_formats": ["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"], @@ -1028,12 +1331,23 @@ def test_given_partition_routing_and_incremental_sync_then_stream_is_not_concurr "step": "P15D", "cursor_field": "updated_at", "lookback_window": "P5D", - "start_time_option": {"type": "RequestOption", "field_name": "start", "inject_into": "request_parameter"}, - "end_time_option": {"type": "RequestOption", "field_name": "end", "inject_into": "request_parameter"}, + "start_time_option": { + "type": "RequestOption", + "field_name": "start", + "inject_into": "request_parameter", + }, + "end_time_option": { + "type": "RequestOption", + "field_name": "end", + "inject_into": "request_parameter", + }, }, "base_stream": {"retriever": {"$ref": "#/definitions/retriever"}}, "base_incremental_stream": { - "retriever": {"$ref": "#/definitions/retriever", "requester": {"$ref": "#/definitions/requester"}}, + "retriever": { + "$ref": "#/definitions/retriever", + "requester": {"$ref": "#/definitions/requester"}, + }, "incremental_sync": {"$ref": "#/definitions/incremental_cursor"}, }, "incremental_party_members_skills_stream": { @@ -1061,7 +1375,10 @@ def test_given_partition_routing_and_incremental_sync_then_stream_is_not_concurr "description": "The identifier", "type": ["null", "string"], }, - "name": {"description": "The name of the party member", "type": ["null", "string"]}, + "name": { + "description": "The name of the party member", + "type": ["null", "string"], + }, }, }, }, @@ -1079,7 +1396,11 @@ def test_given_partition_routing_and_incremental_sync_then_stream_is_not_concurr catalog = ConfiguredAirbyteCatalog( streams=[ ConfiguredAirbyteStream( - stream=AirbyteStream(name="incremental_party_members_skills", json_schema={}, supported_sync_modes=[SyncMode.full_refresh]), + stream=AirbyteStream( + name="incremental_party_members_skills", + json_schema={}, + supported_sync_modes=[SyncMode.full_refresh], + ), sync_mode=SyncMode.incremental, destination_sync_mode=DestinationSyncMode.append, ) @@ -1088,7 +1409,9 @@ def test_given_partition_routing_and_incremental_sync_then_stream_is_not_concurr state = [] - source = ConcurrentDeclarativeSource(source_config=manifest, config=_CONFIG, catalog=catalog, state=state) + source = ConcurrentDeclarativeSource( + source_config=manifest, config=_CONFIG, catalog=catalog, state=state + ) assert len(source._concurrent_streams) == 0 assert len(source._synchronous_streams) == 1 @@ -1097,7 +1420,9 @@ def test_given_partition_routing_and_incremental_sync_then_stream_is_not_concurr def create_wrapped_stream(stream: DeclarativeStream) -> Stream: slice_to_records_mapping = get_mocked_read_records_output(stream_name=stream.name) - return DeclarativeStreamDecorator(declarative_stream=stream, slice_to_records_mapping=slice_to_records_mapping) + return DeclarativeStreamDecorator( + declarative_stream=stream, slice_to_records_mapping=slice_to_records_mapping + ) def get_mocked_read_records_output(stream_name: str) -> Mapping[tuple[str, str], List[StreamData]]: @@ -1105,16 +1430,32 @@ def get_mocked_read_records_output(stream_name: str) -> Mapping[tuple[str, str], case "locations": slices = [ # Slices used during first incremental sync - StreamSlice(cursor_slice={"start": "2024-07-01", "end": "2024-07-31"}, partition={}), - StreamSlice(cursor_slice={"start": "2024-08-01", "end": "2024-08-31"}, partition={}), - StreamSlice(cursor_slice={"start": "2024-09-01", "end": "2024-09-09"}, partition={}), + StreamSlice( + cursor_slice={"start": "2024-07-01", "end": "2024-07-31"}, partition={} + ), + StreamSlice( + cursor_slice={"start": "2024-08-01", "end": "2024-08-31"}, partition={} + ), + StreamSlice( + cursor_slice={"start": "2024-09-01", "end": "2024-09-09"}, partition={} + ), # Slices used during incremental checkpoint sync - StreamSlice(cursor_slice={"start": "2024-07-26", "end": "2024-08-25"}, partition={}), - StreamSlice(cursor_slice={"start": "2024-08-26", "end": "2024-09-09"}, partition={}), + StreamSlice( + cursor_slice={"start": "2024-07-26", "end": "2024-08-25"}, partition={} + ), + StreamSlice( + cursor_slice={"start": "2024-08-26", "end": "2024-09-09"}, partition={} + ), # Slices used during incremental sync with some partitions that exit with an error - StreamSlice(cursor_slice={"start": "2024-07-05", "end": "2024-08-04"}, partition={}), - StreamSlice(cursor_slice={"start": "2024-08-05", "end": "2024-09-04"}, partition={}), - StreamSlice(cursor_slice={"start": "2024-09-05", "end": "2024-09-09"}, partition={}), + StreamSlice( + cursor_slice={"start": "2024-07-05", "end": "2024-08-04"}, partition={} + ), + StreamSlice( + cursor_slice={"start": "2024-08-05", "end": "2024-09-04"}, partition={} + ), + StreamSlice( + cursor_slice={"start": "2024-09-05", "end": "2024-09-09"}, partition={} + ), ] records = [ @@ -1126,23 +1467,56 @@ def get_mocked_read_records_output(stream_name: str) -> Mapping[tuple[str, str], case "party_members": slices = [ # Slices used during first incremental sync - StreamSlice(cursor_slice={"start": "2024-07-01", "end": "2024-07-15"}, partition={}), - StreamSlice(cursor_slice={"start": "2024-07-16", "end": "2024-07-30"}, partition={}), - StreamSlice(cursor_slice={"start": "2024-07-31", "end": "2024-08-14"}, partition={}), - StreamSlice(cursor_slice={"start": "2024-08-15", "end": "2024-08-29"}, partition={}), - StreamSlice(cursor_slice={"start": "2024-08-30", "end": "2024-09-09"}, partition={}), + StreamSlice( + cursor_slice={"start": "2024-07-01", "end": "2024-07-15"}, partition={} + ), + StreamSlice( + cursor_slice={"start": "2024-07-16", "end": "2024-07-30"}, partition={} + ), + StreamSlice( + cursor_slice={"start": "2024-07-31", "end": "2024-08-14"}, partition={} + ), + StreamSlice( + cursor_slice={"start": "2024-08-15", "end": "2024-08-29"}, partition={} + ), + StreamSlice( + cursor_slice={"start": "2024-08-30", "end": "2024-09-09"}, partition={} + ), # Slices used during incremental checkpoint sync. Unsuccessful partitions use the P5D lookback window which explains # the skew of records midway through - StreamSlice(cursor_slice={"start": "2024-07-01", "end": "2024-07-16"}, partition={}), - StreamSlice(cursor_slice={"start": "2024-07-30", "end": "2024-08-13"}, partition={}), - StreamSlice(cursor_slice={"start": "2024-08-14", "end": "2024-08-14"}, partition={}), - StreamSlice(cursor_slice={"start": "2024-09-04", "end": "2024-09-09"}, partition={}), + StreamSlice( + cursor_slice={"start": "2024-07-01", "end": "2024-07-16"}, partition={} + ), + StreamSlice( + cursor_slice={"start": "2024-07-30", "end": "2024-08-13"}, partition={} + ), + StreamSlice( + cursor_slice={"start": "2024-08-14", "end": "2024-08-14"}, partition={} + ), + StreamSlice( + cursor_slice={"start": "2024-09-04", "end": "2024-09-09"}, partition={} + ), ] records = [ - {"id": "amamiya", "first_name": "ren", "last_name": "amamiya", "updated_at": "2024-07-10"}, - {"id": "nijima", "first_name": "makoto", "last_name": "nijima", "updated_at": "2024-08-10"}, - {"id": "yoshizawa", "first_name": "sumire", "last_name": "yoshizawa", "updated_at": "2024-09-10"}, + { + "id": "amamiya", + "first_name": "ren", + "last_name": "amamiya", + "updated_at": "2024-07-10", + }, + { + "id": "nijima", + "first_name": "makoto", + "last_name": "nijima", + "updated_at": "2024-08-10", + }, + { + "id": "yoshizawa", + "first_name": "sumire", + "last_name": "yoshizawa", + "updated_at": "2024-09-10", + }, ] case "palaces": slices = [StreamSlice(cursor_slice={}, partition={})] @@ -1169,17 +1543,31 @@ def get_mocked_read_records_output(stream_name: str) -> Mapping[tuple[str, str], raise ValueError(f"Stream '{stream_name}' does not have associated mocked records") return { - (_slice.get("start"), _slice.get("end")): [Record(data=stream_data, associated_slice=_slice) for stream_data in records] + (_slice.get("start"), _slice.get("end")): [ + Record(data=stream_data, associated_slice=_slice) for stream_data in records + ] for _slice in slices } -def get_records_for_stream(stream_name: str, messages: List[AirbyteMessage]) -> List[AirbyteRecordMessage]: - return [message.record for message in messages if message.record and message.record.stream == stream_name] +def get_records_for_stream( + stream_name: str, messages: List[AirbyteMessage] +) -> List[AirbyteRecordMessage]: + return [ + message.record + for message in messages + if message.record and message.record.stream == stream_name + ] -def get_states_for_stream(stream_name: str, messages: List[AirbyteMessage]) -> List[AirbyteStateMessage]: - return [message.state for message in messages if message.state and message.state.stream.stream_descriptor.name == stream_name] +def get_states_for_stream( + stream_name: str, messages: List[AirbyteMessage] +) -> List[AirbyteStateMessage]: + return [ + message.state + for message in messages + if message.state and message.state.stream.stream_descriptor.name == stream_name + ] def disable_emitting_sequential_state_messages(source: ConcurrentDeclarativeSource) -> None: diff --git a/unit_tests/sources/declarative/test_declarative_stream.py b/unit_tests/sources/declarative/test_declarative_stream.py index 8906b625..f34fd137 100644 --- a/unit_tests/sources/declarative/test_declarative_stream.py +++ b/unit_tests/sources/declarative/test_declarative_stream.py @@ -5,7 +5,15 @@ from unittest.mock import MagicMock import pytest -from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, AirbyteTraceMessage, Level, SyncMode, TraceType, Type +from airbyte_cdk.models import ( + AirbyteLogMessage, + AirbyteMessage, + AirbyteTraceMessage, + Level, + SyncMode, + TraceType, + Type, +) from airbyte_cdk.sources.declarative.declarative_stream import DeclarativeStream from airbyte_cdk.sources.types import StreamSlice @@ -24,8 +32,12 @@ def test_declarative_stream(): records = [ {"pk": 1234, "field": "value"}, {"pk": 4567, "field": "different_value"}, - AirbyteMessage(type=Type.LOG, log=AirbyteLogMessage(level=Level.INFO, message="This is a log message")), - AirbyteMessage(type=Type.TRACE, trace=AirbyteTraceMessage(type=TraceType.ERROR, emitted_at=12345)), + AirbyteMessage( + type=Type.LOG, log=AirbyteLogMessage(level=Level.INFO, message="This is a log message") + ), + AirbyteMessage( + type=Type.TRACE, trace=AirbyteTraceMessage(type=TraceType.ERROR, emitted_at=12345) + ), ] stream_slices = [ StreamSlice(partition={}, cursor_slice={"date": "2021-01-01"}), @@ -54,10 +66,18 @@ def test_declarative_stream(): assert stream.get_json_schema() == _json_schema assert stream.state == state input_slice = stream_slices[0] - assert list(stream.read_records(SyncMode.full_refresh, _cursor_field, input_slice, state)) == records + assert ( + list(stream.read_records(SyncMode.full_refresh, _cursor_field, input_slice, state)) + == records + ) assert stream.primary_key == _primary_key assert stream.cursor_field == _cursor_field - assert stream.stream_slices(sync_mode=SyncMode.incremental, cursor_field=_cursor_field, stream_state=None) == stream_slices + assert ( + stream.stream_slices( + sync_mode=SyncMode.incremental, cursor_field=_cursor_field, stream_state=None + ) + == stream_slices + ) def test_declarative_stream_using_empty_slice(): @@ -69,8 +89,12 @@ def test_declarative_stream_using_empty_slice(): records = [ {"pk": 1234, "field": "value"}, {"pk": 4567, "field": "different_value"}, - AirbyteMessage(type=Type.LOG, log=AirbyteLogMessage(level=Level.INFO, message="This is a log message")), - AirbyteMessage(type=Type.TRACE, trace=AirbyteTraceMessage(type=TraceType.ERROR, emitted_at=12345)), + AirbyteMessage( + type=Type.LOG, log=AirbyteLogMessage(level=Level.INFO, message="This is a log message") + ), + AirbyteMessage( + type=Type.TRACE, trace=AirbyteTraceMessage(type=TraceType.ERROR, emitted_at=12345) + ), ] retriever = MagicMock() diff --git a/unit_tests/sources/declarative/test_manifest_declarative_source.py b/unit_tests/sources/declarative/test_manifest_declarative_source.py index 2d350fa1..19be3a82 100644 --- a/unit_tests/sources/declarative/test_manifest_declarative_source.py +++ b/unit_tests/sources/declarative/test_manifest_declarative_source.py @@ -59,7 +59,10 @@ def use_external_yaml_spec(self): test_path = os.path.dirname(module_path) spec_root = test_path.split("/sources/declarative")[0] - spec = {"documentationUrl": "https://airbyte.com/#yaml-from-external", "connectionSpecification": EXTERNAL_CONNECTION_SPECIFICATION} + spec = { + "documentationUrl": "https://airbyte.com/#yaml-from-external", + "connectionSpecification": EXTERNAL_CONNECTION_SPECIFICATION, + } yaml_path = os.path.join(spec_root, "spec.yaml") with open(yaml_path, "w") as f: @@ -75,7 +78,11 @@ def test_valid_manifest(self): "streams": [ { "type": "DeclarativeStream", - "$parameters": {"name": "lists", "primary_key": "id", "url_base": "https://api.sendgrid.com"}, + "$parameters": { + "name": "lists", + "primary_key": "id", + "url_base": "https://api.sendgrid.com", + }, "schema_loader": { "name": "{{ parameters.stream_name }}", "file_path": "./source_sendgrid/schemas/{{ parameters.name }}.yaml", @@ -84,7 +91,11 @@ def test_valid_manifest(self): "paginator": { "type": "DefaultPaginator", "page_size": 10, - "page_size_option": {"type": "RequestOption", "inject_into": "request_parameter", "field_name": "page_size"}, + "page_size_option": { + "type": "RequestOption", + "inject_into": "request_parameter", + "field_name": "page_size", + }, "page_token_option": {"type": "RequestPath"}, "pagination_strategy": { "type": "CursorPagination", @@ -94,7 +105,10 @@ def test_valid_manifest(self): }, "requester": { "path": "/v3/marketing/lists", - "authenticator": {"type": "BearerAuthenticator", "api_token": "{{ config.apikey }}"}, + "authenticator": { + "type": "BearerAuthenticator", + "api_token": "{{ config.apikey }}", + }, "request_parameters": {"page_size": "{{ 10 }}"}, }, "record_selector": {"extractor": {"field_path": ["result"]}}, @@ -102,7 +116,11 @@ def test_valid_manifest(self): }, { "type": "DeclarativeStream", - "$parameters": {"name": "stream_with_custom_requester", "primary_key": "id", "url_base": "https://api.sendgrid.com"}, + "$parameters": { + "name": "stream_with_custom_requester", + "primary_key": "id", + "url_base": "https://api.sendgrid.com", + }, "schema_loader": { "name": "{{ parameters.stream_name }}", "file_path": "./source_sendgrid/schemas/{{ parameters.name }}.yaml", @@ -111,7 +129,11 @@ def test_valid_manifest(self): "paginator": { "type": "DefaultPaginator", "page_size": 10, - "page_size_option": {"type": "RequestOption", "inject_into": "request_parameter", "field_name": "page_size"}, + "page_size_option": { + "type": "RequestOption", + "inject_into": "request_parameter", + "field_name": "page_size", + }, "page_token_option": {"type": "RequestPath"}, "pagination_strategy": { "type": "CursorPagination", @@ -140,7 +162,10 @@ def test_valid_manifest(self): assert len(streams) == 2 assert isinstance(streams[0], DeclarativeStream) assert isinstance(streams[1], DeclarativeStream) - assert source.resolved_manifest["description"] == "This is a sample source connector that is very valid." + assert ( + source.resolved_manifest["description"] + == "This is a sample source connector that is very valid." + ) def test_manifest_with_spec(self): manifest = { @@ -154,13 +179,23 @@ def test_manifest_with_spec(self): "paginator": { "type": "DefaultPaginator", "page_size": 10, - "page_size_option": {"type": "RequestOption", "inject_into": "request_parameter", "field_name": "page_size"}, + "page_size_option": { + "type": "RequestOption", + "inject_into": "request_parameter", + "field_name": "page_size", + }, "page_token_option": {"type": "RequestPath"}, - "pagination_strategy": {"type": "CursorPagination", "cursor_value": "{{ response._metadata.next }}"}, + "pagination_strategy": { + "type": "CursorPagination", + "cursor_value": "{{ response._metadata.next }}", + }, }, "requester": { "path": "/v3/marketing/lists", - "authenticator": {"type": "BearerAuthenticator", "api_token": "{{ config.apikey }}"}, + "authenticator": { + "type": "BearerAuthenticator", + "api_token": "{{ config.apikey }}", + }, "request_parameters": {"page_size": "{{ 10 }}"}, }, "record_selector": {"extractor": {"field_path": ["result"]}}, @@ -169,7 +204,11 @@ def test_manifest_with_spec(self): "streams": [ { "type": "DeclarativeStream", - "$parameters": {"name": "lists", "primary_key": "id", "url_base": "https://api.sendgrid.com"}, + "$parameters": { + "name": "lists", + "primary_key": "id", + "url_base": "https://api.sendgrid.com", + }, "schema_loader": { "name": "{{ parameters.stream_name }}", "file_path": "./source_sendgrid/schemas/{{ parameters.name }}.yaml", @@ -178,13 +217,23 @@ def test_manifest_with_spec(self): "paginator": { "type": "DefaultPaginator", "page_size": 10, - "page_size_option": {"type": "RequestOption", "inject_into": "request_parameter", "field_name": "page_size"}, + "page_size_option": { + "type": "RequestOption", + "inject_into": "request_parameter", + "field_name": "page_size", + }, "page_token_option": {"type": "RequestPath"}, - "pagination_strategy": {"type": "CursorPagination", "cursor_value": "{{ response._metadata.next }}"}, + "pagination_strategy": { + "type": "CursorPagination", + "cursor_value": "{{ response._metadata.next }}", + }, }, "requester": { "path": "/v3/marketing/lists", - "authenticator": {"type": "BearerAuthenticator", "api_token": "{{ config.apikey }}"}, + "authenticator": { + "type": "BearerAuthenticator", + "api_token": "{{ config.apikey }}", + }, "request_parameters": {"page_size": "{{ 10 }}"}, }, "record_selector": {"extractor": {"field_path": ["result"]}}, @@ -201,7 +250,13 @@ def test_manifest_with_spec(self): "required": ["api_key"], "additionalProperties": False, "properties": { - "api_key": {"type": "string", "airbyte_secret": True, "title": "API Key", "description": "Test API Key", "order": 0} + "api_key": { + "type": "string", + "airbyte_secret": True, + "title": "API Key", + "description": "Test API Key", + "order": 0, + } }, }, }, @@ -233,13 +288,23 @@ def test_manifest_with_external_spec(self, use_external_yaml_spec): "paginator": { "type": "DefaultPaginator", "page_size": 10, - "page_size_option": {"type": "RequestOption", "inject_into": "request_parameter", "field_name": "page_size"}, + "page_size_option": { + "type": "RequestOption", + "inject_into": "request_parameter", + "field_name": "page_size", + }, "page_token_option": {"type": "RequestPath"}, - "pagination_strategy": {"type": "CursorPagination", "cursor_value": "{{ response._metadata.next }}"}, + "pagination_strategy": { + "type": "CursorPagination", + "cursor_value": "{{ response._metadata.next }}", + }, }, "requester": { "path": "/v3/marketing/lists", - "authenticator": {"type": "BearerAuthenticator", "api_token": "{{ config.apikey }}"}, + "authenticator": { + "type": "BearerAuthenticator", + "api_token": "{{ config.apikey }}", + }, "request_parameters": {"page_size": "{{ 10 }}"}, }, "record_selector": {"extractor": {"field_path": ["result"]}}, @@ -248,7 +313,11 @@ def test_manifest_with_external_spec(self, use_external_yaml_spec): "streams": [ { "type": "DeclarativeStream", - "$parameters": {"name": "lists", "primary_key": "id", "url_base": "https://api.sendgrid.com"}, + "$parameters": { + "name": "lists", + "primary_key": "id", + "url_base": "https://api.sendgrid.com", + }, "schema_loader": { "name": "{{ parameters.stream_name }}", "file_path": "./source_sendgrid/schemas/{{ parameters.name }}.yaml", @@ -257,13 +326,23 @@ def test_manifest_with_external_spec(self, use_external_yaml_spec): "paginator": { "type": "DefaultPaginator", "page_size": 10, - "page_size_option": {"type": "RequestOption", "inject_into": "request_parameter", "field_name": "page_size"}, + "page_size_option": { + "type": "RequestOption", + "inject_into": "request_parameter", + "field_name": "page_size", + }, "page_token_option": {"type": "RequestPath"}, - "pagination_strategy": {"type": "CursorPagination", "cursor_value": "{{ response._metadata.next }}"}, + "pagination_strategy": { + "type": "CursorPagination", + "cursor_value": "{{ response._metadata.next }}", + }, }, "requester": { "path": "/v3/marketing/lists", - "authenticator": {"type": "BearerAuthenticator", "api_token": "{{ config.apikey }}"}, + "authenticator": { + "type": "BearerAuthenticator", + "api_token": "{{ config.apikey }}", + }, "request_parameters": {"page_size": "{{ 10 }}"}, }, "record_selector": {"extractor": {"field_path": ["result"]}}, @@ -291,13 +370,23 @@ def test_source_is_not_created_if_toplevel_fields_are_unknown(self): "paginator": { "type": "DefaultPaginator", "page_size": 10, - "page_size_option": {"type": "RequestOption", "inject_into": "request_parameter", "field_name": "page_size"}, + "page_size_option": { + "type": "RequestOption", + "inject_into": "request_parameter", + "field_name": "page_size", + }, "page_token_option": {"type": "RequestPath"}, - "pagination_strategy": {"type": "CursorPagination", "cursor_value": "{{ response._metadata.next }}"}, + "pagination_strategy": { + "type": "CursorPagination", + "cursor_value": "{{ response._metadata.next }}", + }, }, "requester": { "path": "/v3/marketing/lists", - "authenticator": {"type": "BearerAuthenticator", "api_token": "{{ config.apikey }}"}, + "authenticator": { + "type": "BearerAuthenticator", + "api_token": "{{ config.apikey }}", + }, "request_parameters": {"page_size": 10}, }, "record_selector": {"extractor": {"field_path": ["result"]}}, @@ -306,7 +395,11 @@ def test_source_is_not_created_if_toplevel_fields_are_unknown(self): "streams": [ { "type": "DeclarativeStream", - "$parameters": {"name": "lists", "primary_key": "id", "url_base": "https://api.sendgrid.com"}, + "$parameters": { + "name": "lists", + "primary_key": "id", + "url_base": "https://api.sendgrid.com", + }, "schema_loader": { "name": "{{ parameters.stream_name }}", "file_path": "./source_sendgrid/schemas/{{ parameters.name }}.yaml", @@ -315,13 +408,23 @@ def test_source_is_not_created_if_toplevel_fields_are_unknown(self): "paginator": { "type": "DefaultPaginator", "page_size": 10, - "page_size_option": {"type": "RequestOption", "inject_into": "request_parameter", "field_name": "page_size"}, + "page_size_option": { + "type": "RequestOption", + "inject_into": "request_parameter", + "field_name": "page_size", + }, "page_token_option": {"type": "RequestPath"}, - "pagination_strategy": {"type": "CursorPagination", "cursor_value": "{{ response._metadata.next }}"}, + "pagination_strategy": { + "type": "CursorPagination", + "cursor_value": "{{ response._metadata.next }}", + }, }, "requester": { "path": "/v3/marketing/lists", - "authenticator": {"type": "BearerAuthenticator", "api_token": "{{ config.apikey }}"}, + "authenticator": { + "type": "BearerAuthenticator", + "api_token": "{{ config.apikey }}", + }, "request_parameters": {"page_size": 10}, }, "record_selector": {"extractor": {"field_path": ["result"]}}, @@ -346,13 +449,23 @@ def test_source_missing_checker_fails_validation(self): "paginator": { "type": "DefaultPaginator", "page_size": 10, - "page_size_option": {"type": "RequestOption", "inject_into": "request_parameter", "field_name": "page_size"}, + "page_size_option": { + "type": "RequestOption", + "inject_into": "request_parameter", + "field_name": "page_size", + }, "page_token_option": {"type": "RequestPath"}, - "pagination_strategy": {"type": "CursorPagination", "cursor_value": "{{ response._metadata.next }}"}, + "pagination_strategy": { + "type": "CursorPagination", + "cursor_value": "{{ response._metadata.next }}", + }, }, "requester": { "path": "/v3/marketing/lists", - "authenticator": {"type": "BearerAuthenticator", "api_token": "{{ config.apikey }}"}, + "authenticator": { + "type": "BearerAuthenticator", + "api_token": "{{ config.apikey }}", + }, "request_parameters": {"page_size": 10}, }, "record_selector": {"extractor": {"field_path": ["result"]}}, @@ -361,7 +474,11 @@ def test_source_missing_checker_fails_validation(self): "streams": [ { "type": "DeclarativeStream", - "$parameters": {"name": "lists", "primary_key": "id", "url_base": "https://api.sendgrid.com"}, + "$parameters": { + "name": "lists", + "primary_key": "id", + "url_base": "https://api.sendgrid.com", + }, "schema_loader": { "name": "{{ parameters.stream_name }}", "file_path": "./source_sendgrid/schemas/{{ parameters.name }}.yaml", @@ -370,13 +487,23 @@ def test_source_missing_checker_fails_validation(self): "paginator": { "type": "DefaultPaginator", "page_size": 10, - "page_size_option": {"type": "RequestOption", "inject_into": "request_parameter", "field_name": "page_size"}, + "page_size_option": { + "type": "RequestOption", + "inject_into": "request_parameter", + "field_name": "page_size", + }, "page_token_option": {"type": "RequestPath"}, - "pagination_strategy": {"type": "CursorPagination", "cursor_value": "{{ response._metadata.next }}"}, + "pagination_strategy": { + "type": "CursorPagination", + "cursor_value": "{{ response._metadata.next }}", + }, }, "requester": { "path": "/v3/marketing/lists", - "authenticator": {"type": "BearerAuthenticator", "api_token": "{{ config.apikey }}"}, + "authenticator": { + "type": "BearerAuthenticator", + "api_token": "{{ config.apikey }}", + }, "request_parameters": {"page_size": 10}, }, "record_selector": {"extractor": {"field_path": ["result"]}}, @@ -389,7 +516,11 @@ def test_source_missing_checker_fails_validation(self): ManifestDeclarativeSource(source_config=manifest) def test_source_with_missing_streams_fails(self): - manifest = {"version": "0.29.3", "definitions": None, "check": {"type": "CheckStream", "stream_names": ["lists"]}} + manifest = { + "version": "0.29.3", + "definitions": None, + "check": {"type": "CheckStream", "stream_names": ["lists"]}, + } with pytest.raises(ValidationError): ManifestDeclarativeSource(source_config=manifest) @@ -404,13 +535,23 @@ def test_source_with_missing_version_fails(self): "paginator": { "type": "DefaultPaginator", "page_size": 10, - "page_size_option": {"type": "RequestOption", "inject_into": "request_parameter", "field_name": "page_size"}, + "page_size_option": { + "type": "RequestOption", + "inject_into": "request_parameter", + "field_name": "page_size", + }, "page_token_option": {"type": "RequestPath"}, - "pagination_strategy": {"type": "CursorPagination", "cursor_value": "{{ response._metadata.next }}"}, + "pagination_strategy": { + "type": "CursorPagination", + "cursor_value": "{{ response._metadata.next }}", + }, }, "requester": { "path": "/v3/marketing/lists", - "authenticator": {"type": "BearerAuthenticator", "api_token": "{{ config.apikey }}"}, + "authenticator": { + "type": "BearerAuthenticator", + "api_token": "{{ config.apikey }}", + }, "request_parameters": {"page_size": 10}, }, "record_selector": {"extractor": {"field_path": ["result"]}}, @@ -419,7 +560,11 @@ def test_source_with_missing_version_fails(self): "streams": [ { "type": "DeclarativeStream", - "$parameters": {"name": "lists", "primary_key": "id", "url_base": "https://api.sendgrid.com"}, + "$parameters": { + "name": "lists", + "primary_key": "id", + "url_base": "https://api.sendgrid.com", + }, "schema_loader": { "name": "{{ parameters.stream_name }}", "file_path": "./source_sendgrid/schemas/{{ parameters.name }}.yaml", @@ -428,13 +573,23 @@ def test_source_with_missing_version_fails(self): "paginator": { "type": "DefaultPaginator", "page_size": 10, - "page_size_option": {"type": "RequestOption", "inject_into": "request_parameter", "field_name": "page_size"}, + "page_size_option": { + "type": "RequestOption", + "inject_into": "request_parameter", + "field_name": "page_size", + }, "page_token_option": {"type": "RequestPath"}, - "pagination_strategy": {"type": "CursorPagination", "cursor_value": "{{ response._metadata.next }}"}, + "pagination_strategy": { + "type": "CursorPagination", + "cursor_value": "{{ response._metadata.next }}", + }, }, "requester": { "path": "/v3/marketing/lists", - "authenticator": {"type": "BearerAuthenticator", "api_token": "{{ config.apikey }}"}, + "authenticator": { + "type": "BearerAuthenticator", + "api_token": "{{ config.apikey }}", + }, "request_parameters": {"page_size": 10}, }, "record_selector": {"extractor": {"field_path": ["result"]}}, @@ -449,9 +604,18 @@ def test_source_with_missing_version_fails(self): @pytest.mark.parametrize( "cdk_version, manifest_version, expected_error", [ - pytest.param("0.35.0", "0.30.0", None, id="manifest_version_less_than_cdk_package_should_run"), - pytest.param("1.5.0", "0.29.0", None, id="manifest_version_less_than_cdk_major_package_should_run"), - pytest.param("0.29.0", "0.29.0", None, id="manifest_version_matching_cdk_package_should_run"), + pytest.param( + "0.35.0", "0.30.0", None, id="manifest_version_less_than_cdk_package_should_run" + ), + pytest.param( + "1.5.0", + "0.29.0", + None, + id="manifest_version_less_than_cdk_major_package_should_run", + ), + pytest.param( + "0.29.0", "0.29.0", None, id="manifest_version_matching_cdk_package_should_run" + ), pytest.param( "0.29.0", "0.25.0", @@ -464,12 +628,30 @@ def test_source_with_missing_version_fails(self): ValidationError, id="manifest_version_before_beta_that_uses_package_later_major_version_than_beta_0.29.0_cdk_package_should_throw_error", ), - pytest.param("0.34.0", "0.35.0", ValidationError, id="manifest_version_greater_than_cdk_package_should_throw_error"), - pytest.param("0.29.0", "-1.5.0", ValidationError, id="manifest_version_has_invalid_major_format"), - pytest.param("0.29.0", "0.invalid.0", ValidationError, id="manifest_version_has_invalid_minor_format"), - pytest.param("0.29.0", "0.29.0.1", ValidationError, id="manifest_version_has_extra_version_parts"), - pytest.param("0.29.0", "5.0", ValidationError, id="manifest_version_has_too_few_version_parts"), - pytest.param("0.29.0:dev", "0.29.0", ValidationError, id="manifest_version_has_extra_release"), + pytest.param( + "0.34.0", + "0.35.0", + ValidationError, + id="manifest_version_greater_than_cdk_package_should_throw_error", + ), + pytest.param( + "0.29.0", "-1.5.0", ValidationError, id="manifest_version_has_invalid_major_format" + ), + pytest.param( + "0.29.0", + "0.invalid.0", + ValidationError, + id="manifest_version_has_invalid_minor_format", + ), + pytest.param( + "0.29.0", "0.29.0.1", ValidationError, id="manifest_version_has_extra_version_parts" + ), + pytest.param( + "0.29.0", "5.0", ValidationError, id="manifest_version_has_too_few_version_parts" + ), + pytest.param( + "0.29.0:dev", "0.29.0", ValidationError, id="manifest_version_has_extra_release" + ), ], ) @patch("importlib.metadata.version") @@ -483,7 +665,11 @@ def test_manifest_versions(self, version, cdk_version, manifest_version, expecte "streams": [ { "type": "DeclarativeStream", - "$parameters": {"name": "lists", "primary_key": "id", "url_base": "https://api.sendgrid.com"}, + "$parameters": { + "name": "lists", + "primary_key": "id", + "url_base": "https://api.sendgrid.com", + }, "schema_loader": { "name": "{{ parameters.stream_name }}", "file_path": "./source_sendgrid/schemas/{{ parameters.name }}.yaml", @@ -492,7 +678,11 @@ def test_manifest_versions(self, version, cdk_version, manifest_version, expecte "paginator": { "type": "DefaultPaginator", "page_size": 10, - "page_size_option": {"type": "RequestOption", "inject_into": "request_parameter", "field_name": "page_size"}, + "page_size_option": { + "type": "RequestOption", + "inject_into": "request_parameter", + "field_name": "page_size", + }, "page_token_option": {"type": "RequestPath"}, "pagination_strategy": { "type": "CursorPagination", @@ -502,7 +692,10 @@ def test_manifest_versions(self, version, cdk_version, manifest_version, expecte }, "requester": { "path": "/v3/marketing/lists", - "authenticator": {"type": "BearerAuthenticator", "api_token": "{{ config.apikey }}"}, + "authenticator": { + "type": "BearerAuthenticator", + "api_token": "{{ config.apikey }}", + }, "request_parameters": {"page_size": "{{ 10 }}"}, }, "record_selector": {"extractor": {"field_path": ["result"]}}, @@ -510,7 +703,11 @@ def test_manifest_versions(self, version, cdk_version, manifest_version, expecte }, { "type": "DeclarativeStream", - "$parameters": {"name": "stream_with_custom_requester", "primary_key": "id", "url_base": "https://api.sendgrid.com"}, + "$parameters": { + "name": "stream_with_custom_requester", + "primary_key": "id", + "url_base": "https://api.sendgrid.com", + }, "schema_loader": { "name": "{{ parameters.stream_name }}", "file_path": "./source_sendgrid/schemas/{{ parameters.name }}.yaml", @@ -519,7 +716,11 @@ def test_manifest_versions(self, version, cdk_version, manifest_version, expecte "paginator": { "type": "DefaultPaginator", "page_size": 10, - "page_size_option": {"type": "RequestOption", "inject_into": "request_parameter", "field_name": "page_size"}, + "page_size_option": { + "type": "RequestOption", + "inject_into": "request_parameter", + "field_name": "page_size", + }, "page_token_option": {"type": "RequestPath"}, "pagination_strategy": { "type": "CursorPagination", @@ -557,7 +758,11 @@ def test_source_with_invalid_stream_config_fails_validation(self): "streams": [ { "type": "DeclarativeStream", - "$parameters": {"name": "lists", "primary_key": "id", "url_base": "https://api.sendgrid.com"}, + "$parameters": { + "name": "lists", + "primary_key": "id", + "url_base": "https://api.sendgrid.com", + }, "schema_loader": { "name": "{{ parameters.stream_name }}", "file_path": "./source_sendgrid/schemas/{{ parameters.name }}.yaml", @@ -581,13 +786,23 @@ def test_source_with_no_external_spec_and_no_in_yaml_spec_fails(self): "paginator": { "type": "DefaultPaginator", "page_size": 10, - "page_size_option": {"type": "RequestOption", "inject_into": "request_parameter", "field_name": "page_size"}, + "page_size_option": { + "type": "RequestOption", + "inject_into": "request_parameter", + "field_name": "page_size", + }, "page_token_option": {"type": "RequestPath"}, - "pagination_strategy": {"type": "CursorPagination", "cursor_value": "{{ response._metadata.next }}"}, + "pagination_strategy": { + "type": "CursorPagination", + "cursor_value": "{{ response._metadata.next }}", + }, }, "requester": { "path": "/v3/marketing/lists", - "authenticator": {"type": "BearerAuthenticator", "api_token": "{{ config.apikey }}"}, + "authenticator": { + "type": "BearerAuthenticator", + "api_token": "{{ config.apikey }}", + }, "request_parameters": {"page_size": "{{ 10 }}"}, }, "record_selector": {"extractor": {"field_path": ["result"]}}, @@ -596,7 +811,11 @@ def test_source_with_no_external_spec_and_no_in_yaml_spec_fails(self): "streams": [ { "type": "DeclarativeStream", - "$parameters": {"name": "lists", "primary_key": "id", "url_base": "https://api.sendgrid.com"}, + "$parameters": { + "name": "lists", + "primary_key": "id", + "url_base": "https://api.sendgrid.com", + }, "schema_loader": { "name": "{{ parameters.stream_name }}", "file_path": "./source_sendgrid/schemas/{{ parameters.name }}.yaml", @@ -605,13 +824,23 @@ def test_source_with_no_external_spec_and_no_in_yaml_spec_fails(self): "paginator": { "type": "DefaultPaginator", "page_size": 10, - "page_size_option": {"type": "RequestOption", "inject_into": "request_parameter", "field_name": "page_size"}, + "page_size_option": { + "type": "RequestOption", + "inject_into": "request_parameter", + "field_name": "page_size", + }, "page_token_option": {"type": "RequestPath"}, - "pagination_strategy": {"type": "CursorPagination", "cursor_value": "{{ response._metadata.next }}"}, + "pagination_strategy": { + "type": "CursorPagination", + "cursor_value": "{{ response._metadata.next }}", + }, }, "requester": { "path": "/v3/marketing/lists", - "authenticator": {"type": "BearerAuthenticator", "api_token": "{{ config.apikey }}"}, + "authenticator": { + "type": "BearerAuthenticator", + "api_token": "{{ config.apikey }}", + }, "request_parameters": {"page_size": "{{ 10 }}"}, }, "record_selector": {"extractor": {"field_path": ["result"]}}, @@ -638,13 +867,23 @@ def test_manifest_without_at_least_one_stream(self): "paginator": { "type": "DefaultPaginator", "page_size": 10, - "page_size_option": {"type": "RequestOption", "inject_into": "request_parameter", "field_name": "page_size"}, + "page_size_option": { + "type": "RequestOption", + "inject_into": "request_parameter", + "field_name": "page_size", + }, "page_token_option": {"type": "RequestPath"}, - "pagination_strategy": {"type": "CursorPagination", "cursor_value": "{{ response._metadata.next }}"}, + "pagination_strategy": { + "type": "CursorPagination", + "cursor_value": "{{ response._metadata.next }}", + }, }, "requester": { "path": "/v3/marketing/lists", - "authenticator": {"type": "BearerAuthenticator", "api_token": "{{ config.apikey }}"}, + "authenticator": { + "type": "BearerAuthenticator", + "api_token": "{{ config.apikey }}", + }, "request_parameters": {"page_size": 10}, }, "record_selector": {"extractor": {"field_path": ["result"]}}, @@ -669,13 +908,23 @@ def test_given_debug_when_read_then_set_log_level(self, declarative_source_read) "paginator": { "type": "DefaultPaginator", "page_size": 10, - "page_size_option": {"type": "RequestOption", "inject_into": "request_parameter", "field_name": "page_size"}, + "page_size_option": { + "type": "RequestOption", + "inject_into": "request_parameter", + "field_name": "page_size", + }, "page_token_option": {"type": "RequestPath"}, - "pagination_strategy": {"type": "CursorPagination", "cursor_value": "{{ response._metadata.next }}"}, + "pagination_strategy": { + "type": "CursorPagination", + "cursor_value": "{{ response._metadata.next }}", + }, }, "requester": { "path": "/v3/marketing/lists", - "authenticator": {"type": "BearerAuthenticator", "api_token": "{{ config.apikey }}"}, + "authenticator": { + "type": "BearerAuthenticator", + "api_token": "{{ config.apikey }}", + }, "request_parameters": {"page_size": "10"}, }, "record_selector": {"extractor": {"field_path": ["result"]}}, @@ -684,7 +933,11 @@ def test_given_debug_when_read_then_set_log_level(self, declarative_source_read) "streams": [ { "type": "DeclarativeStream", - "$parameters": {"name": "lists", "primary_key": "id", "url_base": "https://api.sendgrid.com"}, + "$parameters": { + "name": "lists", + "primary_key": "id", + "url_base": "https://api.sendgrid.com", + }, "schema_loader": { "name": "{{ parameters.stream_name }}", "file_path": "./source_sendgrid/schemas/{{ parameters.name }}.yaml", @@ -693,7 +946,11 @@ def test_given_debug_when_read_then_set_log_level(self, declarative_source_read) "paginator": { "type": "DefaultPaginator", "page_size": 10, - "page_size_option": {"type": "RequestOption", "inject_into": "request_parameter", "field_name": "page_size"}, + "page_size_option": { + "type": "RequestOption", + "inject_into": "request_parameter", + "field_name": "page_size", + }, "page_token_option": {"type": "RequestPath"}, "pagination_strategy": { "type": "CursorPagination", @@ -703,7 +960,10 @@ def test_given_debug_when_read_then_set_log_level(self, declarative_source_read) }, "requester": { "path": "/v3/marketing/lists", - "authenticator": {"type": "BearerAuthenticator", "api_token": "{{ config.apikey }}"}, + "authenticator": { + "type": "BearerAuthenticator", + "api_token": "{{ config.apikey }}", + }, "request_parameters": {"page_size": "{{ 10 }}"}, }, "record_selector": {"extractor": {"field_path": ["result"]}}, @@ -711,7 +971,11 @@ def test_given_debug_when_read_then_set_log_level(self, declarative_source_read) }, { "type": "DeclarativeStream", - "$parameters": {"name": "stream_with_custom_requester", "primary_key": "id", "url_base": "https://api.sendgrid.com"}, + "$parameters": { + "name": "stream_with_custom_requester", + "primary_key": "id", + "url_base": "https://api.sendgrid.com", + }, "schema_loader": { "name": "{{ parameters.stream_name }}", "file_path": "./source_sendgrid/schemas/{{ parameters.name }}.yaml", @@ -720,7 +984,11 @@ def test_given_debug_when_read_then_set_log_level(self, declarative_source_read) "paginator": { "type": "DefaultPaginator", "page_size": 10, - "page_size_option": {"type": "RequestOption", "inject_into": "request_parameter", "field_name": "page_size"}, + "page_size_option": { + "type": "RequestOption", + "inject_into": "request_parameter", + "field_name": "page_size", + }, "page_token_option": {"type": "RequestPath"}, "pagination_strategy": { "type": "CursorPagination", @@ -749,11 +1017,17 @@ def test_given_debug_when_read_then_set_log_level(self, declarative_source_read) def request_log_message(request: dict) -> AirbyteMessage: - return AirbyteMessage(type=Type.LOG, log=AirbyteLogMessage(level=Level.INFO, message=f"request:{json.dumps(request)}")) + return AirbyteMessage( + type=Type.LOG, + log=AirbyteLogMessage(level=Level.INFO, message=f"request:{json.dumps(request)}"), + ) def response_log_message(response: dict) -> AirbyteMessage: - return AirbyteMessage(type=Type.LOG, log=AirbyteLogMessage(level=Level.INFO, message=f"response:{json.dumps(response)}")) + return AirbyteMessage( + type=Type.LOG, + log=AirbyteLogMessage(level=Level.INFO, message=f"response:{json.dumps(response)}"), + ) def _create_request(): @@ -817,7 +1091,10 @@ def _create_page(response_body): "api_token": "{{ config['api_key'] }}", }, }, - "record_selector": {"type": "RecordSelector", "extractor": {"type": "DpathExtractor", "field_path": ["rates"]}}, + "record_selector": { + "type": "RecordSelector", + "extractor": {"type": "DpathExtractor", "field_path": ["rates"]}, + }, "paginator": {"type": "NoPagination"}, }, } @@ -827,7 +1104,13 @@ def _create_page(response_body): "$schema": "http://json-schema.org/draft-07/schema#", "type": "object", "required": ["api_key"], - "properties": {"api_key": {"type": "string", "title": "API Key", "airbyte_secret": True}}, + "properties": { + "api_key": { + "type": "string", + "title": "API Key", + "airbyte_secret": True, + } + }, "additionalProperties": True, }, "documentation_url": "https://example.org", @@ -867,7 +1150,13 @@ def _create_page(response_body): "transformations": [ { "type": "AddFields", - "fields": [{"type": "AddedFieldDefinition", "path": ["added_field_key"], "value": "added_field_value"}], + "fields": [ + { + "type": "AddedFieldDefinition", + "path": ["added_field_key"], + "value": "added_field_value", + } + ], } ], "retriever": { @@ -886,7 +1175,10 @@ def _create_page(response_body): "api_token": "{{ config['api_key'] }}", }, }, - "record_selector": {"type": "RecordSelector", "extractor": {"type": "DpathExtractor", "field_path": ["rates"]}}, + "record_selector": { + "type": "RecordSelector", + "extractor": {"type": "DpathExtractor", "field_path": ["rates"]}, + }, "paginator": {"type": "NoPagination"}, }, } @@ -896,7 +1188,13 @@ def _create_page(response_body): "$schema": "http://json-schema.org/draft-07/schema#", "type": "object", "required": ["api_key"], - "properties": {"api_key": {"type": "string", "title": "API Key", "airbyte_secret": True}}, + "properties": { + "api_key": { + "type": "string", + "title": "API Key", + "airbyte_secret": True, + } + }, "additionalProperties": True, }, "documentation_url": "https://example.org", @@ -908,7 +1206,10 @@ def _create_page(response_body): _create_page({"rates": [{"USD": 2}], "_metadata": {"next": "next"}}), ) * 10, - [{"ABC": 0, "added_field_key": "added_field_value"}, {"AED": 1, "added_field_key": "added_field_value"}], + [ + {"ABC": 0, "added_field_key": "added_field_value"}, + {"AED": 1, "added_field_key": "added_field_value"}, + ], [call({}, {})], ), ( @@ -950,11 +1251,17 @@ def _create_page(response_body): "api_token": "{{ config['api_key'] }}", }, }, - "record_selector": {"type": "RecordSelector", "extractor": {"type": "DpathExtractor", "field_path": ["rates"]}}, + "record_selector": { + "type": "RecordSelector", + "extractor": {"type": "DpathExtractor", "field_path": ["rates"]}, + }, "paginator": { "type": "DefaultPaginator", "page_size": 2, - "page_size_option": {"inject_into": "request_parameter", "field_name": "page_size"}, + "page_size_option": { + "inject_into": "request_parameter", + "field_name": "page_size", + }, "page_token_option": {"inject_into": "path", "type": "RequestPath"}, "pagination_strategy": { "type": "CursorPagination", @@ -970,7 +1277,13 @@ def _create_page(response_body): "$schema": "http://json-schema.org/draft-07/schema#", "type": "object", "required": ["api_key"], - "properties": {"api_key": {"type": "string", "title": "API Key", "airbyte_secret": True}}, + "properties": { + "api_key": { + "type": "string", + "title": "API Key", + "airbyte_secret": True, + } + }, "additionalProperties": True, }, "documentation_url": "https://example.org", @@ -1000,7 +1313,11 @@ def _create_page(response_body): "type": "InlineSchemaLoader", "schema": { "$schema": "http://json-schema.org/schema#", - "properties": {"ABC": {"type": "number"}, "AED": {"type": "number"}, "partition": {"type": "number"}}, + "properties": { + "ABC": {"type": "number"}, + "AED": {"type": "number"}, + "partition": {"type": "number"}, + }, "type": "object", }, }, @@ -1020,8 +1337,15 @@ def _create_page(response_body): "api_token": "{{ config['api_key'] }}", }, }, - "partition_router": {"type": "ListPartitionRouter", "values": ["0", "1"], "cursor_field": "partition"}, - "record_selector": {"type": "RecordSelector", "extractor": {"type": "DpathExtractor", "field_path": ["rates"]}}, + "partition_router": { + "type": "ListPartitionRouter", + "values": ["0", "1"], + "cursor_field": "partition", + }, + "record_selector": { + "type": "RecordSelector", + "extractor": {"type": "DpathExtractor", "field_path": ["rates"]}, + }, "paginator": {"type": "NoPagination"}, }, } @@ -1031,7 +1355,13 @@ def _create_page(response_body): "$schema": "http://json-schema.org/draft-07/schema#", "type": "object", "required": ["api_key"], - "properties": {"api_key": {"type": "string", "title": "API Key", "airbyte_secret": True}}, + "properties": { + "api_key": { + "type": "string", + "title": "API Key", + "airbyte_secret": True, + } + }, "additionalProperties": True, }, "documentation_url": "https://example.org", @@ -1039,14 +1369,28 @@ def _create_page(response_body): }, }, ( - _create_page({"rates": [{"ABC": 0, "partition": 0}, {"AED": 1, "partition": 0}], "_metadata": {"next": "next"}}), - _create_page({"rates": [{"ABC": 2, "partition": 1}], "_metadata": {"next": "next"}}), + _create_page( + { + "rates": [{"ABC": 0, "partition": 0}, {"AED": 1, "partition": 0}], + "_metadata": {"next": "next"}, + } + ), + _create_page( + {"rates": [{"ABC": 2, "partition": 1}], "_metadata": {"next": "next"}} + ), ), [{"ABC": 0, "partition": 0}, {"AED": 1, "partition": 0}, {"ABC": 2, "partition": 1}], [ call({"states": []}, {"partition": "0"}, None), call( - {"states": [{"partition": {"partition": "0"}, "cursor": {"__ab_full_refresh_sync_complete": True}}]}, + { + "states": [ + { + "partition": {"partition": "0"}, + "cursor": {"__ab_full_refresh_sync_complete": True}, + } + ] + }, {"partition": "1"}, None, ), @@ -1067,7 +1411,11 @@ def _create_page(response_body): "type": "InlineSchemaLoader", "schema": { "$schema": "http://json-schema.org/schema#", - "properties": {"ABC": {"type": "number"}, "AED": {"type": "number"}, "partition": {"type": "number"}}, + "properties": { + "ABC": {"type": "number"}, + "AED": {"type": "number"}, + "partition": {"type": "number"}, + }, "type": "object", }, }, @@ -1087,12 +1435,22 @@ def _create_page(response_body): "api_token": "{{ config['api_key'] }}", }, }, - "partition_router": {"type": "ListPartitionRouter", "values": ["0", "1"], "cursor_field": "partition"}, - "record_selector": {"type": "RecordSelector", "extractor": {"type": "DpathExtractor", "field_path": ["rates"]}}, + "partition_router": { + "type": "ListPartitionRouter", + "values": ["0", "1"], + "cursor_field": "partition", + }, + "record_selector": { + "type": "RecordSelector", + "extractor": {"type": "DpathExtractor", "field_path": ["rates"]}, + }, "paginator": { "type": "DefaultPaginator", "page_size": 2, - "page_size_option": {"inject_into": "request_parameter", "field_name": "page_size"}, + "page_size_option": { + "inject_into": "request_parameter", + "field_name": "page_size", + }, "page_token_option": {"inject_into": "path", "type": "RequestPath"}, "pagination_strategy": { "type": "CursorPagination", @@ -1108,7 +1466,13 @@ def _create_page(response_body): "$schema": "http://json-schema.org/draft-07/schema#", "type": "object", "required": ["api_key"], - "properties": {"api_key": {"type": "string", "title": "API Key", "airbyte_secret": True}}, + "properties": { + "api_key": { + "type": "string", + "title": "API Key", + "airbyte_secret": True, + } + }, "additionalProperties": True, }, "documentation_url": "https://example.org", @@ -1116,16 +1480,33 @@ def _create_page(response_body): }, }, ( - _create_page({"rates": [{"ABC": 0, "partition": 0}, {"AED": 1, "partition": 0}], "_metadata": {"next": "next"}}), + _create_page( + { + "rates": [{"ABC": 0, "partition": 0}, {"AED": 1, "partition": 0}], + "_metadata": {"next": "next"}, + } + ), _create_page({"rates": [{"USD": 3, "partition": 0}], "_metadata": {}}), _create_page({"rates": [{"ABC": 2, "partition": 1}], "_metadata": {}}), ), - [{"ABC": 0, "partition": 0}, {"AED": 1, "partition": 0}, {"USD": 3, "partition": 0}, {"ABC": 2, "partition": 1}], + [ + {"ABC": 0, "partition": 0}, + {"AED": 1, "partition": 0}, + {"USD": 3, "partition": 0}, + {"ABC": 2, "partition": 1}, + ], [ call({"states": []}, {"partition": "0"}, None), call({"states": []}, {"partition": "0"}, {"next_page_token": "next"}), call( - {"states": [{"partition": {"partition": "0"}, "cursor": {"__ab_full_refresh_sync_complete": True}}]}, + { + "states": [ + { + "partition": {"partition": "0"}, + "cursor": {"__ab_full_refresh_sync_complete": True}, + } + ] + }, {"partition": "1"}, None, ), @@ -1133,10 +1514,14 @@ def _create_page(response_body): ), ], ) -def test_read_manifest_declarative_source(test_name, manifest, pages, expected_records, expected_calls): +def test_read_manifest_declarative_source( + test_name, manifest, pages, expected_records, expected_calls +): _stream_name = "Rates" with patch.object(SimpleRetriever, "_fetch_next_page", side_effect=pages) as mock_retriever: - output_data = [message.record.data for message in _run_read(manifest, _stream_name) if message.record] + output_data = [ + message.record.data for message in _run_read(manifest, _stream_name) if message.record + ] assert output_data == expected_records mock_retriever.assert_has_calls(expected_calls) @@ -1144,7 +1529,11 @@ def test_read_manifest_declarative_source(test_name, manifest, pages, expected_r def test_only_parent_streams_use_cache(): applications_stream = { "type": "DeclarativeStream", - "$parameters": {"name": "applications", "primary_key": "id", "url_base": "https://harvest.greenhouse.io/v1/"}, + "$parameters": { + "name": "applications", + "primary_key": "id", + "url_base": "https://harvest.greenhouse.io/v1/", + }, "schema_loader": { "name": "{{ parameters.stream_name }}", "file_path": "./source_sendgrid/schemas/{{ parameters.name }}.yaml", @@ -1153,7 +1542,11 @@ def test_only_parent_streams_use_cache(): "paginator": { "type": "DefaultPaginator", "page_size": 10, - "page_size_option": {"type": "RequestOption", "inject_into": "request_parameter", "field_name": "per_page"}, + "page_size_option": { + "type": "RequestOption", + "inject_into": "request_parameter", + "field_name": "per_page", + }, "page_token_option": {"type": "RequestPath"}, "pagination_strategy": { "type": "CursorPagination", @@ -1164,7 +1557,10 @@ def test_only_parent_streams_use_cache(): }, "requester": { "path": "applications", - "authenticator": {"type": "BasicHttpAuthenticator", "username": "{{ config['api_key'] }}"}, + "authenticator": { + "type": "BasicHttpAuthenticator", + "username": "{{ config['api_key'] }}", + }, }, "record_selector": {"extractor": {"type": "DpathExtractor", "field_path": []}}, }, @@ -1177,7 +1573,11 @@ def test_only_parent_streams_use_cache(): deepcopy(applications_stream), { "type": "DeclarativeStream", - "$parameters": {"name": "applications_interviews", "primary_key": "id", "url_base": "https://harvest.greenhouse.io/v1/"}, + "$parameters": { + "name": "applications_interviews", + "primary_key": "id", + "url_base": "https://harvest.greenhouse.io/v1/", + }, "schema_loader": { "name": "{{ parameters.stream_name }}", "file_path": "./source_sendgrid/schemas/{{ parameters.name }}.yaml", @@ -1186,7 +1586,11 @@ def test_only_parent_streams_use_cache(): "paginator": { "type": "DefaultPaginator", "page_size": 10, - "page_size_option": {"type": "RequestOption", "inject_into": "request_parameter", "field_name": "per_page"}, + "page_size_option": { + "type": "RequestOption", + "inject_into": "request_parameter", + "field_name": "per_page", + }, "page_token_option": {"type": "RequestPath"}, "pagination_strategy": { "type": "CursorPagination", @@ -1197,12 +1601,19 @@ def test_only_parent_streams_use_cache(): }, "requester": { "path": "applications_interviews", - "authenticator": {"type": "BasicHttpAuthenticator", "username": "{{ config['api_key'] }}"}, + "authenticator": { + "type": "BasicHttpAuthenticator", + "username": "{{ config['api_key'] }}", + }, }, "record_selector": {"extractor": {"type": "DpathExtractor", "field_path": []}}, "partition_router": { "parent_stream_configs": [ - {"parent_key": "id", "partition_field": "parent_id", "stream": deepcopy(applications_stream)} + { + "parent_key": "id", + "partition_field": "parent_id", + "stream": deepcopy(applications_stream), + } ], "type": "SubstreamPartitionRouter", }, @@ -1210,7 +1621,11 @@ def test_only_parent_streams_use_cache(): }, { "type": "DeclarativeStream", - "$parameters": {"name": "jobs", "primary_key": "id", "url_base": "https://harvest.greenhouse.io/v1/"}, + "$parameters": { + "name": "jobs", + "primary_key": "id", + "url_base": "https://harvest.greenhouse.io/v1/", + }, "schema_loader": { "name": "{{ parameters.stream_name }}", "file_path": "./source_sendgrid/schemas/{{ parameters.name }}.yaml", @@ -1219,7 +1634,11 @@ def test_only_parent_streams_use_cache(): "paginator": { "type": "DefaultPaginator", "page_size": 10, - "page_size_option": {"type": "RequestOption", "inject_into": "request_parameter", "field_name": "per_page"}, + "page_size_option": { + "type": "RequestOption", + "inject_into": "request_parameter", + "field_name": "per_page", + }, "page_token_option": {"type": "RequestPath"}, "pagination_strategy": { "type": "CursorPagination", @@ -1230,7 +1649,10 @@ def test_only_parent_streams_use_cache(): }, "requester": { "path": "jobs", - "authenticator": {"type": "BasicHttpAuthenticator", "username": "{{ config['api_key'] }}"}, + "authenticator": { + "type": "BasicHttpAuthenticator", + "username": "{{ config['api_key'] }}", + }, }, "record_selector": {"extractor": {"type": "DpathExtractor", "field_path": []}}, }, @@ -1252,8 +1674,15 @@ def test_only_parent_streams_use_cache(): assert not streams[1].retriever.requester.use_cache # Parent stream created for substream - assert streams[1].retriever.stream_slicer._partition_router.parent_stream_configs[0].stream.name == "applications" - assert streams[1].retriever.stream_slicer._partition_router.parent_stream_configs[0].stream.retriever.requester.use_cache + assert ( + streams[1].retriever.stream_slicer._partition_router.parent_stream_configs[0].stream.name + == "applications" + ) + assert ( + streams[1] + .retriever.stream_slicer._partition_router.parent_stream_configs[0] + .stream.retriever.requester.use_cache + ) # Main stream without caching assert streams[2].name == "jobs" @@ -1265,7 +1694,9 @@ def _run_read(manifest: Mapping[str, Any], stream_name: str) -> List[AirbyteMess catalog = ConfiguredAirbyteCatalog( streams=[ ConfiguredAirbyteStream( - stream=AirbyteStream(name=stream_name, json_schema={}, supported_sync_modes=[SyncMode.full_refresh]), + stream=AirbyteStream( + name=stream_name, json_schema={}, supported_sync_modes=[SyncMode.full_refresh] + ), sync_mode=SyncMode.full_refresh, destination_sync_mode=DestinationSyncMode.append, ) @@ -1311,6 +1742,7 @@ def validate_refs(yaml_file: str) -> List[str]: return invalid_refs yaml_file_path = ( - Path(__file__).resolve().parent.parent.parent.parent / "airbyte_cdk/sources/declarative/declarative_component_schema.yaml" + Path(__file__).resolve().parent.parent.parent.parent + / "airbyte_cdk/sources/declarative/declarative_component_schema.yaml" ) assert not validate_refs(yaml_file_path) diff --git a/unit_tests/sources/declarative/test_types.py b/unit_tests/sources/declarative/test_types.py index b6eb42f9..1a15dcfa 100644 --- a/unit_tests/sources/declarative/test_types.py +++ b/unit_tests/sources/declarative/test_types.py @@ -7,21 +7,37 @@ @pytest.mark.parametrize( "stream_slice, expected_partition", [ - pytest.param(StreamSlice(partition={}, cursor_slice={}), {}, id="test_partition_with_empty_partition"), pytest.param( - StreamSlice(partition=StreamSlice(partition={}, cursor_slice={}), cursor_slice={}), {}, id="test_partition_nested_empty" + StreamSlice(partition={}, cursor_slice={}), {}, id="test_partition_with_empty_partition" ), pytest.param( - StreamSlice(partition={"key": "value"}, cursor_slice={}), {"key": "value"}, id="test_partition_with_mapping_partition" + StreamSlice(partition=StreamSlice(partition={}, cursor_slice={}), cursor_slice={}), + {}, + id="test_partition_nested_empty", + ), + pytest.param( + StreamSlice(partition={"key": "value"}, cursor_slice={}), + {"key": "value"}, + id="test_partition_with_mapping_partition", + ), + pytest.param( + StreamSlice(partition={}, cursor_slice={"cursor": "value"}), + {}, + id="test_partition_with_only_cursor", ), - pytest.param(StreamSlice(partition={}, cursor_slice={"cursor": "value"}), {}, id="test_partition_with_only_cursor"), pytest.param( - StreamSlice(partition=StreamSlice(partition={}, cursor_slice={}), cursor_slice={"cursor": "value"}), + StreamSlice( + partition=StreamSlice(partition={}, cursor_slice={}), + cursor_slice={"cursor": "value"}, + ), {}, id="test_partition_nested_empty_and_cursor_value_mapping", ), pytest.param( - StreamSlice(partition=StreamSlice(partition={}, cursor_slice={"cursor": "value"}), cursor_slice={}), + StreamSlice( + partition=StreamSlice(partition={}, cursor_slice={"cursor": "value"}), + cursor_slice={}, + ), {}, id="test_partition_nested_empty_and_cursor_value", ), @@ -36,21 +52,37 @@ def test_partition(stream_slice, expected_partition): @pytest.mark.parametrize( "stream_slice, expected_cursor_slice", [ - pytest.param(StreamSlice(partition={}, cursor_slice={}), {}, id="test_cursor_slice_with_empty_cursor"), pytest.param( - StreamSlice(partition={}, cursor_slice=StreamSlice(partition={}, cursor_slice={})), {}, id="test_cursor_slice_nested_empty" + StreamSlice(partition={}, cursor_slice={}), {}, id="test_cursor_slice_with_empty_cursor" ), pytest.param( - StreamSlice(partition={}, cursor_slice={"key": "value"}), {"key": "value"}, id="test_cursor_slice_with_mapping_cursor_slice" + StreamSlice(partition={}, cursor_slice=StreamSlice(partition={}, cursor_slice={})), + {}, + id="test_cursor_slice_nested_empty", + ), + pytest.param( + StreamSlice(partition={}, cursor_slice={"key": "value"}), + {"key": "value"}, + id="test_cursor_slice_with_mapping_cursor_slice", + ), + pytest.param( + StreamSlice(partition={"partition": "value"}, cursor_slice={}), + {}, + id="test_cursor_slice_with_only_partition", ), - pytest.param(StreamSlice(partition={"partition": "value"}, cursor_slice={}), {}, id="test_cursor_slice_with_only_partition"), pytest.param( - StreamSlice(partition={"partition": "value"}, cursor_slice=StreamSlice(partition={}, cursor_slice={})), + StreamSlice( + partition={"partition": "value"}, + cursor_slice=StreamSlice(partition={}, cursor_slice={}), + ), {}, id="test_cursor_slice_nested_empty_and_partition_mapping", ), pytest.param( - StreamSlice(partition=StreamSlice(partition={"partition": "value"}, cursor_slice={}), cursor_slice={}), + StreamSlice( + partition=StreamSlice(partition={"partition": "value"}, cursor_slice={}), + cursor_slice={}, + ), {}, id="test_cursor_slice_nested_empty_and_partition", ), diff --git a/unit_tests/sources/declarative/transformations/test_add_fields.py b/unit_tests/sources/declarative/transformations/test_add_fields.py index 9b46cf49..b598e7bc 100644 --- a/unit_tests/sources/declarative/transformations/test_add_fields.py +++ b/unit_tests/sources/declarative/transformations/test_add_fields.py @@ -13,8 +13,22 @@ @pytest.mark.parametrize( ["input_record", "field", "field_type", "kwargs", "expected"], [ - pytest.param({"k": "v"}, [(["path"], "static_value")], None, {}, {"k": "v", "path": "static_value"}, id="add new static value"), - pytest.param({"k": "v"}, [(["path"], "{{ 1 }}")], None, {}, {"k": "v", "path": 1}, id="add an expression evaluated as a number"), + pytest.param( + {"k": "v"}, + [(["path"], "static_value")], + None, + {}, + {"k": "v", "path": "static_value"}, + id="add new static value", + ), + pytest.param( + {"k": "v"}, + [(["path"], "{{ 1 }}")], + None, + {}, + {"k": "v", "path": 1}, + id="add an expression evaluated as a number", + ), pytest.param( {"k": "v"}, [(["path"], "{{ 1 }}")], @@ -39,8 +53,22 @@ {"k": "v", "nested": {"path": "static_value"}}, id="set static value at nested path", ), - pytest.param({"k": "v"}, [(["k"], "new_value")], None, {}, {"k": "new_value"}, id="update value which already exists"), - pytest.param({"k": [0, 1]}, [(["k", 3], "v")], None, {}, {"k": [0, 1, None, "v"]}, id="Set element inside array"), + pytest.param( + {"k": "v"}, + [(["k"], "new_value")], + None, + {}, + {"k": "new_value"}, + id="update value which already exists", + ), + pytest.param( + {"k": [0, 1]}, + [(["k", 3], "v")], + None, + {}, + {"k": [0, 1, None, "v"]}, + id="Set element inside array", + ), pytest.param( {"k": "v"}, [(["k2"], '{{ config["shop"] }}')], @@ -121,7 +149,14 @@ {"k": {"nested": "v"}, "k2": "v"}, id="set a value from a nested field in the record using bracket notation", ), - pytest.param({"k": "v"}, [(["k2"], "{{ 2 + 2 }}")], None, {}, {"k": "v", "k2": 4}, id="set a value from a jinja expression"), + pytest.param( + {"k": "v"}, + [(["k2"], "{{ 2 + 2 }}")], + None, + {}, + {"k": "v", "k2": 4}, + id="set a value from a jinja expression", + ), ], ) def test_add_fields( @@ -131,6 +166,9 @@ def test_add_fields( kwargs: Mapping[str, Any], expected: Mapping[str, Any], ): - inputs = [AddedFieldDefinition(path=v[0], value=v[1], value_type=field_type, parameters={}) for v in field] + inputs = [ + AddedFieldDefinition(path=v[0], value=v[1], value_type=field_type, parameters={}) + for v in field + ] AddFields(fields=inputs, parameters={"alas": "i live"}).transform(input_record, **kwargs) assert input_record == expected diff --git a/unit_tests/sources/declarative/transformations/test_keys_to_lower_transformation.py b/unit_tests/sources/declarative/transformations/test_keys_to_lower_transformation.py index 7464b9f0..cdf52615 100644 --- a/unit_tests/sources/declarative/transformations/test_keys_to_lower_transformation.py +++ b/unit_tests/sources/declarative/transformations/test_keys_to_lower_transformation.py @@ -2,7 +2,9 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. # -from airbyte_cdk.sources.declarative.transformations.keys_to_lower_transformation import KeysToLowerTransformation +from airbyte_cdk.sources.declarative.transformations.keys_to_lower_transformation import ( + KeysToLowerTransformation, +) _ANY_VALUE = -1 diff --git a/unit_tests/sources/declarative/transformations/test_remove_fields.py b/unit_tests/sources/declarative/transformations/test_remove_fields.py index 89b17e8d..4638b7ea 100644 --- a/unit_tests/sources/declarative/transformations/test_remove_fields.py +++ b/unit_tests/sources/declarative/transformations/test_remove_fields.py @@ -12,14 +12,50 @@ @pytest.mark.parametrize( ["input_record", "field_pointers", "condition", "expected"], [ - pytest.param({"k1": "v", "k2": "v"}, [["k1"]], None, {"k2": "v"}, id="remove a field that exists (flat dict), condition = None"), - pytest.param({"k1": "v", "k2": "v"}, [["k1"]], "", {"k2": "v"}, id="remove a field that exists (flat dict)"), - pytest.param({"k1": "v", "k2": "v"}, [["k3"]], "", {"k1": "v", "k2": "v"}, id="remove a field that doesn't exist (flat dict)"), - pytest.param({"k1": "v", "k2": "v"}, [["k1"], ["k2"]], "", {}, id="remove multiple fields that exist (flat dict)"), + pytest.param( + {"k1": "v", "k2": "v"}, + [["k1"]], + None, + {"k2": "v"}, + id="remove a field that exists (flat dict), condition = None", + ), + pytest.param( + {"k1": "v", "k2": "v"}, + [["k1"]], + "", + {"k2": "v"}, + id="remove a field that exists (flat dict)", + ), + pytest.param( + {"k1": "v", "k2": "v"}, + [["k3"]], + "", + {"k1": "v", "k2": "v"}, + id="remove a field that doesn't exist (flat dict)", + ), + pytest.param( + {"k1": "v", "k2": "v"}, + [["k1"], ["k2"]], + "", + {}, + id="remove multiple fields that exist (flat dict)", + ), # TODO: should we instead splice the element out of the array? I think that's the more intuitive solution # Otherwise one could just set the field's value to null. - pytest.param({"k1": [1, 2]}, [["k1", 0]], "", {"k1": [None, 2]}, id="remove field inside array (int index)"), - pytest.param({"k1": [1, 2]}, [["k1", "0"]], "", {"k1": [None, 2]}, id="remove field inside array (string index)"), + pytest.param( + {"k1": [1, 2]}, + [["k1", 0]], + "", + {"k1": [None, 2]}, + id="remove field inside array (int index)", + ), + pytest.param( + {"k1": [1, 2]}, + [["k1", "0"]], + "", + {"k1": [None, 2]}, + id="remove field inside array (string index)", + ), pytest.param( {"k1": "v", "k2": "v", "k3": [0, 1], "k4": "v"}, [["k1"], ["k2"], ["k3", 0]], @@ -27,16 +63,40 @@ {"k3": [None, 1], "k4": "v"}, id="test all cases (flat)", ), - pytest.param({"k1": [0, 1]}, [[".", "k1", 10]], "", {"k1": [0, 1]}, id="remove array index that doesn't exist (flat)"), pytest.param( - {".": {"k1": [0, 1]}}, [[".", "k1", 10]], "", {".": {"k1": [0, 1]}}, id="remove array index that doesn't exist (nested)" + {"k1": [0, 1]}, + [[".", "k1", 10]], + "", + {"k1": [0, 1]}, + id="remove array index that doesn't exist (flat)", + ), + pytest.param( + {".": {"k1": [0, 1]}}, + [[".", "k1", 10]], + "", + {".": {"k1": [0, 1]}}, + id="remove array index that doesn't exist (nested)", + ), + pytest.param( + {".": {"k2": "v", "k1": "v"}}, + [[".", "k1"]], + "", + {".": {"k2": "v"}}, + id="remove nested field that exists", ), - pytest.param({".": {"k2": "v", "k1": "v"}}, [[".", "k1"]], "", {".": {"k2": "v"}}, id="remove nested field that exists"), pytest.param( - {".": {"k2": "v", "k1": "v"}}, [[".", "k3"]], "", {".": {"k2": "v", "k1": "v"}}, id="remove field that doesn't exist (nested)" + {".": {"k2": "v", "k1": "v"}}, + [[".", "k3"]], + "", + {".": {"k2": "v", "k1": "v"}}, + id="remove field that doesn't exist (nested)", ), pytest.param( - {".": {"k2": "v", "k1": "v"}}, [[".", "k1"], [".", "k2"]], "", {".": {}}, id="remove multiple fields that exist (nested)" + {".": {"k2": "v", "k1": "v"}}, + [[".", "k1"], [".", "k2"]], + "", + {".": {}}, + id="remove multiple fields that exist (nested)", ), pytest.param( {".": {"k1": [0, 1]}}, @@ -59,7 +119,13 @@ {"k1": "v", "k2": "v"}, id="do not remove any field if condition is boolean False", ), - pytest.param({"k1": "v", "k2": "v"}, [["**"]], "{{ True }}", {}, id="remove all field if condition is boolean True"), + pytest.param( + {"k1": "v", "k2": "v"}, + [["**"]], + "{{ True }}", + {}, + id="remove all field if condition is boolean True", + ), pytest.param( {"k1": "v", "k2": "v1", "k3": "v1", "k4": {"k_nested": "v1", "k_nested2": "v2"}}, [["**"]], @@ -68,14 +134,24 @@ id="recursively remove any field that matches property condition and leave that does not", ), pytest.param( - {"k1": "v", "k2": "some_long_string", "k3": "some_long_string", "k4": {"k_nested": "v1", "k_nested2": "v2"}}, + { + "k1": "v", + "k2": "some_long_string", + "k3": "some_long_string", + "k4": {"k_nested": "v1", "k_nested2": "v2"}, + }, [["**"]], "{{ property|length > 5 }}", {"k1": "v", "k4": {"k_nested": "v1", "k_nested2": "v2"}}, id="remove any field that have length > 5 and leave that does not", ), pytest.param( - {"k1": 255, "k2": "some_string", "k3": "some_long_string", "k4": {"k_nested": 123123, "k_nested2": "v2"}}, + { + "k1": 255, + "k2": "some_string", + "k3": "some_long_string", + "k4": {"k_nested": 123123, "k_nested2": "v2"}, + }, [["**"]], "{{ property is integer }}", {"k2": "some_string", "k3": "some_long_string", "k4": {"k_nested2": "v2"}}, @@ -83,7 +159,12 @@ ), ], ) -def test_remove_fields(input_record: Mapping[str, Any], field_pointers: List[FieldPointer], condition: str, expected: Mapping[str, Any]): +def test_remove_fields( + input_record: Mapping[str, Any], + field_pointers: List[FieldPointer], + condition: str, + expected: Mapping[str, Any], +): transformation = RemoveFields(field_pointers=field_pointers, condition=condition, parameters={}) transformation.transform(input_record) assert input_record == expected diff --git a/unit_tests/sources/embedded/test_embedded_integration.py b/unit_tests/sources/embedded/test_embedded_integration.py index 7560dc40..f8e11cff 100644 --- a/unit_tests/sources/embedded/test_embedded_integration.py +++ b/unit_tests/sources/embedded/test_embedded_integration.py @@ -52,15 +52,26 @@ def setUp(self): json_schema={}, supported_sync_modes=[SyncMode.full_refresh, SyncMode.incremental], ) - self.stream2 = AirbyteStream(name="test2", json_schema={}, supported_sync_modes=[SyncMode.full_refresh]) + self.stream2 = AirbyteStream( + name="test2", json_schema={}, supported_sync_modes=[SyncMode.full_refresh] + ) self.source.discover.return_value = AirbyteCatalog(streams=[self.stream2, self.stream1]) def test_integration(self): self.source.read.return_value = [ AirbyteMessage(type=Type.LOG, log=AirbyteLogMessage(level=Level.INFO, message="test")), - AirbyteMessage(type=Type.RECORD, record=AirbyteRecordMessage(stream="test", data={"test": 1}, emitted_at=1)), - AirbyteMessage(type=Type.RECORD, record=AirbyteRecordMessage(stream="test", data={"test": 2}, emitted_at=2)), - AirbyteMessage(type=Type.RECORD, record=AirbyteRecordMessage(stream="test", data={"test": 3}, emitted_at=3)), + AirbyteMessage( + type=Type.RECORD, + record=AirbyteRecordMessage(stream="test", data={"test": 1}, emitted_at=1), + ), + AirbyteMessage( + type=Type.RECORD, + record=AirbyteRecordMessage(stream="test", data={"test": 2}, emitted_at=2), + ), + AirbyteMessage( + type=Type.RECORD, + record=AirbyteRecordMessage(stream="test", data={"test": 3}, emitted_at=3), + ), ] result = list(self.integration._load_data("test", None)) self.assertEqual( @@ -97,7 +108,10 @@ def test_state(self): state = AirbyteStateMessage(data={}) self.source.read.return_value = [ AirbyteMessage(type=Type.LOG, log=AirbyteLogMessage(level=Level.INFO, message="test")), - AirbyteMessage(type=Type.RECORD, record=AirbyteRecordMessage(stream="test", data={"test": 1}, emitted_at=1)), + AirbyteMessage( + type=Type.RECORD, + record=AirbyteRecordMessage(stream="test", data={"test": 1}, emitted_at=1), + ), AirbyteMessage(type=Type.STATE, state=state), ] result = list(self.integration._load_data("test", None)) diff --git a/unit_tests/sources/file_based/availability_strategy/test_default_file_based_availability_strategy.py b/unit_tests/sources/file_based/availability_strategy/test_default_file_based_availability_strategy.py index 7206d234..b05bff03 100644 --- a/unit_tests/sources/file_based/availability_strategy/test_default_file_based_availability_strategy.py +++ b/unit_tests/sources/file_based/availability_strategy/test_default_file_based_availability_strategy.py @@ -17,7 +17,9 @@ from airbyte_cdk.sources.file_based.remote_file import RemoteFile from airbyte_cdk.sources.file_based.stream import AbstractFileBasedStream -_FILE_WITH_UNKNOWN_EXTENSION = RemoteFile(uri="a.unknown_extension", last_modified=datetime.now(), file_type="csv") +_FILE_WITH_UNKNOWN_EXTENSION = RemoteFile( + uri="a.unknown_extension", last_modified=datetime.now(), file_type="csv" +) _ANY_CONFIG = FileBasedStreamConfig( name="config.name", file_type="parquet", @@ -40,7 +42,9 @@ def setUp(self) -> None: self._stream.validation_policy = PropertyMock(validate_schema_before_sync=False) self._stream.stream_reader = self._stream_reader - def test_given_file_extension_does_not_match_when_check_availability_and_parsability_then_stream_is_still_available(self) -> None: + def test_given_file_extension_does_not_match_when_check_availability_and_parsability_then_stream_is_still_available( + self, + ) -> None: """ Before, we had a validation on the file extension but it turns out that in production, users sometimes have mismatch there. The example we've seen was for JSONL parser but the file extension was just `.json`. Note that there we more than one record extracted @@ -49,7 +53,9 @@ def test_given_file_extension_does_not_match_when_check_availability_and_parsabi self._stream.get_files.return_value = [_FILE_WITH_UNKNOWN_EXTENSION] self._parser.parse_records.return_value = [{"a record": 1}] - is_available, reason = self._strategy.check_availability_and_parsability(self._stream, Mock(), Mock()) + is_available, reason = self._strategy.check_availability_and_parsability( + self._stream, Mock(), Mock() + ) assert is_available @@ -59,7 +65,9 @@ def test_not_available_given_no_files(self) -> None: """ self._stream.get_files.return_value = [] - is_available, reason = self._strategy.check_availability_and_parsability(self._stream, Mock(), Mock()) + is_available, reason = self._strategy.check_availability_and_parsability( + self._stream, Mock(), Mock() + ) assert not is_available assert "No files were identified in the stream" in reason @@ -71,7 +79,9 @@ def test_parse_records_is_not_called_with_parser_max_n_files_for_parsability_set self._parser.parser_max_n_files_for_parsability = 0 self._stream.get_files.return_value = [_FILE_WITH_UNKNOWN_EXTENSION] - is_available, reason = self._strategy.check_availability_and_parsability(self._stream, Mock(), Mock()) + is_available, reason = self._strategy.check_availability_and_parsability( + self._stream, Mock(), Mock() + ) assert is_available assert not self._parser.parse_records.called @@ -82,7 +92,9 @@ def test_passing_config_check(self) -> None: Test if the DefaultFileBasedAvailabilityStrategy correctly handles the check_config method defined on the parser. """ self._parser.check_config.return_value = (False, "Ran into error") - is_available, error_message = self._strategy.check_availability_and_parsability(self._stream, Mock(), Mock()) + is_available, error_message = self._strategy.check_availability_and_parsability( + self._stream, Mock(), Mock() + ) assert not is_available assert "Ran into error" in error_message @@ -92,9 +104,13 @@ def test_catching_and_raising_custom_file_based_exception(self) -> None: by raising a CheckAvailabilityError when the get_files method is called. """ # Mock the get_files method to raise CustomFileBasedException when called - self._stream.get_files.side_effect = CustomFileBasedException("Custom exception for testing.") + self._stream.get_files.side_effect = CustomFileBasedException( + "Custom exception for testing." + ) # Invoke the check_availability_and_parsability method and check if it correctly handles the exception - is_available, error_message = self._strategy.check_availability_and_parsability(self._stream, Mock(), Mock()) + is_available, error_message = self._strategy.check_availability_and_parsability( + self._stream, Mock(), Mock() + ) assert not is_available assert "Custom exception for testing." in error_message diff --git a/unit_tests/sources/file_based/config/test_abstract_file_based_spec.py b/unit_tests/sources/file_based/config/test_abstract_file_based_spec.py index 3c3d72c2..84de3ad6 100644 --- a/unit_tests/sources/file_based/config/test_abstract_file_based_spec.py +++ b/unit_tests/sources/file_based/config/test_abstract_file_based_spec.py @@ -5,7 +5,11 @@ from typing import Type import pytest -from airbyte_cdk.sources.file_based.config.file_based_stream_config import AvroFormat, CsvFormat, ParquetFormat +from airbyte_cdk.sources.file_based.config.file_based_stream_config import ( + AvroFormat, + CsvFormat, + ParquetFormat, +) from jsonschema import ValidationError, validate from pydantic.v1 import BaseModel @@ -13,12 +17,21 @@ @pytest.mark.parametrize( "file_format, file_type, expected_error", [ - pytest.param(ParquetFormat, "parquet", None, id="test_parquet_format_is_a_valid_parquet_file_type"), + pytest.param( + ParquetFormat, "parquet", None, id="test_parquet_format_is_a_valid_parquet_file_type" + ), pytest.param(AvroFormat, "avro", None, id="test_avro_format_is_a_valid_avro_file_type"), - pytest.param(CsvFormat, "parquet", ValidationError, id="test_csv_format_is_not_a_valid_parquet_file_type"), + pytest.param( + CsvFormat, + "parquet", + ValidationError, + id="test_csv_format_is_not_a_valid_parquet_file_type", + ), ], ) -def test_parquet_file_type_is_not_a_valid_csv_file_type(file_format: BaseModel, file_type: str, expected_error: Type[Exception]) -> None: +def test_parquet_file_type_is_not_a_valid_csv_file_type( + file_format: BaseModel, file_type: str, expected_error: Type[Exception] +) -> None: format_config = {file_type: {"filetype": file_type, "decimal_as_float": True}} if expected_error: diff --git a/unit_tests/sources/file_based/config/test_csv_format.py b/unit_tests/sources/file_based/config/test_csv_format.py index c233bd7a..ace9e034 100644 --- a/unit_tests/sources/file_based/config/test_csv_format.py +++ b/unit_tests/sources/file_based/config/test_csv_format.py @@ -5,7 +5,12 @@ import unittest import pytest -from airbyte_cdk.sources.file_based.config.csv_format import CsvFormat, CsvHeaderAutogenerated, CsvHeaderFromCsv, CsvHeaderUserProvided +from airbyte_cdk.sources.file_based.config.csv_format import ( + CsvFormat, + CsvHeaderAutogenerated, + CsvHeaderFromCsv, + CsvHeaderUserProvided, +) from pydantic.v1.error_wrappers import ValidationError diff --git a/unit_tests/sources/file_based/config/test_file_based_stream_config.py b/unit_tests/sources/file_based/config/test_file_based_stream_config.py index 4c5d69a7..addc7223 100644 --- a/unit_tests/sources/file_based/config/test_file_based_stream_config.py +++ b/unit_tests/sources/file_based/config/test_file_based_stream_config.py @@ -5,7 +5,10 @@ from typing import Any, Mapping, Type import pytest as pytest -from airbyte_cdk.sources.file_based.config.file_based_stream_config import CsvFormat, FileBasedStreamConfig +from airbyte_cdk.sources.file_based.config.file_based_stream_config import ( + CsvFormat, + FileBasedStreamConfig, +) from pydantic.v1.error_wrappers import ValidationError @@ -14,8 +17,22 @@ [ pytest.param( "csv", - {"filetype": "csv", "delimiter": "d", "quote_char": "q", "escape_char": "e", "encoding": "ascii", "double_quote": True}, - {"filetype": "csv", "delimiter": "d", "quote_char": "q", "escape_char": "e", "encoding": "ascii", "double_quote": True}, + { + "filetype": "csv", + "delimiter": "d", + "quote_char": "q", + "escape_char": "e", + "encoding": "ascii", + "double_quote": True, + }, + { + "filetype": "csv", + "delimiter": "d", + "quote_char": "q", + "escape_char": "e", + "encoding": "ascii", + "double_quote": True, + }, None, id="test_valid_format", ), @@ -27,30 +44,61 @@ id="test_default_format_values", ), pytest.param( - "csv", {"filetype": "csv", "delimiter": "nope", "double_quote": True}, None, ValidationError, id="test_invalid_delimiter" + "csv", + {"filetype": "csv", "delimiter": "nope", "double_quote": True}, + None, + ValidationError, + id="test_invalid_delimiter", ), pytest.param( - "csv", {"filetype": "csv", "quote_char": "nope", "double_quote": True}, None, ValidationError, id="test_invalid_quote_char" + "csv", + {"filetype": "csv", "quote_char": "nope", "double_quote": True}, + None, + ValidationError, + id="test_invalid_quote_char", ), pytest.param( - "csv", {"filetype": "csv", "escape_char": "nope", "double_quote": True}, None, ValidationError, id="test_invalid_escape_char" + "csv", + {"filetype": "csv", "escape_char": "nope", "double_quote": True}, + None, + ValidationError, + id="test_invalid_escape_char", ), pytest.param( "csv", - {"filetype": "csv", "delimiter": ",", "quote_char": '"', "encoding": "not_a_format", "double_quote": True}, + { + "filetype": "csv", + "delimiter": ",", + "quote_char": '"', + "encoding": "not_a_format", + "double_quote": True, + }, {}, ValidationError, id="test_invalid_encoding_type", ), pytest.param( - "invalid", {"filetype": "invalid", "double_quote": False}, {}, ValidationError, id="test_config_format_file_type_mismatch" + "invalid", + {"filetype": "invalid", "double_quote": False}, + {}, + ValidationError, + id="test_config_format_file_type_mismatch", ), ], ) def test_csv_config( - file_type: str, input_format: Mapping[str, Any], expected_format: Mapping[str, Any], expected_error: Type[Exception] + file_type: str, + input_format: Mapping[str, Any], + expected_format: Mapping[str, Any], + expected_error: Type[Exception], ) -> None: - stream_config = {"name": "stream1", "file_type": file_type, "globs": ["*"], "validation_policy": "Emit Record", "format": input_format} + stream_config = { + "name": "stream1", + "file_type": file_type, + "globs": ["*"], + "validation_policy": "Emit Record", + "format": input_format, + } if expected_error: with pytest.raises(expected_error): diff --git a/unit_tests/sources/file_based/discovery_policy/test_default_discovery_policy.py b/unit_tests/sources/file_based/discovery_policy/test_default_discovery_policy.py index b7ad6711..8cb97715 100644 --- a/unit_tests/sources/file_based/discovery_policy/test_default_discovery_policy.py +++ b/unit_tests/sources/file_based/discovery_policy/test_default_discovery_policy.py @@ -5,7 +5,9 @@ import unittest from unittest.mock import Mock -from airbyte_cdk.sources.file_based.discovery_policy.default_discovery_policy import DefaultDiscoveryPolicy +from airbyte_cdk.sources.file_based.discovery_policy.default_discovery_policy import ( + DefaultDiscoveryPolicy, +) from airbyte_cdk.sources.file_based.file_types.file_type_parser import FileTypeParser diff --git a/unit_tests/sources/file_based/file_types/test_avro_parser.py b/unit_tests/sources/file_based/file_types/test_avro_parser.py index a45d424b..2c52f9f8 100644 --- a/unit_tests/sources/file_based/file_types/test_avro_parser.py +++ b/unit_tests/sources/file_based/file_types/test_avro_parser.py @@ -24,7 +24,13 @@ pytest.param(_default_avro_format, "long", {"type": "integer"}, None, id="test_long"), pytest.param(_default_avro_format, "float", {"type": "number"}, None, id="test_float"), pytest.param(_default_avro_format, "double", {"type": "number"}, None, id="test_double"), - pytest.param(_double_as_string_avro_format, "double", {"type": "string"}, None, id="test_double_as_string"), + pytest.param( + _double_as_string_avro_format, + "double", + {"type": "string"}, + None, + id="test_double_as_string", + ), pytest.param(_default_avro_format, "bytes", {"type": "string"}, None, id="test_bytes"), pytest.param(_default_avro_format, "string", {"type": "string"}, None, id="test_string"), pytest.param(_default_avro_format, "void", None, ValueError, id="test_invalid_type"), @@ -34,7 +40,11 @@ { "type": "record", "name": "SubRecord", - "fields": [{"name": "precise", "type": "double"}, {"name": "robo", "type": "bytes"}, {"name": "simple", "type": "long"}], + "fields": [ + {"name": "precise", "type": "double"}, + {"name": "robo", "type": "bytes"}, + {"name": "simple", "type": "long"}, + ], }, { "type": "object", @@ -52,9 +62,18 @@ { "type": "record", "name": "SubRecord", - "fields": [{"name": "precise", "type": "double"}, {"name": "obj_array", "type": {"type": "array", "items": "float"}}], + "fields": [ + {"name": "precise", "type": "double"}, + {"name": "obj_array", "type": {"type": "array", "items": "float"}}, + ], + }, + { + "type": "object", + "properties": { + "precise": {"type": "number"}, + "obj_array": {"type": "array", "items": {"type": "number"}}, + }, }, - {"type": "object", "properties": {"precise": {"type": "number"}, "obj_array": {"type": "array", "items": {"type": "number"}}}}, None, id="test_record_with_nested_array", ), @@ -66,7 +85,11 @@ "fields": [ { "name": "nested_record", - "type": {"type": "record", "name": "SubRecord", "fields": [{"name": "question", "type": "boolean"}]}, + "type": { + "type": "record", + "name": "SubRecord", + "fields": [{"name": "question", "type": "boolean"}], + }, } ], }, @@ -83,11 +106,22 @@ id="test_record_with_nested_record", ), pytest.param( - _default_avro_format, {"type": "array", "items": "float"}, {"type": "array", "items": {"type": "number"}}, None, id="test_array" + _default_avro_format, + {"type": "array", "items": "float"}, + {"type": "array", "items": {"type": "number"}}, + None, + id="test_array", ), pytest.param( _default_avro_format, - {"type": "array", "items": {"type": "record", "name": "SubRecord", "fields": [{"name": "precise", "type": "double"}]}}, + { + "type": "array", + "items": { + "type": "record", + "name": "SubRecord", + "fields": [{"name": "precise", "type": "double"}], + }, + }, { "type": "array", "items": { @@ -100,9 +134,19 @@ None, id="test_array_of_records", ), - pytest.param(_default_avro_format, {"type": "array", "not_items": "string"}, None, ValueError, id="test_array_missing_items"), pytest.param( - _default_avro_format, {"type": "array", "items": "invalid_avro_type"}, None, ValueError, id="test_array_invalid_item_type" + _default_avro_format, + {"type": "array", "not_items": "string"}, + None, + ValueError, + id="test_array_missing_items", + ), + pytest.param( + _default_avro_format, + {"type": "array", "items": "invalid_avro_type"}, + None, + ValueError, + id="test_array_invalid_item_type", ), pytest.param( _default_avro_format, @@ -111,9 +155,19 @@ None, id="test_enum", ), - pytest.param(_default_avro_format, {"type": "enum", "name": "IMF"}, None, ValueError, id="test_enum_missing_symbols"), pytest.param( - _default_avro_format, {"type": "enum", "symbols": ["mission", "not", "accepted"]}, None, ValueError, id="test_enum_missing_name" + _default_avro_format, + {"type": "enum", "name": "IMF"}, + None, + ValueError, + id="test_enum_missing_symbols", + ), + pytest.param( + _default_avro_format, + {"type": "enum", "symbols": ["mission", "not", "accepted"]}, + None, + ValueError, + id="test_enum_missing_name", ), pytest.param( _default_avro_format, @@ -124,12 +178,27 @@ ), pytest.param( _default_avro_format, - {"type": "map", "values": {"type": "record", "name": "SubRecord", "fields": [{"name": "agent", "type": "string"}]}}, - {"type": "object", "additionalProperties": {"type": "object", "properties": {"agent": {"type": "string"}}}}, + { + "type": "map", + "values": { + "type": "record", + "name": "SubRecord", + "fields": [{"name": "agent", "type": "string"}], + }, + }, + { + "type": "object", + "additionalProperties": { + "type": "object", + "properties": {"agent": {"type": "string"}}, + }, + }, None, id="test_map_object", ), - pytest.param(_default_avro_format, {"type": "map"}, None, ValueError, id="test_map_missing_values"), + pytest.param( + _default_avro_format, {"type": "map"}, None, ValueError, id="test_map_missing_values" + ), pytest.param( _default_avro_format, {"type": "fixed", "name": "limit", "size": 12}, @@ -137,9 +206,19 @@ None, id="test_fixed", ), - pytest.param(_default_avro_format, {"type": "fixed", "name": "limit"}, None, ValueError, id="test_fixed_missing_size"), pytest.param( - _default_avro_format, {"type": "fixed", "name": "limit", "size": "50"}, None, ValueError, id="test_fixed_size_not_integer" + _default_avro_format, + {"type": "fixed", "name": "limit"}, + None, + ValueError, + id="test_fixed_missing_size", + ), + pytest.param( + _default_avro_format, + {"type": "fixed", "name": "limit", "size": "50"}, + None, + ValueError, + id="test_fixed_size_not_integer", ), # Logical types pytest.param( @@ -163,13 +242,33 @@ ValueError, id="test_decimal_missing_scale", ), - pytest.param(_default_avro_format, {"type": "bytes", "logicalType": "uuid"}, {"type": "string"}, None, id="test_uuid"), pytest.param( - _default_avro_format, {"type": "int", "logicalType": "date"}, {"type": "string", "format": "date"}, None, id="test_date" + _default_avro_format, + {"type": "bytes", "logicalType": "uuid"}, + {"type": "string"}, + None, + id="test_uuid", ), - pytest.param(_default_avro_format, {"type": "int", "logicalType": "time-millis"}, {"type": "integer"}, None, id="test_time_millis"), pytest.param( - _default_avro_format, {"type": "long", "logicalType": "time-micros"}, {"type": "integer"}, None, id="test_time_micros" + _default_avro_format, + {"type": "int", "logicalType": "date"}, + {"type": "string", "format": "date"}, + None, + id="test_date", + ), + pytest.param( + _default_avro_format, + {"type": "int", "logicalType": "time-millis"}, + {"type": "integer"}, + None, + id="test_time_millis", + ), + pytest.param( + _default_avro_format, + {"type": "long", "logicalType": "time-micros"}, + {"type": "integer"}, + None, + id="test_time_micros", ), pytest.param( _default_avro_format, @@ -179,7 +278,11 @@ id="test_timestamp_millis", ), pytest.param( - _default_avro_format, {"type": "long", "logicalType": "timestamp-micros"}, {"type": "string"}, None, id="test_timestamp_micros" + _default_avro_format, + {"type": "long", "logicalType": "timestamp-micros"}, + {"type": "string"}, + None, + id="test_timestamp_micros", ), pytest.param( _default_avro_format, @@ -204,12 +307,16 @@ ), ], ) -def test_convert_primitive_avro_type_to_json(avro_format, avro_type, expected_json_type, expected_error): +def test_convert_primitive_avro_type_to_json( + avro_format, avro_type, expected_json_type, expected_error +): if expected_error: with pytest.raises(expected_error): AvroParser._convert_avro_type_to_json(avro_format, "field_name", avro_type) else: - actual_json_type = AvroParser._convert_avro_type_to_json(avro_format, "field_name", avro_type) + actual_json_type = AvroParser._convert_avro_type_to_json( + avro_format, "field_name", avro_type + ) assert actual_json_type == expected_json_type @@ -220,15 +327,47 @@ def test_convert_primitive_avro_type_to_json(avro_format, avro_type, expected_js pytest.param(_default_avro_format, "int", 123, 123, id="test_int"), pytest.param(_default_avro_format, "long", 123, 123, id="test_long"), pytest.param(_default_avro_format, "float", 123.456, 123.456, id="test_float"), - pytest.param(_default_avro_format, "double", 123.456, 123.456, id="test_double_default_config"), - pytest.param(_double_as_string_avro_format, "double", 123.456, "123.456", id="test_double_as_string"), + pytest.param( + _default_avro_format, "double", 123.456, 123.456, id="test_double_default_config" + ), + pytest.param( + _double_as_string_avro_format, "double", 123.456, "123.456", id="test_double_as_string" + ), pytest.param(_default_avro_format, "bytes", b"hello world", "hello world", id="test_bytes"), - pytest.param(_default_avro_format, "string", "hello world", "hello world", id="test_string"), - pytest.param(_default_avro_format, {"logicalType": "decimal"}, 3.1415, "3.1415", id="test_decimal"), - pytest.param(_default_avro_format, {"logicalType": "uuid"}, _uuid_value, str(_uuid_value), id="test_uuid"), - pytest.param(_default_avro_format, {"logicalType": "date"}, datetime.date(2023, 8, 7), "2023-08-07", id="test_date"), - pytest.param(_default_avro_format, {"logicalType": "time-millis"}, 70267068, 70267068, id="test_time_millis"), - pytest.param(_default_avro_format, {"logicalType": "time-micros"}, 70267068, 70267068, id="test_time_micros"), + pytest.param( + _default_avro_format, "string", "hello world", "hello world", id="test_string" + ), + pytest.param( + _default_avro_format, {"logicalType": "decimal"}, 3.1415, "3.1415", id="test_decimal" + ), + pytest.param( + _default_avro_format, + {"logicalType": "uuid"}, + _uuid_value, + str(_uuid_value), + id="test_uuid", + ), + pytest.param( + _default_avro_format, + {"logicalType": "date"}, + datetime.date(2023, 8, 7), + "2023-08-07", + id="test_date", + ), + pytest.param( + _default_avro_format, + {"logicalType": "time-millis"}, + 70267068, + 70267068, + id="test_time_millis", + ), + pytest.param( + _default_avro_format, + {"logicalType": "time-micros"}, + 70267068, + 70267068, + id="test_time_micros", + ), pytest.param( _default_avro_format, {"logicalType": "local-timestamp-millis"}, diff --git a/unit_tests/sources/file_based/file_types/test_csv_parser.py b/unit_tests/sources/file_based/file_types/test_csv_parser.py index 9280ffb6..295c4da6 100644 --- a/unit_tests/sources/file_based/file_types/test_csv_parser.py +++ b/unit_tests/sources/file_based/file_types/test_csv_parser.py @@ -24,7 +24,10 @@ ) from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig from airbyte_cdk.sources.file_based.exceptions import RecordParseError -from airbyte_cdk.sources.file_based.file_based_stream_reader import AbstractFileBasedStreamReader, FileReadMode +from airbyte_cdk.sources.file_based.file_based_stream_reader import ( + AbstractFileBasedStreamReader, + FileReadMode, +) from airbyte_cdk.sources.file_based.file_types.csv_parser import CsvParser, _CsvReader from airbyte_cdk.sources.file_based.remote_file import RemoteFile from airbyte_cdk.utils.traced_exception import AirbyteTracedException @@ -77,10 +80,34 @@ }, id="cast-all-cols", ), - pytest.param({"col1": "1"}, DEFAULT_TRUE_VALUES, DEFAULT_FALSE_VALUES, {"col1": "1"}, id="cannot-cast-to-null"), - pytest.param({"col2": "1"}, DEFAULT_TRUE_VALUES, DEFAULT_FALSE_VALUES, {"col2": True}, id="cast-1-to-bool"), - pytest.param({"col2": "0"}, DEFAULT_TRUE_VALUES, DEFAULT_FALSE_VALUES, {"col2": False}, id="cast-0-to-bool"), - pytest.param({"col2": "yes"}, DEFAULT_TRUE_VALUES, DEFAULT_FALSE_VALUES, {"col2": True}, id="cast-yes-to-bool"), + pytest.param( + {"col1": "1"}, + DEFAULT_TRUE_VALUES, + DEFAULT_FALSE_VALUES, + {"col1": "1"}, + id="cannot-cast-to-null", + ), + pytest.param( + {"col2": "1"}, + DEFAULT_TRUE_VALUES, + DEFAULT_FALSE_VALUES, + {"col2": True}, + id="cast-1-to-bool", + ), + pytest.param( + {"col2": "0"}, + DEFAULT_TRUE_VALUES, + DEFAULT_FALSE_VALUES, + {"col2": False}, + id="cast-0-to-bool", + ), + pytest.param( + {"col2": "yes"}, + DEFAULT_TRUE_VALUES, + DEFAULT_FALSE_VALUES, + {"col2": True}, + id="cast-yes-to-bool", + ), pytest.param( {"col2": "this_is_a_true_value"}, ["this_is_a_true_value"], @@ -95,24 +122,77 @@ {"col2": False}, id="cast-custom-false-value-to-bool", ), - pytest.param({"col2": "no"}, DEFAULT_TRUE_VALUES, DEFAULT_FALSE_VALUES, {"col2": False}, id="cast-no-to-bool"), - pytest.param({"col2": "10"}, DEFAULT_TRUE_VALUES, DEFAULT_FALSE_VALUES, {"col2": "10"}, id="cannot-cast-to-bool"), - pytest.param({"col3": "1.1"}, DEFAULT_TRUE_VALUES, DEFAULT_FALSE_VALUES, {"col3": "1.1"}, id="cannot-cast-to-int"), - pytest.param({"col4": "asdf"}, DEFAULT_TRUE_VALUES, DEFAULT_FALSE_VALUES, {"col4": "asdf"}, id="cannot-cast-to-float"), - pytest.param({"col6": "{'a': 'b'}"}, DEFAULT_TRUE_VALUES, DEFAULT_FALSE_VALUES, {"col6": "{'a': 'b'}"}, id="cannot-cast-to-dict"), pytest.param( - {"col7": "['a', 'b']"}, DEFAULT_TRUE_VALUES, DEFAULT_FALSE_VALUES, {"col7": "['a', 'b']"}, id="cannot-cast-to-list-of-ints" + {"col2": "no"}, + DEFAULT_TRUE_VALUES, + DEFAULT_FALSE_VALUES, + {"col2": False}, + id="cast-no-to-bool", ), pytest.param( - {"col8": "['a', 'b']"}, DEFAULT_TRUE_VALUES, DEFAULT_FALSE_VALUES, {"col8": "['a', 'b']"}, id="cannot-cast-to-list-of-strings" + {"col2": "10"}, + DEFAULT_TRUE_VALUES, + DEFAULT_FALSE_VALUES, + {"col2": "10"}, + id="cannot-cast-to-bool", ), pytest.param( - {"col9": "['a', 'b']"}, DEFAULT_TRUE_VALUES, DEFAULT_FALSE_VALUES, {"col9": "['a', 'b']"}, id="cannot-cast-to-list-of-objects" + {"col3": "1.1"}, + DEFAULT_TRUE_VALUES, + DEFAULT_FALSE_VALUES, + {"col3": "1.1"}, + id="cannot-cast-to-int", + ), + pytest.param( + {"col4": "asdf"}, + DEFAULT_TRUE_VALUES, + DEFAULT_FALSE_VALUES, + {"col4": "asdf"}, + id="cannot-cast-to-float", + ), + pytest.param( + {"col6": "{'a': 'b'}"}, + DEFAULT_TRUE_VALUES, + DEFAULT_FALSE_VALUES, + {"col6": "{'a': 'b'}"}, + id="cannot-cast-to-dict", + ), + pytest.param( + {"col7": "['a', 'b']"}, + DEFAULT_TRUE_VALUES, + DEFAULT_FALSE_VALUES, + {"col7": "['a', 'b']"}, + id="cannot-cast-to-list-of-ints", + ), + pytest.param( + {"col8": "['a', 'b']"}, + DEFAULT_TRUE_VALUES, + DEFAULT_FALSE_VALUES, + {"col8": "['a', 'b']"}, + id="cannot-cast-to-list-of-strings", + ), + pytest.param( + {"col9": "['a', 'b']"}, + DEFAULT_TRUE_VALUES, + DEFAULT_FALSE_VALUES, + {"col9": "['a', 'b']"}, + id="cannot-cast-to-list-of-objects", + ), + pytest.param( + {"col11": "x"}, + DEFAULT_TRUE_VALUES, + DEFAULT_FALSE_VALUES, + {}, + id="item-not-in-props-doesn't-error", ), - pytest.param({"col11": "x"}, DEFAULT_TRUE_VALUES, DEFAULT_FALSE_VALUES, {}, id="item-not-in-props-doesn't-error"), ], ) -def test_cast_to_python_type(row: Dict[str, str], true_values: Set[str], false_values: Set[str], expected_output: Dict[str, Any]) -> None: +def test_cast_to_python_type( + row: Dict[str, str], + true_values: Set[str], + false_values: Set[str], + expected_output: Dict[str, Any], +) -> None: csv_format = CsvFormat(true_values=true_values, false_values=false_values) assert CsvParser._cast_types(row, PROPERTY_TYPES, csv_format, logger) == expected_output @@ -189,9 +269,13 @@ def test_given_integers_only_when_infer_schema_then_type_is_integer(self) -> Non self._config_format.inference_type = InferenceType.PRIMITIVE_TYPES_ONLY self._test_infer_schema(["2", "90329", "5645"], "integer") - def test_given_integer_overlap_with_bool_value_only_when_infer_schema_then_type_is_integer(self) -> None: + def test_given_integer_overlap_with_bool_value_only_when_infer_schema_then_type_is_integer( + self, + ) -> None: self._config_format.inference_type = InferenceType.PRIMITIVE_TYPES_ONLY - self._test_infer_schema(["1", "90329", "5645"], "integer") # here, "1" is also considered a boolean + self._test_infer_schema( + ["1", "90329", "5645"], "integer" + ) # here, "1" is also considered a boolean def test_given_numbers_and_integers_when_infer_schema_then_type_is_number(self) -> None: self._config_format.inference_type = InferenceType.PRIMITIVE_TYPES_ONLY @@ -199,7 +283,9 @@ def test_given_numbers_and_integers_when_infer_schema_then_type_is_number(self) def test_given_arrays_when_infer_schema_then_type_is_string(self) -> None: self._config_format.inference_type = InferenceType.PRIMITIVE_TYPES_ONLY - self._test_infer_schema(['["first_item", "second_item"]', '["first_item_again", "second_item_again"]'], "string") + self._test_infer_schema( + ['["first_item", "second_item"]', '["first_item_again", "second_item_again"]'], "string" + ) def test_given_objects_when_infer_schema_then_type_is_object(self) -> None: self._config_format.inference_type = InferenceType.PRIMITIVE_TYPES_ONLY @@ -215,11 +301,15 @@ def test_given_a_null_value_when_infer_then_ignore_null(self) -> None: def test_given_only_null_values_when_infer_then_type_is_string(self) -> None: self._config_format.inference_type = InferenceType.PRIMITIVE_TYPES_ONLY - self._test_infer_schema([self._A_NULL_VALUE, self._A_NULL_VALUE, self._A_NULL_VALUE], "string") + self._test_infer_schema( + [self._A_NULL_VALUE, self._A_NULL_VALUE, self._A_NULL_VALUE], "string" + ) def test_given_big_file_when_infer_schema_then_stop_early(self) -> None: self._config_format.inference_type = InferenceType.PRIMITIVE_TYPES_ONLY - self._csv_reader.read_data.return_value = ({self._HEADER_NAME: row} for row in ["2." + "2" * 1_000_000] + ["this is a string"]) + self._csv_reader.read_data.return_value = ( + {self._HEADER_NAME: row} for row in ["2." + "2" * 1_000_000] + ["this is a string"] + ) inferred_schema = self._infer_schema() # since the type is number, we know the string at the end was not considered assert inferred_schema == {self._HEADER_NAME: {"type": "number"}} @@ -237,7 +327,9 @@ def _test_infer_schema(self, rows: List[str], expected_type: str) -> None: def _infer_schema(self): loop = asyncio.new_event_loop() - task = loop.create_task(self._parser.infer_schema(self._config, self._file, self._stream_reader, self._logger)) + task = loop.create_task( + self._parser.infer_schema(self._config, self._file, self._stream_reader, self._logger) + ) loop.run_until_complete(task) return task.result() @@ -292,42 +384,64 @@ def test_given_skip_rows_when_read_data_then_do_not_considered_prefixed_rows(sel assert list(data_generator) == [{"header": "a value"}, {"header": "another value"}] - def test_given_autogenerated_headers_when_read_data_then_generate_headers_with_format_fX(self) -> None: + def test_given_autogenerated_headers_when_read_data_then_generate_headers_with_format_fX( + self, + ) -> None: self._config_format.header_definition = CsvHeaderAutogenerated() - self._stream_reader.open_file.return_value = CsvFileBuilder().with_data(["0,1,2,3,4,5,6"]).build() + self._stream_reader.open_file.return_value = ( + CsvFileBuilder().with_data(["0,1,2,3,4,5,6"]).build() + ) data_generator = self._read_data() - assert list(data_generator) == [{"f0": "0", "f1": "1", "f2": "2", "f3": "3", "f4": "4", "f5": "5", "f6": "6"}] + assert list(data_generator) == [ + {"f0": "0", "f1": "1", "f2": "2", "f3": "3", "f4": "4", "f5": "5", "f6": "6"} + ] - def test_given_skip_row_before_and_after_and_autogenerated_headers_when_read_data_then_generate_headers_with_format_fX(self) -> None: + def test_given_skip_row_before_and_after_and_autogenerated_headers_when_read_data_then_generate_headers_with_format_fX( + self, + ) -> None: self._config_format.header_definition = CsvHeaderAutogenerated() self._config_format.skip_rows_before_header = 1 self._config_format.skip_rows_after_header = 2 self._stream_reader.open_file.return_value = ( - CsvFileBuilder().with_data(["skip before", "skip after 1", "skip after 2", "0,1,2,3,4,5,6"]).build() + CsvFileBuilder() + .with_data(["skip before", "skip after 1", "skip after 2", "0,1,2,3,4,5,6"]) + .build() ) data_generator = self._read_data() - assert list(data_generator) == [{"f0": "0", "f1": "1", "f2": "2", "f3": "3", "f4": "4", "f5": "5", "f6": "6"}] + assert list(data_generator) == [ + {"f0": "0", "f1": "1", "f2": "2", "f3": "3", "f4": "4", "f5": "5", "f6": "6"} + ] - def test_given_user_provided_headers_when_read_data_then_use_user_provided_headers(self) -> None: - self._config_format.header_definition = CsvHeaderUserProvided(column_names=["first", "second", "third", "fourth"]) + def test_given_user_provided_headers_when_read_data_then_use_user_provided_headers( + self, + ) -> None: + self._config_format.header_definition = CsvHeaderUserProvided( + column_names=["first", "second", "third", "fourth"] + ) self._stream_reader.open_file.return_value = CsvFileBuilder().with_data(["0,1,2,3"]).build() data_generator = self._read_data() assert list(data_generator) == [{"first": "0", "second": "1", "third": "2", "fourth": "3"}] - def test_given_len_mistmatch_on_user_provided_headers_when_read_data_then_raise_error(self) -> None: - self._config_format.header_definition = CsvHeaderUserProvided(column_names=["missing", "one", "column"]) + def test_given_len_mistmatch_on_user_provided_headers_when_read_data_then_raise_error( + self, + ) -> None: + self._config_format.header_definition = CsvHeaderUserProvided( + column_names=["missing", "one", "column"] + ) self._stream_reader.open_file.return_value = CsvFileBuilder().with_data(["0,1,2,3"]).build() with pytest.raises(RecordParseError): list(self._read_data()) - def test_given_skip_rows_after_header_when_read_data_then_do_not_parse_skipped_rows(self) -> None: + def test_given_skip_rows_after_header_when_read_data_then_do_not_parse_skipped_rows( + self, + ) -> None: self._config_format.skip_rows_after_header = 1 self._stream_reader.open_file.return_value = ( CsvFileBuilder() @@ -415,7 +529,9 @@ def test_given_double_quote_on_when_read_data_then_parse_properly(self) -> None: data_generator = self._read_data() - assert list(data_generator) == [{"header1": "1", "header2": 'Text with doublequote: "This is a text."'}] + assert list(data_generator) == [ + {"header1": "1", "header2": 'Text with doublequote: "This is a text."'} + ] def test_given_double_quote_off_when_read_data_then_parse_properly(self) -> None: self._config_format.double_quote = False @@ -432,7 +548,9 @@ def test_given_double_quote_off_when_read_data_then_parse_properly(self) -> None data_generator = self._read_data() - assert list(data_generator) == [{"header1": "1", "header2": 'Text with doublequote: "This is a text."""'}] + assert list(data_generator) == [ + {"header1": "1", "header2": 'Text with doublequote: "This is a text."""'} + ] def test_given_generator_closed_when_read_data_then_unregister_dialect(self) -> None: self._stream_reader.open_file.return_value = ( @@ -455,7 +573,9 @@ def test_given_generator_closed_when_read_data_then_unregister_dialect(self) -> data_generator.close() assert new_dialect not in csv.list_dialects() - def test_given_too_many_values_for_columns_when_read_data_then_raise_exception_and_unregister_dialect(self) -> None: + def test_given_too_many_values_for_columns_when_read_data_then_raise_exception_and_unregister_dialect( + self, + ) -> None: self._stream_reader.open_file.return_value = ( CsvFileBuilder() .with_data( @@ -478,7 +598,9 @@ def test_given_too_many_values_for_columns_when_read_data_then_raise_exception_a next(data_generator) assert new_dialect not in csv.list_dialects() - def test_given_too_few_values_for_columns_when_read_data_then_raise_exception_and_unregister_dialect(self) -> None: + def test_given_too_few_values_for_columns_when_read_data_then_raise_exception_and_unregister_dialect( + self, + ) -> None: self._stream_reader.open_file.return_value = ( CsvFileBuilder() .with_data( @@ -521,8 +643,12 @@ def test_parse_field_size_larger_than_default_python_maximum(self) -> None: assert list(data_generator) == [{"header1": "1", "header2": long_string}] def test_read_data_with_encoding_error(self) -> None: - self._stream_reader.open_file.return_value = CsvFileBuilder().with_data(["something"]).build() - self._csv_reader._get_headers = Mock(side_effect=UnicodeDecodeError("encoding", b"", 0, 1, "reason")) + self._stream_reader.open_file.return_value = ( + CsvFileBuilder().with_data(["something"]).build() + ) + self._csv_reader._get_headers = Mock( + side_effect=UnicodeDecodeError("encoding", b"", 0, 1, "reason") + ) with pytest.raises(AirbyteTracedException) as ate: data_generator = self._read_data() @@ -579,7 +705,9 @@ def _read_data(self) -> Generator[Dict[str, str], None, None]: ), ], ) -def test_mismatch_between_values_and_header(ignore_errors_on_fields_mismatch, data, error_message) -> None: +def test_mismatch_between_values_and_header( + ignore_errors_on_fields_mismatch, data, error_message +) -> None: config_format = CsvFormat() config = Mock() config.name = "config_name" @@ -619,8 +747,21 @@ def test_encoding_is_passed_to_stream_reader() -> None: mock_obj.__enter__ = Mock(return_value=io.StringIO("c1,c2\nv1,v2")) mock_obj.__exit__ = Mock(return_value=None) file = RemoteFile(uri="s3://bucket/key.csv", last_modified=datetime.now()) - config = FileBasedStreamConfig(name="test", validation_policy="Emit Record", file_type="csv", format=CsvFormat(encoding=encoding)) - list(parser.parse_records(config, file, stream_reader, logger, {"properties": {"c1": {"type": "string"}, "c2": {"type": "string"}}})) + config = FileBasedStreamConfig( + name="test", + validation_policy="Emit Record", + file_type="csv", + format=CsvFormat(encoding=encoding), + ) + list( + parser.parse_records( + config, + file, + stream_reader, + logger, + {"properties": {"c1": {"type": "string"}, "c2": {"type": "string"}}}, + ) + ) stream_reader.open_file.assert_has_calls( [ mock.call(file, FileReadMode.READ, encoding, logger), diff --git a/unit_tests/sources/file_based/file_types/test_excel_parser.py b/unit_tests/sources/file_based/file_types/test_excel_parser.py index bd9d8338..aac74be9 100644 --- a/unit_tests/sources/file_based/file_types/test_excel_parser.py +++ b/unit_tests/sources/file_based/file_types/test_excel_parser.py @@ -9,7 +9,11 @@ import pandas as pd import pytest -from airbyte_cdk.sources.file_based.config.file_based_stream_config import ExcelFormat, FileBasedStreamConfig, ValidationPolicy +from airbyte_cdk.sources.file_based.config.file_based_stream_config import ( + ExcelFormat, + FileBasedStreamConfig, + ValidationPolicy, +) from airbyte_cdk.sources.file_based.exceptions import ConfigValidationError, RecordParseError from airbyte_cdk.sources.file_based.file_based_stream_reader import AbstractFileBasedStreamReader from airbyte_cdk.sources.file_based.file_types.excel_parser import ExcelParser @@ -66,7 +70,14 @@ def setup_parser(remote_file): stream_reader = MagicMock(spec=AbstractFileBasedStreamReader) stream_reader.open_file.return_value = BytesIO(excel_bytes.read()) - return parser, FileBasedStreamConfig(name="test_stream", format=ExcelFormat()), remote_file, stream_reader, MagicMock(), data + return ( + parser, + FileBasedStreamConfig(name="test_stream", format=ExcelFormat()), + remote_file, + stream_reader, + MagicMock(), + data, + ) @patch("pandas.ExcelFile") @@ -92,7 +103,9 @@ async def test_infer_schema(mock_excel_file, setup_parser): assert schema == expected_schema # Assert that the stream_reader's open_file was called correctly - stream_reader.open_file.assert_called_once_with(file, parser.file_read_mode, parser.ENCODING, logger) + stream_reader.open_file.assert_called_once_with( + file, parser.file_read_mode, parser.ENCODING, logger + ) # Assert that the logger was not used for warnings/errors logger.info.assert_not_called() @@ -119,4 +132,6 @@ def test_file_read_error(mock_stream_reader, mock_logger, file_config, remote_fi mock_excel.return_value.parse.side_effect = ValueError("Failed to parse file") with pytest.raises(RecordParseError): - list(parser.parse_records(file_config, remote_file, mock_stream_reader, mock_logger)) + list( + parser.parse_records(file_config, remote_file, mock_stream_reader, mock_logger) + ) diff --git a/unit_tests/sources/file_based/file_types/test_jsonl_parser.py b/unit_tests/sources/file_based/file_types/test_jsonl_parser.py index af5d83d7..cf924131 100644 --- a/unit_tests/sources/file_based/file_types/test_jsonl_parser.py +++ b/unit_tests/sources/file_based/file_types/test_jsonl_parser.py @@ -51,8 +51,18 @@ def _infer_schema(stream_reader: MagicMock) -> Dict[str, Any]: def test_when_infer_then_return_proper_types(stream_reader: MagicMock) -> None: - record = {"col1": 1, "col2": 2.2, "col3": "3", "col4": ["a", "list"], "col5": {"inner": "obj"}, "col6": None, "col7": True} - stream_reader.open_file.return_value.__enter__.return_value = io.BytesIO(json.dumps(record).encode("utf-8")) + record = { + "col1": 1, + "col2": 2.2, + "col3": "3", + "col4": ["a", "list"], + "col5": {"inner": "obj"}, + "col6": None, + "col7": True, + } + stream_reader.open_file.return_value.__enter__.return_value = io.BytesIO( + json.dumps(record).encode("utf-8") + ) schema = _infer_schema(stream_reader) @@ -88,8 +98,14 @@ def test_given_no_records_when_infer_then_return_empty_schema(stream_reader: Mag def test_given_limit_hit_when_infer_then_stop_considering_records(stream_reader: MagicMock) -> None: - jsonl_file_content = '{"key": 2.' + "2" * JsonlParser.MAX_BYTES_PER_FILE_FOR_SCHEMA_INFERENCE + '}\n{"key": "a string"}' - stream_reader.open_file.return_value.__enter__.return_value = io.BytesIO(jsonl_file_content.encode("utf-8")) + jsonl_file_content = ( + '{"key": 2.' + + "2" * JsonlParser.MAX_BYTES_PER_FILE_FOR_SCHEMA_INFERENCE + + '}\n{"key": "a string"}' + ) + stream_reader.open_file.return_value.__enter__.return_value = io.BytesIO( + jsonl_file_content.encode("utf-8") + ) schema = _infer_schema(stream_reader) @@ -99,34 +115,52 @@ def test_given_limit_hit_when_infer_then_stop_considering_records(stream_reader: def test_given_multiline_json_objects_and_read_limit_hit_when_infer_then_return_parse_until_at_least_one_record( stream_reader: MagicMock, ) -> None: - jsonl_file_content = '{\n"key": 2.' + "2" * JsonlParser.MAX_BYTES_PER_FILE_FOR_SCHEMA_INFERENCE + "\n}" - stream_reader.open_file.return_value.__enter__.return_value = io.BytesIO(jsonl_file_content.encode("utf-8")) + jsonl_file_content = ( + '{\n"key": 2.' + "2" * JsonlParser.MAX_BYTES_PER_FILE_FOR_SCHEMA_INFERENCE + "\n}" + ) + stream_reader.open_file.return_value.__enter__.return_value = io.BytesIO( + jsonl_file_content.encode("utf-8") + ) schema = _infer_schema(stream_reader) assert schema == {"key": {"type": "number"}} -def test_given_multiline_json_objects_and_hits_read_limit_when_infer_then_return_proper_types(stream_reader: MagicMock) -> None: - stream_reader.open_file.return_value.__enter__.return_value = JSONL_CONTENT_WITH_MULTILINE_JSON_OBJECTS +def test_given_multiline_json_objects_and_hits_read_limit_when_infer_then_return_proper_types( + stream_reader: MagicMock, +) -> None: + stream_reader.open_file.return_value.__enter__.return_value = ( + JSONL_CONTENT_WITH_MULTILINE_JSON_OBJECTS + ) schema = _infer_schema(stream_reader) assert schema == {"a": {"type": "integer"}, "b": {"type": "string"}} def test_given_multiple_records_then_merge_types(stream_reader: MagicMock) -> None: - stream_reader.open_file.return_value.__enter__.return_value = io.BytesIO('{"col1": 1}\n{"col1": 2.3}'.encode("utf-8")) + stream_reader.open_file.return_value.__enter__.return_value = io.BytesIO( + '{"col1": 1}\n{"col1": 2.3}'.encode("utf-8") + ) schema = _infer_schema(stream_reader) assert schema == {"col1": {"type": "number"}} -def test_given_one_json_per_line_when_parse_records_then_return_records(stream_reader: MagicMock) -> None: - stream_reader.open_file.return_value.__enter__.return_value = JSONL_CONTENT_WITHOUT_MULTILINE_JSON_OBJECTS +def test_given_one_json_per_line_when_parse_records_then_return_records( + stream_reader: MagicMock, +) -> None: + stream_reader.open_file.return_value.__enter__.return_value = ( + JSONL_CONTENT_WITHOUT_MULTILINE_JSON_OBJECTS + ) records = list(JsonlParser().parse_records(Mock(), Mock(), stream_reader, Mock(), None)) assert records == [{"a": 1, "b": "1"}, {"a": 2, "b": "2"}] -def test_given_one_json_per_line_when_parse_records_then_do_not_send_warning(stream_reader: MagicMock) -> None: - stream_reader.open_file.return_value.__enter__.return_value = JSONL_CONTENT_WITHOUT_MULTILINE_JSON_OBJECTS +def test_given_one_json_per_line_when_parse_records_then_do_not_send_warning( + stream_reader: MagicMock, +) -> None: + stream_reader.open_file.return_value.__enter__.return_value = ( + JSONL_CONTENT_WITHOUT_MULTILINE_JSON_OBJECTS + ) logger = Mock() list(JsonlParser().parse_records(Mock(), Mock(), stream_reader, logger, None)) @@ -134,14 +168,22 @@ def test_given_one_json_per_line_when_parse_records_then_do_not_send_warning(str assert logger.warning.call_count == 0 -def test_given_multiline_json_object_when_parse_records_then_return_records(stream_reader: MagicMock) -> None: - stream_reader.open_file.return_value.__enter__.return_value = JSONL_CONTENT_WITH_MULTILINE_JSON_OBJECTS +def test_given_multiline_json_object_when_parse_records_then_return_records( + stream_reader: MagicMock, +) -> None: + stream_reader.open_file.return_value.__enter__.return_value = ( + JSONL_CONTENT_WITH_MULTILINE_JSON_OBJECTS + ) records = list(JsonlParser().parse_records(Mock(), Mock(), stream_reader, Mock(), None)) assert records == [{"a": 1, "b": "1"}, {"a": 2, "b": "2"}] -def test_given_multiline_json_object_when_parse_records_then_log_once_one_record_yielded(stream_reader: MagicMock) -> None: - stream_reader.open_file.return_value.__enter__.return_value = JSONL_CONTENT_WITH_MULTILINE_JSON_OBJECTS +def test_given_multiline_json_object_when_parse_records_then_log_once_one_record_yielded( + stream_reader: MagicMock, +) -> None: + stream_reader.open_file.return_value.__enter__.return_value = ( + JSONL_CONTENT_WITH_MULTILINE_JSON_OBJECTS + ) logger = Mock() next(iter(JsonlParser().parse_records(Mock(), Mock(), stream_reader, logger, None))) @@ -149,7 +191,9 @@ def test_given_multiline_json_object_when_parse_records_then_log_once_one_record assert logger.warning.call_count == 1 -def test_given_unparsable_json_when_parse_records_then_raise_error(stream_reader: MagicMock) -> None: +def test_given_unparsable_json_when_parse_records_then_raise_error( + stream_reader: MagicMock, +) -> None: stream_reader.open_file.return_value.__enter__.return_value = INVALID_JSON_CONTENT logger = Mock() diff --git a/unit_tests/sources/file_based/file_types/test_parquet_parser.py b/unit_tests/sources/file_based/file_types/test_parquet_parser.py index c4768fac..e0c06e86 100644 --- a/unit_tests/sources/file_based/file_types/test_parquet_parser.py +++ b/unit_tests/sources/file_based/file_types/test_parquet_parser.py @@ -11,7 +11,10 @@ import pyarrow as pa import pytest from airbyte_cdk.sources.file_based.config.csv_format import CsvFormat -from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig, ValidationPolicy +from airbyte_cdk.sources.file_based.config.file_based_stream_config import ( + FileBasedStreamConfig, + ValidationPolicy, +) from airbyte_cdk.sources.file_based.config.jsonl_format import JsonlFormat from airbyte_cdk.sources.file_based.config.parquet_format import ParquetFormat from airbyte_cdk.sources.file_based.file_types import ParquetParser @@ -25,25 +28,65 @@ @pytest.mark.parametrize( "parquet_type, expected_type, parquet_format", [ - pytest.param(pa.bool_(), {"type": "boolean"}, _default_parquet_format, id="test_parquet_bool"), - pytest.param(pa.int8(), {"type": "integer"}, _default_parquet_format, id="test_parquet_int8"), - pytest.param(pa.int16(), {"type": "integer"}, _default_parquet_format, id="test_parquet_int16"), - pytest.param(pa.int32(), {"type": "integer"}, _default_parquet_format, id="test_parquet_int32"), - pytest.param(pa.int64(), {"type": "integer"}, _default_parquet_format, id="test_parquet_int64"), - pytest.param(pa.uint8(), {"type": "integer"}, _default_parquet_format, id="test_parquet_uint8"), - pytest.param(pa.uint16(), {"type": "integer"}, _default_parquet_format, id="test_parquet_uint16"), - pytest.param(pa.uint32(), {"type": "integer"}, _default_parquet_format, id="test_parquet_uint32"), - pytest.param(pa.uint64(), {"type": "integer"}, _default_parquet_format, id="test_parquet_uint64"), - pytest.param(pa.float16(), {"type": "number"}, _default_parquet_format, id="test_parquet_float16"), - pytest.param(pa.float32(), {"type": "number"}, _default_parquet_format, id="test_parquet_float32"), - pytest.param(pa.float64(), {"type": "number"}, _default_parquet_format, id="test_parquet_float64"), - pytest.param(pa.time32("s"), {"type": "string"}, _default_parquet_format, id="test_parquet_time32s"), - pytest.param(pa.time32("ms"), {"type": "string"}, _default_parquet_format, id="test_parquet_time32ms"), - pytest.param(pa.time64("us"), {"type": "string"}, _default_parquet_format, id="test_parquet_time64us"), - pytest.param(pa.time64("ns"), {"type": "string"}, _default_parquet_format, id="test_parquet_time64us"), - pytest.param(pa.timestamp("s"), {"type": "string", "format": "date-time"}, _default_parquet_format, id="test_parquet_timestamps_s"), - pytest.param( - pa.timestamp("ms"), {"type": "string", "format": "date-time"}, _default_parquet_format, id="test_parquet_timestamp_ms" + pytest.param( + pa.bool_(), {"type": "boolean"}, _default_parquet_format, id="test_parquet_bool" + ), + pytest.param( + pa.int8(), {"type": "integer"}, _default_parquet_format, id="test_parquet_int8" + ), + pytest.param( + pa.int16(), {"type": "integer"}, _default_parquet_format, id="test_parquet_int16" + ), + pytest.param( + pa.int32(), {"type": "integer"}, _default_parquet_format, id="test_parquet_int32" + ), + pytest.param( + pa.int64(), {"type": "integer"}, _default_parquet_format, id="test_parquet_int64" + ), + pytest.param( + pa.uint8(), {"type": "integer"}, _default_parquet_format, id="test_parquet_uint8" + ), + pytest.param( + pa.uint16(), {"type": "integer"}, _default_parquet_format, id="test_parquet_uint16" + ), + pytest.param( + pa.uint32(), {"type": "integer"}, _default_parquet_format, id="test_parquet_uint32" + ), + pytest.param( + pa.uint64(), {"type": "integer"}, _default_parquet_format, id="test_parquet_uint64" + ), + pytest.param( + pa.float16(), {"type": "number"}, _default_parquet_format, id="test_parquet_float16" + ), + pytest.param( + pa.float32(), {"type": "number"}, _default_parquet_format, id="test_parquet_float32" + ), + pytest.param( + pa.float64(), {"type": "number"}, _default_parquet_format, id="test_parquet_float64" + ), + pytest.param( + pa.time32("s"), {"type": "string"}, _default_parquet_format, id="test_parquet_time32s" + ), + pytest.param( + pa.time32("ms"), {"type": "string"}, _default_parquet_format, id="test_parquet_time32ms" + ), + pytest.param( + pa.time64("us"), {"type": "string"}, _default_parquet_format, id="test_parquet_time64us" + ), + pytest.param( + pa.time64("ns"), {"type": "string"}, _default_parquet_format, id="test_parquet_time64us" + ), + pytest.param( + pa.timestamp("s"), + {"type": "string", "format": "date-time"}, + _default_parquet_format, + id="test_parquet_timestamps_s", + ), + pytest.param( + pa.timestamp("ms"), + {"type": "string", "format": "date-time"}, + _default_parquet_format, + id="test_parquet_timestamp_ms", ), pytest.param( pa.timestamp("s", "utc"), @@ -57,38 +100,111 @@ _default_parquet_format, id="test_parquet_timestamps_ms_with_tz", ), - pytest.param(pa.date32(), {"type": "string", "format": "date"}, _default_parquet_format, id="test_parquet_date32"), - pytest.param(pa.date64(), {"type": "string", "format": "date"}, _default_parquet_format, id="test_parquet_date64"), - pytest.param(pa.duration("s"), {"type": "integer"}, _default_parquet_format, id="test_duration_s"), - pytest.param(pa.duration("ms"), {"type": "integer"}, _default_parquet_format, id="test_duration_ms"), - pytest.param(pa.duration("us"), {"type": "integer"}, _default_parquet_format, id="test_duration_us"), - pytest.param(pa.duration("ns"), {"type": "integer"}, _default_parquet_format, id="test_duration_ns"), - pytest.param(pa.month_day_nano_interval(), {"type": "array"}, _default_parquet_format, id="test_parquet_month_day_nano_interval"), + pytest.param( + pa.date32(), + {"type": "string", "format": "date"}, + _default_parquet_format, + id="test_parquet_date32", + ), + pytest.param( + pa.date64(), + {"type": "string", "format": "date"}, + _default_parquet_format, + id="test_parquet_date64", + ), + pytest.param( + pa.duration("s"), {"type": "integer"}, _default_parquet_format, id="test_duration_s" + ), + pytest.param( + pa.duration("ms"), {"type": "integer"}, _default_parquet_format, id="test_duration_ms" + ), + pytest.param( + pa.duration("us"), {"type": "integer"}, _default_parquet_format, id="test_duration_us" + ), + pytest.param( + pa.duration("ns"), {"type": "integer"}, _default_parquet_format, id="test_duration_ns" + ), + pytest.param( + pa.month_day_nano_interval(), + {"type": "array"}, + _default_parquet_format, + id="test_parquet_month_day_nano_interval", + ), pytest.param(pa.binary(), {"type": "string"}, _default_parquet_format, id="test_binary"), - pytest.param(pa.binary(2), {"type": "string"}, _default_parquet_format, id="test_fixed_size_binary"), - pytest.param(pa.string(), {"type": "string"}, _default_parquet_format, id="test_parquet_string"), + pytest.param( + pa.binary(2), {"type": "string"}, _default_parquet_format, id="test_fixed_size_binary" + ), + pytest.param( + pa.string(), {"type": "string"}, _default_parquet_format, id="test_parquet_string" + ), pytest.param(pa.utf8(), {"type": "string"}, _default_parquet_format, id="test_utf8"), - pytest.param(pa.large_binary(), {"type": "string"}, _default_parquet_format, id="test_large_binary"), - pytest.param(pa.large_string(), {"type": "string"}, _default_parquet_format, id="test_large_string"), - pytest.param(pa.large_utf8(), {"type": "string"}, _default_parquet_format, id="test_large_utf8"), - pytest.param(pa.dictionary(pa.int32(), pa.string()), {"type": "object"}, _default_parquet_format, id="test_dictionary"), - pytest.param(pa.struct([pa.field("field", pa.int32())]), {"type": "object"}, _default_parquet_format, id="test_struct"), - pytest.param(pa.list_(pa.int32()), {"type": "array"}, _default_parquet_format, id="test_list"), - pytest.param(pa.large_list(pa.int32()), {"type": "array"}, _default_parquet_format, id="test_large_list"), - pytest.param(pa.decimal128(2), {"type": "string"}, _default_parquet_format, id="test_decimal128"), - pytest.param(pa.decimal256(2), {"type": "string"}, _default_parquet_format, id="test_decimal256"), - pytest.param(pa.decimal128(2), {"type": "number"}, _decimal_as_float_parquet_format, id="test_decimal128_as_float"), - pytest.param(pa.decimal256(2), {"type": "number"}, _decimal_as_float_parquet_format, id="test_decimal256_as_float"), - pytest.param(pa.map_(pa.int32(), pa.int32()), {"type": "object"}, _default_parquet_format, id="test_map"), + pytest.param( + pa.large_binary(), {"type": "string"}, _default_parquet_format, id="test_large_binary" + ), + pytest.param( + pa.large_string(), {"type": "string"}, _default_parquet_format, id="test_large_string" + ), + pytest.param( + pa.large_utf8(), {"type": "string"}, _default_parquet_format, id="test_large_utf8" + ), + pytest.param( + pa.dictionary(pa.int32(), pa.string()), + {"type": "object"}, + _default_parquet_format, + id="test_dictionary", + ), + pytest.param( + pa.struct([pa.field("field", pa.int32())]), + {"type": "object"}, + _default_parquet_format, + id="test_struct", + ), + pytest.param( + pa.list_(pa.int32()), {"type": "array"}, _default_parquet_format, id="test_list" + ), + pytest.param( + pa.large_list(pa.int32()), + {"type": "array"}, + _default_parquet_format, + id="test_large_list", + ), + pytest.param( + pa.decimal128(2), {"type": "string"}, _default_parquet_format, id="test_decimal128" + ), + pytest.param( + pa.decimal256(2), {"type": "string"}, _default_parquet_format, id="test_decimal256" + ), + pytest.param( + pa.decimal128(2), + {"type": "number"}, + _decimal_as_float_parquet_format, + id="test_decimal128_as_float", + ), + pytest.param( + pa.decimal256(2), + {"type": "number"}, + _decimal_as_float_parquet_format, + id="test_decimal256_as_float", + ), + pytest.param( + pa.map_(pa.int32(), pa.int32()), + {"type": "object"}, + _default_parquet_format, + id="test_map", + ), pytest.param(pa.null(), {"type": "null"}, _default_parquet_format, id="test_null"), ], ) -def test_type_mapping(parquet_type: pa.DataType, expected_type: Mapping[str, str], parquet_format: ParquetFormat) -> None: +def test_type_mapping( + parquet_type: pa.DataType, expected_type: Mapping[str, str], parquet_format: ParquetFormat +) -> None: if expected_type is None: with pytest.raises(ValueError): ParquetParser.parquet_type_to_schema_type(parquet_type, parquet_format) else: - assert ParquetParser.parquet_type_to_schema_type(parquet_type, parquet_format) == expected_type + assert ( + ParquetParser.parquet_type_to_schema_type(parquet_type, parquet_format) == expected_type + ) @pytest.mark.parametrize( @@ -105,10 +221,34 @@ def test_type_mapping(parquet_type: pa.DataType, expected_type: Mapping[str, str pytest.param(pa.uint64(), _default_parquet_format, 6, 6, id="test_parquet_uint64"), pytest.param(pa.float32(), _default_parquet_format, 2.7, 2.7, id="test_parquet_float32"), pytest.param(pa.float64(), _default_parquet_format, 3.14, 3.14, id="test_parquet_float64"), - pytest.param(pa.time32("s"), _default_parquet_format, datetime.time(1, 2, 3), "01:02:03", id="test_parquet_time32s"), - pytest.param(pa.time32("ms"), _default_parquet_format, datetime.time(3, 4, 5), "03:04:05", id="test_parquet_time32ms"), - pytest.param(pa.time64("us"), _default_parquet_format, datetime.time(6, 7, 8), "06:07:08", id="test_parquet_time64us"), - pytest.param(pa.time64("ns"), _default_parquet_format, datetime.time(9, 10, 11), "09:10:11", id="test_parquet_time64us"), + pytest.param( + pa.time32("s"), + _default_parquet_format, + datetime.time(1, 2, 3), + "01:02:03", + id="test_parquet_time32s", + ), + pytest.param( + pa.time32("ms"), + _default_parquet_format, + datetime.time(3, 4, 5), + "03:04:05", + id="test_parquet_time32ms", + ), + pytest.param( + pa.time64("us"), + _default_parquet_format, + datetime.time(6, 7, 8), + "06:07:08", + id="test_parquet_time64us", + ), + pytest.param( + pa.time64("ns"), + _default_parquet_format, + datetime.time(9, 10, 11), + "09:10:11", + id="test_parquet_time64us", + ), pytest.param( pa.timestamp("s"), _default_parquet_format, @@ -137,12 +277,30 @@ def test_type_mapping(parquet_type: pa.DataType, expected_type: Mapping[str, str "2021-02-03T04:05:00+00:00", id="test_parquet_timestamps_ms_with_tz", ), - pytest.param(pa.date32(), _default_parquet_format, datetime.date(2023, 7, 7), "2023-07-07", id="test_parquet_date32"), - pytest.param(pa.date64(), _default_parquet_format, datetime.date(2023, 7, 8), "2023-07-08", id="test_parquet_date64"), + pytest.param( + pa.date32(), + _default_parquet_format, + datetime.date(2023, 7, 7), + "2023-07-07", + id="test_parquet_date32", + ), + pytest.param( + pa.date64(), + _default_parquet_format, + datetime.date(2023, 7, 8), + "2023-07-08", + id="test_parquet_date64", + ), pytest.param(pa.duration("s"), _default_parquet_format, 12345, 12345, id="test_duration_s"), - pytest.param(pa.duration("ms"), _default_parquet_format, 12345, 12345, id="test_duration_ms"), - pytest.param(pa.duration("us"), _default_parquet_format, 12345, 12345, id="test_duration_us"), - pytest.param(pa.duration("ns"), _default_parquet_format, 12345, 12345, id="test_duration_ns"), + pytest.param( + pa.duration("ms"), _default_parquet_format, 12345, 12345, id="test_duration_ms" + ), + pytest.param( + pa.duration("us"), _default_parquet_format, 12345, 12345, id="test_duration_us" + ), + pytest.param( + pa.duration("ns"), _default_parquet_format, 12345, 12345, id="test_duration_ns" + ), pytest.param( pa.month_day_nano_interval(), _default_parquet_format, @@ -150,28 +308,91 @@ def test_type_mapping(parquet_type: pa.DataType, expected_type: Mapping[str, str [0, 3, 4000], id="test_parquet_month_day_nano_interval", ), - pytest.param(pa.binary(), _default_parquet_format, b"this is a binary string", "this is a binary string", id="test_binary"), - pytest.param(pa.binary(2), _default_parquet_format, b"t1", "t1", id="test_fixed_size_binary"), - pytest.param(pa.string(), _default_parquet_format, "this is a string", "this is a string", id="test_parquet_string"), - pytest.param(pa.utf8(), _default_parquet_format, "utf8".encode("utf8"), "utf8", id="test_utf8"), - pytest.param(pa.large_binary(), _default_parquet_format, b"large binary string", "large binary string", id="test_large_binary"), - pytest.param(pa.large_string(), _default_parquet_format, "large string", "large string", id="test_large_string"), - pytest.param(pa.large_utf8(), _default_parquet_format, "large utf8", "large utf8", id="test_large_utf8"), - pytest.param(pa.struct([pa.field("field", pa.int32())]), _default_parquet_format, {"field": 1}, {"field": 1}, id="test_struct"), - pytest.param(pa.list_(pa.int32()), _default_parquet_format, [1, 2, 3], [1, 2, 3], id="test_list"), - pytest.param(pa.large_list(pa.int32()), _default_parquet_format, [4, 5, 6], [4, 5, 6], id="test_large_list"), - pytest.param(pa.decimal128(5, 3), _default_parquet_format, 12, "12.000", id="test_decimal128"), - pytest.param(pa.decimal256(8, 2), _default_parquet_format, 13, "13.00", id="test_decimal256"), - pytest.param(pa.decimal128(5, 3), _decimal_as_float_parquet_format, 12, 12.000, id="test_decimal128"), - pytest.param(pa.decimal256(8, 2), _decimal_as_float_parquet_format, 13, 13.00, id="test_decimal256"), - pytest.param( - pa.map_(pa.string(), pa.int32()), _default_parquet_format, {"hello": 1, "world": 2}, {"hello": 1, "world": 2}, id="test_map" + pytest.param( + pa.binary(), + _default_parquet_format, + b"this is a binary string", + "this is a binary string", + id="test_binary", + ), + pytest.param( + pa.binary(2), _default_parquet_format, b"t1", "t1", id="test_fixed_size_binary" + ), + pytest.param( + pa.string(), + _default_parquet_format, + "this is a string", + "this is a string", + id="test_parquet_string", + ), + pytest.param( + pa.utf8(), _default_parquet_format, "utf8".encode("utf8"), "utf8", id="test_utf8" + ), + pytest.param( + pa.large_binary(), + _default_parquet_format, + b"large binary string", + "large binary string", + id="test_large_binary", + ), + pytest.param( + pa.large_string(), + _default_parquet_format, + "large string", + "large string", + id="test_large_string", + ), + pytest.param( + pa.large_utf8(), + _default_parquet_format, + "large utf8", + "large utf8", + id="test_large_utf8", + ), + pytest.param( + pa.struct([pa.field("field", pa.int32())]), + _default_parquet_format, + {"field": 1}, + {"field": 1}, + id="test_struct", + ), + pytest.param( + pa.list_(pa.int32()), _default_parquet_format, [1, 2, 3], [1, 2, 3], id="test_list" + ), + pytest.param( + pa.large_list(pa.int32()), + _default_parquet_format, + [4, 5, 6], + [4, 5, 6], + id="test_large_list", + ), + pytest.param( + pa.decimal128(5, 3), _default_parquet_format, 12, "12.000", id="test_decimal128" + ), + pytest.param( + pa.decimal256(8, 2), _default_parquet_format, 13, "13.00", id="test_decimal256" + ), + pytest.param( + pa.decimal128(5, 3), _decimal_as_float_parquet_format, 12, 12.000, id="test_decimal128" + ), + pytest.param( + pa.decimal256(8, 2), _decimal_as_float_parquet_format, 13, 13.00, id="test_decimal256" + ), + pytest.param( + pa.map_(pa.string(), pa.int32()), + _default_parquet_format, + {"hello": 1, "world": 2}, + {"hello": 1, "world": 2}, + id="test_map", ), pytest.param(pa.null(), _default_parquet_format, None, None, id="test_null"), ], ) def test_value_transformation( - pyarrow_type: pa.DataType, parquet_format: ParquetFormat, parquet_object: Scalar, expected_value: Any + pyarrow_type: pa.DataType, + parquet_format: ParquetFormat, + parquet_object: Scalar, + expected_value: Any, ) -> None: pyarrow_value = pa.array([parquet_object], type=pyarrow_type)[0] py_value = ParquetParser._to_output_value(pyarrow_value, parquet_format) @@ -212,15 +433,27 @@ def test_value_dictionary() -> None: pytest.param(pa.time64("ns"), _default_parquet_format, id="test_parquet_time64ns"), pytest.param(pa.timestamp("s"), _default_parquet_format, id="test_parquet_timestamps_s"), pytest.param(pa.timestamp("ms"), _default_parquet_format, id="test_parquet_timestamp_ms"), - pytest.param(pa.timestamp("s", "utc"), _default_parquet_format, id="test_parquet_timestamps_s_with_tz"), - pytest.param(pa.timestamp("ms", "est"), _default_parquet_format, id="test_parquet_timestamps_ms_with_tz"), + pytest.param( + pa.timestamp("s", "utc"), + _default_parquet_format, + id="test_parquet_timestamps_s_with_tz", + ), + pytest.param( + pa.timestamp("ms", "est"), + _default_parquet_format, + id="test_parquet_timestamps_ms_with_tz", + ), pytest.param(pa.date32(), _default_parquet_format, id="test_parquet_date32"), pytest.param(pa.date64(), _default_parquet_format, id="test_parquet_date64"), pytest.param(pa.duration("s"), _default_parquet_format, id="test_duration_s"), pytest.param(pa.duration("ms"), _default_parquet_format, id="test_duration_ms"), pytest.param(pa.duration("us"), _default_parquet_format, id="test_duration_us"), pytest.param(pa.duration("ns"), _default_parquet_format, id="test_duration_ns"), - pytest.param(pa.month_day_nano_interval(), _default_parquet_format, id="test_parquet_month_day_nano_interval"), + pytest.param( + pa.month_day_nano_interval(), + _default_parquet_format, + id="test_parquet_month_day_nano_interval", + ), pytest.param(pa.binary(), _default_parquet_format, id="test_binary"), pytest.param(pa.binary(2), _default_parquet_format, id="test_fixed_size_binary"), pytest.param(pa.string(), _default_parquet_format, id="test_parquet_string"), @@ -228,14 +461,22 @@ def test_value_dictionary() -> None: pytest.param(pa.large_binary(), _default_parquet_format, id="test_large_binary"), pytest.param(pa.large_string(), _default_parquet_format, id="test_large_string"), pytest.param(pa.large_utf8(), _default_parquet_format, id="test_large_utf8"), - pytest.param(pa.dictionary(pa.int32(), pa.string()), _default_parquet_format, id="test_dictionary"), - pytest.param(pa.struct([pa.field("field", pa.int32())]), _default_parquet_format, id="test_struct"), + pytest.param( + pa.dictionary(pa.int32(), pa.string()), _default_parquet_format, id="test_dictionary" + ), + pytest.param( + pa.struct([pa.field("field", pa.int32())]), _default_parquet_format, id="test_struct" + ), pytest.param(pa.list_(pa.int32()), _default_parquet_format, id="test_list"), pytest.param(pa.large_list(pa.int32()), _default_parquet_format, id="test_large_list"), pytest.param(pa.decimal128(2), _default_parquet_format, id="test_decimal128"), pytest.param(pa.decimal256(2), _default_parquet_format, id="test_decimal256"), - pytest.param(pa.decimal128(2), _decimal_as_float_parquet_format, id="test_decimal128_as_float"), - pytest.param(pa.decimal256(2), _decimal_as_float_parquet_format, id="test_decimal256_as_float"), + pytest.param( + pa.decimal128(2), _decimal_as_float_parquet_format, id="test_decimal128_as_float" + ), + pytest.param( + pa.decimal256(2), _decimal_as_float_parquet_format, id="test_decimal256_as_float" + ), pytest.param(pa.map_(pa.int32(), pa.int32()), _default_parquet_format, id="test_map"), pytest.param(pa.null(), _default_parquet_format, id="test_null"), ], @@ -272,4 +513,6 @@ def test_wrong_file_format(file_format: Union[CsvFormat, JsonlFormat]) -> None: stream_reader = Mock() logger = Mock() with pytest.raises(ValueError): - asyncio.get_event_loop().run_until_complete(parser.infer_schema(config, file, stream_reader, logger)) + asyncio.get_event_loop().run_until_complete( + parser.infer_schema(config, file, stream_reader, logger) + ) diff --git a/unit_tests/sources/file_based/file_types/test_unstructured_parser.py b/unit_tests/sources/file_based/file_types/test_unstructured_parser.py index 9bc096c5..ea4e091a 100644 --- a/unit_tests/sources/file_based/file_types/test_unstructured_parser.py +++ b/unit_tests/sources/file_based/file_types/test_unstructured_parser.py @@ -11,7 +11,11 @@ import requests from airbyte_cdk.models import FailureType from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig -from airbyte_cdk.sources.file_based.config.unstructured_format import APIParameterConfigModel, APIProcessingConfigModel, UnstructuredFormat +from airbyte_cdk.sources.file_based.config.unstructured_format import ( + APIParameterConfigModel, + APIProcessingConfigModel, + UnstructuredFormat, +) from airbyte_cdk.sources.file_based.exceptions import RecordParseError from airbyte_cdk.sources.file_based.file_types import UnstructuredParser from airbyte_cdk.sources.file_based.remote_file import RemoteFile @@ -80,12 +84,22 @@ def test_infer_schema(mock_detect_filetype, filetype, format_config, raises): config.format = format_config if raises: with pytest.raises(RecordParseError): - loop.run_until_complete(UnstructuredParser().infer_schema(config, fake_file, stream_reader, logger)) + loop.run_until_complete( + UnstructuredParser().infer_schema(config, fake_file, stream_reader, logger) + ) else: - schema = loop.run_until_complete(UnstructuredParser().infer_schema(config, MagicMock(), MagicMock(), MagicMock())) + schema = loop.run_until_complete( + UnstructuredParser().infer_schema(config, MagicMock(), MagicMock(), MagicMock()) + ) assert schema == { - "content": {"type": "string", "description": "Content of the file as markdown. Might be null if the file could not be parsed"}, - "document_key": {"type": "string", "description": "Unique identifier of the document, e.g. the file path"}, + "content": { + "type": "string", + "description": "Content of the file as markdown. Might be null if the file could not be parsed", + }, + "document_key": { + "type": "string", + "description": "Unique identifier of the document, e.g. the file path", + }, "_ab_source_file_parse_error": { "type": "string", "description": "Error message if the file could not be parsed even though the file is supported", @@ -246,9 +260,20 @@ def test_parse_records( mock_partition_pdf.return_value = parse_result if raises: with pytest.raises(RecordParseError): - list(UnstructuredParser().parse_records(config, fake_file, stream_reader, logger, MagicMock())) + list( + UnstructuredParser().parse_records( + config, fake_file, stream_reader, logger, MagicMock() + ) + ) else: - assert list(UnstructuredParser().parse_records(config, fake_file, stream_reader, logger, MagicMock())) == expected_records + assert ( + list( + UnstructuredParser().parse_records( + config, fake_file, stream_reader, logger, MagicMock() + ) + ) + == expected_records + ) @pytest.mark.parametrize( @@ -279,7 +304,10 @@ def test_parse_records( id="local_unsupported_strategy", ), pytest.param( - UnstructuredFormat(skip_unprocessable_file_types=False, processing=APIProcessingConfigModel(mode="api", api_key="test")), + UnstructuredFormat( + skip_unprocessable_file_types=False, + processing=APIProcessingConfigModel(mode="api", api_key="test"), + ), False, [{"type": "Title", "text": "Airbyte source connection test"}], True, @@ -287,7 +315,10 @@ def test_parse_records( id="api_ok", ), pytest.param( - UnstructuredFormat(skip_unprocessable_file_types=False, processing=APIProcessingConfigModel(mode="api", api_key="test")), + UnstructuredFormat( + skip_unprocessable_file_types=False, + processing=APIProcessingConfigModel(mode="api", api_key="test"), + ), True, None, False, @@ -295,7 +326,10 @@ def test_parse_records( id="api_error", ), pytest.param( - UnstructuredFormat(skip_unprocessable_file_types=False, processing=APIProcessingConfigModel(mode="api", api_key="test")), + UnstructuredFormat( + skip_unprocessable_file_types=False, + processing=APIProcessingConfigModel(mode="api", api_key="test"), + ), False, {"unexpected": "response"}, False, @@ -305,13 +339,17 @@ def test_parse_records( ], ) @patch("airbyte_cdk.sources.file_based.file_types.unstructured_parser.requests") -def test_check_config(requests_mock, format_config, raises_for_status, json_response, is_ok, expected_error): +def test_check_config( + requests_mock, format_config, raises_for_status, json_response, is_ok, expected_error +): mock_response = MagicMock() mock_response.json.return_value = json_response if raises_for_status: mock_response.raise_for_status.side_effect = Exception("API error") requests_mock.post.return_value = mock_response - result, error = UnstructuredParser().check_config(FileBasedStreamConfig(name="test", format=format_config)) + result, error = UnstructuredParser().check_config( + FileBasedStreamConfig(name="test", format=format_config) + ) assert result == is_ok if expected_error: assert expected_error in error @@ -322,7 +360,10 @@ def test_check_config(requests_mock, format_config, raises_for_status, json_resp [ pytest.param( FileType.PDF, - UnstructuredFormat(skip_unprocessable_file_types=False, processing=APIProcessingConfigModel(mode="api", api_key="test")), + UnstructuredFormat( + skip_unprocessable_file_types=False, + processing=APIProcessingConfigModel(mode="api", api_key="test"), + ), None, "test", [{"type": "Text", "text": "test"}], @@ -362,7 +403,11 @@ def test_check_config(requests_mock, format_config, raises_for_status, json_resp call( "http://localhost:8000/general/v0/general", headers={"accept": "application/json", "unstructured-api-key": "test"}, - data={"strategy": "hi_res", "include_page_breaks": "true", "ocr_languages": ["eng", "kor"]}, + data={ + "strategy": "hi_res", + "include_page_breaks": "true", + "ocr_languages": ["eng", "kor"], + }, files={"files": ("filename", mock.ANY, "application/pdf")}, ) ], @@ -373,19 +418,31 @@ def test_check_config(requests_mock, format_config, raises_for_status, json_resp ), pytest.param( FileType.MD, - UnstructuredFormat(skip_unprocessable_file_types=False, processing=APIProcessingConfigModel(mode="api", api_key="test")), + UnstructuredFormat( + skip_unprocessable_file_types=False, + processing=APIProcessingConfigModel(mode="api", api_key="test"), + ), None, "# Mymarkdown", None, None, False, - [{"content": "# Mymarkdown", "document_key": FILE_URI, "_ab_source_file_parse_error": None}], + [ + { + "content": "# Mymarkdown", + "document_key": FILE_URI, + "_ab_source_file_parse_error": None, + } + ], 200, id="handle_markdown_locally", ), pytest.param( FileType.PDF, - UnstructuredFormat(skip_unprocessable_file_types=False, processing=APIProcessingConfigModel(mode="api", api_key="test")), + UnstructuredFormat( + skip_unprocessable_file_types=False, + processing=APIProcessingConfigModel(mode="api", api_key="test"), + ), [ requests.exceptions.RequestException("API error"), requests.exceptions.RequestException("API error"), @@ -439,7 +496,10 @@ def test_check_config(requests_mock, format_config, raises_for_status, json_resp ), pytest.param( FileType.PDF, - UnstructuredFormat(skip_unprocessable_file_types=False, processing=APIProcessingConfigModel(mode="api", api_key="test")), + UnstructuredFormat( + skip_unprocessable_file_types=False, + processing=APIProcessingConfigModel(mode="api", api_key="test"), + ), [ requests.exceptions.RequestException("API error"), requests.exceptions.RequestException("API error"), @@ -477,7 +537,10 @@ def test_check_config(requests_mock, format_config, raises_for_status, json_resp ), pytest.param( FileType.PDF, - UnstructuredFormat(skip_unprocessable_file_types=False, processing=APIProcessingConfigModel(mode="api", api_key="test")), + UnstructuredFormat( + skip_unprocessable_file_types=False, + processing=APIProcessingConfigModel(mode="api", api_key="test"), + ), [ Exception("Unexpected error"), ], @@ -499,9 +562,14 @@ def test_check_config(requests_mock, format_config, raises_for_status, json_resp ), pytest.param( FileType.PDF, - UnstructuredFormat(skip_unprocessable_file_types=False, processing=APIProcessingConfigModel(mode="api", api_key="test")), + UnstructuredFormat( + skip_unprocessable_file_types=False, + processing=APIProcessingConfigModel(mode="api", api_key="test"), + ), [ - requests.exceptions.RequestException("API error", response=MagicMock(status_code=400)), + requests.exceptions.RequestException( + "API error", response=MagicMock(status_code=400) + ), ], "test", [{"type": "Text", "text": "test"}], @@ -521,7 +589,10 @@ def test_check_config(requests_mock, format_config, raises_for_status, json_resp ), pytest.param( FileType.PDF, - UnstructuredFormat(skip_unprocessable_file_types=False, processing=APIProcessingConfigModel(mode="api", api_key="test")), + UnstructuredFormat( + skip_unprocessable_file_types=False, + processing=APIProcessingConfigModel(mode="api", api_key="test"), + ), None, "test", [{"detail": "Something went wrong"}], @@ -581,11 +652,22 @@ def test_parse_records_remotely( if raises: with pytest.raises(AirbyteTracedException) as exc: - list(UnstructuredParser().parse_records(config, fake_file, stream_reader, logger, MagicMock())) + list( + UnstructuredParser().parse_records( + config, fake_file, stream_reader, logger, MagicMock() + ) + ) # Failures from the API are treated as config errors assert exc.value.failure_type == FailureType.config_error else: - assert list(UnstructuredParser().parse_records(config, fake_file, stream_reader, logger, MagicMock())) == expected_records + assert ( + list( + UnstructuredParser().parse_records( + config, fake_file, stream_reader, logger, MagicMock() + ) + ) + == expected_records + ) if expected_requests: requests_mock.post.assert_has_calls(expected_requests) diff --git a/unit_tests/sources/file_based/helpers.py b/unit_tests/sources/file_based/helpers.py index 6d4966e2..2138cdc5 100644 --- a/unit_tests/sources/file_based/helpers.py +++ b/unit_tests/sources/file_based/helpers.py @@ -9,7 +9,10 @@ from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig from airbyte_cdk.sources.file_based.discovery_policy import DefaultDiscoveryPolicy -from airbyte_cdk.sources.file_based.file_based_stream_reader import AbstractFileBasedStreamReader, FileReadMode +from airbyte_cdk.sources.file_based.file_based_stream_reader import ( + AbstractFileBasedStreamReader, + FileReadMode, +) from airbyte_cdk.sources.file_based.file_types.csv_parser import CsvParser from airbyte_cdk.sources.file_based.file_types.file_type_parser import FileTypeParser from airbyte_cdk.sources.file_based.file_types.jsonl_parser import JsonlParser @@ -22,7 +25,11 @@ class EmptySchemaParser(CsvParser): async def infer_schema( - self, config: FileBasedStreamConfig, file: RemoteFile, stream_reader: AbstractFileBasedStreamReader, logger: logging.Logger + self, + config: FileBasedStreamConfig, + file: RemoteFile, + stream_reader: AbstractFileBasedStreamReader, + logger: logging.Logger, ) -> Dict[str, Any]: return {} @@ -46,7 +53,13 @@ def get_matching_files( class TestErrorOpenFileInMemoryFilesStreamReader(InMemoryFilesStreamReader): - def open_file(self, file: RemoteFile, file_read_mode: FileReadMode, encoding: Optional[str], logger: logging.Logger) -> IOBase: + def open_file( + self, + file: RemoteFile, + file_read_mode: FileReadMode, + encoding: Optional[str], + logger: logging.Logger, + ) -> IOBase: raise Exception("Error opening file") @@ -54,7 +67,9 @@ class FailingSchemaValidationPolicy(AbstractSchemaValidationPolicy): ALWAYS_FAIL = "always_fail" validate_schema_before_sync = True - def record_passes_validation_policy(self, record: Mapping[str, Any], schema: Optional[Mapping[str, Any]]) -> bool: + def record_passes_validation_policy( + self, record: Mapping[str, Any], schema: Optional[Mapping[str, Any]] + ) -> bool: return False @@ -67,4 +82,10 @@ class LowHistoryLimitConcurrentCursor(FileBasedConcurrentCursor): def make_remote_files(files: List[str]) -> List[RemoteFile]: - return [RemoteFile(uri=f, last_modified=datetime.strptime("2023-06-05T03:54:07.000Z", "%Y-%m-%dT%H:%M:%S.%fZ")) for f in files] + return [ + RemoteFile( + uri=f, + last_modified=datetime.strptime("2023-06-05T03:54:07.000Z", "%Y-%m-%dT%H:%M:%S.%fZ"), + ) + for f in files + ] diff --git a/unit_tests/sources/file_based/in_memory_files_source.py b/unit_tests/sources/file_based/in_memory_files_source.py index b8448dbc..1a6ef55b 100644 --- a/unit_tests/sources/file_based/in_memory_files_source.py +++ b/unit_tests/sources/file_based/in_memory_files_source.py @@ -17,15 +17,30 @@ import pyarrow as pa import pyarrow.parquet as pq from airbyte_cdk.models import ConfiguredAirbyteCatalog, ConfiguredAirbyteCatalogSerializer -from airbyte_cdk.sources.file_based.availability_strategy import AbstractFileBasedAvailabilityStrategy, DefaultFileBasedAvailabilityStrategy +from airbyte_cdk.sources.file_based.availability_strategy import ( + AbstractFileBasedAvailabilityStrategy, + DefaultFileBasedAvailabilityStrategy, +) from airbyte_cdk.sources.file_based.config.abstract_file_based_spec import AbstractFileBasedSpec -from airbyte_cdk.sources.file_based.discovery_policy import AbstractDiscoveryPolicy, DefaultDiscoveryPolicy +from airbyte_cdk.sources.file_based.discovery_policy import ( + AbstractDiscoveryPolicy, + DefaultDiscoveryPolicy, +) from airbyte_cdk.sources.file_based.file_based_source import FileBasedSource -from airbyte_cdk.sources.file_based.file_based_stream_reader import AbstractFileBasedStreamReader, FileReadMode +from airbyte_cdk.sources.file_based.file_based_stream_reader import ( + AbstractFileBasedStreamReader, + FileReadMode, +) from airbyte_cdk.sources.file_based.file_types.file_type_parser import FileTypeParser from airbyte_cdk.sources.file_based.remote_file import RemoteFile -from airbyte_cdk.sources.file_based.schema_validation_policies import DEFAULT_SCHEMA_VALIDATION_POLICIES, AbstractSchemaValidationPolicy -from airbyte_cdk.sources.file_based.stream.cursor import AbstractFileBasedCursor, DefaultFileBasedCursor +from airbyte_cdk.sources.file_based.schema_validation_policies import ( + DEFAULT_SCHEMA_VALIDATION_POLICIES, + AbstractSchemaValidationPolicy, +) +from airbyte_cdk.sources.file_based.stream.cursor import ( + AbstractFileBasedCursor, + DefaultFileBasedCursor, +) from airbyte_cdk.sources.source import TState from avro import datafile from pydantic.v1 import AnyUrl @@ -53,13 +68,19 @@ def __init__( self.files = files self.file_type = file_type self.catalog = catalog - self.configured_catalog = ConfiguredAirbyteCatalogSerializer.load(self.catalog) if self.catalog else None + self.configured_catalog = ( + ConfiguredAirbyteCatalogSerializer.load(self.catalog) if self.catalog else None + ) self.config = config self.state = state # Source setup - stream_reader = stream_reader or InMemoryFilesStreamReader(files=files, file_type=file_type, file_write_options=file_write_options) - availability_strategy = availability_strategy or DefaultFileBasedAvailabilityStrategy(stream_reader) # type: ignore[assignment] + stream_reader = stream_reader or InMemoryFilesStreamReader( + files=files, file_type=file_type, file_write_options=file_write_options + ) + availability_strategy = availability_strategy or DefaultFileBasedAvailabilityStrategy( + stream_reader + ) # type: ignore[assignment] super().__init__( stream_reader, spec_class=InMemorySpec, @@ -78,7 +99,12 @@ def read_catalog(self, catalog_path: str) -> ConfiguredAirbyteCatalog: class InMemoryFilesStreamReader(AbstractFileBasedStreamReader): - def __init__(self, files: Mapping[str, Mapping[str, Any]], file_type: str, file_write_options: Optional[Mapping[str, Any]] = None): + def __init__( + self, + files: Mapping[str, Mapping[str, Any]], + file_type: str, + file_write_options: Optional[Mapping[str, Any]] = None, + ): self.files = files self.file_type = file_type self.file_write_options = file_write_options @@ -113,10 +139,14 @@ def get_matching_files( def file_size(self, file: RemoteFile) -> int: return 0 - def get_file(self, file: RemoteFile, local_directory: str, logger: logging.Logger) -> Dict[str, Any]: + def get_file( + self, file: RemoteFile, local_directory: str, logger: logging.Logger + ) -> Dict[str, Any]: return {} - def open_file(self, file: RemoteFile, mode: FileReadMode, encoding: Optional[str], logger: logging.Logger) -> IOBase: + def open_file( + self, file: RemoteFile, mode: FileReadMode, encoding: Optional[str], logger: logging.Logger + ) -> IOBase: if self.file_type == "csv": return self._make_csv_file_contents(file.uri) elif self.file_type == "jsonl": @@ -168,7 +198,9 @@ def _make_binary_file_contents(self, file_name: str) -> IOBase: class InMemorySpec(AbstractFileBasedSpec): @classmethod def documentation_url(cls) -> AnyUrl: - return AnyUrl(scheme="https", url="https://docs.airbyte.com/integrations/sources/in_memory_files") # type: ignore + return AnyUrl( + scheme="https", url="https://docs.airbyte.com/integrations/sources/in_memory_files" + ) # type: ignore class TemporaryParquetFilesStreamReader(InMemoryFilesStreamReader): @@ -176,7 +208,9 @@ class TemporaryParquetFilesStreamReader(InMemoryFilesStreamReader): A file reader that writes RemoteFiles to a temporary file and then reads them back. """ - def open_file(self, file: RemoteFile, mode: FileReadMode, encoding: Optional[str], logger: logging.Logger) -> IOBase: + def open_file( + self, file: RemoteFile, mode: FileReadMode, encoding: Optional[str], logger: logging.Logger + ) -> IOBase: return io.BytesIO(self._create_file(file.uri)) def _create_file(self, file_name: str) -> bytes: @@ -197,7 +231,9 @@ class TemporaryAvroFilesStreamReader(InMemoryFilesStreamReader): A file reader that writes RemoteFiles to a temporary file and then reads them back. """ - def open_file(self, file: RemoteFile, mode: FileReadMode, encoding: Optional[str], logger: logging.Logger) -> IOBase: + def open_file( + self, file: RemoteFile, mode: FileReadMode, encoding: Optional[str], logger: logging.Logger + ) -> IOBase: return io.BytesIO(self._make_file_contents(file.uri)) def _make_file_contents(self, file_name: str) -> bytes: @@ -221,7 +257,9 @@ class TemporaryExcelFilesStreamReader(InMemoryFilesStreamReader): A file reader that writes RemoteFiles to a temporary file and then reads them back. """ - def open_file(self, file: RemoteFile, mode: FileReadMode, encoding: Optional[str], logger: logging.Logger) -> IOBase: + def open_file( + self, file: RemoteFile, mode: FileReadMode, encoding: Optional[str], logger: logging.Logger + ) -> IOBase: return io.BytesIO(self._make_file_contents(file.uri)) def _make_file_contents(self, file_name: str) -> bytes: diff --git a/unit_tests/sources/file_based/scenarios/avro_scenarios.py b/unit_tests/sources/file_based/scenarios/avro_scenarios.py index 7b891a16..77f51c68 100644 --- a/unit_tests/sources/file_based/scenarios/avro_scenarios.py +++ b/unit_tests/sources/file_based/scenarios/avro_scenarios.py @@ -35,12 +35,25 @@ "fields": [ {"name": "col_double", "type": "double"}, {"name": "col_string", "type": "string"}, - {"name": "col_album", "type": {"type": "record", "name": "Album", "fields": [{"name": "album", "type": "string"}]}}, + { + "name": "col_album", + "type": { + "type": "record", + "name": "Album", + "fields": [{"name": "album", "type": "string"}], + }, + }, ], }, "contents": [ (20.02, "Robbers", {"album": "The 1975"}), - (20.23, "Somebody Else", {"album": "I Like It When You Sleep, for You Are So Beautiful yet So Unaware of It"}), + ( + 20.23, + "Somebody Else", + { + "album": "I Like It When You Sleep, for You Are So Beautiful yet So Unaware of It" + }, + ), ], "last_modified": "2023-06-05T03:54:07.000Z", }, @@ -51,11 +64,22 @@ "fields": [ {"name": "col_double", "type": "double"}, {"name": "col_string", "type": "string"}, - {"name": "col_song", "type": {"type": "record", "name": "Song", "fields": [{"name": "title", "type": "string"}]}}, + { + "name": "col_song", + "type": { + "type": "record", + "name": "Song", + "fields": [{"name": "title", "type": "string"}], + }, + }, ], }, "contents": [ - (1975.1975, "It's Not Living (If It's Not with You)", {"title": "Love It If We Made It"}), + ( + 1975.1975, + "It's Not Living (If It's Not with You)", + {"title": "Love It If We Made It"}, + ), (5791.5791, "The 1975", {"title": "About You"}), ], "last_modified": "2023-06-06T03:54:07.000Z", @@ -89,18 +113,39 @@ ], }, }, - {"name": "col_enum", "type": {"type": "enum", "name": "Genre", "symbols": ["POP_ROCK", "INDIE_ROCK", "ALTERNATIVE_ROCK"]}}, + { + "name": "col_enum", + "type": { + "type": "enum", + "name": "Genre", + "symbols": ["POP_ROCK", "INDIE_ROCK", "ALTERNATIVE_ROCK"], + }, + }, {"name": "col_array", "type": {"type": "array", "items": "string"}}, {"name": "col_map", "type": {"type": "map", "values": "string"}}, {"name": "col_fixed", "type": {"type": "fixed", "name": "MyFixed", "size": 4}}, # Logical Types - {"name": "col_decimal", "type": {"type": "bytes", "logicalType": "decimal", "precision": 10, "scale": 5}}, + { + "name": "col_decimal", + "type": { + "type": "bytes", + "logicalType": "decimal", + "precision": 10, + "scale": 5, + }, + }, {"name": "col_uuid", "type": {"type": "string", "logicalType": "uuid"}}, {"name": "col_date", "type": {"type": "int", "logicalType": "date"}}, {"name": "col_time_millis", "type": {"type": "int", "logicalType": "time-millis"}}, {"name": "col_time_micros", "type": {"type": "long", "logicalType": "time-micros"}}, - {"name": "col_timestamp_millis", "type": {"type": "long", "logicalType": "timestamp-millis"}}, - {"name": "col_timestamp_micros", "type": {"type": "long", "logicalType": "timestamp-micros"}}, + { + "name": "col_timestamp_millis", + "type": {"type": "long", "logicalType": "timestamp-millis"}, + }, + { + "name": "col_timestamp_micros", + "type": {"type": "long", "logicalType": "timestamp-micros"}, + }, ], }, "contents": [ @@ -121,7 +166,12 @@ "Notes on a Conditional Form", "Being Funny in a Foreign Language", ], - {"lead_singer": "Matty Healy", "lead_guitar": "Adam Hann", "bass_guitar": "Ross MacDonald", "drummer": "George Daniel"}, + { + "lead_singer": "Matty Healy", + "lead_guitar": "Adam Hann", + "bass_guitar": "Ross MacDonald", + "drummer": "George Daniel", + }, b"\x12\x34\x56\x78", decimal.Decimal("1234.56789"), "123e4567-e89b-12d3-a456-426655440000", @@ -148,7 +198,12 @@ "type": { "type": "enum", "name": "Album", - "symbols": ["SUMMERS_GONE", "IN_RETURN", "A_MOMENT_APART", "THE_LAST_GOODBYE"], + "symbols": [ + "SUMMERS_GONE", + "IN_RETURN", + "A_MOMENT_APART", + "THE_LAST_GOODBYE", + ], }, }, {"name": "col_year", "type": "int"}, @@ -188,8 +243,16 @@ "contents": [ ("Coachella", {"country": "USA", "state": "California", "city": "Indio"}, 250000), ("CRSSD", {"country": "USA", "state": "California", "city": "San Diego"}, 30000), - ("Lightning in a Bottle", {"country": "USA", "state": "California", "city": "Buena Vista Lake"}, 18000), - ("Outside Lands", {"country": "USA", "state": "California", "city": "San Francisco"}, 220000), + ( + "Lightning in a Bottle", + {"country": "USA", "state": "California", "city": "Buena Vista Lake"}, + 18000, + ), + ( + "Outside Lands", + {"country": "USA", "state": "California", "city": "San Francisco"}, + 220000, + ), ], "last_modified": "2023-06-06T03:54:07.000Z", }, @@ -212,7 +275,9 @@ ) .set_source_builder( FileBasedSourceBuilder() - .set_stream_reader(TemporaryAvroFilesStreamReader(files=_single_avro_file, file_type="avro")) + .set_stream_reader( + TemporaryAvroFilesStreamReader(files=_single_avro_file, file_type="avro") + ) .set_file_type("avro") ) .set_expected_check_status("SUCCEEDED") @@ -279,7 +344,11 @@ ) .set_source_builder( FileBasedSourceBuilder() - .set_stream_reader(TemporaryAvroFilesStreamReader(files=_multiple_avro_combine_schema_file, file_type="avro")) + .set_stream_reader( + TemporaryAvroFilesStreamReader( + files=_multiple_avro_combine_schema_file, file_type="avro" + ) + ) .set_file_type("avro") ) .set_expected_records( @@ -298,7 +367,9 @@ "data": { "col_double": 20.23, "col_string": "Somebody Else", - "col_album": {"album": "I Like It When You Sleep, for You Are So Beautiful yet So Unaware of It"}, + "col_album": { + "album": "I Like It When You Sleep, for You Are So Beautiful yet So Unaware of It" + }, "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "a.avro", }, @@ -379,7 +450,9 @@ ) .set_source_builder( FileBasedSourceBuilder() - .set_stream_reader(TemporaryAvroFilesStreamReader(files=_avro_all_types_file, file_type="avro")) + .set_stream_reader( + TemporaryAvroFilesStreamReader(files=_avro_all_types_file, file_type="avro") + ) .set_file_type("avro") ) .set_expected_records( @@ -431,16 +504,28 @@ "json_schema": { "type": "object", "properties": { - "col_array": {"items": {"type": ["null", "string"]}, "type": ["null", "array"]}, + "col_array": { + "items": {"type": ["null", "string"]}, + "type": ["null", "array"], + }, "col_bool": {"type": ["null", "boolean"]}, "col_bytes": {"type": ["null", "string"]}, "col_double": {"type": ["null", "number"]}, - "col_enum": {"enum": ["POP_ROCK", "INDIE_ROCK", "ALTERNATIVE_ROCK"], "type": ["null", "string"]}, - "col_fixed": {"pattern": "^[0-9A-Fa-f]{8}$", "type": ["null", "string"]}, + "col_enum": { + "enum": ["POP_ROCK", "INDIE_ROCK", "ALTERNATIVE_ROCK"], + "type": ["null", "string"], + }, + "col_fixed": { + "pattern": "^[0-9A-Fa-f]{8}$", + "type": ["null", "string"], + }, "col_float": {"type": ["null", "number"]}, "col_int": {"type": ["null", "integer"]}, "col_long": {"type": ["null", "integer"]}, - "col_map": {"additionalProperties": {"type": ["null", "string"]}, "type": ["null", "object"]}, + "col_map": { + "additionalProperties": {"type": ["null", "string"]}, + "type": ["null", "object"], + }, "col_record": { "properties": { "artist": {"type": ["null", "string"]}, @@ -450,12 +535,18 @@ "type": ["null", "object"], }, "col_string": {"type": ["null", "string"]}, - "col_decimal": {"pattern": "^-?\\d{(1, 5)}(?:\\.\\d(1, 5))?$", "type": ["null", "string"]}, + "col_decimal": { + "pattern": "^-?\\d{(1, 5)}(?:\\.\\d(1, 5))?$", + "type": ["null", "string"], + }, "col_uuid": {"type": ["null", "string"]}, "col_date": {"format": "date", "type": ["null", "string"]}, "col_time_millis": {"type": ["null", "integer"]}, "col_time_micros": {"type": ["null", "integer"]}, - "col_timestamp_millis": {"format": "date-time", "type": ["null", "string"]}, + "col_timestamp_millis": { + "format": "date-time", + "type": ["null", "string"], + }, "col_timestamp_micros": {"type": ["null", "string"]}, "_ab_source_file_last_modified": {"type": "string"}, "_ab_source_file_url": {"type": "string"}, @@ -494,7 +585,9 @@ ) .set_source_builder( FileBasedSourceBuilder() - .set_stream_reader(TemporaryAvroFilesStreamReader(files=_multiple_avro_stream_file, file_type="avro")) + .set_stream_reader( + TemporaryAvroFilesStreamReader(files=_multiple_avro_stream_file, file_type="avro") + ) .set_file_type("avro") ) .set_expected_records( @@ -577,7 +670,11 @@ { "data": { "col_name": "Lightning in a Bottle", - "col_location": {"country": "USA", "state": "California", "city": "Buena Vista Lake"}, + "col_location": { + "country": "USA", + "state": "California", + "city": "Buena Vista Lake", + }, "col_attendance": 18000, "_ab_source_file_last_modified": "2023-06-06T03:54:07.000000Z", "_ab_source_file_url": "california_festivals.avro", @@ -587,7 +684,11 @@ { "data": { "col_name": "Outside Lands", - "col_location": {"country": "USA", "state": "California", "city": "San Francisco"}, + "col_location": { + "country": "USA", + "state": "California", + "city": "San Francisco", + }, "col_attendance": 220000, "_ab_source_file_last_modified": "2023-06-06T03:54:07.000000Z", "_ab_source_file_url": "california_festivals.avro", @@ -607,7 +708,12 @@ "col_title": {"type": ["null", "string"]}, "col_album": { "type": ["null", "string"], - "enum": ["SUMMERS_GONE", "IN_RETURN", "A_MOMENT_APART", "THE_LAST_GOODBYE"], + "enum": [ + "SUMMERS_GONE", + "IN_RETURN", + "A_MOMENT_APART", + "THE_LAST_GOODBYE", + ], }, "col_year": {"type": ["null", "integer"]}, "col_vocals": {"type": ["null", "boolean"]}, @@ -666,7 +772,11 @@ ) .set_source_builder( FileBasedSourceBuilder() - .set_stream_reader(TemporaryAvroFilesStreamReader(files=_multiple_avro_combine_schema_file, file_type="avro")) + .set_stream_reader( + TemporaryAvroFilesStreamReader( + files=_multiple_avro_combine_schema_file, file_type="avro" + ) + ) .set_file_type("avro") ) .set_expected_records( @@ -685,7 +795,9 @@ "data": { "col_double": 20.23, "col_string": "Somebody Else", - "col_album": {"album": "I Like It When You Sleep, for You Are So Beautiful yet So Unaware of It"}, + "col_album": { + "album": "I Like It When You Sleep, for You Are So Beautiful yet So Unaware of It" + }, "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "a.avro", }, diff --git a/unit_tests/sources/file_based/scenarios/check_scenarios.py b/unit_tests/sources/file_based/scenarios/check_scenarios.py index 26136d9c..9a235b9e 100644 --- a/unit_tests/sources/file_based/scenarios/check_scenarios.py +++ b/unit_tests/sources/file_based/scenarios/check_scenarios.py @@ -139,7 +139,9 @@ .set_name("error_listing_files_scenario") .set_source_builder( _base_failure_scenario.source_builder.copy().set_stream_reader( - TestErrorListMatchingFilesInMemoryFilesStreamReader(files=_base_failure_scenario.source_builder._files, file_type="csv") + TestErrorListMatchingFilesInMemoryFilesStreamReader( + files=_base_failure_scenario.source_builder._files, file_type="csv" + ) ) ) .set_expected_check_error(None, FileBasedSourceError.ERROR_LISTING_FILES.value) @@ -151,7 +153,9 @@ .set_name("error_reading_file_scenario") .set_source_builder( _base_failure_scenario.source_builder.copy().set_stream_reader( - TestErrorOpenFileInMemoryFilesStreamReader(files=_base_failure_scenario.source_builder._files, file_type="csv") + TestErrorOpenFileInMemoryFilesStreamReader( + files=_base_failure_scenario.source_builder._files, file_type="csv" + ) ) ) .set_expected_check_error(None, FileBasedSourceError.ERROR_READING_FILE.value) @@ -189,7 +193,9 @@ } ) .set_file_type("csv") - .set_validation_policies({FailingSchemaValidationPolicy.ALWAYS_FAIL: FailingSchemaValidationPolicy()}) + .set_validation_policies( + {FailingSchemaValidationPolicy.ALWAYS_FAIL: FailingSchemaValidationPolicy()} + ) ) .set_expected_check_error(None, FileBasedSourceError.ERROR_VALIDATING_RECORD.value) ).build() diff --git a/unit_tests/sources/file_based/scenarios/concurrent_incremental_scenarios.py b/unit_tests/sources/file_based/scenarios/concurrent_incremental_scenarios.py index e5a7ee41..92ce67fe 100644 --- a/unit_tests/sources/file_based/scenarios/concurrent_incremental_scenarios.py +++ b/unit_tests/sources/file_based/scenarios/concurrent_incremental_scenarios.py @@ -6,7 +6,10 @@ from airbyte_cdk.test.state_builder import StateBuilder from unit_tests.sources.file_based.helpers import LowHistoryLimitConcurrentCursor from unit_tests.sources.file_based.scenarios.file_based_source_builder import FileBasedSourceBuilder -from unit_tests.sources.file_based.scenarios.scenario_builder import IncrementalScenarioConfig, TestScenarioBuilder +from unit_tests.sources.file_based.scenarios.scenario_builder import ( + IncrementalScenarioConfig, + TestScenarioBuilder, +) single_csv_input_state_is_earlier_scenario_concurrent = ( TestScenarioBuilder() @@ -74,7 +77,10 @@ "stream": "stream1", }, { - "history": {"some_old_file.csv": "2023-06-01T03:54:07.000000Z", "a.csv": "2023-06-05T03:54:07.000000Z"}, + "history": { + "some_old_file.csv": "2023-06-01T03:54:07.000000Z", + "a.csv": "2023-06-05T03:54:07.000000Z", + }, "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z_a.csv", }, ] @@ -488,7 +494,10 @@ "stream": "stream1", }, { - "history": {"a.csv": "2023-06-05T03:54:07.000000Z", "b.csv": "2023-06-05T03:54:07.000000Z"}, + "history": { + "a.csv": "2023-06-05T03:54:07.000000Z", + "b.csv": "2023-06-05T03:54:07.000000Z", + }, "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z_b.csv", }, ] @@ -718,7 +727,10 @@ "stream": "stream1", }, { - "history": {"a.csv": "2023-06-04T03:54:07.000000Z", "b.csv": "2023-06-05T03:54:07.000000Z"}, + "history": { + "a.csv": "2023-06-04T03:54:07.000000Z", + "b.csv": "2023-06-05T03:54:07.000000Z", + }, "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z_b.csv", }, ] @@ -848,7 +860,10 @@ "stream": "stream1", }, { - "history": {"a.csv": "2023-06-05T03:54:07.000000Z", "b.csv": "2023-06-05T03:54:07.000000Z"}, + "history": { + "a.csv": "2023-06-05T03:54:07.000000Z", + "b.csv": "2023-06-05T03:54:07.000000Z", + }, "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z_b.csv", }, { @@ -990,7 +1005,10 @@ "stream": "stream1", }, { - "history": {"a.csv": "2023-06-05T03:54:07.000000Z", "b.csv": "2023-06-05T03:54:07.000000Z"}, + "history": { + "a.csv": "2023-06-05T03:54:07.000000Z", + "b.csv": "2023-06-05T03:54:07.000000Z", + }, "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z_b.csv", }, { @@ -1028,7 +1046,10 @@ input_state=StateBuilder() .with_stream_state( "stream1", - {"history": {"a.csv": "2023-06-05T03:54:07.000000Z"}, "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z_a.csv"}, + { + "history": {"a.csv": "2023-06-05T03:54:07.000000Z"}, + "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z_a.csv", + }, ) .build(), ) @@ -1154,7 +1175,10 @@ .with_stream_state( "stream1", { - "history": {"a.csv": "2023-06-05T03:54:07.000000Z", "c.csv": "2023-06-06T03:54:07.000000Z"}, + "history": { + "a.csv": "2023-06-05T03:54:07.000000Z", + "c.csv": "2023-06-06T03:54:07.000000Z", + }, "_ab_source_file_last_modified": "2023-06-06T03:54:07.000000Z_c.csv", }, ) @@ -1282,7 +1306,10 @@ .with_stream_state( "stream1", { - "history": {"a.csv": "2023-06-05T03:54:07.000000Z", "c.csv": "2023-06-06T03:54:07.000000Z"}, + "history": { + "a.csv": "2023-06-05T03:54:07.000000Z", + "c.csv": "2023-06-06T03:54:07.000000Z", + }, "_ab_source_file_last_modified": "2023-06-03T03:54:07.000000Z_x.csv", }, ) @@ -1657,7 +1684,9 @@ multi_csv_same_timestamp_more_files_than_history_size_scenario_concurrent_cursor_is_newer = ( TestScenarioBuilder() - .set_name("multi_csv_same_timestamp_more_files_than_history_size_scenario_concurrent_cursor_is_newer") + .set_name( + "multi_csv_same_timestamp_more_files_than_history_size_scenario_concurrent_cursor_is_newer" + ) .set_config( { "streams": [ @@ -1840,7 +1869,9 @@ multi_csv_same_timestamp_more_files_than_history_size_scenario_concurrent_cursor_is_older = ( TestScenarioBuilder() - .set_name("multi_csv_same_timestamp_more_files_than_history_size_scenario_concurrent_cursor_is_older") + .set_name( + "multi_csv_same_timestamp_more_files_than_history_size_scenario_concurrent_cursor_is_older" + ) .set_config( { "streams": [ @@ -2023,7 +2054,9 @@ multi_csv_sync_recent_files_if_history_is_incomplete_scenario_concurrent_cursor_is_older = ( TestScenarioBuilder() - .set_name("multi_csv_sync_recent_files_if_history_is_incomplete_scenario_concurrent_cursor_is_older") + .set_name( + "multi_csv_sync_recent_files_if_history_is_incomplete_scenario_concurrent_cursor_is_older" + ) .set_config( { "streams": [ @@ -2140,7 +2173,9 @@ multi_csv_sync_recent_files_if_history_is_incomplete_scenario_concurrent_cursor_is_newer = ( TestScenarioBuilder() - .set_name("multi_csv_sync_recent_files_if_history_is_incomplete_scenario_concurrent_cursor_is_newer") + .set_name( + "multi_csv_sync_recent_files_if_history_is_incomplete_scenario_concurrent_cursor_is_newer" + ) .set_config( { "streams": [ @@ -2258,7 +2293,9 @@ multi_csv_sync_files_within_time_window_if_history_is_incomplete__different_timestamps_scenario_concurrent_cursor_is_older = ( TestScenarioBuilder() - .set_name("multi_csv_sync_files_within_time_window_if_history_is_incomplete__different_timestamps_scenario_concurrent_cursor_is_older") + .set_name( + "multi_csv_sync_files_within_time_window_if_history_is_incomplete__different_timestamps_scenario_concurrent_cursor_is_older" + ) .set_config( { "streams": [ @@ -2397,7 +2434,9 @@ multi_csv_sync_files_within_time_window_if_history_is_incomplete__different_timestamps_scenario_concurrent_cursor_is_newer = ( TestScenarioBuilder() - .set_name("multi_csv_sync_files_within_time_window_if_history_is_incomplete__different_timestamps_scenario_concurrent_cursor_is_newer") + .set_name( + "multi_csv_sync_files_within_time_window_if_history_is_incomplete__different_timestamps_scenario_concurrent_cursor_is_newer" + ) .set_config( { "streams": [ diff --git a/unit_tests/sources/file_based/scenarios/csv_scenarios.py b/unit_tests/sources/file_based/scenarios/csv_scenarios.py index d88c38ec..2f4f02cf 100644 --- a/unit_tests/sources/file_based/scenarios/csv_scenarios.py +++ b/unit_tests/sources/file_based/scenarios/csv_scenarios.py @@ -7,10 +7,16 @@ from airbyte_cdk.sources.file_based.exceptions import ConfigValidationError, FileBasedSourceError from airbyte_cdk.test.catalog_builder import CatalogBuilder from airbyte_cdk.utils.traced_exception import AirbyteTracedException -from unit_tests.sources.file_based.helpers import EmptySchemaParser, LowInferenceLimitDiscoveryPolicy +from unit_tests.sources.file_based.helpers import ( + EmptySchemaParser, + LowInferenceLimitDiscoveryPolicy, +) from unit_tests.sources.file_based.in_memory_files_source import InMemoryFilesSource from unit_tests.sources.file_based.scenarios.file_based_source_builder import FileBasedSourceBuilder -from unit_tests.sources.file_based.scenarios.scenario_builder import TestScenario, TestScenarioBuilder +from unit_tests.sources.file_based.scenarios.scenario_builder import ( + TestScenario, + TestScenarioBuilder, +) single_csv_scenario: TestScenario[InMemoryFilesSource] = ( TestScenarioBuilder[InMemoryFilesSource]() @@ -71,7 +77,11 @@ "title": "FileBasedStreamConfig", "type": "object", "properties": { - "name": {"title": "Name", "description": "The name of the stream.", "type": "string"}, + "name": { + "title": "Name", + "description": "The name of the stream.", + "type": "string", + }, "globs": { "title": "Globs", "description": 'The pattern used to specify which files should be selected from the file system. For more information on glob pattern matching look here.', @@ -118,7 +128,12 @@ "title": "Avro Format", "type": "object", "properties": { - "filetype": {"title": "Filetype", "default": "avro", "const": "avro", "type": "string"}, + "filetype": { + "title": "Filetype", + "default": "avro", + "const": "avro", + "type": "string", + }, "double_as_string": { "title": "Convert Double Fields to Strings", "description": "Whether to convert double fields to strings. This is recommended if you have decimal numbers with a high degree of precision because there can be a loss precision when handling floating point numbers.", @@ -132,7 +147,12 @@ "title": "CSV Format", "type": "object", "properties": { - "filetype": {"title": "Filetype", "default": "csv", "const": "csv", "type": "string"}, + "filetype": { + "title": "Filetype", + "default": "csv", + "const": "csv", + "type": "string", + }, "delimiter": { "title": "Delimiter", "description": "The character delimiting individual cells in the CSV data. This may only be a 1-character string. For tab-delimited data enter '\\t'.", @@ -192,7 +212,9 @@ "title": "CSV Header Definition", "type": "object", "description": "How headers will be defined. `User Provided` assumes the CSV does not have a header row and uses the headers provided and `Autogenerated` assumes the CSV does not have a header row and the CDK will generate headers using for `f{i}` where `i` is the index starting from 0. Else, the default behavior is to use the header from the CSV file. If a user wants to autogenerate or provide column names for a CSV having headers, they can skip rows.", - "default": {"header_definition_type": "From CSV"}, + "default": { + "header_definition_type": "From CSV" + }, "oneOf": [ { "title": "From CSV", @@ -237,7 +259,10 @@ "items": {"type": "string"}, }, }, - "required": ["column_names", "header_definition_type"], + "required": [ + "column_names", + "header_definition_type", + ], }, ], }, @@ -252,7 +277,14 @@ "false_values": { "title": "False Values", "description": "A set of case-sensitive strings that should be interpreted as false values.", - "default": ["n", "no", "f", "false", "off", "0"], + "default": [ + "n", + "no", + "f", + "false", + "off", + "0", + ], "type": "array", "items": {"type": "string"}, "uniqueItems": True, @@ -277,7 +309,12 @@ "title": "Jsonl Format", "type": "object", "properties": { - "filetype": {"title": "Filetype", "default": "jsonl", "const": "jsonl", "type": "string"} + "filetype": { + "title": "Filetype", + "default": "jsonl", + "const": "jsonl", + "type": "string", + } }, "required": ["filetype"], }, @@ -371,7 +408,9 @@ "description": "The URL of the unstructured API to use", "default": "https://api.unstructured.io", "always_show": True, - "examples": ["https://api.unstructured.com"], + "examples": [ + "https://api.unstructured.com" + ], "type": "string", }, "parameters": { @@ -387,17 +426,26 @@ "name": { "title": "Parameter name", "description": "The name of the unstructured API parameter to use", - "examples": ["combine_under_n_chars", "languages"], + "examples": [ + "combine_under_n_chars", + "languages", + ], "type": "string", }, "value": { "title": "Value", "description": "The value of the parameter", - "examples": ["true", "hi_res"], + "examples": [ + "true", + "hi_res", + ], "type": "string", }, }, - "required": ["name", "value"], + "required": [ + "name", + "value", + ], }, }, }, @@ -414,7 +462,12 @@ "title": "Excel Format", "type": "object", "properties": { - "filetype": {"title": "Filetype", "default": "excel", "const": "excel", "type": "string"} + "filetype": { + "title": "Filetype", + "default": "excel", + "const": "excel", + "type": "string", + } }, "required": ["filetype"], }, @@ -1065,7 +1118,9 @@ } ) .set_expected_records([]) - .set_expected_discover_error(AirbyteTracedException, FileBasedSourceError.SCHEMA_INFERENCE_ERROR.value) + .set_expected_discover_error( + AirbyteTracedException, FileBasedSourceError.SCHEMA_INFERENCE_ERROR.value + ) .set_expected_logs( { "read": [ @@ -1165,7 +1220,9 @@ } ) .set_expected_records([]) - .set_expected_discover_error(AirbyteTracedException, FileBasedSourceError.SCHEMA_INFERENCE_ERROR.value) + .set_expected_discover_error( + AirbyteTracedException, FileBasedSourceError.SCHEMA_INFERENCE_ERROR.value + ) .set_expected_logs( { "read": [ @@ -1180,7 +1237,9 @@ ] } ) - .set_expected_read_error(AirbyteTracedException, "Please check the logged errors for more information.") + .set_expected_read_error( + AirbyteTracedException, "Please check the logged errors for more information." + ) ).build() csv_single_stream_scenario: TestScenario[InMemoryFilesSource] = ( @@ -1371,19 +1430,35 @@ "stream": "stream1", }, { - "data": {"col3": "val13b", "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "b.csv"}, + "data": { + "col3": "val13b", + "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", + "_ab_source_file_url": "b.csv", + }, "stream": "stream1", }, { - "data": {"col3": "val23b", "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "b.csv"}, + "data": { + "col3": "val23b", + "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", + "_ab_source_file_url": "b.csv", + }, "stream": "stream1", }, { - "data": {"col3": "val13b", "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "b.csv"}, + "data": { + "col3": "val13b", + "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", + "_ab_source_file_url": "b.csv", + }, "stream": "stream2", }, { - "data": {"col3": "val23b", "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "b.csv"}, + "data": { + "col3": "val23b", + "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", + "_ab_source_file_url": "b.csv", + }, "stream": "stream2", }, ] @@ -1509,7 +1584,13 @@ "name": "stream1", "globs": ["*.csv"], "validation_policy": "Emit Record", - "format": {"filetype": "csv", "delimiter": "#", "escape_char": "!", "double_quote": True, "newlines_in_values": False}, + "format": { + "filetype": "csv", + "delimiter": "#", + "escape_char": "!", + "double_quote": True, + "newlines_in_values": False, + }, }, { "name": "stream2", @@ -1630,7 +1711,11 @@ "stream": "stream1", }, { - "data": {"col3": "val23b", "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "b.csv"}, + "data": { + "col3": "val23b", + "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", + "_ab_source_file_url": "b.csv", + }, "stream": "stream1", }, { @@ -1642,7 +1727,11 @@ "stream": "stream2", }, { - "data": {"col3": "val23b", "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "b.csv"}, + "data": { + "col3": "val23b", + "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", + "_ab_source_file_url": "b.csv", + }, "stream": "stream2", }, ] @@ -1703,7 +1792,9 @@ ] } ) - .set_expected_discover_error(AirbyteTracedException, FileBasedSourceError.SCHEMA_INFERENCE_ERROR.value) + .set_expected_discover_error( + AirbyteTracedException, FileBasedSourceError.SCHEMA_INFERENCE_ERROR.value + ) ).build() schemaless_csv_scenario: TestScenario[InMemoryFilesSource] = ( @@ -1906,18 +1997,28 @@ "stream": "stream1", }, { - "data": {"col3": "val13b", "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "b.csv"}, + "data": { + "col3": "val13b", + "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", + "_ab_source_file_url": "b.csv", + }, "stream": "stream2", }, { - "data": {"col3": "val23b", "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "b.csv"}, + "data": { + "col3": "val23b", + "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", + "_ab_source_file_url": "b.csv", + }, "stream": "stream2", }, ] ) ).build() -schemaless_with_user_input_schema_fails_connection_check_scenario: TestScenario[InMemoryFilesSource] = ( +schemaless_with_user_input_schema_fails_connection_check_scenario: TestScenario[ + InMemoryFilesSource +] = ( TestScenarioBuilder[InMemoryFilesSource]() .set_name("schemaless_with_user_input_schema_fails_connection_check_scenario") .set_config( @@ -1982,11 +2083,17 @@ ) .set_expected_check_status("FAILED") .set_expected_check_error(None, FileBasedSourceError.CONFIG_VALIDATION_ERROR.value) - .set_expected_discover_error(ConfigValidationError, FileBasedSourceError.CONFIG_VALIDATION_ERROR.value) - .set_expected_read_error(ConfigValidationError, FileBasedSourceError.CONFIG_VALIDATION_ERROR.value) + .set_expected_discover_error( + ConfigValidationError, FileBasedSourceError.CONFIG_VALIDATION_ERROR.value + ) + .set_expected_read_error( + ConfigValidationError, FileBasedSourceError.CONFIG_VALIDATION_ERROR.value + ) ).build() -schemaless_with_user_input_schema_fails_connection_check_multi_stream_scenario: TestScenario[InMemoryFilesSource] = ( +schemaless_with_user_input_schema_fails_connection_check_multi_stream_scenario: TestScenario[ + InMemoryFilesSource +] = ( TestScenarioBuilder[InMemoryFilesSource]() .set_name("schemaless_with_user_input_schema_fails_connection_check_multi_stream_scenario") .set_config( @@ -2033,7 +2140,12 @@ ) .set_file_type("csv") ) - .set_catalog(CatalogBuilder().with_stream("stream1", SyncMode.full_refresh).with_stream("stream2", SyncMode.full_refresh).build()) + .set_catalog( + CatalogBuilder() + .with_stream("stream1", SyncMode.full_refresh) + .with_stream("stream2", SyncMode.full_refresh) + .build() + ) .set_expected_catalog( { "streams": [ @@ -2072,8 +2184,12 @@ ) .set_expected_check_status("FAILED") .set_expected_check_error(None, FileBasedSourceError.CONFIG_VALIDATION_ERROR.value) - .set_expected_discover_error(ConfigValidationError, FileBasedSourceError.CONFIG_VALIDATION_ERROR.value) - .set_expected_read_error(ConfigValidationError, FileBasedSourceError.CONFIG_VALIDATION_ERROR.value) + .set_expected_discover_error( + ConfigValidationError, FileBasedSourceError.CONFIG_VALIDATION_ERROR.value + ) + .set_expected_read_error( + ConfigValidationError, FileBasedSourceError.CONFIG_VALIDATION_ERROR.value + ) ).build() csv_string_can_be_null_with_input_schemas_scenario: TestScenario[InMemoryFilesSource] = ( @@ -2147,7 +2263,9 @@ ) ).build() -csv_string_are_not_null_if_strings_can_be_null_is_false_scenario: TestScenario[InMemoryFilesSource] = ( +csv_string_are_not_null_if_strings_can_be_null_is_false_scenario: TestScenario[ + InMemoryFilesSource +] = ( TestScenarioBuilder[InMemoryFilesSource]() .set_name("csv_string_are_not_null_if_strings_can_be_null_is_false") .set_config( @@ -2500,7 +2618,9 @@ AirbyteTracedException, f"{FileBasedSourceError.ERROR_PARSING_RECORD.value} stream=stream1 file=a.csv line_no=2 n_skipped=0", ) - .set_expected_discover_error(AirbyteTracedException, FileBasedSourceError.SCHEMA_INFERENCE_ERROR.value) + .set_expected_discover_error( + AirbyteTracedException, FileBasedSourceError.SCHEMA_INFERENCE_ERROR.value + ) .set_expected_read_error( AirbyteTracedException, "Please check the logged errors for more information.", @@ -2666,7 +2786,13 @@ "name": "stream1", "globs": ["*"], "validation_policy": "Emit Record", - "format": {"filetype": "csv", "double_quotes": True, "quote_char": "@", "delimiter": "|", "escape_char": "+"}, + "format": { + "filetype": "csv", + "double_quotes": True, + "quote_char": "@", + "delimiter": "|", + "escape_char": "+", + }, } ], "start_date": "2023-06-04T03:54:07.000000Z", diff --git a/unit_tests/sources/file_based/scenarios/excel_scenarios.py b/unit_tests/sources/file_based/scenarios/excel_scenarios.py index 66532965..94ccc676 100644 --- a/unit_tests/sources/file_based/scenarios/excel_scenarios.py +++ b/unit_tests/sources/file_based/scenarios/excel_scenarios.py @@ -32,7 +32,11 @@ }, "b.xlsx": { "contents": [ - {"col_double": 1975.1975, "col_string": "It's Not Living (If It's Not with You)", "col_song": "Love It If We Made It"}, + { + "col_double": 1975.1975, + "col_string": "It's Not Living (If It's Not with You)", + "col_song": "Love It If We Made It", + }, {"col_double": 5791.5791, "col_string": "The 1975", "col_song": "About You"}, ], "last_modified": "2023-06-06T03:54:07.000Z", @@ -60,9 +64,24 @@ _multiple_excel_stream_file = { "odesza_songs.xlsx": { "contents": [ - {"col_title": "Late Night", "col_album": "A_MOMENT_APART", "col_year": 2017, "col_vocals": False}, - {"col_title": "White Lies", "col_album": "IN_RETURN", "col_year": 2014, "col_vocals": True}, - {"col_title": "Wide Awake", "col_album": "THE_LAST_GOODBYE", "col_year": 2022, "col_vocals": True}, + { + "col_title": "Late Night", + "col_album": "A_MOMENT_APART", + "col_year": 2017, + "col_vocals": False, + }, + { + "col_title": "White Lies", + "col_album": "IN_RETURN", + "col_year": 2014, + "col_vocals": True, + }, + { + "col_title": "Wide Awake", + "col_album": "THE_LAST_GOODBYE", + "col_year": 2022, + "col_vocals": True, + }, ], "last_modified": "2023-06-05T03:54:07.000Z", }, @@ -70,7 +89,11 @@ "contents": [ { "col_name": "Lightning in a Bottle", - "col_location": {"country": "USA", "state": "California", "city": "Buena Vista Lake"}, + "col_location": { + "country": "USA", + "state": "California", + "city": "Buena Vista Lake", + }, "col_attendance": 18000, }, { @@ -100,7 +123,9 @@ ) .set_source_builder( FileBasedSourceBuilder() - .set_stream_reader(TemporaryExcelFilesStreamReader(files=_single_excel_file, file_type="excel")) + .set_stream_reader( + TemporaryExcelFilesStreamReader(files=_single_excel_file, file_type="excel") + ) .set_file_type("excel") ) .set_expected_check_status("SUCCEEDED") @@ -167,7 +192,11 @@ ) .set_source_builder( FileBasedSourceBuilder() - .set_stream_reader(TemporaryExcelFilesStreamReader(files=_multiple_excel_combine_schema_file, file_type="excel")) + .set_stream_reader( + TemporaryExcelFilesStreamReader( + files=_multiple_excel_combine_schema_file, file_type="excel" + ) + ) .set_file_type("excel") ) .set_expected_records( @@ -257,7 +286,9 @@ ) .set_source_builder( FileBasedSourceBuilder() - .set_stream_reader(TemporaryExcelFilesStreamReader(files=_excel_all_types_file, file_type="excel")) + .set_stream_reader( + TemporaryExcelFilesStreamReader(files=_excel_all_types_file, file_type="excel") + ) .set_file_type("excel") ) .set_expected_records( @@ -332,7 +363,9 @@ ) .set_source_builder( FileBasedSourceBuilder() - .set_stream_reader(TemporaryExcelFilesStreamReader(files=_multiple_excel_stream_file, file_type="excel")) + .set_stream_reader( + TemporaryExcelFilesStreamReader(files=_multiple_excel_stream_file, file_type="excel") + ) .set_file_type("excel") ) .set_expected_records( diff --git a/unit_tests/sources/file_based/scenarios/file_based_source_builder.py b/unit_tests/sources/file_based/scenarios/file_based_source_builder.py index 6675df38..4c2939f6 100644 --- a/unit_tests/sources/file_based/scenarios/file_based_source_builder.py +++ b/unit_tests/sources/file_based/scenarios/file_based_source_builder.py @@ -8,7 +8,10 @@ from airbyte_cdk.sources.file_based.availability_strategy.abstract_file_based_availability_strategy import ( AbstractFileBasedAvailabilityStrategy, ) -from airbyte_cdk.sources.file_based.discovery_policy import AbstractDiscoveryPolicy, DefaultDiscoveryPolicy +from airbyte_cdk.sources.file_based.discovery_policy import ( + AbstractDiscoveryPolicy, + DefaultDiscoveryPolicy, +) from airbyte_cdk.sources.file_based.file_based_source import default_parsers from airbyte_cdk.sources.file_based.file_based_stream_reader import AbstractFileBasedStreamReader from airbyte_cdk.sources.file_based.file_types.file_type_parser import FileTypeParser @@ -34,7 +37,10 @@ def __init__(self) -> None: self._state: Optional[TState] = None def build( - self, configured_catalog: Optional[Mapping[str, Any]], config: Optional[Mapping[str, Any]], state: Optional[TState] + self, + configured_catalog: Optional[Mapping[str, Any]], + config: Optional[Mapping[str, Any]], + state: Optional[TState], ) -> InMemoryFilesSource: if self._file_type is None: raise ValueError("file_type is not set") @@ -65,19 +71,27 @@ def set_parsers(self, parsers: Mapping[Type[Any], FileTypeParser]) -> "FileBased self._parsers = parsers return self - def set_availability_strategy(self, availability_strategy: AbstractFileBasedAvailabilityStrategy) -> "FileBasedSourceBuilder": + def set_availability_strategy( + self, availability_strategy: AbstractFileBasedAvailabilityStrategy + ) -> "FileBasedSourceBuilder": self._availability_strategy = availability_strategy return self - def set_discovery_policy(self, discovery_policy: AbstractDiscoveryPolicy) -> "FileBasedSourceBuilder": + def set_discovery_policy( + self, discovery_policy: AbstractDiscoveryPolicy + ) -> "FileBasedSourceBuilder": self._discovery_policy = discovery_policy return self - def set_validation_policies(self, validation_policies: Mapping[str, AbstractSchemaValidationPolicy]) -> "FileBasedSourceBuilder": + def set_validation_policies( + self, validation_policies: Mapping[str, AbstractSchemaValidationPolicy] + ) -> "FileBasedSourceBuilder": self._validation_policies = validation_policies return self - def set_stream_reader(self, stream_reader: AbstractFileBasedStreamReader) -> "FileBasedSourceBuilder": + def set_stream_reader( + self, stream_reader: AbstractFileBasedStreamReader + ) -> "FileBasedSourceBuilder": self._stream_reader = stream_reader return self @@ -85,7 +99,9 @@ def set_cursor_cls(self, cursor_cls: AbstractFileBasedCursor) -> "FileBasedSourc self._cursor_cls = cursor_cls return self - def set_file_write_options(self, file_write_options: Mapping[str, Any]) -> "FileBasedSourceBuilder": + def set_file_write_options( + self, file_write_options: Mapping[str, Any] + ) -> "FileBasedSourceBuilder": self._file_write_options = file_write_options return self diff --git a/unit_tests/sources/file_based/scenarios/incremental_scenarios.py b/unit_tests/sources/file_based/scenarios/incremental_scenarios.py index 95cde6b7..aea4b484 100644 --- a/unit_tests/sources/file_based/scenarios/incremental_scenarios.py +++ b/unit_tests/sources/file_based/scenarios/incremental_scenarios.py @@ -6,7 +6,10 @@ from airbyte_cdk.test.state_builder import StateBuilder from unit_tests.sources.file_based.helpers import LowHistoryLimitCursor from unit_tests.sources.file_based.scenarios.file_based_source_builder import FileBasedSourceBuilder -from unit_tests.sources.file_based.scenarios.scenario_builder import IncrementalScenarioConfig, TestScenarioBuilder +from unit_tests.sources.file_based.scenarios.scenario_builder import ( + IncrementalScenarioConfig, + TestScenarioBuilder, +) single_csv_input_state_is_earlier_scenario = ( TestScenarioBuilder() @@ -73,7 +76,10 @@ "stream": "stream1", }, { - "history": {"some_old_file.csv": "2023-06-01T03:54:07.000000Z", "a.csv": "2023-06-05T03:54:07.000000Z"}, + "history": { + "some_old_file.csv": "2023-06-01T03:54:07.000000Z", + "a.csv": "2023-06-05T03:54:07.000000Z", + }, "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z_a.csv", }, ] @@ -485,7 +491,10 @@ "stream": "stream1", }, { - "history": {"a.csv": "2023-06-05T03:54:07.000000Z", "b.csv": "2023-06-05T03:54:07.000000Z"}, + "history": { + "a.csv": "2023-06-05T03:54:07.000000Z", + "b.csv": "2023-06-05T03:54:07.000000Z", + }, "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z_b.csv", }, ] @@ -714,7 +723,10 @@ "stream": "stream1", }, { - "history": {"a.csv": "2023-06-04T03:54:07.000000Z", "b.csv": "2023-06-05T03:54:07.000000Z"}, + "history": { + "a.csv": "2023-06-04T03:54:07.000000Z", + "b.csv": "2023-06-05T03:54:07.000000Z", + }, "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z_b.csv", }, ] @@ -844,7 +856,10 @@ "stream": "stream1", }, { - "history": {"a.csv": "2023-06-05T03:54:07.000000Z", "b.csv": "2023-06-05T03:54:07.000000Z"}, + "history": { + "a.csv": "2023-06-05T03:54:07.000000Z", + "b.csv": "2023-06-05T03:54:07.000000Z", + }, "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z_b.csv", }, { @@ -986,7 +1001,10 @@ "stream": "stream1", }, { - "history": {"a.csv": "2023-06-05T03:54:07.000000Z", "b.csv": "2023-06-05T03:54:07.000000Z"}, + "history": { + "a.csv": "2023-06-05T03:54:07.000000Z", + "b.csv": "2023-06-05T03:54:07.000000Z", + }, "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z_b.csv", }, { @@ -1152,7 +1170,10 @@ .with_stream_state( "stream1", { - "history": {"a.csv": "2023-06-05T03:54:07.000000Z", "c.csv": "2023-06-06T03:54:07.000000Z"}, + "history": { + "a.csv": "2023-06-05T03:54:07.000000Z", + "c.csv": "2023-06-06T03:54:07.000000Z", + }, }, ) .build(), @@ -1780,7 +1801,9 @@ multi_csv_sync_files_within_history_time_window_if_history_is_incomplete_different_timestamps_scenario = ( TestScenarioBuilder() - .set_name("multi_csv_sync_files_within_history_time_window_if_history_is_incomplete_different_timestamps") + .set_name( + "multi_csv_sync_files_within_history_time_window_if_history_is_incomplete_different_timestamps" + ) .set_config( { "streams": [ diff --git a/unit_tests/sources/file_based/scenarios/jsonl_scenarios.py b/unit_tests/sources/file_based/scenarios/jsonl_scenarios.py index 23e87930..c4ebafca 100644 --- a/unit_tests/sources/file_based/scenarios/jsonl_scenarios.py +++ b/unit_tests/sources/file_based/scenarios/jsonl_scenarios.py @@ -5,7 +5,10 @@ from airbyte_cdk.sources.file_based.config.jsonl_format import JsonlFormat from airbyte_cdk.sources.file_based.exceptions import FileBasedSourceError from airbyte_cdk.utils.traced_exception import AirbyteTracedException -from unit_tests.sources.file_based.helpers import LowInferenceBytesJsonlParser, LowInferenceLimitDiscoveryPolicy +from unit_tests.sources.file_based.helpers import ( + LowInferenceBytesJsonlParser, + LowInferenceLimitDiscoveryPolicy, +) from unit_tests.sources.file_based.scenarios.file_based_source_builder import FileBasedSourceBuilder from unit_tests.sources.file_based.scenarios.scenario_builder import TestScenarioBuilder @@ -490,8 +493,12 @@ } ) .set_expected_records([]) - .set_expected_discover_error(AirbyteTracedException, FileBasedSourceError.SCHEMA_INFERENCE_ERROR.value) - .set_expected_read_error(AirbyteTracedException, "Please check the logged errors for more information.") + .set_expected_discover_error( + AirbyteTracedException, FileBasedSourceError.SCHEMA_INFERENCE_ERROR.value + ) + .set_expected_read_error( + AirbyteTracedException, "Please check the logged errors for more information." + ) .set_expected_logs( { "read": [ @@ -617,19 +624,35 @@ "stream": "stream1", }, { - "data": {"col3": 1.1, "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "b.jsonl"}, + "data": { + "col3": 1.1, + "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", + "_ab_source_file_url": "b.jsonl", + }, "stream": "stream1", }, { - "data": {"col3": 2.2, "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "b.jsonl"}, + "data": { + "col3": 2.2, + "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", + "_ab_source_file_url": "b.jsonl", + }, "stream": "stream1", }, { - "data": {"col3": 1.1, "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "b.jsonl"}, + "data": { + "col3": 1.1, + "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", + "_ab_source_file_url": "b.jsonl", + }, "stream": "stream2", }, { - "data": {"col3": 2.2, "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "b.jsonl"}, + "data": { + "col3": 2.2, + "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", + "_ab_source_file_url": "b.jsonl", + }, "stream": "stream2", }, ] @@ -846,11 +869,19 @@ "stream": "stream1", }, { - "data": {"col3": 1.1, "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "b.jsonl"}, + "data": { + "col3": 1.1, + "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", + "_ab_source_file_url": "b.jsonl", + }, "stream": "stream2", }, { - "data": {"col3": 2.2, "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "b.jsonl"}, + "data": { + "col3": 2.2, + "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", + "_ab_source_file_url": "b.jsonl", + }, "stream": "stream2", }, ] diff --git a/unit_tests/sources/file_based/scenarios/parquet_scenarios.py b/unit_tests/sources/file_based/scenarios/parquet_scenarios.py index 732660e5..5ddb8468 100644 --- a/unit_tests/sources/file_based/scenarios/parquet_scenarios.py +++ b/unit_tests/sources/file_based/scenarios/parquet_scenarios.py @@ -166,7 +166,9 @@ ) .set_source_builder( FileBasedSourceBuilder() - .set_stream_reader(TemporaryParquetFilesStreamReader(files=_single_parquet_file, file_type="parquet")) + .set_stream_reader( + TemporaryParquetFilesStreamReader(files=_single_parquet_file, file_type="parquet") + ) .set_file_type("parquet") ) .set_expected_records( @@ -232,7 +234,11 @@ ) .set_source_builder( FileBasedSourceBuilder() - .set_stream_reader(TemporaryParquetFilesStreamReader(files=_single_partitioned_parquet_file, file_type="parquet")) + .set_stream_reader( + TemporaryParquetFilesStreamReader( + files=_single_partitioned_parquet_file, file_type="parquet" + ) + ) .set_file_type("parquet") ) .set_expected_records( @@ -305,7 +311,9 @@ .set_source_builder( FileBasedSourceBuilder() .set_file_type("parquet") - .set_stream_reader(TemporaryParquetFilesStreamReader(files=_multiple_parquet_file, file_type="parquet")) + .set_stream_reader( + TemporaryParquetFilesStreamReader(files=_multiple_parquet_file, file_type="parquet") + ) ) .set_expected_catalog( { @@ -391,7 +399,11 @@ ) .set_source_builder( FileBasedSourceBuilder() - .set_stream_reader(TemporaryParquetFilesStreamReader(files=_parquet_file_with_various_types, file_type="parquet")) + .set_stream_reader( + TemporaryParquetFilesStreamReader( + files=_parquet_file_with_various_types, file_type="parquet" + ) + ) .set_file_type("parquet") ) .set_expected_catalog( @@ -437,8 +449,14 @@ }, "col_date32": {"type": ["null", "string"], "format": "date"}, "col_date64": {"type": ["null", "string"], "format": "date"}, - "col_timestamp_without_tz": {"type": ["null", "string"], "format": "date-time"}, - "col_timestamp_with_tz": {"type": ["null", "string"], "format": "date-time"}, + "col_timestamp_without_tz": { + "type": ["null", "string"], + "format": "date-time", + }, + "col_timestamp_with_tz": { + "type": ["null", "string"], + "format": "date-time", + }, "col_time32s": { "type": ["null", "string"], }, @@ -528,7 +546,9 @@ ) .set_source_builder( FileBasedSourceBuilder() - .set_stream_reader(TemporaryParquetFilesStreamReader(files=_parquet_file_with_decimal, file_type="parquet")) + .set_stream_reader( + TemporaryParquetFilesStreamReader(files=_parquet_file_with_decimal, file_type="parquet") + ) .set_file_type("parquet") ) .set_expected_records( @@ -583,7 +603,9 @@ ) .set_source_builder( FileBasedSourceBuilder() - .set_stream_reader(TemporaryParquetFilesStreamReader(files=_parquet_file_with_decimal, file_type="parquet")) + .set_stream_reader( + TemporaryParquetFilesStreamReader(files=_parquet_file_with_decimal, file_type="parquet") + ) .set_file_type("parquet") ) .set_expected_records( @@ -638,13 +660,19 @@ ) .set_source_builder( FileBasedSourceBuilder() - .set_stream_reader(TemporaryParquetFilesStreamReader(files=_parquet_file_with_decimal, file_type="parquet")) + .set_stream_reader( + TemporaryParquetFilesStreamReader(files=_parquet_file_with_decimal, file_type="parquet") + ) .set_file_type("parquet") ) .set_expected_records( [ { - "data": {"col1": 13.00, "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "a.parquet"}, + "data": { + "col1": 13.00, + "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", + "_ab_source_file_url": "a.parquet", + }, "stream": "stream1", }, ] @@ -691,13 +719,19 @@ ) .set_source_builder( FileBasedSourceBuilder() - .set_stream_reader(TemporaryParquetFilesStreamReader(files=_parquet_file_with_decimal, file_type="parquet")) + .set_stream_reader( + TemporaryParquetFilesStreamReader(files=_parquet_file_with_decimal, file_type="parquet") + ) .set_file_type("parquet") ) .set_expected_records( [ { - "data": {"col1": 13.00, "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "a.parquet"}, + "data": { + "col1": 13.00, + "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", + "_ab_source_file_url": "a.parquet", + }, "stream": "stream1", }, ] @@ -728,16 +762,31 @@ parquet_with_invalid_config_scenario = ( TestScenarioBuilder() .set_name("parquet_with_invalid_config") - .set_config({"streams": [{"name": "stream1", "globs": ["*"], "validation_policy": "Emit Record", "format": {"filetype": "csv"}}]}) + .set_config( + { + "streams": [ + { + "name": "stream1", + "globs": ["*"], + "validation_policy": "Emit Record", + "format": {"filetype": "csv"}, + } + ] + } + ) .set_source_builder( FileBasedSourceBuilder() - .set_stream_reader(TemporaryParquetFilesStreamReader(files=_single_parquet_file, file_type="parquet")) + .set_stream_reader( + TemporaryParquetFilesStreamReader(files=_single_parquet_file, file_type="parquet") + ) .set_file_type("parquet") ) .set_expected_records([]) .set_expected_logs({"read": [{"level": "ERROR", "message": "Error parsing record"}]}) .set_expected_discover_error(AirbyteTracedException, "Error inferring schema from files") - .set_expected_read_error(AirbyteTracedException, "Please check the logged errors for more information.") + .set_expected_read_error( + AirbyteTracedException, "Please check the logged errors for more information." + ) .set_expected_catalog( { "streams": [ diff --git a/unit_tests/sources/file_based/scenarios/scenario_builder.py b/unit_tests/sources/file_based/scenarios/scenario_builder.py index 8158225a..da8c7ba8 100644 --- a/unit_tests/sources/file_based/scenarios/scenario_builder.py +++ b/unit_tests/sources/file_based/scenarios/scenario_builder.py @@ -33,7 +33,10 @@ class SourceBuilder(ABC, Generic[SourceType]): @abstractmethod def build( - self, configured_catalog: Optional[Mapping[str, Any]], config: Optional[Mapping[str, Any]], state: Optional[TState] + self, + configured_catalog: Optional[Mapping[str, Any]], + config: Optional[Mapping[str, Any]], + state: Optional[TState], ) -> SourceType: raise NotImplementedError() @@ -140,7 +143,9 @@ def set_config(self, config: Mapping[str, Any]) -> "TestScenarioBuilder[SourceTy self._config = config return self - def set_expected_spec(self, expected_spec: Mapping[str, Any]) -> "TestScenarioBuilder[SourceType]": + def set_expected_spec( + self, expected_spec: Mapping[str, Any] + ) -> "TestScenarioBuilder[SourceType]": self._expected_spec = expected_spec return self @@ -148,35 +153,51 @@ def set_catalog(self, catalog: ConfiguredAirbyteCatalog) -> "TestScenarioBuilder self._catalog = catalog return self - def set_expected_check_status(self, expected_check_status: str) -> "TestScenarioBuilder[SourceType]": + def set_expected_check_status( + self, expected_check_status: str + ) -> "TestScenarioBuilder[SourceType]": self._expected_check_status = expected_check_status return self - def set_expected_catalog(self, expected_catalog: Mapping[str, Any]) -> "TestScenarioBuilder[SourceType]": + def set_expected_catalog( + self, expected_catalog: Mapping[str, Any] + ) -> "TestScenarioBuilder[SourceType]": self._expected_catalog = expected_catalog return self - def set_expected_logs(self, expected_logs: Mapping[str, List[Mapping[str, Any]]]) -> "TestScenarioBuilder[SourceType]": + def set_expected_logs( + self, expected_logs: Mapping[str, List[Mapping[str, Any]]] + ) -> "TestScenarioBuilder[SourceType]": self._expected_logs = expected_logs return self - def set_expected_records(self, expected_records: Optional[List[Mapping[str, Any]]]) -> "TestScenarioBuilder[SourceType]": + def set_expected_records( + self, expected_records: Optional[List[Mapping[str, Any]]] + ) -> "TestScenarioBuilder[SourceType]": self._expected_records = expected_records return self - def set_incremental_scenario_config(self, incremental_scenario_config: IncrementalScenarioConfig) -> "TestScenarioBuilder[SourceType]": + def set_incremental_scenario_config( + self, incremental_scenario_config: IncrementalScenarioConfig + ) -> "TestScenarioBuilder[SourceType]": self._incremental_scenario_config = incremental_scenario_config return self - def set_expected_check_error(self, error: Optional[Type[Exception]], message: str) -> "TestScenarioBuilder[SourceType]": + def set_expected_check_error( + self, error: Optional[Type[Exception]], message: str + ) -> "TestScenarioBuilder[SourceType]": self._expected_check_error = error, message return self - def set_expected_discover_error(self, error: Type[Exception], message: str) -> "TestScenarioBuilder[SourceType]": + def set_expected_discover_error( + self, error: Type[Exception], message: str + ) -> "TestScenarioBuilder[SourceType]": self._expected_discover_error = error, message return self - def set_expected_read_error(self, error: Type[Exception], message: str) -> "TestScenarioBuilder[SourceType]": + def set_expected_read_error( + self, error: Type[Exception], message: str + ) -> "TestScenarioBuilder[SourceType]": self._expected_read_error = error, message return self @@ -184,11 +205,15 @@ def set_log_levels(self, levels: Set[str]) -> "TestScenarioBuilder": self._log_levels = levels return self - def set_source_builder(self, source_builder: SourceBuilder[SourceType]) -> "TestScenarioBuilder[SourceType]": + def set_source_builder( + self, source_builder: SourceBuilder[SourceType] + ) -> "TestScenarioBuilder[SourceType]": self.source_builder = source_builder return self - def set_expected_analytics(self, expected_analytics: Optional[List[AirbyteAnalyticsTraceMessage]]) -> "TestScenarioBuilder[SourceType]": + def set_expected_analytics( + self, expected_analytics: Optional[List[AirbyteAnalyticsTraceMessage]] + ) -> "TestScenarioBuilder[SourceType]": self._expected_analytics = expected_analytics return self @@ -200,12 +225,15 @@ def build(self) -> "TestScenario[SourceType]": raise ValueError("source_builder is not set") if self._incremental_scenario_config and self._incremental_scenario_config.input_state: state = [ - AirbyteStateMessageSerializer.load(s) if isinstance(s, dict) else s for s in self._incremental_scenario_config.input_state + AirbyteStateMessageSerializer.load(s) if isinstance(s, dict) else s + for s in self._incremental_scenario_config.input_state ] else: state = None source = self.source_builder.build( - self._configured_catalog(SyncMode.incremental if self._incremental_scenario_config else SyncMode.full_refresh), + self._configured_catalog( + SyncMode.incremental if self._incremental_scenario_config else SyncMode.full_refresh + ), self._config, state, ) diff --git a/unit_tests/sources/file_based/scenarios/unstructured_scenarios.py b/unit_tests/sources/file_based/scenarios/unstructured_scenarios.py index 97c0c491..c8d3dae9 100644 --- a/unit_tests/sources/file_based/scenarios/unstructured_scenarios.py +++ b/unit_tests/sources/file_based/scenarios/unstructured_scenarios.py @@ -19,7 +19,10 @@ "type": ["null", "string"], "description": "Content of the file as markdown. Might be null if the file could not be parsed", }, - "document_key": {"type": ["null", "string"], "description": "Unique identifier of the document, e.g. the file path"}, + "document_key": { + "type": ["null", "string"], + "description": "Unique identifier of the document, e.g. the file path", + }, "_ab_source_file_parse_error": { "type": ["null", "string"], "description": "Error message if the file could not be parsed even though the file is supported", @@ -50,7 +53,8 @@ { "a.md": { "contents": bytes( - "# Title 1\n\n## Title 2\n\n### Title 3\n\n#### Title 4\n\n##### Title 5\n\n###### Title 6\n\n", "UTF-8" + "# Title 1\n\n## Title 2\n\n### Title 3\n\n#### Title 4\n\n##### Title 5\n\n###### Title 6\n\n", + "UTF-8", ), "last_modified": "2023-06-05T03:54:07.000Z", }, diff --git a/unit_tests/sources/file_based/scenarios/user_input_schema_scenarios.py b/unit_tests/sources/file_based/scenarios/user_input_schema_scenarios.py index 3c10e701..9d233921 100644 --- a/unit_tests/sources/file_based/scenarios/user_input_schema_scenarios.py +++ b/unit_tests/sources/file_based/scenarios/user_input_schema_scenarios.py @@ -122,8 +122,12 @@ .set_catalog(CatalogBuilder().with_stream("stream1", SyncMode.full_refresh).build()) .set_expected_check_status("FAILED") .set_expected_check_error(None, FileBasedSourceError.ERROR_PARSING_USER_PROVIDED_SCHEMA.value) - .set_expected_discover_error(ConfigValidationError, FileBasedSourceError.ERROR_PARSING_USER_PROVIDED_SCHEMA.value) - .set_expected_read_error(ConfigValidationError, FileBasedSourceError.ERROR_PARSING_USER_PROVIDED_SCHEMA.value) + .set_expected_discover_error( + ConfigValidationError, FileBasedSourceError.ERROR_PARSING_USER_PROVIDED_SCHEMA.value + ) + .set_expected_read_error( + ConfigValidationError, FileBasedSourceError.ERROR_PARSING_USER_PROVIDED_SCHEMA.value + ) ).build() @@ -374,11 +378,19 @@ "stream": "stream2", }, { - "data": {"col1": "val11c", "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "c.csv"}, + "data": { + "col1": "val11c", + "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", + "_ab_source_file_url": "c.csv", + }, "stream": "stream3", }, { - "data": {"col1": "val21c", "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "c.csv"}, + "data": { + "col1": "val21c", + "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", + "_ab_source_file_url": "c.csv", + }, "stream": "stream3", }, ] @@ -457,8 +469,12 @@ ) .set_expected_check_status("FAILED") .set_expected_check_error(None, FileBasedSourceError.ERROR_PARSING_USER_PROVIDED_SCHEMA.value) - .set_expected_discover_error(ConfigValidationError, FileBasedSourceError.ERROR_PARSING_USER_PROVIDED_SCHEMA.value) - .set_expected_read_error(ConfigValidationError, FileBasedSourceError.ERROR_PARSING_USER_PROVIDED_SCHEMA.value) + .set_expected_discover_error( + ConfigValidationError, FileBasedSourceError.ERROR_PARSING_USER_PROVIDED_SCHEMA.value + ) + .set_expected_read_error( + ConfigValidationError, FileBasedSourceError.ERROR_PARSING_USER_PROVIDED_SCHEMA.value + ) ).build() @@ -590,11 +606,19 @@ "stream": "stream2", }, { - "data": {"col1": "val11c", "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "c.csv"}, + "data": { + "col1": "val11c", + "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", + "_ab_source_file_url": "c.csv", + }, "stream": "stream3", }, { - "data": {"col1": "val21c", "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "c.csv"}, + "data": { + "col1": "val21c", + "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", + "_ab_source_file_url": "c.csv", + }, "stream": "stream3", }, ] @@ -728,11 +752,19 @@ # {"data": {"col1": "val21b", "col2": "val22b", "col3": "val23b", "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", # "_ab_source_file_url": "b.csv"}, "stream": "stream2"}, { - "data": {"col1": "val11c", "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "c.csv"}, + "data": { + "col1": "val11c", + "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", + "_ab_source_file_url": "c.csv", + }, "stream": "stream3", }, { - "data": {"col1": "val21c", "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "c.csv"}, + "data": { + "col1": "val21c", + "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", + "_ab_source_file_url": "c.csv", + }, "stream": "stream3", }, ] diff --git a/unit_tests/sources/file_based/scenarios/validation_policy_scenarios.py b/unit_tests/sources/file_based/scenarios/validation_policy_scenarios.py index 47e37dff..d6ed7e9a 100644 --- a/unit_tests/sources/file_based/scenarios/validation_policy_scenarios.py +++ b/unit_tests/sources/file_based/scenarios/validation_policy_scenarios.py @@ -664,7 +664,9 @@ ] } ) - .set_expected_records(None) # When syncing streams concurrently we don't know how many records will be emitted before the sync stops + .set_expected_records( + None + ) # When syncing streams concurrently we don't know how many records will be emitted before the sync stops .set_expected_logs( { "read": [ @@ -703,7 +705,9 @@ ] } ) - .set_expected_records(None) # When syncing streams concurrently we don't know how many records will be emitted before the sync stops + .set_expected_records( + None + ) # When syncing streams concurrently we don't know how many records will be emitted before the sync stops .set_expected_logs( { "read": [ diff --git a/unit_tests/sources/file_based/schema_validation_policies/test_default_schema_validation_policy.py b/unit_tests/sources/file_based/schema_validation_policies/test_default_schema_validation_policy.py index a7d2bfb7..9cbf33e5 100644 --- a/unit_tests/sources/file_based/schema_validation_policies/test_default_schema_validation_policy.py +++ b/unit_tests/sources/file_based/schema_validation_policies/test_default_schema_validation_policy.py @@ -7,7 +7,9 @@ import pytest from airbyte_cdk.sources.file_based.config.file_based_stream_config import ValidationPolicy from airbyte_cdk.sources.file_based.exceptions import StopSyncPerValidationPolicy -from airbyte_cdk.sources.file_based.schema_validation_policies import DEFAULT_SCHEMA_VALIDATION_POLICIES +from airbyte_cdk.sources.file_based.schema_validation_policies import ( + DEFAULT_SCHEMA_VALIDATION_POLICIES, +) CONFORMING_RECORD = { "col1": "val1", @@ -32,19 +34,65 @@ @pytest.mark.parametrize( "record,schema,validation_policy,expected_result", [ - pytest.param(CONFORMING_RECORD, SCHEMA, ValidationPolicy.emit_record, True, id="record-conforms_emit_record"), - pytest.param(NONCONFORMING_RECORD, SCHEMA, ValidationPolicy.emit_record, True, id="nonconforming_emit_record"), - pytest.param(CONFORMING_RECORD, SCHEMA, ValidationPolicy.skip_record, True, id="record-conforms_skip_record"), - pytest.param(NONCONFORMING_RECORD, SCHEMA, ValidationPolicy.skip_record, False, id="nonconforming_skip_record"), - pytest.param(CONFORMING_RECORD, SCHEMA, ValidationPolicy.wait_for_discover, True, id="record-conforms_wait_for_discover"), - pytest.param(NONCONFORMING_RECORD, SCHEMA, ValidationPolicy.wait_for_discover, False, id="nonconforming_wait_for_discover"), + pytest.param( + CONFORMING_RECORD, + SCHEMA, + ValidationPolicy.emit_record, + True, + id="record-conforms_emit_record", + ), + pytest.param( + NONCONFORMING_RECORD, + SCHEMA, + ValidationPolicy.emit_record, + True, + id="nonconforming_emit_record", + ), + pytest.param( + CONFORMING_RECORD, + SCHEMA, + ValidationPolicy.skip_record, + True, + id="record-conforms_skip_record", + ), + pytest.param( + NONCONFORMING_RECORD, + SCHEMA, + ValidationPolicy.skip_record, + False, + id="nonconforming_skip_record", + ), + pytest.param( + CONFORMING_RECORD, + SCHEMA, + ValidationPolicy.wait_for_discover, + True, + id="record-conforms_wait_for_discover", + ), + pytest.param( + NONCONFORMING_RECORD, + SCHEMA, + ValidationPolicy.wait_for_discover, + False, + id="nonconforming_wait_for_discover", + ), ], ) def test_record_passes_validation_policy( - record: Mapping[str, Any], schema: Mapping[str, Any], validation_policy: ValidationPolicy, expected_result: bool + record: Mapping[str, Any], + schema: Mapping[str, Any], + validation_policy: ValidationPolicy, + expected_result: bool, ) -> None: if validation_policy == ValidationPolicy.wait_for_discover and expected_result is False: with pytest.raises(StopSyncPerValidationPolicy): - DEFAULT_SCHEMA_VALIDATION_POLICIES[validation_policy].record_passes_validation_policy(record, schema) + DEFAULT_SCHEMA_VALIDATION_POLICIES[validation_policy].record_passes_validation_policy( + record, schema + ) else: - assert DEFAULT_SCHEMA_VALIDATION_POLICIES[validation_policy].record_passes_validation_policy(record, schema) == expected_result + assert ( + DEFAULT_SCHEMA_VALIDATION_POLICIES[validation_policy].record_passes_validation_policy( + record, schema + ) + == expected_result + ) diff --git a/unit_tests/sources/file_based/stream/concurrent/test_adapters.py b/unit_tests/sources/file_based/stream/concurrent/test_adapters.py index fc4b5c22..3c271dfe 100644 --- a/unit_tests/sources/file_based/stream/concurrent/test_adapters.py +++ b/unit_tests/sources/file_based/stream/concurrent/test_adapters.py @@ -9,7 +9,9 @@ import pytest from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, AirbyteStream, Level, SyncMode from airbyte_cdk.models import Type as MessageType -from airbyte_cdk.sources.file_based.availability_strategy import DefaultFileBasedAvailabilityStrategy +from airbyte_cdk.sources.file_based.availability_strategy import ( + DefaultFileBasedAvailabilityStrategy, +) from airbyte_cdk.sources.file_based.config.csv_format import CsvFormat from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig from airbyte_cdk.sources.file_based.discovery_policy import DefaultDiscoveryPolicy @@ -62,7 +64,9 @@ def test_file_based_stream_partition_generator(sync_mode): partitions = list(partition_generator.generate()) slices = [partition.to_slice() for partition in partitions] assert slices == stream_slices - stream.stream_slices.assert_called_once_with(sync_mode=_ANY_SYNC_MODE, cursor_field=_ANY_CURSOR_FIELD, stream_state=_ANY_STATE) + stream.stream_slices.assert_called_once_with( + sync_mode=_ANY_SYNC_MODE, cursor_field=_ANY_CURSOR_FIELD, stream_state=_ANY_STATE + ) @pytest.mark.parametrize( @@ -71,16 +75,36 @@ def test_file_based_stream_partition_generator(sync_mode): pytest.param( TypeTransformer(TransformConfig.NoTransform), [ - Record({"data": "1"}, Mock(spec=FileBasedStreamPartition, stream_name=Mock(return_value=_STREAM_NAME))), - Record({"data": "2"}, Mock(spec=FileBasedStreamPartition, stream_name=Mock(return_value=_STREAM_NAME))), + Record( + {"data": "1"}, + Mock( + spec=FileBasedStreamPartition, stream_name=Mock(return_value=_STREAM_NAME) + ), + ), + Record( + {"data": "2"}, + Mock( + spec=FileBasedStreamPartition, stream_name=Mock(return_value=_STREAM_NAME) + ), + ), ], id="test_no_transform", ), pytest.param( TypeTransformer(TransformConfig.DefaultSchemaNormalization), [ - Record({"data": 1}, Mock(spec=FileBasedStreamPartition, stream_name=Mock(return_value=_STREAM_NAME))), - Record({"data": 2}, Mock(spec=FileBasedStreamPartition, stream_name=Mock(return_value=_STREAM_NAME))), + Record( + {"data": 1}, + Mock( + spec=FileBasedStreamPartition, stream_name=Mock(return_value=_STREAM_NAME) + ), + ), + Record( + {"data": 2}, + Mock( + spec=FileBasedStreamPartition, stream_name=Mock(return_value=_STREAM_NAME) + ), + ), ], id="test_default_transform", ), @@ -89,14 +113,19 @@ def test_file_based_stream_partition_generator(sync_mode): def test_file_based_stream_partition(transformer, expected_records): stream = Mock() stream.name = _STREAM_NAME - stream.get_json_schema.return_value = {"type": "object", "properties": {"data": {"type": ["integer"]}}} + stream.get_json_schema.return_value = { + "type": "object", + "properties": {"data": {"type": ["integer"]}}, + } stream.transformer = transformer message_repository = InMemoryMessageRepository() _slice = None sync_mode = SyncMode.full_refresh cursor_field = None state = None - partition = FileBasedStreamPartition(stream, _slice, message_repository, sync_mode, cursor_field, state, _ANY_CURSOR) + partition = FileBasedStreamPartition( + stream, _slice, message_repository, sync_mode, cursor_field, state, _ANY_CURSOR + ) a_log_message = AirbyteMessage( type=MessageType.LOG, @@ -120,7 +149,9 @@ def test_file_based_stream_partition(transformer, expected_records): "exception_type, expected_display_message", [ pytest.param(Exception, None, id="test_exception_no_display_message"), - pytest.param(ExceptionWithDisplayMessage, "display_message", id="test_exception_no_display_message"), + pytest.param( + ExceptionWithDisplayMessage, "display_message", id="test_exception_no_display_message" + ), ], ) def test_file_based_stream_partition_raising_exception(exception_type, expected_display_message): @@ -130,7 +161,15 @@ def test_file_based_stream_partition_raising_exception(exception_type, expected_ message_repository = InMemoryMessageRepository() _slice = None - partition = FileBasedStreamPartition(stream, _slice, message_repository, _ANY_SYNC_MODE, _ANY_CURSOR_FIELD, _ANY_STATE, _ANY_CURSOR) + partition = FileBasedStreamPartition( + stream, + _slice, + message_repository, + _ANY_SYNC_MODE, + _ANY_CURSOR_FIELD, + _ANY_STATE, + _ANY_CURSOR, + ) stream.read_records.side_effect = Exception() @@ -145,7 +184,16 @@ def test_file_based_stream_partition_raising_exception(exception_type, expected_ "_slice, expected_hash", [ pytest.param( - {"files": [RemoteFile(uri="1", last_modified=datetime.strptime("2023-06-09T00:00:00Z", "%Y-%m-%dT%H:%M:%SZ"))]}, + { + "files": [ + RemoteFile( + uri="1", + last_modified=datetime.strptime( + "2023-06-09T00:00:00Z", "%Y-%m-%dT%H:%M:%SZ" + ), + ) + ] + }, hash(("stream", "2023-06-09T00:00:00.000000Z_1")), id="test_hash_with_slice", ), @@ -155,7 +203,9 @@ def test_file_based_stream_partition_raising_exception(exception_type, expected_ def test_file_based_stream_partition_hash(_slice, expected_hash): stream = Mock() stream.name = "stream" - partition = FileBasedStreamPartition(stream, _slice, Mock(), _ANY_SYNC_MODE, _ANY_CURSOR_FIELD, _ANY_STATE, _ANY_CURSOR) + partition = FileBasedStreamPartition( + stream, _slice, Mock(), _ANY_SYNC_MODE, _ANY_CURSOR_FIELD, _ANY_STATE, _ANY_CURSOR + ) _hash = partition.__hash__() assert _hash == expected_hash @@ -171,7 +221,9 @@ def setUp(self): supported_sync_modes=[SyncMode.full_refresh], ) self._legacy_stream = DefaultFileBasedStream( - cursor=FileBasedFinalStateCursor(stream_config=MagicMock(), stream_namespace=None, message_repository=Mock()), + cursor=FileBasedFinalStateCursor( + stream_config=MagicMock(), stream_namespace=None, message_repository=Mock() + ), config=FileBasedStreamConfig(name="stream", format=CsvFormat()), catalog_schema={}, stream_reader=MagicMock(), @@ -185,7 +237,13 @@ def setUp(self): self._logger = Mock() self._slice_logger = Mock() self._slice_logger.should_log_slice_message.return_value = False - self._facade = FileBasedStreamFacade(self._abstract_stream, self._legacy_stream, self._cursor, self._slice_logger, self._logger) + self._facade = FileBasedStreamFacade( + self._abstract_stream, + self._legacy_stream, + self._cursor, + self._slice_logger, + self._logger, + ) self._source = Mock() self._stream = Mock() @@ -207,17 +265,27 @@ def test_json_schema_is_delegated_to_wrapped_stream(self): assert self._facade.get_json_schema() == json_schema self._abstract_stream.get_json_schema.assert_called_once_with() - def test_given_cursor_is_noop_when_supports_incremental_then_return_legacy_stream_response(self): + def test_given_cursor_is_noop_when_supports_incremental_then_return_legacy_stream_response( + self, + ): assert ( FileBasedStreamFacade( - self._abstract_stream, self._legacy_stream, _ANY_CURSOR, Mock(spec=SliceLogger), Mock(spec=logging.Logger) + self._abstract_stream, + self._legacy_stream, + _ANY_CURSOR, + Mock(spec=SliceLogger), + Mock(spec=logging.Logger), ).supports_incremental == self._legacy_stream.supports_incremental ) def test_given_cursor_is_not_noop_when_supports_incremental_then_return_true(self): assert FileBasedStreamFacade( - self._abstract_stream, self._legacy_stream, Mock(spec=Cursor), Mock(spec=SliceLogger), Mock(spec=logging.Logger) + self._abstract_stream, + self._legacy_stream, + Mock(spec=Cursor), + Mock(spec=SliceLogger), + Mock(spec=logging.Logger), ).supports_incremental def test_full_refresh(self): @@ -249,7 +317,9 @@ def test_create_from_stream_stream(self): stream.primary_key = "id" stream.cursor_field = "cursor" - facade = FileBasedStreamFacade.create_from_stream(stream, self._source, self._logger, _ANY_STATE, self._cursor) + facade = FileBasedStreamFacade.create_from_stream( + stream, self._source, self._logger, _ANY_STATE, self._cursor + ) assert facade.name == "stream" assert facade.cursor_field == "cursor" @@ -261,7 +331,9 @@ def test_create_from_stream_stream_with_none_primary_key(self): stream.primary_key = None stream.cursor_field = [] - facade = FileBasedStreamFacade.create_from_stream(stream, self._source, self._logger, _ANY_STATE, self._cursor) + facade = FileBasedStreamFacade.create_from_stream( + stream, self._source, self._logger, _ANY_STATE, self._cursor + ) assert facade._abstract_stream._primary_key == [] def test_create_from_stream_with_composite_primary_key(self): @@ -270,7 +342,9 @@ def test_create_from_stream_with_composite_primary_key(self): stream.primary_key = ["id", "name"] stream.cursor_field = [] - facade = FileBasedStreamFacade.create_from_stream(stream, self._source, self._logger, _ANY_STATE, self._cursor) + facade = FileBasedStreamFacade.create_from_stream( + stream, self._source, self._logger, _ANY_STATE, self._cursor + ) assert facade._abstract_stream._primary_key == ["id", "name"] def test_create_from_stream_with_empty_list_cursor(self): @@ -278,7 +352,9 @@ def test_create_from_stream_with_empty_list_cursor(self): stream.primary_key = "id" stream.cursor_field = [] - facade = FileBasedStreamFacade.create_from_stream(stream, self._source, self._logger, _ANY_STATE, self._cursor) + facade = FileBasedStreamFacade.create_from_stream( + stream, self._source, self._logger, _ANY_STATE, self._cursor + ) assert facade.cursor_field == [] @@ -288,7 +364,9 @@ def test_create_from_stream_raises_exception_if_primary_key_is_nested(self): stream.primary_key = [["field", "id"]] with self.assertRaises(ValueError): - FileBasedStreamFacade.create_from_stream(stream, self._source, self._logger, _ANY_STATE, self._cursor) + FileBasedStreamFacade.create_from_stream( + stream, self._source, self._logger, _ANY_STATE, self._cursor + ) def test_create_from_stream_raises_exception_if_primary_key_has_invalid_type(self): stream = Mock() @@ -296,7 +374,9 @@ def test_create_from_stream_raises_exception_if_primary_key_has_invalid_type(sel stream.primary_key = 123 with self.assertRaises(ValueError): - FileBasedStreamFacade.create_from_stream(stream, self._source, self._logger, _ANY_STATE, self._cursor) + FileBasedStreamFacade.create_from_stream( + stream, self._source, self._logger, _ANY_STATE, self._cursor + ) def test_create_from_stream_raises_exception_if_cursor_field_is_nested(self): stream = Mock() @@ -305,7 +385,9 @@ def test_create_from_stream_raises_exception_if_cursor_field_is_nested(self): stream.cursor_field = ["field", "cursor"] with self.assertRaises(ValueError): - FileBasedStreamFacade.create_from_stream(stream, self._source, self._logger, _ANY_STATE, self._cursor) + FileBasedStreamFacade.create_from_stream( + stream, self._source, self._logger, _ANY_STATE, self._cursor + ) def test_create_from_stream_with_cursor_field_as_list(self): stream = Mock() @@ -313,7 +395,9 @@ def test_create_from_stream_with_cursor_field_as_list(self): stream.primary_key = "id" stream.cursor_field = ["cursor"] - facade = FileBasedStreamFacade.create_from_stream(stream, self._source, self._logger, _ANY_STATE, self._cursor) + facade = FileBasedStreamFacade.create_from_stream( + stream, self._source, self._logger, _ANY_STATE, self._cursor + ) assert facade.cursor_field == "cursor" def test_create_from_stream_none_message_repository(self): @@ -323,12 +407,16 @@ def test_create_from_stream_none_message_repository(self): self._source.message_repository = None with self.assertRaises(ValueError): - FileBasedStreamFacade.create_from_stream(self._stream, self._source, self._logger, {}, self._cursor) + FileBasedStreamFacade.create_from_stream( + self._stream, self._source, self._logger, {}, self._cursor + ) def test_get_error_display_message_no_display_message(self): self._stream.get_error_display_message.return_value = "display_message" - facade = FileBasedStreamFacade.create_from_stream(self._stream, self._source, self._logger, _ANY_STATE, self._cursor) + facade = FileBasedStreamFacade.create_from_stream( + self._stream, self._source, self._logger, _ANY_STATE, self._cursor + ) expected_display_message = None e = Exception() @@ -340,7 +428,9 @@ def test_get_error_display_message_no_display_message(self): def test_get_error_display_message_with_display_message(self): self._stream.get_error_display_message.return_value = "display_message" - facade = FileBasedStreamFacade.create_from_stream(self._stream, self._source, self._logger, _ANY_STATE, self._cursor) + facade = FileBasedStreamFacade.create_from_stream( + self._stream, self._source, self._logger, _ANY_STATE, self._cursor + ) expected_display_message = "display_message" e = ExceptionWithDisplayMessage("display_message") @@ -354,7 +444,9 @@ def test_get_error_display_message_with_display_message(self): "exception, expected_display_message", [ pytest.param(Exception("message"), None, id="test_no_display_message"), - pytest.param(ExceptionWithDisplayMessage("message"), "message", id="test_no_display_message"), + pytest.param( + ExceptionWithDisplayMessage("message"), "message", id="test_no_display_message" + ), ], ) def test_get_error_display_message(exception, expected_display_message): diff --git a/unit_tests/sources/file_based/stream/concurrent/test_file_based_concurrent_cursor.py b/unit_tests/sources/file_based/stream/concurrent/test_file_based_concurrent_cursor.py index 96c90790..ce48f845 100644 --- a/unit_tests/sources/file_based/stream/concurrent/test_file_based_concurrent_cursor.py +++ b/unit_tests/sources/file_based/stream/concurrent/test_file_based_concurrent_cursor.py @@ -30,7 +30,9 @@ def _make_cursor(input_state: Optional[MutableMapping[str, Any]]) -> FileBasedCo None, input_state, MagicMock(), - ConnectorStateManager(state=[AirbyteStateMessage(input_state)] if input_state is not None else None), + ConnectorStateManager( + state=[AirbyteStateMessage(input_state)] if input_state is not None else None + ), CursorField(FileBasedConcurrentCursor.CURSOR_FIELD), ) return cursor @@ -40,19 +42,30 @@ def _make_cursor(input_state: Optional[MutableMapping[str, Any]]) -> FileBasedCo "input_state, expected_cursor_value", [ pytest.param({}, (datetime.min, ""), id="no-state-gives-min-cursor"), - pytest.param({"history": {}}, (datetime.min, ""), id="missing-cursor-field-gives-min-cursor"), pytest.param( - {"history": {"a.csv": "2021-01-01T00:00:00.000000Z"}, "_ab_source_file_last_modified": "2021-01-01T00:00:00.000000Z_a.csv"}, + {"history": {}}, (datetime.min, ""), id="missing-cursor-field-gives-min-cursor" + ), + pytest.param( + { + "history": {"a.csv": "2021-01-01T00:00:00.000000Z"}, + "_ab_source_file_last_modified": "2021-01-01T00:00:00.000000Z_a.csv", + }, (datetime.strptime("2021-01-01T00:00:00.000000Z", DATE_TIME_FORMAT), "a.csv"), id="cursor-value-matches-earliest-file", ), pytest.param( - {"history": {"a.csv": "2021-01-01T00:00:00.000000Z"}, "_ab_source_file_last_modified": "2020-01-01T00:00:00.000000Z_a.csv"}, + { + "history": {"a.csv": "2021-01-01T00:00:00.000000Z"}, + "_ab_source_file_last_modified": "2020-01-01T00:00:00.000000Z_a.csv", + }, (datetime.strptime("2020-01-01T00:00:00.000000Z", DATE_TIME_FORMAT), "a.csv"), id="cursor-value-is-earlier", ), pytest.param( - {"history": {"a.csv": "2022-01-01T00:00:00.000000Z"}, "_ab_source_file_last_modified": "2021-01-01T00:00:00.000000Z_a.csv"}, + { + "history": {"a.csv": "2022-01-01T00:00:00.000000Z"}, + "_ab_source_file_last_modified": "2021-01-01T00:00:00.000000Z_a.csv", + }, (datetime.strptime("2021-01-01T00:00:00.000000Z", DATE_TIME_FORMAT), "a.csv"), id="cursor-value-is-later", ), @@ -69,18 +82,26 @@ def _make_cursor(input_state: Optional[MutableMapping[str, Any]]) -> FileBasedCo id="cursor-not-earliest", ), pytest.param( - {"history": {"b.csv": "2020-12-31T00:00:00.000000Z"}, "_ab_source_file_last_modified": "2021-01-01T00:00:00.000000Z_a.csv"}, + { + "history": {"b.csv": "2020-12-31T00:00:00.000000Z"}, + "_ab_source_file_last_modified": "2021-01-01T00:00:00.000000Z_a.csv", + }, (datetime.strptime("2020-12-31T00:00:00.000000Z", DATE_TIME_FORMAT), "b.csv"), id="state-with-cursor-and-earlier-history", ), pytest.param( - {"history": {"b.csv": "2021-01-02T00:00:00.000000Z"}, "_ab_source_file_last_modified": "2021-01-01T00:00:00.000000Z_a.csv"}, + { + "history": {"b.csv": "2021-01-02T00:00:00.000000Z"}, + "_ab_source_file_last_modified": "2021-01-01T00:00:00.000000Z_a.csv", + }, (datetime.strptime("2021-01-01T00:00:00.000000Z", DATE_TIME_FORMAT), "a.csv"), id="state-with-cursor-and-later-history", ), ], ) -def test_compute_prev_sync_cursor(input_state: MutableMapping[str, Any], expected_cursor_value: Tuple[datetime, str]): +def test_compute_prev_sync_cursor( + input_state: MutableMapping[str, Any], expected_cursor_value: Tuple[datetime, str] +): cursor = _make_cursor(input_state) assert cursor._compute_prev_sync_cursor(input_state) == expected_cursor_value @@ -99,7 +120,10 @@ def test_compute_prev_sync_cursor(input_state: MutableMapping[str, Any], expecte ), pytest.param( {"history": {}}, - [("newfile.csv", "2021-01-05T00:00:00.000000Z"), ("pending.csv", "2020-01-05T00:00:00.000000Z")], + [ + ("newfile.csv", "2021-01-05T00:00:00.000000Z"), + ("pending.csv", "2020-01-05T00:00:00.000000Z"), + ], ("newfile.csv", "2021-01-05T00:00:00.000000Z"), {"newfile.csv": "2021-01-05T00:00:00.000000Z"}, [("pending.csv", "2020-01-05T00:00:00.000000Z")], @@ -108,7 +132,10 @@ def test_compute_prev_sync_cursor(input_state: MutableMapping[str, Any], expecte ), pytest.param( {"history": {}}, - [("newfile.csv", "2021-01-05T00:00:00.000000Z"), ("pending.csv", "2022-01-05T00:00:00.000000Z")], + [ + ("newfile.csv", "2021-01-05T00:00:00.000000Z"), + ("pending.csv", "2022-01-05T00:00:00.000000Z"), + ], ("newfile.csv", "2021-01-05T00:00:00.000000Z"), {"newfile.csv": "2021-01-05T00:00:00.000000Z"}, [("pending.csv", "2022-01-05T00:00:00.000000Z")], @@ -119,25 +146,40 @@ def test_compute_prev_sync_cursor(input_state: MutableMapping[str, Any], expecte {"history": {"existing.csv": "2021-01-04T00:00:00.000000Z"}}, [("newfile.csv", "2021-01-05T00:00:00.000000Z")], ("newfile.csv", "2021-01-05T00:00:00.000000Z"), - {"existing.csv": "2021-01-04T00:00:00.000000Z", "newfile.csv": "2021-01-05T00:00:00.000000Z"}, + { + "existing.csv": "2021-01-04T00:00:00.000000Z", + "newfile.csv": "2021-01-05T00:00:00.000000Z", + }, [], "2021-01-05T00:00:00.000000Z_newfile.csv", id="add-to-nonempty-history-single-pending-file", ), pytest.param( {"history": {"existing.csv": "2021-01-04T00:00:00.000000Z"}}, - [("newfile.csv", "2021-01-05T00:00:00.000000Z"), ("pending.csv", "2020-01-05T00:00:00.000000Z")], + [ + ("newfile.csv", "2021-01-05T00:00:00.000000Z"), + ("pending.csv", "2020-01-05T00:00:00.000000Z"), + ], ("newfile.csv", "2021-01-05T00:00:00.000000Z"), - {"existing.csv": "2021-01-04T00:00:00.000000Z", "newfile.csv": "2021-01-05T00:00:00.000000Z"}, + { + "existing.csv": "2021-01-04T00:00:00.000000Z", + "newfile.csv": "2021-01-05T00:00:00.000000Z", + }, [("pending.csv", "2020-01-05T00:00:00.000000Z")], "2020-01-05T00:00:00.000000Z_pending.csv", id="add-to-nonempty-history-pending-file-is-older", ), pytest.param( {"history": {"existing.csv": "2021-01-04T00:00:00.000000Z"}}, - [("newfile.csv", "2021-01-05T00:00:00.000000Z"), ("pending.csv", "2022-01-05T00:00:00.000000Z")], + [ + ("newfile.csv", "2021-01-05T00:00:00.000000Z"), + ("pending.csv", "2022-01-05T00:00:00.000000Z"), + ], ("newfile.csv", "2021-01-05T00:00:00.000000Z"), - {"existing.csv": "2021-01-04T00:00:00.000000Z", "newfile.csv": "2021-01-05T00:00:00.000000Z"}, + { + "existing.csv": "2021-01-04T00:00:00.000000Z", + "newfile.csv": "2021-01-05T00:00:00.000000Z", + }, [("pending.csv", "2022-01-05T00:00:00.000000Z")], "2022-01-05T00:00:00.000000Z_pending.csv", id="add-to-nonempty-history-pending-file-is-newer", @@ -161,7 +203,13 @@ def test_add_file( [ FileBasedStreamPartition( stream, - {"files": [RemoteFile(uri=uri, last_modified=datetime.strptime(timestamp, DATE_TIME_FORMAT))]}, + { + "files": [ + RemoteFile( + uri=uri, last_modified=datetime.strptime(timestamp, DATE_TIME_FORMAT) + ) + ] + }, mock_message_repository, SyncMode.full_refresh, FileBasedConcurrentCursor.CURSOR_FIELD, @@ -173,13 +221,18 @@ def test_add_file( ) uri, timestamp = file_to_add - cursor.add_file(RemoteFile(uri=uri, last_modified=datetime.strptime(timestamp, DATE_TIME_FORMAT))) + cursor.add_file( + RemoteFile(uri=uri, last_modified=datetime.strptime(timestamp, DATE_TIME_FORMAT)) + ) assert cursor._file_to_datetime_history == expected_history assert cursor._pending_files == { - uri: RemoteFile(uri=uri, last_modified=datetime.strptime(timestamp, DATE_TIME_FORMAT)) for uri, timestamp in expected_pending_files + uri: RemoteFile(uri=uri, last_modified=datetime.strptime(timestamp, DATE_TIME_FORMAT)) + for uri, timestamp in expected_pending_files } assert ( - mock_message_repository.emit_message.call_args_list[0].args[0].state.stream.stream_state._ab_source_file_last_modified + mock_message_repository.emit_message.call_args_list[0] + .args[0] + .state.stream.stream_state._ab_source_file_last_modified == expected_cursor_value ) @@ -217,20 +270,26 @@ def test_add_file_invalid( ): cursor = _make_cursor(initial_state) cursor._pending_files = { - uri: RemoteFile(uri=uri, last_modified=datetime.strptime(timestamp, DATE_TIME_FORMAT)) for uri, timestamp in pending_files + uri: RemoteFile(uri=uri, last_modified=datetime.strptime(timestamp, DATE_TIME_FORMAT)) + for uri, timestamp in pending_files } mock_message_repository = MagicMock() cursor._message_repository = mock_message_repository uri, timestamp = file_to_add - cursor.add_file(RemoteFile(uri=uri, last_modified=datetime.strptime(timestamp, DATE_TIME_FORMAT))) + cursor.add_file( + RemoteFile(uri=uri, last_modified=datetime.strptime(timestamp, DATE_TIME_FORMAT)) + ) assert cursor._file_to_datetime_history == expected_history assert cursor._pending_files == { - uri: RemoteFile(uri=uri, last_modified=datetime.strptime(timestamp, DATE_TIME_FORMAT)) for uri, timestamp in expected_pending_files + uri: RemoteFile(uri=uri, last_modified=datetime.strptime(timestamp, DATE_TIME_FORMAT)) + for uri, timestamp in expected_pending_files } assert mock_message_repository.emit_message.call_args_list[0].args[0].log.level.value == "WARN" assert ( - mock_message_repository.emit_message.call_args_list[1].args[0].state.stream.stream_state._ab_source_file_last_modified + mock_message_repository.emit_message.call_args_list[1] + .args[0] + .state.stream.stream_state._ab_source_file_last_modified == expected_cursor_value ) @@ -238,12 +297,20 @@ def test_add_file_invalid( @pytest.mark.parametrize( "input_state, pending_files, expected_cursor_value", [ - pytest.param({}, [], f"{datetime.min.strftime('%Y-%m-%dT%H:%M:%S.%fZ')}_", id="no-state-no-pending"), pytest.param( - {"history": {"a.csv": "2021-01-01T00:00:00.000000Z"}}, [], "2021-01-01T00:00:00.000000Z_a.csv", id="no-pending-with-history" + {}, [], f"{datetime.min.strftime('%Y-%m-%dT%H:%M:%S.%fZ')}_", id="no-state-no-pending" + ), + pytest.param( + {"history": {"a.csv": "2021-01-01T00:00:00.000000Z"}}, + [], + "2021-01-01T00:00:00.000000Z_a.csv", + id="no-pending-with-history", ), pytest.param( - {"history": {}}, [("b.csv", "2021-01-02T00:00:00.000000Z")], "2021-01-02T00:00:00.000000Z_b.csv", id="pending-no-history" + {"history": {}}, + [("b.csv", "2021-01-02T00:00:00.000000Z")], + "2021-01-02T00:00:00.000000Z_b.csv", + id="pending-no-history", ), pytest.param( {"history": {"a.csv": "2022-01-01T00:00:00.000000Z"}}, @@ -259,13 +326,19 @@ def test_add_file_invalid( ), ], ) -def test_get_new_cursor_value(input_state: MutableMapping[str, Any], pending_files: List[Tuple[str, str]], expected_cursor_value: str): +def test_get_new_cursor_value( + input_state: MutableMapping[str, Any], + pending_files: List[Tuple[str, str]], + expected_cursor_value: str, +): cursor = _make_cursor(input_state) pending_partitions = [] for url, timestamp in pending_files: partition = MagicMock() partition.to_slice = lambda *args, **kwargs: { - "files": [RemoteFile(uri=url, last_modified=datetime.strptime(timestamp, DATE_TIME_FORMAT))] + "files": [ + RemoteFile(uri=url, last_modified=datetime.strptime(timestamp, DATE_TIME_FORMAT)) + ] } pending_partitions.append(partition) @@ -276,7 +349,14 @@ def test_get_new_cursor_value(input_state: MutableMapping[str, Any], pending_fil "all_files, history, is_history_full, prev_cursor_value, expected_files_to_sync", [ pytest.param( - [RemoteFile(uri="new.csv", last_modified=datetime.strptime("2021-01-03T00:00:00.000000Z", "%Y-%m-%dT%H:%M:%S.%fZ"))], + [ + RemoteFile( + uri="new.csv", + last_modified=datetime.strptime( + "2021-01-03T00:00:00.000000Z", "%Y-%m-%dT%H:%M:%S.%fZ" + ), + ) + ], {}, False, (datetime.min, ""), @@ -284,7 +364,14 @@ def test_get_new_cursor_value(input_state: MutableMapping[str, Any], pending_fil id="empty-history-one-new-file", ), pytest.param( - [RemoteFile(uri="a.csv", last_modified=datetime.strptime("2021-01-02T00:00:00.000000Z", "%Y-%m-%dT%H:%M:%S.%fZ"))], + [ + RemoteFile( + uri="a.csv", + last_modified=datetime.strptime( + "2021-01-02T00:00:00.000000Z", "%Y-%m-%dT%H:%M:%S.%fZ" + ), + ) + ], {"a.csv": "2021-01-01T00:00:00.000000Z"}, False, (datetime.min, ""), @@ -292,7 +379,14 @@ def test_get_new_cursor_value(input_state: MutableMapping[str, Any], pending_fil id="non-empty-history-file-in-history-modified", ), pytest.param( - [RemoteFile(uri="a.csv", last_modified=datetime.strptime("2021-01-01T00:00:00.000000Z", "%Y-%m-%dT%H:%M:%S.%fZ"))], + [ + RemoteFile( + uri="a.csv", + last_modified=datetime.strptime( + "2021-01-01T00:00:00.000000Z", "%Y-%m-%dT%H:%M:%S.%fZ" + ), + ) + ], {"a.csv": "2021-01-01T00:00:00.000000Z"}, False, (datetime.min, ""), @@ -301,7 +395,9 @@ def test_get_new_cursor_value(input_state: MutableMapping[str, Any], pending_fil ), ], ) -def test_get_files_to_sync(all_files, history, is_history_full, prev_cursor_value, expected_files_to_sync): +def test_get_files_to_sync( + all_files, history, is_history_full, prev_cursor_value, expected_files_to_sync +): cursor = _make_cursor({}) cursor._file_to_datetime_history = history cursor._prev_cursor_value = prev_cursor_value @@ -315,7 +411,10 @@ def test_get_files_to_sync(all_files, history, is_history_full, prev_cursor_valu "file_to_check, history, is_history_full, prev_cursor_value, sync_start, expected_should_sync", [ pytest.param( - RemoteFile(uri="new.csv", last_modified=datetime.strptime("2021-01-03T00:00:00.000000Z", DATE_TIME_FORMAT)), + RemoteFile( + uri="new.csv", + last_modified=datetime.strptime("2021-01-03T00:00:00.000000Z", DATE_TIME_FORMAT), + ), {}, False, (datetime.min, ""), @@ -324,7 +423,10 @@ def test_get_files_to_sync(all_files, history, is_history_full, prev_cursor_valu id="file-not-in-history-not-full-old-cursor", ), pytest.param( - RemoteFile(uri="new.csv", last_modified=datetime.strptime("2021-01-03T00:00:00.000000Z", DATE_TIME_FORMAT)), + RemoteFile( + uri="new.csv", + last_modified=datetime.strptime("2021-01-03T00:00:00.000000Z", DATE_TIME_FORMAT), + ), {}, False, (datetime.strptime("2024-01-02T00:00:00.000000Z", DATE_TIME_FORMAT), ""), @@ -333,7 +435,10 @@ def test_get_files_to_sync(all_files, history, is_history_full, prev_cursor_valu id="file-not-in-history-not-full-new-cursor", ), pytest.param( - RemoteFile(uri="a.csv", last_modified=datetime.strptime("2021-01-01T00:00:00.000000Z", DATE_TIME_FORMAT)), + RemoteFile( + uri="a.csv", + last_modified=datetime.strptime("2021-01-01T00:00:00.000000Z", DATE_TIME_FORMAT), + ), {"a.csv": "2021-01-01T00:00:00.000000Z"}, False, (datetime.min, ""), @@ -342,7 +447,10 @@ def test_get_files_to_sync(all_files, history, is_history_full, prev_cursor_valu id="file-in-history-not-modified", ), pytest.param( - RemoteFile(uri="a.csv", last_modified=datetime.strptime("2020-01-01T00:00:00.000000Z", DATE_TIME_FORMAT)), + RemoteFile( + uri="a.csv", + last_modified=datetime.strptime("2020-01-01T00:00:00.000000Z", DATE_TIME_FORMAT), + ), {"a.csv": "2021-01-01T00:00:00.000000Z"}, False, (datetime.min, ""), @@ -351,7 +459,10 @@ def test_get_files_to_sync(all_files, history, is_history_full, prev_cursor_valu id="file-in-history-modified-before", ), pytest.param( - RemoteFile(uri="a.csv", last_modified=datetime.strptime("2022-01-01T00:00:00.000000Z", DATE_TIME_FORMAT)), + RemoteFile( + uri="a.csv", + last_modified=datetime.strptime("2022-01-01T00:00:00.000000Z", DATE_TIME_FORMAT), + ), {"a.csv": "2021-01-01T00:00:00.000000Z"}, False, (datetime.min, ""), @@ -360,7 +471,10 @@ def test_get_files_to_sync(all_files, history, is_history_full, prev_cursor_valu id="file-in-history-modified-after", ), pytest.param( - RemoteFile(uri="new.csv", last_modified=datetime.strptime("2022-01-01T00:00:00.000000Z", DATE_TIME_FORMAT)), + RemoteFile( + uri="new.csv", + last_modified=datetime.strptime("2022-01-01T00:00:00.000000Z", DATE_TIME_FORMAT), + ), {}, True, (datetime.strptime("2021-01-02T00:00:00.000000Z", DATE_TIME_FORMAT), "a.csv"), @@ -369,7 +483,10 @@ def test_get_files_to_sync(all_files, history, is_history_full, prev_cursor_valu id="history-full-file-modified-after-cursor", ), pytest.param( - RemoteFile(uri="new1.csv", last_modified=datetime.strptime("2021-01-01T00:00:00.000000Z", DATE_TIME_FORMAT)), + RemoteFile( + uri="new1.csv", + last_modified=datetime.strptime("2021-01-01T00:00:00.000000Z", DATE_TIME_FORMAT), + ), {}, True, (datetime.strptime("2021-01-01T00:00:00.000000Z", DATE_TIME_FORMAT), "new0.csv"), @@ -378,7 +495,10 @@ def test_get_files_to_sync(all_files, history, is_history_full, prev_cursor_valu id="history-full-modified-eq-cursor-uri-gt", ), pytest.param( - RemoteFile(uri="new0.csv", last_modified=datetime.strptime("2021-01-01T00:00:00.000000Z", DATE_TIME_FORMAT)), + RemoteFile( + uri="new0.csv", + last_modified=datetime.strptime("2021-01-01T00:00:00.000000Z", DATE_TIME_FORMAT), + ), {}, True, (datetime.strptime("2021-01-01T00:00:00.000000Z", DATE_TIME_FORMAT), "new1.csv"), @@ -387,7 +507,10 @@ def test_get_files_to_sync(all_files, history, is_history_full, prev_cursor_valu id="history-full-modified-eq-cursor-uri-lt", ), pytest.param( - RemoteFile(uri="new.csv", last_modified=datetime.strptime("2020-01-01T00:00:00.000000Z", DATE_TIME_FORMAT)), + RemoteFile( + uri="new.csv", + last_modified=datetime.strptime("2020-01-01T00:00:00.000000Z", DATE_TIME_FORMAT), + ), {}, True, (datetime.strptime("2021-01-01T00:00:00.000000Z", DATE_TIME_FORMAT), "a.csv"), @@ -396,7 +519,10 @@ def test_get_files_to_sync(all_files, history, is_history_full, prev_cursor_valu id="history-full-modified-before-cursor-and-after-sync-start", ), pytest.param( - RemoteFile(uri="new.csv", last_modified=datetime.strptime("2021-01-01T00:00:00.000000Z", DATE_TIME_FORMAT)), + RemoteFile( + uri="new.csv", + last_modified=datetime.strptime("2021-01-01T00:00:00.000000Z", DATE_TIME_FORMAT), + ), {}, True, (datetime.strptime("2022-01-01T00:00:00.000000Z", DATE_TIME_FORMAT), "a.csv"), @@ -435,15 +561,23 @@ def test_should_sync_file( id="non-full-history", ), pytest.param( - {f"file{i}.csv": f"2021-01-0{i}T00:00:00.000000Z" for i in range(1, 4)}, # all before the time window + { + f"file{i}.csv": f"2021-01-0{i}T00:00:00.000000Z" for i in range(1, 4) + }, # all before the time window True, - datetime.strptime("2021-01-01T00:00:00.000000Z", DATE_TIME_FORMAT), # Time window start time + datetime.strptime( + "2021-01-01T00:00:00.000000Z", DATE_TIME_FORMAT + ), # Time window start time id="full-history-earliest-before-window", ), pytest.param( - {f"file{i}.csv": f"2024-01-0{i}T00:00:00.000000Z" for i in range(1, 4)}, # all after the time window + { + f"file{i}.csv": f"2024-01-0{i}T00:00:00.000000Z" for i in range(1, 4) + }, # all after the time window True, - datetime.strptime("2023-06-13T00:00:00.000000Z", DATE_TIME_FORMAT), # Earliest file time + datetime.strptime( + "2023-06-13T00:00:00.000000Z", DATE_TIME_FORMAT + ), # Earliest file time id="full-history-earliest-after-window", ), ], diff --git a/unit_tests/sources/file_based/stream/test_default_file_based_cursor.py b/unit_tests/sources/file_based/stream/test_default_file_based_cursor.py index 957ed912..6cd3e20b 100644 --- a/unit_tests/sources/file_based/stream/test_default_file_based_cursor.py +++ b/unit_tests/sources/file_based/stream/test_default_file_based_cursor.py @@ -8,9 +8,14 @@ import pytest from airbyte_cdk.sources.file_based.config.csv_format import CsvFormat -from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig, ValidationPolicy +from airbyte_cdk.sources.file_based.config.file_based_stream_config import ( + FileBasedStreamConfig, + ValidationPolicy, +) from airbyte_cdk.sources.file_based.remote_file import RemoteFile -from airbyte_cdk.sources.file_based.stream.cursor.default_file_based_cursor import DefaultFileBasedCursor +from airbyte_cdk.sources.file_based.stream.cursor.default_file_based_cursor import ( + DefaultFileBasedCursor, +) from freezegun import freeze_time @@ -20,13 +25,25 @@ pytest.param( [ RemoteFile( - uri="a.csv", last_modified=datetime.strptime("2021-01-01T00:00:00.000Z", "%Y-%m-%dT%H:%M:%S.%fZ"), file_type="csv" + uri="a.csv", + last_modified=datetime.strptime( + "2021-01-01T00:00:00.000Z", "%Y-%m-%dT%H:%M:%S.%fZ" + ), + file_type="csv", ), RemoteFile( - uri="b.csv", last_modified=datetime.strptime("2021-01-02T00:00:00.000Z", "%Y-%m-%dT%H:%M:%S.%fZ"), file_type="csv" + uri="b.csv", + last_modified=datetime.strptime( + "2021-01-02T00:00:00.000Z", "%Y-%m-%dT%H:%M:%S.%fZ" + ), + file_type="csv", ), RemoteFile( - uri="c.csv", last_modified=datetime.strptime("2020-12-31T00:00:00.000Z", "%Y-%m-%dT%H:%M:%S.%fZ"), file_type="csv" + uri="c.csv", + last_modified=datetime.strptime( + "2020-12-31T00:00:00.000Z", "%Y-%m-%dT%H:%M:%S.%fZ" + ), + file_type="csv", ), ], [datetime(2021, 1, 1), datetime(2021, 1, 1), datetime(2020, 12, 31)], @@ -43,19 +60,40 @@ pytest.param( [ RemoteFile( - uri="a.csv", last_modified=datetime.strptime("2021-01-01T00:00:00.000Z", "%Y-%m-%dT%H:%M:%S.%fZ"), file_type="csv" + uri="a.csv", + last_modified=datetime.strptime( + "2021-01-01T00:00:00.000Z", "%Y-%m-%dT%H:%M:%S.%fZ" + ), + file_type="csv", ), RemoteFile( - uri="b.csv", last_modified=datetime.strptime("2021-01-02T00:00:00.000Z", "%Y-%m-%dT%H:%M:%S.%fZ"), file_type="csv" + uri="b.csv", + last_modified=datetime.strptime( + "2021-01-02T00:00:00.000Z", "%Y-%m-%dT%H:%M:%S.%fZ" + ), + file_type="csv", ), RemoteFile( - uri="c.csv", last_modified=datetime.strptime("2021-01-03T00:00:00.000Z", "%Y-%m-%dT%H:%M:%S.%fZ"), file_type="csv" + uri="c.csv", + last_modified=datetime.strptime( + "2021-01-03T00:00:00.000Z", "%Y-%m-%dT%H:%M:%S.%fZ" + ), + file_type="csv", ), RemoteFile( - uri="d.csv", last_modified=datetime.strptime("2021-01-04T00:00:00.000Z", "%Y-%m-%dT%H:%M:%S.%fZ"), file_type="csv" + uri="d.csv", + last_modified=datetime.strptime( + "2021-01-04T00:00:00.000Z", "%Y-%m-%dT%H:%M:%S.%fZ" + ), + file_type="csv", ), ], - [datetime(2021, 1, 1), datetime(2021, 1, 1), datetime(2021, 1, 1), datetime(2021, 1, 2)], + [ + datetime(2021, 1, 1), + datetime(2021, 1, 1), + datetime(2021, 1, 1), + datetime(2021, 1, 2), + ], { "history": { "b.csv": "2021-01-02T00:00:00.000000Z", @@ -69,21 +107,39 @@ pytest.param( [ RemoteFile( - uri="a.csv", last_modified=datetime.strptime("2021-01-01T00:00:00.000Z", "%Y-%m-%dT%H:%M:%S.%fZ"), file_type="csv" + uri="a.csv", + last_modified=datetime.strptime( + "2021-01-01T00:00:00.000Z", "%Y-%m-%dT%H:%M:%S.%fZ" + ), + file_type="csv", ), RemoteFile( uri="file_with_same_timestamp_as_b.csv", - last_modified=datetime.strptime("2021-01-02T00:00:00.000Z", "%Y-%m-%dT%H:%M:%S.%fZ"), + last_modified=datetime.strptime( + "2021-01-02T00:00:00.000Z", "%Y-%m-%dT%H:%M:%S.%fZ" + ), file_type="csv", ), RemoteFile( - uri="b.csv", last_modified=datetime.strptime("2021-01-02T00:00:00.000Z", "%Y-%m-%dT%H:%M:%S.%fZ"), file_type="csv" + uri="b.csv", + last_modified=datetime.strptime( + "2021-01-02T00:00:00.000Z", "%Y-%m-%dT%H:%M:%S.%fZ" + ), + file_type="csv", ), RemoteFile( - uri="c.csv", last_modified=datetime.strptime("2021-01-03T00:00:00.000Z", "%Y-%m-%dT%H:%M:%S.%fZ"), file_type="csv" + uri="c.csv", + last_modified=datetime.strptime( + "2021-01-03T00:00:00.000Z", "%Y-%m-%dT%H:%M:%S.%fZ" + ), + file_type="csv", ), RemoteFile( - uri="d.csv", last_modified=datetime.strptime("2021-01-04T00:00:00.000Z", "%Y-%m-%dT%H:%M:%S.%fZ"), file_type="csv" + uri="d.csv", + last_modified=datetime.strptime( + "2021-01-04T00:00:00.000Z", "%Y-%m-%dT%H:%M:%S.%fZ" + ), + file_type="csv", ), ], [ @@ -105,7 +161,11 @@ ), ], ) -def test_add_file(files_to_add: List[RemoteFile], expected_start_time: List[datetime], expected_state_dict: Mapping[str, Any]) -> None: +def test_add_file( + files_to_add: List[RemoteFile], + expected_start_time: List[datetime], + expected_state_dict: Mapping[str, Any], +) -> None: cursor = get_cursor(max_history_size=3, days_to_sync_if_history_is_full=3) assert cursor._compute_start_time() == datetime.min @@ -121,24 +181,48 @@ def test_add_file(files_to_add: List[RemoteFile], expected_start_time: List[date pytest.param( [ RemoteFile( - uri="a.csv", last_modified=datetime.strptime("2021-01-01T00:00:00.000Z", "%Y-%m-%dT%H:%M:%S.%fZ"), file_type="csv" + uri="a.csv", + last_modified=datetime.strptime( + "2021-01-01T00:00:00.000Z", "%Y-%m-%dT%H:%M:%S.%fZ" + ), + file_type="csv", ), RemoteFile( - uri="b.csv", last_modified=datetime.strptime("2021-01-02T00:00:00.000Z", "%Y-%m-%dT%H:%M:%S.%fZ"), file_type="csv" + uri="b.csv", + last_modified=datetime.strptime( + "2021-01-02T00:00:00.000Z", "%Y-%m-%dT%H:%M:%S.%fZ" + ), + file_type="csv", ), RemoteFile( - uri="c.csv", last_modified=datetime.strptime("2020-12-31T00:00:00.000Z", "%Y-%m-%dT%H:%M:%S.%fZ"), file_type="csv" + uri="c.csv", + last_modified=datetime.strptime( + "2020-12-31T00:00:00.000Z", "%Y-%m-%dT%H:%M:%S.%fZ" + ), + file_type="csv", ), ], [ RemoteFile( - uri="a.csv", last_modified=datetime.strptime("2021-01-01T00:00:00.000Z", "%Y-%m-%dT%H:%M:%S.%fZ"), file_type="csv" + uri="a.csv", + last_modified=datetime.strptime( + "2021-01-01T00:00:00.000Z", "%Y-%m-%dT%H:%M:%S.%fZ" + ), + file_type="csv", ), RemoteFile( - uri="b.csv", last_modified=datetime.strptime("2021-01-02T00:00:00.000Z", "%Y-%m-%dT%H:%M:%S.%fZ"), file_type="csv" + uri="b.csv", + last_modified=datetime.strptime( + "2021-01-02T00:00:00.000Z", "%Y-%m-%dT%H:%M:%S.%fZ" + ), + file_type="csv", ), RemoteFile( - uri="c.csv", last_modified=datetime.strptime("2020-12-31T00:00:00.000Z", "%Y-%m-%dT%H:%M:%S.%fZ"), file_type="csv" + uri="c.csv", + last_modified=datetime.strptime( + "2020-12-31T00:00:00.000Z", "%Y-%m-%dT%H:%M:%S.%fZ" + ), + file_type="csv", ), ], 3, @@ -148,24 +232,48 @@ def test_add_file(files_to_add: List[RemoteFile], expected_start_time: List[date pytest.param( [ RemoteFile( - uri="a.csv", last_modified=datetime.strptime("2021-01-01T00:00:00.000Z", "%Y-%m-%dT%H:%M:%S.%fZ"), file_type="csv" + uri="a.csv", + last_modified=datetime.strptime( + "2021-01-01T00:00:00.000Z", "%Y-%m-%dT%H:%M:%S.%fZ" + ), + file_type="csv", ), RemoteFile( - uri="b.csv", last_modified=datetime.strptime("2021-01-02T00:00:00.000Z", "%Y-%m-%dT%H:%M:%S.%fZ"), file_type="csv" + uri="b.csv", + last_modified=datetime.strptime( + "2021-01-02T00:00:00.000Z", "%Y-%m-%dT%H:%M:%S.%fZ" + ), + file_type="csv", ), RemoteFile( - uri="c.csv", last_modified=datetime.strptime("2020-12-31T00:00:00.000Z", "%Y-%m-%dT%H:%M:%S.%fZ"), file_type="csv" + uri="c.csv", + last_modified=datetime.strptime( + "2020-12-31T00:00:00.000Z", "%Y-%m-%dT%H:%M:%S.%fZ" + ), + file_type="csv", ), ], [ RemoteFile( - uri="a.csv", last_modified=datetime.strptime("2021-01-01T00:00:00.000Z", "%Y-%m-%dT%H:%M:%S.%fZ"), file_type="csv" + uri="a.csv", + last_modified=datetime.strptime( + "2021-01-01T00:00:00.000Z", "%Y-%m-%dT%H:%M:%S.%fZ" + ), + file_type="csv", ), RemoteFile( - uri="b.csv", last_modified=datetime.strptime("2021-01-02T00:00:00.000Z", "%Y-%m-%dT%H:%M:%S.%fZ"), file_type="csv" + uri="b.csv", + last_modified=datetime.strptime( + "2021-01-02T00:00:00.000Z", "%Y-%m-%dT%H:%M:%S.%fZ" + ), + file_type="csv", ), RemoteFile( - uri="c.csv", last_modified=datetime.strptime("2020-12-31T00:00:00.000Z", "%Y-%m-%dT%H:%M:%S.%fZ"), file_type="csv" + uri="c.csv", + last_modified=datetime.strptime( + "2020-12-31T00:00:00.000Z", "%Y-%m-%dT%H:%M:%S.%fZ" + ), + file_type="csv", ), ], 2, @@ -175,7 +283,10 @@ def test_add_file(files_to_add: List[RemoteFile], expected_start_time: List[date ], ) def test_get_files_to_sync( - files: List[RemoteFile], expected_files_to_sync: List[RemoteFile], max_history_size: int, history_is_partial: bool + files: List[RemoteFile], + expected_files_to_sync: List[RemoteFile], + max_history_size: int, + history_is_partial: bool, ) -> None: logger = MagicMock() cursor = get_cursor(max_history_size, 3) @@ -199,7 +310,10 @@ def test_only_recent_files_are_synced_if_history_is_full() -> None: ] state = { - "history": {f.uri: f.last_modified.strftime(DefaultFileBasedCursor.DATE_TIME_FORMAT) for f in files_in_history}, + "history": { + f.uri: f.last_modified.strftime(DefaultFileBasedCursor.DATE_TIME_FORMAT) + for f in files_in_history + }, } cursor.set_initial_state(state) @@ -227,7 +341,9 @@ def test_only_recent_files_are_synced_if_history_is_full() -> None: pytest.param(timedelta(days=1), True, id="test_modified_at_is_more_recent"), ], ) -def test_sync_file_already_present_in_history(modified_at_delta: timedelta, should_sync_file: bool) -> None: +def test_sync_file_already_present_in_history( + modified_at_delta: timedelta, should_sync_file: bool +) -> None: logger = MagicMock() cursor = get_cursor(2, 3) original_modified_at = datetime(2021, 1, 2) @@ -237,12 +353,17 @@ def test_sync_file_already_present_in_history(modified_at_delta: timedelta, shou ] state = { - "history": {f.uri: f.last_modified.strftime(DefaultFileBasedCursor.DATE_TIME_FORMAT) for f in files_in_history}, + "history": { + f.uri: f.last_modified.strftime(DefaultFileBasedCursor.DATE_TIME_FORMAT) + for f in files_in_history + }, } cursor.set_initial_state(state) files = [ - RemoteFile(uri=filename, last_modified=original_modified_at + modified_at_delta, file_type="csv"), + RemoteFile( + uri=filename, last_modified=original_modified_at + modified_at_delta, file_type="csv" + ), ] files_to_sync = list(cursor.get_files_to_sync(files, logger)) @@ -253,9 +374,27 @@ def test_sync_file_already_present_in_history(modified_at_delta: timedelta, shou @pytest.mark.parametrize( "file_name, last_modified, earliest_dt_in_history, should_sync_file", [ - pytest.param("a.csv", datetime(2023, 6, 3), datetime(2023, 6, 6), True, id="test_last_modified_is_equal_to_time_buffer"), - pytest.param("b.csv", datetime(2023, 6, 6), datetime(2023, 6, 6), False, id="test_file_was_already_synced"), - pytest.param("b.csv", datetime(2023, 6, 7), datetime(2023, 6, 6), True, id="test_file_was_synced_in_the_past"), + pytest.param( + "a.csv", + datetime(2023, 6, 3), + datetime(2023, 6, 6), + True, + id="test_last_modified_is_equal_to_time_buffer", + ), + pytest.param( + "b.csv", + datetime(2023, 6, 6), + datetime(2023, 6, 6), + False, + id="test_file_was_already_synced", + ), + pytest.param( + "b.csv", + datetime(2023, 6, 7), + datetime(2023, 6, 6), + True, + id="test_file_was_synced_in_the_past", + ), pytest.param( "b.csv", datetime(2023, 6, 3), @@ -279,7 +418,12 @@ def test_sync_file_already_present_in_history(modified_at_delta: timedelta, shou ), ], ) -def test_should_sync_file(file_name: str, last_modified: datetime, earliest_dt_in_history: datetime, should_sync_file: bool) -> None: +def test_should_sync_file( + file_name: str, + last_modified: datetime, + earliest_dt_in_history: datetime, + should_sync_file: bool, +) -> None: logger = MagicMock() cursor = get_cursor(1, 3) @@ -288,7 +432,14 @@ def test_should_sync_file(file_name: str, last_modified: datetime, earliest_dt_i cursor._initial_earliest_file_in_history = cursor._compute_earliest_file_in_history() assert ( - bool(list(cursor.get_files_to_sync([RemoteFile(uri=file_name, last_modified=last_modified, file_type="csv")], logger))) + bool( + list( + cursor.get_files_to_sync( + [RemoteFile(uri=file_name, last_modified=last_modified, file_type="csv")], + logger, + ) + ) + ) == should_sync_file ) @@ -298,7 +449,9 @@ def test_set_initial_state_no_history() -> None: cursor.set_initial_state({}) -def get_cursor(max_history_size: int, days_to_sync_if_history_is_full: int) -> DefaultFileBasedCursor: +def get_cursor( + max_history_size: int, days_to_sync_if_history_is_full: int +) -> DefaultFileBasedCursor: cursor_cls = DefaultFileBasedCursor cursor_cls.DEFAULT_MAX_HISTORY_SIZE = max_history_size config = FileBasedStreamConfig( diff --git a/unit_tests/sources/file_based/stream/test_default_file_based_stream.py b/unit_tests/sources/file_based/stream/test_default_file_based_stream.py index 85318b52..8eea01d6 100644 --- a/unit_tests/sources/file_based/stream/test_default_file_based_stream.py +++ b/unit_tests/sources/file_based/stream/test_default_file_based_stream.py @@ -12,7 +12,9 @@ import pytest from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, Level from airbyte_cdk.models import Type as MessageType -from airbyte_cdk.sources.file_based.availability_strategy import AbstractFileBasedAvailabilityStrategy +from airbyte_cdk.sources.file_based.availability_strategy import ( + AbstractFileBasedAvailabilityStrategy, +) from airbyte_cdk.sources.file_based.discovery_policy import AbstractDiscoveryPolicy from airbyte_cdk.sources.file_based.exceptions import FileBasedErrorsCollector, FileBasedSourceError from airbyte_cdk.sources.file_based.file_based_stream_reader import AbstractFileBasedStreamReader @@ -103,7 +105,11 @@ def setUp(self) -> None: def test_when_read_records_from_slice_then_return_records(self) -> None: self._parser.parse_records.return_value = [self._A_RECORD] - messages = list(self._stream.read_records_from_slice({"files": [RemoteFile(uri="uri", last_modified=self._NOW)]})) + messages = list( + self._stream.read_records_from_slice( + {"files": [RemoteFile(uri="uri", last_modified=self._NOW)]} + ) + ) assert list(map(lambda message: message.record.data["data"], messages)) == [self._A_RECORD] def test_when_transform_record_then_return_updated_record(self) -> None: @@ -166,7 +172,9 @@ def test_given_exception_after_skipping_records_when_read_records_from_slice_the ) -> None: self._stream_config.schemaless = False self._validation_policy.record_passes_validation_policy.return_value = False - self._parser.parse_records.side_effect = [self._iter([self._A_RECORD, ValueError("An error")])] + self._parser.parse_records.side_effect = [ + self._iter([self._A_RECORD, ValueError("An error")]) + ] messages = list( self._stream.read_records_from_slice( @@ -225,7 +233,9 @@ class TestFileBasedErrorCollector: "Multiple errors", ], ) - def test_collect_parsing_error(self, stream, file, line_no, n_skipped, collector_expected_len) -> None: + def test_collect_parsing_error( + self, stream, file, line_no, n_skipped, collector_expected_len + ) -> None: test_error_pattern = "Error parsing record." # format the error body test_error = ( @@ -251,12 +261,19 @@ def test_yield_and_raise_collected(self) -> None: with pytest.raises(AirbyteTracedException) as parse_error: list(self.test_error_collector.yield_and_raise_collected()) assert parse_error.value.message == "Some errors occured while reading from the source." - assert parse_error.value.internal_message == "Please check the logged errors for more information." + assert ( + parse_error.value.internal_message + == "Please check the logged errors for more information." + ) class DefaultFileBasedStreamFileTransferTest(unittest.TestCase): _NOW = datetime(2022, 10, 22, tzinfo=timezone.utc) - _A_RECORD = {"bytes": 10, "file_relative_path": "relative/path/file.csv", "file_url": "/absolute/path/file.csv"} + _A_RECORD = { + "bytes": 10, + "file_relative_path": "relative/path/file.csv", + "file_url": "/absolute/path/file.csv", + } def setUp(self) -> None: self._stream_config = Mock() @@ -287,7 +304,11 @@ def setUp(self) -> None: def test_when_read_records_from_slice_then_return_records(self) -> None: """Verify that we have the new file method and data is empty""" with mock.patch.object(FileTransfer, "get_file", return_value=[self._A_RECORD]): - messages = list(self._stream.read_records_from_slice({"files": [RemoteFile(uri="uri", last_modified=self._NOW)]})) + messages = list( + self._stream.read_records_from_slice( + {"files": [RemoteFile(uri="uri", last_modified=self._NOW)]} + ) + ) assert list(map(lambda message: message.record.file, messages)) == [self._A_RECORD] assert list(map(lambda message: message.record.data, messages)) == [{}] diff --git a/unit_tests/sources/file_based/test_file_based_scenarios.py b/unit_tests/sources/file_based/test_file_based_scenarios.py index 247a9f34..a930192f 100644 --- a/unit_tests/sources/file_based/test_file_based_scenarios.py +++ b/unit_tests/sources/file_based/test_file_based_scenarios.py @@ -159,7 +159,12 @@ wait_for_rediscovery_scenario_multi_stream, wait_for_rediscovery_scenario_single_stream, ) -from unit_tests.sources.file_based.test_scenarios import verify_check, verify_discover, verify_read, verify_spec +from unit_tests.sources.file_based.test_scenarios import ( + verify_check, + verify_discover, + verify_read, + verify_spec, +) discover_failure_scenarios = [ empty_schema_inference_scenario, @@ -316,7 +321,9 @@ @pytest.mark.parametrize("scenario", discover_scenarios, ids=[s.name for s in discover_scenarios]) -def test_file_based_discover(capsys: CaptureFixture[str], tmp_path: PosixPath, scenario: TestScenario[AbstractSource]) -> None: +def test_file_based_discover( + capsys: CaptureFixture[str], tmp_path: PosixPath, scenario: TestScenario[AbstractSource] +) -> None: verify_discover(capsys, tmp_path, scenario) @@ -327,10 +334,14 @@ def test_file_based_read(scenario: TestScenario[AbstractSource]) -> None: @pytest.mark.parametrize("scenario", spec_scenarios, ids=[c.name for c in spec_scenarios]) -def test_file_based_spec(capsys: CaptureFixture[str], scenario: TestScenario[AbstractSource]) -> None: +def test_file_based_spec( + capsys: CaptureFixture[str], scenario: TestScenario[AbstractSource] +) -> None: verify_spec(capsys, scenario) @pytest.mark.parametrize("scenario", check_scenarios, ids=[c.name for c in check_scenarios]) -def test_file_based_check(capsys: CaptureFixture[str], tmp_path: PosixPath, scenario: TestScenario[AbstractSource]) -> None: +def test_file_based_check( + capsys: CaptureFixture[str], tmp_path: PosixPath, scenario: TestScenario[AbstractSource] +) -> None: verify_check(capsys, tmp_path, scenario) diff --git a/unit_tests/sources/file_based/test_file_based_stream_reader.py b/unit_tests/sources/file_based/test_file_based_stream_reader.py index b77bf2fd..66729af4 100644 --- a/unit_tests/sources/file_based/test_file_based_stream_reader.py +++ b/unit_tests/sources/file_based/test_file_based_stream_reader.py @@ -77,7 +77,9 @@ def open_file(self, file: RemoteFile) -> IOBase: def file_size(self, file: RemoteFile) -> int: return 0 - def get_file(self, file: RemoteFile, local_directory: str, logger: logging.Logger) -> Dict[str, Any]: + def get_file( + self, file: RemoteFile, local_directory: str, logger: logging.Logger + ) -> Dict[str, Any]: return {} @@ -94,7 +96,11 @@ def documentation_url(cls) -> AnyUrl: pytest.param([""], DEFAULT_CONFIG, set(), set(), id="empty-string"), pytest.param(["**"], DEFAULT_CONFIG, set(FILEPATHS), set(), id="**"), pytest.param( - ["**/*.csv"], DEFAULT_CONFIG, {"a.csv", "a/b.csv", "a/c.csv", "a/b/c.csv", "a/c/c.csv", "a/b/c/d.csv"}, set(), id="**/*.csv" + ["**/*.csv"], + DEFAULT_CONFIG, + {"a.csv", "a/b.csv", "a/c.csv", "a/b/c.csv", "a/c/c.csv", "a/b/c/d.csv"}, + set(), + id="**/*.csv", ), pytest.param( ["**/*.csv*"], @@ -122,12 +128,27 @@ def documentation_url(cls) -> AnyUrl: pytest.param( ["*/*"], DEFAULT_CONFIG, - {"a/b", "a/b.csv", "a/b.csv.gz", "a/b.jsonl", "a/c", "a/c.csv", "a/c.csv.gz", "a/c.jsonl"}, + { + "a/b", + "a/b.csv", + "a/b.csv.gz", + "a/b.jsonl", + "a/c", + "a/c.csv", + "a/c.csv.gz", + "a/c.jsonl", + }, set(), id="*/*", ), pytest.param(["*/*.csv"], DEFAULT_CONFIG, {"a/b.csv", "a/c.csv"}, set(), id="*/*.csv"), - pytest.param(["*/*.csv*"], DEFAULT_CONFIG, {"a/b.csv", "a/b.csv.gz", "a/c.csv", "a/c.csv.gz"}, set(), id="*/*.csv*"), + pytest.param( + ["*/*.csv*"], + DEFAULT_CONFIG, + {"a/b.csv", "a/b.csv.gz", "a/c.csv", "a/c.csv.gz"}, + set(), + id="*/*.csv*", + ), pytest.param( ["*/**"], DEFAULT_CONFIG, @@ -159,24 +180,64 @@ def documentation_url(cls) -> AnyUrl: pytest.param( ["a/*"], DEFAULT_CONFIG, - {"a/b", "a/b.csv", "a/b.csv.gz", "a/b.jsonl", "a/c", "a/c.csv", "a/c.csv.gz", "a/c.jsonl"}, + { + "a/b", + "a/b.csv", + "a/b.csv.gz", + "a/b.jsonl", + "a/c", + "a/c.csv", + "a/c.csv.gz", + "a/c.jsonl", + }, {"a/"}, id="a/*", ), pytest.param(["a/*.csv"], DEFAULT_CONFIG, {"a/b.csv", "a/c.csv"}, {"a/"}, id="a/*.csv"), - pytest.param(["a/*.csv*"], DEFAULT_CONFIG, {"a/b.csv", "a/b.csv.gz", "a/c.csv", "a/c.csv.gz"}, {"a/"}, id="a/*.csv*"), - pytest.param(["a/b/*"], DEFAULT_CONFIG, {"a/b/c", "a/b/c.csv", "a/b/c.csv.gz", "a/b/c.jsonl"}, {"a/b/"}, id="a/b/*"), + pytest.param( + ["a/*.csv*"], + DEFAULT_CONFIG, + {"a/b.csv", "a/b.csv.gz", "a/c.csv", "a/c.csv.gz"}, + {"a/"}, + id="a/*.csv*", + ), + pytest.param( + ["a/b/*"], + DEFAULT_CONFIG, + {"a/b/c", "a/b/c.csv", "a/b/c.csv.gz", "a/b/c.jsonl"}, + {"a/b/"}, + id="a/b/*", + ), pytest.param(["a/b/*.csv"], DEFAULT_CONFIG, {"a/b/c.csv"}, {"a/b/"}, id="a/b/*.csv"), - pytest.param(["a/b/*.csv*"], DEFAULT_CONFIG, {"a/b/c.csv", "a/b/c.csv.gz"}, {"a/b/"}, id="a/b/*.csv*"), + pytest.param( + ["a/b/*.csv*"], DEFAULT_CONFIG, {"a/b/c.csv", "a/b/c.csv.gz"}, {"a/b/"}, id="a/b/*.csv*" + ), pytest.param( ["a/*/*"], DEFAULT_CONFIG, - {"a/b/c", "a/b/c.csv", "a/b/c.csv.gz", "a/b/c.jsonl", "a/c/c", "a/c/c.csv", "a/c/c.csv.gz", "a/c/c.jsonl"}, + { + "a/b/c", + "a/b/c.csv", + "a/b/c.csv.gz", + "a/b/c.jsonl", + "a/c/c", + "a/c/c.csv", + "a/c/c.csv.gz", + "a/c/c.jsonl", + }, {"a/"}, id="a/*/*", ), - pytest.param(["a/*/*.csv"], DEFAULT_CONFIG, {"a/b/c.csv", "a/c/c.csv"}, {"a/"}, id="a/*/*.csv"), - pytest.param(["a/*/*.csv*"], DEFAULT_CONFIG, {"a/b/c.csv", "a/b/c.csv.gz", "a/c/c.csv", "a/c/c.csv.gz"}, {"a/"}, id="a/*/*.csv*"), + pytest.param( + ["a/*/*.csv"], DEFAULT_CONFIG, {"a/b/c.csv", "a/c/c.csv"}, {"a/"}, id="a/*/*.csv" + ), + pytest.param( + ["a/*/*.csv*"], + DEFAULT_CONFIG, + {"a/b/c.csv", "a/b/c.csv.gz", "a/c/c.csv", "a/c/c.csv.gz"}, + {"a/"}, + id="a/*/*.csv*", + ), pytest.param( ["a/**/*"], DEFAULT_CONFIG, @@ -206,7 +267,11 @@ def documentation_url(cls) -> AnyUrl: id="a/**/*", ), pytest.param( - ["a/**/*.csv"], DEFAULT_CONFIG, {"a/b.csv", "a/c.csv", "a/b/c.csv", "a/c/c.csv", "a/b/c/d.csv"}, {"a/"}, id="a/**/*.csv" + ["a/**/*.csv"], + DEFAULT_CONFIG, + {"a/b.csv", "a/c.csv", "a/b/c.csv", "a/c/c.csv", "a/b/c/d.csv"}, + {"a/"}, + id="a/**/*.csv", ), pytest.param( ["a/**/*.csv*"], @@ -246,11 +311,23 @@ def documentation_url(cls) -> AnyUrl: set(), id="**/*.csv,**/*.gz", ), - pytest.param(["*.csv", "*.gz"], DEFAULT_CONFIG, {"a.csv", "a.csv.gz"}, set(), id="*.csv,*.gz"), pytest.param( - ["a/*.csv", "a/*/*.csv"], DEFAULT_CONFIG, {"a/b.csv", "a/c.csv", "a/b/c.csv", "a/c/c.csv"}, {"a/"}, id="a/*.csv,a/*/*.csv" + ["*.csv", "*.gz"], DEFAULT_CONFIG, {"a.csv", "a.csv.gz"}, set(), id="*.csv,*.gz" + ), + pytest.param( + ["a/*.csv", "a/*/*.csv"], + DEFAULT_CONFIG, + {"a/b.csv", "a/c.csv", "a/b/c.csv", "a/c/c.csv"}, + {"a/"}, + id="a/*.csv,a/*/*.csv", + ), + pytest.param( + ["a/*.csv", "a/b/*.csv"], + DEFAULT_CONFIG, + {"a/b.csv", "a/c.csv", "a/b/c.csv"}, + {"a/", "a/b/"}, + id="a/*.csv,a/b/*.csv", ), - pytest.param(["a/*.csv", "a/b/*.csv"], DEFAULT_CONFIG, {"a/b.csv", "a/c.csv", "a/b/c.csv"}, {"a/", "a/b/"}, id="a/*.csv,a/b/*.csv"), pytest.param( ["**/*.csv"], {"start_date": "2023-06-01T03:54:07.000Z", "streams": []}, @@ -259,7 +336,11 @@ def documentation_url(cls) -> AnyUrl: id="all_csvs_modified_after_start_date", ), pytest.param( - ["**/*.csv"], {"start_date": "2023-06-10T03:54:07.000Z", "streams": []}, set(), set(), id="all_csvs_modified_before_start_date" + ["**/*.csv"], + {"start_date": "2023-06-10T03:54:07.000Z", "streams": []}, + set(), + set(), + id="all_csvs_modified_before_start_date", ), pytest.param( ["**/*.csv"], @@ -271,9 +352,15 @@ def documentation_url(cls) -> AnyUrl: ], ) def test_globs_and_prefixes_from_globs( - globs: List[str], config: Mapping[str, Any], expected_matches: Set[str], expected_path_prefixes: Set[str] + globs: List[str], + config: Mapping[str, Any], + expected_matches: Set[str], + expected_path_prefixes: Set[str], ) -> None: reader = TestStreamReader() reader.config = TestSpec(**config) - assert set([f.uri for f in reader.filter_files_by_globs_and_start_date(FILES, globs)]) == expected_matches + assert ( + set([f.uri for f in reader.filter_files_by_globs_and_start_date(FILES, globs)]) + == expected_matches + ) assert set(reader.get_prefixes_from_globs(globs)) == expected_path_prefixes diff --git a/unit_tests/sources/file_based/test_scenarios.py b/unit_tests/sources/file_based/test_scenarios.py index b381e688..14da7176 100644 --- a/unit_tests/sources/file_based/test_scenarios.py +++ b/unit_tests/sources/file_based/test_scenarios.py @@ -11,9 +11,17 @@ from _pytest.capture import CaptureFixture from _pytest.reports import ExceptionInfo from airbyte_cdk.entrypoint import launch -from airbyte_cdk.models import AirbyteAnalyticsTraceMessage, AirbyteLogMessage, AirbyteMessage, ConfiguredAirbyteCatalogSerializer, SyncMode +from airbyte_cdk.models import ( + AirbyteAnalyticsTraceMessage, + AirbyteLogMessage, + AirbyteMessage, + ConfiguredAirbyteCatalogSerializer, + SyncMode, +) from airbyte_cdk.sources import AbstractSource -from airbyte_cdk.sources.file_based.stream.concurrent.cursor import AbstractConcurrentFileBasedCursor +from airbyte_cdk.sources.file_based.stream.concurrent.cursor import ( + AbstractConcurrentFileBasedCursor, +) from airbyte_cdk.test.entrypoint_wrapper import EntrypointOutput from airbyte_cdk.test.entrypoint_wrapper import read as entrypoint_read from airbyte_cdk.utils import message_utils @@ -21,7 +29,9 @@ from unit_tests.sources.file_based.scenarios.scenario_builder import TestScenario -def verify_discover(capsys: CaptureFixture[str], tmp_path: PosixPath, scenario: TestScenario[AbstractSource]) -> None: +def verify_discover( + capsys: CaptureFixture[str], tmp_path: PosixPath, scenario: TestScenario[AbstractSource] +) -> None: expected_exc, expected_msg = scenario.expected_discover_error expected_logs = scenario.expected_logs if expected_exc: @@ -72,7 +82,9 @@ def assert_exception(expected_exception: type[BaseException], output: Entrypoint def _verify_read_output(output: EntrypointOutput, scenario: TestScenario[AbstractSource]) -> None: records_and_state_messages, log_messages = output.records_and_state_messages, output.logs - logs = [message.log for message in log_messages if message.log.level.value in scenario.log_levels] + logs = [ + message.log for message in log_messages if message.log.level.value in scenario.log_levels + ] if scenario.expected_records is None: return @@ -81,13 +93,17 @@ def _verify_read_output(output: EntrypointOutput, scenario: TestScenario[Abstrac sorted_expected_records = sorted( filter(lambda e: "data" in e, expected_records), key=lambda record: ",".join( - f"{k}={v}" for k, v in sorted(record["data"].items(), key=lambda items: (items[0], items[1])) if k != "emitted_at" + f"{k}={v}" + for k, v in sorted(record["data"].items(), key=lambda items: (items[0], items[1])) + if k != "emitted_at" ), ) sorted_records = sorted( filter(lambda r: r.record, records_and_state_messages), key=lambda record: ",".join( - f"{k}={v}" for k, v in sorted(record.record.data.items(), key=lambda items: (items[0], items[1])) if k != "emitted_at" + f"{k}={v}" + for k, v in sorted(record.record.data.items(), key=lambda items: (items[0], items[1])) + if k != "emitted_at" ), ) @@ -105,15 +121,23 @@ def _verify_read_output(output: EntrypointOutput, scenario: TestScenario[Abstrac expected_states = list(filter(lambda e: "data" not in e, expected_records)) states = list(filter(lambda r: r.state, records_and_state_messages)) - assert len(states) > 0, "No state messages emitted. Successful syncs should emit at least one stream state." + assert ( + len(states) > 0 + ), "No state messages emitted. Successful syncs should emit at least one stream state." _verify_state_record_counts(sorted_records, states) - if hasattr(scenario.source, "cursor_cls") and issubclass(scenario.source.cursor_cls, AbstractConcurrentFileBasedCursor): + if hasattr(scenario.source, "cursor_cls") and issubclass( + scenario.source.cursor_cls, AbstractConcurrentFileBasedCursor + ): # Only check the last state emitted because we don't know the order the others will be in. # This may be needed for non-file-based concurrent scenarios too. - assert {k: v for k, v in states[-1].state.stream.stream_state.__dict__.items()} == expected_states[-1] + assert { + k: v for k, v in states[-1].state.stream.stream_state.__dict__.items() + } == expected_states[-1] else: - for actual, expected in zip(states, expected_states): # states should be emitted in sorted order + for actual, expected in zip( + states, expected_states + ): # states should be emitted in sorted order assert {k: v for k, v in actual.state.stream.stream_state.__dict__.items()} == expected if scenario.expected_logs: @@ -127,7 +151,9 @@ def _verify_read_output(output: EntrypointOutput, scenario: TestScenario[Abstrac _verify_analytics(analytics, scenario.expected_analytics) -def _verify_state_record_counts(records: List[AirbyteMessage], states: List[AirbyteMessage]) -> None: +def _verify_state_record_counts( + records: List[AirbyteMessage], states: List[AirbyteMessage] +) -> None: actual_record_counts = {} for record in records: stream_descriptor = message_utils.get_stream_descriptor(record) @@ -137,7 +163,8 @@ def _verify_state_record_counts(records: List[AirbyteMessage], states: List[Airb for state_message in states: stream_descriptor = message_utils.get_stream_descriptor(state_message) state_record_count_sums[stream_descriptor] = ( - state_record_count_sums.get(stream_descriptor, 0) + state_message.state.sourceStats.recordCount + state_record_count_sums.get(stream_descriptor, 0) + + state_message.state.sourceStats.recordCount ) for stream, actual_count in actual_record_counts.items(): @@ -149,10 +176,13 @@ def _verify_state_record_counts(records: List[AirbyteMessage], states: List[Airb assert state_record_count_sums[stream] == 0 -def _verify_analytics(analytics: List[AirbyteMessage], expected_analytics: Optional[List[AirbyteAnalyticsTraceMessage]]) -> None: +def _verify_analytics( + analytics: List[AirbyteMessage], + expected_analytics: Optional[List[AirbyteAnalyticsTraceMessage]], +) -> None: if expected_analytics: - assert len(analytics) == len( - expected_analytics + assert ( + len(analytics) == len(expected_analytics) ), f"Number of actual analytics messages ({len(analytics)}) did not match expected ({len(expected_analytics)})" for actual, expected in zip(analytics, expected_analytics): actual_type, actual_value = actual.trace.analytics.type, actual.trace.analytics.value @@ -162,7 +192,9 @@ def _verify_analytics(analytics: List[AirbyteMessage], expected_analytics: Optio assert actual_value == expected_value -def _verify_expected_logs(logs: List[AirbyteLogMessage], expected_logs: Optional[List[Mapping[str, Any]]]) -> None: +def _verify_expected_logs( + logs: List[AirbyteLogMessage], expected_logs: Optional[List[Mapping[str, Any]]] +) -> None: if expected_logs: for actual, expected in zip(logs, expected_logs): actual_level, actual_message = actual.level.value, actual.message @@ -176,7 +208,9 @@ def verify_spec(capsys: CaptureFixture[str], scenario: TestScenario[AbstractSour assert spec(capsys, scenario) == scenario.expected_spec -def verify_check(capsys: CaptureFixture[str], tmp_path: PosixPath, scenario: TestScenario[AbstractSource]) -> None: +def verify_check( + capsys: CaptureFixture[str], tmp_path: PosixPath, scenario: TestScenario[AbstractSource] +) -> None: expected_exc, expected_msg = scenario.expected_check_error if expected_exc: @@ -200,7 +234,9 @@ def spec(capsys: CaptureFixture[str], scenario: TestScenario[AbstractSource]) -> return json.loads(captured.out.splitlines()[0])["spec"] # type: ignore -def check(capsys: CaptureFixture[str], tmp_path: PosixPath, scenario: TestScenario[AbstractSource]) -> Dict[str, Any]: +def check( + capsys: CaptureFixture[str], tmp_path: PosixPath, scenario: TestScenario[AbstractSource] +) -> Dict[str, Any]: launch( scenario.source, ["check", "--config", make_file(tmp_path / "config.json", scenario.config)], @@ -217,7 +253,9 @@ def _find_connection_status(output: List[str]) -> Mapping[str, Any]: raise ValueError("No valid connectionStatus found in output") -def discover(capsys: CaptureFixture[str], tmp_path: PosixPath, scenario: TestScenario[AbstractSource]) -> Dict[str, Any]: +def discover( + capsys: CaptureFixture[str], tmp_path: PosixPath, scenario: TestScenario[AbstractSource] +) -> Dict[str, Any]: launch( scenario.source, ["discover", "--config", make_file(tmp_path / "config.json", scenario.config)], @@ -247,7 +285,9 @@ def read_with_state(scenario: TestScenario[AbstractSource]) -> EntrypointOutput: ) -def make_file(path: Path, file_contents: Optional[Union[Mapping[str, Any], List[Mapping[str, Any]]]]) -> str: +def make_file( + path: Path, file_contents: Optional[Union[Mapping[str, Any], List[Mapping[str, Any]]]] +) -> str: path.write_text(json.dumps(file_contents)) return str(path) diff --git a/unit_tests/sources/file_based/test_schema_helpers.py b/unit_tests/sources/file_based/test_schema_helpers.py index 90e01942..20bcadcc 100644 --- a/unit_tests/sources/file_based/test_schema_helpers.py +++ b/unit_tests/sources/file_based/test_schema_helpers.py @@ -181,19 +181,60 @@ "record,schema,expected_result", [ pytest.param(COMPLETE_CONFORMING_RECORD, SCHEMA, True, id="record-conforms"), - pytest.param(NONCONFORMING_EXTRA_COLUMN_RECORD, SCHEMA, False, id="nonconforming-extra-column"), - pytest.param(CONFORMING_WITH_MISSING_COLUMN_RECORD, SCHEMA, True, id="record-conforms-with-missing-column"), - pytest.param(CONFORMING_WITH_NARROWER_TYPE_RECORD, SCHEMA, True, id="record-conforms-with-narrower-type"), + pytest.param( + NONCONFORMING_EXTRA_COLUMN_RECORD, SCHEMA, False, id="nonconforming-extra-column" + ), + pytest.param( + CONFORMING_WITH_MISSING_COLUMN_RECORD, + SCHEMA, + True, + id="record-conforms-with-missing-column", + ), + pytest.param( + CONFORMING_WITH_NARROWER_TYPE_RECORD, + SCHEMA, + True, + id="record-conforms-with-narrower-type", + ), pytest.param(NONCONFORMING_WIDER_TYPE_RECORD, SCHEMA, False, id="nonconforming-wider-type"), - pytest.param(NONCONFORMING_NON_OBJECT_RECORD, SCHEMA, False, id="nonconforming-string-is-not-an-object"), - pytest.param(NONCONFORMING_NON_ARRAY_RECORD, SCHEMA, False, id="nonconforming-string-is-not-an-array"), - pytest.param(NONCONFORMING_TOO_WIDE_ARRAY_RECORD, SCHEMA, False, id="nonconforming-array-values-too-wide"), - pytest.param(CONFORMING_NARROWER_ARRAY_RECORD, SCHEMA, True, id="conforming-array-values-narrower-than-schema"), - pytest.param(NONCONFORMING_INVALID_ARRAY_RECORD, SCHEMA, False, id="nonconforming-array-is-not-a-string"), - pytest.param(NONCONFORMING_INVALID_OBJECT_RECORD, SCHEMA, False, id="nonconforming-object-is-not-a-string"), + pytest.param( + NONCONFORMING_NON_OBJECT_RECORD, + SCHEMA, + False, + id="nonconforming-string-is-not-an-object", + ), + pytest.param( + NONCONFORMING_NON_ARRAY_RECORD, SCHEMA, False, id="nonconforming-string-is-not-an-array" + ), + pytest.param( + NONCONFORMING_TOO_WIDE_ARRAY_RECORD, + SCHEMA, + False, + id="nonconforming-array-values-too-wide", + ), + pytest.param( + CONFORMING_NARROWER_ARRAY_RECORD, + SCHEMA, + True, + id="conforming-array-values-narrower-than-schema", + ), + pytest.param( + NONCONFORMING_INVALID_ARRAY_RECORD, + SCHEMA, + False, + id="nonconforming-array-is-not-a-string", + ), + pytest.param( + NONCONFORMING_INVALID_OBJECT_RECORD, + SCHEMA, + False, + id="nonconforming-object-is-not-a-string", + ), ], ) -def test_conforms_to_schema(record: Mapping[str, Any], schema: Mapping[str, Any], expected_result: bool) -> None: +def test_conforms_to_schema( + record: Mapping[str, Any], schema: Mapping[str, Any], expected_result: bool +) -> None: assert conforms_to_schema(record, schema) == expected_result @@ -210,13 +251,42 @@ def test_comparable_types() -> None: [ pytest.param({}, {}, {}, id="empty-schemas"), pytest.param({"a": None}, {}, None, id="null-value-in-schema"), - pytest.param({"a": {"type": "integer"}}, {}, {"a": {"type": "integer"}}, id="single-key-schema1"), - pytest.param({}, {"a": {"type": "integer"}}, {"a": {"type": "integer"}}, id="single-key-schema2"), - pytest.param({"a": {"type": "integer"}}, {"a": {"type": "integer"}}, {"a": {"type": "integer"}}, id="single-key-both-schemas"), - pytest.param({"a": {"type": "integer"}}, {"a": {"type": "number"}}, {"a": {"type": "number"}}, id="single-key-schema2-is-wider"), - pytest.param({"a": {"type": "number"}}, {"a": {"type": "integer"}}, {"a": {"type": "number"}}, id="single-key-schema1-is-wider"), - pytest.param({"a": {"type": "array"}}, {"a": {"type": "integer"}}, None, id="single-key-with-array-schema1"), - pytest.param({"a": {"type": "integer"}}, {"a": {"type": "array"}}, None, id="single-key-with-array-schema2"), + pytest.param( + {"a": {"type": "integer"}}, {}, {"a": {"type": "integer"}}, id="single-key-schema1" + ), + pytest.param( + {}, {"a": {"type": "integer"}}, {"a": {"type": "integer"}}, id="single-key-schema2" + ), + pytest.param( + {"a": {"type": "integer"}}, + {"a": {"type": "integer"}}, + {"a": {"type": "integer"}}, + id="single-key-both-schemas", + ), + pytest.param( + {"a": {"type": "integer"}}, + {"a": {"type": "number"}}, + {"a": {"type": "number"}}, + id="single-key-schema2-is-wider", + ), + pytest.param( + {"a": {"type": "number"}}, + {"a": {"type": "integer"}}, + {"a": {"type": "number"}}, + id="single-key-schema1-is-wider", + ), + pytest.param( + {"a": {"type": "array"}}, + {"a": {"type": "integer"}}, + None, + id="single-key-with-array-schema1", + ), + pytest.param( + {"a": {"type": "integer"}}, + {"a": {"type": "array"}}, + None, + id="single-key-with-array-schema2", + ), pytest.param( {"a": {"type": "object", "properties": {"b": {"type": "integer"}}}}, {"a": {"type": "object", "properties": {"b": {"type": "integer"}}}}, @@ -259,7 +329,9 @@ def test_comparable_types() -> None: {"a": {"type": "integer"}, "b": {"type": "string"}, "c": {"type": "number"}}, id="", ), - pytest.param({"a": {"type": "invalid_type"}}, {"b": {"type": "integer"}}, None, id="invalid-type"), + pytest.param( + {"a": {"type": "invalid_type"}}, {"b": {"type": "integer"}}, None, id="invalid-type" + ), pytest.param( {"a": {"type": "object"}}, {"a": {"type": "null"}}, @@ -280,7 +352,9 @@ def test_comparable_types() -> None: ), ], ) -def test_merge_schemas(schema1: SchemaType, schema2: SchemaType, expected_result: Optional[SchemaType]) -> None: +def test_merge_schemas( + schema1: SchemaType, schema2: SchemaType, expected_result: Optional[SchemaType] +) -> None: if expected_result is not None: assert merge_schemas(schema1, schema2) == expected_result else: @@ -311,7 +385,10 @@ def test_merge_schemas(schema1: SchemaType, schema2: SchemaType, expected_result ), pytest.param( '{"col1 ": " string", "col2": " integer"}', - {"type": "object", "properties": {"col1": {"type": "string"}, "col2": {"type": "integer"}}}, + { + "type": "object", + "properties": {"col1": {"type": "string"}, "col2": {"type": "integer"}}, + }, None, id="valid_extra_spaces", ), @@ -354,7 +431,9 @@ def test_merge_schemas(schema1: SchemaType, schema2: SchemaType, expected_result ], ) def test_type_mapping_to_jsonschema( - type_mapping: Mapping[str, Any], expected_schema: Optional[Mapping[str, Any]], expected_exc_msg: Optional[str] + type_mapping: Mapping[str, Any], + expected_schema: Optional[Mapping[str, Any]], + expected_exc_msg: Optional[str], ) -> None: if expected_exc_msg: with pytest.raises(ConfigValidationError) as exc: diff --git a/unit_tests/sources/fixtures/source_test_fixture.py b/unit_tests/sources/fixtures/source_test_fixture.py index 6f3cd57b..620ad8c4 100644 --- a/unit_tests/sources/fixtures/source_test_fixture.py +++ b/unit_tests/sources/fixtures/source_test_fixture.py @@ -30,7 +30,9 @@ class SourceTestFixture(AbstractSource): the need to load static files (ex. spec.yaml, config.json, configured_catalog.json) into the unit-test package. """ - def __init__(self, streams: Optional[List[Stream]] = None, authenticator: Optional[AuthBase] = None): + def __init__( + self, streams: Optional[List[Stream]] = None, authenticator: Optional[AuthBase] = None + ): self._streams = streams self._authenticator = authenticator @@ -151,4 +153,7 @@ class SourceFixtureOauthAuthenticator(Oauth2Authenticator): def refresh_access_token(self) -> Tuple[str, int]: response = requests.request(method="POST", url=self.get_token_refresh_endpoint(), params={}) response.raise_for_status() - return "some_access_token", 1800 # Mock oauth response values to be used during the data retrieval step + return ( + "some_access_token", + 1800, + ) # Mock oauth response values to be used during the data retrieval step diff --git a/unit_tests/sources/message/test_repository.py b/unit_tests/sources/message/test_repository.py index 45502e45..6d637a6b 100644 --- a/unit_tests/sources/message/test_repository.py +++ b/unit_tests/sources/message/test_repository.py @@ -5,7 +5,14 @@ from unittest.mock import Mock import pytest -from airbyte_cdk.models import AirbyteControlConnectorConfigMessage, AirbyteControlMessage, AirbyteMessage, Level, OrchestratorType, Type +from airbyte_cdk.models import ( + AirbyteControlConnectorConfigMessage, + AirbyteControlMessage, + AirbyteMessage, + Level, + OrchestratorType, + Type, +) from airbyte_cdk.sources.message import ( InMemoryMessageRepository, LogAppenderMessageRepositoryDecorator, @@ -29,7 +36,9 @@ ANOTHER_CONTROL = AirbyteControlMessage( type=OrchestratorType.CONNECTOR_CONFIG, emitted_at=0, - connectorConfig=AirbyteControlConnectorConfigMessage(config={"another config": "another value"}), + connectorConfig=AirbyteControlConnectorConfigMessage( + config={"another config": "another value"} + ), ) UNKNOWN_LEVEL = "potato" @@ -65,26 +74,34 @@ def test_given_message_is_consumed_when_consume_queue_then_remove_message_from_q second_message_generator = repo.consume_queue() assert list(second_message_generator) == [second_message] - def test_given_log_level_is_severe_enough_when_log_message_then_allow_message_to_be_consumed(self): + def test_given_log_level_is_severe_enough_when_log_message_then_allow_message_to_be_consumed( + self, + ): repo = InMemoryMessageRepository(Level.DEBUG) repo.log_message(Level.INFO, lambda: {"message": "this is a log message"}) assert list(repo.consume_queue()) def test_given_log_level_is_severe_enough_when_log_message_then_filter_secrets(self, mocker): filtered_message = "a filtered message" - mocker.patch("airbyte_cdk.sources.message.repository.filter_secrets", return_value=filtered_message) + mocker.patch( + "airbyte_cdk.sources.message.repository.filter_secrets", return_value=filtered_message + ) repo = InMemoryMessageRepository(Level.DEBUG) repo.log_message(Level.INFO, lambda: {"message": "this is a log message"}) assert list(repo.consume_queue())[0].log.message == filtered_message - def test_given_log_level_not_severe_enough_when_log_message_then_do_not_allow_message_to_be_consumed(self): + def test_given_log_level_not_severe_enough_when_log_message_then_do_not_allow_message_to_be_consumed( + self, + ): repo = InMemoryMessageRepository(Level.ERROR) repo.log_message(Level.INFO, lambda: {"message": "this is a log message"}) assert not list(repo.consume_queue()) - def test_given_unknown_log_level_as_threshold_when_log_message_then_allow_message_to_be_consumed(self): + def test_given_unknown_log_level_as_threshold_when_log_message_then_allow_message_to_be_consumed( + self, + ): repo = InMemoryMessageRepository(UNKNOWN_LEVEL) repo.log_message(Level.DEBUG, lambda: {"message": "this is a log message"}) assert list(repo.consume_queue()) @@ -112,23 +129,31 @@ def test_when_emit_message_then_delegate_call(self, decorated): decorated.emit_message.assert_called_once_with(ANY_MESSAGE) def test_when_log_message_then_append(self, decorated): - repo = LogAppenderMessageRepositoryDecorator({"a": {"dict_to_append": "appended value"}}, decorated, Level.DEBUG) + repo = LogAppenderMessageRepositoryDecorator( + {"a": {"dict_to_append": "appended value"}}, decorated, Level.DEBUG + ) repo.log_message(Level.INFO, lambda: {"a": {"original": "original value"}}) assert decorated.log_message.call_args_list[0].args[1]() == { "a": {"dict_to_append": "appended value", "original": "original value"} } def test_given_value_clash_when_log_message_then_overwrite_value(self, decorated): - repo = LogAppenderMessageRepositoryDecorator({"clash": "appended value"}, decorated, Level.DEBUG) + repo = LogAppenderMessageRepositoryDecorator( + {"clash": "appended value"}, decorated, Level.DEBUG + ) repo.log_message(Level.INFO, lambda: {"clash": "original value"}) assert decorated.log_message.call_args_list[0].args[1]() == {"clash": "appended value"} - def test_given_log_level_is_severe_enough_when_log_message_then_allow_message_to_be_consumed(self, decorated): + def test_given_log_level_is_severe_enough_when_log_message_then_allow_message_to_be_consumed( + self, decorated + ): repo = LogAppenderMessageRepositoryDecorator(self._DICT_TO_APPEND, decorated, Level.DEBUG) repo.log_message(Level.INFO, lambda: {}) assert decorated.log_message.call_count == 1 - def test_given_log_level_not_severe_enough_when_log_message_then_do_not_allow_message_to_be_consumed(self, decorated): + def test_given_log_level_not_severe_enough_when_log_message_then_do_not_allow_message_to_be_consumed( + self, decorated + ): repo = LogAppenderMessageRepositoryDecorator(self._DICT_TO_APPEND, decorated, Level.ERROR) repo.log_message(Level.INFO, lambda: {}) assert decorated.log_message.call_count == 0 diff --git a/unit_tests/sources/mock_server_tests/mock_source_fixture.py b/unit_tests/sources/mock_server_tests/mock_source_fixture.py index 20631854..b5927219 100644 --- a/unit_tests/sources/mock_server_tests/mock_source_fixture.py +++ b/unit_tests/sources/mock_server_tests/mock_source_fixture.py @@ -131,10 +131,17 @@ def request_params( stream_slice: Optional[Mapping[str, Any]] = None, next_page_token: Optional[Mapping[str, Any]] = None, ) -> MutableMapping[str, Any]: - return {"start_date": stream_slice.get("start_date"), "end_date": stream_slice.get("end_date")} + return { + "start_date": stream_slice.get("start_date"), + "end_date": stream_slice.get("end_date"), + } def stream_slices( - self, *, sync_mode: SyncMode, cursor_field: Optional[List[str]] = None, stream_state: Optional[Mapping[str, Any]] = None + self, + *, + sync_mode: SyncMode, + cursor_field: Optional[List[str]] = None, + stream_state: Optional[Mapping[str, Any]] = None, ) -> Iterable[Optional[Mapping[str, Any]]]: start_date = pendulum.parse(self.start_date) @@ -206,10 +213,17 @@ def request_params( stream_slice: Optional[Mapping[str, Any]] = None, next_page_token: Optional[Mapping[str, Any]] = None, ) -> MutableMapping[str, Any]: - return {"start_date": stream_slice.get("start_date"), "end_date": stream_slice.get("end_date")} + return { + "start_date": stream_slice.get("start_date"), + "end_date": stream_slice.get("end_date"), + } def stream_slices( - self, *, sync_mode: SyncMode, cursor_field: Optional[List[str]] = None, stream_state: Optional[Mapping[str, Any]] = None + self, + *, + sync_mode: SyncMode, + cursor_field: Optional[List[str]] = None, + stream_state: Optional[Mapping[str, Any]] = None, ) -> Iterable[Optional[Mapping[str, Any]]]: start_date = pendulum.parse(self.start_date) @@ -251,7 +265,11 @@ def get_json_schema(self) -> Mapping[str, Any]: } def stream_slices( - self, *, sync_mode: SyncMode, cursor_field: Optional[List[str]] = None, stream_state: Optional[Mapping[str, Any]] = None + self, + *, + sync_mode: SyncMode, + cursor_field: Optional[List[str]] = None, + stream_state: Optional[Mapping[str, Any]] = None, ) -> Iterable[Optional[Mapping[str, Any]]]: return [{"divide_category": "dukes"}, {"divide_category": "mentats"}] @@ -329,20 +347,40 @@ def _read_single_page( stream_state: Optional[Mapping[str, Any]] = None, ) -> Iterable[StreamData]: next_page_token = stream_slice - request_headers = self.request_headers(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token) - request_params = self.request_params(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token) + request_headers = self.request_headers( + stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token + ) + request_params = self.request_params( + stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token + ) request, response = self._http_client.send_request( http_method=self.http_method, url=self._join_url( self.url_base, - self.path(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token), + self.path( + stream_state=stream_state, + stream_slice=stream_slice, + next_page_token=next_page_token, + ), + ), + request_kwargs=self.request_kwargs( + stream_state=stream_state, + stream_slice=stream_slice, + next_page_token=next_page_token, ), - request_kwargs=self.request_kwargs(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token), headers=request_headers, params=request_params, - json=self.request_body_json(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token), - data=self.request_body_data(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token), + json=self.request_body_json( + stream_state=stream_state, + stream_slice=stream_slice, + next_page_token=next_page_token, + ), + data=self.request_body_data( + stream_state=stream_state, + stream_slice=stream_slice, + next_page_token=next_page_token, + ), dedupe_query_params=True, ) yield from self.parse_response(response=response) @@ -359,7 +397,9 @@ def next_page_token(self, response: requests.Response) -> Optional[Mapping[str, class SourceFixture(AbstractSource): - def check_connection(self, logger: logging.Logger, config: Mapping[str, Any]) -> Tuple[bool, any]: + def check_connection( + self, logger: logging.Logger, config: Mapping[str, Any] + ) -> Tuple[bool, any]: return True, None def streams(self, config: Mapping[str, Any]) -> List[Stream]: diff --git a/unit_tests/sources/mock_server_tests/test_helpers/airbyte_message_assertions.py b/unit_tests/sources/mock_server_tests/test_helpers/airbyte_message_assertions.py index 04b65594..4e2f0051 100644 --- a/unit_tests/sources/mock_server_tests/test_helpers/airbyte_message_assertions.py +++ b/unit_tests/sources/mock_server_tests/test_helpers/airbyte_message_assertions.py @@ -19,7 +19,9 @@ def emits_successful_sync_status_messages(status_messages: List[AirbyteStreamSta def validate_message_order(expected_message_order: List[Type], messages: List[AirbyteMessage]): if len(expected_message_order) != len(messages): - pytest.fail(f"Expected message order count {len(expected_message_order)} did not match actual messages {len(messages)}") + pytest.fail( + f"Expected message order count {len(expected_message_order)} did not match actual messages {len(messages)}" + ) for i, message in enumerate(messages): if message.type != expected_message_order[i]: diff --git a/unit_tests/sources/mock_server_tests/test_mock_server_abstract_source.py b/unit_tests/sources/mock_server_tests/test_mock_server_abstract_source.py index c7fd2cef..0670b28c 100644 --- a/unit_tests/sources/mock_server_tests/test_mock_server_abstract_source.py +++ b/unit_tests/sources/mock_server_tests/test_mock_server_abstract_source.py @@ -21,7 +21,10 @@ ) from airbyte_cdk.test.state_builder import StateBuilder from unit_tests.sources.mock_server_tests.mock_source_fixture import SourceFixture -from unit_tests.sources.mock_server_tests.test_helpers import emits_successful_sync_status_messages, validate_message_order +from unit_tests.sources.mock_server_tests.test_helpers import ( + emits_successful_sync_status_messages, + validate_message_order, +) _NOW = datetime.now(timezone.utc) @@ -114,7 +117,11 @@ def _create_justice_songs_request() -> RequestBuilder: return RequestBuilder.justice_songs_endpoint() -RESPONSE_TEMPLATE = {"object": "list", "has_more": False, "data": [{"id": "123", "created_at": "2024-01-01T07:04:28.000Z"}]} +RESPONSE_TEMPLATE = { + "object": "list", + "has_more": False, + "data": [{"id": "123", "created_at": "2024-01-01T07:04:28.000Z"}], +} USER_TEMPLATE = { "object": "list", @@ -199,7 +206,9 @@ def _create_response(pagination_has_more: bool = False) -> HttpResponseBuilder: return create_response_builder( response_template=RESPONSE_TEMPLATE, records_path=FieldPath("data"), - pagination_strategy=FieldUpdatePaginationStrategy(FieldPath("has_more"), pagination_has_more), + pagination_strategy=FieldUpdatePaginationStrategy( + FieldPath("has_more"), pagination_has_more + ), ) @@ -220,18 +229,27 @@ def test_full_refresh_sync(self, http_mocker): http_mocker.get( _create_users_request().build(), - _create_response().with_record(record=_create_record("users")).with_record(record=_create_record("users")).build(), + _create_response() + .with_record(record=_create_record("users")) + .with_record(record=_create_record("users")) + .build(), ) source = SourceFixture() - actual_messages = read(source, config=config, catalog=_create_catalog([("users", SyncMode.full_refresh)])) + actual_messages = read( + source, config=config, catalog=_create_catalog([("users", SyncMode.full_refresh)]) + ) assert emits_successful_sync_status_messages(actual_messages.get_stream_statuses("users")) assert len(actual_messages.records) == 2 assert len(actual_messages.state_messages) == 1 - validate_message_order([Type.RECORD, Type.RECORD, Type.STATE], actual_messages.records_and_state_messages) + validate_message_order( + [Type.RECORD, Type.RECORD, Type.STATE], actual_messages.records_and_state_messages + ) assert actual_messages.state_messages[0].state.stream.stream_descriptor.name == "users" - assert actual_messages.state_messages[0].state.stream.stream_state == AirbyteStateBlob(__ab_full_refresh_sync_complete=True) + assert actual_messages.state_messages[0].state.stream.stream_state == AirbyteStateBlob( + __ab_full_refresh_sync_complete=True + ) assert actual_messages.state_messages[0].state.sourceStats.recordCount == 2.0 @HttpMocker() @@ -240,32 +258,52 @@ def test_substream_resumable_full_refresh_with_parent_slices(self, http_mocker): config = {"start_date": start_datetime.strftime("%Y-%m-%dT%H:%M:%SZ")} expected_first_substream_per_stream_state = [ - {"partition": {"divide_category": "dukes"}, "cursor": {"__ab_full_refresh_sync_complete": True}}, + { + "partition": {"divide_category": "dukes"}, + "cursor": {"__ab_full_refresh_sync_complete": True}, + }, ] expected_second_substream_per_stream_state = [ - {"partition": {"divide_category": "dukes"}, "cursor": {"__ab_full_refresh_sync_complete": True}}, - {"partition": {"divide_category": "mentats"}, "cursor": {"__ab_full_refresh_sync_complete": True}}, + { + "partition": {"divide_category": "dukes"}, + "cursor": {"__ab_full_refresh_sync_complete": True}, + }, + { + "partition": {"divide_category": "mentats"}, + "cursor": {"__ab_full_refresh_sync_complete": True}, + }, ] http_mocker.get( _create_dividers_request().with_category("dukes").build(), - _create_response().with_record(record=_create_record("dividers")).with_record(record=_create_record("dividers")).build(), + _create_response() + .with_record(record=_create_record("dividers")) + .with_record(record=_create_record("dividers")) + .build(), ) http_mocker.get( _create_dividers_request().with_category("mentats").build(), - _create_response().with_record(record=_create_record("dividers")).with_record(record=_create_record("dividers")).build(), + _create_response() + .with_record(record=_create_record("dividers")) + .with_record(record=_create_record("dividers")) + .build(), ) source = SourceFixture() - actual_messages = read(source, config=config, catalog=_create_catalog([("dividers", SyncMode.full_refresh)])) + actual_messages = read( + source, config=config, catalog=_create_catalog([("dividers", SyncMode.full_refresh)]) + ) - assert emits_successful_sync_status_messages(actual_messages.get_stream_statuses("dividers")) + assert emits_successful_sync_status_messages( + actual_messages.get_stream_statuses("dividers") + ) assert len(actual_messages.records) == 4 assert len(actual_messages.state_messages) == 2 validate_message_order( - [Type.RECORD, Type.RECORD, Type.STATE, Type.RECORD, Type.RECORD, Type.STATE], actual_messages.records_and_state_messages + [Type.RECORD, Type.RECORD, Type.STATE, Type.RECORD, Type.RECORD, Type.STATE], + actual_messages.records_and_state_messages, ) assert actual_messages.state_messages[0].state.stream.stream_state == AirbyteStateBlob( states=expected_first_substream_per_stream_state @@ -286,7 +324,10 @@ def test_incremental_sync(self, http_mocker): last_record_date_0 = (start_datetime + timedelta(days=4)).strftime("%Y-%m-%dT%H:%M:%SZ") http_mocker.get( - _create_planets_request().with_start_date(start_datetime).with_end_date(start_datetime + timedelta(days=7)).build(), + _create_planets_request() + .with_start_date(start_datetime) + .with_end_date(start_datetime + timedelta(days=7)) + .build(), _create_response() .with_record(record=_create_record("planets").with_cursor(last_record_date_0)) .with_record(record=_create_record("planets").with_cursor(last_record_date_0)) @@ -296,7 +337,10 @@ def test_incremental_sync(self, http_mocker): last_record_date_1 = (_NOW - timedelta(days=1)).strftime("%Y-%m-%dT%H:%M:%SZ") http_mocker.get( - _create_planets_request().with_start_date(start_datetime + timedelta(days=7)).with_end_date(_NOW).build(), + _create_planets_request() + .with_start_date(start_datetime + timedelta(days=7)) + .with_end_date(_NOW) + .build(), _create_response() .with_record(record=_create_record("planets").with_cursor(last_record_date_1)) .with_record(record=_create_record("planets").with_cursor(last_record_date_1)) @@ -304,20 +348,34 @@ def test_incremental_sync(self, http_mocker): ) source = SourceFixture() - actual_messages = read(source, config=config, catalog=_create_catalog([("planets", SyncMode.incremental)])) + actual_messages = read( + source, config=config, catalog=_create_catalog([("planets", SyncMode.incremental)]) + ) assert emits_successful_sync_status_messages(actual_messages.get_stream_statuses("planets")) assert len(actual_messages.records) == 5 assert len(actual_messages.state_messages) == 2 validate_message_order( - [Type.RECORD, Type.RECORD, Type.RECORD, Type.STATE, Type.RECORD, Type.RECORD, Type.STATE], + [ + Type.RECORD, + Type.RECORD, + Type.RECORD, + Type.STATE, + Type.RECORD, + Type.RECORD, + Type.STATE, + ], actual_messages.records_and_state_messages, ) assert actual_messages.state_messages[0].state.stream.stream_descriptor.name == "planets" - assert actual_messages.state_messages[0].state.stream.stream_state == AirbyteStateBlob(created_at=last_record_date_0) + assert actual_messages.state_messages[0].state.stream.stream_state == AirbyteStateBlob( + created_at=last_record_date_0 + ) assert actual_messages.state_messages[0].state.sourceStats.recordCount == 3.0 assert actual_messages.state_messages[1].state.stream.stream_descriptor.name == "planets" - assert actual_messages.state_messages[1].state.stream.stream_state == AirbyteStateBlob(created_at=last_record_date_1) + assert actual_messages.state_messages[1].state.stream.stream_state == AirbyteStateBlob( + created_at=last_record_date_1 + ) assert actual_messages.state_messages[1].state.sourceStats.recordCount == 2.0 @HttpMocker() @@ -327,7 +385,10 @@ def test_incremental_running_as_full_refresh(self, http_mocker): last_record_date_0 = (start_datetime + timedelta(days=4)).strftime("%Y-%m-%dT%H:%M:%SZ") http_mocker.get( - _create_planets_request().with_start_date(start_datetime).with_end_date(start_datetime + timedelta(days=7)).build(), + _create_planets_request() + .with_start_date(start_datetime) + .with_end_date(start_datetime + timedelta(days=7)) + .build(), _create_response() .with_record(record=_create_record("planets").with_cursor(last_record_date_0)) .with_record(record=_create_record("planets").with_cursor(last_record_date_0)) @@ -337,7 +398,10 @@ def test_incremental_running_as_full_refresh(self, http_mocker): last_record_date_1 = (_NOW - timedelta(days=1)).strftime("%Y-%m-%dT%H:%M:%SZ") http_mocker.get( - _create_planets_request().with_start_date(start_datetime + timedelta(days=7)).with_end_date(_NOW).build(), + _create_planets_request() + .with_start_date(start_datetime + timedelta(days=7)) + .with_end_date(_NOW) + .build(), _create_response() .with_record(record=_create_record("planets").with_cursor(last_record_date_1)) .with_record(record=_create_record("planets").with_cursor(last_record_date_1)) @@ -345,21 +409,35 @@ def test_incremental_running_as_full_refresh(self, http_mocker): ) source = SourceFixture() - actual_messages = read(source, config=config, catalog=_create_catalog([("planets", SyncMode.full_refresh)])) + actual_messages = read( + source, config=config, catalog=_create_catalog([("planets", SyncMode.full_refresh)]) + ) assert emits_successful_sync_status_messages(actual_messages.get_stream_statuses("planets")) assert len(actual_messages.records) == 5 assert len(actual_messages.state_messages) == 2 validate_message_order( - [Type.RECORD, Type.RECORD, Type.RECORD, Type.STATE, Type.RECORD, Type.RECORD, Type.STATE], + [ + Type.RECORD, + Type.RECORD, + Type.RECORD, + Type.STATE, + Type.RECORD, + Type.RECORD, + Type.STATE, + ], actual_messages.records_and_state_messages, ) assert actual_messages.state_messages[0].state.stream.stream_descriptor.name == "planets" - assert actual_messages.state_messages[0].state.stream.stream_state == AirbyteStateBlob(created_at=last_record_date_0) + assert actual_messages.state_messages[0].state.stream.stream_state == AirbyteStateBlob( + created_at=last_record_date_0 + ) assert actual_messages.state_messages[0].state.sourceStats.recordCount == 3.0 assert actual_messages.state_messages[1].state.stream.stream_descriptor.name == "planets" - assert actual_messages.state_messages[1].state.stream.stream_state == AirbyteStateBlob(created_at=last_record_date_1) + assert actual_messages.state_messages[1].state.stream.stream_state == AirbyteStateBlob( + created_at=last_record_date_1 + ) assert actual_messages.state_messages[1].state.sourceStats.recordCount == 2.0 @HttpMocker() @@ -369,7 +447,10 @@ def test_legacy_incremental_sync(self, http_mocker): last_record_date_0 = (start_datetime + timedelta(days=4)).strftime("%Y-%m-%dT%H:%M:%SZ") http_mocker.get( - _create_legacies_request().with_start_date(start_datetime).with_end_date(start_datetime + timedelta(days=7)).build(), + _create_legacies_request() + .with_start_date(start_datetime) + .with_end_date(start_datetime + timedelta(days=7)) + .build(), _create_response() .with_record(record=_create_record("legacies").with_cursor(last_record_date_0)) .with_record(record=_create_record("legacies").with_cursor(last_record_date_0)) @@ -379,7 +460,10 @@ def test_legacy_incremental_sync(self, http_mocker): last_record_date_1 = (_NOW - timedelta(days=1)).strftime("%Y-%m-%dT%H:%M:%SZ") http_mocker.get( - _create_legacies_request().with_start_date(start_datetime + timedelta(days=7)).with_end_date(_NOW).build(), + _create_legacies_request() + .with_start_date(start_datetime + timedelta(days=7)) + .with_end_date(_NOW) + .build(), _create_response() .with_record(record=_create_record("legacies").with_cursor(last_record_date_1)) .with_record(record=_create_record("legacies").with_cursor(last_record_date_1)) @@ -387,20 +471,36 @@ def test_legacy_incremental_sync(self, http_mocker): ) source = SourceFixture() - actual_messages = read(source, config=config, catalog=_create_catalog([("legacies", SyncMode.incremental)])) + actual_messages = read( + source, config=config, catalog=_create_catalog([("legacies", SyncMode.incremental)]) + ) - assert emits_successful_sync_status_messages(actual_messages.get_stream_statuses("legacies")) + assert emits_successful_sync_status_messages( + actual_messages.get_stream_statuses("legacies") + ) assert len(actual_messages.records) == 5 assert len(actual_messages.state_messages) == 2 validate_message_order( - [Type.RECORD, Type.RECORD, Type.RECORD, Type.STATE, Type.RECORD, Type.RECORD, Type.STATE], + [ + Type.RECORD, + Type.RECORD, + Type.RECORD, + Type.STATE, + Type.RECORD, + Type.RECORD, + Type.STATE, + ], actual_messages.records_and_state_messages, ) assert actual_messages.state_messages[0].state.stream.stream_descriptor.name == "legacies" - assert actual_messages.state_messages[0].state.stream.stream_state == AirbyteStateBlob(created_at=last_record_date_0) + assert actual_messages.state_messages[0].state.stream.stream_state == AirbyteStateBlob( + created_at=last_record_date_0 + ) assert actual_messages.state_messages[0].state.sourceStats.recordCount == 3.0 assert actual_messages.state_messages[1].state.stream.stream_descriptor.name == "legacies" - assert actual_messages.state_messages[1].state.stream.stream_state == AirbyteStateBlob(created_at=last_record_date_1) + assert actual_messages.state_messages[1].state.stream.stream_state == AirbyteStateBlob( + created_at=last_record_date_1 + ) assert actual_messages.state_messages[1].state.sourceStats.recordCount == 2.0 @HttpMocker() @@ -410,7 +510,10 @@ def test_legacy_no_records_retains_incoming_state(self, http_mocker): last_record_date_1 = (_NOW - timedelta(days=1)).strftime("%Y-%m-%dT%H:%M:%SZ") http_mocker.get( - _create_legacies_request().with_start_date(_NOW - timedelta(days=1)).with_end_date(_NOW).build(), + _create_legacies_request() + .with_start_date(_NOW - timedelta(days=1)) + .with_end_date(_NOW) + .build(), _create_response().build(), ) @@ -418,7 +521,12 @@ def test_legacy_no_records_retains_incoming_state(self, http_mocker): state = StateBuilder().with_stream_state("legacies", incoming_state).build() source = SourceFixture() - actual_messages = read(source, config=config, catalog=_create_catalog([("legacies", SyncMode.incremental)]), state=state) + actual_messages = read( + source, + config=config, + catalog=_create_catalog([("legacies", SyncMode.incremental)]), + state=state, + ) assert actual_messages.state_messages[0].state.stream.stream_descriptor.name == "legacies" assert actual_messages.state_messages[0].state.stream.stream_state == incoming_state @@ -435,7 +543,12 @@ def test_legacy_no_slices_retains_incoming_state(self, http_mocker): state = StateBuilder().with_stream_state("legacies", incoming_state).build() source = SourceFixture() - actual_messages = read(source, config=config, catalog=_create_catalog([("legacies", SyncMode.incremental)]), state=state) + actual_messages = read( + source, + config=config, + catalog=_create_catalog([("legacies", SyncMode.incremental)]), + state=state, + ) assert actual_messages.state_messages[0].state.stream.stream_descriptor.name == "legacies" assert actual_messages.state_messages[0].state.stream.stream_state == incoming_state @@ -450,24 +563,39 @@ def test_incremental_and_full_refresh_streams(self, http_mocker): config = {"start_date": start_datetime.strftime("%Y-%m-%dT%H:%M:%SZ")} expected_first_substream_per_stream_state = [ - {"partition": {"divide_category": "dukes"}, "cursor": {"__ab_full_refresh_sync_complete": True}}, + { + "partition": {"divide_category": "dukes"}, + "cursor": {"__ab_full_refresh_sync_complete": True}, + }, ] expected_second_substream_per_stream_state = [ - {"partition": {"divide_category": "dukes"}, "cursor": {"__ab_full_refresh_sync_complete": True}}, - {"partition": {"divide_category": "mentats"}, "cursor": {"__ab_full_refresh_sync_complete": True}}, + { + "partition": {"divide_category": "dukes"}, + "cursor": {"__ab_full_refresh_sync_complete": True}, + }, + { + "partition": {"divide_category": "mentats"}, + "cursor": {"__ab_full_refresh_sync_complete": True}, + }, ] # Mocks for users full refresh stream http_mocker.get( _create_users_request().build(), - _create_response().with_record(record=_create_record("users")).with_record(record=_create_record("users")).build(), + _create_response() + .with_record(record=_create_record("users")) + .with_record(record=_create_record("users")) + .build(), ) # Mocks for planets incremental stream last_record_date_0 = (start_datetime + timedelta(days=4)).strftime("%Y-%m-%dT%H:%M:%SZ") http_mocker.get( - _create_planets_request().with_start_date(start_datetime).with_end_date(start_datetime + timedelta(days=7)).build(), + _create_planets_request() + .with_start_date(start_datetime) + .with_end_date(start_datetime + timedelta(days=7)) + .build(), _create_response() .with_record(record=_create_record("planets").with_cursor(last_record_date_0)) .with_record(record=_create_record("planets").with_cursor(last_record_date_0)) @@ -477,7 +605,10 @@ def test_incremental_and_full_refresh_streams(self, http_mocker): last_record_date_1 = (_NOW - timedelta(days=1)).strftime("%Y-%m-%dT%H:%M:%SZ") http_mocker.get( - _create_planets_request().with_start_date(start_datetime + timedelta(days=7)).with_end_date(_NOW).build(), + _create_planets_request() + .with_start_date(start_datetime + timedelta(days=7)) + .with_end_date(_NOW) + .build(), _create_response() .with_record(record=_create_record("planets").with_cursor(last_record_date_1)) .with_record(record=_create_record("planets").with_cursor(last_record_date_1)) @@ -487,12 +618,18 @@ def test_incremental_and_full_refresh_streams(self, http_mocker): # Mocks for dividers full refresh stream http_mocker.get( _create_dividers_request().with_category("dukes").build(), - _create_response().with_record(record=_create_record("dividers")).with_record(record=_create_record("dividers")).build(), + _create_response() + .with_record(record=_create_record("dividers")) + .with_record(record=_create_record("dividers")) + .build(), ) http_mocker.get( _create_dividers_request().with_category("mentats").build(), - _create_response().with_record(record=_create_record("dividers")).with_record(record=_create_record("dividers")).build(), + _create_response() + .with_record(record=_create_record("dividers")) + .with_record(record=_create_record("dividers")) + .build(), ) source = SourceFixture() @@ -500,13 +637,19 @@ def test_incremental_and_full_refresh_streams(self, http_mocker): source, config=config, catalog=_create_catalog( - [("users", SyncMode.full_refresh), ("planets", SyncMode.incremental), ("dividers", SyncMode.full_refresh)] + [ + ("users", SyncMode.full_refresh), + ("planets", SyncMode.incremental), + ("dividers", SyncMode.full_refresh), + ] ), ) assert emits_successful_sync_status_messages(actual_messages.get_stream_statuses("users")) assert emits_successful_sync_status_messages(actual_messages.get_stream_statuses("planets")) - assert emits_successful_sync_status_messages(actual_messages.get_stream_statuses("dividers")) + assert emits_successful_sync_status_messages( + actual_messages.get_stream_statuses("dividers") + ) assert len(actual_messages.records) == 11 assert len(actual_messages.state_messages) == 5 @@ -532,13 +675,19 @@ def test_incremental_and_full_refresh_streams(self, http_mocker): actual_messages.records_and_state_messages, ) assert actual_messages.state_messages[0].state.stream.stream_descriptor.name == "users" - assert actual_messages.state_messages[0].state.stream.stream_state == AirbyteStateBlob(__ab_full_refresh_sync_complete=True) + assert actual_messages.state_messages[0].state.stream.stream_state == AirbyteStateBlob( + __ab_full_refresh_sync_complete=True + ) assert actual_messages.state_messages[0].state.sourceStats.recordCount == 2.0 assert actual_messages.state_messages[1].state.stream.stream_descriptor.name == "planets" - assert actual_messages.state_messages[1].state.stream.stream_state == AirbyteStateBlob(created_at=last_record_date_0) + assert actual_messages.state_messages[1].state.stream.stream_state == AirbyteStateBlob( + created_at=last_record_date_0 + ) assert actual_messages.state_messages[1].state.sourceStats.recordCount == 3.0 assert actual_messages.state_messages[2].state.stream.stream_descriptor.name == "planets" - assert actual_messages.state_messages[2].state.stream.stream_state == AirbyteStateBlob(created_at=last_record_date_1) + assert actual_messages.state_messages[2].state.stream.stream_state == AirbyteStateBlob( + created_at=last_record_date_1 + ) assert actual_messages.state_messages[2].state.sourceStats.recordCount == 2.0 assert actual_messages.state_messages[3].state.stream.stream_descriptor.name == "dividers" assert actual_messages.state_messages[3].state.stream.stream_state == AirbyteStateBlob( diff --git a/unit_tests/sources/mock_server_tests/test_resumable_full_refresh.py b/unit_tests/sources/mock_server_tests/test_resumable_full_refresh.py index f5a9e857..b3e3c2ac 100644 --- a/unit_tests/sources/mock_server_tests/test_resumable_full_refresh.py +++ b/unit_tests/sources/mock_server_tests/test_resumable_full_refresh.py @@ -7,7 +7,14 @@ from unittest import TestCase import freezegun -from airbyte_cdk.models import AirbyteStateBlob, AirbyteStreamStatus, ConfiguredAirbyteCatalog, FailureType, SyncMode, Type +from airbyte_cdk.models import ( + AirbyteStateBlob, + AirbyteStreamStatus, + ConfiguredAirbyteCatalog, + FailureType, + SyncMode, + Type, +) from airbyte_cdk.test.catalog_builder import ConfiguredAirbyteStreamBuilder from airbyte_cdk.test.entrypoint_wrapper import read from airbyte_cdk.test.mock_http import HttpMocker, HttpRequest @@ -21,7 +28,10 @@ ) from airbyte_cdk.test.state_builder import StateBuilder from unit_tests.sources.mock_server_tests.mock_source_fixture import SourceFixture -from unit_tests.sources.mock_server_tests.test_helpers import emits_successful_sync_status_messages, validate_message_order +from unit_tests.sources.mock_server_tests.test_helpers import ( + emits_successful_sync_status_messages, + validate_message_order, +) _NOW = datetime.now(timezone.utc) @@ -50,11 +60,17 @@ def build(self) -> HttpRequest: ) -def _create_catalog(names_and_sync_modes: List[tuple[str, SyncMode, Dict[str, Any]]]) -> ConfiguredAirbyteCatalog: +def _create_catalog( + names_and_sync_modes: List[tuple[str, SyncMode, Dict[str, Any]]], +) -> ConfiguredAirbyteCatalog: stream_builder = ConfiguredAirbyteStreamBuilder() streams = [] for stream_name, sync_mode, json_schema in names_and_sync_modes: - streams.append(stream_builder.with_name(stream_name).with_sync_mode(sync_mode).with_json_schema(json_schema or {})) + streams.append( + stream_builder.with_name(stream_name) + .with_sync_mode(sync_mode) + .with_json_schema(json_schema or {}) + ) return ConfiguredAirbyteCatalog(streams=list(map(lambda builder: builder.build(), streams))) @@ -63,7 +79,11 @@ def _create_justice_songs_request() -> RequestBuilder: return RequestBuilder.justice_songs_endpoint() -RESPONSE_TEMPLATE = {"object": "list", "has_more": False, "data": [{"id": "123", "created_at": "2024-01-01T07:04:28.000Z"}]} +RESPONSE_TEMPLATE = { + "object": "list", + "has_more": False, + "data": [{"id": "123", "created_at": "2024-01-01T07:04:28.000Z"}], +} JUSTICE_SONGS_TEMPLATE = { @@ -95,7 +115,9 @@ def _create_response(pagination_has_more: bool = False) -> HttpResponseBuilder: return create_response_builder( response_template=RESPONSE_TEMPLATE, records_path=FieldPath("data"), - pagination_strategy=FieldUpdatePaginationStrategy(FieldPath("has_more"), pagination_has_more), + pagination_strategy=FieldUpdatePaginationStrategy( + FieldPath("has_more"), pagination_has_more + ), ) @@ -134,30 +156,65 @@ def test_resumable_full_refresh_sync(self, http_mocker): http_mocker.get( _create_justice_songs_request().with_page(2).build(), - _create_response(pagination_has_more=False).with_pagination().with_record(record=_create_record("justice_songs")).build(), + _create_response(pagination_has_more=False) + .with_pagination() + .with_record(record=_create_record("justice_songs")) + .build(), ) source = SourceFixture() - actual_messages = read(source, config=config, catalog=_create_catalog([("justice_songs", SyncMode.full_refresh, {})])) + actual_messages = read( + source, + config=config, + catalog=_create_catalog([("justice_songs", SyncMode.full_refresh, {})]), + ) - assert emits_successful_sync_status_messages(actual_messages.get_stream_statuses("justice_songs")) + assert emits_successful_sync_status_messages( + actual_messages.get_stream_statuses("justice_songs") + ) assert len(actual_messages.records) == 5 assert len(actual_messages.state_messages) == 4 validate_message_order( - [Type.RECORD, Type.RECORD, Type.STATE, Type.RECORD, Type.RECORD, Type.STATE, Type.RECORD, Type.STATE, Type.STATE], + [ + Type.RECORD, + Type.RECORD, + Type.STATE, + Type.RECORD, + Type.RECORD, + Type.STATE, + Type.RECORD, + Type.STATE, + Type.STATE, + ], actual_messages.records_and_state_messages, ) - assert actual_messages.state_messages[0].state.stream.stream_descriptor.name == "justice_songs" - assert actual_messages.state_messages[0].state.stream.stream_state == AirbyteStateBlob(page=1) + assert ( + actual_messages.state_messages[0].state.stream.stream_descriptor.name == "justice_songs" + ) + assert actual_messages.state_messages[0].state.stream.stream_state == AirbyteStateBlob( + page=1 + ) assert actual_messages.state_messages[0].state.sourceStats.recordCount == 2.0 - assert actual_messages.state_messages[1].state.stream.stream_descriptor.name == "justice_songs" - assert actual_messages.state_messages[1].state.stream.stream_state == AirbyteStateBlob(page=2) + assert ( + actual_messages.state_messages[1].state.stream.stream_descriptor.name == "justice_songs" + ) + assert actual_messages.state_messages[1].state.stream.stream_state == AirbyteStateBlob( + page=2 + ) assert actual_messages.state_messages[1].state.sourceStats.recordCount == 2.0 - assert actual_messages.state_messages[2].state.stream.stream_descriptor.name == "justice_songs" - assert actual_messages.state_messages[2].state.stream.stream_state == AirbyteStateBlob(__ab_full_refresh_sync_complete=True) + assert ( + actual_messages.state_messages[2].state.stream.stream_descriptor.name == "justice_songs" + ) + assert actual_messages.state_messages[2].state.stream.stream_state == AirbyteStateBlob( + __ab_full_refresh_sync_complete=True + ) assert actual_messages.state_messages[2].state.sourceStats.recordCount == 1.0 - assert actual_messages.state_messages[3].state.stream.stream_descriptor.name == "justice_songs" - assert actual_messages.state_messages[3].state.stream.stream_state == AirbyteStateBlob(__ab_full_refresh_sync_complete=True) + assert ( + actual_messages.state_messages[3].state.stream.stream_descriptor.name == "justice_songs" + ) + assert actual_messages.state_messages[3].state.stream.stream_state == AirbyteStateBlob( + __ab_full_refresh_sync_complete=True + ) assert actual_messages.state_messages[3].state.sourceStats.recordCount == 0.0 @HttpMocker() @@ -196,9 +253,16 @@ def test_resumable_full_refresh_second_attempt(self, http_mocker): ) source = SourceFixture() - actual_messages = read(source, config=config, catalog=_create_catalog([("justice_songs", SyncMode.full_refresh, {})]), state=state) + actual_messages = read( + source, + config=config, + catalog=_create_catalog([("justice_songs", SyncMode.full_refresh, {})]), + state=state, + ) - assert emits_successful_sync_status_messages(actual_messages.get_stream_statuses("justice_songs")) + assert emits_successful_sync_status_messages( + actual_messages.get_stream_statuses("justice_songs") + ) assert len(actual_messages.records) == 8 assert len(actual_messages.state_messages) == 4 validate_message_order( @@ -218,17 +282,33 @@ def test_resumable_full_refresh_second_attempt(self, http_mocker): ], actual_messages.records_and_state_messages, ) - assert actual_messages.state_messages[0].state.stream.stream_descriptor.name == "justice_songs" - assert actual_messages.state_messages[0].state.stream.stream_state == AirbyteStateBlob(page=101) + assert ( + actual_messages.state_messages[0].state.stream.stream_descriptor.name == "justice_songs" + ) + assert actual_messages.state_messages[0].state.stream.stream_state == AirbyteStateBlob( + page=101 + ) assert actual_messages.state_messages[0].state.sourceStats.recordCount == 3.0 - assert actual_messages.state_messages[1].state.stream.stream_descriptor.name == "justice_songs" - assert actual_messages.state_messages[1].state.stream.stream_state == AirbyteStateBlob(page=102) + assert ( + actual_messages.state_messages[1].state.stream.stream_descriptor.name == "justice_songs" + ) + assert actual_messages.state_messages[1].state.stream.stream_state == AirbyteStateBlob( + page=102 + ) assert actual_messages.state_messages[1].state.sourceStats.recordCount == 3.0 - assert actual_messages.state_messages[2].state.stream.stream_descriptor.name == "justice_songs" - assert actual_messages.state_messages[2].state.stream.stream_state == AirbyteStateBlob(__ab_full_refresh_sync_complete=True) + assert ( + actual_messages.state_messages[2].state.stream.stream_descriptor.name == "justice_songs" + ) + assert actual_messages.state_messages[2].state.stream.stream_state == AirbyteStateBlob( + __ab_full_refresh_sync_complete=True + ) assert actual_messages.state_messages[2].state.sourceStats.recordCount == 2.0 - assert actual_messages.state_messages[3].state.stream.stream_descriptor.name == "justice_songs" - assert actual_messages.state_messages[3].state.stream.stream_state == AirbyteStateBlob(__ab_full_refresh_sync_complete=True) + assert ( + actual_messages.state_messages[3].state.stream.stream_descriptor.name == "justice_songs" + ) + assert actual_messages.state_messages[3].state.stream.stream_state == AirbyteStateBlob( + __ab_full_refresh_sync_complete=True + ) assert actual_messages.state_messages[3].state.sourceStats.recordCount == 0.0 @HttpMocker() @@ -253,11 +333,17 @@ def test_resumable_full_refresh_failure(self, http_mocker): .build(), ) - http_mocker.get(_create_justice_songs_request().with_page(2).build(), _create_response().with_status_code(status_code=400).build()) + http_mocker.get( + _create_justice_songs_request().with_page(2).build(), + _create_response().with_status_code(status_code=400).build(), + ) source = SourceFixture() actual_messages = read( - source, config=config, catalog=_create_catalog([("justice_songs", SyncMode.full_refresh, {})]), expecting_exception=True + source, + config=config, + catalog=_create_catalog([("justice_songs", SyncMode.full_refresh, {})]), + expecting_exception=True, ) status_messages = actual_messages.get_stream_statuses("justice_songs") @@ -266,12 +352,21 @@ def test_resumable_full_refresh_failure(self, http_mocker): assert len(actual_messages.state_messages) == 2 validate_message_order( - [Type.RECORD, Type.RECORD, Type.STATE, Type.RECORD, Type.RECORD, Type.STATE], actual_messages.records_and_state_messages + [Type.RECORD, Type.RECORD, Type.STATE, Type.RECORD, Type.RECORD, Type.STATE], + actual_messages.records_and_state_messages, + ) + assert ( + actual_messages.state_messages[0].state.stream.stream_descriptor.name == "justice_songs" + ) + assert actual_messages.state_messages[0].state.stream.stream_state == AirbyteStateBlob( + page=1 + ) + assert ( + actual_messages.state_messages[1].state.stream.stream_descriptor.name == "justice_songs" + ) + assert actual_messages.state_messages[1].state.stream.stream_state == AirbyteStateBlob( + page=2 ) - assert actual_messages.state_messages[0].state.stream.stream_descriptor.name == "justice_songs" - assert actual_messages.state_messages[0].state.stream.stream_state == AirbyteStateBlob(page=1) - assert actual_messages.state_messages[1].state.stream.stream_descriptor.name == "justice_songs" - assert actual_messages.state_messages[1].state.stream.stream_state == AirbyteStateBlob(page=2) assert actual_messages.errors[0].trace.error.failure_type == FailureType.system_error assert actual_messages.errors[0].trace.error.stream_descriptor.name == "justice_songs" diff --git a/unit_tests/sources/streams/checkpoint/test_checkpoint_reader.py b/unit_tests/sources/streams/checkpoint/test_checkpoint_reader.py index 01ddd363..35db9c24 100644 --- a/unit_tests/sources/streams/checkpoint/test_checkpoint_reader.py +++ b/unit_tests/sources/streams/checkpoint/test_checkpoint_reader.py @@ -50,7 +50,9 @@ def test_incremental_checkpoint_reader_incoming_state(): def test_resumable_full_refresh_checkpoint_reader_next(): - checkpoint_reader = ResumableFullRefreshCheckpointReader(stream_state={"synthetic_page_number": 55}) + checkpoint_reader = ResumableFullRefreshCheckpointReader( + stream_state={"synthetic_page_number": 55} + ) checkpoint_reader.observe({"synthetic_page_number": 56}) assert checkpoint_reader.next() == {"synthetic_page_number": 56} @@ -97,9 +99,15 @@ def test_full_refresh_checkpoint_reader_substream(): def test_cursor_based_checkpoint_reader_incremental(): expected_slices = [ - StreamSlice(cursor_slice={"start_date": "2024-01-01", "end_date": "2024-02-01"}, partition={}), - StreamSlice(cursor_slice={"start_date": "2024-02-01", "end_date": "2024-03-01"}, partition={}), - StreamSlice(cursor_slice={"start_date": "2024-03-01", "end_date": "2024-04-01"}, partition={}), + StreamSlice( + cursor_slice={"start_date": "2024-01-01", "end_date": "2024-02-01"}, partition={} + ), + StreamSlice( + cursor_slice={"start_date": "2024-02-01", "end_date": "2024-03-01"}, partition={} + ), + StreamSlice( + cursor_slice={"start_date": "2024-03-01", "end_date": "2024-04-01"}, partition={} + ), ] expected_stream_state = {"end_date": "2024-02-01"} @@ -110,7 +118,9 @@ def test_cursor_based_checkpoint_reader_incremental(): incremental_cursor.get_stream_state.return_value = expected_stream_state checkpoint_reader = CursorBasedCheckpointReader( - cursor=incremental_cursor, stream_slices=incremental_cursor.stream_slices(), read_state_from_cursor=False + cursor=incremental_cursor, + stream_slices=incremental_cursor.stream_slices(), + read_state_from_cursor=False, ) assert checkpoint_reader.next() == expected_slices[0] @@ -265,7 +275,9 @@ def test_cursor_based_checkpoint_reader_sync_first_parent_slice(): StreamSlice(cursor_slice={}, partition={"parent_id": "naga"}), ] rfr_cursor.select_state.side_effect = [ - {"next_page_token": 3}, # Accounts for the first invocation when checking if partition was already successful + { + "next_page_token": 3 + }, # Accounts for the first invocation when checking if partition was already successful {"next_page_token": 4}, {"next_page_token": 4}, {"__ab_full_refresh_sync_complete": True}, @@ -293,7 +305,9 @@ def test_cursor_based_checkpoint_reader_sync_first_parent_slice(): def test_cursor_based_checkpoint_reader_resumable_full_refresh_invalid_slice(): rfr_cursor = Mock() rfr_cursor.stream_slices.return_value = [{"invalid": "stream_slice"}] - rfr_cursor.select_state.side_effect = [StreamSlice(cursor_slice={"invalid": "stream_slice"}, partition={})] + rfr_cursor.select_state.side_effect = [ + StreamSlice(cursor_slice={"invalid": "stream_slice"}, partition={}) + ] checkpoint_reader = CursorBasedCheckpointReader( cursor=rfr_cursor, stream_slices=rfr_cursor.stream_slices(), read_state_from_cursor=True @@ -306,10 +320,30 @@ def test_cursor_based_checkpoint_reader_resumable_full_refresh_invalid_slice(): def test_legacy_cursor_based_checkpoint_reader_resumable_full_refresh(): expected_mapping_slices = [ {"parent_id": 400, "partition": {"parent_id": 400}, "cursor_slice": {}}, - {"parent_id": 400, "next_page_token": 2, "partition": {"parent_id": 400}, "cursor_slice": {"next_page_token": 2}}, - {"parent_id": 400, "next_page_token": 2, "partition": {"parent_id": 400}, "cursor_slice": {"next_page_token": 2}}, - {"parent_id": 400, "next_page_token": 3, "partition": {"parent_id": 400}, "cursor_slice": {"next_page_token": 3}}, - {"parent_id": 400, "next_page_token": 4, "partition": {"parent_id": 400}, "cursor_slice": {"next_page_token": 4}}, + { + "parent_id": 400, + "next_page_token": 2, + "partition": {"parent_id": 400}, + "cursor_slice": {"next_page_token": 2}, + }, + { + "parent_id": 400, + "next_page_token": 2, + "partition": {"parent_id": 400}, + "cursor_slice": {"next_page_token": 2}, + }, + { + "parent_id": 400, + "next_page_token": 3, + "partition": {"parent_id": 400}, + "cursor_slice": {"next_page_token": 3}, + }, + { + "parent_id": 400, + "next_page_token": 4, + "partition": {"parent_id": 400}, + "cursor_slice": {"next_page_token": 4}, + }, { "parent_id": 400, "__ab_full_refresh_sync_complete": True, diff --git a/unit_tests/sources/streams/checkpoint/test_substream_resumable_full_refresh_cursor.py b/unit_tests/sources/streams/checkpoint/test_substream_resumable_full_refresh_cursor.py index 49445185..f023b038 100644 --- a/unit_tests/sources/streams/checkpoint/test_substream_resumable_full_refresh_cursor.py +++ b/unit_tests/sources/streams/checkpoint/test_substream_resumable_full_refresh_cursor.py @@ -1,7 +1,9 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. import pytest -from airbyte_cdk.sources.streams.checkpoint.substream_resumable_full_refresh_cursor import SubstreamResumableFullRefreshCursor +from airbyte_cdk.sources.streams.checkpoint.substream_resumable_full_refresh_cursor import ( + SubstreamResumableFullRefreshCursor, +) from airbyte_cdk.sources.types import StreamSlice from airbyte_cdk.utils import AirbyteTracedException @@ -14,8 +16,14 @@ def test_substream_resumable_full_refresh_cursor(): expected_ending_state = { "states": [ - {"partition": {"musician_id": "kousei_arima"}, "cursor": {"__ab_full_refresh_sync_complete": True}}, - {"partition": {"musician_id": "kaori_miyazono"}, "cursor": {"__ab_full_refresh_sync_complete": True}}, + { + "partition": {"musician_id": "kousei_arima"}, + "cursor": {"__ab_full_refresh_sync_complete": True}, + }, + { + "partition": {"musician_id": "kaori_miyazono"}, + "cursor": {"__ab_full_refresh_sync_complete": True}, + }, ] } @@ -44,18 +52,36 @@ def test_substream_resumable_full_refresh_cursor_with_state(): """ initial_state = { "states": [ - {"partition": {"musician_id": "kousei_arima"}, "cursor": {"__ab_full_refresh_sync_complete": True}}, - {"partition": {"musician_id": "kaori_miyazono"}, "cursor": {"__ab_full_refresh_sync_complete": True}}, + { + "partition": {"musician_id": "kousei_arima"}, + "cursor": {"__ab_full_refresh_sync_complete": True}, + }, + { + "partition": {"musician_id": "kaori_miyazono"}, + "cursor": {"__ab_full_refresh_sync_complete": True}, + }, {"partition": {"musician_id": "takeshi_aiza"}, "cursor": {}}, ] } expected_ending_state = { "states": [ - {"partition": {"musician_id": "kousei_arima"}, "cursor": {"__ab_full_refresh_sync_complete": True}}, - {"partition": {"musician_id": "kaori_miyazono"}, "cursor": {"__ab_full_refresh_sync_complete": True}}, - {"partition": {"musician_id": "takeshi_aiza"}, "cursor": {"__ab_full_refresh_sync_complete": True}}, - {"partition": {"musician_id": "emi_igawa"}, "cursor": {"__ab_full_refresh_sync_complete": True}}, + { + "partition": {"musician_id": "kousei_arima"}, + "cursor": {"__ab_full_refresh_sync_complete": True}, + }, + { + "partition": {"musician_id": "kaori_miyazono"}, + "cursor": {"__ab_full_refresh_sync_complete": True}, + }, + { + "partition": {"musician_id": "takeshi_aiza"}, + "cursor": {"__ab_full_refresh_sync_complete": True}, + }, + { + "partition": {"musician_id": "emi_igawa"}, + "cursor": {"__ab_full_refresh_sync_complete": True}, + }, ] } diff --git a/unit_tests/sources/streams/concurrent/scenarios/incremental_scenarios.py b/unit_tests/sources/streams/concurrent/scenarios/incremental_scenarios.py index f3a4df14..c8bd0429 100644 --- a/unit_tests/sources/streams/concurrent/scenarios/incremental_scenarios.py +++ b/unit_tests/sources/streams/concurrent/scenarios/incremental_scenarios.py @@ -2,11 +2,18 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # from airbyte_cdk.sources.streams.concurrent.cursor import CursorField -from airbyte_cdk.sources.streams.concurrent.state_converters.abstract_stream_state_converter import ConcurrencyCompatibleStateType +from airbyte_cdk.sources.streams.concurrent.state_converters.abstract_stream_state_converter import ( + ConcurrencyCompatibleStateType, +) from airbyte_cdk.test.state_builder import StateBuilder from airbyte_cdk.utils.traced_exception import AirbyteTracedException -from unit_tests.sources.file_based.scenarios.scenario_builder import IncrementalScenarioConfig, TestScenarioBuilder -from unit_tests.sources.streams.concurrent.scenarios.stream_facade_builder import StreamFacadeSourceBuilder +from unit_tests.sources.file_based.scenarios.scenario_builder import ( + IncrementalScenarioConfig, + TestScenarioBuilder, +) +from unit_tests.sources.streams.concurrent.scenarios.stream_facade_builder import ( + StreamFacadeSourceBuilder, +) from unit_tests.sources.streams.concurrent.scenarios.utils import MockStream _NO_SLICE_BOUNDARIES = None @@ -21,8 +28,14 @@ [ MockStream( [ - ({"from": 0, "to": 1}, [{"id": "1", "cursor_field": 0}, {"id": "2", "cursor_field": 1}]), - ({"from": 1, "to": 2}, [{"id": "3", "cursor_field": 2}, {"id": "4", "cursor_field": 3}]), + ( + {"from": 0, "to": 1}, + [{"id": "1", "cursor_field": 0}, {"id": "2", "cursor_field": 1}], + ), + ( + {"from": 1, "to": 2}, + [{"id": "3", "cursor_field": 2}, {"id": "4", "cursor_field": 3}], + ), ], "stream1", cursor_field="cursor_field", @@ -55,8 +68,14 @@ [ MockStream( [ - ({"from": 0, "to": 1}, [{"id": "1", "cursor_field": 0}, {"id": "2", "cursor_field": 1}]), - ({"from": 1, "to": 2}, [{"id": "3", "cursor_field": 2}, {"id": "4", "cursor_field": 3}]), + ( + {"from": 0, "to": 1}, + [{"id": "1", "cursor_field": 0}, {"id": "2", "cursor_field": 1}], + ), + ( + {"from": 1, "to": 2}, + [{"id": "3", "cursor_field": 2}, {"id": "4", "cursor_field": 3}], + ), ], "stream1", cursor_field="cursor_field", @@ -100,8 +119,14 @@ [ MockStream( [ - ({"from": 0, "to": 1}, [{"id": "1", "cursor_field": 0}, {"id": "2", "cursor_field": 1}]), - ({"from": 1, "to": 2}, [{"id": "3", "cursor_field": 2}, {"id": "4", "cursor_field": 3}]), + ( + {"from": 0, "to": 1}, + [{"id": "1", "cursor_field": 0}, {"id": "2", "cursor_field": 1}], + ), + ( + {"from": 1, "to": 2}, + [{"id": "3", "cursor_field": 2}, {"id": "4", "cursor_field": 3}], + ), ], "stream1", cursor_field="cursor_field", @@ -134,8 +159,14 @@ [ MockStream( [ - ({"from": 0, "to": 1}, [{"id": "1", "cursor_field": 0}, {"id": "2", "cursor_field": 1}]), - ({"from": 1, "to": 2}, [{"id": "3", "cursor_field": 2}, {"id": "4", "cursor_field": 3}]), + ( + {"from": 0, "to": 1}, + [{"id": "1", "cursor_field": 0}, {"id": "2", "cursor_field": 1}], + ), + ( + {"from": 1, "to": 2}, + [{"id": "3", "cursor_field": 2}, {"id": "4", "cursor_field": 3}], + ), ], "stream1", cursor_field="cursor_field", @@ -189,8 +220,14 @@ [ MockStream( [ - ({"from": 0, "to": 1}, [{"id": "1", "cursor_field": 0}, {"id": "2", "cursor_field": 1}]), - ({"from": 1, "to": 2}, [{"id": "3", "cursor_field": 2}, {"id": "4", "cursor_field": 3}]), + ( + {"from": 0, "to": 1}, + [{"id": "1", "cursor_field": 0}, {"id": "2", "cursor_field": 1}], + ), + ( + {"from": 1, "to": 2}, + [{"id": "3", "cursor_field": 2}, {"id": "4", "cursor_field": 3}], + ), ], "stream1", cursor_field="cursor_field", @@ -223,8 +260,14 @@ [ MockStream( [ - ({"from": 0, "to": 1}, [{"id": "1", "cursor_field": 0}, {"id": "2", "cursor_field": 1}]), - ({"from": 1, "to": 2}, [{"id": "3", "cursor_field": 2}, {"id": "4", "cursor_field": 3}]), + ( + {"from": 0, "to": 1}, + [{"id": "1", "cursor_field": 0}, {"id": "2", "cursor_field": 1}], + ), + ( + {"from": 1, "to": 2}, + [{"id": "3", "cursor_field": 2}, {"id": "4", "cursor_field": 3}], + ), ], "stream1", cursor_field="cursor_field", diff --git a/unit_tests/sources/streams/concurrent/scenarios/stream_facade_builder.py b/unit_tests/sources/streams/concurrent/scenarios/stream_facade_builder.py index 38c44c08..50695ba1 100644 --- a/unit_tests/sources/streams/concurrent/scenarios/stream_facade_builder.py +++ b/unit_tests/sources/streams/concurrent/scenarios/stream_facade_builder.py @@ -22,9 +22,13 @@ from airbyte_cdk.sources.source import TState from airbyte_cdk.sources.streams import Stream from airbyte_cdk.sources.streams.concurrent.cursor import CursorField -from airbyte_cdk.sources.streams.concurrent.state_converters.datetime_stream_state_converter import EpochValueConcurrentStreamStateConverter +from airbyte_cdk.sources.streams.concurrent.state_converters.datetime_stream_state_converter import ( + EpochValueConcurrentStreamStateConverter, +) from unit_tests.sources.file_based.scenarios.scenario_builder import SourceBuilder -from unit_tests.sources.streams.concurrent.scenarios.thread_based_concurrent_stream_source_builder import NeverLogSliceLogger +from unit_tests.sources.streams.concurrent.scenarios.thread_based_concurrent_stream_source_builder import ( + NeverLogSliceLogger, +) _CURSOR_FIELD = "cursor_field" _NO_STATE = None @@ -45,7 +49,9 @@ def __init__( ): self._message_repository = InMemoryMessageRepository() threadpool_manager = ThreadPoolManager(threadpool, streams[0].logger) - concurrent_source = ConcurrentSource(threadpool_manager, streams[0].logger, NeverLogSliceLogger(), self._message_repository) + concurrent_source = ConcurrentSource( + threadpool_manager, streams[0].logger, NeverLogSliceLogger(), self._message_repository + ) super().__init__(concurrent_source) self._streams = streams self._threadpool = threadpool_manager @@ -53,7 +59,9 @@ def __init__( self._cursor_boundaries = cursor_boundaries self._state = [AirbyteStateMessage(s) for s in input_state] if input_state else None - def check_connection(self, logger: logging.Logger, config: Mapping[str, Any]) -> Tuple[bool, Optional[Any]]: + def check_connection( + self, logger: logging.Logger, config: Mapping[str, Any] + ) -> Tuple[bool, Optional[Any]]: return True, None def streams(self, config: Mapping[str, Any]) -> List[Stream]: @@ -117,7 +125,9 @@ def set_max_workers(self, max_workers: int) -> "StreamFacadeSourceBuilder": self._max_workers = max_workers return self - def set_incremental(self, cursor_field: CursorField, cursor_boundaries: Optional[Tuple[str, str]]) -> "StreamFacadeSourceBuilder": + def set_incremental( + self, cursor_field: CursorField, cursor_boundaries: Optional[Tuple[str, str]] + ) -> "StreamFacadeSourceBuilder": self._cursor_field = cursor_field self._cursor_boundaries = cursor_boundaries return self @@ -127,7 +137,14 @@ def set_input_state(self, state: List[Mapping[str, Any]]) -> "StreamFacadeSource return self def build( - self, configured_catalog: Optional[Mapping[str, Any]], config: Optional[Mapping[str, Any]], state: Optional[TState] + self, + configured_catalog: Optional[Mapping[str, Any]], + config: Optional[Mapping[str, Any]], + state: Optional[TState], ) -> StreamFacadeSource: - threadpool = concurrent.futures.ThreadPoolExecutor(max_workers=self._max_workers, thread_name_prefix="workerpool") - return StreamFacadeSource(self._streams, threadpool, self._cursor_field, self._cursor_boundaries, state) + threadpool = concurrent.futures.ThreadPoolExecutor( + max_workers=self._max_workers, thread_name_prefix="workerpool" + ) + return StreamFacadeSource( + self._streams, threadpool, self._cursor_field, self._cursor_boundaries, state + ) diff --git a/unit_tests/sources/streams/concurrent/scenarios/stream_facade_scenarios.py b/unit_tests/sources/streams/concurrent/scenarios/stream_facade_scenarios.py index 41483282..36fc90e9 100644 --- a/unit_tests/sources/streams/concurrent/scenarios/stream_facade_scenarios.py +++ b/unit_tests/sources/streams/concurrent/scenarios/stream_facade_scenarios.py @@ -3,8 +3,13 @@ # from airbyte_cdk.sources.streams.concurrent.cursor import CursorField from airbyte_cdk.utils.traced_exception import AirbyteTracedException -from unit_tests.sources.file_based.scenarios.scenario_builder import IncrementalScenarioConfig, TestScenarioBuilder -from unit_tests.sources.streams.concurrent.scenarios.stream_facade_builder import StreamFacadeSourceBuilder +from unit_tests.sources.file_based.scenarios.scenario_builder import ( + IncrementalScenarioConfig, + TestScenarioBuilder, +) +from unit_tests.sources.streams.concurrent.scenarios.stream_facade_builder import ( + StreamFacadeSourceBuilder, +) from unit_tests.sources.streams.concurrent.scenarios.utils import MockStream _stream1 = MockStream( @@ -339,8 +344,14 @@ [ MockStream( [ - ({"from": 0, "to": 1}, [{"id": "1", "cursor_field": 0}, {"id": "2", "cursor_field": 1}]), - ({"from": 1, "to": 2}, [{"id": "3", "cursor_field": 2}, {"id": "4", "cursor_field": 3}]), + ( + {"from": 0, "to": 1}, + [{"id": "1", "cursor_field": 0}, {"id": "2", "cursor_field": 1}], + ), + ( + {"from": 1, "to": 2}, + [{"id": "3", "cursor_field": 2}, {"id": "4", "cursor_field": 3}], + ), ], "stream1", cursor_field="cursor_field", diff --git a/unit_tests/sources/streams/concurrent/scenarios/test_concurrent_scenarios.py b/unit_tests/sources/streams/concurrent/scenarios/test_concurrent_scenarios.py index af224987..a0abaec0 100644 --- a/unit_tests/sources/streams/concurrent/scenarios/test_concurrent_scenarios.py +++ b/unit_tests/sources/streams/concurrent/scenarios/test_concurrent_scenarios.py @@ -72,5 +72,7 @@ def test_concurrent_read(scenario: TestScenario) -> None: @pytest.mark.parametrize("scenario", scenarios, ids=[s.name for s in scenarios]) -def test_concurrent_discover(capsys: CaptureFixture[str], tmp_path: PosixPath, scenario: TestScenario) -> None: +def test_concurrent_discover( + capsys: CaptureFixture[str], tmp_path: PosixPath, scenario: TestScenario +) -> None: verify_discover(capsys, tmp_path, scenario) diff --git a/unit_tests/sources/streams/concurrent/scenarios/thread_based_concurrent_stream_scenarios.py b/unit_tests/sources/streams/concurrent/scenarios/thread_based_concurrent_stream_scenarios.py index 05f3adbe..2de8bfd0 100644 --- a/unit_tests/sources/streams/concurrent/scenarios/thread_based_concurrent_stream_scenarios.py +++ b/unit_tests/sources/streams/concurrent/scenarios/thread_based_concurrent_stream_scenarios.py @@ -5,7 +5,9 @@ import logging from airbyte_cdk.sources.message import InMemoryMessageRepository -from airbyte_cdk.sources.streams.concurrent.availability_strategy import AlwaysAvailableAvailabilityStrategy +from airbyte_cdk.sources.streams.concurrent.availability_strategy import ( + AlwaysAvailableAvailabilityStrategy, +) from airbyte_cdk.sources.streams.concurrent.cursor import FinalStateCursor from airbyte_cdk.sources.streams.concurrent.default_stream import DefaultStream from airbyte_cdk.sources.streams.concurrent.partitions.record import Record @@ -44,12 +46,21 @@ primary_key=[], cursor_field=None, logger=logging.getLogger("test_logger"), - cursor=FinalStateCursor(stream_name="stream1", stream_namespace=None, message_repository=_message_repository), + cursor=FinalStateCursor( + stream_name="stream1", stream_namespace=None, message_repository=_message_repository + ), ) _id_only_stream_with_slice_logger = DefaultStream( partition_generator=InMemoryPartitionGenerator( - [InMemoryPartition("partition1", "stream1", None, [Record({"id": "1"}, "stream1"), Record({"id": "2"}, "stream1")])] + [ + InMemoryPartition( + "partition1", + "stream1", + None, + [Record({"id": "1"}, "stream1"), Record({"id": "2"}, "stream1")], + ) + ] ), name="stream1", json_schema={ @@ -62,7 +73,9 @@ primary_key=[], cursor_field=None, logger=logging.getLogger("test_logger"), - cursor=FinalStateCursor(stream_name="stream1", stream_namespace=None, message_repository=_message_repository), + cursor=FinalStateCursor( + stream_name="stream1", stream_namespace=None, message_repository=_message_repository + ), ) _id_only_stream_with_primary_key = DefaultStream( @@ -90,7 +103,9 @@ primary_key=["id"], cursor_field=None, logger=logging.getLogger("test_logger"), - cursor=FinalStateCursor(stream_name="stream1", stream_namespace=None, message_repository=_message_repository), + cursor=FinalStateCursor( + stream_name="stream1", stream_namespace=None, message_repository=_message_repository + ), ) _id_only_stream_multiple_partitions = DefaultStream( @@ -127,7 +142,9 @@ primary_key=[], cursor_field=None, logger=logging.getLogger("test_logger"), - cursor=FinalStateCursor(stream_name="stream1", stream_namespace=None, message_repository=_message_repository), + cursor=FinalStateCursor( + stream_name="stream1", stream_namespace=None, message_repository=_message_repository + ), ) _id_only_stream_multiple_partitions_concurrency_level_two = DefaultStream( @@ -164,7 +181,9 @@ primary_key=[], cursor_field=None, logger=logging.getLogger("test_logger"), - cursor=FinalStateCursor(stream_name="stream1", stream_namespace=None, message_repository=_message_repository), + cursor=FinalStateCursor( + stream_name="stream1", stream_namespace=None, message_repository=_message_repository + ), ) _stream_raising_exception = DefaultStream( @@ -174,7 +193,10 @@ "partition1", "stream1", None, - [Record({"id": "1"}, InMemoryPartition("partition1", "stream1", None, [])), ValueError("test exception")], + [ + Record({"id": "1"}, InMemoryPartition("partition1", "stream1", None, [])), + ValueError("test exception"), + ], ) ] ), @@ -189,7 +211,9 @@ primary_key=[], cursor_field=None, logger=logging.getLogger("test_logger"), - cursor=FinalStateCursor(stream_name="stream1", stream_namespace=None, message_repository=_message_repository), + cursor=FinalStateCursor( + stream_name="stream1", stream_namespace=None, message_repository=_message_repository + ), ) test_concurrent_cdk_single_stream = ( @@ -301,8 +325,14 @@ "stream2", None, [ - Record({"id": "10", "key": "v1"}, InMemoryPartition("partition1", "stream2", None, [])), - Record({"id": "20", "key": "v2"}, InMemoryPartition("partition1", "stream2", None, [])), + Record( + {"id": "10", "key": "v1"}, + InMemoryPartition("partition1", "stream2", None, []), + ), + Record( + {"id": "20", "key": "v2"}, + InMemoryPartition("partition1", "stream2", None, []), + ), ], ) ] @@ -319,7 +349,11 @@ primary_key=[], cursor_field=None, logger=logging.getLogger("test_logger"), - cursor=FinalStateCursor(stream_name="stream2", stream_namespace=None, message_repository=_message_repository), + cursor=FinalStateCursor( + stream_name="stream2", + stream_namespace=None, + message_repository=_message_repository, + ), ), ] ) diff --git a/unit_tests/sources/streams/concurrent/scenarios/thread_based_concurrent_stream_source_builder.py b/unit_tests/sources/streams/concurrent/scenarios/thread_based_concurrent_stream_source_builder.py index 17a4b395..98633daf 100644 --- a/unit_tests/sources/streams/concurrent/scenarios/thread_based_concurrent_stream_source_builder.py +++ b/unit_tests/sources/streams/concurrent/scenarios/thread_based_concurrent_stream_source_builder.py @@ -5,7 +5,13 @@ import logging from typing import Any, Iterable, List, Mapping, Optional, Tuple, Union -from airbyte_cdk.models import ConfiguredAirbyteCatalog, ConfiguredAirbyteStream, ConnectorSpecification, DestinationSyncMode, SyncMode +from airbyte_cdk.models import ( + ConfiguredAirbyteCatalog, + ConfiguredAirbyteStream, + ConnectorSpecification, + DestinationSyncMode, + SyncMode, +) from airbyte_cdk.sources.concurrent_source.concurrent_source import ConcurrentSource from airbyte_cdk.sources.concurrent_source.concurrent_source_adapter import ConcurrentSourceAdapter from airbyte_cdk.sources.message import InMemoryMessageRepository, MessageRepository @@ -36,13 +42,23 @@ def read_records( class ConcurrentCdkSource(ConcurrentSourceAdapter): - def __init__(self, streams: List[DefaultStream], message_repository: Optional[MessageRepository], max_workers, timeout_in_seconds): - concurrent_source = ConcurrentSource.create(1, 1, streams[0]._logger, NeverLogSliceLogger(), message_repository) + def __init__( + self, + streams: List[DefaultStream], + message_repository: Optional[MessageRepository], + max_workers, + timeout_in_seconds, + ): + concurrent_source = ConcurrentSource.create( + 1, 1, streams[0]._logger, NeverLogSliceLogger(), message_repository + ) super().__init__(concurrent_source) self._streams = streams self._message_repository = message_repository - def check_connection(self, logger: logging.Logger, config: Mapping[str, Any]) -> Tuple[bool, Optional[Any]]: + def check_connection( + self, logger: logging.Logger, config: Mapping[str, Any] + ) -> Tuple[bool, Optional[Any]]: # Check is not verified because it is up to the source to implement this method return True, None @@ -51,7 +67,11 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]: StreamFacade( s, LegacyStream(), - FinalStateCursor(stream_name=s.name, stream_namespace=s.namespace, message_repository=self.message_repository), + FinalStateCursor( + stream_name=s.name, + stream_namespace=s.namespace, + message_repository=self.message_repository, + ), NeverLogSliceLogger(), s._logger, ) @@ -68,7 +88,11 @@ def read_catalog(self, catalog_path: str) -> ConfiguredAirbyteCatalog: stream=StreamFacade( s, LegacyStream(), - FinalStateCursor(stream_name=s.name, stream_namespace=s.namespace, message_repository=InMemoryMessageRepository()), + FinalStateCursor( + stream_name=s.name, + stream_namespace=s.namespace, + message_repository=InMemoryMessageRepository(), + ), NeverLogSliceLogger(), s._logger, ).as_airbyte_stream(), @@ -140,7 +164,9 @@ def set_streams(self, streams: List[DefaultStream]) -> "ConcurrentSourceBuilder" self._streams = streams return self - def set_message_repository(self, message_repository: MessageRepository) -> "ConcurrentSourceBuilder": + def set_message_repository( + self, message_repository: MessageRepository + ) -> "ConcurrentSourceBuilder": self._message_repository = message_repository return self diff --git a/unit_tests/sources/streams/concurrent/scenarios/utils.py b/unit_tests/sources/streams/concurrent/scenarios/utils.py index 85f6a1f7..627891ee 100644 --- a/unit_tests/sources/streams/concurrent/scenarios/utils.py +++ b/unit_tests/sources/streams/concurrent/scenarios/utils.py @@ -11,7 +11,9 @@ class MockStream(Stream): def __init__( self, - slices_and_records_or_exception: Iterable[Tuple[Optional[Mapping[str, Any]], Iterable[Union[Exception, Mapping[str, Any]]]]], + slices_and_records_or_exception: Iterable[ + Tuple[Optional[Mapping[str, Any]], Iterable[Union[Exception, Mapping[str, Any]]]] + ], name, json_schema, primary_key=None, @@ -53,9 +55,15 @@ def get_json_schema(self) -> Mapping[str, Any]: return self._json_schema def stream_slices( - self, *, sync_mode: SyncMode, cursor_field: Optional[List[str]] = None, stream_state: Optional[Mapping[str, Any]] = None + self, + *, + sync_mode: SyncMode, + cursor_field: Optional[List[str]] = None, + stream_state: Optional[Mapping[str, Any]] = None, ) -> Iterable[Optional[Mapping[str, Any]]]: if self._slices_and_records_or_exception: - yield from [_slice for _slice, records_or_exception in self._slices_and_records_or_exception] + yield from [ + _slice for _slice, records_or_exception in self._slices_and_records_or_exception + ] else: yield None diff --git a/unit_tests/sources/streams/concurrent/test_adapters.py b/unit_tests/sources/streams/concurrent/test_adapters.py index d1dffcdd..cbebfe7c 100644 --- a/unit_tests/sources/streams/concurrent/test_adapters.py +++ b/unit_tests/sources/streams/concurrent/test_adapters.py @@ -17,7 +17,11 @@ StreamPartition, StreamPartitionGenerator, ) -from airbyte_cdk.sources.streams.concurrent.availability_strategy import STREAM_AVAILABLE, StreamAvailable, StreamUnavailable +from airbyte_cdk.sources.streams.concurrent.availability_strategy import ( + STREAM_AVAILABLE, + StreamAvailable, + StreamUnavailable, +) from airbyte_cdk.sources.streams.concurrent.cursor import Cursor from airbyte_cdk.sources.streams.concurrent.exceptions import ExceptionWithDisplayMessage from airbyte_cdk.sources.streams.concurrent.partitions.record import Record @@ -71,12 +75,16 @@ def test_stream_partition_generator(sync_mode): stream_slices = [{"slice": 1}, {"slice": 2}] stream.stream_slices.return_value = stream_slices - partition_generator = StreamPartitionGenerator(stream, message_repository, _ANY_SYNC_MODE, _ANY_CURSOR_FIELD, _ANY_STATE, _ANY_CURSOR) + partition_generator = StreamPartitionGenerator( + stream, message_repository, _ANY_SYNC_MODE, _ANY_CURSOR_FIELD, _ANY_STATE, _ANY_CURSOR + ) partitions = list(partition_generator.generate()) slices = [partition.to_slice() for partition in partitions] assert slices == stream_slices - stream.stream_slices.assert_called_once_with(sync_mode=_ANY_SYNC_MODE, cursor_field=_ANY_CURSOR_FIELD, stream_state=_ANY_STATE) + stream.stream_slices.assert_called_once_with( + sync_mode=_ANY_SYNC_MODE, cursor_field=_ANY_CURSOR_FIELD, stream_state=_ANY_STATE + ) @pytest.mark.parametrize( @@ -97,14 +105,19 @@ def test_stream_partition_generator(sync_mode): def test_stream_partition(transformer, expected_records): stream = Mock() stream.name = _STREAM_NAME - stream.get_json_schema.return_value = {"type": "object", "properties": {"data": {"type": ["integer"]}}} + stream.get_json_schema.return_value = { + "type": "object", + "properties": {"data": {"type": ["integer"]}}, + } stream.transformer = transformer message_repository = InMemoryMessageRepository() _slice = None sync_mode = SyncMode.full_refresh cursor_field = None state = None - partition = StreamPartition(stream, _slice, message_repository, sync_mode, cursor_field, state, _ANY_CURSOR) + partition = StreamPartition( + stream, _slice, message_repository, sync_mode, cursor_field, state, _ANY_CURSOR + ) a_log_message = AirbyteMessage( type=MessageType.LOG, @@ -130,7 +143,9 @@ def test_stream_partition(transformer, expected_records): "exception_type, expected_display_message", [ pytest.param(Exception, None, id="test_exception_no_display_message"), - pytest.param(ExceptionWithDisplayMessage, "display_message", id="test_exception_no_display_message"), + pytest.param( + ExceptionWithDisplayMessage, "display_message", id="test_exception_no_display_message" + ), ], ) def test_stream_partition_raising_exception(exception_type, expected_display_message): @@ -140,7 +155,15 @@ def test_stream_partition_raising_exception(exception_type, expected_display_mes message_repository = InMemoryMessageRepository() _slice = None - partition = StreamPartition(stream, _slice, message_repository, _ANY_SYNC_MODE, _ANY_CURSOR_FIELD, _ANY_STATE, _ANY_CURSOR) + partition = StreamPartition( + stream, + _slice, + message_repository, + _ANY_SYNC_MODE, + _ANY_CURSOR_FIELD, + _ANY_STATE, + _ANY_CURSOR, + ) stream.read_records.side_effect = Exception() @@ -153,14 +176,20 @@ def test_stream_partition_raising_exception(exception_type, expected_display_mes @pytest.mark.parametrize( "_slice, expected_hash", [ - pytest.param({"partition": 1, "k": "v"}, hash(("stream", '{"k": "v", "partition": 1}')), id="test_hash_with_slice"), + pytest.param( + {"partition": 1, "k": "v"}, + hash(("stream", '{"k": "v", "partition": 1}')), + id="test_hash_with_slice", + ), pytest.param(None, hash("stream"), id="test_hash_no_slice"), ], ) def test_stream_partition_hash(_slice, expected_hash): stream = Mock() stream.name = "stream" - partition = StreamPartition(stream, _slice, Mock(), _ANY_SYNC_MODE, _ANY_CURSOR_FIELD, _ANY_STATE, _ANY_CURSOR) + partition = StreamPartition( + stream, _slice, Mock(), _ANY_SYNC_MODE, _ANY_CURSOR_FIELD, _ANY_STATE, _ANY_CURSOR + ) _hash = partition.__hash__() assert _hash == expected_hash @@ -180,7 +209,13 @@ def setUp(self): self._logger = Mock() self._slice_logger = Mock() self._slice_logger.should_log_slice_message.return_value = False - self._facade = StreamFacade(self._abstract_stream, self._legacy_stream, self._cursor, self._slice_logger, self._logger) + self._facade = StreamFacade( + self._abstract_stream, + self._legacy_stream, + self._cursor, + self._slice_logger, + self._logger, + ) self._source = Mock() self._stream = Mock() @@ -206,23 +241,36 @@ def test_json_schema_is_delegated_to_wrapped_stream(self): assert self._facade.get_json_schema() == json_schema self._abstract_stream.get_json_schema.assert_called_once_with() - def test_given_cursor_is_noop_when_supports_incremental_then_return_legacy_stream_response(self): + def test_given_cursor_is_noop_when_supports_incremental_then_return_legacy_stream_response( + self, + ): assert ( StreamFacade( - self._abstract_stream, self._legacy_stream, _ANY_CURSOR, Mock(spec=SliceLogger), Mock(spec=logging.Logger) + self._abstract_stream, + self._legacy_stream, + _ANY_CURSOR, + Mock(spec=SliceLogger), + Mock(spec=logging.Logger), ).supports_incremental == self._legacy_stream.supports_incremental ) def test_given_cursor_is_not_noop_when_supports_incremental_then_return_true(self): assert StreamFacade( - self._abstract_stream, self._legacy_stream, Mock(spec=Cursor), Mock(spec=SliceLogger), Mock(spec=logging.Logger) + self._abstract_stream, + self._legacy_stream, + Mock(spec=Cursor), + Mock(spec=SliceLogger), + Mock(spec=logging.Logger), ).supports_incremental def test_check_availability_is_delegated_to_wrapped_stream(self): availability = StreamAvailable() self._abstract_stream.check_availability.return_value = availability - assert self._facade.check_availability(Mock(), Mock()) == (availability.is_available(), availability.message()) + assert self._facade.check_availability(Mock(), Mock()) == ( + availability.is_available(), + availability.message(), + ) self._abstract_stream.check_availability.assert_called_once_with() def test_full_refresh(self): @@ -233,7 +281,9 @@ def test_full_refresh(self): partition.read.return_value = records self._abstract_stream.generate_partitions.return_value = [partition] - actual_stream_data = list(self._facade.read_records(SyncMode.full_refresh, None, None, None)) + actual_stream_data = list( + self._facade.read_records(SyncMode.full_refresh, None, None, None) + ) assert actual_stream_data == expected_stream_data @@ -254,7 +304,9 @@ def test_create_from_stream_stream(self): stream.primary_key = "id" stream.cursor_field = "cursor" - facade = StreamFacade.create_from_stream(stream, self._source, self._logger, _ANY_STATE, self._cursor) + facade = StreamFacade.create_from_stream( + stream, self._source, self._logger, _ANY_STATE, self._cursor + ) assert facade.name == "stream" assert facade.cursor_field == "cursor" @@ -266,7 +318,9 @@ def test_create_from_stream_stream_with_none_primary_key(self): stream.primary_key = None stream.cursor_field = [] - facade = StreamFacade.create_from_stream(stream, self._source, self._logger, _ANY_STATE, self._cursor) + facade = StreamFacade.create_from_stream( + stream, self._source, self._logger, _ANY_STATE, self._cursor + ) assert facade._abstract_stream._primary_key == [] def test_create_from_stream_with_composite_primary_key(self): @@ -275,7 +329,9 @@ def test_create_from_stream_with_composite_primary_key(self): stream.primary_key = ["id", "name"] stream.cursor_field = [] - facade = StreamFacade.create_from_stream(stream, self._source, self._logger, _ANY_STATE, self._cursor) + facade = StreamFacade.create_from_stream( + stream, self._source, self._logger, _ANY_STATE, self._cursor + ) assert facade._abstract_stream._primary_key == ["id", "name"] def test_create_from_stream_with_empty_list_cursor(self): @@ -283,7 +339,9 @@ def test_create_from_stream_with_empty_list_cursor(self): stream.primary_key = "id" stream.cursor_field = [] - facade = StreamFacade.create_from_stream(stream, self._source, self._logger, _ANY_STATE, self._cursor) + facade = StreamFacade.create_from_stream( + stream, self._source, self._logger, _ANY_STATE, self._cursor + ) assert facade.cursor_field == [] @@ -293,7 +351,9 @@ def test_create_from_stream_raises_exception_if_primary_key_is_nested(self): stream.primary_key = [["field", "id"]] with self.assertRaises(ValueError): - StreamFacade.create_from_stream(stream, self._source, self._logger, _ANY_STATE, self._cursor) + StreamFacade.create_from_stream( + stream, self._source, self._logger, _ANY_STATE, self._cursor + ) def test_create_from_stream_raises_exception_if_primary_key_has_invalid_type(self): stream = Mock() @@ -301,7 +361,9 @@ def test_create_from_stream_raises_exception_if_primary_key_has_invalid_type(sel stream.primary_key = 123 with self.assertRaises(ValueError): - StreamFacade.create_from_stream(stream, self._source, self._logger, _ANY_STATE, self._cursor) + StreamFacade.create_from_stream( + stream, self._source, self._logger, _ANY_STATE, self._cursor + ) def test_create_from_stream_raises_exception_if_cursor_field_is_nested(self): stream = Mock() @@ -310,7 +372,9 @@ def test_create_from_stream_raises_exception_if_cursor_field_is_nested(self): stream.cursor_field = ["field", "cursor"] with self.assertRaises(ValueError): - StreamFacade.create_from_stream(stream, self._source, self._logger, _ANY_STATE, self._cursor) + StreamFacade.create_from_stream( + stream, self._source, self._logger, _ANY_STATE, self._cursor + ) def test_create_from_stream_with_cursor_field_as_list(self): stream = Mock() @@ -318,7 +382,9 @@ def test_create_from_stream_with_cursor_field_as_list(self): stream.primary_key = "id" stream.cursor_field = ["cursor"] - facade = StreamFacade.create_from_stream(stream, self._source, self._logger, _ANY_STATE, self._cursor) + facade = StreamFacade.create_from_stream( + stream, self._source, self._logger, _ANY_STATE, self._cursor + ) assert facade.cursor_field == "cursor" def test_create_from_stream_none_message_repository(self): @@ -328,12 +394,16 @@ def test_create_from_stream_none_message_repository(self): self._source.message_repository = None with self.assertRaises(ValueError): - StreamFacade.create_from_stream(self._stream, self._source, self._logger, {}, self._cursor) + StreamFacade.create_from_stream( + self._stream, self._source, self._logger, {}, self._cursor + ) def test_get_error_display_message_no_display_message(self): self._stream.get_error_display_message.return_value = "display_message" - facade = StreamFacade.create_from_stream(self._stream, self._source, self._logger, _ANY_STATE, self._cursor) + facade = StreamFacade.create_from_stream( + self._stream, self._source, self._logger, _ANY_STATE, self._cursor + ) expected_display_message = None e = Exception() @@ -345,7 +415,9 @@ def test_get_error_display_message_no_display_message(self): def test_get_error_display_message_with_display_message(self): self._stream.get_error_display_message.return_value = "display_message" - facade = StreamFacade.create_from_stream(self._stream, self._source, self._logger, _ANY_STATE, self._cursor) + facade = StreamFacade.create_from_stream( + self._stream, self._source, self._logger, _ANY_STATE, self._cursor + ) expected_display_message = "display_message" e = ExceptionWithDisplayMessage("display_message") @@ -359,7 +431,9 @@ def test_get_error_display_message_with_display_message(self): "exception, expected_display_message", [ pytest.param(Exception("message"), None, id="test_no_display_message"), - pytest.param(ExceptionWithDisplayMessage("message"), "message", id="test_no_display_message"), + pytest.param( + ExceptionWithDisplayMessage("message"), "message", id="test_no_display_message" + ), ], ) def test_get_error_display_message(exception, expected_display_message): @@ -377,19 +451,37 @@ def test_cursor_partition_generator(): stream = Mock() cursor = Mock() message_repository = Mock() - connector_state_converter = CustomFormatConcurrentStreamStateConverter(datetime_format="%Y-%m-%dT%H:%M:%S") + connector_state_converter = CustomFormatConcurrentStreamStateConverter( + datetime_format="%Y-%m-%dT%H:%M:%S" + ) cursor_field = Mock() slice_boundary_fields = ("start", "end") - expected_slices = [StreamSlice(partition={}, cursor_slice={"start": "2024-01-01T00:00:00", "end": "2024-01-02T00:00:00"})] - cursor.generate_slices.return_value = [(datetime.datetime(year=2024, month=1, day=1), datetime.datetime(year=2024, month=1, day=2))] + expected_slices = [ + StreamSlice( + partition={}, + cursor_slice={"start": "2024-01-01T00:00:00", "end": "2024-01-02T00:00:00"}, + ) + ] + cursor.generate_slices.return_value = [ + (datetime.datetime(year=2024, month=1, day=1), datetime.datetime(year=2024, month=1, day=2)) + ] partition_generator = CursorPartitionGenerator( - stream, message_repository, cursor, connector_state_converter, cursor_field, slice_boundary_fields + stream, + message_repository, + cursor, + connector_state_converter, + cursor_field, + slice_boundary_fields, ) partitions = list(partition_generator.generate()) generated_slices = [partition.to_slice() for partition in partitions] - assert all(isinstance(partition, StreamPartition) for partition in partitions), "Not all partitions are instances of StreamPartition" - assert generated_slices == expected_slices, f"Expected {expected_slices}, but got {generated_slices}" + assert all( + isinstance(partition, StreamPartition) for partition in partitions + ), "Not all partitions are instances of StreamPartition" + assert ( + generated_slices == expected_slices + ), f"Expected {expected_slices}, but got {generated_slices}" diff --git a/unit_tests/sources/streams/concurrent/test_concurrent_read_processor.py b/unit_tests/sources/streams/concurrent/test_concurrent_read_processor.py index bd2d4b1e..68c7d797 100644 --- a/unit_tests/sources/streams/concurrent/test_concurrent_read_processor.py +++ b/unit_tests/sources/streams/concurrent/test_concurrent_read_processor.py @@ -20,7 +20,9 @@ from airbyte_cdk.models import StreamDescriptor, SyncMode, TraceType from airbyte_cdk.models import Type as MessageType from airbyte_cdk.sources.concurrent_source.concurrent_read_processor import ConcurrentReadProcessor -from airbyte_cdk.sources.concurrent_source.partition_generation_completed_sentinel import PartitionGenerationCompletedSentinel +from airbyte_cdk.sources.concurrent_source.partition_generation_completed_sentinel import ( + PartitionGenerationCompletedSentinel, +) from airbyte_cdk.sources.concurrent_source.stream_thread_exception import StreamThreadException from airbyte_cdk.sources.concurrent_source.thread_pool_manager import ThreadPoolManager from airbyte_cdk.sources.message import LogMessage, MessageRepository @@ -172,8 +174,12 @@ def test_handle_partition(self): handler.on_partition(self._a_closed_partition) - self._thread_pool_manager.submit.assert_called_with(self._partition_reader.process_partition, self._a_closed_partition) - assert self._a_closed_partition in handler._streams_to_running_partitions[_ANOTHER_STREAM_NAME] + self._thread_pool_manager.submit.assert_called_with( + self._partition_reader.process_partition, self._a_closed_partition + ) + assert ( + self._a_closed_partition in handler._streams_to_running_partitions[_ANOTHER_STREAM_NAME] + ) def test_handle_partition_emits_log_message_if_it_should_be_logged(self): stream_instances_to_read_from = [self._stream] @@ -193,7 +199,9 @@ def test_handle_partition_emits_log_message_if_it_should_be_logged(self): handler.on_partition(self._an_open_partition) - self._thread_pool_manager.submit.assert_called_with(self._partition_reader.process_partition, self._an_open_partition) + self._thread_pool_manager.submit.assert_called_with( + self._partition_reader.process_partition, self._an_open_partition + ) self._message_repository.emit_message.assert_called_with(self._log_message) assert self._an_open_partition in handler._streams_to_running_partitions[_STREAM_NAME] @@ -221,20 +229,32 @@ def test_handle_on_partition_complete_sentinel_with_messages_from_repository(sel sentinel = PartitionCompleteSentinel(partition) self._message_repository.consume_queue.return_value = [ - AirbyteMessage(type=MessageType.LOG, log=AirbyteLogMessage(level=LogLevel.INFO, message="message emitted from the repository")) + AirbyteMessage( + type=MessageType.LOG, + log=AirbyteLogMessage( + level=LogLevel.INFO, message="message emitted from the repository" + ), + ) ] messages = list(handler.on_partition_complete_sentinel(sentinel)) expected_messages = [ - AirbyteMessage(type=MessageType.LOG, log=AirbyteLogMessage(level=LogLevel.INFO, message="message emitted from the repository")) + AirbyteMessage( + type=MessageType.LOG, + log=AirbyteLogMessage( + level=LogLevel.INFO, message="message emitted from the repository" + ), + ) ] assert messages == expected_messages partition.close.assert_called_once() @freezegun.freeze_time("2020-01-01T00:00:00") - def test_handle_on_partition_complete_sentinel_yields_status_message_if_the_stream_is_done(self): + def test_handle_on_partition_complete_sentinel_yields_status_message_if_the_stream_is_done( + self, + ): self._streams_currently_generating_partitions = [self._another_stream] stream_instances_to_read_from = [self._another_stream] log_message = Mock(spec=LogMessage) @@ -252,7 +272,11 @@ def test_handle_on_partition_complete_sentinel_yields_status_message_if_the_stre ) handler.start_next_partition_generator() handler.on_partition(self._a_closed_partition) - list(handler.on_partition_generation_completed(PartitionGenerationCompletedSentinel(self._another_stream))) + list( + handler.on_partition_generation_completed( + PartitionGenerationCompletedSentinel(self._another_stream) + ) + ) sentinel = PartitionCompleteSentinel(self._a_closed_partition) @@ -277,7 +301,9 @@ def test_handle_on_partition_complete_sentinel_yields_status_message_if_the_stre self._a_closed_partition.close.assert_called_once() @freezegun.freeze_time("2020-01-01T00:00:00") - def test_given_exception_on_partition_complete_sentinel_then_yield_error_trace_message_and_stream_is_incomplete(self) -> None: + def test_given_exception_on_partition_complete_sentinel_then_yield_error_trace_message_and_stream_is_incomplete( + self, + ) -> None: self._a_closed_partition.stream_name.return_value = self._stream.name self._a_closed_partition.close.side_effect = ValueError @@ -292,8 +318,16 @@ def test_given_exception_on_partition_complete_sentinel_then_yield_error_trace_m ) handler.start_next_partition_generator() handler.on_partition(self._a_closed_partition) - list(handler.on_partition_generation_completed(PartitionGenerationCompletedSentinel(self._stream))) - messages = list(handler.on_partition_complete_sentinel(PartitionCompleteSentinel(self._a_closed_partition))) + list( + handler.on_partition_generation_completed( + PartitionGenerationCompletedSentinel(self._stream) + ) + ) + messages = list( + handler.on_partition_complete_sentinel( + PartitionCompleteSentinel(self._a_closed_partition) + ) + ) expected_status_message = AirbyteMessage( type=MessageType.TRACE, @@ -308,11 +342,16 @@ def test_given_exception_on_partition_complete_sentinel_then_yield_error_trace_m emitted_at=1577836800000.0, ), ) - assert list(map(lambda message: message.trace.type, messages)) == [TraceType.ERROR, TraceType.STREAM_STATUS] + assert list(map(lambda message: message.trace.type, messages)) == [ + TraceType.ERROR, + TraceType.STREAM_STATUS, + ] assert messages[1] == expected_status_message @freezegun.freeze_time("2020-01-01T00:00:00") - def test_handle_on_partition_complete_sentinel_yields_no_status_message_if_the_stream_is_not_done(self): + def test_handle_on_partition_complete_sentinel_yields_no_status_message_if_the_stream_is_not_done( + self, + ): stream_instances_to_read_from = [self._stream] partition = Mock(spec=Partition) log_message = Mock(spec=LogMessage) @@ -385,7 +424,12 @@ def test_on_record_with_repository_messge(self): slice_logger.should_log_slice_message.return_value = True slice_logger.create_slice_log_message.return_value = log_message self._message_repository.consume_queue.return_value = [ - AirbyteMessage(type=MessageType.LOG, log=AirbyteLogMessage(level=LogLevel.INFO, message="message emitted from the repository")) + AirbyteMessage( + type=MessageType.LOG, + log=AirbyteLogMessage( + level=LogLevel.INFO, message="message emitted from the repository" + ), + ) ] handler = ConcurrentReadProcessor( @@ -420,7 +464,12 @@ def test_on_record_with_repository_messge(self): emitted_at=1577836800000, ), ), - AirbyteMessage(type=MessageType.LOG, log=AirbyteLogMessage(level=LogLevel.INFO, message="message emitted from the repository")), + AirbyteMessage( + type=MessageType.LOG, + log=AirbyteLogMessage( + level=LogLevel.INFO, message="message emitted from the repository" + ), + ), ] assert messages == expected_messages assert handler._record_counter[_STREAM_NAME] == 2 @@ -451,7 +500,8 @@ def test_on_record_emits_status_message_on_first_record_no_repository_message(se type=TraceType.STREAM_STATUS, emitted_at=1577836800000.0, stream_status=AirbyteStreamStatusTraceMessage( - stream_descriptor=StreamDescriptor(name=_STREAM_NAME), status=AirbyteStreamStatus(AirbyteStreamStatus.RUNNING) + stream_descriptor=StreamDescriptor(name=_STREAM_NAME), + status=AirbyteStreamStatus(AirbyteStreamStatus.RUNNING), ), ), ), @@ -474,7 +524,12 @@ def test_on_record_emits_status_message_on_first_record_with_repository_message( partition.to_slice.return_value = log_message partition.stream_name.return_value = _STREAM_NAME self._message_repository.consume_queue.return_value = [ - AirbyteMessage(type=MessageType.LOG, log=AirbyteLogMessage(level=LogLevel.INFO, message="message emitted from the repository")) + AirbyteMessage( + type=MessageType.LOG, + log=AirbyteLogMessage( + level=LogLevel.INFO, message="message emitted from the repository" + ), + ) ] handler = ConcurrentReadProcessor( @@ -504,7 +559,8 @@ def test_on_record_emits_status_message_on_first_record_with_repository_message( type=TraceType.STREAM_STATUS, emitted_at=1577836800000.0, stream_status=AirbyteStreamStatusTraceMessage( - stream_descriptor=StreamDescriptor(name=_STREAM_NAME), status=AirbyteStreamStatus(AirbyteStreamStatus.RUNNING) + stream_descriptor=StreamDescriptor(name=_STREAM_NAME), + status=AirbyteStreamStatus(AirbyteStreamStatus.RUNNING), ), ), ), @@ -516,7 +572,12 @@ def test_on_record_emits_status_message_on_first_record_with_repository_message( emitted_at=1577836800000, ), ), - AirbyteMessage(type=MessageType.LOG, log=AirbyteLogMessage(level=LogLevel.INFO, message="message emitted from the repository")), + AirbyteMessage( + type=MessageType.LOG, + log=AirbyteLogMessage( + level=LogLevel.INFO, message="message emitted from the repository" + ), + ), ] assert messages == expected_messages @@ -536,8 +597,16 @@ def test_on_exception_return_trace_message_and_on_stream_complete_return_stream_ handler.start_next_partition_generator() handler.on_partition(self._an_open_partition) - list(handler.on_partition_generation_completed(PartitionGenerationCompletedSentinel(self._stream))) - list(handler.on_partition_generation_completed(PartitionGenerationCompletedSentinel(self._another_stream))) + list( + handler.on_partition_generation_completed( + PartitionGenerationCompletedSentinel(self._stream) + ) + ) + list( + handler.on_partition_generation_completed( + PartitionGenerationCompletedSentinel(self._another_stream) + ) + ) another_stream = Mock(spec=AbstractStream) another_stream.name = _STREAM_NAME @@ -553,14 +622,19 @@ def test_on_exception_return_trace_message_and_on_stream_complete_return_stream_ assert len(exception_messages) == 1 assert "StreamThreadException" in exception_messages[0].trace.error.stack_trace - assert list(handler.on_partition_complete_sentinel(PartitionCompleteSentinel(self._an_open_partition))) == [ + assert list( + handler.on_partition_complete_sentinel( + PartitionCompleteSentinel(self._an_open_partition) + ) + ) == [ AirbyteMessage( type=MessageType.TRACE, trace=AirbyteTraceMessage( type=TraceType.STREAM_STATUS, emitted_at=1577836800000.0, stream_status=AirbyteStreamStatusTraceMessage( - stream_descriptor=StreamDescriptor(name=_STREAM_NAME), status=AirbyteStreamStatus(AirbyteStreamStatus.INCOMPLETE) + stream_descriptor=StreamDescriptor(name=_STREAM_NAME), + status=AirbyteStreamStatus(AirbyteStreamStatus.INCOMPLETE), ), ), ) @@ -586,8 +660,16 @@ def test_given_underlying_exception_is_traced_exception_on_exception_return_trac handler.start_next_partition_generator() handler.on_partition(self._an_open_partition) - list(handler.on_partition_generation_completed(PartitionGenerationCompletedSentinel(self._stream))) - list(handler.on_partition_generation_completed(PartitionGenerationCompletedSentinel(self._another_stream))) + list( + handler.on_partition_generation_completed( + PartitionGenerationCompletedSentinel(self._stream) + ) + ) + list( + handler.on_partition_generation_completed( + PartitionGenerationCompletedSentinel(self._another_stream) + ) + ) another_stream = Mock(spec=AbstractStream) another_stream.name = _STREAM_NAME @@ -604,14 +686,19 @@ def test_given_underlying_exception_is_traced_exception_on_exception_return_trac assert len(exception_messages) == 1 assert "AirbyteTracedException" in exception_messages[0].trace.error.stack_trace - assert list(handler.on_partition_complete_sentinel(PartitionCompleteSentinel(self._an_open_partition))) == [ + assert list( + handler.on_partition_complete_sentinel( + PartitionCompleteSentinel(self._an_open_partition) + ) + ) == [ AirbyteMessage( type=MessageType.TRACE, trace=AirbyteTraceMessage( type=TraceType.STREAM_STATUS, emitted_at=1577836800000.0, stream_status=AirbyteStreamStatusTraceMessage( - stream_descriptor=StreamDescriptor(name=_STREAM_NAME), status=AirbyteStreamStatus(AirbyteStreamStatus.INCOMPLETE) + stream_descriptor=StreamDescriptor(name=_STREAM_NAME), + status=AirbyteStreamStatus(AirbyteStreamStatus.INCOMPLETE), ), ), ) @@ -634,9 +721,17 @@ def test_given_partition_completion_is_not_success_then_do_not_close_partition(s handler.start_next_partition_generator() handler.on_partition(self._an_open_partition) - list(handler.on_partition_generation_completed(PartitionGenerationCompletedSentinel(self._stream))) + list( + handler.on_partition_generation_completed( + PartitionGenerationCompletedSentinel(self._stream) + ) + ) - list(handler.on_partition_complete_sentinel(PartitionCompleteSentinel(self._an_open_partition, not _IS_SUCCESSFUL))) + list( + handler.on_partition_complete_sentinel( + PartitionCompleteSentinel(self._an_open_partition, not _IS_SUCCESSFUL) + ) + ) assert self._an_open_partition.close.call_count == 0 @@ -655,9 +750,17 @@ def test_given_partition_completion_is_not_success_then_do_not_close_partition(s handler.start_next_partition_generator() handler.on_partition(self._an_open_partition) - list(handler.on_partition_generation_completed(PartitionGenerationCompletedSentinel(self._stream))) + list( + handler.on_partition_generation_completed( + PartitionGenerationCompletedSentinel(self._stream) + ) + ) - list(handler.on_partition_complete_sentinel(PartitionCompleteSentinel(self._an_open_partition, not _IS_SUCCESSFUL))) + list( + handler.on_partition_complete_sentinel( + PartitionCompleteSentinel(self._an_open_partition, not _IS_SUCCESSFUL) + ) + ) assert self._an_open_partition.close.call_count == 0 @@ -708,11 +811,15 @@ def test_is_done_is_false_if_all_partitions_are_not_closed(self): handler.start_next_partition_generator() handler.on_partition(self._an_open_partition) - handler.on_partition_generation_completed(PartitionGenerationCompletedSentinel(self._stream)) + handler.on_partition_generation_completed( + PartitionGenerationCompletedSentinel(self._stream) + ) assert not handler.is_done() - def test_is_done_is_true_if_all_partitions_are_closed_and_no_streams_are_generating_partitions_and_none_are_still_to_run(self): + def test_is_done_is_true_if_all_partitions_are_closed_and_no_streams_are_generating_partitions_and_none_are_still_to_run( + self, + ): stream_instances_to_read_from = [] handler = ConcurrentReadProcessor( @@ -748,10 +855,13 @@ def test_start_next_partition_generator(self): type=TraceType.STREAM_STATUS, emitted_at=1577836800000.0, stream_status=AirbyteStreamStatusTraceMessage( - stream_descriptor=StreamDescriptor(name=_STREAM_NAME), status=AirbyteStreamStatus(AirbyteStreamStatus.STARTED) + stream_descriptor=StreamDescriptor(name=_STREAM_NAME), + status=AirbyteStreamStatus(AirbyteStreamStatus.STARTED), ), ), ) assert _STREAM_NAME in handler._streams_currently_generating_partitions - self._thread_pool_manager.submit.assert_called_with(self._partition_enqueuer.generate_partitions, self._stream) + self._thread_pool_manager.submit.assert_called_with( + self._partition_enqueuer.generate_partitions, self._stream + ) diff --git a/unit_tests/sources/streams/concurrent/test_cursor.py b/unit_tests/sources/streams/concurrent/test_cursor.py index 5ef8c60c..883f2418 100644 --- a/unit_tests/sources/streams/concurrent/test_cursor.py +++ b/unit_tests/sources/streams/concurrent/test_cursor.py @@ -14,10 +14,16 @@ from airbyte_cdk.sources.declarative.datetime.min_max_datetime import MinMaxDatetime from airbyte_cdk.sources.declarative.incremental.datetime_based_cursor import DatetimeBasedCursor from airbyte_cdk.sources.message import MessageRepository -from airbyte_cdk.sources.streams.concurrent.cursor import ConcurrentCursor, CursorField, CursorValueType +from airbyte_cdk.sources.streams.concurrent.cursor import ( + ConcurrentCursor, + CursorField, + CursorValueType, +) from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition from airbyte_cdk.sources.streams.concurrent.partitions.record import Record -from airbyte_cdk.sources.streams.concurrent.state_converters.abstract_stream_state_converter import ConcurrencyCompatibleStateType +from airbyte_cdk.sources.streams.concurrent.state_converters.abstract_stream_state_converter import ( + ConcurrencyCompatibleStateType, +) from airbyte_cdk.sources.streams.concurrent.state_converters.datetime_stream_state_converter import ( EpochValueConcurrentStreamStateConverter, IsoMillisConcurrentStreamStateConverter, @@ -38,14 +44,18 @@ _NO_LOOKBACK_WINDOW = timedelta(seconds=0) -def _partition(_slice: Optional[Mapping[str, Any]], _stream_name: Optional[str] = Mock()) -> Partition: +def _partition( + _slice: Optional[Mapping[str, Any]], _stream_name: Optional[str] = Mock() +) -> Partition: partition = Mock(spec=Partition) partition.to_slice.return_value = _slice partition.stream_name.return_value = _stream_name return partition -def _record(cursor_value: CursorValueType, partition: Optional[Partition] = Mock(spec=Partition)) -> Record: +def _record( + cursor_value: CursorValueType, partition: Optional[Partition] = Mock(spec=Partition) +) -> Record: return Record(data={_A_CURSOR_FIELD_KEY: cursor_value}, partition=partition) @@ -92,11 +102,15 @@ def test_given_boundary_fields_when_close_partition_then_emit_state(self) -> Non ) ) - self._message_repository.emit_message.assert_called_once_with(self._state_manager.create_state_message.return_value) + self._message_repository.emit_message.assert_called_once_with( + self._state_manager.create_state_message.return_value + ) self._state_manager.update_state_for_stream.assert_called_once_with( _A_STREAM_NAME, _A_STREAM_NAMESPACE, - {_A_CURSOR_FIELD_KEY: 0}, # State message is updated to the legacy format before being emitted + { + _A_CURSOR_FIELD_KEY: 0 + }, # State message is updated to the legacy format before being emitted ) def test_given_state_not_sequential_when_close_partition_then_emit_state(self) -> None: @@ -107,34 +121,52 @@ def test_given_state_not_sequential_when_close_partition_then_emit_state(self) - ) ) - self._message_repository.emit_message.assert_called_once_with(self._state_manager.create_state_message.return_value) + self._message_repository.emit_message.assert_called_once_with( + self._state_manager.create_state_message.return_value + ) self._state_manager.update_state_for_stream.assert_called_once_with( _A_STREAM_NAME, _A_STREAM_NAMESPACE, - {"slices": [{"end": 0, "start": 0}, {"end": 30, "start": 12}], "state_type": "date-range"}, + { + "slices": [{"end": 0, "start": 0}, {"end": 30, "start": 12}], + "state_type": "date-range", + }, ) - def test_close_partition_emits_message_to_lower_boundary_when_no_prior_state_exists(self) -> None: + def test_close_partition_emits_message_to_lower_boundary_when_no_prior_state_exists( + self, + ) -> None: self._cursor_with_slice_boundary_fields().close_partition( _partition( {_LOWER_SLICE_BOUNDARY_FIELD: 0, _UPPER_SLICE_BOUNDARY_FIELD: 30}, ) ) - self._message_repository.emit_message.assert_called_once_with(self._state_manager.create_state_message.return_value) + self._message_repository.emit_message.assert_called_once_with( + self._state_manager.create_state_message.return_value + ) self._state_manager.update_state_for_stream.assert_called_once_with( _A_STREAM_NAME, _A_STREAM_NAMESPACE, {_A_CURSOR_FIELD_KEY: 0}, # State message is updated to the lower slice boundary ) - def test_given_boundary_fields_and_record_observed_when_close_partition_then_ignore_records(self) -> None: + def test_given_boundary_fields_and_record_observed_when_close_partition_then_ignore_records( + self, + ) -> None: cursor = self._cursor_with_slice_boundary_fields() cursor.observe(_record(_A_VERY_HIGH_CURSOR_VALUE)) - cursor.close_partition(_partition({_LOWER_SLICE_BOUNDARY_FIELD: 12, _UPPER_SLICE_BOUNDARY_FIELD: 30})) + cursor.close_partition( + _partition({_LOWER_SLICE_BOUNDARY_FIELD: 12, _UPPER_SLICE_BOUNDARY_FIELD: 30}) + ) - assert self._state_manager.update_state_for_stream.call_args_list[0].args[2][_A_CURSOR_FIELD_KEY] != _A_VERY_HIGH_CURSOR_VALUE + assert ( + self._state_manager.update_state_for_stream.call_args_list[0].args[2][ + _A_CURSOR_FIELD_KEY + ] + != _A_VERY_HIGH_CURSOR_VALUE + ) def test_given_no_boundary_fields_when_close_partition_then_emit_state(self) -> None: cursor = self._cursor_without_slice_boundary_fields() @@ -148,7 +180,9 @@ def test_given_no_boundary_fields_when_close_partition_then_emit_state(self) -> {"a_cursor_field_key": 10}, ) - def test_given_no_boundary_fields_when_close_multiple_partitions_then_raise_exception(self) -> None: + def test_given_no_boundary_fields_when_close_multiple_partitions_then_raise_exception( + self, + ) -> None: cursor = self._cursor_without_slice_boundary_fields() partition = _partition(_NO_SLICE) cursor.observe(_record(10, partition=partition)) @@ -162,12 +196,16 @@ def test_given_no_records_observed_when_close_partition_then_do_not_emit_state(s cursor.close_partition(_partition(_NO_SLICE)) assert self._message_repository.emit_message.call_count == 0 - def test_given_slice_boundaries_and_no_slice_when_close_partition_then_raise_error(self) -> None: + def test_given_slice_boundaries_and_no_slice_when_close_partition_then_raise_error( + self, + ) -> None: cursor = self._cursor_with_slice_boundary_fields() with pytest.raises(KeyError): cursor.close_partition(_partition(_NO_SLICE)) - def test_given_slice_boundaries_not_matching_slice_when_close_partition_then_raise_error(self) -> None: + def test_given_slice_boundaries_not_matching_slice_when_close_partition_then_raise_error( + self, + ) -> None: cursor = self._cursor_with_slice_boundary_fields() with pytest.raises(KeyError): cursor.close_partition(_partition({"not_matching_key": "value"})) @@ -196,7 +234,9 @@ def test_given_no_state_when_generate_slices_then_create_slice_from_start_to_end ] @freezegun.freeze_time(time_to_freeze=datetime.fromtimestamp(50, timezone.utc)) - def test_given_one_slice_when_generate_slices_then_create_slice_from_slice_upper_boundary_to_end(self): + def test_given_one_slice_when_generate_slices_then_create_slice_from_slice_upper_boundary_to_end( + self, + ): start = datetime.fromtimestamp(0, timezone.utc) cursor = ConcurrentCursor( _A_STREAM_NAME, @@ -204,7 +244,10 @@ def test_given_one_slice_when_generate_slices_then_create_slice_from_slice_upper { "state_type": ConcurrencyCompatibleStateType.date_range.value, "slices": [ - {EpochValueConcurrentStreamStateConverter.START_KEY: 0, EpochValueConcurrentStreamStateConverter.END_KEY: 20}, + { + EpochValueConcurrentStreamStateConverter.START_KEY: 0, + EpochValueConcurrentStreamStateConverter.END_KEY: 20, + }, ], }, self._message_repository, @@ -232,7 +275,10 @@ def test_given_start_after_slices_when_generate_slices_then_generate_from_start( { "state_type": ConcurrencyCompatibleStateType.date_range.value, "slices": [ - {EpochValueConcurrentStreamStateConverter.START_KEY: 0, EpochValueConcurrentStreamStateConverter.END_KEY: 20}, + { + EpochValueConcurrentStreamStateConverter.START_KEY: 0, + EpochValueConcurrentStreamStateConverter.END_KEY: 20, + }, ], }, self._message_repository, @@ -252,7 +298,9 @@ def test_given_start_after_slices_when_generate_slices_then_generate_from_start( ] @freezegun.freeze_time(time_to_freeze=datetime.fromtimestamp(50, timezone.utc)) - def test_given_state_with_gap_and_start_after_slices_when_generate_slices_then_generate_from_start(self): + def test_given_state_with_gap_and_start_after_slices_when_generate_slices_then_generate_from_start( + self, + ): start = datetime.fromtimestamp(30, timezone.utc) cursor = ConcurrentCursor( _A_STREAM_NAME, @@ -260,8 +308,14 @@ def test_given_state_with_gap_and_start_after_slices_when_generate_slices_then_g { "state_type": ConcurrencyCompatibleStateType.date_range.value, "slices": [ - {EpochValueConcurrentStreamStateConverter.START_KEY: 0, EpochValueConcurrentStreamStateConverter.END_KEY: 10}, - {EpochValueConcurrentStreamStateConverter.START_KEY: 15, EpochValueConcurrentStreamStateConverter.END_KEY: 20}, + { + EpochValueConcurrentStreamStateConverter.START_KEY: 0, + EpochValueConcurrentStreamStateConverter.END_KEY: 10, + }, + { + EpochValueConcurrentStreamStateConverter.START_KEY: 15, + EpochValueConcurrentStreamStateConverter.END_KEY: 20, + }, ], }, self._message_repository, @@ -290,7 +344,10 @@ def test_given_small_slice_range_when_generate_slices_then_create_many_slices(se { "state_type": ConcurrencyCompatibleStateType.date_range.value, "slices": [ - {EpochValueConcurrentStreamStateConverter.START_KEY: 0, EpochValueConcurrentStreamStateConverter.END_KEY: 20}, + { + EpochValueConcurrentStreamStateConverter.START_KEY: 0, + EpochValueConcurrentStreamStateConverter.END_KEY: 20, + }, ], }, self._message_repository, @@ -313,7 +370,9 @@ def test_given_small_slice_range_when_generate_slices_then_create_many_slices(se ] @freezegun.freeze_time(time_to_freeze=datetime.fromtimestamp(50, timezone.utc)) - def test_given_difference_between_slices_match_slice_range_when_generate_slices_then_create_one_slice(self): + def test_given_difference_between_slices_match_slice_range_when_generate_slices_then_create_one_slice( + self, + ): start = datetime.fromtimestamp(0, timezone.utc) small_slice_range = timedelta(seconds=10) cursor = ConcurrentCursor( @@ -322,8 +381,14 @@ def test_given_difference_between_slices_match_slice_range_when_generate_slices_ { "state_type": ConcurrencyCompatibleStateType.date_range.value, "slices": [ - {EpochValueConcurrentStreamStateConverter.START_KEY: 0, EpochValueConcurrentStreamStateConverter.END_KEY: 30}, - {EpochValueConcurrentStreamStateConverter.START_KEY: 40, EpochValueConcurrentStreamStateConverter.END_KEY: 50}, + { + EpochValueConcurrentStreamStateConverter.START_KEY: 0, + EpochValueConcurrentStreamStateConverter.END_KEY: 30, + }, + { + EpochValueConcurrentStreamStateConverter.START_KEY: 40, + EpochValueConcurrentStreamStateConverter.END_KEY: 50, + }, ], }, self._message_repository, @@ -344,7 +409,9 @@ def test_given_difference_between_slices_match_slice_range_when_generate_slices_ ] @freezegun.freeze_time(time_to_freeze=datetime.fromtimestamp(50, timezone.utc)) - def test_given_small_slice_range_with_granularity_when_generate_slices_then_create_many_slices(self): + def test_given_small_slice_range_with_granularity_when_generate_slices_then_create_many_slices( + self, + ): start = datetime.fromtimestamp(1, timezone.utc) small_slice_range = timedelta(seconds=10) granularity = timedelta(seconds=1) @@ -354,7 +421,10 @@ def test_given_small_slice_range_with_granularity_when_generate_slices_then_crea { "state_type": ConcurrencyCompatibleStateType.date_range.value, "slices": [ - {EpochValueConcurrentStreamStateConverter.START_KEY: 1, EpochValueConcurrentStreamStateConverter.END_KEY: 20}, + { + EpochValueConcurrentStreamStateConverter.START_KEY: 1, + EpochValueConcurrentStreamStateConverter.END_KEY: 20, + }, ], }, self._message_repository, @@ -378,7 +448,9 @@ def test_given_small_slice_range_with_granularity_when_generate_slices_then_crea ] @freezegun.freeze_time(time_to_freeze=datetime.fromtimestamp(50, timezone.utc)) - def test_given_difference_between_slices_match_slice_range_and_cursor_granularity_when_generate_slices_then_create_one_slice(self): + def test_given_difference_between_slices_match_slice_range_and_cursor_granularity_when_generate_slices_then_create_one_slice( + self, + ): start = datetime.fromtimestamp(1, timezone.utc) small_slice_range = timedelta(seconds=10) granularity = timedelta(seconds=1) @@ -388,8 +460,14 @@ def test_given_difference_between_slices_match_slice_range_and_cursor_granularit { "state_type": ConcurrencyCompatibleStateType.date_range.value, "slices": [ - {EpochValueConcurrentStreamStateConverter.START_KEY: 1, EpochValueConcurrentStreamStateConverter.END_KEY: 30}, - {EpochValueConcurrentStreamStateConverter.START_KEY: 41, EpochValueConcurrentStreamStateConverter.END_KEY: 50}, + { + EpochValueConcurrentStreamStateConverter.START_KEY: 1, + EpochValueConcurrentStreamStateConverter.END_KEY: 30, + }, + { + EpochValueConcurrentStreamStateConverter.START_KEY: 41, + EpochValueConcurrentStreamStateConverter.END_KEY: 50, + }, ], }, self._message_repository, @@ -414,16 +492,27 @@ def test_given_difference_between_slices_match_slice_range_and_cursor_granularit ] @freezegun.freeze_time(time_to_freeze=datetime.fromtimestamp(50, timezone.utc)) - def test_given_non_continuous_state_when_generate_slices_then_create_slices_between_gaps_and_after(self): + def test_given_non_continuous_state_when_generate_slices_then_create_slices_between_gaps_and_after( + self, + ): cursor = ConcurrentCursor( _A_STREAM_NAME, _A_STREAM_NAMESPACE, { "state_type": ConcurrencyCompatibleStateType.date_range.value, "slices": [ - {EpochValueConcurrentStreamStateConverter.START_KEY: 0, EpochValueConcurrentStreamStateConverter.END_KEY: 10}, - {EpochValueConcurrentStreamStateConverter.START_KEY: 20, EpochValueConcurrentStreamStateConverter.END_KEY: 25}, - {EpochValueConcurrentStreamStateConverter.START_KEY: 30, EpochValueConcurrentStreamStateConverter.END_KEY: 40}, + { + EpochValueConcurrentStreamStateConverter.START_KEY: 0, + EpochValueConcurrentStreamStateConverter.END_KEY: 10, + }, + { + EpochValueConcurrentStreamStateConverter.START_KEY: 20, + EpochValueConcurrentStreamStateConverter.END_KEY: 25, + }, + { + EpochValueConcurrentStreamStateConverter.START_KEY: 30, + EpochValueConcurrentStreamStateConverter.END_KEY: 40, + }, ], }, self._message_repository, @@ -445,7 +534,9 @@ def test_given_non_continuous_state_when_generate_slices_then_create_slices_betw ] @freezegun.freeze_time(time_to_freeze=datetime.fromtimestamp(50, timezone.utc)) - def test_given_lookback_window_when_generate_slices_then_apply_lookback_on_most_recent_slice(self): + def test_given_lookback_window_when_generate_slices_then_apply_lookback_on_most_recent_slice( + self, + ): start = datetime.fromtimestamp(0, timezone.utc) lookback_window = timedelta(seconds=10) cursor = ConcurrentCursor( @@ -454,8 +545,14 @@ def test_given_lookback_window_when_generate_slices_then_apply_lookback_on_most_ { "state_type": ConcurrencyCompatibleStateType.date_range.value, "slices": [ - {EpochValueConcurrentStreamStateConverter.START_KEY: 0, EpochValueConcurrentStreamStateConverter.END_KEY: 20}, - {EpochValueConcurrentStreamStateConverter.START_KEY: 30, EpochValueConcurrentStreamStateConverter.END_KEY: 40}, + { + EpochValueConcurrentStreamStateConverter.START_KEY: 0, + EpochValueConcurrentStreamStateConverter.END_KEY: 20, + }, + { + EpochValueConcurrentStreamStateConverter.START_KEY: 30, + EpochValueConcurrentStreamStateConverter.END_KEY: 40, + }, ], }, self._message_repository, @@ -476,7 +573,9 @@ def test_given_lookback_window_when_generate_slices_then_apply_lookback_on_most_ ] @freezegun.freeze_time(time_to_freeze=datetime.fromtimestamp(50, timezone.utc)) - def test_given_start_is_before_first_slice_lower_boundary_when_generate_slices_then_generate_slice_before(self): + def test_given_start_is_before_first_slice_lower_boundary_when_generate_slices_then_generate_slice_before( + self, + ): start = datetime.fromtimestamp(0, timezone.utc) cursor = ConcurrentCursor( _A_STREAM_NAME, @@ -484,7 +583,10 @@ def test_given_start_is_before_first_slice_lower_boundary_when_generate_slices_t { "state_type": ConcurrencyCompatibleStateType.date_range.value, "slices": [ - {EpochValueConcurrentStreamStateConverter.START_KEY: 10, EpochValueConcurrentStreamStateConverter.END_KEY: 20}, + { + EpochValueConcurrentStreamStateConverter.START_KEY: 10, + EpochValueConcurrentStreamStateConverter.END_KEY: 20, + }, ], }, self._message_repository, @@ -504,10 +606,16 @@ def test_given_start_is_before_first_slice_lower_boundary_when_generate_slices_t (datetime.fromtimestamp(20, timezone.utc), datetime.fromtimestamp(50, timezone.utc)), ] - def test_slices_with_records_when_close_then_most_recent_cursor_value_from_most_recent_slice(self) -> None: + def test_slices_with_records_when_close_then_most_recent_cursor_value_from_most_recent_slice( + self, + ) -> None: cursor = self._cursor_with_slice_boundary_fields(is_sequential_state=False) - first_partition = _partition({_LOWER_SLICE_BOUNDARY_FIELD: 0, _UPPER_SLICE_BOUNDARY_FIELD: 10}) - second_partition = _partition({_LOWER_SLICE_BOUNDARY_FIELD: 10, _UPPER_SLICE_BOUNDARY_FIELD: 20}) + first_partition = _partition( + {_LOWER_SLICE_BOUNDARY_FIELD: 0, _UPPER_SLICE_BOUNDARY_FIELD: 10} + ) + second_partition = _partition( + {_LOWER_SLICE_BOUNDARY_FIELD: 10, _UPPER_SLICE_BOUNDARY_FIELD: 20} + ) cursor.observe(_record(5, partition=first_partition)) cursor.close_partition(first_partition) @@ -519,10 +627,16 @@ def test_slices_with_records_when_close_then_most_recent_cursor_value_from_most_ "state_type": "date-range", } - def test_last_slice_without_records_when_close_then_most_recent_cursor_value_is_from_previous_slice(self) -> None: + def test_last_slice_without_records_when_close_then_most_recent_cursor_value_is_from_previous_slice( + self, + ) -> None: cursor = self._cursor_with_slice_boundary_fields(is_sequential_state=False) - first_partition = _partition({_LOWER_SLICE_BOUNDARY_FIELD: 0, _UPPER_SLICE_BOUNDARY_FIELD: 10}) - second_partition = _partition({_LOWER_SLICE_BOUNDARY_FIELD: 10, _UPPER_SLICE_BOUNDARY_FIELD: 20}) + first_partition = _partition( + {_LOWER_SLICE_BOUNDARY_FIELD: 0, _UPPER_SLICE_BOUNDARY_FIELD: 10} + ) + second_partition = _partition( + {_LOWER_SLICE_BOUNDARY_FIELD: 10, _UPPER_SLICE_BOUNDARY_FIELD: 20} + ) cursor.observe(_record(5, partition=first_partition)) cursor.close_partition(first_partition) @@ -533,7 +647,9 @@ def test_last_slice_without_records_when_close_then_most_recent_cursor_value_is_ "state_type": "date-range", } - def test_most_recent_cursor_value_outside_of_boundaries_when_close_then_most_recent_cursor_value_still_considered(self) -> None: + def test_most_recent_cursor_value_outside_of_boundaries_when_close_then_most_recent_cursor_value_still_considered( + self, + ) -> None: """ Not sure what is the value of this behavior but I'm simply documenting how it is today """ @@ -547,29 +663,41 @@ def test_most_recent_cursor_value_outside_of_boundaries_when_close_then_most_rec "state_type": "date-range", } - def test_most_recent_cursor_value_on_sequential_state_when_close_then_cursor_value_is_most_recent_cursor_value(self) -> None: + def test_most_recent_cursor_value_on_sequential_state_when_close_then_cursor_value_is_most_recent_cursor_value( + self, + ) -> None: cursor = self._cursor_with_slice_boundary_fields(is_sequential_state=True) partition = _partition({_LOWER_SLICE_BOUNDARY_FIELD: 0, _UPPER_SLICE_BOUNDARY_FIELD: 10}) cursor.observe(_record(7, partition=partition)) cursor.close_partition(partition) - assert self._state_manager.update_state_for_stream.call_args_list[-1].args[2] == {_A_CURSOR_FIELD_KEY: 7} + assert self._state_manager.update_state_for_stream.call_args_list[-1].args[2] == { + _A_CURSOR_FIELD_KEY: 7 + } def test_non_continuous_slices_on_sequential_state_when_close_then_cursor_value_is_most_recent_cursor_value_of_first_slice( self, ) -> None: cursor = self._cursor_with_slice_boundary_fields(is_sequential_state=True) - first_partition = _partition({_LOWER_SLICE_BOUNDARY_FIELD: 0, _UPPER_SLICE_BOUNDARY_FIELD: 10}) - third_partition = _partition({_LOWER_SLICE_BOUNDARY_FIELD: 20, _UPPER_SLICE_BOUNDARY_FIELD: 30}) # second partition has failed + first_partition = _partition( + {_LOWER_SLICE_BOUNDARY_FIELD: 0, _UPPER_SLICE_BOUNDARY_FIELD: 10} + ) + third_partition = _partition( + {_LOWER_SLICE_BOUNDARY_FIELD: 20, _UPPER_SLICE_BOUNDARY_FIELD: 30} + ) # second partition has failed cursor.observe(_record(7, partition=first_partition)) cursor.close_partition(first_partition) cursor.close_partition(third_partition) - assert self._state_manager.update_state_for_stream.call_args_list[-1].args[2] == {_A_CURSOR_FIELD_KEY: 7} + assert self._state_manager.update_state_for_stream.call_args_list[-1].args[2] == { + _A_CURSOR_FIELD_KEY: 7 + } @freezegun.freeze_time(time_to_freeze=datetime.fromtimestamp(10, timezone.utc)) - def test_given_overflowing_slice_gap_when_generate_slices_then_cap_upper_bound_to_end_provider(self) -> None: + def test_given_overflowing_slice_gap_when_generate_slices_then_cap_upper_bound_to_end_provider( + self, + ) -> None: a_very_big_slice_range = timedelta.max cursor = ConcurrentCursor( _A_STREAM_NAME, @@ -588,7 +716,9 @@ def test_given_overflowing_slice_gap_when_generate_slices_then_cap_upper_bound_t slices = list(cursor.generate_slices()) - assert slices == [(datetime.fromtimestamp(0, timezone.utc), datetime.fromtimestamp(10, timezone.utc))] + assert slices == [ + (datetime.fromtimestamp(0, timezone.utc), datetime.fromtimestamp(10, timezone.utc)) + ] @freezegun.freeze_time(time_to_freeze=datetime(2024, 4, 1, 0, 0, 0, 0, tzinfo=timezone.utc)) @@ -603,12 +733,30 @@ def test_given_overflowing_slice_gap_when_generate_slices_then_cap_upper_bound_t "P5D", {}, [ - (datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), datetime(2024, 1, 10, 23, 59, 59, tzinfo=timezone.utc)), - (datetime(2024, 1, 11, 0, 0, tzinfo=timezone.utc), datetime(2024, 1, 20, 23, 59, 59, tzinfo=timezone.utc)), - (datetime(2024, 1, 21, 0, 0, tzinfo=timezone.utc), datetime(2024, 1, 30, 23, 59, 59, tzinfo=timezone.utc)), - (datetime(2024, 1, 31, 0, 0, tzinfo=timezone.utc), datetime(2024, 2, 9, 23, 59, 59, tzinfo=timezone.utc)), - (datetime(2024, 2, 10, 0, 0, tzinfo=timezone.utc), datetime(2024, 2, 19, 23, 59, 59, tzinfo=timezone.utc)), - (datetime(2024, 2, 20, 0, 0, tzinfo=timezone.utc), datetime(2024, 3, 1, 0, 0, 0, tzinfo=timezone.utc)), + ( + datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + datetime(2024, 1, 10, 23, 59, 59, tzinfo=timezone.utc), + ), + ( + datetime(2024, 1, 11, 0, 0, tzinfo=timezone.utc), + datetime(2024, 1, 20, 23, 59, 59, tzinfo=timezone.utc), + ), + ( + datetime(2024, 1, 21, 0, 0, tzinfo=timezone.utc), + datetime(2024, 1, 30, 23, 59, 59, tzinfo=timezone.utc), + ), + ( + datetime(2024, 1, 31, 0, 0, tzinfo=timezone.utc), + datetime(2024, 2, 9, 23, 59, 59, tzinfo=timezone.utc), + ), + ( + datetime(2024, 2, 10, 0, 0, tzinfo=timezone.utc), + datetime(2024, 2, 19, 23, 59, 59, tzinfo=timezone.utc), + ), + ( + datetime(2024, 2, 20, 0, 0, tzinfo=timezone.utc), + datetime(2024, 3, 1, 0, 0, 0, tzinfo=timezone.utc), + ), ], id="test_datetime_based_cursor_all_fields", ), @@ -628,9 +776,18 @@ def test_given_overflowing_slice_gap_when_generate_slices_then_cap_upper_bound_t "state_type": "date-range", }, [ - (datetime(2024, 2, 5, 0, 0, 0, tzinfo=timezone.utc), datetime(2024, 2, 14, 23, 59, 59, tzinfo=timezone.utc)), - (datetime(2024, 2, 15, 0, 0, 0, tzinfo=timezone.utc), datetime(2024, 2, 24, 23, 59, 59, tzinfo=timezone.utc)), - (datetime(2024, 2, 25, 0, 0, 0, tzinfo=timezone.utc), datetime(2024, 3, 1, 0, 0, 0, tzinfo=timezone.utc)), + ( + datetime(2024, 2, 5, 0, 0, 0, tzinfo=timezone.utc), + datetime(2024, 2, 14, 23, 59, 59, tzinfo=timezone.utc), + ), + ( + datetime(2024, 2, 15, 0, 0, 0, tzinfo=timezone.utc), + datetime(2024, 2, 24, 23, 59, 59, tzinfo=timezone.utc), + ), + ( + datetime(2024, 2, 25, 0, 0, 0, tzinfo=timezone.utc), + datetime(2024, 3, 1, 0, 0, 0, tzinfo=timezone.utc), + ), ], id="test_datetime_based_cursor_with_state", ), @@ -650,10 +807,22 @@ def test_given_overflowing_slice_gap_when_generate_slices_then_cap_upper_bound_t "state_type": "date-range", }, [ - (datetime(2024, 1, 20, 0, 0, tzinfo=timezone.utc), datetime(2024, 2, 8, 23, 59, 59, tzinfo=timezone.utc)), - (datetime(2024, 2, 9, 0, 0, tzinfo=timezone.utc), datetime(2024, 2, 28, 23, 59, 59, tzinfo=timezone.utc)), - (datetime(2024, 2, 29, 0, 0, tzinfo=timezone.utc), datetime(2024, 3, 19, 23, 59, 59, tzinfo=timezone.utc)), - (datetime(2024, 3, 20, 0, 0, tzinfo=timezone.utc), datetime(2024, 4, 1, 0, 0, 0, tzinfo=timezone.utc)), + ( + datetime(2024, 1, 20, 0, 0, tzinfo=timezone.utc), + datetime(2024, 2, 8, 23, 59, 59, tzinfo=timezone.utc), + ), + ( + datetime(2024, 2, 9, 0, 0, tzinfo=timezone.utc), + datetime(2024, 2, 28, 23, 59, 59, tzinfo=timezone.utc), + ), + ( + datetime(2024, 2, 29, 0, 0, tzinfo=timezone.utc), + datetime(2024, 3, 19, 23, 59, 59, tzinfo=timezone.utc), + ), + ( + datetime(2024, 3, 20, 0, 0, tzinfo=timezone.utc), + datetime(2024, 4, 1, 0, 0, 0, tzinfo=timezone.utc), + ), ], id="test_datetime_based_cursor_with_state_and_end_date", ), @@ -665,8 +834,14 @@ def test_given_overflowing_slice_gap_when_generate_slices_then_cap_upper_bound_t "P5D", {}, [ - (datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc), datetime(2024, 1, 31, 23, 59, 59, tzinfo=timezone.utc)), - (datetime(2024, 2, 1, 0, 0, 0, tzinfo=timezone.utc), datetime(2024, 3, 1, 0, 0, 0, tzinfo=timezone.utc)), + ( + datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc), + datetime(2024, 1, 31, 23, 59, 59, tzinfo=timezone.utc), + ), + ( + datetime(2024, 2, 1, 0, 0, 0, tzinfo=timezone.utc), + datetime(2024, 3, 1, 0, 0, 0, tzinfo=timezone.utc), + ), ], id="test_datetime_based_cursor_using_large_step_duration", ), @@ -724,9 +899,17 @@ def test_generate_slices_concurrent_cursor_from_datetime_based_cursor( # may need to convert a Duration to timedelta by multiplying month by 30 (but could lose precision). step_length = datetime_based_cursor._step - lookback_window = parse_duration(datetime_based_cursor.lookback_window) if datetime_based_cursor.lookback_window else None + lookback_window = ( + parse_duration(datetime_based_cursor.lookback_window) + if datetime_based_cursor.lookback_window + else None + ) - cursor_granularity = parse_duration(datetime_based_cursor.cursor_granularity) if datetime_based_cursor.cursor_granularity else None + cursor_granularity = ( + parse_duration(datetime_based_cursor.cursor_granularity) + if datetime_based_cursor.cursor_granularity + else None + ) cursor = ConcurrentCursor( stream_name=_A_STREAM_NAME, @@ -786,21 +969,39 @@ def test_observe_concurrent_cursor_from_datetime_based_cursor(): ) partition = _partition( - {_LOWER_SLICE_BOUNDARY_FIELD: "2024-08-01T00:00:00.000000+0000", _UPPER_SLICE_BOUNDARY_FIELD: "2024-09-01T00:00:00.000000+0000"}, + { + _LOWER_SLICE_BOUNDARY_FIELD: "2024-08-01T00:00:00.000000+0000", + _UPPER_SLICE_BOUNDARY_FIELD: "2024-09-01T00:00:00.000000+0000", + }, _stream_name="gods", ) record_1 = Record( partition=partition, - data={"id": "999", "updated_at": "2024-08-23T00:00:00.000000+0000", "name": "kratos", "mythology": "greek"}, + data={ + "id": "999", + "updated_at": "2024-08-23T00:00:00.000000+0000", + "name": "kratos", + "mythology": "greek", + }, ) record_2 = Record( partition=partition, - data={"id": "1000", "updated_at": "2024-08-22T00:00:00.000000+0000", "name": "odin", "mythology": "norse"}, + data={ + "id": "1000", + "updated_at": "2024-08-22T00:00:00.000000+0000", + "name": "odin", + "mythology": "norse", + }, ) record_3 = Record( partition=partition, - data={"id": "500", "updated_at": "2024-08-24T00:00:00.000000+0000", "name": "freya", "mythology": "norse"}, + data={ + "id": "500", + "updated_at": "2024-08-24T00:00:00.000000+0000", + "name": "freya", + "mythology": "norse", + }, ) concurrent_cursor.observe(record_1) @@ -845,7 +1046,9 @@ def test_close_partition_concurrent_cursor_from_datetime_based_cursor(): stream_state={}, message_repository=message_repository, connector_state_manager=state_manager, - connector_state_converter=IsoMillisConcurrentStreamStateConverter(is_sequential_state=False), + connector_state_converter=IsoMillisConcurrentStreamStateConverter( + is_sequential_state=False + ), cursor_field=cursor_field, slice_boundary_fields=None, start=start_date, @@ -854,19 +1057,29 @@ def test_close_partition_concurrent_cursor_from_datetime_based_cursor(): ) partition = _partition( - {_LOWER_SLICE_BOUNDARY_FIELD: "2024-08-01T00:00:00.000000+0000", _UPPER_SLICE_BOUNDARY_FIELD: "2024-09-01T00:00:00.000000+0000"}, + { + _LOWER_SLICE_BOUNDARY_FIELD: "2024-08-01T00:00:00.000000+0000", + _UPPER_SLICE_BOUNDARY_FIELD: "2024-09-01T00:00:00.000000+0000", + }, _stream_name="gods", ) record_1 = Record( partition=partition, - data={"id": "999", "updated_at": "2024-08-23T00:00:00.000000+0000", "name": "kratos", "mythology": "greek"}, + data={ + "id": "999", + "updated_at": "2024-08-23T00:00:00.000000+0000", + "name": "kratos", + "mythology": "greek", + }, ) concurrent_cursor.observe(record_1) concurrent_cursor.close_partition(partition) - message_repository.emit_message.assert_called_once_with(state_manager.create_state_message.return_value) + message_repository.emit_message.assert_called_once_with( + state_manager.create_state_message.return_value + ) state_manager.update_state_for_stream.assert_called_once_with( "gods", _A_STREAM_NAMESPACE, @@ -918,7 +1131,9 @@ def test_close_partition_with_slice_range_concurrent_cursor_from_datetime_based_ stream_state={}, message_repository=message_repository, connector_state_manager=state_manager, - connector_state_converter=IsoMillisConcurrentStreamStateConverter(is_sequential_state=False, cursor_granularity=None), + connector_state_converter=IsoMillisConcurrentStreamStateConverter( + is_sequential_state=False, cursor_granularity=None + ), cursor_field=cursor_field, slice_boundary_fields=slice_boundary_fields, start=start_date, @@ -928,20 +1143,36 @@ def test_close_partition_with_slice_range_concurrent_cursor_from_datetime_based_ ) partition_0 = _partition( - {"start_time": "2024-07-01T00:00:00.000000+0000", "end_time": "2024-07-16T00:00:00.000000+0000"}, + { + "start_time": "2024-07-01T00:00:00.000000+0000", + "end_time": "2024-07-16T00:00:00.000000+0000", + }, _stream_name="gods", ) partition_3 = _partition( - {"start_time": "2024-08-15T00:00:00.000000+0000", "end_time": "2024-08-30T00:00:00.000000+0000"}, + { + "start_time": "2024-08-15T00:00:00.000000+0000", + "end_time": "2024-08-30T00:00:00.000000+0000", + }, _stream_name="gods", ) record_1 = Record( partition=partition_0, - data={"id": "1000", "updated_at": "2024-07-05T00:00:00.000000+0000", "name": "loki", "mythology": "norse"}, + data={ + "id": "1000", + "updated_at": "2024-07-05T00:00:00.000000+0000", + "name": "loki", + "mythology": "norse", + }, ) record_2 = Record( partition=partition_3, - data={"id": "999", "updated_at": "2024-08-20T00:00:00.000000+0000", "name": "kratos", "mythology": "greek"}, + data={ + "id": "999", + "updated_at": "2024-08-20T00:00:00.000000+0000", + "name": "kratos", + "mythology": "greek", + }, ) concurrent_cursor.observe(record_1) @@ -949,7 +1180,9 @@ def test_close_partition_with_slice_range_concurrent_cursor_from_datetime_based_ concurrent_cursor.observe(record_2) concurrent_cursor.close_partition(partition_3) - message_repository.emit_message.assert_called_with(state_manager.create_state_message.return_value) + message_repository.emit_message.assert_called_with( + state_manager.create_state_message.return_value + ) assert message_repository.emit_message.call_count == 2 state_manager.update_state_for_stream.assert_called_with( "gods", @@ -1002,7 +1235,11 @@ def test_close_partition_with_slice_range_granularity_concurrent_cursor_from_dat step_length = datetime_based_cursor._step - cursor_granularity = parse_duration(datetime_based_cursor.cursor_granularity) if datetime_based_cursor.cursor_granularity else None + cursor_granularity = ( + parse_duration(datetime_based_cursor.cursor_granularity) + if datetime_based_cursor.cursor_granularity + else None + ) concurrent_cursor = ConcurrentCursor( stream_name="gods", @@ -1010,7 +1247,9 @@ def test_close_partition_with_slice_range_granularity_concurrent_cursor_from_dat stream_state={}, message_repository=message_repository, connector_state_manager=state_manager, - connector_state_converter=IsoMillisConcurrentStreamStateConverter(is_sequential_state=False, cursor_granularity=cursor_granularity), + connector_state_converter=IsoMillisConcurrentStreamStateConverter( + is_sequential_state=False, cursor_granularity=cursor_granularity + ), cursor_field=cursor_field, slice_boundary_fields=slice_boundary_fields, start=start_date, @@ -1020,28 +1259,52 @@ def test_close_partition_with_slice_range_granularity_concurrent_cursor_from_dat ) partition_0 = _partition( - {"start_time": "2024-07-01T00:00:00.000000+0000", "end_time": "2024-07-15T00:00:00.000000+0000"}, + { + "start_time": "2024-07-01T00:00:00.000000+0000", + "end_time": "2024-07-15T00:00:00.000000+0000", + }, _stream_name="gods", ) partition_1 = _partition( - {"start_time": "2024-07-16T00:00:00.000000+0000", "end_time": "2024-07-31T00:00:00.000000+0000"}, + { + "start_time": "2024-07-16T00:00:00.000000+0000", + "end_time": "2024-07-31T00:00:00.000000+0000", + }, _stream_name="gods", ) partition_3 = _partition( - {"start_time": "2024-08-15T00:00:00.000000+0000", "end_time": "2024-08-29T00:00:00.000000+0000"}, + { + "start_time": "2024-08-15T00:00:00.000000+0000", + "end_time": "2024-08-29T00:00:00.000000+0000", + }, _stream_name="gods", ) record_1 = Record( partition=partition_0, - data={"id": "1000", "updated_at": "2024-07-05T00:00:00.000000+0000", "name": "loki", "mythology": "norse"}, + data={ + "id": "1000", + "updated_at": "2024-07-05T00:00:00.000000+0000", + "name": "loki", + "mythology": "norse", + }, ) record_2 = Record( partition=partition_1, - data={"id": "2000", "updated_at": "2024-07-25T00:00:00.000000+0000", "name": "freya", "mythology": "norse"}, + data={ + "id": "2000", + "updated_at": "2024-07-25T00:00:00.000000+0000", + "name": "freya", + "mythology": "norse", + }, ) record_3 = Record( partition=partition_3, - data={"id": "999", "updated_at": "2024-08-20T00:00:00.000000+0000", "name": "kratos", "mythology": "greek"}, + data={ + "id": "999", + "updated_at": "2024-08-20T00:00:00.000000+0000", + "name": "kratos", + "mythology": "greek", + }, ) concurrent_cursor.observe(record_1) @@ -1051,7 +1314,9 @@ def test_close_partition_with_slice_range_granularity_concurrent_cursor_from_dat concurrent_cursor.observe(record_3) concurrent_cursor.close_partition(partition_3) - message_repository.emit_message.assert_called_with(state_manager.create_state_message.return_value) + message_repository.emit_message.assert_called_with( + state_manager.create_state_message.return_value + ) assert message_repository.emit_message.call_count == 3 state_manager.update_state_for_stream.assert_called_with( "gods", diff --git a/unit_tests/sources/streams/concurrent/test_datetime_state_converter.py b/unit_tests/sources/streams/concurrent/test_datetime_state_converter.py index d139656f..32272973 100644 --- a/unit_tests/sources/streams/concurrent/test_datetime_state_converter.py +++ b/unit_tests/sources/streams/concurrent/test_datetime_state_converter.py @@ -6,7 +6,9 @@ import pytest from airbyte_cdk.sources.streams.concurrent.cursor import CursorField -from airbyte_cdk.sources.streams.concurrent.state_converters.abstract_stream_state_converter import ConcurrencyCompatibleStateType +from airbyte_cdk.sources.streams.concurrent.state_converters.abstract_stream_state_converter import ( + ConcurrencyCompatibleStateType, +) from airbyte_cdk.sources.streams.concurrent.state_converters.datetime_stream_state_converter import ( CustomFormatConcurrentStreamStateConverter, EpochValueConcurrentStreamStateConverter, @@ -83,7 +85,9 @@ ), ], ) -def test_concurrent_stream_state_converter_is_state_message_compatible(converter, input_state, is_compatible): +def test_concurrent_stream_state_converter_is_state_message_compatible( + converter, input_state, is_compatible +): assert converter.is_state_message_compatible(input_state) == is_compatible @@ -222,14 +226,22 @@ def test_get_sync_start(converter, start, state, expected_start): def test_convert_from_sequential_state(converter, start, sequential_state, expected_output_state): comparison_format = "%Y-%m-%dT%H:%M:%S.%f" if expected_output_state["slices"]: - _, conversion = converter.convert_from_sequential_state(CursorField("created"), sequential_state, start) + _, conversion = converter.convert_from_sequential_state( + CursorField("created"), sequential_state, start + ) assert conversion["state_type"] == expected_output_state["state_type"] assert conversion["legacy"] == expected_output_state["legacy"] for actual, expected in zip(conversion["slices"], expected_output_state["slices"]): - assert actual["start"].strftime(comparison_format) == expected["start"].strftime(comparison_format) - assert actual["end"].strftime(comparison_format) == expected["end"].strftime(comparison_format) + assert actual["start"].strftime(comparison_format) == expected["start"].strftime( + comparison_format + ) + assert actual["end"].strftime(comparison_format) == expected["end"].strftime( + comparison_format + ) else: - _, conversion = converter.convert_from_sequential_state(CursorField("created"), sequential_state, start) + _, conversion = converter.convert_from_sequential_state( + CursorField("created"), sequential_state, start + ) assert conversion == expected_output_state @@ -339,7 +351,10 @@ def test_convert_from_sequential_state(converter, start, sequential_state, expec ], ) def test_convert_to_sequential_state(converter, concurrent_state, expected_output_state): - assert converter.convert_to_state_message(CursorField("created"), concurrent_state) == expected_output_state + assert ( + converter.convert_to_state_message(CursorField("created"), concurrent_state) + == expected_output_state + ) @pytest.mark.parametrize( @@ -365,7 +380,9 @@ def test_convert_to_sequential_state(converter, concurrent_state, expected_outpu ), ], ) -def test_convert_to_sequential_state_no_slices_returns_legacy_state(converter, concurrent_state, expected_output_state): +def test_convert_to_sequential_state_no_slices_returns_legacy_state( + converter, concurrent_state, expected_output_state +): with pytest.raises(RuntimeError): converter.convert_to_state_message(CursorField("created"), concurrent_state) diff --git a/unit_tests/sources/streams/concurrent/test_default_stream.py b/unit_tests/sources/streams/concurrent/test_default_stream.py index bb06a7b7..25b15ca2 100644 --- a/unit_tests/sources/streams/concurrent/test_default_stream.py +++ b/unit_tests/sources/streams/concurrent/test_default_stream.py @@ -30,7 +30,11 @@ def setUp(self): self._primary_key, self._cursor_field, self._logger, - FinalStateCursor(stream_name=self._name, stream_namespace=None, message_repository=self._message_repository), + FinalStateCursor( + stream_name=self._name, + stream_namespace=None, + message_repository=self._message_repository, + ), ) def test_get_json_schema(self): @@ -91,7 +95,11 @@ def test_as_airbyte_stream_with_primary_key(self): ["composite_key_1", "composite_key_2"], self._cursor_field, self._logger, - FinalStateCursor(stream_name=self._name, stream_namespace=None, message_repository=self._message_repository), + FinalStateCursor( + stream_name=self._name, + stream_namespace=None, + message_repository=self._message_repository, + ), ) expected_airbyte_stream = AirbyteStream( @@ -123,7 +131,11 @@ def test_as_airbyte_stream_with_composite_primary_key(self): ["id_a", "id_b"], self._cursor_field, self._logger, - FinalStateCursor(stream_name=self._name, stream_namespace=None, message_repository=self._message_repository), + FinalStateCursor( + stream_name=self._name, + stream_namespace=None, + message_repository=self._message_repository, + ), ) expected_airbyte_stream = AirbyteStream( @@ -155,7 +167,11 @@ def test_as_airbyte_stream_with_a_cursor(self): self._primary_key, "date", self._logger, - FinalStateCursor(stream_name=self._name, stream_namespace=None, message_repository=self._message_repository), + FinalStateCursor( + stream_name=self._name, + stream_namespace=None, + message_repository=self._message_repository, + ), ) expected_airbyte_stream = AirbyteStream( @@ -181,7 +197,11 @@ def test_as_airbyte_stream_with_namespace(self): self._primary_key, self._cursor_field, self._logger, - FinalStateCursor(stream_name=self._name, stream_namespace=None, message_repository=self._message_repository), + FinalStateCursor( + stream_name=self._name, + stream_namespace=None, + message_repository=self._message_repository, + ), namespace="test", ) expected_airbyte_stream = AirbyteStream( diff --git a/unit_tests/sources/streams/concurrent/test_partition_enqueuer.py b/unit_tests/sources/streams/concurrent/test_partition_enqueuer.py index da67ff82..02c1bdd1 100644 --- a/unit_tests/sources/streams/concurrent/test_partition_enqueuer.py +++ b/unit_tests/sources/streams/concurrent/test_partition_enqueuer.py @@ -6,7 +6,9 @@ from typing import Callable, Iterable, List from unittest.mock import Mock, patch -from airbyte_cdk.sources.concurrent_source.partition_generation_completed_sentinel import PartitionGenerationCompletedSentinel +from airbyte_cdk.sources.concurrent_source.partition_generation_completed_sentinel import ( + PartitionGenerationCompletedSentinel, +) from airbyte_cdk.sources.concurrent_source.stream_thread_exception import StreamThreadException from airbyte_cdk.sources.concurrent_source.thread_pool_manager import ThreadPoolManager from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream @@ -27,7 +29,9 @@ def setUp(self) -> None: @patch("airbyte_cdk.sources.streams.concurrent.partition_enqueuer.time.sleep") def test_given_no_partitions_when_generate_partitions_then_do_not_wait(self, mocked_sleep): - self._thread_pool_manager.prune_to_validate_has_reached_futures_limit.return_value = True # shouldn't be called but just in case + self._thread_pool_manager.prune_to_validate_has_reached_futures_limit.return_value = ( + True # shouldn't be called but just in case + ) stream = self._a_stream([]) self._partition_generator.generate_partitions(stream) @@ -48,11 +52,19 @@ def test_given_partitions_when_generate_partitions_then_return_partitions_before self._partition_generator.generate_partitions(stream) - assert self._consume_queue() == _SOME_PARTITIONS + [PartitionGenerationCompletedSentinel(stream)] + assert self._consume_queue() == _SOME_PARTITIONS + [ + PartitionGenerationCompletedSentinel(stream) + ] @patch("airbyte_cdk.sources.streams.concurrent.partition_enqueuer.time.sleep") - def test_given_partition_but_limit_reached_when_generate_partitions_then_wait_until_not_hitting_limit(self, mocked_sleep): - self._thread_pool_manager.prune_to_validate_has_reached_futures_limit.side_effect = [True, True, False] + def test_given_partition_but_limit_reached_when_generate_partitions_then_wait_until_not_hitting_limit( + self, mocked_sleep + ): + self._thread_pool_manager.prune_to_validate_has_reached_futures_limit.side_effect = [ + True, + True, + False, + ] stream = self._a_stream([Mock(spec=Partition)]) self._partition_generator.generate_partitions(stream) @@ -63,7 +75,9 @@ def test_given_exception_when_generate_partitions_then_return_exception_and_sent stream = Mock(spec=AbstractStream) stream.name = _A_STREAM_NAME exception = ValueError() - stream.generate_partitions.side_effect = self._partitions_before_raising(_SOME_PARTITIONS, exception) + stream.generate_partitions.side_effect = self._partitions_before_raising( + _SOME_PARTITIONS, exception + ) self._partition_generator.generate_partitions(stream) @@ -73,7 +87,9 @@ def test_given_exception_when_generate_partitions_then_return_exception_and_sent PartitionGenerationCompletedSentinel(stream), ] - def _partitions_before_raising(self, partitions: List[Partition], exception: Exception) -> Callable[[], Iterable[Partition]]: + def _partitions_before_raising( + self, partitions: List[Partition], exception: Exception + ) -> Callable[[], Iterable[Partition]]: def inner_function() -> Iterable[Partition]: for partition in partitions: yield partition diff --git a/unit_tests/sources/streams/concurrent/test_partition_reader.py b/unit_tests/sources/streams/concurrent/test_partition_reader.py index 226652be..16d70719 100644 --- a/unit_tests/sources/streams/concurrent/test_partition_reader.py +++ b/unit_tests/sources/streams/concurrent/test_partition_reader.py @@ -11,7 +11,10 @@ from airbyte_cdk.sources.streams.concurrent.partition_reader import PartitionReader from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition from airbyte_cdk.sources.streams.concurrent.partitions.record import Record -from airbyte_cdk.sources.streams.concurrent.partitions.types import PartitionCompleteSentinel, QueueItem +from airbyte_cdk.sources.streams.concurrent.partitions.types import ( + PartitionCompleteSentinel, + QueueItem, +) _RECORDS = [ Record({"id": 1, "name": "Jack"}, "stream"), @@ -32,7 +35,9 @@ def test_given_no_records_when_process_partition_then_only_emit_sentinel(self): pytest.fail("Only one PartitionCompleteSentinel is expected") break - def test_given_read_partition_successful_when_process_partition_then_queue_records_and_sentinel(self): + def test_given_read_partition_successful_when_process_partition_then_queue_records_and_sentinel( + self, + ): partition = self._a_partition(_RECORDS) self._partition_reader.process_partition(partition) @@ -40,7 +45,9 @@ def test_given_read_partition_successful_when_process_partition_then_queue_recor assert queue_content == _RECORDS + [PartitionCompleteSentinel(partition)] - def test_given_exception_when_process_partition_then_queue_records_and_exception_and_sentinel(self): + def test_given_exception_when_process_partition_then_queue_records_and_exception_and_sentinel( + self, + ): partition = Mock() exception = ValueError() partition.read.side_effect = self._read_with_exception(_RECORDS, exception) @@ -48,7 +55,10 @@ def test_given_exception_when_process_partition_then_queue_records_and_exception queue_content = self._consume_queue() - assert queue_content == _RECORDS + [StreamThreadException(exception, partition.stream_name()), PartitionCompleteSentinel(partition)] + assert queue_content == _RECORDS + [ + StreamThreadException(exception, partition.stream_name()), + PartitionCompleteSentinel(partition), + ] def _a_partition(self, records: List[Record]) -> Partition: partition = Mock(spec=Partition) @@ -56,7 +66,9 @@ def _a_partition(self, records: List[Record]) -> Partition: return partition @staticmethod - def _read_with_exception(records: List[Record], exception: Exception) -> Callable[[], Iterable[Record]]: + def _read_with_exception( + records: List[Record], exception: Exception + ) -> Callable[[], Iterable[Record]]: def mocked_function() -> Iterable[Record]: yield from records raise exception diff --git a/unit_tests/sources/streams/concurrent/test_thread_pool_manager.py b/unit_tests/sources/streams/concurrent/test_thread_pool_manager.py index 197f9b34..d4820db9 100644 --- a/unit_tests/sources/streams/concurrent/test_thread_pool_manager.py +++ b/unit_tests/sources/streams/concurrent/test_thread_pool_manager.py @@ -11,7 +11,9 @@ class ThreadPoolManagerTest(TestCase): def setUp(self): self._threadpool = Mock(spec=ThreadPoolExecutor) - self._thread_pool_manager = ThreadPoolManager(self._threadpool, Mock(), max_concurrent_tasks=1) + self._thread_pool_manager = ThreadPoolManager( + self._threadpool, Mock(), max_concurrent_tasks=1 + ) self._fn = lambda x: x self._arg = "arg" @@ -21,7 +23,9 @@ def test_submit_calls_underlying_thread_pool(self): assert len(self._thread_pool_manager._futures) == 1 - def test_given_exception_during_pruning_when_check_for_errors_and_shutdown_then_shutdown_and_raise(self): + def test_given_exception_during_pruning_when_check_for_errors_and_shutdown_then_shutdown_and_raise( + self, + ): future = Mock(spec=Future) future.exception.return_value = RuntimeError future.done.side_effect = [True, True] diff --git a/unit_tests/sources/streams/http/error_handlers/test_default_backoff_strategy.py b/unit_tests/sources/streams/http/error_handlers/test_default_backoff_strategy.py index 7bb02c55..ee487e97 100644 --- a/unit_tests/sources/streams/http/error_handlers/test_default_backoff_strategy.py +++ b/unit_tests/sources/streams/http/error_handlers/test_default_backoff_strategy.py @@ -16,7 +16,9 @@ def test_given_no_arguments_default_backoff_strategy_returns_default_values(): class CustomBackoffStrategy(BackoffStrategy): def backoff_time( - self, response_or_exception: Optional[Union[requests.Response, requests.RequestException]], attempt_count: int + self, + response_or_exception: Optional[Union[requests.Response, requests.RequestException]], + attempt_count: int, ) -> Optional[float]: return response_or_exception.headers["Retry-After"] diff --git a/unit_tests/sources/streams/http/error_handlers/test_http_status_error_handler.py b/unit_tests/sources/streams/http/error_handlers/test_http_status_error_handler.py index 3ec1cc1d..355d20b8 100644 --- a/unit_tests/sources/streams/http/error_handlers/test_http_status_error_handler.py +++ b/unit_tests/sources/streams/http/error_handlers/test_http_status_error_handler.py @@ -6,7 +6,11 @@ import pytest import requests from airbyte_cdk.models import FailureType -from airbyte_cdk.sources.streams.http.error_handlers import ErrorResolution, HttpStatusErrorHandler, ResponseAction +from airbyte_cdk.sources.streams.http.error_handlers import ( + ErrorResolution, + HttpStatusErrorHandler, + ResponseAction, +) logger = MagicMock() @@ -25,8 +29,18 @@ def test_given_ok_response_http_status_error_handler_returns_success_action(mock @pytest.mark.parametrize( "error, expected_action, expected_failure_type, expected_error_message", [ - (403, ResponseAction.FAIL, FailureType.config_error, "Forbidden. You don't have permission to access this resource."), - (404, ResponseAction.FAIL, FailureType.system_error, "Not found. The requested resource was not found on the server."), + ( + 403, + ResponseAction.FAIL, + FailureType.config_error, + "Forbidden. You don't have permission to access this resource.", + ), + ( + 404, + ResponseAction.FAIL, + FailureType.system_error, + "Not found. The requested resource was not found on the server.", + ), ], ) def test_given_error_code_in_response_http_status_error_handler_returns_expected_actions( @@ -59,7 +73,9 @@ def test_given_unmapped_status_error_returns_retry_action_as_transient_error(): def test_given_requests_exception_returns_retry_action_as_transient_error(): - error_resolution = HttpStatusErrorHandler(logger).interpret_response(requests.RequestException()) + error_resolution = HttpStatusErrorHandler(logger).interpret_response( + requests.RequestException() + ) assert error_resolution.response_action == ResponseAction.RETRY assert error_resolution.failure_type @@ -91,15 +107,22 @@ def test_given_injected_error_mapping_returns_expected_action(): assert default_error_resolution.response_action == ResponseAction.RETRY assert default_error_resolution.failure_type == FailureType.system_error - assert default_error_resolution.error_message == f"Unexpected HTTP Status Code in error handler: {mock_response.status_code}" + assert ( + default_error_resolution.error_message + == f"Unexpected HTTP Status Code in error handler: {mock_response.status_code}" + ) mapped_error_resolution = ErrorResolution( - response_action=ResponseAction.IGNORE, failure_type=FailureType.transient_error, error_message="Injected mapping" + response_action=ResponseAction.IGNORE, + failure_type=FailureType.transient_error, + error_message="Injected mapping", ) error_mapping = {509: mapped_error_resolution} - actual_error_resolution = HttpStatusErrorHandler(logger, error_mapping).interpret_response(mock_response) + actual_error_resolution = HttpStatusErrorHandler(logger, error_mapping).interpret_response( + mock_response + ) assert actual_error_resolution.response_action == mapped_error_resolution.response_action assert actual_error_resolution.failure_type == mapped_error_resolution.failure_type diff --git a/unit_tests/sources/streams/http/error_handlers/test_json_error_message_parser.py b/unit_tests/sources/streams/http/error_handlers/test_json_error_message_parser.py index 72a37722..fecaa13f 100644 --- a/unit_tests/sources/streams/http/error_handlers/test_json_error_message_parser.py +++ b/unit_tests/sources/streams/http/error_handlers/test_json_error_message_parser.py @@ -12,9 +12,15 @@ [ (b'{"message": "json error message"}', "json error message"), (b'[{"message": "list error message"}]', "list error message"), - (b'[{"message": "list error message 1"}, {"message": "list error message 2"}]', "list error message 1, list error message 2"), + ( + b'[{"message": "list error message 1"}, {"message": "list error message 2"}]', + "list error message 1, list error message 2", + ), (b'{"error": "messages error message"}', "messages error message"), - (b'[{"errors": "list error message 1"}, {"errors": "list error message 2"}]', "list error message 1, list error message 2"), + ( + b'[{"errors": "list error message 1"}, {"errors": "list error message 2"}]', + "list error message 1, list error message 2", + ), (b'{"failures": "failures error message"}', "failures error message"), (b'{"failure": "failure error message"}', "failure error message"), (b'{"detail": "detail error message"}', "detail error message"), @@ -25,7 +31,9 @@ (b'{"status_message": "status_message error message"}', "status_message error message"), ], ) -def test_given_error_message_in_response_body_parse_response_error_message_returns_error_message(response_body, expected_error_message): +def test_given_error_message_in_response_body_parse_response_error_message_returns_error_message( + response_body, expected_error_message +): response = requests.Response() response._content = response_body error_message = JsonErrorMessageParser().parse_response_error_message(response) diff --git a/unit_tests/sources/streams/http/error_handlers/test_response_models.py b/unit_tests/sources/streams/http/error_handlers/test_response_models.py index a19d3c8d..7d0eb776 100644 --- a/unit_tests/sources/streams/http/error_handlers/test_response_models.py +++ b/unit_tests/sources/streams/http/error_handlers/test_response_models.py @@ -5,7 +5,10 @@ import requests import requests_mock from airbyte_cdk.models import FailureType -from airbyte_cdk.sources.streams.http.error_handlers.response_models import ResponseAction, create_fallback_error_resolution +from airbyte_cdk.sources.streams.http.error_handlers.response_models import ( + ResponseAction, + create_fallback_error_resolution, +) from airbyte_cdk.utils.airbyte_secrets_utils import update_secrets _A_SECRET = "a-secret" @@ -20,7 +23,9 @@ def tearDown(self) -> None: # to avoid other tests being impacted by added secrets update_secrets([]) - def test_given_none_when_create_fallback_error_resolution_then_return_error_resolution(self) -> None: + def test_given_none_when_create_fallback_error_resolution_then_return_error_resolution( + self, + ) -> None: error_resolution = create_fallback_error_resolution(None) assert error_resolution.failure_type == FailureType.system_error @@ -30,7 +35,9 @@ def test_given_none_when_create_fallback_error_resolution_then_return_error_reso == "Error handler did not receive a valid response or exception. This is unexpected please contact Airbyte Support" ) - def test_given_exception_when_create_fallback_error_resolution_then_return_error_resolution(self) -> None: + def test_given_exception_when_create_fallback_error_resolution_then_return_error_resolution( + self, + ) -> None: exception = ValueError("This is an exception") error_resolution = create_fallback_error_resolution(exception) @@ -41,23 +48,34 @@ def test_given_exception_when_create_fallback_error_resolution_then_return_error assert "ValueError" in error_resolution.error_message assert str(exception) in error_resolution.error_message - def test_given_response_can_raise_for_status_when_create_fallback_error_resolution_then_error_resolution(self) -> None: + def test_given_response_can_raise_for_status_when_create_fallback_error_resolution_then_error_resolution( + self, + ) -> None: response = self._create_response(512) error_resolution = create_fallback_error_resolution(response) assert error_resolution.failure_type == FailureType.system_error assert error_resolution.response_action == ResponseAction.RETRY - assert error_resolution.error_message and "512 Server Error: None for url: https://a-url.com/" in error_resolution.error_message + assert ( + error_resolution.error_message + and "512 Server Error: None for url: https://a-url.com/" + in error_resolution.error_message + ) - def test_given_response_is_ok_when_create_fallback_error_resolution_then_error_resolution(self) -> None: + def test_given_response_is_ok_when_create_fallback_error_resolution_then_error_resolution( + self, + ) -> None: response = self._create_response(205) error_resolution = create_fallback_error_resolution(response) assert error_resolution.failure_type == FailureType.system_error assert error_resolution.response_action == ResponseAction.RETRY - assert error_resolution.error_message and str(response.status_code) in error_resolution.error_message + assert ( + error_resolution.error_message + and str(response.status_code) in error_resolution.error_message + ) def _create_response(self, status_code: int) -> requests.Response: with requests_mock.Mocker() as http_mocker: diff --git a/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py b/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py index 50bd3d8f..c8345c42 100644 --- a/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py +++ b/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py @@ -99,7 +99,9 @@ def test_get_auth_header_fresh(self, mocker): refresh_token=TestOauth2Authenticator.refresh_token, ) - mocker.patch.object(Oauth2Authenticator, "refresh_access_token", return_value=("access_token", 1000)) + mocker.patch.object( + Oauth2Authenticator, "refresh_access_token", return_value=("access_token", 1000) + ) header = oauth.get_auth_header() assert {"Authorization": "Bearer access_token"} == header @@ -115,11 +117,19 @@ def test_get_auth_header_expired(self, mocker): ) expire_immediately = 0 - mocker.patch.object(Oauth2Authenticator, "refresh_access_token", return_value=("access_token_1", expire_immediately)) + mocker.patch.object( + Oauth2Authenticator, + "refresh_access_token", + return_value=("access_token_1", expire_immediately), + ) oauth.get_auth_header() # Set the first expired token. valid_100_secs = 100 - mocker.patch.object(Oauth2Authenticator, "refresh_access_token", return_value=("access_token_2", valid_100_secs)) + mocker.patch.object( + Oauth2Authenticator, + "refresh_access_token", + return_value=("access_token_2", valid_100_secs), + ) header = oauth.get_auth_header() assert {"Authorization": "Bearer access_token_2"} == header @@ -136,7 +146,11 @@ def test_refresh_request_body(self): scopes=["scope1", "scope2"], token_expiry_date=pendulum.now().add(days=3), grant_type="some_grant_type", - refresh_request_body={"custom_field": "in_outbound_request", "another_field": "exists_in_body", "scopes": ["no_override"]}, + refresh_request_body={ + "custom_field": "in_outbound_request", + "another_field": "exists_in_body", + "scopes": ["no_override"], + }, ) body = oauth.build_refresh_request_body() expected = { @@ -158,11 +172,17 @@ def test_refresh_access_token(self, mocker): refresh_token="some_refresh_token", scopes=["scope1", "scope2"], token_expiry_date=pendulum.now().add(days=3), - refresh_request_body={"custom_field": "in_outbound_request", "another_field": "exists_in_body", "scopes": ["no_override"]}, + refresh_request_body={ + "custom_field": "in_outbound_request", + "another_field": "exists_in_body", + "scopes": ["no_override"], + }, ) resp.status_code = 200 - mocker.patch.object(resp, "json", return_value={"access_token": "access_token", "expires_in": 1000}) + mocker.patch.object( + resp, "json", return_value={"access_token": "access_token", "expires_in": 1000} + ) mocker.patch.object(requests, "request", side_effect=mock_request, autospec=True) token, expires_in = oauth.refresh_access_token() @@ -170,14 +190,20 @@ def test_refresh_access_token(self, mocker): assert ("access_token", 1000) == (token, expires_in) # Test with expires_in as str - mocker.patch.object(resp, "json", return_value={"access_token": "access_token", "expires_in": "2000"}) + mocker.patch.object( + resp, "json", return_value={"access_token": "access_token", "expires_in": "2000"} + ) token, expires_in = oauth.refresh_access_token() assert isinstance(expires_in, str) assert ("access_token", "2000") == (token, expires_in) # Test with expires_in as str - mocker.patch.object(resp, "json", return_value={"access_token": "access_token", "expires_in": "2022-04-24T00:00:00Z"}) + mocker.patch.object( + resp, + "json", + return_value={"access_token": "access_token", "expires_in": "2022-04-24T00:00:00Z"}, + ) token, expires_in = oauth.refresh_access_token() assert isinstance(expires_in, str) @@ -189,7 +215,11 @@ def test_refresh_access_token(self, mocker): (3600, None, pendulum.datetime(year=2022, month=1, day=1, hour=1)), ("90012", None, pendulum.datetime(year=2022, month=1, day=2, hour=1, second=12)), ("2024-02-28", "YYYY-MM-DD", pendulum.datetime(year=2024, month=2, day=28)), - ("2022-02-12T00:00:00.000000+00:00", "YYYY-MM-DDTHH:mm:ss.SSSSSSZ", pendulum.datetime(year=2022, month=2, day=12)), + ( + "2022-02-12T00:00:00.000000+00:00", + "YYYY-MM-DDTHH:mm:ss.SSSSSSZ", + pendulum.datetime(year=2022, month=2, day=12), + ), ], ids=["seconds", "string_of_seconds", "simple_date", "simple_datetime"], ) @@ -210,11 +240,19 @@ def test_parse_refresh_token_lifespan( token_expiry_date=pendulum.now().subtract(days=3), token_expiry_date_format=token_expiry_date_format, token_expiry_is_time_of_expiration=bool(token_expiry_date_format), - refresh_request_body={"custom_field": "in_outbound_request", "another_field": "exists_in_body", "scopes": ["no_override"]}, + refresh_request_body={ + "custom_field": "in_outbound_request", + "another_field": "exists_in_body", + "scopes": ["no_override"], + }, ) resp.status_code = 200 - mocker.patch.object(resp, "json", return_value={"access_token": "access_token", "expires_in": expires_in_response}) + mocker.patch.object( + resp, + "json", + return_value={"access_token": "access_token", "expires_in": expires_in_response}, + ) mocker.patch.object(requests, "request", side_effect=mock_request, autospec=True) token, expire_in = oauth.refresh_access_token() expires_datetime = oauth._parse_token_expiration_date(expire_in) @@ -233,7 +271,11 @@ def test_refresh_access_token_retry(self, error_code, requests_mock): ) requests_mock.post( f"https://{TestOauth2Authenticator.refresh_endpoint}", - [{"status_code": error_code}, {"status_code": error_code}, {"json": {"access_token": "token", "expires_in": 10}}], + [ + {"status_code": error_code}, + {"status_code": error_code}, + {"json": {"access_token": "token", "expires_in": 10}}, + ], ) token, expires_in = oauth.refresh_access_token() assert isinstance(expires_in, int) @@ -248,7 +290,9 @@ def test_auth_call_method(self, mocker): refresh_token=TestOauth2Authenticator.refresh_token, ) - mocker.patch.object(Oauth2Authenticator, "refresh_access_token", return_value=("access_token", 1000)) + mocker.patch.object( + Oauth2Authenticator, "refresh_access_token", return_value=("access_token", 1000) + ) prepared_request = requests.PreparedRequest() prepared_request.headers = {} oauth(prepared_request) @@ -256,7 +300,15 @@ def test_auth_call_method(self, mocker): assert {"Authorization": "Bearer access_token"} == prepared_request.headers @pytest.mark.parametrize( - ("config_codes", "response_code", "config_key", "response_key", "config_values", "response_value", "wrapped"), + ( + "config_codes", + "response_code", + "config_key", + "response_key", + "config_values", + "response_value", + "wrapped", + ), ( ((400,), 400, "error", "error", ("invalid_grant",), "invalid_grant", True), ((401,), 400, "error", "error", ("invalid_grant",), "invalid_grant", False), @@ -266,7 +318,15 @@ def test_auth_call_method(self, mocker): ), ) def test_refresh_access_token_wrapped( - self, requests_mock, config_codes, response_code, config_key, response_key, config_values, response_value, wrapped + self, + requests_mock, + config_codes, + response_code, + config_key, + response_key, + config_values, + response_value, + wrapped, ): oauth = Oauth2Authenticator( f"https://{TestOauth2Authenticator.refresh_endpoint}", @@ -278,7 +338,11 @@ def test_refresh_access_token_wrapped( refresh_token_error_values=config_values, ) error_content = {response_key: response_value} - requests_mock.post(f"https://{TestOauth2Authenticator.refresh_endpoint}", status_code=response_code, json=error_content) + requests_mock.post( + f"https://{TestOauth2Authenticator.refresh_endpoint}", + status_code=response_code, + json=error_content, + ) exception_to_raise = AirbyteTracedException if wrapped else RequestException with pytest.raises(exception_to_raise) as exc_info: @@ -317,7 +381,9 @@ def test_init(self, connector_config): ) assert authenticator.access_token == connector_config["credentials"]["access_token"] assert authenticator.get_refresh_token() == connector_config["credentials"]["refresh_token"] - assert authenticator.get_token_expiry_date() == pendulum.parse(connector_config["credentials"]["token_expiry_date"]) + assert authenticator.get_token_expiry_date() == pendulum.parse( + connector_config["credentials"]["token_expiry_date"] + ) @freezegun.freeze_time("2022-12-31") @pytest.mark.parametrize( @@ -329,7 +395,14 @@ def test_init(self, connector_config): ], ) def test_given_no_message_repository_get_access_token( - self, test_name, expires_in_value, expiry_date_format, expected_expiry_date, capsys, mocker, connector_config + self, + test_name, + expires_in_value, + expiry_date_format, + expected_expiry_date, + capsys, + mocker, + connector_config, ): authenticator = SingleUseRefreshTokenOauth2Authenticator( connector_config, @@ -338,7 +411,9 @@ def test_given_no_message_repository_get_access_token( client_secret=connector_config["credentials"]["client_secret"], token_expiry_date_format=expiry_date_format, ) - authenticator.refresh_access_token = mocker.Mock(return_value=("new_access_token", expires_in_value, "new_refresh_token")) + authenticator.refresh_access_token = mocker.Mock( + return_value=("new_access_token", expires_in_value, "new_refresh_token") + ) authenticator.token_has_expired = mocker.Mock(return_value=True) access_token = authenticator.get_access_token() captured = capsys.readouterr() @@ -357,7 +432,9 @@ def test_given_no_message_repository_get_access_token( assert not captured.out assert authenticator.access_token == access_token == "new_access_token" - def test_given_message_repository_when_get_access_token_then_emit_message(self, mocker, connector_config): + def test_given_message_repository_when_get_access_token_then_emit_message( + self, mocker, connector_config + ): message_repository = Mock() authenticator = SingleUseRefreshTokenOauth2Authenticator( connector_config, @@ -367,7 +444,9 @@ def test_given_message_repository_when_get_access_token_then_emit_message(self, token_expiry_date_format="YYYY-MM-DD", message_repository=message_repository, ) - authenticator.refresh_access_token = mocker.Mock(return_value=("new_access_token", "2023-04-04", "new_refresh_token")) + authenticator.refresh_access_token = mocker.Mock( + return_value=("new_access_token", "2023-04-04", "new_refresh_token") + ) authenticator.token_has_expired = mocker.Mock(return_value=True) authenticator.get_access_token() @@ -375,13 +454,30 @@ def test_given_message_repository_when_get_access_token_then_emit_message(self, emitted_message = message_repository.emit_message.call_args_list[0].args[0] assert emitted_message.type == Type.CONTROL assert emitted_message.control.type == OrchestratorType.CONNECTOR_CONFIG - assert emitted_message.control.connectorConfig.config["credentials"]["access_token"] == "new_access_token" - assert emitted_message.control.connectorConfig.config["credentials"]["refresh_token"] == "new_refresh_token" - assert emitted_message.control.connectorConfig.config["credentials"]["token_expiry_date"] == "2023-04-04T00:00:00+00:00" - assert emitted_message.control.connectorConfig.config["credentials"]["client_id"] == "my_client_id" - assert emitted_message.control.connectorConfig.config["credentials"]["client_secret"] == "my_client_secret" + assert ( + emitted_message.control.connectorConfig.config["credentials"]["access_token"] + == "new_access_token" + ) + assert ( + emitted_message.control.connectorConfig.config["credentials"]["refresh_token"] + == "new_refresh_token" + ) + assert ( + emitted_message.control.connectorConfig.config["credentials"]["token_expiry_date"] + == "2023-04-04T00:00:00+00:00" + ) + assert ( + emitted_message.control.connectorConfig.config["credentials"]["client_id"] + == "my_client_id" + ) + assert ( + emitted_message.control.connectorConfig.config["credentials"]["client_secret"] + == "my_client_secret" + ) - def test_given_message_repository_when_get_access_token_then_log_request(self, mocker, connector_config): + def test_given_message_repository_when_get_access_token_then_log_request( + self, mocker, connector_config + ): message_repository = Mock() authenticator = SingleUseRefreshTokenOauth2Authenticator( connector_config, @@ -390,9 +486,12 @@ def test_given_message_repository_when_get_access_token_then_log_request(self, m client_secret=connector_config["credentials"]["client_secret"], message_repository=message_repository, ) - mocker.patch("airbyte_cdk.sources.streams.http.requests_native_auth.abstract_oauth.requests.request") mocker.patch( - "airbyte_cdk.sources.streams.http.requests_native_auth.abstract_oauth.format_http_message", return_value="formatted json" + "airbyte_cdk.sources.streams.http.requests_native_auth.abstract_oauth.requests.request" + ) + mocker.patch( + "airbyte_cdk.sources.streams.http.requests_native_auth.abstract_oauth.format_http_message", + return_value="formatted json", ) authenticator.token_has_expired = mocker.Mock(return_value=True) @@ -415,7 +514,11 @@ def test_refresh_access_token(self, mocker, connector_config): authenticator.get_refresh_token_name(): "new_refresh_token", } ) - assert authenticator.refresh_access_token() == ("new_access_token", "42", "new_refresh_token") + assert authenticator.refresh_access_token() == ( + "new_access_token", + "42", + "new_refresh_token", + ) def mock_request(method, url, data): diff --git a/unit_tests/sources/streams/http/test_availability_strategy.py b/unit_tests/sources/streams/http/test_availability_strategy.py index 42975d8e..766d0e35 100644 --- a/unit_tests/sources/streams/http/test_availability_strategy.py +++ b/unit_tests/sources/streams/http/test_availability_strategy.py @@ -104,9 +104,10 @@ def test_http_availability_raises_unhandled_error(mocker): req.status_code = 404 mocker.patch.object(requests.Session, "send", return_value=req) - assert (False, "Not found. The requested resource was not found on the server.") == HttpAvailabilityStrategy().check_availability( - http_stream, logger - ) + assert ( + False, + "Not found. The requested resource was not found on the server.", + ) == HttpAvailabilityStrategy().check_availability(http_stream, logger) def test_send_handles_retries_when_checking_availability(mocker, caplog): @@ -122,7 +123,9 @@ def test_send_handles_retries_when_checking_availability(mocker, caplog): mock_send = mocker.patch.object(requests.Session, "send", side_effect=[req_1, req_2, req_3]) with caplog.at_level(logging.INFO): - stream_is_available, _ = HttpAvailabilityStrategy().check_availability(stream=http_stream, logger=logger) + stream_is_available, _ = HttpAvailabilityStrategy().check_availability( + stream=http_stream, logger=logger + ) assert stream_is_available assert mock_send.call_count == 3 @@ -147,7 +150,9 @@ def __init__(self, *args, **kvargs): empty_stream.read_records.return_value = iter([]) logger = logging.getLogger("airbyte.test-source") - stream_is_available, _ = HttpAvailabilityStrategy().check_availability(stream=empty_stream, logger=logger) + stream_is_available, _ = HttpAvailabilityStrategy().check_availability( + stream=empty_stream, logger=logger + ) assert stream_is_available assert empty_stream.read_records.called diff --git a/unit_tests/sources/streams/http/test_http.py b/unit_tests/sources/streams/http/test_http.py index 3aeb2186..77a5fca3 100644 --- a/unit_tests/sources/streams/http/test_http.py +++ b/unit_tests/sources/streams/http/test_http.py @@ -14,12 +14,18 @@ from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, Level, SyncMode, Type from airbyte_cdk.sources.streams import CheckpointMixin from airbyte_cdk.sources.streams.checkpoint import ResumableFullRefreshCursor -from airbyte_cdk.sources.streams.checkpoint.substream_resumable_full_refresh_cursor import SubstreamResumableFullRefreshCursor +from airbyte_cdk.sources.streams.checkpoint.substream_resumable_full_refresh_cursor import ( + SubstreamResumableFullRefreshCursor, +) from airbyte_cdk.sources.streams.core import StreamData from airbyte_cdk.sources.streams.http import HttpStream, HttpSubStream from airbyte_cdk.sources.streams.http.error_handlers import ErrorHandler, HttpStatusErrorHandler from airbyte_cdk.sources.streams.http.error_handlers.response_models import ResponseAction -from airbyte_cdk.sources.streams.http.exceptions import DefaultBackoffException, RequestBodyException, UserDefinedBackoffException +from airbyte_cdk.sources.streams.http.exceptions import ( + DefaultBackoffException, + RequestBodyException, + UserDefinedBackoffException, +) from airbyte_cdk.sources.streams.http.http_client import MessageRepresentationAirbyteTracedErrors from airbyte_cdk.sources.streams.http.requests_native_auth import TokenAuthenticator @@ -66,7 +72,9 @@ def test_request_kwargs_used(mocker, requests_mock): stream = StubBasicReadHttpStream() request_kwargs = {"cert": None, "proxies": "google.com"} mocker.patch.object(stream, "request_kwargs", return_value=request_kwargs) - send_mock = mocker.patch.object(stream._http_client._session, "send", wraps=stream._http_client._session.send) + send_mock = mocker.patch.object( + stream._http_client._session, "send", wraps=stream._http_client._session.send + ) requests_mock.register_uri("GET", stream.url_base) list(stream.read_records(sync_mode=SyncMode.full_refresh)) @@ -118,10 +126,14 @@ def test_next_page_token_is_input_to_other_methods(mocker): expected_next_page_tokens = [{"page": i} for i in range(pages)] for method in methods: # First assert that they were called with no next_page_token. This is the first call in the pagination loop. - getattr(stream, method).assert_any_call(next_page_token=None, stream_slice=None, stream_state={}) + getattr(stream, method).assert_any_call( + next_page_token=None, stream_slice=None, stream_state={} + ) for token in expected_next_page_tokens: # Then verify that each method - getattr(stream, method).assert_any_call(next_page_token=token, stream_slice=None, stream_state={}) + getattr(stream, method).assert_any_call( + next_page_token=token, stream_slice=None, stream_state={} + ) expected = [{"data": 1}, {"data": 2}, {"data": 3}, {"data": 4}, {"data": 5}, {"data": 6}] @@ -298,7 +310,9 @@ class TestRequestBody: urlencoded_form_body = "key1=value1&key2=1234" def request2response(self, request, context): - return json.dumps({"body": request.text, "content_type": request.headers.get("Content-Type")}) + return json.dumps( + {"body": request.text, "content_type": request.headers.get("Content-Type")} + ) def test_json_body(self, mocker, requests_mock): stream = PostHttpStream() @@ -492,7 +506,11 @@ def should_retry(self, *args, **kwargs): ], ) def test_http_stream_adapter_http_status_error_handler_should_retry_false_raise_on_http_errors( - mocker, response_status_code: int, should_retry: bool, raise_on_http_errors: bool, expected_response_action: ResponseAction + mocker, + response_status_code: int, + should_retry: bool, + raise_on_http_errors: bool, + expected_response_action: ResponseAction, ): stream = AutoFailTrueHttpStream() mocker.patch.object(stream, "should_retry", return_value=should_retry) @@ -525,13 +543,25 @@ def test_send_raise_on_http_errors_logs(mocker, status_code): ({"error": {"message": "something broke"}}, "something broke"), ({"error": "err-001", "message": "something broke"}, "something broke"), ({"failure": {"message": "something broke"}}, "something broke"), - ({"error": {"errors": [{"message": "one"}, {"message": "two"}, {"message": "three"}]}}, "one, two, three"), + ( + {"error": {"errors": [{"message": "one"}, {"message": "two"}, {"message": "three"}]}}, + "one, two, three", + ), ({"errors": ["one", "two", "three"]}, "one, two, three"), ({"messages": ["one", "two", "three"]}, "one, two, three"), - ({"errors": [{"message": "one"}, {"message": "two"}, {"message": "three"}]}, "one, two, three"), - ({"error": [{"message": "one"}, {"message": "two"}, {"message": "three"}]}, "one, two, three"), + ( + {"errors": [{"message": "one"}, {"message": "two"}, {"message": "three"}]}, + "one, two, three", + ), + ( + {"error": [{"message": "one"}, {"message": "two"}, {"message": "three"}]}, + "one, two, three", + ), ({"errors": [{"error": "one"}, {"error": "two"}, {"error": "three"}]}, "one, two, three"), - ({"failures": [{"message": "one"}, {"message": "two"}, {"message": "three"}]}, "one, two, three"), + ( + {"failures": [{"message": "one"}, {"message": "two"}, {"message": "three"}]}, + "one, two, three", + ), (["one", "two", "three"], "one, two, three"), ([{"error": "one"}, {"error": "two"}, {"error": "three"}], "one, two, three"), ({"error": True}, None), @@ -574,17 +604,42 @@ def test_default_get_error_display_message_handles_http_error(mocker): "test_name, base_url, path, expected_full_url", [ ("test_no_slashes", "https://airbyte.io", "my_endpoint", "https://airbyte.io/my_endpoint"), - ("test_trailing_slash_on_base_url", "https://airbyte.io/", "my_endpoint", "https://airbyte.io/my_endpoint"), + ( + "test_trailing_slash_on_base_url", + "https://airbyte.io/", + "my_endpoint", + "https://airbyte.io/my_endpoint", + ), ( "test_trailing_slash_on_base_url_and_leading_slash_on_path", "https://airbyte.io/", "/my_endpoint", "https://airbyte.io/my_endpoint", ), - ("test_leading_slash_on_path", "https://airbyte.io", "/my_endpoint", "https://airbyte.io/my_endpoint"), - ("test_trailing_slash_on_path", "https://airbyte.io", "/my_endpoint/", "https://airbyte.io/my_endpoint/"), - ("test_nested_path_no_leading_slash", "https://airbyte.io", "v1/my_endpoint", "https://airbyte.io/v1/my_endpoint"), - ("test_nested_path_with_leading_slash", "https://airbyte.io", "/v1/my_endpoint", "https://airbyte.io/v1/my_endpoint"), + ( + "test_leading_slash_on_path", + "https://airbyte.io", + "/my_endpoint", + "https://airbyte.io/my_endpoint", + ), + ( + "test_trailing_slash_on_path", + "https://airbyte.io", + "/my_endpoint/", + "https://airbyte.io/my_endpoint/", + ), + ( + "test_nested_path_no_leading_slash", + "https://airbyte.io", + "v1/my_endpoint", + "https://airbyte.io/v1/my_endpoint", + ), + ( + "test_nested_path_with_leading_slash", + "https://airbyte.io", + "/v1/my_endpoint", + "https://airbyte.io/v1/my_endpoint", + ), ], ) def test_join_url(test_name, base_url, path, expected_full_url): @@ -596,12 +651,26 @@ def test_join_url(test_name, base_url, path, expected_full_url): "deduplicate_query_params, path, params, expected_url", [ pytest.param( - True, "v1/endpoint?param1=value1", {}, "https://test_base_url.com/v1/endpoint?param1=value1", id="test_params_only_in_path" + True, + "v1/endpoint?param1=value1", + {}, + "https://test_base_url.com/v1/endpoint?param1=value1", + id="test_params_only_in_path", ), pytest.param( - True, "v1/endpoint", {"param1": "value1"}, "https://test_base_url.com/v1/endpoint?param1=value1", id="test_params_only_in_path" + True, + "v1/endpoint", + {"param1": "value1"}, + "https://test_base_url.com/v1/endpoint?param1=value1", + id="test_params_only_in_path", + ), + pytest.param( + True, + "v1/endpoint", + None, + "https://test_base_url.com/v1/endpoint", + id="test_params_is_none_and_no_params_in_path", ), - pytest.param(True, "v1/endpoint", None, "https://test_base_url.com/v1/endpoint", id="test_params_is_none_and_no_params_in_path"), pytest.param( True, "v1/endpoint?param1=value1", @@ -708,7 +777,13 @@ def next_page_token(self, response: requests.Response) -> Optional[Mapping[str, def _read_single_page( self, records_generator_fn: Callable[ - [requests.PreparedRequest, requests.Response, Mapping[str, Any], Optional[Mapping[str, Any]]], Iterable[StreamData] + [ + requests.PreparedRequest, + requests.Response, + Mapping[str, Any], + Optional[Mapping[str, Any]], + ], + Iterable[StreamData], ], stream_slice: Optional[Mapping[str, Any]] = None, stream_state: Optional[Mapping[str, Any]] = None, @@ -808,7 +883,13 @@ def next_page_token(self, response: requests.Response) -> Optional[Mapping[str, def _read_pages( self, records_generator_fn: Callable[ - [requests.PreparedRequest, requests.Response, Mapping[str, Any], Optional[Mapping[str, Any]]], Iterable[StreamData] + [ + requests.PreparedRequest, + requests.Response, + Mapping[str, Any], + Optional[Mapping[str, Any]], + ], + Iterable[StreamData], ], stream_slice: Optional[Mapping[str, Any]] = None, stream_state: Optional[Mapping[str, Any]] = None, @@ -888,7 +969,10 @@ def test_substream_skips_non_record_messages(): parent_records = [ {"id": "abc"}, - AirbyteMessage(type=Type.LOG, log=AirbyteLogMessage(level=Level.INFO, message="should_not_be_parent_record")), + AirbyteMessage( + type=Type.LOG, + log=AirbyteLogMessage(level=Level.INFO, message="should_not_be_parent_record"), + ), {"id": "def"}, {"id": "ghi"}, ] @@ -934,7 +1018,11 @@ def must_deduplicate_query_params(self) -> bool: class StubFullRefreshLegacySliceHttpStream(StubFullRefreshHttpStream): def stream_slices( - self, *, sync_mode: SyncMode, cursor_field: Optional[List[str]] = None, stream_state: Optional[Mapping[str, Any]] = None + self, + *, + sync_mode: SyncMode, + cursor_field: Optional[List[str]] = None, + stream_state: Optional[Mapping[str, Any]] = None, ) -> Iterable[Optional[Mapping[str, Any]]]: yield from [{}] @@ -957,15 +1045,26 @@ def test_resumable_full_refresh_read_from_start(mocker): mocker.patch.object(stream, method, wraps=getattr(stream, method)) checkpoint_reader = stream._get_checkpoint_reader( - cursor_field=[], logger=logging.getLogger("airbyte"), sync_mode=SyncMode.full_refresh, stream_state={} + cursor_field=[], + logger=logging.getLogger("airbyte"), + sync_mode=SyncMode.full_refresh, + stream_state={}, ) next_stream_slice = checkpoint_reader.next() records = [] - expected_checkpoints = [{"page": 2}, {"page": 3}, {"page": 4}, {"page": 5}, {"__ab_full_refresh_sync_complete": True}] + expected_checkpoints = [ + {"page": 2}, + {"page": 3}, + {"page": 4}, + {"page": 5}, + {"__ab_full_refresh_sync_complete": True}, + ] i = 0 while next_stream_slice is not None: - next_records = list(stream.read_records(SyncMode.full_refresh, stream_slice=next_stream_slice)) + next_records = list( + stream.read_records(SyncMode.full_refresh, stream_slice=next_stream_slice) + ) records.extend(next_records) checkpoint_reader.observe(stream.state) assert checkpoint_reader.get_checkpoint() == expected_checkpoints[i] @@ -978,10 +1077,14 @@ def test_resumable_full_refresh_read_from_start(mocker): expected_next_page_tokens = expected_checkpoints[:4] for method in methods: # First assert that they were called with no next_page_token. This is the first call in the pagination loop. - getattr(stream, method).assert_any_call(next_page_token=None, stream_slice={}, stream_state={}) + getattr(stream, method).assert_any_call( + next_page_token=None, stream_slice={}, stream_state={} + ) for token in expected_next_page_tokens: # Then verify that each method - getattr(stream, method).assert_any_call(next_page_token=token, stream_slice=token, stream_state={}) + getattr(stream, method).assert_any_call( + next_page_token=token, stream_slice=token, stream_state={} + ) expected = [{"data": 1}, {"data": 2}, {"data": 3}, {"data": 4}, {"data": 5}] @@ -1006,7 +1109,10 @@ def test_resumable_full_refresh_read_from_state(mocker): mocker.patch.object(stream, method, wraps=getattr(stream, method)) checkpoint_reader = stream._get_checkpoint_reader( - cursor_field=[], logger=logging.getLogger("airbyte"), sync_mode=SyncMode.full_refresh, stream_state={"page": 3} + cursor_field=[], + logger=logging.getLogger("airbyte"), + sync_mode=SyncMode.full_refresh, + stream_state={"page": 3}, ) next_stream_slice = checkpoint_reader.next() records = [] @@ -1014,7 +1120,9 @@ def test_resumable_full_refresh_read_from_state(mocker): expected_checkpoints = [{"page": 4}, {"page": 5}, {"__ab_full_refresh_sync_complete": True}] i = 0 while next_stream_slice is not None: - next_records = list(stream.read_records(SyncMode.full_refresh, stream_slice=next_stream_slice)) + next_records = list( + stream.read_records(SyncMode.full_refresh, stream_slice=next_stream_slice) + ) records.extend(next_records) checkpoint_reader.observe(stream.state) assert checkpoint_reader.get_checkpoint() == expected_checkpoints[i] @@ -1028,7 +1136,9 @@ def test_resumable_full_refresh_read_from_state(mocker): for method in methods: for token in expected_next_page_tokens: # Then verify that each method - getattr(stream, method).assert_any_call(next_page_token=token, stream_slice=token, stream_state={}) + getattr(stream, method).assert_any_call( + next_page_token=token, stream_slice=token, stream_state={} + ) expected = [{"data": 1}, {"data": 2}, {"data": 3}] @@ -1052,15 +1162,25 @@ def test_resumable_full_refresh_legacy_stream_slice(mocker): mocker.patch.object(stream, method, wraps=getattr(stream, method)) checkpoint_reader = stream._get_checkpoint_reader( - cursor_field=[], logger=logging.getLogger("airbyte"), sync_mode=SyncMode.full_refresh, stream_state={"page": 2} + cursor_field=[], + logger=logging.getLogger("airbyte"), + sync_mode=SyncMode.full_refresh, + stream_state={"page": 2}, ) next_stream_slice = checkpoint_reader.next() records = [] - expected_checkpoints = [{"page": 3}, {"page": 4}, {"page": 5}, {"__ab_full_refresh_sync_complete": True}] + expected_checkpoints = [ + {"page": 3}, + {"page": 4}, + {"page": 5}, + {"__ab_full_refresh_sync_complete": True}, + ] i = 0 while next_stream_slice is not None: - next_records = list(stream.read_records(SyncMode.full_refresh, stream_slice=next_stream_slice)) + next_records = list( + stream.read_records(SyncMode.full_refresh, stream_slice=next_stream_slice) + ) records.extend(next_records) checkpoint_reader.observe(stream.state) assert checkpoint_reader.get_checkpoint() == expected_checkpoints[i] @@ -1074,7 +1194,9 @@ def test_resumable_full_refresh_legacy_stream_slice(mocker): for method in methods: for token in expected_next_page_tokens: # Then verify that each method - getattr(stream, method).assert_any_call(next_page_token=token, stream_slice=token, stream_state={}) + getattr(stream, method).assert_any_call( + next_page_token=token, stream_slice=token, stream_state={} + ) expected = [{"data": 1}, {"data": 2}, {"data": 3}, {"data": 4}] @@ -1086,7 +1208,11 @@ class StubSubstreamResumableFullRefreshStream(HttpSubStream, CheckpointMixin): counter = 0 - def __init__(self, parent: HttpStream, partition_id_to_child_records: Mapping[str, List[Mapping[str, Any]]]): + def __init__( + self, + parent: HttpStream, + partition_id_to_child_records: Mapping[str, List[Mapping[str, Any]]], + ): super().__init__(parent=parent) self._partition_id_to_child_records = partition_id_to_child_records # self._state: MutableMapping[str, Any] = {} @@ -1168,14 +1294,19 @@ def test_substream_resumable_full_refresh_read_from_start(mocker): {"id": "a201", "parent_id": "100", "film": "oppenheimer"}, {"id": "a202", "parent_id": "100", "film": "inception"}, ], - "101": [{"id": "b200", "parent_id": "101", "film": "past_lives"}, {"id": "b201", "parent_id": "101", "film": "materialists"}], + "101": [ + {"id": "b200", "parent_id": "101", "film": "past_lives"}, + {"id": "b201", "parent_id": "101", "film": "materialists"}, + ], "102": [ {"id": "c200", "parent_id": "102", "film": "the_social_network"}, {"id": "c201", "parent_id": "102", "film": "gone_girl"}, {"id": "c202", "parent_id": "102", "film": "the_curious_case_of_benjamin_button"}, ], } - stream = StubSubstreamResumableFullRefreshStream(parent=parent_stream, partition_id_to_child_records=parents_to_children_records) + stream = StubSubstreamResumableFullRefreshStream( + parent=parent_stream, partition_id_to_child_records=parents_to_children_records + ) blank_response = {} # Send a blank response is fine as we ignore the response in `parse_response anyway. mocker.patch.object(stream._http_client, "send_request", return_value=(None, blank_response)) @@ -1184,7 +1315,10 @@ def test_substream_resumable_full_refresh_read_from_start(mocker): mocker.patch.object(stream, "_read_pages", wraps=getattr(stream, "_read_pages")) checkpoint_reader = stream._get_checkpoint_reader( - cursor_field=[], logger=logging.getLogger("airbyte"), sync_mode=SyncMode.full_refresh, stream_state={} + cursor_field=[], + logger=logging.getLogger("airbyte"), + sync_mode=SyncMode.full_refresh, + stream_state={}, ) next_stream_slice = checkpoint_reader.next() records = [] @@ -1204,7 +1338,10 @@ def test_substream_resumable_full_refresh_read_from_start(mocker): "cursor": {"__ab_full_refresh_sync_complete": True}, "partition": {"parent": {"name": "christopher_nolan", "parent_id": "100"}}, }, - {"cursor": {"__ab_full_refresh_sync_complete": True}, "partition": {"parent": {"name": "celine_song", "parent_id": "101"}}}, + { + "cursor": {"__ab_full_refresh_sync_complete": True}, + "partition": {"parent": {"name": "celine_song", "parent_id": "101"}}, + }, ] }, { @@ -1213,7 +1350,10 @@ def test_substream_resumable_full_refresh_read_from_start(mocker): "cursor": {"__ab_full_refresh_sync_complete": True}, "partition": {"parent": {"name": "christopher_nolan", "parent_id": "100"}}, }, - {"cursor": {"__ab_full_refresh_sync_complete": True}, "partition": {"parent": {"name": "celine_song", "parent_id": "101"}}}, + { + "cursor": {"__ab_full_refresh_sync_complete": True}, + "partition": {"parent": {"name": "celine_song", "parent_id": "101"}}, + }, { "cursor": {"__ab_full_refresh_sync_complete": True}, "partition": {"parent": {"name": "david_fincher", "parent_id": "102"}}, @@ -1224,7 +1364,9 @@ def test_substream_resumable_full_refresh_read_from_start(mocker): i = 0 while next_stream_slice is not None: - next_records = list(stream.read_records(SyncMode.full_refresh, stream_slice=next_stream_slice)) + next_records = list( + stream.read_records(SyncMode.full_refresh, stream_slice=next_stream_slice) + ) records.extend(next_records) checkpoint_reader.observe(stream.state) assert checkpoint_reader.get_checkpoint() == expected_checkpoints[i] @@ -1266,9 +1408,14 @@ def test_substream_resumable_full_refresh_read_from_state(mocker): {"id": "a201", "parent_id": "100", "film": "oppenheimer"}, {"id": "a202", "parent_id": "100", "film": "inception"}, ], - "101": [{"id": "b200", "parent_id": "101", "film": "past_lives"}, {"id": "b201", "parent_id": "101", "film": "materialists"}], + "101": [ + {"id": "b200", "parent_id": "101", "film": "past_lives"}, + {"id": "b201", "parent_id": "101", "film": "materialists"}, + ], } - stream = StubSubstreamResumableFullRefreshStream(parent=parent_stream, partition_id_to_child_records=parents_to_children_records) + stream = StubSubstreamResumableFullRefreshStream( + parent=parent_stream, partition_id_to_child_records=parents_to_children_records + ) blank_response = {} # Send a blank response is fine as we ignore the response in `parse_response anyway. mocker.patch.object(stream._http_client, "send_request", return_value=(None, blank_response)) @@ -1299,14 +1446,19 @@ def test_substream_resumable_full_refresh_read_from_state(mocker): "cursor": {"__ab_full_refresh_sync_complete": True}, "partition": {"parent": {"name": "christopher_nolan", "parent_id": "100"}}, }, - {"cursor": {"__ab_full_refresh_sync_complete": True}, "partition": {"parent": {"name": "celine_song", "parent_id": "101"}}}, + { + "cursor": {"__ab_full_refresh_sync_complete": True}, + "partition": {"parent": {"name": "celine_song", "parent_id": "101"}}, + }, ] }, ] i = 0 while next_stream_slice is not None: - next_records = list(stream.read_records(SyncMode.full_refresh, stream_slice=next_stream_slice)) + next_records = list( + stream.read_records(SyncMode.full_refresh, stream_slice=next_stream_slice) + ) records.extend(next_records) checkpoint_reader.observe(stream.state) assert checkpoint_reader.get_checkpoint() == expected_checkpoints[i] @@ -1324,7 +1476,13 @@ def test_substream_resumable_full_refresh_read_from_state(mocker): class StubWithCursorFields(StubBasicReadHttpStream): - def __init__(self, has_multiple_slices: bool, set_cursor_field: List[str], deduplicate_query_params: bool = False, **kwargs): + def __init__( + self, + has_multiple_slices: bool, + set_cursor_field: List[str], + deduplicate_query_params: bool = False, + **kwargs, + ): self.has_multiple_slices = has_multiple_slices self._cursor_field = set_cursor_field super().__init__() @@ -1337,9 +1495,16 @@ def cursor_field(self) -> Union[str, List[str]]: @pytest.mark.parametrize( "cursor_field, is_substream, expected_cursor", [ - pytest.param([], False, ResumableFullRefreshCursor(), id="test_stream_supports_resumable_full_refresh_cursor"), + pytest.param( + [], + False, + ResumableFullRefreshCursor(), + id="test_stream_supports_resumable_full_refresh_cursor", + ), pytest.param(["updated_at"], False, None, id="test_incremental_stream_does_not_use_cursor"), - pytest.param(["updated_at"], True, None, id="test_incremental_substream_does_not_use_cursor"), + pytest.param( + ["updated_at"], True, None, id="test_incremental_substream_does_not_use_cursor" + ), pytest.param( [], True, diff --git a/unit_tests/sources/streams/http/test_http_client.py b/unit_tests/sources/streams/http/test_http_client.py index 0c0f3c62..4ef9e968 100644 --- a/unit_tests/sources/streams/http/test_http_client.py +++ b/unit_tests/sources/streams/http/test_http_client.py @@ -9,8 +9,17 @@ from airbyte_cdk.models import FailureType from airbyte_cdk.sources.streams.call_rate import CachedLimiterSession, LimiterSession from airbyte_cdk.sources.streams.http import HttpClient -from airbyte_cdk.sources.streams.http.error_handlers import BackoffStrategy, ErrorResolution, HttpStatusErrorHandler, ResponseAction -from airbyte_cdk.sources.streams.http.exceptions import DefaultBackoffException, RequestBodyException, UserDefinedBackoffException +from airbyte_cdk.sources.streams.http.error_handlers import ( + BackoffStrategy, + ErrorResolution, + HttpStatusErrorHandler, + ResponseAction, +) +from airbyte_cdk.sources.streams.http.exceptions import ( + DefaultBackoffException, + RequestBodyException, + UserDefinedBackoffException, +) from airbyte_cdk.sources.streams.http.requests_native_auth import TokenAuthenticator from airbyte_cdk.utils.traced_exception import AirbyteTracedException from requests_cache import CachedRequest @@ -121,7 +130,12 @@ def test_duplicate_request_params_are_deduped(deduplicate_query_params, url, par if expected_url is None: with pytest.raises(ValueError): - http_client._create_prepared_request(http_method="get", url=url, dedupe_query_params=deduplicate_query_params, params=params) + http_client._create_prepared_request( + http_method="get", + url=url, + dedupe_query_params=deduplicate_query_params, + params=params, + ) else: prepared_request = http_client._create_prepared_request( http_method="get", url=url, dedupe_query_params=deduplicate_query_params, params=params @@ -134,7 +148,10 @@ def test_create_prepared_response_given_given_both_json_and_data_raises_request_ with pytest.raises(RequestBodyException): http_client._create_prepared_request( - http_method="get", url="https://test_base_url.com/v1/endpoint", json={"test": "json"}, data={"test": "data"} + http_method="get", + url="https://test_base_url.com/v1/endpoint", + json={"test": "json"}, + data={"test": "data"}, ) @@ -155,7 +172,9 @@ def test_create_prepared_response_given_either_json_or_data_returns_valid_reques def test_connection_pool(): - http_client = HttpClient(name="test", logger=MagicMock(), authenticator=TokenAuthenticator("test-token")) + http_client = HttpClient( + name="test", logger=MagicMock(), authenticator=TokenAuthenticator("test-token") + ) assert http_client._session.adapters["https://"]._pool_connections == 20 @@ -179,7 +198,12 @@ def test_send_raises_airbyte_traced_exception_with_fail_response_action(): name="test", logger=MagicMock(), error_handler=HttpStatusErrorHandler( - logger=MagicMock(), error_mapping={400: ErrorResolution(ResponseAction.FAIL, FailureType.system_error, "test error message")} + logger=MagicMock(), + error_mapping={ + 400: ErrorResolution( + ResponseAction.FAIL, FailureType.system_error, "test error message" + ) + }, ), session=mocked_session, ) @@ -205,12 +229,19 @@ def test_send_ignores_with_ignore_reponse_action_and_returns_response(): name="test", logger=mocked_logger, error_handler=HttpStatusErrorHandler( - logger=MagicMock(), error_mapping={300: ErrorResolution(ResponseAction.IGNORE, FailureType.system_error, "test ignore message")} + logger=MagicMock(), + error_mapping={ + 300: ErrorResolution( + ResponseAction.IGNORE, FailureType.system_error, "test ignore message" + ) + }, ), session=mocked_session, ) - prepared_request = http_client._create_prepared_request(http_method="get", url="https://test_base_url.com/v1/endpoint") + prepared_request = http_client._create_prepared_request( + http_method="get", url="https://test_base_url.com/v1/endpoint" + ) returned_response = http_client._send(prepared_request, {}) @@ -227,17 +258,29 @@ def backoff_time(self, *args, **kwargs) -> float: return self._backoff_time_value -@pytest.mark.parametrize("backoff_time_value, exception_type", [(0.1, UserDefinedBackoffException), (None, DefaultBackoffException)]) -def test_raises_backoff_exception_with_retry_response_action(mocker, backoff_time_value, exception_type): +@pytest.mark.parametrize( + "backoff_time_value, exception_type", + [(0.1, UserDefinedBackoffException), (None, DefaultBackoffException)], +) +def test_raises_backoff_exception_with_retry_response_action( + mocker, backoff_time_value, exception_type +): http_client = HttpClient( name="test", logger=MagicMock(), error_handler=HttpStatusErrorHandler( - logger=MagicMock(), error_mapping={408: ErrorResolution(ResponseAction.FAIL, FailureType.system_error, "test retry message")} + logger=MagicMock(), + error_mapping={ + 408: ErrorResolution( + ResponseAction.FAIL, FailureType.system_error, "test retry message" + ) + }, ), backoff_strategy=CustomBackoffStrategy(backoff_time_value=backoff_time_value), ) - prepared_request = http_client._create_prepared_request(http_method="get", url="https://test_base_url.com/v1/endpoint") + prepared_request = http_client._create_prepared_request( + http_method="get", url="https://test_base_url.com/v1/endpoint" + ) mocked_response = MagicMock(spec=requests.Response) mocked_response.status_code = 408 mocked_response.headers = {} @@ -247,20 +290,32 @@ def test_raises_backoff_exception_with_retry_response_action(mocker, backoff_tim mocker.patch.object( http_client._error_handler, "interpret_response", - return_value=ErrorResolution(ResponseAction.RETRY, FailureType.system_error, "test retry message"), + return_value=ErrorResolution( + ResponseAction.RETRY, FailureType.system_error, "test retry message" + ), ) with pytest.raises(exception_type): http_client._send(prepared_request, {}) -@pytest.mark.parametrize("backoff_time_value, exception_type", [(0.1, UserDefinedBackoffException), (None, DefaultBackoffException)]) -def test_raises_backoff_exception_with_response_with_unmapped_error(mocker, backoff_time_value, exception_type): +@pytest.mark.parametrize( + "backoff_time_value, exception_type", + [(0.1, UserDefinedBackoffException), (None, DefaultBackoffException)], +) +def test_raises_backoff_exception_with_response_with_unmapped_error( + mocker, backoff_time_value, exception_type +): http_client = HttpClient( name="test", logger=MagicMock(), error_handler=HttpStatusErrorHandler( - logger=MagicMock(), error_mapping={408: ErrorResolution(ResponseAction.FAIL, FailureType.system_error, "test retry message")} + logger=MagicMock(), + error_mapping={ + 408: ErrorResolution( + ResponseAction.FAIL, FailureType.system_error, "test retry message" + ) + }, ), backoff_strategy=CustomBackoffStrategy(backoff_time_value=backoff_time_value), ) @@ -301,7 +356,12 @@ def update_response(*args, **kwargs): name="test", logger=MagicMock(), error_handler=HttpStatusErrorHandler( - logger=MagicMock(), error_mapping={408: ErrorResolution(ResponseAction.RETRY, FailureType.system_error, "test retry message")} + logger=MagicMock(), + error_mapping={ + 408: ErrorResolution( + ResponseAction.RETRY, FailureType.system_error, "test retry message" + ) + }, ), session=mocked_session, ) @@ -318,12 +378,16 @@ def test_session_request_exception_raises_backoff_exception(): error_handler = HttpStatusErrorHandler( logger=MagicMock(), error_mapping={ - requests.exceptions.RequestException: ErrorResolution(ResponseAction.RETRY, FailureType.system_error, "test retry message") + requests.exceptions.RequestException: ErrorResolution( + ResponseAction.RETRY, FailureType.system_error, "test retry message" + ) }, ) mocked_session = MagicMock(spec=requests.Session) mocked_session.send.side_effect = requests.RequestException - http_client = HttpClient(name="test", logger=MagicMock(), error_handler=error_handler, session=mocked_session) + http_client = HttpClient( + name="test", logger=MagicMock(), error_handler=error_handler, session=mocked_session + ) prepared_request = requests.PreparedRequest() with pytest.raises(DefaultBackoffException): @@ -337,7 +401,9 @@ def test_that_response_was_cached(requests_mock): cached_http_client._session.cache.clear() - prepared_request = cached_http_client._create_prepared_request(http_method="GET", url="https://google.com/") + prepared_request = cached_http_client._create_prepared_request( + http_method="GET", url="https://google.com/" + ) requests_mock.register_uri("GET", "https://google.com/", json='{"test": "response"}') @@ -353,14 +419,20 @@ def test_that_response_was_cached(requests_mock): def test_send_handles_response_action_given_session_send_raises_request_exception(): - error_resolution = ErrorResolution(ResponseAction.FAIL, FailureType.system_error, "test fail message") + error_resolution = ErrorResolution( + ResponseAction.FAIL, FailureType.system_error, "test fail message" + ) - custom_error_handler = HttpStatusErrorHandler(logger=MagicMock(), error_mapping={requests.RequestException: error_resolution}) + custom_error_handler = HttpStatusErrorHandler( + logger=MagicMock(), error_mapping={requests.RequestException: error_resolution} + ) mocked_session = MagicMock(spec=requests.Session) mocked_session.send.side_effect = requests.RequestException - http_client = HttpClient(name="test", logger=MagicMock(), error_handler=custom_error_handler, session=mocked_session) + http_client = HttpClient( + name="test", logger=MagicMock(), error_handler=custom_error_handler, session=mocked_session + ) prepared_request = requests.PreparedRequest() with pytest.raises(AirbyteTracedException) as e: @@ -392,7 +464,12 @@ def update_response(*args, **kwargs): name="test", logger=MagicMock(), error_handler=HttpStatusErrorHandler( - logger=MagicMock(), error_mapping={408: ErrorResolution(ResponseAction.RETRY, FailureType.system_error, "test retry message")} + logger=MagicMock(), + error_mapping={ + 408: ErrorResolution( + ResponseAction.RETRY, FailureType.system_error, "test retry message" + ) + }, ), session=mocked_session, ) @@ -427,7 +504,9 @@ def backoff_time(self, *args, **kwargs): with patch.object(requests.Session, "send", return_value=mocked_response) as mocked_send: with pytest.raises(UserDefinedBackoffException): - http_client.send_request(http_method="get", url="https://test_base_url.com/v1/endpoint", request_kwargs={}) + http_client.send_request( + http_method="get", url="https://test_base_url.com/v1/endpoint", request_kwargs={} + ) assert mocked_send.call_count == 1 @@ -438,7 +517,10 @@ def backoff_time(self, *args, **kwargs): return 0.001 http_client = HttpClient( - name="test", logger=MagicMock(), error_handler=HttpStatusErrorHandler(logger=MagicMock()), backoff_strategy=BackoffStrategy() + name="test", + logger=MagicMock(), + error_handler=HttpStatusErrorHandler(logger=MagicMock()), + backoff_strategy=BackoffStrategy(), ) mocked_response = MagicMock(spec=requests.Response) @@ -450,7 +532,9 @@ def backoff_time(self, *args, **kwargs): with patch.object(requests.Session, "send", return_value=mocked_response) as mocked_send: with pytest.raises(UserDefinedBackoffException): - http_client.send_request(http_method="get", url="https://test_base_url.com/v1/endpoint", request_kwargs={}) + http_client.send_request( + http_method="get", url="https://test_base_url.com/v1/endpoint", request_kwargs={} + ) assert mocked_send.call_count == 6 @@ -478,7 +562,9 @@ def backoff_time(self, *args, **kwargs): with patch.object(requests.Session, "send", return_value=mocked_response) as mocked_send: with pytest.raises(UserDefinedBackoffException): - http_client.send_request(http_method="get", url="https://test_base_url.com/v1/endpoint", request_kwargs={}) + http_client.send_request( + http_method="get", url="https://test_base_url.com/v1/endpoint", request_kwargs={} + ) assert mocked_send.call_count == retries + 1 @@ -486,7 +572,11 @@ def backoff_time(self, *args, **kwargs): def test_backoff_strategy_max_time(): error_handler = HttpStatusErrorHandler( logger=MagicMock(), - error_mapping={requests.RequestException: ErrorResolution(ResponseAction.RETRY, FailureType.system_error, "test retry message")}, + error_mapping={ + requests.RequestException: ErrorResolution( + ResponseAction.RETRY, FailureType.system_error, "test retry message" + ) + }, max_retries=10, max_time=timedelta(seconds=2), ) @@ -495,7 +585,12 @@ class BackoffStrategy: def backoff_time(self, *args, **kwargs): return 1 - http_client = HttpClient(name="test", logger=MagicMock(), error_handler=error_handler, backoff_strategy=BackoffStrategy()) + http_client = HttpClient( + name="test", + logger=MagicMock(), + error_handler=error_handler, + backoff_strategy=BackoffStrategy(), + ) mocked_response = MagicMock(spec=requests.Response) mocked_response.status_code = 429 @@ -506,7 +601,9 @@ def backoff_time(self, *args, **kwargs): with patch.object(requests.Session, "send", return_value=mocked_response) as mocked_send: with pytest.raises(UserDefinedBackoffException): - http_client.send_request(http_method="get", url="https://test_base_url.com/v1/endpoint", request_kwargs={}) + http_client.send_request( + http_method="get", url="https://test_base_url.com/v1/endpoint", request_kwargs={} + ) assert mocked_send.call_count == 2 @@ -517,7 +614,10 @@ def backoff_time(self, *args, **kwargs): return 0.001 http_client = HttpClient( - name="test", logger=MagicMock(), error_handler=HttpStatusErrorHandler(logger=MagicMock()), backoff_strategy=BackoffStrategy() + name="test", + logger=MagicMock(), + error_handler=HttpStatusErrorHandler(logger=MagicMock()), + backoff_strategy=BackoffStrategy(), ) mocked_response = MagicMock(spec=requests.Response) @@ -529,18 +629,23 @@ def backoff_time(self, *args, **kwargs): with patch.object(requests.Session, "send", return_value=mocked_response) as mocked_send: with pytest.raises(UserDefinedBackoffException): - http_client.send_request(http_method="get", url="https://test_base_url.com/v1/endpoint", request_kwargs={}) + http_client.send_request( + http_method="get", url="https://test_base_url.com/v1/endpoint", request_kwargs={} + ) trace_messages = capsys.readouterr().out.split() assert len(trace_messages) == mocked_send.call_count @pytest.mark.parametrize( - "exit_on_rate_limit, expected_call_count, expected_error", [[True, 6, DefaultBackoffException], [False, 38, OverflowError]] + "exit_on_rate_limit, expected_call_count, expected_error", + [[True, 6, DefaultBackoffException], [False, 38, OverflowError]], ) @pytest.mark.usefixtures("mock_sleep") def test_backoff_strategy_endless(exit_on_rate_limit, expected_call_count, expected_error): - http_client = HttpClient(name="test", logger=MagicMock(), error_handler=HttpStatusErrorHandler(logger=MagicMock())) + http_client = HttpClient( + name="test", logger=MagicMock(), error_handler=HttpStatusErrorHandler(logger=MagicMock()) + ) mocked_response = MagicMock(spec=requests.Response) mocked_response.status_code = 429 @@ -552,7 +657,10 @@ def test_backoff_strategy_endless(exit_on_rate_limit, expected_call_count, expec with patch.object(requests.Session, "send", return_value=mocked_response) as mocked_send: with pytest.raises(expected_error): http_client.send_request( - http_method="get", url="https://test_base_url.com/v1/endpoint", request_kwargs={}, exit_on_rate_limit=exit_on_rate_limit + http_method="get", + url="https://test_base_url.com/v1/endpoint", + request_kwargs={}, + exit_on_rate_limit=exit_on_rate_limit, ) assert mocked_send.call_count == expected_call_count @@ -561,10 +669,24 @@ def test_given_different_headers_then_response_is_not_cached(requests_mock): http_client = HttpClient(name="test", logger=MagicMock(), use_cache=True) first_request_headers = {"header_key": "first"} second_request_headers = {"header_key": "second"} - requests_mock.register_uri("GET", "https://google.com/", request_headers=first_request_headers, json={"test": "first response"}) - requests_mock.register_uri("GET", "https://google.com/", request_headers=second_request_headers, json={"test": "second response"}) + requests_mock.register_uri( + "GET", + "https://google.com/", + request_headers=first_request_headers, + json={"test": "first response"}, + ) + requests_mock.register_uri( + "GET", + "https://google.com/", + request_headers=second_request_headers, + json={"test": "second response"}, + ) - http_client.send_request("GET", "https://google.com/", headers=first_request_headers, request_kwargs={}) - _, second_response = http_client.send_request("GET", "https://google.com/", headers=second_request_headers, request_kwargs={}) + http_client.send_request( + "GET", "https://google.com/", headers=first_request_headers, request_kwargs={} + ) + _, second_response = http_client.send_request( + "GET", "https://google.com/", headers=second_request_headers, request_kwargs={} + ) assert second_response.json()["test"] == "second response" diff --git a/unit_tests/sources/streams/test_call_rate.py b/unit_tests/sources/streams/test_call_rate.py index c78d494b..518951cd 100644 --- a/unit_tests/sources/streams/test_call_rate.py +++ b/unit_tests/sources/streams/test_call_rate.py @@ -77,23 +77,50 @@ def test_method(self, request_factory): def test_params(self, request_factory): matcher = HttpRequestMatcher(params={"param1": 10, "param2": 15}) assert not matcher(request_factory(url="http://some_url/")) - assert not matcher(request_factory(url="http://some_url/", params={"param1": 10, "param3": 100})) - assert not matcher(request_factory(url="http://some_url/", params={"param1": 10, "param2": 10})) - assert matcher(request_factory(url="http://some_url/", params={"param1": 10, "param2": 15, "param3": 100})) + assert not matcher( + request_factory(url="http://some_url/", params={"param1": 10, "param3": 100}) + ) + assert not matcher( + request_factory(url="http://some_url/", params={"param1": 10, "param2": 10}) + ) + assert matcher( + request_factory( + url="http://some_url/", params={"param1": 10, "param2": 15, "param3": 100} + ) + ) @try_all_types_of_requests def test_header(self, request_factory): matcher = HttpRequestMatcher(headers={"header1": 10, "header2": 15}) assert not matcher(request_factory(url="http://some_url")) - assert not matcher(request_factory(url="http://some_url", headers={"header1": "10", "header3": "100"})) - assert not matcher(request_factory(url="http://some_url", headers={"header1": "10", "header2": "10"})) - assert matcher(request_factory(url="http://some_url", headers={"header1": "10", "header2": "15", "header3": "100"})) + assert not matcher( + request_factory(url="http://some_url", headers={"header1": "10", "header3": "100"}) + ) + assert not matcher( + request_factory(url="http://some_url", headers={"header1": "10", "header2": "10"}) + ) + assert matcher( + request_factory( + url="http://some_url", headers={"header1": "10", "header2": "15", "header3": "100"} + ) + ) @try_all_types_of_requests def test_combination(self, request_factory): - matcher = HttpRequestMatcher(method="GET", url="http://some_url/", headers={"header1": 10}, params={"param2": "test"}) - assert matcher(request_factory(method="GET", url="http://some_url", headers={"header1": "10"}, params={"param2": "test"})) - assert not matcher(request_factory(method="GET", url="http://some_url", headers={"header1": "10"})) + matcher = HttpRequestMatcher( + method="GET", url="http://some_url/", headers={"header1": 10}, params={"param2": "test"} + ) + assert matcher( + request_factory( + method="GET", + url="http://some_url", + headers={"header1": "10"}, + params={"param2": "test"}, + ) + ) + assert not matcher( + request_factory(method="GET", url="http://some_url", headers={"header1": "10"}) + ) assert not matcher(request_factory(method="GET", url="http://some_url")) assert not matcher(request_factory(url="http://some_url")) @@ -104,8 +131,12 @@ def test_http_request_matching(mocker): groups_policy = mocker.Mock(spec=MovingWindowCallRatePolicy) root_policy = mocker.Mock(spec=MovingWindowCallRatePolicy) - users_policy.matches.side_effect = HttpRequestMatcher(url="http://domain/api/users", method="GET") - groups_policy.matches.side_effect = HttpRequestMatcher(url="http://domain/api/groups", method="POST") + users_policy.matches.side_effect = HttpRequestMatcher( + url="http://domain/api/users", method="GET" + ) + groups_policy.matches.side_effect = HttpRequestMatcher( + url="http://domain/api/groups", method="POST" + ) root_policy.matches.side_effect = HttpRequestMatcher(method="GET") api_budget = APIBudget( policies=[ @@ -115,7 +146,12 @@ def test_http_request_matching(mocker): ] ) - api_budget.acquire_call(Request("POST", url="http://domain/unmatched_endpoint"), block=False), "unrestricted call" + ( + api_budget.acquire_call( + Request("POST", url="http://domain/unmatched_endpoint"), block=False + ), + "unrestricted call", + ) users_policy.try_acquire.assert_not_called() groups_policy.try_acquire.assert_not_called() root_policy.try_acquire.assert_not_called() @@ -126,7 +162,10 @@ def test_http_request_matching(mocker): groups_policy.try_acquire.assert_not_called() root_policy.try_acquire.assert_not_called() - api_budget.acquire_call(Request("GET", url="http://domain/api/users"), block=False), "second call, first matcher" + ( + api_budget.acquire_call(Request("GET", url="http://domain/api/users"), block=False), + "second call, first matcher", + ) assert users_policy.try_acquire.call_count == 2 groups_policy.try_acquire.assert_not_called() root_policy.try_acquire.assert_not_called() @@ -137,7 +176,10 @@ def test_http_request_matching(mocker): groups_policy.try_acquire.assert_called_once_with(group_request, weight=1) root_policy.try_acquire.assert_not_called() - api_budget.acquire_call(Request("POST", url="http://domain/api/groups"), block=False), "second call, second matcher" + ( + api_budget.acquire_call(Request("POST", url="http://domain/api/groups"), block=False), + "second call, second matcher", + ) assert users_policy.try_acquire.call_count == 2 assert groups_policy.try_acquire.call_count == 2 root_policy.try_acquire.assert_not_called() @@ -165,7 +207,9 @@ def test_update(self): class TestFixedWindowCallRatePolicy: def test_limit_rate(self, mocker): - policy = FixedWindowCallRatePolicy(matchers=[], next_reset_ts=datetime.now(), period=timedelta(hours=1), call_limit=100) + policy = FixedWindowCallRatePolicy( + matchers=[], next_reset_ts=datetime.now(), period=timedelta(hours=1), call_limit=100 + ) policy.try_acquire(mocker.Mock(), weight=1) policy.try_acquire(mocker.Mock(), weight=20) with pytest.raises(ValueError, match="Weight can not exceed the call limit"): @@ -179,7 +223,9 @@ def test_limit_rate(self, mocker): assert exc.value.item def test_update_available_calls(self, mocker): - policy = FixedWindowCallRatePolicy(matchers=[], next_reset_ts=datetime.now(), period=timedelta(hours=1), call_limit=100) + policy = FixedWindowCallRatePolicy( + matchers=[], next_reset_ts=datetime.now(), period=timedelta(hours=1), call_limit=100 + ) # update to decrease number of calls available policy.update(available_calls=2, call_reset_ts=None) # hit the limit with weight=3 @@ -216,7 +262,9 @@ def test_limit_rate(self): with pytest.raises(CallRateLimitHit) as excinfo2: policy.try_acquire("call", weight=1), "call over limit" - assert excinfo2.value.time_to_wait < excinfo1.value.time_to_wait, "time to wait must decrease over time" + assert ( + excinfo2.value.time_to_wait < excinfo1.value.time_to_wait + ), "time to wait must decrease over time" def test_limit_rate_support_custom_weight(self): """try_acquire must take into account provided weight and throw CallRateLimitHit when hit the limit.""" @@ -225,7 +273,9 @@ def test_limit_rate_support_custom_weight(self): policy.try_acquire("call", weight=2), "1st call with weight of 2" with pytest.raises(CallRateLimitHit) as excinfo: policy.try_acquire("call", weight=9), "2nd call, over limit since 2 + 9 = 11 > 10" - assert excinfo.value.time_to_wait.total_seconds() == pytest.approx(60, 0.1), "should wait 1 minute before next call" + assert excinfo.value.time_to_wait.total_seconds() == pytest.approx( + 60, 0.1 + ), "should wait 1 minute before next call" def test_multiple_limit_rates(self): """try_acquire must take into all call rates and apply stricter.""" @@ -257,7 +307,9 @@ def test_without_cache(self, mocker, requests_mock): api_budget = APIBudget( policies=[ MovingWindowCallRatePolicy( - matchers=[HttpRequestMatcher(url=f"{StubDummyHttpStream.url_base}/", method="GET")], + matchers=[ + HttpRequestMatcher(url=f"{StubDummyHttpStream.url_base}/", method="GET") + ], rates=[ Rate(2, timedelta(minutes=1)), ], @@ -265,7 +317,9 @@ def test_without_cache(self, mocker, requests_mock): ] ) - stream = StubDummyHttpStream(api_budget=api_budget, authenticator=TokenAuthenticator(token="ABCD")) + stream = StubDummyHttpStream( + api_budget=api_budget, authenticator=TokenAuthenticator(token="ABCD") + ) for i in range(10): records = stream.read_records(SyncMode.full_refresh) assert next(records) == {"data": "some_data"} diff --git a/unit_tests/sources/streams/test_stream_read.py b/unit_tests/sources/streams/test_stream_read.py index 5b82ab11..f079fe99 100644 --- a/unit_tests/sources/streams/test_stream_read.py +++ b/unit_tests/sources/streams/test_stream_read.py @@ -45,7 +45,11 @@ class _MockStream(Stream): - def __init__(self, slice_to_records: Mapping[str, List[Mapping[str, Any]]], json_schema: Dict[str, Any] = None): + def __init__( + self, + slice_to_records: Mapping[str, List[Mapping[str, Any]]], + json_schema: Dict[str, Any] = None, + ): self._slice_to_records = slice_to_records self._mocked_json_schema = json_schema or {} @@ -54,7 +58,11 @@ def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]: return None def stream_slices( - self, *, sync_mode: SyncMode, cursor_field: Optional[List[str]] = None, stream_state: Optional[Mapping[str, Any]] = None + self, + *, + sync_mode: SyncMode, + cursor_field: Optional[List[str]] = None, + stream_state: Optional[Mapping[str, Any]] = None, ) -> Iterable[Optional[Mapping[str, Any]]]: for partition in self._slice_to_records.keys(): yield {"partition_key": partition} @@ -143,9 +151,19 @@ def _stream(slice_to_partition_mapping, slice_logger, logger, message_repository return _MockStream(slice_to_partition_mapping, json_schema=json_schema) -def _concurrent_stream(slice_to_partition_mapping, slice_logger, logger, message_repository, cursor: Optional[Cursor] = None): +def _concurrent_stream( + slice_to_partition_mapping, + slice_logger, + logger, + message_repository, + cursor: Optional[Cursor] = None, +): stream = _stream(slice_to_partition_mapping, slice_logger, logger, message_repository) - cursor = cursor or FinalStateCursor(stream_name=stream.name, stream_namespace=stream.namespace, message_repository=message_repository) + cursor = cursor or FinalStateCursor( + stream_name=stream.name, + stream_namespace=stream.namespace, + message_repository=message_repository, + ) source = Mock() source._slice_logger = slice_logger source.message_repository = message_repository @@ -154,18 +172,28 @@ def _concurrent_stream(slice_to_partition_mapping, slice_logger, logger, message return stream -def _incremental_stream(slice_to_partition_mapping, slice_logger, logger, message_repository, timestamp): +def _incremental_stream( + slice_to_partition_mapping, slice_logger, logger, message_repository, timestamp +): stream = _MockIncrementalStream(slice_to_partition_mapping) return stream -def _incremental_concurrent_stream(slice_to_partition_mapping, slice_logger, logger, message_repository, cursor): - stream = _concurrent_stream(slice_to_partition_mapping, slice_logger, logger, message_repository, cursor) +def _incremental_concurrent_stream( + slice_to_partition_mapping, slice_logger, logger, message_repository, cursor +): + stream = _concurrent_stream( + slice_to_partition_mapping, slice_logger, logger, message_repository, cursor + ) return stream -def _stream_with_no_cursor_field(slice_to_partition_mapping, slice_logger, logger, message_repository): - def get_updated_state(current_stream_state: MutableMapping[str, Any], latest_record: Mapping[str, Any]) -> MutableMapping[str, Any]: +def _stream_with_no_cursor_field( + slice_to_partition_mapping, slice_logger, logger, message_repository +): + def get_updated_state( + current_stream_state: MutableMapping[str, Any], latest_record: Mapping[str, Any] + ) -> MutableMapping[str, Any]: raise Exception("I shouldn't be invoked by a full_refresh stream") mock_stream = _MockStream(slice_to_partition_mapping) @@ -184,7 +212,9 @@ def test_full_refresh_read_a_single_slice_with_debug(constructor): # This test verifies that a concurrent stream adapted from a Stream behaves the same as the Stream object. # It is done by running the same test cases on both streams configured_stream = ConfiguredAirbyteStream( - stream=AirbyteStream(name="mock_stream", supported_sync_modes=[SyncMode.full_refresh], json_schema={}), + stream=AirbyteStream( + name="mock_stream", supported_sync_modes=[SyncMode.full_refresh], json_schema={} + ), sync_mode=SyncMode.full_refresh, destination_sync_mode=DestinationSyncMode.overwrite, ) @@ -228,7 +258,15 @@ def test_full_refresh_read_a_single_slice_with_debug(constructor): ), ) - actual_records = _read(stream, configured_stream, logger, slice_logger, message_repository, state_manager, internal_config) + actual_records = _read( + stream, + configured_stream, + logger, + slice_logger, + message_repository, + state_manager, + internal_config, + ) if constructor == _concurrent_stream: assert hasattr(stream._cursor, "state") @@ -248,7 +286,9 @@ def test_full_refresh_read_a_single_slice(constructor): # This test verifies that a concurrent stream adapted from a Stream behaves the same as the Stream object. # It is done by running the same test cases on both streams configured_stream = ConfiguredAirbyteStream( - stream=AirbyteStream(name="mock_stream", supported_sync_modes=[SyncMode.full_refresh], json_schema={}), + stream=AirbyteStream( + name="mock_stream", supported_sync_modes=[SyncMode.full_refresh], json_schema={} + ), sync_mode=SyncMode.full_refresh, destination_sync_mode=DestinationSyncMode.overwrite, ) @@ -284,7 +324,15 @@ def test_full_refresh_read_a_single_slice(constructor): ), ) - actual_records = _read(stream, configured_stream, logger, slice_logger, message_repository, state_manager, internal_config) + actual_records = _read( + stream, + configured_stream, + logger, + slice_logger, + message_repository, + state_manager, + internal_config, + ) if constructor == _concurrent_stream: assert hasattr(stream._cursor, "state") @@ -305,7 +353,9 @@ def test_full_refresh_read_two_slices(constructor): # This test verifies that a concurrent stream adapted from a Stream behaves the same as the Stream object # It is done by running the same test cases on both streams configured_stream = ConfiguredAirbyteStream( - stream=AirbyteStream(name="mock_stream", supported_sync_modes=[SyncMode.full_refresh], json_schema={}), + stream=AirbyteStream( + name="mock_stream", supported_sync_modes=[SyncMode.full_refresh], json_schema={} + ), sync_mode=SyncMode.full_refresh, destination_sync_mode=DestinationSyncMode.overwrite, ) @@ -348,7 +398,15 @@ def test_full_refresh_read_two_slices(constructor): ), ) - actual_records = _read(stream, configured_stream, logger, slice_logger, message_repository, state_manager, internal_config) + actual_records = _read( + stream, + configured_stream, + logger, + slice_logger, + message_repository, + state_manager, + internal_config, + ) if constructor == _concurrent_stream: assert hasattr(stream._cursor, "state") @@ -362,7 +420,11 @@ def test_full_refresh_read_two_slices(constructor): def test_incremental_read_two_slices(): # This test verifies that a stream running in incremental mode emits state messages correctly configured_stream = ConfiguredAirbyteStream( - stream=AirbyteStream(name="mock_stream", supported_sync_modes=[SyncMode.full_refresh, SyncMode.incremental], json_schema={}), + stream=AirbyteStream( + name="mock_stream", + supported_sync_modes=[SyncMode.full_refresh, SyncMode.incremental], + json_schema={}, + ), sync_mode=SyncMode.incremental, cursor_field=["created_at"], destination_sync_mode=DestinationSyncMode.overwrite, @@ -383,7 +445,9 @@ def test_incremental_read_two_slices(): {"id": 4, "partition": 2, "created_at": "1708899427"}, ] slice_to_partition = {1: records_partition_1, 2: records_partition_2} - stream = _incremental_stream(slice_to_partition, slice_logger, logger, message_repository, timestamp) + stream = _incremental_stream( + slice_to_partition, slice_logger, logger, message_repository, timestamp + ) expected_records = [ *records_partition_1, @@ -392,7 +456,15 @@ def test_incremental_read_two_slices(): _create_state_message("__mock_incremental_stream", {"created_at": timestamp}), ] - actual_records = _read(stream, configured_stream, logger, slice_logger, message_repository, state_manager, internal_config) + actual_records = _read( + stream, + configured_stream, + logger, + slice_logger, + message_repository, + state_manager, + internal_config, + ) for record in expected_records: assert record in actual_records @@ -402,7 +474,11 @@ def test_incremental_read_two_slices(): def test_concurrent_incremental_read_two_slices(): # This test verifies that an incremental concurrent stream manages state correctly for multiple slices syncing concurrently configured_stream = ConfiguredAirbyteStream( - stream=AirbyteStream(name="mock_stream", supported_sync_modes=[SyncMode.full_refresh, SyncMode.incremental], json_schema={}), + stream=AirbyteStream( + name="mock_stream", + supported_sync_modes=[SyncMode.full_refresh, SyncMode.incremental], + json_schema={}, + ), sync_mode=SyncMode.incremental, destination_sync_mode=DestinationSyncMode.overwrite, ) @@ -424,7 +500,9 @@ def test_concurrent_incremental_read_two_slices(): {"id": 4, "partition": 2, "created_at": slice_timestamp_2}, ] slice_to_partition = {1: records_partition_1, 2: records_partition_2} - stream = _incremental_concurrent_stream(slice_to_partition, slice_logger, logger, message_repository, cursor) + stream = _incremental_concurrent_stream( + slice_to_partition, slice_logger, logger, message_repository, cursor + ) expected_records = [ *records_partition_1, @@ -432,10 +510,19 @@ def test_concurrent_incremental_read_two_slices(): ] expected_state = _create_state_message( - "__mock_stream", {"1": {"created_at": slice_timestamp_1}, "2": {"created_at": slice_timestamp_2}} + "__mock_stream", + {"1": {"created_at": slice_timestamp_1}, "2": {"created_at": slice_timestamp_2}}, ) - actual_records = _read(stream, configured_stream, logger, slice_logger, message_repository, state_manager, internal_config) + actual_records = _read( + stream, + configured_stream, + logger, + slice_logger, + message_repository, + state_manager, + internal_config, + ) handler = ConcurrentReadProcessor( [stream], @@ -452,7 +539,13 @@ def test_concurrent_incremental_read_two_slices(): # We need run on_record to update cursor with record cursor value for record in actual_records: - list(handler.on_record(Record(record, Mock(spec=Partition, **{"stream_name.return_value": "__mock_stream"})))) + list( + handler.on_record( + Record( + record, Mock(spec=Partition, **{"stream_name.return_value": "__mock_stream"}) + ) + ) + ) assert len(actual_records) == len(expected_records) @@ -467,7 +560,11 @@ def test_concurrent_incremental_read_two_slices(): def setup_stream_dependencies(configured_json_schema): configured_stream = ConfiguredAirbyteStream( - stream=AirbyteStream(name="mock_stream", supported_sync_modes=[SyncMode.full_refresh], json_schema=configured_json_schema), + stream=AirbyteStream( + name="mock_stream", + supported_sync_modes=[SyncMode.full_refresh], + json_schema=configured_json_schema, + ), sync_mode=SyncMode.full_refresh, destination_sync_mode=DestinationSyncMode.overwrite, ) @@ -476,7 +573,14 @@ def setup_stream_dependencies(configured_json_schema): slice_logger = DebugSliceLogger() message_repository = InMemoryMessageRepository(Level.INFO) state_manager = ConnectorStateManager() - return configured_stream, internal_config, logger, slice_logger, message_repository, state_manager + return ( + configured_stream, + internal_config, + logger, + slice_logger, + message_repository, + state_manager, + ) def test_configured_json_schema(): @@ -489,8 +593,8 @@ def test_configured_json_schema(): }, } - configured_stream, internal_config, logger, slice_logger, message_repository, state_manager = setup_stream_dependencies( - current_json_schema + configured_stream, internal_config, logger, slice_logger, message_repository, state_manager = ( + setup_stream_dependencies(current_json_schema) ) records = [ {"id": 1, "partition": 1}, @@ -498,9 +602,23 @@ def test_configured_json_schema(): ] slice_to_partition = {1: records} - stream = _stream(slice_to_partition, slice_logger, logger, message_repository, json_schema=current_json_schema) + stream = _stream( + slice_to_partition, + slice_logger, + logger, + message_repository, + json_schema=current_json_schema, + ) assert not stream.configured_json_schema - _read(stream, configured_stream, logger, slice_logger, message_repository, state_manager, internal_config) + _read( + stream, + configured_stream, + logger, + slice_logger, + message_repository, + state_manager, + internal_config, + ) assert stream.configured_json_schema == current_json_schema @@ -527,8 +645,8 @@ def test_configured_json_schema_with_invalid_properties(): del stream_schema["properties"][old_user_insights] del stream_schema["properties"][old_feature_info] - configured_stream, internal_config, logger, slice_logger, message_repository, state_manager = setup_stream_dependencies( - configured_json_schema + configured_stream, internal_config, logger, slice_logger, message_repository, state_manager = ( + setup_stream_dependencies(configured_json_schema) ) records = [ {"id": 1, "partition": 1}, @@ -536,9 +654,19 @@ def test_configured_json_schema_with_invalid_properties(): ] slice_to_partition = {1: records} - stream = _stream(slice_to_partition, slice_logger, logger, message_repository, json_schema=stream_schema) + stream = _stream( + slice_to_partition, slice_logger, logger, message_repository, json_schema=stream_schema + ) assert not stream.configured_json_schema - _read(stream, configured_stream, logger, slice_logger, message_repository, state_manager, internal_config) + _read( + stream, + configured_stream, + logger, + slice_logger, + message_repository, + state_manager, + internal_config, + ) assert stream.configured_json_schema != configured_json_schema configured_json_schema_properties = stream.configured_json_schema["properties"] assert old_user_insights not in configured_json_schema_properties @@ -547,19 +675,34 @@ def test_configured_json_schema_with_invalid_properties(): assert ( stream_schema_property in configured_json_schema_properties ), f"Stream schema property: {stream_schema_property} missing in configured schema" - assert stream_schema["properties"][stream_schema_property] == configured_json_schema_properties[stream_schema_property] + assert ( + stream_schema["properties"][stream_schema_property] + == configured_json_schema_properties[stream_schema_property] + ) -def _read(stream, configured_stream, logger, slice_logger, message_repository, state_manager, internal_config): +def _read( + stream, + configured_stream, + logger, + slice_logger, + message_repository, + state_manager, + internal_config, +): records = [] - for record in stream.read(configured_stream, logger, slice_logger, {}, state_manager, internal_config): + for record in stream.read( + configured_stream, logger, slice_logger, {}, state_manager, internal_config + ): for message in message_repository.consume_queue(): records.append(message) records.append(record) return records -def _mock_partition_generator(name: str, slices, records_per_partition, *, available=True, debug_log=False): +def _mock_partition_generator( + name: str, slices, records_per_partition, *, available=True, debug_log=False +): stream = Mock() stream.name = name stream.get_json_schema.return_value = {} diff --git a/unit_tests/sources/streams/test_streams_core.py b/unit_tests/sources/streams/test_streams_core.py index 9f356b5c..9f21ebab 100644 --- a/unit_tests/sources/streams/test_streams_core.py +++ b/unit_tests/sources/streams/test_streams_core.py @@ -184,7 +184,11 @@ def get_cursor(self) -> Optional[Cursor]: class LegacyCursorBasedStreamStubFullRefresh(CursorBasedStreamStubFullRefresh): def stream_slices( - self, *, sync_mode: SyncMode, cursor_field: Optional[List[str]] = None, stream_state: Optional[Mapping[str, Any]] = None + self, + *, + sync_mode: SyncMode, + cursor_field: Optional[List[str]] = None, + stream_state: Optional[Mapping[str, Any]] = None, ) -> Iterable[Optional[Mapping[str, Any]]]: yield from [{}] @@ -201,7 +205,11 @@ def url_base(self) -> str: return "https://airbyte.io/api/v1" def stream_slices( - self, *, sync_mode: SyncMode, cursor_field: Optional[List[str]] = None, stream_state: Optional[Mapping[str, Any]] = None + self, + *, + sync_mode: SyncMode, + cursor_field: Optional[List[str]] = None, + stream_state: Optional[Mapping[str, Any]] = None, ) -> Iterable[Optional[Mapping[str, Any]]]: yield from [ StreamSlice(partition={"parent_id": "korra"}, cursor_slice={}), @@ -277,7 +285,12 @@ def test_as_airbyte_stream_full_refresh(mocker): mocker.patch.object(StreamStubFullRefresh, "get_json_schema", return_value={}) airbyte_stream = test_stream.as_airbyte_stream() - exp = AirbyteStream(name="stream_stub_full_refresh", json_schema={}, supported_sync_modes=[SyncMode.full_refresh], is_resumable=False) + exp = AirbyteStream( + name="stream_stub_full_refresh", + json_schema={}, + supported_sync_modes=[SyncMode.full_refresh], + is_resumable=False, + ) assert airbyte_stream == exp @@ -367,7 +380,11 @@ def test_namespace_not_set(): @pytest.mark.parametrize( "test_input, expected", - [("key", [["key"]]), (["key1", "key2"], [["key1"], ["key2"]]), ([["key1", "key2"], ["key3"]], [["key1", "key2"], ["key3"]])], + [ + ("key", [["key"]]), + (["key1", "key2"], [["key1"], ["key2"]]), + ([["key1", "key2"], ["key3"]], [["key1", "key2"], ["key3"]]), + ], ) def test_wrapped_primary_key_various_argument(test_input, expected): """ @@ -390,13 +407,29 @@ def test_get_json_schema_is_cached(mocked_method): @pytest.mark.parametrize( "stream, stream_state, expected_checkpoint_reader_type", [ - pytest.param(StreamStubIncremental(), {}, IncrementalCheckpointReader, id="test_incremental_checkpoint_reader"), - pytest.param(StreamStubFullRefresh(), {}, FullRefreshCheckpointReader, id="test_full_refresh_checkpoint_reader"), pytest.param( - StreamStubResumableFullRefresh(), {}, ResumableFullRefreshCheckpointReader, id="test_resumable_full_refresh_checkpoint_reader" + StreamStubIncremental(), + {}, + IncrementalCheckpointReader, + id="test_incremental_checkpoint_reader", + ), + pytest.param( + StreamStubFullRefresh(), + {}, + FullRefreshCheckpointReader, + id="test_full_refresh_checkpoint_reader", ), pytest.param( - StreamStubLegacyStateInterface(), {}, IncrementalCheckpointReader, id="test_incremental_checkpoint_reader_with_legacy_state" + StreamStubResumableFullRefresh(), + {}, + ResumableFullRefreshCheckpointReader, + id="test_resumable_full_refresh_checkpoint_reader", + ), + pytest.param( + StreamStubLegacyStateInterface(), + {}, + IncrementalCheckpointReader, + id="test_incremental_checkpoint_reader_with_legacy_state", ), pytest.param( CursorBasedStreamStubFullRefresh(), diff --git a/unit_tests/sources/streams/utils/test_stream_helper.py b/unit_tests/sources/streams/utils/test_stream_helper.py index da76a787..39b642cb 100644 --- a/unit_tests/sources/streams/utils/test_stream_helper.py +++ b/unit_tests/sources/streams/utils/test_stream_helper.py @@ -11,7 +11,8 @@ def __init__(self, records, exit_on_rate_limit=True): self.records = records self._exit_on_rate_limit = exit_on_rate_limit type(self).exit_on_rate_limit = property( - lambda self: self._get_exit_on_rate_limit(), lambda self, value: self._set_exit_on_rate_limit(value) + lambda self: self._get_exit_on_rate_limit(), + lambda self, value: self._set_exit_on_rate_limit(value), ) def _get_exit_on_rate_limit(self): @@ -32,7 +33,9 @@ def read_records(self, sync_mode, stream_slice): ([], None, True, None, True), # No records, with setter ], ) -def test_get_first_record_for_slice(records, stream_slice, exit_on_rate_limit, expected_result, raises_exception): +def test_get_first_record_for_slice( + records, stream_slice, exit_on_rate_limit, expected_result, raises_exception +): stream = MockStream(records, exit_on_rate_limit) if raises_exception: diff --git a/unit_tests/sources/test_abstract_source.py b/unit_tests/sources/test_abstract_source.py index 9de46b9e..2cc0db54 100644 --- a/unit_tests/sources/test_abstract_source.py +++ b/unit_tests/sources/test_abstract_source.py @@ -5,7 +5,18 @@ import copy import datetime import logging -from typing import Any, Callable, Dict, Iterable, List, Mapping, MutableMapping, Optional, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Mapping, + MutableMapping, + Optional, + Tuple, + Union, +) from unittest.mock import Mock import pytest @@ -62,7 +73,9 @@ def __init__( self._message_repository = message_repository self._stop_sync_on_stream_failure = stop_sync_on_stream_failure - def check_connection(self, logger: logging.Logger, config: Mapping[str, Any]) -> Tuple[bool, Optional[Any]]: + def check_connection( + self, logger: logging.Logger, config: Mapping[str, Any] + ) -> Tuple[bool, Optional[Any]]: if self.check_lambda: return self.check_lambda() return False, "Missing callable." @@ -135,7 +148,9 @@ def read_records(self, *args, **kwargs) -> Iterable[Mapping[str, Any]]: @fixture def message_repository(): message_repository = Mock(spec=MessageRepository) - message_repository.consume_queue.return_value = [message for message in [MESSAGE_FROM_REPOSITORY]] + message_repository.consume_queue.return_value = [ + message for message in [MESSAGE_FROM_REPOSITORY] + ] return message_repository @@ -161,7 +176,9 @@ def test_raising_check(mocker): class MockStream(Stream): def __init__( self, - inputs_and_mocked_outputs: List[Tuple[Mapping[str, Any], Iterable[Mapping[str, Any]]]] = None, + inputs_and_mocked_outputs: List[ + Tuple[Mapping[str, Any], Iterable[Mapping[str, Any]]] + ] = None, name: str = None, ): self._inputs_and_mocked_outputs = inputs_and_mocked_outputs @@ -179,7 +196,9 @@ def read_records(self, **kwargs) -> Iterable[Mapping[str, Any]]: # type: ignore if kwargs == _input: return output - raise Exception(f"No mocked output supplied for input: {kwargs}. Mocked inputs/outputs: {self._inputs_and_mocked_outputs}") + raise Exception( + f"No mocked output supplied for input: {kwargs}. Mocked inputs/outputs: {self._inputs_and_mocked_outputs}" + ) @property def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]: @@ -193,12 +212,21 @@ def cursor_field(self) -> Union[str, List[str]]: class MockStreamWithCursor(MockStream): cursor_field = "cursor" - def __init__(self, inputs_and_mocked_outputs: List[Tuple[Mapping[str, Any], Iterable[Mapping[str, Any]]]], name: str): + def __init__( + self, + inputs_and_mocked_outputs: List[Tuple[Mapping[str, Any], Iterable[Mapping[str, Any]]]], + name: str, + ): super().__init__(inputs_and_mocked_outputs, name) class MockStreamWithState(MockStreamWithCursor): - def __init__(self, inputs_and_mocked_outputs: List[Tuple[Mapping[str, Any], Iterable[Mapping[str, Any]]]], name: str, state=None): + def __init__( + self, + inputs_and_mocked_outputs: List[Tuple[Mapping[str, Any], Iterable[Mapping[str, Any]]]], + name: str, + state=None, + ): super().__init__(inputs_and_mocked_outputs, name) self._state = state @@ -213,7 +241,10 @@ def state(self, value): class MockStreamEmittingAirbyteMessages(MockStreamWithState): def __init__( - self, inputs_and_mocked_outputs: List[Tuple[Mapping[str, Any], Iterable[AirbyteMessage]]] = None, name: str = None, state=None + self, + inputs_and_mocked_outputs: List[Tuple[Mapping[str, Any], Iterable[AirbyteMessage]]] = None, + name: str = None, + state=None, ): super().__init__(inputs_and_mocked_outputs, name, state) self._inputs_and_mocked_outputs = inputs_and_mocked_outputs @@ -264,7 +295,9 @@ def read_records(self, **kwargs) -> Iterable[Mapping[str, Any]]: # type: ignore output = mocked_output.get("records") if output is None: - raise Exception(f"No mocked output supplied for input: {kwargs}. Mocked inputs/outputs: {self._inputs_and_mocked_outputs}") + raise Exception( + f"No mocked output supplied for input: {kwargs}. Mocked inputs/outputs: {self._inputs_and_mocked_outputs}" + ) self.state = next_page_token or {"__ab_full_refresh_sync_complete": True} yield from output @@ -292,7 +325,9 @@ def test_discover(mocker): source_defined_cursor=True, source_defined_primary_key=[["pk"]], ) - airbyte_stream2 = AirbyteStream(name="2", json_schema={}, supported_sync_modes=[SyncMode.full_refresh]) + airbyte_stream2 = AirbyteStream( + name="2", json_schema={}, supported_sync_modes=[SyncMode.full_refresh] + ) stream1 = MockStream() stream2 = MockStream() @@ -334,7 +369,9 @@ def test_read_nonexistent_stream_without_raises_exception(mocker, as_stream_stat messages = list(src.read(logger, {}, catalog)) messages = _fix_emitted_at(messages) - expected = _fix_emitted_at([as_stream_status("this_stream_doesnt_exist_in_the_source", AirbyteStreamStatus.INCOMPLETE)]) + expected = _fix_emitted_at( + [as_stream_status("this_stream_doesnt_exist_in_the_source", AirbyteStreamStatus.INCOMPLETE)] + ) assert messages == expected @@ -342,28 +379,55 @@ def test_read_nonexistent_stream_without_raises_exception(mocker, as_stream_stat def test_read_stream_emits_repository_message_before_record(mocker, message_repository): stream = MockStream(name="my_stream") mocker.patch.object(MockStream, "get_json_schema", return_value={}) - mocker.patch.object(MockStream, "read_records", side_effect=[[{"a record": "a value"}, {"another record": "another value"}]]) - message_repository.consume_queue.side_effect = [[message for message in [MESSAGE_FROM_REPOSITORY]], [], []] + mocker.patch.object( + MockStream, + "read_records", + side_effect=[[{"a record": "a value"}, {"another record": "another value"}]], + ) + message_repository.consume_queue.side_effect = [ + [message for message in [MESSAGE_FROM_REPOSITORY]], + [], + [], + ] source = MockSource(streams=[stream], message_repository=message_repository) - messages = list(source.read(logger, {}, ConfiguredAirbyteCatalog(streams=[_configured_stream(stream, SyncMode.full_refresh)]))) + messages = list( + source.read( + logger, + {}, + ConfiguredAirbyteCatalog(streams=[_configured_stream(stream, SyncMode.full_refresh)]), + ) + ) assert messages.count(MESSAGE_FROM_REPOSITORY) == 1 record_messages = (message for message in messages if message.type == Type.RECORD) - assert all(messages.index(MESSAGE_FROM_REPOSITORY) < messages.index(record) for record in record_messages) + assert all( + messages.index(MESSAGE_FROM_REPOSITORY) < messages.index(record) + for record in record_messages + ) def test_read_stream_emits_repository_message_on_error(mocker, message_repository): stream = MockStream(name="my_stream") mocker.patch.object(MockStream, "get_json_schema", return_value={}) mocker.patch.object(MockStream, "read_records", side_effect=RuntimeError("error")) - message_repository.consume_queue.return_value = [message for message in [MESSAGE_FROM_REPOSITORY]] + message_repository.consume_queue.return_value = [ + message for message in [MESSAGE_FROM_REPOSITORY] + ] source = MockSource(streams=[stream], message_repository=message_repository) with pytest.raises(AirbyteTracedException): - messages = list(source.read(logger, {}, ConfiguredAirbyteCatalog(streams=[_configured_stream(stream, SyncMode.full_refresh)]))) + messages = list( + source.read( + logger, + {}, + ConfiguredAirbyteCatalog( + streams=[_configured_stream(stream, SyncMode.full_refresh)] + ), + ) + ) assert MESSAGE_FROM_REPOSITORY in messages @@ -421,14 +485,19 @@ def _as_state(stream_name: str = "", per_stream_state: Dict[str, Any] = None): state=AirbyteStateMessage( type=AirbyteStateType.STREAM, stream=AirbyteStreamState( - stream_descriptor=StreamDescriptor(name=stream_name), stream_state=AirbyteStateBlob(per_stream_state) + stream_descriptor=StreamDescriptor(name=stream_name), + stream_state=AirbyteStateBlob(per_stream_state), ), ), ) def _as_error_trace( - stream: str, error_message: str, internal_message: Optional[str], failure_type: Optional[FailureType], stack_trace: Optional[str] + stream: str, + error_message: str, + internal_message: Optional[str], + failure_type: Optional[FailureType], + stack_trace: Optional[str], ) -> AirbyteMessage: trace_message = AirbyteTraceMessage( emitted_at=datetime.datetime.now().timestamp() * 1000.0, @@ -465,8 +534,24 @@ def _fix_emitted_at(messages: List[AirbyteMessage]) -> List[AirbyteMessage]: def test_valid_full_refresh_read_no_slices(mocker): """Tests that running a full refresh sync on streams which don't specify slices produces the expected AirbyteMessages""" stream_output = [{"k1": "v1"}, {"k2": "v2"}] - s1 = MockStream([({"stream_slice": {}, "stream_state": {}, "sync_mode": SyncMode.full_refresh}, stream_output)], name="s1") - s2 = MockStream([({"stream_slice": {}, "stream_state": {}, "sync_mode": SyncMode.full_refresh}, stream_output)], name="s2") + s1 = MockStream( + [ + ( + {"stream_slice": {}, "stream_state": {}, "sync_mode": SyncMode.full_refresh}, + stream_output, + ) + ], + name="s1", + ) + s2 = MockStream( + [ + ( + {"stream_slice": {}, "stream_state": {}, "sync_mode": SyncMode.full_refresh}, + stream_output, + ) + ], + name="s2", + ) mocker.patch.object(MockStream, "get_json_schema", return_value={}) mocker.patch.object(MockStream, "cursor_field", return_value=[]) @@ -503,11 +588,17 @@ def test_valid_full_refresh_read_with_slices(mocker): slices = [{"1": "1"}, {"2": "2"}] # When attempting to sync a slice, just output that slice as a record s1 = MockStream( - [({"stream_state": {}, "sync_mode": SyncMode.full_refresh, "stream_slice": s}, [s]) for s in slices], + [ + ({"stream_state": {}, "sync_mode": SyncMode.full_refresh, "stream_slice": s}, [s]) + for s in slices + ], name="s1", ) s2 = MockStream( - [({"stream_state": {}, "sync_mode": SyncMode.full_refresh, "stream_slice": s}, [s]) for s in slices], + [ + ({"stream_state": {}, "sync_mode": SyncMode.full_refresh, "stream_slice": s}, [s]) + for s in slices + ], name="s2", ) @@ -545,14 +636,23 @@ def test_valid_full_refresh_read_with_slices(mocker): @pytest.mark.parametrize( "slices", - [[{"1": "1"}, {"2": "2"}], [{"date": datetime.date(year=2023, month=1, day=1)}, {"date": datetime.date(year=2023, month=1, day=1)}]], + [ + [{"1": "1"}, {"2": "2"}], + [ + {"date": datetime.date(year=2023, month=1, day=1)}, + {"date": datetime.date(year=2023, month=1, day=1)}, + ], + ], ) def test_read_full_refresh_with_slices_sends_slice_messages(mocker, slices): """Given the logger is debug and a full refresh, AirbyteMessages are sent for slices""" debug_logger = logging.getLogger("airbyte.debug") debug_logger.setLevel(logging.DEBUG) stream = MockStream( - [({"stream_state": {}, "sync_mode": SyncMode.full_refresh, "stream_slice": s}, [s]) for s in slices], + [ + ({"stream_state": {}, "sync_mode": SyncMode.full_refresh, "stream_slice": s}, [s]) + for s in slices + ], name="s1", ) @@ -568,7 +668,13 @@ def test_read_full_refresh_with_slices_sends_slice_messages(mocker, slices): messages = src.read(debug_logger, {}, catalog) - assert 2 == len(list(filter(lambda message: message.log and message.log.message.startswith("slice:"), messages))) + assert 2 == len( + list( + filter( + lambda message: message.log and message.log.message.startswith("slice:"), messages + ) + ) + ) def test_read_incremental_with_slices_sends_slice_messages(mocker): @@ -577,7 +683,10 @@ def test_read_incremental_with_slices_sends_slice_messages(mocker): debug_logger.setLevel(logging.DEBUG) slices = [{"1": "1"}, {"2": "2"}] stream = MockStream( - [({"sync_mode": SyncMode.incremental, "stream_slice": s, "stream_state": {}}, [s]) for s in slices], + [ + ({"sync_mode": SyncMode.incremental, "stream_slice": s, "stream_state": {}}, [s]) + for s in slices + ], name="s1", ) @@ -594,7 +703,13 @@ def test_read_incremental_with_slices_sends_slice_messages(mocker): messages = src.read(debug_logger, {}, catalog) - assert 2 == len(list(filter(lambda message: message.log and message.log.message.startswith("slice:"), messages))) + assert 2 == len( + list( + filter( + lambda message: message.log and message.log.message.startswith("slice:"), messages + ) + ) + ) class TestIncrementalRead: @@ -605,7 +720,10 @@ def test_with_state_attribute(self, mocker): input_state = [ AirbyteStateMessage( type=AirbyteStateType.STREAM, - stream=AirbyteStreamState(stream_descriptor=StreamDescriptor(name="s1"), stream_state=AirbyteStateBlob(old_state)), + stream=AirbyteStreamState( + stream_descriptor=StreamDescriptor(name="s1"), + stream_state=AirbyteStateBlob(old_state), + ), ), ] new_state_from_connector = {"cursor": "new_value"} @@ -613,14 +731,23 @@ def test_with_state_attribute(self, mocker): stream_1 = MockStreamWithState( [ ( - {"sync_mode": SyncMode.incremental, "stream_slice": {}, "stream_state": old_state}, + { + "sync_mode": SyncMode.incremental, + "stream_slice": {}, + "stream_state": old_state, + }, stream_output, ) ], name="s1", ) stream_2 = MockStreamWithState( - [({"sync_mode": SyncMode.incremental, "stream_slice": {}, "stream_state": {}}, stream_output)], + [ + ( + {"sync_mode": SyncMode.incremental, "stream_slice": {}, "stream_state": {}}, + stream_output, + ) + ], name="s2", ) @@ -684,11 +811,21 @@ def test_with_checkpoint_interval(self, mocker): stream_output = [{"k1": "v1"}, {"k2": "v2"}] stream_1 = MockStream( - [({"sync_mode": SyncMode.incremental, "stream_slice": {}, "stream_state": {}}, stream_output)], + [ + ( + {"sync_mode": SyncMode.incremental, "stream_slice": {}, "stream_state": {}}, + stream_output, + ) + ], name="s1", ) stream_2 = MockStream( - [({"sync_mode": SyncMode.incremental, "stream_slice": {}, "stream_state": {}}, stream_output)], + [ + ( + {"sync_mode": SyncMode.incremental, "stream_slice": {}, "stream_state": {}}, + stream_output, + ) + ], name="s2", ) state = {"cursor": "value"} @@ -743,11 +880,21 @@ def test_with_no_interval(self, mocker): stream_output = [{"k1": "v1"}, {"k2": "v2"}] stream_1 = MockStream( - [({"sync_mode": SyncMode.incremental, "stream_slice": {}, "stream_state": {}}, stream_output)], + [ + ( + {"sync_mode": SyncMode.incremental, "stream_slice": {}, "stream_state": {}}, + stream_output, + ) + ], name="s1", ) stream_2 = MockStream( - [({"sync_mode": SyncMode.incremental, "stream_slice": {}, "stream_state": {}}, stream_output)], + [ + ( + {"sync_mode": SyncMode.incremental, "stream_slice": {}, "stream_state": {}}, + stream_output, + ) + ], name="s2", ) state = {"cursor": "value"} @@ -857,7 +1004,13 @@ def test_with_slices(self, mocker): assert messages == expected - @pytest.mark.parametrize("slices", [pytest.param([], id="test_slices_as_list"), pytest.param(iter([]), id="test_slices_as_iterator")]) + @pytest.mark.parametrize( + "slices", + [ + pytest.param([], id="test_slices_as_list"), + pytest.param(iter([]), id="test_slices_as_iterator"), + ], + ) def test_no_slices(self, mocker, slices): """ Tests that an incremental read returns at least one state messages even if no records were read: @@ -867,11 +1020,17 @@ def test_no_slices(self, mocker, slices): input_state = [ AirbyteStateMessage( type=AirbyteStateType.STREAM, - stream=AirbyteStreamState(stream_descriptor=StreamDescriptor(name="s1"), stream_state=AirbyteStateBlob(state)), + stream=AirbyteStreamState( + stream_descriptor=StreamDescriptor(name="s1"), + stream_state=AirbyteStateBlob(state), + ), ), AirbyteStateMessage( type=AirbyteStateType.STREAM, - stream=AirbyteStreamState(stream_descriptor=StreamDescriptor(name="s2"), stream_state=AirbyteStateBlob(state)), + stream=AirbyteStreamState( + stream_descriptor=StreamDescriptor(name="s2"), + stream_state=AirbyteStateBlob(state), + ), ), ] @@ -1165,19 +1324,32 @@ def test_without_state_attribute_for_stream_with_desc_records(self, mocker): In this scenario records are returned in descending order, but we keep the "highest" cursor in the state. """ stream_cursor = MockStreamWithCursor.cursor_field - stream_output = [{f"k{cursor_id}": f"v{cursor_id}", stream_cursor: cursor_id} for cursor_id in range(5, 1, -1)] + stream_output = [ + {f"k{cursor_id}": f"v{cursor_id}", stream_cursor: cursor_id} + for cursor_id in range(5, 1, -1) + ] initial_state = {stream_cursor: 1} stream_name = "stream_with_cursor" input_state = [ AirbyteStateMessage( type=AirbyteStateType.STREAM, stream=AirbyteStreamState( - stream_descriptor=StreamDescriptor(name=stream_name), stream_state=AirbyteStateBlob(initial_state) + stream_descriptor=StreamDescriptor(name=stream_name), + stream_state=AirbyteStateBlob(initial_state), ), ), ] stream_with_cursor = MockStreamWithCursor( - [({"sync_mode": SyncMode.incremental, "stream_slice": {}, "stream_state": initial_state}, stream_output)], + [ + ( + { + "sync_mode": SyncMode.incremental, + "stream_slice": {}, + "stream_state": initial_state, + }, + stream_output, + ) + ], name=stream_name, ) @@ -1227,9 +1399,26 @@ def test_resumable_full_refresh_multiple_pages(self, mocker): # So in reality we can probably get rid of this test entirely s1 = MockResumableFullRefreshStream( [ - ({"stream_state": {}, "sync_mode": SyncMode.full_refresh, "stream_slice": {}}, responses[0]), - ({"stream_state": {}, "sync_mode": SyncMode.full_refresh, "stream_slice": {"page": 1}}, responses[1]), - ({"stream_state": {}, "sync_mode": SyncMode.full_refresh, "stream_slice": {"page": 2}}, responses[2]), + ( + {"stream_state": {}, "sync_mode": SyncMode.full_refresh, "stream_slice": {}}, + responses[0], + ), + ( + { + "stream_state": {}, + "sync_mode": SyncMode.full_refresh, + "stream_slice": {"page": 1}, + }, + responses[1], + ), + ( + { + "stream_state": {}, + "sync_mode": SyncMode.full_refresh, + "stream_slice": {"page": 2}, + }, + responses[2], + ), ], name="s1", ) @@ -1276,10 +1465,38 @@ def test_resumable_full_refresh_with_incoming_state(self, mocker): # So in reality we can probably get rid of this test entirely s1 = MockResumableFullRefreshStream( [ - ({"stream_state": {"page": 10}, "sync_mode": SyncMode.full_refresh, "stream_slice": {"page": 10}}, responses[0]), - ({"stream_state": {"page": 10}, "sync_mode": SyncMode.full_refresh, "stream_slice": {"page": 11}}, responses[1]), - ({"stream_state": {"page": 10}, "sync_mode": SyncMode.full_refresh, "stream_slice": {"page": 12}}, responses[2]), - ({"stream_state": {"page": 10}, "sync_mode": SyncMode.full_refresh, "stream_slice": {"page": 13}}, responses[3]), + ( + { + "stream_state": {"page": 10}, + "sync_mode": SyncMode.full_refresh, + "stream_slice": {"page": 10}, + }, + responses[0], + ), + ( + { + "stream_state": {"page": 10}, + "sync_mode": SyncMode.full_refresh, + "stream_slice": {"page": 11}, + }, + responses[1], + ), + ( + { + "stream_state": {"page": 10}, + "sync_mode": SyncMode.full_refresh, + "stream_slice": {"page": 12}, + }, + responses[2], + ), + ( + { + "stream_state": {"page": 10}, + "sync_mode": SyncMode.full_refresh, + "stream_slice": {"page": 13}, + }, + responses[3], + ), ], name="s1", ) @@ -1338,9 +1555,26 @@ def test_resumable_full_refresh_partial_failure(self, mocker): # So in reality we can probably get rid of this test entirely s1 = MockResumableFullRefreshStream( [ - ({"stream_state": {}, "sync_mode": SyncMode.full_refresh, "stream_slice": {}}, responses[0]), - ({"stream_state": {}, "sync_mode": SyncMode.full_refresh, "stream_slice": {"page": 1}}, responses[1]), - ({"stream_state": {}, "sync_mode": SyncMode.full_refresh, "stream_slice": {"page": 2}}, responses[2]), + ( + {"stream_state": {}, "sync_mode": SyncMode.full_refresh, "stream_slice": {}}, + responses[0], + ), + ( + { + "stream_state": {}, + "sync_mode": SyncMode.full_refresh, + "stream_slice": {"page": 1}, + }, + responses[1], + ), + ( + { + "stream_state": {}, + "sync_mode": SyncMode.full_refresh, + "stream_slice": {"page": 2}, + }, + responses[2], + ), ], name="s1", ) @@ -1393,20 +1627,76 @@ def test_resumable_full_refresh_skip_prior_successful_streams(self, mocker): # So in reality we can probably get rid of this test entirely s1 = MockResumableFullRefreshStream( [ - ({"stream_state": {"page": 10}, "sync_mode": SyncMode.full_refresh, "stream_slice": {"page": 10}}, responses[0]), - ({"stream_state": {"page": 10}, "sync_mode": SyncMode.full_refresh, "stream_slice": {"page": 11}}, responses[1]), - ({"stream_state": {"page": 10}, "sync_mode": SyncMode.full_refresh, "stream_slice": {"page": 12}}, responses[2]), - ({"stream_state": {"page": 10}, "sync_mode": SyncMode.full_refresh, "stream_slice": {"page": 13}}, responses[3]), + ( + { + "stream_state": {"page": 10}, + "sync_mode": SyncMode.full_refresh, + "stream_slice": {"page": 10}, + }, + responses[0], + ), + ( + { + "stream_state": {"page": 10}, + "sync_mode": SyncMode.full_refresh, + "stream_slice": {"page": 11}, + }, + responses[1], + ), + ( + { + "stream_state": {"page": 10}, + "sync_mode": SyncMode.full_refresh, + "stream_slice": {"page": 12}, + }, + responses[2], + ), + ( + { + "stream_state": {"page": 10}, + "sync_mode": SyncMode.full_refresh, + "stream_slice": {"page": 13}, + }, + responses[3], + ), ], name="s1", ) s2 = MockResumableFullRefreshStream( [ - ({"stream_state": {"page": 10}, "sync_mode": SyncMode.full_refresh, "stream_slice": {"page": 10}}, responses[0]), - ({"stream_state": {"page": 10}, "sync_mode": SyncMode.full_refresh, "stream_slice": {"page": 11}}, responses[1]), - ({"stream_state": {"page": 10}, "sync_mode": SyncMode.full_refresh, "stream_slice": {"page": 12}}, responses[2]), - ({"stream_state": {"page": 10}, "sync_mode": SyncMode.full_refresh, "stream_slice": {"page": 13}}, responses[3]), + ( + { + "stream_state": {"page": 10}, + "sync_mode": SyncMode.full_refresh, + "stream_slice": {"page": 10}, + }, + responses[0], + ), + ( + { + "stream_state": {"page": 10}, + "sync_mode": SyncMode.full_refresh, + "stream_slice": {"page": 11}, + }, + responses[1], + ), + ( + { + "stream_state": {"page": 10}, + "sync_mode": SyncMode.full_refresh, + "stream_slice": {"page": 12}, + }, + responses[2], + ), + ( + { + "stream_state": {"page": 10}, + "sync_mode": SyncMode.full_refresh, + "stream_slice": {"page": 13}, + }, + responses[3], + ), ], name="s2", ) @@ -1480,7 +1770,9 @@ def test_resumable_full_refresh_skip_prior_successful_streams(self, mocker): ), ], ) -def test_continue_sync_with_failed_streams(mocker, exception_to_raise, expected_error_message, expected_internal_message): +def test_continue_sync_with_failed_streams( + mocker, exception_to_raise, expected_error_message, expected_internal_message +): """ Tests that running a sync for a connector with multiple streams will continue syncing when one stream fails with an error. This source does not override the default behavior defined in the AbstractSource class. @@ -1510,7 +1802,13 @@ def test_continue_sync_with_failed_streams(mocker, exception_to_raise, expected_ _as_stream_status("s1", AirbyteStreamStatus.COMPLETE), _as_stream_status("lamentations", AirbyteStreamStatus.STARTED), _as_stream_status("lamentations", AirbyteStreamStatus.INCOMPLETE), - _as_error_trace("lamentations", expected_error_message, expected_internal_message, FailureType.system_error, None), + _as_error_trace( + "lamentations", + expected_error_message, + expected_internal_message, + FailureType.system_error, + None, + ), _as_stream_status("s3", AirbyteStreamStatus.STARTED), _as_stream_status("s3", AirbyteStreamStatus.RUNNING), *_as_records("s3", stream_output), @@ -1537,7 +1835,9 @@ def test_continue_sync_source_override_false(mocker): stream_output = [{"k1": "v1"}, {"k2": "v2"}] s1 = MockStream([({"sync_mode": SyncMode.full_refresh}, stream_output)], name="s1") - s2 = StreamRaisesException(exception_to_raise=AirbyteTracedException(message="I was born only to crash like Icarus")) + s2 = StreamRaisesException( + exception_to_raise=AirbyteTracedException(message="I was born only to crash like Icarus") + ) s3 = MockStream([({"sync_mode": SyncMode.full_refresh}, stream_output)], name="s3") mocker.patch.object(MockStream, "get_json_schema", return_value={}) @@ -1560,7 +1860,13 @@ def test_continue_sync_source_override_false(mocker): _as_stream_status("s1", AirbyteStreamStatus.COMPLETE), _as_stream_status("lamentations", AirbyteStreamStatus.STARTED), _as_stream_status("lamentations", AirbyteStreamStatus.INCOMPLETE), - _as_error_trace("lamentations", "I was born only to crash like Icarus", None, FailureType.system_error, None), + _as_error_trace( + "lamentations", + "I was born only to crash like Icarus", + None, + FailureType.system_error, + None, + ), _as_stream_status("s3", AirbyteStreamStatus.STARTED), _as_stream_status("s3", AirbyteStreamStatus.RUNNING), *_as_records("s3", stream_output), @@ -1587,7 +1893,9 @@ def test_sync_error_trace_messages_obfuscate_secrets(mocker): stream_output = [{"k1": "v1"}, {"k2": "v2"}] s1 = MockStream([({"sync_mode": SyncMode.full_refresh}, stream_output)], name="s1") s2 = StreamRaisesException( - exception_to_raise=AirbyteTracedException(message="My api_key value API_KEY_VALUE flew too close to the sun.") + exception_to_raise=AirbyteTracedException( + message="My api_key value API_KEY_VALUE flew too close to the sun." + ) ) s3 = MockStream([({"sync_mode": SyncMode.full_refresh}, stream_output)], name="s3") @@ -1611,7 +1919,13 @@ def test_sync_error_trace_messages_obfuscate_secrets(mocker): _as_stream_status("s1", AirbyteStreamStatus.COMPLETE), _as_stream_status("lamentations", AirbyteStreamStatus.STARTED), _as_stream_status("lamentations", AirbyteStreamStatus.INCOMPLETE), - _as_error_trace("lamentations", "My api_key value **** flew too close to the sun.", None, FailureType.system_error, None), + _as_error_trace( + "lamentations", + "My api_key value **** flew too close to the sun.", + None, + FailureType.system_error, + None, + ), _as_stream_status("s3", AirbyteStreamStatus.STARTED), _as_stream_status("s3", AirbyteStreamStatus.RUNNING), *_as_records("s3", stream_output), @@ -1635,9 +1949,27 @@ def test_continue_sync_with_failed_streams_with_override_false(mocker): the sync when one stream fails with an error. """ stream_output = [{"k1": "v1"}, {"k2": "v2"}] - s1 = MockStream([({"stream_state": {}, "stream_slice": {}, "sync_mode": SyncMode.full_refresh}, stream_output)], name="s1") - s2 = StreamRaisesException(AirbyteTracedException(message="I was born only to crash like Icarus")) - s3 = MockStream([({"stream_state": {}, "stream_slice": {}, "sync_mode": SyncMode.full_refresh}, stream_output)], name="s3") + s1 = MockStream( + [ + ( + {"stream_state": {}, "stream_slice": {}, "sync_mode": SyncMode.full_refresh}, + stream_output, + ) + ], + name="s1", + ) + s2 = StreamRaisesException( + AirbyteTracedException(message="I was born only to crash like Icarus") + ) + s3 = MockStream( + [ + ( + {"stream_state": {}, "stream_slice": {}, "sync_mode": SyncMode.full_refresh}, + stream_output, + ) + ], + name="s3", + ) mocker.patch.object(MockStream, "get_json_schema", return_value={}) mocker.patch.object(StreamRaisesException, "get_json_schema", return_value={}) @@ -1660,7 +1992,13 @@ def test_continue_sync_with_failed_streams_with_override_false(mocker): _as_stream_status("s1", AirbyteStreamStatus.COMPLETE), _as_stream_status("lamentations", AirbyteStreamStatus.STARTED), _as_stream_status("lamentations", AirbyteStreamStatus.INCOMPLETE), - _as_error_trace("lamentations", "I was born only to crash like Icarus", None, FailureType.system_error, None), + _as_error_trace( + "lamentations", + "I was born only to crash like Icarus", + None, + FailureType.system_error, + None, + ), ] ) @@ -1684,7 +2022,9 @@ def _remove_stack_trace(message: AirbyteMessage) -> AirbyteMessage: return message -def test_read_nonexistent_stream_emit_incomplete_stream_status(mocker, remove_stack_trace, as_stream_status): +def test_read_nonexistent_stream_emit_incomplete_stream_status( + mocker, remove_stack_trace, as_stream_status +): """ Tests that attempting to sync a stream which the source does not return from the `streams` method emit incomplete stream status """ @@ -1696,7 +2036,9 @@ def test_read_nonexistent_stream_emit_incomplete_stream_status(mocker, remove_st src = MockSource(streams=[s1]) catalog = ConfiguredAirbyteCatalog(streams=[_configured_stream(s2, SyncMode.full_refresh)]) - expected = _fix_emitted_at([as_stream_status("this_stream_doesnt_exist_in_the_source", AirbyteStreamStatus.INCOMPLETE)]) + expected = _fix_emitted_at( + [as_stream_status("this_stream_doesnt_exist_in_the_source", AirbyteStreamStatus.INCOMPLETE)] + ) expected_error_message = ( "The stream 'this_stream_doesnt_exist_in_the_source' in your connection configuration was not found in the " diff --git a/unit_tests/sources/test_config.py b/unit_tests/sources/test_config.py index 8cac7992..933bdc8f 100644 --- a/unit_tests/sources/test_config.py +++ b/unit_tests/sources/test_config.py @@ -62,7 +62,11 @@ class TestBaseConfig: "type": "string", "default": "option2", }, - "sequence": {"items": {"type": "string"}, "title": "Sequence", "type": "array"}, + "sequence": { + "items": {"type": "string"}, + "title": "Sequence", + "type": "array", + }, }, "required": ["sequence"], "title": "Choice2", @@ -73,7 +77,10 @@ class TestBaseConfig: }, "items": { "items": { - "properties": {"field1": {"title": "Field1", "type": "string"}, "field2": {"title": "Field2", "type": "integer"}}, + "properties": { + "field1": {"title": "Field1", "type": "string"}, + "field2": {"title": "Field2", "type": "integer"}, + }, "required": ["field1", "field2"], "title": "InnerClass", "type": "object", diff --git a/unit_tests/sources/test_connector_state_manager.py b/unit_tests/sources/test_connector_state_manager.py index 1a5526b1..9e53f2e6 100644 --- a/unit_tests/sources/test_connector_state_manager.py +++ b/unit_tests/sources/test_connector_state_manager.py @@ -16,7 +16,10 @@ StreamDescriptor, ) from airbyte_cdk.models import Type as MessageType -from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager, HashableStreamDescriptor +from airbyte_cdk.sources.connector_state_manager import ( + ConnectorStateManager, + HashableStreamDescriptor, +) @pytest.mark.parametrize( @@ -26,23 +29,38 @@ [ { "type": "STREAM", - "stream": {"stream_descriptor": {"name": "actors", "namespace": "public"}, "stream_state": {"id": "mando_michael"}}, + "stream": { + "stream_descriptor": {"name": "actors", "namespace": "public"}, + "stream_state": {"id": "mando_michael"}, + }, }, { "type": "STREAM", - "stream": {"stream_descriptor": {"name": "actresses", "namespace": "public"}, "stream_state": {"id": "seehorn_rhea"}}, + "stream": { + "stream_descriptor": {"name": "actresses", "namespace": "public"}, + "stream_state": {"id": "seehorn_rhea"}, + }, }, ], { - HashableStreamDescriptor(name="actors", namespace="public"): AirbyteStateBlob({"id": "mando_michael"}), - HashableStreamDescriptor(name="actresses", namespace="public"): AirbyteStateBlob({"id": "seehorn_rhea"}), + HashableStreamDescriptor(name="actors", namespace="public"): AirbyteStateBlob( + {"id": "mando_michael"} + ), + HashableStreamDescriptor(name="actresses", namespace="public"): AirbyteStateBlob( + {"id": "seehorn_rhea"} + ), }, does_not_raise(), id="test_incoming_per_stream_state", ), pytest.param([], {}, does_not_raise(), id="test_incoming_empty_stream_state"), pytest.param( - [{"type": "STREAM", "stream": {"stream_descriptor": {"name": "actresses", "namespace": "public"}}}], + [ + { + "type": "STREAM", + "stream": {"stream_descriptor": {"name": "actresses", "namespace": "public"}}, + } + ], {HashableStreamDescriptor(name="actresses", namespace="public"): None}, does_not_raise(), id="test_stream_states_that_have_none_state_blob", @@ -67,8 +85,12 @@ }, ], { - HashableStreamDescriptor(name="actors", namespace="public"): AirbyteStateBlob({"id": "mando_michael"}), - HashableStreamDescriptor(name="actresses", namespace="public"): AirbyteStateBlob({"id": "seehorn_rhea"}), + HashableStreamDescriptor(name="actors", namespace="public"): AirbyteStateBlob( + {"id": "mando_michael"} + ), + HashableStreamDescriptor(name="actresses", namespace="public"): AirbyteStateBlob( + {"id": "seehorn_rhea"} + ), }, pytest.raises(ValueError), id="test_incoming_global_state_with_shared_state_throws_error", @@ -79,13 +101,18 @@ "type": "GLOBAL", "global": { "stream_states": [ - {"stream_descriptor": {"name": "actors", "namespace": "public"}, "stream_state": {"id": "mando_michael"}}, + { + "stream_descriptor": {"name": "actors", "namespace": "public"}, + "stream_state": {"id": "mando_michael"}, + }, ], }, }, ], { - HashableStreamDescriptor(name="actors", namespace="public"): AirbyteStateBlob({"id": "mando_michael"}), + HashableStreamDescriptor(name="actors", namespace="public"): AirbyteStateBlob( + {"id": "mando_michael"} + ), }, does_not_raise(), id="test_incoming_global_state_without_shared", @@ -106,7 +133,9 @@ }, ], { - HashableStreamDescriptor(name="actors", namespace="public"): AirbyteStateBlob({"id": "mando_michael"}), + HashableStreamDescriptor(name="actors", namespace="public"): AirbyteStateBlob( + {"id": "mando_michael"} + ), }, does_not_raise(), id="test_incoming_global_state_with_none_shared", @@ -130,7 +159,9 @@ ) def test_initialize_state_manager(input_stream_state, expected_stream_state, expected_error): if isinstance(input_stream_state, List): - input_stream_state = [AirbyteStateMessageSerializer.load(state_obj) for state_obj in list(input_stream_state)] + input_stream_state = [ + AirbyteStateMessageSerializer.load(state_obj) for state_obj in list(input_stream_state) + ] with expected_error: state_manager = ConnectorStateManager(input_stream_state) @@ -145,11 +176,17 @@ def test_initialize_state_manager(input_stream_state, expected_stream_state, exp [ { "type": "STREAM", - "stream": {"stream_descriptor": {"name": "users", "namespace": "public"}, "stream_state": {"created_at": 12345}}, + "stream": { + "stream_descriptor": {"name": "users", "namespace": "public"}, + "stream_state": {"created_at": 12345}, + }, }, { "type": "STREAM", - "stream": {"stream_descriptor": {"name": "accounts", "namespace": "public"}, "stream_state": {"id": "abc"}}, + "stream": { + "stream_descriptor": {"name": "accounts", "namespace": "public"}, + "stream_state": {"id": "abc"}, + }, }, ], "users", @@ -161,9 +198,18 @@ def test_initialize_state_manager(input_stream_state, expected_stream_state, exp [ { "type": "STREAM", - "stream": {"stream_descriptor": {"name": "users"}, "stream_state": {"created_at": 12345}}, + "stream": { + "stream_descriptor": {"name": "users"}, + "stream_state": {"created_at": 12345}, + }, + }, + { + "type": "STREAM", + "stream": { + "stream_descriptor": {"name": "accounts"}, + "stream_state": {"id": "abc"}, + }, }, - {"type": "STREAM", "stream": {"stream_descriptor": {"name": "accounts"}, "stream_state": {"id": "abc"}}}, ], "users", None, @@ -173,7 +219,13 @@ def test_initialize_state_manager(input_stream_state, expected_stream_state, exp pytest.param( [ {"type": "STREAM", "stream": {"stream_descriptor": {"name": "users"}}}, - {"type": "STREAM", "stream": {"stream_descriptor": {"name": "accounts"}, "stream_state": {"id": "abc"}}}, + { + "type": "STREAM", + "stream": { + "stream_descriptor": {"name": "accounts"}, + "stream_state": {"id": "abc"}, + }, + }, ], "users", None, @@ -184,11 +236,17 @@ def test_initialize_state_manager(input_stream_state, expected_stream_state, exp [ { "type": "STREAM", - "stream": {"stream_descriptor": {"name": "users", "namespace": "public"}, "stream_state": {"created_at": 12345}}, + "stream": { + "stream_descriptor": {"name": "users", "namespace": "public"}, + "stream_state": {"created_at": 12345}, + }, }, { "type": "STREAM", - "stream": {"stream_descriptor": {"name": "accounts", "namespace": "public"}, "stream_state": {"id": "abc"}}, + "stream": { + "stream_descriptor": {"name": "accounts", "namespace": "public"}, + "stream_state": {"id": "abc"}, + }, }, ], "missing", @@ -200,11 +258,17 @@ def test_initialize_state_manager(input_stream_state, expected_stream_state, exp [ { "type": "STREAM", - "stream": {"stream_descriptor": {"name": "users", "namespace": "public"}, "stream_state": {"created_at": 12345}}, + "stream": { + "stream_descriptor": {"name": "users", "namespace": "public"}, + "stream_state": {"created_at": 12345}, + }, }, { "type": "STREAM", - "stream": {"stream_descriptor": {"name": "accounts", "namespace": "public"}, "stream_state": {"id": "abc"}}, + "stream": { + "stream_descriptor": {"name": "accounts", "namespace": "public"}, + "stream_state": {"id": "abc"}, + }, }, ], "users", @@ -212,12 +276,17 @@ def test_initialize_state_manager(input_stream_state, expected_stream_state, exp {}, id="test_get_stream_wrong_namespace", ), - pytest.param([], "users", "public", {}, id="test_get_empty_stream_state_defaults_to_empty_dictionary"), + pytest.param( + [], "users", "public", {}, id="test_get_empty_stream_state_defaults_to_empty_dictionary" + ), pytest.param( [ { "type": "STREAM", - "stream": {"stream_descriptor": {"name": "users", "namespace": "public"}, "stream_state": None}, + "stream": { + "stream_descriptor": {"name": "users", "namespace": "public"}, + "stream_state": None, + }, }, ], "users", @@ -228,7 +297,9 @@ def test_initialize_state_manager(input_stream_state, expected_stream_state, exp ], ) def test_get_stream_state(input_state, stream_name, namespace, expected_state): - state_messages = [AirbyteStateMessageSerializer.load(state_obj) for state_obj in list(input_state)] + state_messages = [ + AirbyteStateMessageSerializer.load(state_obj) for state_obj in list(input_state) + ] state_manager = ConnectorStateManager(state_messages) actual_state = state_manager.get_stream_state(stream_name, namespace) @@ -261,11 +332,17 @@ def test_get_state_returns_deep_copy(): [ { "type": "STREAM", - "stream": {"stream_descriptor": {"name": "actors", "namespace": "public"}, "stream_state": {"id": "mckean_michael"}}, + "stream": { + "stream_descriptor": {"name": "actors", "namespace": "public"}, + "stream_state": {"id": "mckean_michael"}, + }, }, { "type": "STREAM", - "stream": {"stream_descriptor": {"name": "actresses", "namespace": "public"}, "stream_state": {"id": "seehorn_rhea"}}, + "stream": { + "stream_descriptor": {"name": "actresses", "namespace": "public"}, + "stream_state": {"id": "seehorn_rhea"}, + }, }, ], "actors", @@ -284,7 +361,10 @@ def test_get_state_returns_deep_copy(): [ { "type": "STREAM", - "stream": {"stream_descriptor": {"name": "actresses", "namespace": "public"}, "stream_state": {"id": "seehorn_rhea"}}, + "stream": { + "stream_descriptor": {"name": "actresses", "namespace": "public"}, + "stream_state": {"id": "seehorn_rhea"}, + }, } ], "actors", @@ -296,7 +376,10 @@ def test_get_state_returns_deep_copy(): [ { "type": "STREAM", - "stream": {"stream_descriptor": {"name": "actresses", "namespace": "public"}, "stream_state": {"id": "seehorn_rhea"}}, + "stream": { + "stream_descriptor": {"name": "actresses", "namespace": "public"}, + "stream_state": {"id": "seehorn_rhea"}, + }, } ], "actors", @@ -312,9 +395,9 @@ def test_update_state_for_stream(start_state, update_name, update_namespace, upd state_manager.update_state_for_stream(update_name, update_namespace, update_value) - assert state_manager.per_stream_states[HashableStreamDescriptor(name=update_name, namespace=update_namespace)] == AirbyteStateBlob( - update_value - ) + assert state_manager.per_stream_states[ + HashableStreamDescriptor(name=update_name, namespace=update_namespace) + ] == AirbyteStateBlob(update_value) @pytest.mark.parametrize( @@ -392,7 +475,8 @@ def test_update_state_for_stream(start_state, update_name, update_namespace, upd state=AirbyteStateMessage( type=AirbyteStateType.STREAM, stream=AirbyteStreamState( - stream_descriptor=StreamDescriptor(name="missing", namespace="public"), stream_state=AirbyteStateBlob() + stream_descriptor=StreamDescriptor(name="missing", namespace="public"), + stream_state=AirbyteStateBlob(), ), ), ), @@ -415,7 +499,10 @@ def test_update_state_for_stream(start_state, update_name, update_namespace, upd state=AirbyteStateMessage( type=AirbyteStateType.STREAM, stream=AirbyteStreamState( - stream_descriptor=StreamDescriptor(name="episodes", namespace="nonexistent"), stream_state=AirbyteStateBlob() + stream_descriptor=StreamDescriptor( + name="episodes", namespace="nonexistent" + ), + stream_state=AirbyteStateBlob(), ), ), ), @@ -426,5 +513,7 @@ def test_update_state_for_stream(start_state, update_name, update_namespace, upd def test_create_state_message(start_state, update_name, update_namespace, expected_state_message): state_manager = ConnectorStateManager(start_state) - actual_state_message = state_manager.create_state_message(stream_name=update_name, namespace=update_namespace) + actual_state_message = state_manager.create_state_message( + stream_name=update_name, namespace=update_namespace + ) assert actual_state_message == expected_state_message diff --git a/unit_tests/sources/test_http_logger.py b/unit_tests/sources/test_http_logger.py index 5711d352..29f73e69 100644 --- a/unit_tests/sources/test_http_logger.py +++ b/unit_tests/sources/test_http_logger.py @@ -9,7 +9,9 @@ A_TITLE = "a title" A_DESCRIPTION = "a description" A_STREAM_NAME = "a stream name" -ANY_REQUEST = requests.Request(method="POST", url="http://a-url.com", headers={}, params={}).prepare() +ANY_REQUEST = requests.Request( + method="POST", url="http://a-url.com", headers={}, params={} +).prepare() class ResponseBuilder: @@ -83,7 +85,11 @@ def build(self): "http": { "title": A_TITLE, "description": A_DESCRIPTION, - "request": {"method": "GET", "body": {"content": None}, "headers": {"h1": "v1", "h2": "v2"}}, + "request": { + "method": "GET", + "body": {"content": None}, + "headers": {"h1": "v1", "h2": "v2"}, + }, "response": EMPTY_RESPONSE, }, "log": {"level": "debug"}, @@ -150,7 +156,11 @@ def build(self): "request": { "method": "GET", "body": {"content": '{"b1": "v1", "b2": "v2"}'}, - "headers": {"Content-Type": "application/json", "Content-Length": "24", "h1": "v1"}, + "headers": { + "Content-Type": "application/json", + "Content-Length": "24", + "h1": "v1", + }, }, "response": EMPTY_RESPONSE, }, @@ -174,7 +184,10 @@ def build(self): "request": { "method": "GET", "body": {"content": "b1=v1&b2=v2"}, - "headers": {"Content-Type": "application/x-www-form-urlencoded", "Content-Length": "11"}, + "headers": { + "Content-Type": "application/x-www-form-urlencoded", + "Content-Length": "11", + }, }, "response": EMPTY_RESPONSE, }, @@ -195,7 +208,11 @@ def build(self): "http": { "title": A_TITLE, "description": A_DESCRIPTION, - "request": {"method": "POST", "body": {"content": None}, "headers": {"Content-Length": "0"}}, + "request": { + "method": "POST", + "body": {"content": None}, + "headers": {"Content-Length": "0"}, + }, "response": EMPTY_RESPONSE, }, "log": {"level": "debug"}, @@ -204,7 +221,9 @@ def build(self): ), ], ) -def test_prepared_request_to_airbyte_message(test_name, http_method, url, headers, params, body_json, body_data, expected_airbyte_message): +def test_prepared_request_to_airbyte_message( + test_name, http_method, url, headers, params, body_json, body_data, expected_airbyte_message +): request = requests.Request(method=http_method, url=url, headers=headers, params=params) if body_json: request.json = body_json @@ -212,7 +231,9 @@ def test_prepared_request_to_airbyte_message(test_name, http_method, url, header request.data = body_data prepared_request = request.prepare() - actual_airbyte_message = format_http_message(ResponseBuilder().request(prepared_request).build(), A_TITLE, A_DESCRIPTION, A_STREAM_NAME) + actual_airbyte_message = format_http_message( + ResponseBuilder().request(prepared_request).build(), A_TITLE, A_DESCRIPTION, A_STREAM_NAME + ) assert actual_airbyte_message == expected_airbyte_message @@ -220,7 +241,13 @@ def test_prepared_request_to_airbyte_message(test_name, http_method, url, header @pytest.mark.parametrize( "test_name, response_body, response_headers, status_code, expected_airbyte_message", [ - ("test_response_no_body_no_headers", b"", {}, 200, {"body": {"content": ""}, "headers": {}, "status_code": 200}), + ( + "test_response_no_body_no_headers", + b"", + {}, + 200, + {"body": {"content": ""}, "headers": {}, "status_code": 200}, + ), ( "test_response_no_body_with_headers", b"", @@ -240,12 +267,24 @@ def test_prepared_request_to_airbyte_message(test_name, http_method, url, header b'{"b1": "v1", "b2": "v2"}', {"h1": "v1", "h2": "v2"}, 200, - {"body": {"content": '{"b1": "v1", "b2": "v2"}'}, "headers": {"h1": "v1", "h2": "v2"}, "status_code": 200}, + { + "body": {"content": '{"b1": "v1", "b2": "v2"}'}, + "headers": {"h1": "v1", "h2": "v2"}, + "status_code": 200, + }, ), ], ) -def test_response_to_airbyte_message(test_name, response_body, response_headers, status_code, expected_airbyte_message): - response = ResponseBuilder().body_content(response_body).headers(response_headers).status_code(status_code).build() +def test_response_to_airbyte_message( + test_name, response_body, response_headers, status_code, expected_airbyte_message +): + response = ( + ResponseBuilder() + .body_content(response_body) + .headers(response_headers) + .status_code(status_code) + .build() + ) actual_airbyte_message = format_http_message(response, A_TITLE, A_DESCRIPTION, A_STREAM_NAME) diff --git a/unit_tests/sources/test_integration_source.py b/unit_tests/sources/test_integration_source.py index 17628a02..1f86d1e7 100644 --- a/unit_tests/sources/test_integration_source.py +++ b/unit_tests/sources/test_integration_source.py @@ -23,19 +23,55 @@ @pytest.mark.parametrize( "deployment_mode, url_base, expected_records, expected_error", [ - pytest.param("CLOUD", "https://airbyte.com/api/v1/", [], None, id="test_cloud_read_with_public_endpoint"), - pytest.param("CLOUD", "http://unsecured.com/api/v1/", [], "system_error", id="test_cloud_read_with_unsecured_url"), - pytest.param("CLOUD", "https://172.20.105.99/api/v1/", [], "config_error", id="test_cloud_read_with_private_endpoint"), - pytest.param("CLOUD", "https://localhost:80/api/v1/", [], "config_error", id="test_cloud_read_with_localhost"), - pytest.param("OSS", "https://airbyte.com/api/v1/", [], None, id="test_oss_read_with_public_endpoint"), - pytest.param("OSS", "https://172.20.105.99/api/v1/", [], None, id="test_oss_read_with_private_endpoint"), + pytest.param( + "CLOUD", + "https://airbyte.com/api/v1/", + [], + None, + id="test_cloud_read_with_public_endpoint", + ), + pytest.param( + "CLOUD", + "http://unsecured.com/api/v1/", + [], + "system_error", + id="test_cloud_read_with_unsecured_url", + ), + pytest.param( + "CLOUD", + "https://172.20.105.99/api/v1/", + [], + "config_error", + id="test_cloud_read_with_private_endpoint", + ), + pytest.param( + "CLOUD", + "https://localhost:80/api/v1/", + [], + "config_error", + id="test_cloud_read_with_localhost", + ), + pytest.param( + "OSS", "https://airbyte.com/api/v1/", [], None, id="test_oss_read_with_public_endpoint" + ), + pytest.param( + "OSS", + "https://172.20.105.99/api/v1/", + [], + None, + id="test_oss_read_with_private_endpoint", + ), ], ) @patch.object(requests.Session, "send", fixture_mock_send) -def test_external_request_source(capsys, deployment_mode, url_base, expected_records, expected_error): +def test_external_request_source( + capsys, deployment_mode, url_base, expected_records, expected_error +): source = SourceTestFixture() - with mock.patch.dict(os.environ, {"DEPLOYMENT_MODE": deployment_mode}, clear=False): # clear=True clears the existing os.environ dict + with mock.patch.dict( + os.environ, {"DEPLOYMENT_MODE": deployment_mode}, clear=False + ): # clear=True clears the existing os.environ dict with mock.patch.object(HttpTestStream, "url_base", url_base): args = ["read", "--config", "config.json", "--catalog", "configured_catalog.json"] if expected_error: @@ -50,21 +86,54 @@ def test_external_request_source(capsys, deployment_mode, url_base, expected_rec @pytest.mark.parametrize( "deployment_mode, token_refresh_url, expected_records, expected_error", [ - pytest.param("CLOUD", "https://airbyte.com/api/v1/", [], None, id="test_cloud_read_with_public_endpoint"), - pytest.param("CLOUD", "http://unsecured.com/api/v1/", [], "system_error", id="test_cloud_read_with_unsecured_url"), - pytest.param("CLOUD", "https://172.20.105.99/api/v1/", [], "config_error", id="test_cloud_read_with_private_endpoint"), - pytest.param("OSS", "https://airbyte.com/api/v1/", [], None, id="test_oss_read_with_public_endpoint"), - pytest.param("OSS", "https://172.20.105.99/api/v1/", [], None, id="test_oss_read_with_private_endpoint"), + pytest.param( + "CLOUD", + "https://airbyte.com/api/v1/", + [], + None, + id="test_cloud_read_with_public_endpoint", + ), + pytest.param( + "CLOUD", + "http://unsecured.com/api/v1/", + [], + "system_error", + id="test_cloud_read_with_unsecured_url", + ), + pytest.param( + "CLOUD", + "https://172.20.105.99/api/v1/", + [], + "config_error", + id="test_cloud_read_with_private_endpoint", + ), + pytest.param( + "OSS", "https://airbyte.com/api/v1/", [], None, id="test_oss_read_with_public_endpoint" + ), + pytest.param( + "OSS", + "https://172.20.105.99/api/v1/", + [], + None, + id="test_oss_read_with_private_endpoint", + ), ], ) @patch.object(requests.Session, "send", fixture_mock_send) -def test_external_oauth_request_source(capsys, deployment_mode, token_refresh_url, expected_records, expected_error): +def test_external_oauth_request_source( + capsys, deployment_mode, token_refresh_url, expected_records, expected_error +): oauth_authenticator = SourceFixtureOauthAuthenticator( - client_id="nora", client_secret="hae_sung", refresh_token="arthur", token_refresh_endpoint=token_refresh_url + client_id="nora", + client_secret="hae_sung", + refresh_token="arthur", + token_refresh_endpoint=token_refresh_url, ) source = SourceTestFixture(authenticator=oauth_authenticator) - with mock.patch.dict(os.environ, {"DEPLOYMENT_MODE": deployment_mode}, clear=False): # clear=True clears the existing os.environ dict + with mock.patch.dict( + os.environ, {"DEPLOYMENT_MODE": deployment_mode}, clear=False + ): # clear=True clears the existing os.environ dict args = ["read", "--config", "config.json", "--catalog", "configured_catalog.json"] if expected_error: with pytest.raises(AirbyteTracedException): diff --git a/unit_tests/sources/test_source.py b/unit_tests/sources/test_source.py index d548a51b..c47b12a0 100644 --- a/unit_tests/sources/test_source.py +++ b/unit_tests/sources/test_source.py @@ -32,7 +32,11 @@ class MockSource(Source): def read( - self, logger: logging.Logger, config: Mapping[str, Any], catalog: ConfiguredAirbyteCatalog, state: MutableMapping[str, Any] = None + self, + logger: logging.Logger, + config: Mapping[str, Any], + catalog: ConfiguredAirbyteCatalog, + state: MutableMapping[str, Any] = None, ): pass @@ -66,12 +70,20 @@ def catalog(): configured_catalog = { "streams": [ { - "stream": {"name": "mock_http_stream", "json_schema": {}, "supported_sync_modes": ["full_refresh"]}, + "stream": { + "name": "mock_http_stream", + "json_schema": {}, + "supported_sync_modes": ["full_refresh"], + }, "destination_sync_mode": "overwrite", "sync_mode": "full_refresh", }, { - "stream": {"name": "mock_stream", "json_schema": {}, "supported_sync_modes": ["full_refresh"]}, + "stream": { + "name": "mock_stream", + "json_schema": {}, + "supported_sync_modes": ["full_refresh"], + }, "destination_sync_mode": "overwrite", "sync_mode": "full_refresh", }, @@ -221,7 +233,10 @@ def streams(self, config): "global": { "shared_state": {"shared_key": "shared_val"}, "stream_states": [ - {"stream_state": {"created_at": "2009-07-19"}, "stream_descriptor": {"name": "movies", "namespace": "public"}} + { + "stream_state": {"created_at": "2009-07-19"}, + "stream_descriptor": {"name": "movies", "namespace": "public"}, + } ], }, } @@ -233,7 +248,9 @@ def streams(self, config): shared_state=AirbyteStateBlob({"shared_key": "shared_val"}), stream_states=[ AirbyteStreamState( - stream_descriptor=StreamDescriptor(name="movies", namespace="public"), + stream_descriptor=StreamDescriptor( + name="movies", namespace="public" + ), stream_state=AirbyteStateBlob({"created_at": "2009-07-19"}), ) ], @@ -297,7 +314,9 @@ def test_read_state(source, incoming_state, expected_state, expected_error): with expected_error: actual = source.read_state(state_file.name) if expected_state and actual: - assert AirbyteStateMessageSerializer.dump(actual[0]) == AirbyteStateMessageSerializer.dump(expected_state[0]) + assert AirbyteStateMessageSerializer.dump( + actual[0] + ) == AirbyteStateMessageSerializer.dump(expected_state[0]) def test_read_invalid_state(source): @@ -311,7 +330,11 @@ def test_read_invalid_state(source): @pytest.mark.parametrize( "source, expected_state", [ - pytest.param(MockAbstractSource(), [], id="test_source_not_implementing_read_returns_per_stream_format"), + pytest.param( + MockAbstractSource(), + [], + id="test_source_not_implementing_read_returns_per_stream_format", + ), ], ) def test_read_state_nonexistent(source, expected_state): @@ -361,19 +384,34 @@ def test_internal_config(abstract_source, catalog): assert not non_http_stream.page_size # Test with records limit set to 1 internal_config = {"some_config": 100, "_limit": 1} - records = [r for r in abstract_source.read(logger=logger, config=internal_config, catalog=catalog, state={})] + records = [ + r + for r in abstract_source.read( + logger=logger, config=internal_config, catalog=catalog, state={} + ) + ] # 1 from http stream + 1 from non http stream, 1 for state message for each stream (2x) and 3 for stream status messages for each stream (2x) assert len(records) == 1 + 1 + 1 + 1 + 3 + 3 assert "_limit" not in abstract_source.streams_config assert "some_config" in abstract_source.streams_config # Test with records limit set to number that exceeds expceted records internal_config = {"some_config": 100, "_limit": 20} - records = [r for r in abstract_source.read(logger=logger, config=internal_config, catalog=catalog, state={})] + records = [ + r + for r in abstract_source.read( + logger=logger, config=internal_config, catalog=catalog, state={} + ) + ] assert len(records) == 3 + 3 + 1 + 1 + 3 + 3 # Check if page_size paramter is set to http instance only internal_config = {"some_config": 100, "_page_size": 2} - records = [r for r in abstract_source.read(logger=logger, config=internal_config, catalog=catalog, state={})] + records = [ + r + for r in abstract_source.read( + logger=logger, config=internal_config, catalog=catalog, state={} + ) + ] assert "_page_size" not in abstract_source.streams_config assert "some_config" in abstract_source.streams_config assert len(records) == 3 + 3 + 1 + 1 + 3 + 3 @@ -397,7 +435,12 @@ def test_internal_config_limit(mocker, abstract_source, catalog): internal_config = {"some_config": 100, "_limit": STREAM_LIMIT} catalog.streams[0].sync_mode = SyncMode.full_refresh - records = [r for r in abstract_source.read(logger=logger_mock, config=internal_config, catalog=catalog, state={})] + records = [ + r + for r in abstract_source.read( + logger=logger_mock, config=internal_config, catalog=catalog, state={} + ) + ] assert len(records) == STREAM_LIMIT + SLICE_DEBUG_LOG_COUNT + TRACE_STATUS_COUNT + STATE_COUNT logger_info_args = [call[0][0] for call in logger_mock.info.call_args_list] # Check if log line matches number of limit @@ -406,14 +449,21 @@ def test_internal_config_limit(mocker, abstract_source, catalog): # No limit, check if state record produced for incremental stream catalog.streams[0].sync_mode = SyncMode.incremental - records = [r for r in abstract_source.read(logger=logger_mock, config={}, catalog=catalog, state={})] + records = [ + r for r in abstract_source.read(logger=logger_mock, config={}, catalog=catalog, state={}) + ] assert len(records) == FULL_RECORDS_NUMBER + SLICE_DEBUG_LOG_COUNT + TRACE_STATUS_COUNT + 1 assert records[-2].type == Type.STATE assert records[-1].type == Type.TRACE # Set limit and check if state is produced when limit is set for incremental stream logger_mock.reset_mock() - records = [r for r in abstract_source.read(logger=logger_mock, config=internal_config, catalog=catalog, state={})] + records = [ + r + for r in abstract_source.read( + logger=logger_mock, config=internal_config, catalog=catalog, state={} + ) + ] assert len(records) == STREAM_LIMIT + SLICE_DEBUG_LOG_COUNT + TRACE_STATUS_COUNT + 1 assert records[-2].type == Type.STATE assert records[-1].type == Type.TRACE @@ -436,8 +486,12 @@ def test_source_config_no_transform(mocker, abstract_source, catalog): streams = abstract_source.streams(None) http_stream, non_http_stream = streams http_stream.get_json_schema.return_value = non_http_stream.get_json_schema.return_value = SCHEMA - http_stream.read_records.return_value, non_http_stream.read_records.return_value = [[{"value": 23}] * 5] * 2 - records = [r for r in abstract_source.read(logger=logger_mock, config={}, catalog=catalog, state={})] + http_stream.read_records.return_value, non_http_stream.read_records.return_value = [ + [{"value": 23}] * 5 + ] * 2 + records = [ + r for r in abstract_source.read(logger=logger_mock, config={}, catalog=catalog, state={}) + ] assert len(records) == 2 * (5 + SLICE_DEBUG_LOG_COUNT + TRACE_STATUS_COUNT + STATE_COUNT) assert [r.record.data for r in records if r.type == Type.RECORD] == [{"value": 23}] * 2 * 5 assert http_stream.get_json_schema.call_count == 5 + GET_JSON_SCHEMA_COUNT_WHEN_FILTERING @@ -455,8 +509,13 @@ def test_source_config_transform(mocker, abstract_source, catalog): http_stream.transformer = TypeTransformer(TransformConfig.DefaultSchemaNormalization) non_http_stream.transformer = TypeTransformer(TransformConfig.DefaultSchemaNormalization) http_stream.get_json_schema.return_value = non_http_stream.get_json_schema.return_value = SCHEMA - http_stream.read_records.return_value, non_http_stream.read_records.return_value = [{"value": 23}], [{"value": 23}] - records = [r for r in abstract_source.read(logger=logger_mock, config={}, catalog=catalog, state={})] + http_stream.read_records.return_value, non_http_stream.read_records.return_value = ( + [{"value": 23}], + [{"value": 23}], + ) + records = [ + r for r in abstract_source.read(logger=logger_mock, config={}, catalog=catalog, state={}) + ] assert len(records) == 2 + SLICE_DEBUG_LOG_COUNT + TRACE_STATUS_COUNT + STATE_COUNT assert [r.record.data for r in records if r.type == Type.RECORD] == [{"value": "23"}] * 2 @@ -471,7 +530,15 @@ def test_source_config_transform_and_no_transform(mocker, abstract_source, catal http_stream, non_http_stream = streams http_stream.transformer = TypeTransformer(TransformConfig.DefaultSchemaNormalization) http_stream.get_json_schema.return_value = non_http_stream.get_json_schema.return_value = SCHEMA - http_stream.read_records.return_value, non_http_stream.read_records.return_value = [{"value": 23}], [{"value": 23}] - records = [r for r in abstract_source.read(logger=logger_mock, config={}, catalog=catalog, state={})] + http_stream.read_records.return_value, non_http_stream.read_records.return_value = ( + [{"value": 23}], + [{"value": 23}], + ) + records = [ + r for r in abstract_source.read(logger=logger_mock, config={}, catalog=catalog, state={}) + ] assert len(records) == 2 + SLICE_DEBUG_LOG_COUNT + TRACE_STATUS_COUNT + STATE_COUNT - assert [r.record.data for r in records if r.type == Type.RECORD] == [{"value": "23"}, {"value": 23}] + assert [r.record.data for r in records if r.type == Type.RECORD] == [ + {"value": "23"}, + {"value": 23}, + ] diff --git a/unit_tests/sources/test_source_read.py b/unit_tests/sources/test_source_read.py index 05c71d1e..a4878a8c 100644 --- a/unit_tests/sources/test_source_read.py +++ b/unit_tests/sources/test_source_read.py @@ -30,7 +30,9 @@ from airbyte_cdk.sources.streams.concurrent.cursor import FinalStateCursor from airbyte_cdk.sources.streams.core import StreamData from airbyte_cdk.utils import AirbyteTracedException -from unit_tests.sources.streams.concurrent.scenarios.thread_based_concurrent_stream_source_builder import NeverLogSliceLogger +from unit_tests.sources.streams.concurrent.scenarios.thread_based_concurrent_stream_source_builder import ( + NeverLogSliceLogger, +) class _MockStream(Stream): @@ -47,7 +49,11 @@ def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]: return None def stream_slices( - self, *, sync_mode: SyncMode, cursor_field: Optional[List[str]] = None, stream_state: Optional[Mapping[str, Any]] = None + self, + *, + sync_mode: SyncMode, + cursor_field: Optional[List[str]] = None, + stream_state: Optional[Mapping[str, Any]] = None, ) -> Iterable[Optional[Mapping[str, Any]]]: for partition in self._slice_to_records.keys(): yield {"partition": partition} @@ -72,7 +78,9 @@ def get_json_schema(self) -> Mapping[str, Any]: class _MockSource(AbstractSource): message_repository = InMemoryMessageRepository() - def check_connection(self, logger: logging.Logger, config: Mapping[str, Any]) -> Tuple[bool, Optional[Any]]: + def check_connection( + self, logger: logging.Logger, config: Mapping[str, Any] + ) -> Tuple[bool, Optional[Any]]: pass def set_streams(self, streams): @@ -86,10 +94,14 @@ class _MockConcurrentSource(ConcurrentSourceAdapter): message_repository = InMemoryMessageRepository() def __init__(self, logger): - concurrent_source = ConcurrentSource.create(1, 1, logger, NeverLogSliceLogger(), self.message_repository) + concurrent_source = ConcurrentSource.create( + 1, 1, logger, NeverLogSliceLogger(), self.message_repository + ) super().__init__(concurrent_source) - def check_connection(self, logger: logging.Logger, config: Mapping[str, Any]) -> Tuple[bool, Optional[Any]]: + def check_connection( + self, logger: logging.Logger, config: Mapping[str, Any] + ) -> Tuple[bool, Optional[Any]]: pass def set_streams(self, streams): @@ -117,18 +129,28 @@ def test_concurrent_source_yields_the_same_messages_as_abstract_source_when_no_e {"id": 300, "partition": "B"}, {"id": 400, "partition": "B"}, ] - stream_1_slice_to_partition = {"1": records_stream_1_partition_1, "2": records_stream_1_partition_2} - stream_2_slice_to_partition = {"A": records_stream_2_partition_1, "B": records_stream_2_partition_2} + stream_1_slice_to_partition = { + "1": records_stream_1_partition_1, + "2": records_stream_1_partition_2, + } + stream_2_slice_to_partition = { + "A": records_stream_2_partition_1, + "B": records_stream_2_partition_2, + } state = None logger = _init_logger() - source, concurrent_source = _init_sources([stream_1_slice_to_partition, stream_2_slice_to_partition], state, logger) + source, concurrent_source = _init_sources( + [stream_1_slice_to_partition, stream_2_slice_to_partition], state, logger + ) config = {} catalog = _create_configured_catalog(source._streams) # FIXME this is currently unused in this test # messages_from_abstract_source = _read_from_source(source, logger, config, catalog, state, None) - messages_from_concurrent_source = _read_from_source(concurrent_source, logger, config, catalog, state, None) + messages_from_concurrent_source = _read_from_source( + concurrent_source, logger, config, catalog, state, None + ) expected_messages = [ AirbyteMessage( @@ -139,7 +161,8 @@ def test_concurrent_source_yields_the_same_messages_as_abstract_source_when_no_e error=None, estimate=None, stream_status=AirbyteStreamStatusTraceMessage( - stream_descriptor=StreamDescriptor(name="stream0"), status=AirbyteStreamStatus(AirbyteStreamStatus.STARTED) + stream_descriptor=StreamDescriptor(name="stream0"), + status=AirbyteStreamStatus(AirbyteStreamStatus.STARTED), ), ), ), @@ -151,7 +174,8 @@ def test_concurrent_source_yields_the_same_messages_as_abstract_source_when_no_e error=None, estimate=None, stream_status=AirbyteStreamStatusTraceMessage( - stream_descriptor=StreamDescriptor(name="stream0"), status=AirbyteStreamStatus(AirbyteStreamStatus.RUNNING) + stream_descriptor=StreamDescriptor(name="stream0"), + status=AirbyteStreamStatus(AirbyteStreamStatus.RUNNING), ), ), ), @@ -195,7 +219,8 @@ def test_concurrent_source_yields_the_same_messages_as_abstract_source_when_no_e error=None, estimate=None, stream_status=AirbyteStreamStatusTraceMessage( - stream_descriptor=StreamDescriptor(name="stream0"), status=AirbyteStreamStatus(AirbyteStreamStatus.COMPLETE) + stream_descriptor=StreamDescriptor(name="stream0"), + status=AirbyteStreamStatus(AirbyteStreamStatus.COMPLETE), ), ), ), @@ -207,7 +232,8 @@ def test_concurrent_source_yields_the_same_messages_as_abstract_source_when_no_e error=None, estimate=None, stream_status=AirbyteStreamStatusTraceMessage( - stream_descriptor=StreamDescriptor(name="stream1"), status=AirbyteStreamStatus(AirbyteStreamStatus.STARTED) + stream_descriptor=StreamDescriptor(name="stream1"), + status=AirbyteStreamStatus(AirbyteStreamStatus.STARTED), ), ), ), @@ -219,7 +245,8 @@ def test_concurrent_source_yields_the_same_messages_as_abstract_source_when_no_e error=None, estimate=None, stream_status=AirbyteStreamStatusTraceMessage( - stream_descriptor=StreamDescriptor(name="stream1"), status=AirbyteStreamStatus(AirbyteStreamStatus.RUNNING) + stream_descriptor=StreamDescriptor(name="stream1"), + status=AirbyteStreamStatus(AirbyteStreamStatus.RUNNING), ), ), ), @@ -263,7 +290,8 @@ def test_concurrent_source_yields_the_same_messages_as_abstract_source_when_no_e error=None, estimate=None, stream_status=AirbyteStreamStatusTraceMessage( - stream_descriptor=StreamDescriptor(name="stream1"), status=AirbyteStreamStatus(AirbyteStreamStatus.COMPLETE) + stream_descriptor=StreamDescriptor(name="stream1"), + status=AirbyteStreamStatus(AirbyteStreamStatus.COMPLETE), ), ), ), @@ -281,8 +309,12 @@ def test_concurrent_source_yields_the_same_messages_as_abstract_source_when_a_tr source, concurrent_source = _init_sources([stream_slice_to_partition], state, logger) config = {} catalog = _create_configured_catalog(source._streams) - messages_from_abstract_source = _read_from_source(source, logger, config, catalog, state, AirbyteTracedException) - messages_from_concurrent_source = _read_from_source(concurrent_source, logger, config, catalog, state, AirbyteTracedException) + messages_from_abstract_source = _read_from_source( + source, logger, config, catalog, state, AirbyteTracedException + ) + messages_from_concurrent_source = _read_from_source( + concurrent_source, logger, config, catalog, state, AirbyteTracedException + ) _assert_status_messages(messages_from_abstract_source, messages_from_concurrent_source) _assert_record_messages(messages_from_abstract_source, messages_from_concurrent_source) @@ -300,8 +332,12 @@ def test_concurrent_source_yields_the_same_messages_as_abstract_source_when_an_e source, concurrent_source = _init_sources([stream_slice_to_partition], state, logger) config = {} catalog = _create_configured_catalog(source._streams) - messages_from_abstract_source = _read_from_source(source, logger, config, catalog, state, AirbyteTracedException) - messages_from_concurrent_source = _read_from_source(concurrent_source, logger, config, catalog, state, AirbyteTracedException) + messages_from_abstract_source = _read_from_source( + source, logger, config, catalog, state, AirbyteTracedException + ) + messages_from_concurrent_source = _read_from_source( + concurrent_source, logger, config, catalog, state, AirbyteTracedException + ) _assert_status_messages(messages_from_abstract_source, messages_from_concurrent_source) _assert_record_messages(messages_from_abstract_source, messages_from_concurrent_source) @@ -327,11 +363,17 @@ def _assert_status_messages(messages_from_abstract_source, messages_from_concurr def _assert_record_messages(messages_from_abstract_source, messages_from_concurrent_source): - records_from_concurrent_source = [message for message in messages_from_concurrent_source if message.type == MessageType.RECORD] + records_from_concurrent_source = [ + message for message in messages_from_concurrent_source if message.type == MessageType.RECORD + ] assert records_from_concurrent_source _verify_messages( - [message for message in messages_from_abstract_source if message.type == MessageType.RECORD], + [ + message + for message in messages_from_abstract_source + if message.type == MessageType.RECORD + ], records_from_concurrent_source, ) @@ -343,7 +385,9 @@ def _assert_errors(messages_from_abstract_source, messages_from_concurrent_sourc if message.type == MessageType.TRACE and message.trace.type == TraceType.ERROR ] errors_from_abstract_source = [ - message for message in messages_from_abstract_source if message.type == MessageType.TRACE and message.trace.type == TraceType.ERROR + message + for message in messages_from_abstract_source + if message.type == MessageType.TRACE and message.trace.type == TraceType.ERROR ] assert errors_from_concurrent_source @@ -360,7 +404,9 @@ def _init_logger(): def _init_sources(stream_slice_to_partitions, state, logger): source = _init_source(stream_slice_to_partitions, state, logger, _MockSource()) - concurrent_source = _init_source(stream_slice_to_partitions, state, logger, _MockConcurrentSource(logger)) + concurrent_source = _init_source( + stream_slice_to_partitions, state, logger, _MockConcurrentSource(logger) + ) return source, concurrent_source @@ -371,7 +417,11 @@ def _init_source(stream_slice_to_partitions, state, logger, source): source, logger, state, - FinalStateCursor(stream_name=f"stream{i}", stream_namespace=None, message_repository=InMemoryMessageRepository()), + FinalStateCursor( + stream_name=f"stream{i}", + stream_namespace=None, + message_repository=InMemoryMessageRepository(), + ), ) for i, stream_slices in enumerate(stream_slice_to_partitions) ] @@ -383,7 +433,9 @@ def _create_configured_catalog(streams): return ConfiguredAirbyteCatalog( streams=[ ConfiguredAirbyteStream( - stream=AirbyteStream(name=s.name, json_schema={}, supported_sync_modes=[SyncMode.full_refresh]), + stream=AirbyteStream( + name=s.name, json_schema={}, supported_sync_modes=[SyncMode.full_refresh] + ), sync_mode=SyncMode.full_refresh, cursor_field=None, destination_sync_mode=DestinationSyncMode.overwrite, diff --git a/unit_tests/sources/utils/test_record_helper.py b/unit_tests/sources/utils/test_record_helper.py index b5476180..71f882df 100644 --- a/unit_tests/sources/utils/test_record_helper.py +++ b/unit_tests/sources/utils/test_record_helper.py @@ -30,7 +30,11 @@ {"id": 0, "field_A": 1.0, "field_B": "airbyte"}, AirbyteMessage( type=MessageType.RECORD, - record=AirbyteRecordMessage(stream="my_stream", data={"id": 0, "field_A": 1.0, "field_B": "airbyte"}, emitted_at=NOW), + record=AirbyteRecordMessage( + stream="my_stream", + data={"id": 0, "field_A": 1.0, "field_B": "airbyte"}, + emitted_at=NOW, + ), ), ), ], @@ -54,12 +58,18 @@ def test_data_or_record_to_airbyte_record(test_name, data, expected_message): ( "test_log_message_to_airbyte_record", AirbyteLogMessage(level=Level.INFO, message="Hello, this is a log message"), - AirbyteMessage(type=MessageType.LOG, log=AirbyteLogMessage(level=Level.INFO, message="Hello, this is a log message")), + AirbyteMessage( + type=MessageType.LOG, + log=AirbyteLogMessage(level=Level.INFO, message="Hello, this is a log message"), + ), ), ( "test_trace_message_to_airbyte_record", AirbyteTraceMessage(type=TraceType.ERROR, emitted_at=101), - AirbyteMessage(type=MessageType.TRACE, trace=AirbyteTraceMessage(type=TraceType.ERROR, emitted_at=101)), + AirbyteMessage( + type=MessageType.TRACE, + trace=AirbyteTraceMessage(type=TraceType.ERROR, emitted_at=101), + ), ), ], ) diff --git a/unit_tests/sources/utils/test_schema_helpers.py b/unit_tests/sources/utils/test_schema_helpers.py index 76b7a9b1..495c728e 100644 --- a/unit_tests/sources/utils/test_schema_helpers.py +++ b/unit_tests/sources/utils/test_schema_helpers.py @@ -15,7 +15,11 @@ import jsonref import pytest from airbyte_cdk.models import ConnectorSpecification, ConnectorSpecificationSerializer, FailureType -from airbyte_cdk.sources.utils.schema_helpers import InternalConfig, ResourceSchemaLoader, check_config_against_spec_or_exit +from airbyte_cdk.sources.utils.schema_helpers import ( + InternalConfig, + ResourceSchemaLoader, + check_config_against_spec_or_exit, +) from airbyte_cdk.utils.traced_exception import AirbyteTracedException from pytest import fixture from pytest import raises as pytest_raises diff --git a/unit_tests/sources/utils/test_slice_logger.py b/unit_tests/sources/utils/test_slice_logger.py index 0796e876..43b54050 100644 --- a/unit_tests/sources/utils/test_slice_logger.py +++ b/unit_tests/sources/utils/test_slice_logger.py @@ -13,18 +13,75 @@ @pytest.mark.parametrize( "slice_logger, level, should_log", [ - pytest.param(DebugSliceLogger(), logging.DEBUG, True, id="debug_logger_should_log_if_level_is_debug"), - pytest.param(DebugSliceLogger(), logging.INFO, False, id="debug_logger_should_not_log_if_level_is_info"), - pytest.param(DebugSliceLogger(), logging.WARN, False, id="debug_logger_should_not_log_if_level_is_warn"), - pytest.param(DebugSliceLogger(), logging.WARNING, False, id="debug_logger_should_not_log_if_level_is_warning"), - pytest.param(DebugSliceLogger(), logging.ERROR, False, id="debug_logger_should_not_log_if_level_is_error"), - pytest.param(DebugSliceLogger(), logging.CRITICAL, False, id="always_log_logger_should_not_log_if_level_is_critical"), - pytest.param(AlwaysLogSliceLogger(), logging.DEBUG, True, id="always_log_logger_should_log_if_level_is_debug"), - pytest.param(AlwaysLogSliceLogger(), logging.INFO, True, id="always_log_logger_should_log_if_level_is_info"), - pytest.param(AlwaysLogSliceLogger(), logging.WARN, True, id="always_log_logger_should_log_if_level_is_warn"), - pytest.param(AlwaysLogSliceLogger(), logging.WARNING, True, id="always_log_logger_should_log_if_level_is_warning"), - pytest.param(AlwaysLogSliceLogger(), logging.ERROR, True, id="always_log_logger_should_log_if_level_is_error"), - pytest.param(AlwaysLogSliceLogger(), logging.CRITICAL, True, id="always_log_logger_should_log_if_level_is_critical"), + pytest.param( + DebugSliceLogger(), logging.DEBUG, True, id="debug_logger_should_log_if_level_is_debug" + ), + pytest.param( + DebugSliceLogger(), + logging.INFO, + False, + id="debug_logger_should_not_log_if_level_is_info", + ), + pytest.param( + DebugSliceLogger(), + logging.WARN, + False, + id="debug_logger_should_not_log_if_level_is_warn", + ), + pytest.param( + DebugSliceLogger(), + logging.WARNING, + False, + id="debug_logger_should_not_log_if_level_is_warning", + ), + pytest.param( + DebugSliceLogger(), + logging.ERROR, + False, + id="debug_logger_should_not_log_if_level_is_error", + ), + pytest.param( + DebugSliceLogger(), + logging.CRITICAL, + False, + id="always_log_logger_should_not_log_if_level_is_critical", + ), + pytest.param( + AlwaysLogSliceLogger(), + logging.DEBUG, + True, + id="always_log_logger_should_log_if_level_is_debug", + ), + pytest.param( + AlwaysLogSliceLogger(), + logging.INFO, + True, + id="always_log_logger_should_log_if_level_is_info", + ), + pytest.param( + AlwaysLogSliceLogger(), + logging.WARN, + True, + id="always_log_logger_should_log_if_level_is_warn", + ), + pytest.param( + AlwaysLogSliceLogger(), + logging.WARNING, + True, + id="always_log_logger_should_log_if_level_is_warning", + ), + pytest.param( + AlwaysLogSliceLogger(), + logging.ERROR, + True, + id="always_log_logger_should_log_if_level_is_error", + ), + pytest.param( + AlwaysLogSliceLogger(), + logging.CRITICAL, + True, + id="always_log_logger_should_log_if_level_is_critical", + ), ], ) def test_should_log_slice_message(slice_logger, level, should_log): @@ -41,6 +98,8 @@ def test_should_log_slice_message(slice_logger, level, should_log): ], ) def test_create_slice_log_message(_slice, expected_message): - expected_log_message = AirbyteMessage(type=MessageType.LOG, log=AirbyteLogMessage(level=Level.INFO, message=expected_message)) + expected_log_message = AirbyteMessage( + type=MessageType.LOG, log=AirbyteLogMessage(level=Level.INFO, message=expected_message) + ) log_message = DebugSliceLogger().create_slice_log_message(_slice) assert log_message == expected_log_message diff --git a/unit_tests/sources/utils/test_transform.py b/unit_tests/sources/utils/test_transform.py index 9b3f7398..5d7aa1a6 100644 --- a/unit_tests/sources/utils/test_transform.py +++ b/unit_tests/sources/utils/test_transform.py @@ -47,7 +47,9 @@ "properties": { "very_nested_value": { "type": ["null", "object"], - "properties": {"very_nested_value": {"type": ["null", "number"]}}, + "properties": { + "very_nested_value": {"type": ["null", "number"]} + }, } }, } @@ -64,16 +66,36 @@ [ (SIMPLE_SCHEMA, {"value": 12}, {"value": "12"}, None), (SIMPLE_SCHEMA, {"value": 12}, {"value": "12"}, None), - (SIMPLE_SCHEMA, {"value": 12, "unexpected_value": "unexpected"}, {"value": "12", "unexpected_value": "unexpected"}, None), - (COMPLEX_SCHEMA, {"value": 1, "array": ["111", 111, {1: 111}]}, {"value": True, "array": ["111", "111", "{1: 111}"]}, None), + ( + SIMPLE_SCHEMA, + {"value": 12, "unexpected_value": "unexpected"}, + {"value": "12", "unexpected_value": "unexpected"}, + None, + ), + ( + COMPLEX_SCHEMA, + {"value": 1, "array": ["111", 111, {1: 111}]}, + {"value": True, "array": ["111", "111", "{1: 111}"]}, + None, + ), ( COMPLEX_SCHEMA, {"value": 1, "list_of_lists": [["111"], [111], [11], [{1: 1}]]}, {"value": True, "list_of_lists": [["111"], ["111"], ["11"], ["{1: 1}"]]}, None, ), - (COMPLEX_SCHEMA, {"value": 1, "nested": {"a": [1, 2, 3]}}, {"value": True, "nested": {"a": "[1, 2, 3]"}}, None), - (COMPLEX_SCHEMA, {"value": "false", "nested": {"a": [1, 2, 3]}}, {"value": False, "nested": {"a": "[1, 2, 3]"}}, None), + ( + COMPLEX_SCHEMA, + {"value": 1, "nested": {"a": [1, 2, 3]}}, + {"value": True, "nested": {"a": "[1, 2, 3]"}}, + None, + ), + ( + COMPLEX_SCHEMA, + {"value": "false", "nested": {"a": [1, 2, 3]}}, + {"value": False, "nested": {"a": "[1, 2, 3]"}}, + None, + ), (COMPLEX_SCHEMA, {}, {}, None), (COMPLEX_SCHEMA, {"int_prop": "12"}, {"int_prop": 12}, None), # Skip invalid formattted field and process other fields. @@ -93,14 +115,36 @@ # Test null field (COMPLEX_SCHEMA, {"prop": None, "array": [12]}, {"prop": "None", "array": ["12"]}, None), # If field can be null do not convert - (COMPLEX_SCHEMA, {"prop_with_null": None, "array": [12]}, {"prop_with_null": None, "array": ["12"]}, None), + ( + COMPLEX_SCHEMA, + {"prop_with_null": None, "array": [12]}, + {"prop_with_null": None, "array": ["12"]}, + None, + ), ( VERY_NESTED_SCHEMA, - {"very_nested_value": {"very_nested_value": {"very_nested_value": {"very_nested_value": {"very_nested_value": "2"}}}}}, - {"very_nested_value": {"very_nested_value": {"very_nested_value": {"very_nested_value": {"very_nested_value": 2.0}}}}}, + { + "very_nested_value": { + "very_nested_value": { + "very_nested_value": {"very_nested_value": {"very_nested_value": "2"}} + } + } + }, + { + "very_nested_value": { + "very_nested_value": { + "very_nested_value": {"very_nested_value": {"very_nested_value": 2.0}} + } + } + }, + None, + ), + ( + VERY_NESTED_SCHEMA, + {"very_nested_value": {"very_nested_value": None}}, + {"very_nested_value": {"very_nested_value": None}}, None, ), - (VERY_NESTED_SCHEMA, {"very_nested_value": {"very_nested_value": None}}, {"very_nested_value": {"very_nested_value": None}}, None), # Object without properties ({"type": "object"}, {"value": 12}, {"value": 12}, None), ( @@ -136,19 +180,28 @@ None, ), ( - {"type": "object", "properties": {"value": {"type": "array", "items": {"type": ["string"]}}}}, + { + "type": "object", + "properties": {"value": {"type": "array", "items": {"type": ["string"]}}}, + }, {"value": 10}, {"value": ["10"]}, None, ), ( - {"type": "object", "properties": {"value": {"type": "array", "items": {"type": ["object"]}}}}, + { + "type": "object", + "properties": {"value": {"type": "array", "items": {"type": ["object"]}}}, + }, {"value": "string"}, {"value": "string"}, "Failed to transform value 'string' of type 'string' to 'array', key path: '.value'", ), ( - {"type": "object", "properties": {"value": {"type": "array", "items": {"type": ["string"]}}}}, + { + "type": "object", + "properties": {"value": {"type": "array", "items": {"type": ["string"]}}}, + }, {"value": {"key": "value"}}, {"value": {"key": "value"}}, "Failed to transform value {'key': 'value'} of type 'object' to 'array', key path: '.value'", @@ -169,7 +222,13 @@ ), ( # Oneof not suported, no conversion for one_of_value should happen - {"type": "object", "properties": {"one_of_value": {"oneOf": ["string", "boolean", "null"]}, "value_2": {"type": "string"}}}, + { + "type": "object", + "properties": { + "one_of_value": {"oneOf": ["string", "boolean", "null"]}, + "value_2": {"type": "string"}, + }, + }, {"one_of_value": 12, "value_2": 12}, {"one_of_value": 12, "value_2": "12"}, None, @@ -186,19 +245,30 @@ None, ), ( - {"type": "object", "properties": {"value": {"type": "array", "items": {"type": "string"}}}}, + { + "type": "object", + "properties": {"value": {"type": "array", "items": {"type": "string"}}}, + }, {"value": {"key": "value"}}, {"value": {"key": "value"}}, "Failed to transform value {'key': 'value'} of type 'object' to 'array', key path: '.value'", ), ( - {"type": "object", "properties": {"value1": {"type": "object", "properties": {"value2": {"type": "string"}}}}}, + { + "type": "object", + "properties": { + "value1": {"type": "object", "properties": {"value2": {"type": "string"}}} + }, + }, {"value1": "value2"}, {"value1": "value2"}, "Failed to transform value 'value2' of type 'string' to 'object', key path: '.value1'", ), ( - {"type": "object", "properties": {"value": {"type": "array", "items": {"type": "object"}}}}, + { + "type": "object", + "properties": {"value": {"type": "array", "items": {"type": "object"}}}, + }, {"value": ["one", "two"]}, {"value": ["one", "two"]}, "Failed to transform value 'one' of type 'string' to 'object', key path: '.value.0'", @@ -222,7 +292,10 @@ def test_transform_wrong_config(): with pytest.raises(Exception, match="NoTransform option cannot be combined with other flags."): TypeTransformer(TransformConfig.NoTransform | TransformConfig.DefaultSchemaNormalization) - with pytest.raises(Exception, match="Please set TransformConfig.CustomSchemaNormalization config before registering custom normalizer"): + with pytest.raises( + Exception, + match="Please set TransformConfig.CustomSchemaNormalization config before registering custom normalizer", + ): class NotAStream: transformer = TypeTransformer(TransformConfig.DefaultSchemaNormalization) @@ -251,7 +324,9 @@ def transform_cb(instance, schema): def test_custom_transform_with_default_normalization(): class NotAStream: - transformer = TypeTransformer(TransformConfig.CustomSchemaNormalization | TransformConfig.DefaultSchemaNormalization) + transformer = TypeTransformer( + TransformConfig.CustomSchemaNormalization | TransformConfig.DefaultSchemaNormalization + ) @transformer.registerCustomTransform def transform_cb(instance, schema): diff --git a/unit_tests/test/mock_http/test_matcher.py b/unit_tests/test/mock_http/test_matcher.py index 61a9ecfe..a1018a01 100644 --- a/unit_tests/test/mock_http/test_matcher.py +++ b/unit_tests/test/mock_http/test_matcher.py @@ -19,7 +19,9 @@ def test_given_request_matches_when_matches_then_has_expected_match_count(self): self._matcher.matches(self._a_request) assert self._matcher.has_expected_match_count() - def test_given_request_does_not_match_when_matches_then_does_not_have_expected_match_count(self): + def test_given_request_does_not_match_when_matches_then_does_not_have_expected_match_count( + self, + ): self._a_request.matches.return_value = False self._matcher.matches(self._a_request) @@ -44,7 +46,9 @@ def test_given_expected_number_of_requests_met_when_matches_then_has_expected_ma assert _matcher.has_expected_match_count() assert _matcher.actual_number_of_matches == 2 - def test_given_expected_number_of_requests_not_met_when_matches_then_does_not_have_expected_match_count(self): + def test_given_expected_number_of_requests_not_met_when_matches_then_does_not_have_expected_match_count( + self, + ): _matcher = HttpRequestMatcher(self._request_to_match, 2) self._a_request.matches.side_effect = [True, False] _matcher.matches(self._a_request) diff --git a/unit_tests/test/mock_http/test_mocker.py b/unit_tests/test/mock_http/test_mocker.py index 2b086a17..9ca79ae5 100644 --- a/unit_tests/test/mock_http/test_mocker.py +++ b/unit_tests/test/mock_http/test_mocker.py @@ -54,7 +54,11 @@ def test_given_loose_headers_matching_when_decorate_then_match(self, http_mocker HttpResponse(_A_RESPONSE_BODY, 474), ) - requests.get(_A_URL, params=_SOME_QUERY_PARAMS, headers=_SOME_HEADERS | {"more strict query param key": "any value"}) + requests.get( + _A_URL, + params=_SOME_QUERY_PARAMS, + headers=_SOME_HEADERS | {"more strict query param key": "any value"}, + ) @HttpMocker() def test_given_post_request_match_when_decorate_then_return_response(self, http_mocker): @@ -63,13 +67,17 @@ def test_given_post_request_match_when_decorate_then_return_response(self, http_ HttpResponse(_A_RESPONSE_BODY, 474), ) - response = requests.post(_A_URL, params=_SOME_QUERY_PARAMS, headers=_SOME_HEADERS, data=_SOME_REQUEST_BODY_STR) + response = requests.post( + _A_URL, params=_SOME_QUERY_PARAMS, headers=_SOME_HEADERS, data=_SOME_REQUEST_BODY_STR + ) assert response.text == _A_RESPONSE_BODY assert response.status_code == 474 @HttpMocker() - def test_given_multiple_responses_when_decorate_get_request_then_return_response(self, http_mocker): + def test_given_multiple_responses_when_decorate_get_request_then_return_response( + self, http_mocker + ): http_mocker.get( HttpRequest(_A_URL, _SOME_QUERY_PARAMS, _SOME_HEADERS), [HttpResponse(_A_RESPONSE_BODY, 1), HttpResponse(_ANOTHER_RESPONSE_BODY, 2)], @@ -84,7 +92,9 @@ def test_given_multiple_responses_when_decorate_get_request_then_return_response assert second_response.status_code == 2 @HttpMocker() - def test_given_multiple_responses_when_decorate_delete_request_then_return_response(self, http_mocker): + def test_given_multiple_responses_when_decorate_delete_request_then_return_response( + self, http_mocker + ): http_mocker.delete( HttpRequest(_A_URL, headers=_SOME_HEADERS), [HttpResponse(_A_RESPONSE_BODY, 1), HttpResponse(_ANOTHER_RESPONSE_BODY, 2)], @@ -99,14 +109,20 @@ def test_given_multiple_responses_when_decorate_delete_request_then_return_respo assert second_response.status_code == 2 @HttpMocker() - def test_given_multiple_responses_when_decorate_post_request_then_return_response(self, http_mocker): + def test_given_multiple_responses_when_decorate_post_request_then_return_response( + self, http_mocker + ): http_mocker.post( HttpRequest(_A_URL, _SOME_QUERY_PARAMS, _SOME_HEADERS, _SOME_REQUEST_BODY_STR), [HttpResponse(_A_RESPONSE_BODY, 1), HttpResponse(_ANOTHER_RESPONSE_BODY, 2)], ) - first_response = requests.post(_A_URL, params=_SOME_QUERY_PARAMS, headers=_SOME_HEADERS, data=_SOME_REQUEST_BODY_STR) - second_response = requests.post(_A_URL, params=_SOME_QUERY_PARAMS, headers=_SOME_HEADERS, data=_SOME_REQUEST_BODY_STR) + first_response = requests.post( + _A_URL, params=_SOME_QUERY_PARAMS, headers=_SOME_HEADERS, data=_SOME_REQUEST_BODY_STR + ) + second_response = requests.post( + _A_URL, params=_SOME_QUERY_PARAMS, headers=_SOME_HEADERS, data=_SOME_REQUEST_BODY_STR + ) assert first_response.text == _A_RESPONSE_BODY assert first_response.status_code == 1 @@ -120,7 +136,10 @@ def test_given_more_requests_than_responses_when_decorate_then_raise_error(self, [HttpResponse(_A_RESPONSE_BODY, 1), HttpResponse(_ANOTHER_RESPONSE_BODY, 2)], ) - last_response = [requests.get(_A_URL, params=_SOME_QUERY_PARAMS, headers=_SOME_HEADERS) for _ in range(10)][-1] + last_response = [ + requests.get(_A_URL, params=_SOME_QUERY_PARAMS, headers=_SOME_HEADERS) + for _ in range(10) + ][-1] assert last_response.text == _ANOTHER_RESPONSE_BODY assert last_response.status_code == 2 @@ -158,7 +177,9 @@ def decorated_function(http_mocker): with pytest.raises(AssertionError): decorated_function() - def test_given_assertion_error_but_missing_request_when_decorate_then_raise_missing_http_request(self): + def test_given_assertion_error_but_missing_request_when_decorate_then_raise_missing_http_request( + self, + ): @HttpMocker() def decorated_function(http_mocker): http_mocker.get( @@ -202,9 +223,13 @@ def decorated_function(http_mocker): with pytest.raises(ValueError) as exc_info: decorated_function() - assert "more_granular" in str(exc_info.value) # the matcher corresponding to the first `http_mocker.get` is not matched + assert "more_granular" in str( + exc_info.value + ) # the matcher corresponding to the first `http_mocker.get` is not matched - def test_given_exact_number_of_call_provided_when_assert_number_of_calls_then_do_not_raise(self): + def test_given_exact_number_of_call_provided_when_assert_number_of_calls_then_do_not_raise( + self, + ): @HttpMocker() def decorated_function(http_mocker): request = HttpRequest(_A_URL) diff --git a/unit_tests/test/mock_http/test_request.py b/unit_tests/test/mock_http/test_request.py index a5a94ea0..15d1f667 100644 --- a/unit_tests/test/mock_http/test_request.py +++ b/unit_tests/test/mock_http/test_request.py @@ -8,21 +8,37 @@ class HttpRequestMatcherTest(TestCase): def test_given_query_params_as_dict_and_string_then_query_params_are_properly_considered(self): - with_string = HttpRequest("mock://test.com/path", query_params="a_query_param=q1&a_list_param=first&a_list_param=second") - with_dict = HttpRequest("mock://test.com/path", query_params={"a_query_param": "q1", "a_list_param": ["first", "second"]}) + with_string = HttpRequest( + "mock://test.com/path", + query_params="a_query_param=q1&a_list_param=first&a_list_param=second", + ) + with_dict = HttpRequest( + "mock://test.com/path", + query_params={"a_query_param": "q1", "a_list_param": ["first", "second"]}, + ) assert with_string.matches(with_dict) and with_dict.matches(with_string) def test_given_query_params_in_url_and_also_provided_then_raise_error(self): with pytest.raises(ValueError): - HttpRequest("mock://test.com/path?a_query_param=1", query_params={"another_query_param": "2"}) + HttpRequest( + "mock://test.com/path?a_query_param=1", query_params={"another_query_param": "2"} + ) def test_given_same_url_query_params_and_subset_headers_when_matches_then_return_true(self): - request_to_match = HttpRequest("mock://test.com/path", {"a_query_param": "q1"}, {"first_header": "h1"}) - actual_request = HttpRequest("mock://test.com/path", {"a_query_param": "q1"}, {"first_header": "h1", "second_header": "h2"}) + request_to_match = HttpRequest( + "mock://test.com/path", {"a_query_param": "q1"}, {"first_header": "h1"} + ) + actual_request = HttpRequest( + "mock://test.com/path", + {"a_query_param": "q1"}, + {"first_header": "h1", "second_header": "h2"}, + ) assert actual_request.matches(request_to_match) def test_given_url_differs_when_matches_then_return_false(self): - assert not HttpRequest("mock://test.com/another_path").matches(HttpRequest("mock://test.com/path")) + assert not HttpRequest("mock://test.com/another_path").matches( + HttpRequest("mock://test.com/path") + ) def test_given_query_params_differs_when_matches_then_return_false(self): request_to_match = HttpRequest("mock://test.com/path", {"a_query_param": "q1"}) @@ -31,27 +47,39 @@ def test_given_query_params_differs_when_matches_then_return_false(self): def test_given_query_params_is_subset_differs_when_matches_then_return_false(self): request_to_match = HttpRequest("mock://test.com/path", {"a_query_param": "q1"}) - actual_request = HttpRequest("mock://test.com/path", {"a_query_param": "q1", "another_query_param": "q2"}) + actual_request = HttpRequest( + "mock://test.com/path", {"a_query_param": "q1", "another_query_param": "q2"} + ) assert not actual_request.matches(request_to_match) def test_given_headers_is_subset_differs_when_matches_then_return_true(self): request_to_match = HttpRequest("mock://test.com/path", headers={"first_header": "h1"}) - actual_request = HttpRequest("mock://test.com/path", headers={"first_header": "h1", "second_header": "h2"}) + actual_request = HttpRequest( + "mock://test.com/path", headers={"first_header": "h1", "second_header": "h2"} + ) assert actual_request.matches(request_to_match) def test_given_headers_value_does_not_match_differs_when_matches_then_return_false(self): request_to_match = HttpRequest("mock://test.com/path", headers={"first_header": "h1"}) - actual_request = HttpRequest("mock://test.com/path", headers={"first_header": "value does not match"}) + actual_request = HttpRequest( + "mock://test.com/path", headers={"first_header": "value does not match"} + ) assert not actual_request.matches(request_to_match) def test_given_same_body_mappings_value_when_matches_then_return_true(self): - request_to_match = HttpRequest("mock://test.com/path", body={"first_field": "first_value", "second_field": 2}) - actual_request = HttpRequest("mock://test.com/path", body={"first_field": "first_value", "second_field": 2}) + request_to_match = HttpRequest( + "mock://test.com/path", body={"first_field": "first_value", "second_field": 2} + ) + actual_request = HttpRequest( + "mock://test.com/path", body={"first_field": "first_value", "second_field": 2} + ) assert actual_request.matches(request_to_match) def test_given_bodies_are_mapping_and_differs_when_matches_then_return_false(self): request_to_match = HttpRequest("mock://test.com/path", body={"first_field": "first_value"}) - actual_request = HttpRequest("mock://test.com/path", body={"first_field": "value does not match"}) + actual_request = HttpRequest( + "mock://test.com/path", body={"first_field": "value does not match"} + ) assert not actual_request.matches(request_to_match) def test_given_same_mapping_and_bytes_when_matches_then_return_true(self): @@ -61,7 +89,9 @@ def test_given_same_mapping_and_bytes_when_matches_then_return_true(self): def test_given_different_mapping_and_bytes_when_matches_then_return_false(self): request_to_match = HttpRequest("mock://test.com/path", body={"first_field": "first_value"}) - actual_request = HttpRequest("mock://test.com/path", body=b'{"first_field": "another value"}') + actual_request = HttpRequest( + "mock://test.com/path", body=b'{"first_field": "another value"}' + ) assert not actual_request.matches(request_to_match) def test_given_same_mapping_and_str_when_matches_then_return_true(self): @@ -71,26 +101,36 @@ def test_given_same_mapping_and_str_when_matches_then_return_true(self): def test_given_different_mapping_and_str_when_matches_then_return_false(self): request_to_match = HttpRequest("mock://test.com/path", body={"first_field": "first_value"}) - actual_request = HttpRequest("mock://test.com/path", body='{"first_field": "another value"}') + actual_request = HttpRequest( + "mock://test.com/path", body='{"first_field": "another value"}' + ) assert not actual_request.matches(request_to_match) def test_given_same_bytes_and_mapping_when_matches_then_return_true(self): - request_to_match = HttpRequest("mock://test.com/path", body=b'{"first_field": "first_value"}') + request_to_match = HttpRequest( + "mock://test.com/path", body=b'{"first_field": "first_value"}' + ) actual_request = HttpRequest("mock://test.com/path", body={"first_field": "first_value"}) assert actual_request.matches(request_to_match) def test_given_different_bytes_and_mapping_when_matches_then_return_false(self): - request_to_match = HttpRequest("mock://test.com/path", body=b'{"first_field": "first_value"}') + request_to_match = HttpRequest( + "mock://test.com/path", body=b'{"first_field": "first_value"}' + ) actual_request = HttpRequest("mock://test.com/path", body={"first_field": "another value"}) assert not actual_request.matches(request_to_match) def test_given_same_str_and_mapping_when_matches_then_return_true(self): - request_to_match = HttpRequest("mock://test.com/path", body='{"first_field": "first_value"}') + request_to_match = HttpRequest( + "mock://test.com/path", body='{"first_field": "first_value"}' + ) actual_request = HttpRequest("mock://test.com/path", body={"first_field": "first_value"}) assert actual_request.matches(request_to_match) def test_given_different_str_and_mapping_when_matches_then_return_false(self): - request_to_match = HttpRequest("mock://test.com/path", body='{"first_field": "first_value"}') + request_to_match = HttpRequest( + "mock://test.com/path", body='{"first_field": "first_value"}' + ) actual_request = HttpRequest("mock://test.com/path", body={"first_field": "another value"}) assert not actual_request.matches(request_to_match) diff --git a/unit_tests/test/mock_http/test_response_builder.py b/unit_tests/test/mock_http/test_response_builder.py index c8ccdc41..cf7fbe50 100644 --- a/unit_tests/test/mock_http/test_response_builder.py +++ b/unit_tests/test/mock_http/test_response_builder.py @@ -38,17 +38,25 @@ def _record_builder( record_id_path: Optional[Path] = None, record_cursor_path: Optional[Union[FieldPath, NestedPath]] = None, ) -> RecordBuilder: - return create_record_builder(deepcopy(response_template), records_path, record_id_path, record_cursor_path) + return create_record_builder( + deepcopy(response_template), records_path, record_id_path, record_cursor_path + ) def _any_record_builder() -> RecordBuilder: - return create_record_builder({"record_path": [{"a_record": "record value"}]}, FieldPath("record_path")) + return create_record_builder( + {"record_path": [{"a_record": "record value"}]}, FieldPath("record_path") + ) def _response_builder( - response_template: Dict[str, Any], records_path: Union[FieldPath, NestedPath], pagination_strategy: Optional[PaginationStrategy] = None + response_template: Dict[str, Any], + records_path: Union[FieldPath, NestedPath], + pagination_strategy: Optional[PaginationStrategy] = None, ) -> HttpResponseBuilder: - return create_response_builder(deepcopy(response_template), records_path, pagination_strategy=pagination_strategy) + return create_response_builder( + deepcopy(response_template), records_path, pagination_strategy=pagination_strategy + ) def _body(response: HttpResponse) -> Dict[str, Any]: @@ -57,13 +65,19 @@ def _body(response: HttpResponse) -> Dict[str, Any]: class RecordBuilderTest(TestCase): def test_given_with_id_when_build_then_set_id(self) -> None: - builder = _record_builder({_RECORDS_FIELD: [{_ID_FIELD: "an id"}]}, FieldPath(_RECORDS_FIELD), FieldPath(_ID_FIELD)) + builder = _record_builder( + {_RECORDS_FIELD: [{_ID_FIELD: "an id"}]}, + FieldPath(_RECORDS_FIELD), + FieldPath(_ID_FIELD), + ) record = builder.with_id("another id").build() assert record[_ID_FIELD] == "another id" def test_given_nested_id_when_build_then_set_id(self) -> None: builder = _record_builder( - {_RECORDS_FIELD: [{"nested": {_ID_FIELD: "id"}}]}, FieldPath(_RECORDS_FIELD), NestedPath(["nested", _ID_FIELD]) + {_RECORDS_FIELD: [{"nested": {_ID_FIELD: "id"}}]}, + FieldPath(_RECORDS_FIELD), + NestedPath(["nested", _ID_FIELD]), ) record = builder.with_id("another id").build() assert record["nested"][_ID_FIELD] == "another id" @@ -75,11 +89,17 @@ def test_given_id_path_not_provided_but_with_id_when_build_then_raise_error(self def test_given_no_id_in_template_for_path_when_build_then_raise_error(self) -> None: with pytest.raises(ValueError): - _record_builder({_RECORDS_FIELD: [{"record without id": "should fail"}]}, FieldPath(_RECORDS_FIELD), FieldPath(_ID_FIELD)) + _record_builder( + {_RECORDS_FIELD: [{"record without id": "should fail"}]}, + FieldPath(_RECORDS_FIELD), + FieldPath(_ID_FIELD), + ) def test_given_with_cursor_when_build_then_set_id(self) -> None: builder = _record_builder( - {_RECORDS_FIELD: [{_CURSOR_FIELD: "a cursor"}]}, FieldPath(_RECORDS_FIELD), record_cursor_path=FieldPath(_CURSOR_FIELD) + {_RECORDS_FIELD: [{_CURSOR_FIELD: "a cursor"}]}, + FieldPath(_RECORDS_FIELD), + record_cursor_path=FieldPath(_CURSOR_FIELD), ) record = builder.with_cursor("another cursor").build() assert record[_CURSOR_FIELD] == "another cursor" @@ -119,7 +139,9 @@ def test_given_no_cursor_in_template_for_path_when_build_then_raise_error(self) class HttpResponseBuilderTest(TestCase): def test_given_records_in_template_but_no_with_records_when_build_then_no_records(self) -> None: - builder = _response_builder({_RECORDS_FIELD: [{"a_record_field": "a record value"}]}, FieldPath(_RECORDS_FIELD)) + builder = _response_builder( + {_RECORDS_FIELD: [{"a_record_field": "a record value"}]}, FieldPath(_RECORDS_FIELD) + ) response = builder.build() assert len(_body(response)[_RECORDS_FIELD]) == 0 @@ -148,7 +170,9 @@ def test_given_pagination_with_strategy_when_build_then_apply_strategy(self) -> builder = _response_builder( {"has_more_pages": False} | _SOME_RECORDS, FieldPath(_RECORDS_FIELD), - pagination_strategy=FieldUpdatePaginationStrategy(FieldPath("has_more_pages"), "yes more page"), + pagination_strategy=FieldUpdatePaginationStrategy( + FieldPath("has_more_pages"), "yes more page" + ), ) response = builder.with_pagination().build() @@ -166,10 +190,14 @@ def test_from_resource_file(self) -> None: template = find_template("test-resource", __file__) assert template == {"test-source template": "this is a template for test-resource"} - def test_given_cwd_doesnt_have_unit_tests_as_parent_when_from_resource_file__then_raise_error(self) -> None: + def test_given_cwd_doesnt_have_unit_tests_as_parent_when_from_resource_file__then_raise_error( + self, + ) -> None: with pytest.raises(ValueError): find_template("test-resource", str(FilePath(__file__).parent.parent.parent.parent)) - def test_given_records_path_invalid_when_create_builders_from_resource_then_raise_exception(self) -> None: + def test_given_records_path_invalid_when_create_builders_from_resource_then_raise_exception( + self, + ) -> None: with pytest.raises(ValueError): create_record_builder(_A_RESPONSE_TEMPLATE, NestedPath(["invalid", "record", "path"])) diff --git a/unit_tests/test/test_entrypoint_wrapper.py b/unit_tests/test/test_entrypoint_wrapper.py index 37533057..3ead41f5 100644 --- a/unit_tests/test/test_entrypoint_wrapper.py +++ b/unit_tests/test/test_entrypoint_wrapper.py @@ -38,7 +38,10 @@ def _a_state_message(stream_name: str, stream_state: Mapping[str, Any]) -> Airby return AirbyteMessage( type=Type.STATE, state=AirbyteStateMessage( - stream=AirbyteStreamState(stream_descriptor=StreamDescriptor(name=stream_name), stream_state=AirbyteStateBlob(**stream_state)) + stream=AirbyteStreamState( + stream_descriptor=StreamDescriptor(name=stream_name), + stream_state=AirbyteStateBlob(**stream_state), + ) ), ) @@ -62,10 +65,15 @@ def _a_status_message(stream_name: str, status: AirbyteStreamStatus) -> AirbyteM catalog=AirbyteCatalog(streams=[]), ) _A_RECORD = AirbyteMessage( - type=Type.RECORD, record=AirbyteRecordMessage(stream="stream", data={"record key": "record value"}, emitted_at=0) + type=Type.RECORD, + record=AirbyteRecordMessage(stream="stream", data={"record key": "record value"}, emitted_at=0), +) +_A_STATE_MESSAGE = _a_state_message( + "stream_name", {"state key": "state value for _A_STATE_MESSAGE"} +) +_A_LOG = AirbyteMessage( + type=Type.LOG, log=AirbyteLogMessage(level=Level.INFO, message="This is an Airbyte log message") ) -_A_STATE_MESSAGE = _a_state_message("stream_name", {"state key": "state value for _A_STATE_MESSAGE"}) -_A_LOG = AirbyteMessage(type=Type.LOG, log=AirbyteLogMessage(level=Level.INFO, message="This is an Airbyte log message")) _AN_ERROR_MESSAGE = AirbyteMessage( type=Type.TRACE, trace=AirbyteTraceMessage( @@ -123,7 +131,12 @@ def _validate_tmp_catalog(expected, file_path) -> None: assert ConfiguredAirbyteCatalogSerializer.load(orjson.loads(open(file_path).read())) == expected -def _create_tmp_file_validation(entrypoint, expected_config, expected_catalog: Optional[Any] = None, expected_state: Optional[Any] = None): +def _create_tmp_file_validation( + entrypoint, + expected_config, + expected_catalog: Optional[Any] = None, + expected_state: Optional[Any] = None, +): def _validate_tmp_files(self): _validate_tmp_json_file(expected_config, entrypoint.parse_args.call_args.args[0][2]) if expected_catalog: @@ -184,19 +197,25 @@ def _do_some_logging(self): def test_given_record_when_discover_then_output_has_record(self, entrypoint): entrypoint.return_value.run.return_value = _to_entrypoint_output([_A_CATALOG_MESSAGE]) output = discover(self._a_source, _A_CONFIG) - assert AirbyteMessageSerializer.dump(output.catalog) == AirbyteMessageSerializer.dump(_A_CATALOG_MESSAGE) + assert AirbyteMessageSerializer.dump(output.catalog) == AirbyteMessageSerializer.dump( + _A_CATALOG_MESSAGE + ) @patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint") def test_given_log_when_discover_then_output_has_log(self, entrypoint): entrypoint.return_value.run.return_value = _to_entrypoint_output([_A_LOG]) output = discover(self._a_source, _A_CONFIG) - assert AirbyteMessageSerializer.dump(output.logs[0]) == AirbyteMessageSerializer.dump(_A_LOG) + assert AirbyteMessageSerializer.dump(output.logs[0]) == AirbyteMessageSerializer.dump( + _A_LOG + ) @patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint") def test_given_trace_message_when_discover_then_output_has_trace_messages(self, entrypoint): entrypoint.return_value.run.return_value = _to_entrypoint_output([_AN_ANALYTIC_MESSAGE]) output = discover(self._a_source, _A_CONFIG) - assert AirbyteMessageSerializer.dump(output.analytics_messages[0]) == AirbyteMessageSerializer.dump(_AN_ANALYTIC_MESSAGE) + assert AirbyteMessageSerializer.dump( + output.analytics_messages[0] + ) == AirbyteMessageSerializer.dump(_AN_ANALYTIC_MESSAGE) @patch("airbyte_cdk.test.entrypoint_wrapper.print", create=True) @patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint") @@ -225,7 +244,9 @@ def setUp(self) -> None: @patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint") def test_when_read_then_ensure_parameters(self, entrypoint): - entrypoint.return_value.run.side_effect = _create_tmp_file_validation(entrypoint, _A_CONFIG, _A_CATALOG, _A_STATE) + entrypoint.return_value.run.side_effect = _create_tmp_file_validation( + entrypoint, _A_CONFIG, _A_CATALOG, _A_STATE + ) read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE) @@ -262,45 +283,64 @@ def _do_some_logging(self): def test_given_record_when_read_then_output_has_record(self, entrypoint): entrypoint.return_value.run.return_value = _to_entrypoint_output([_A_RECORD]) output = read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE) - assert AirbyteMessageSerializer.dump(output.records[0]) == AirbyteMessageSerializer.dump(_A_RECORD) + assert AirbyteMessageSerializer.dump(output.records[0]) == AirbyteMessageSerializer.dump( + _A_RECORD + ) @patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint") def test_given_state_message_when_read_then_output_has_state_message(self, entrypoint): entrypoint.return_value.run.return_value = _to_entrypoint_output([_A_STATE_MESSAGE]) output = read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE) - assert AirbyteMessageSerializer.dump(output.state_messages[0]) == AirbyteMessageSerializer.dump(_A_STATE_MESSAGE) + assert AirbyteMessageSerializer.dump( + output.state_messages[0] + ) == AirbyteMessageSerializer.dump(_A_STATE_MESSAGE) @patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint") - def test_given_state_message_and_records_when_read_then_output_has_records_and_state_message(self, entrypoint): - entrypoint.return_value.run.return_value = _to_entrypoint_output([_A_RECORD, _A_STATE_MESSAGE]) + def test_given_state_message_and_records_when_read_then_output_has_records_and_state_message( + self, entrypoint + ): + entrypoint.return_value.run.return_value = _to_entrypoint_output( + [_A_RECORD, _A_STATE_MESSAGE] + ) output = read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE) - assert [AirbyteMessageSerializer.dump(message) for message in output.records_and_state_messages] == [ - AirbyteMessageSerializer.dump(message) for message in (_A_RECORD, _A_STATE_MESSAGE) - ] + assert [ + AirbyteMessageSerializer.dump(message) for message in output.records_and_state_messages + ] == [AirbyteMessageSerializer.dump(message) for message in (_A_RECORD, _A_STATE_MESSAGE)] @patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint") - def test_given_many_state_messages_and_records_when_read_then_output_has_records_and_state_message(self, entrypoint): + def test_given_many_state_messages_and_records_when_read_then_output_has_records_and_state_message( + self, entrypoint + ): state_value = {"state_key": "last state value"} last_emitted_state = AirbyteStreamState( - stream_descriptor=StreamDescriptor(name="stream_name"), stream_state=AirbyteStateBlob(**state_value) + stream_descriptor=StreamDescriptor(name="stream_name"), + stream_state=AirbyteStateBlob(**state_value), + ) + entrypoint.return_value.run.return_value = _to_entrypoint_output( + [_A_STATE_MESSAGE, _a_state_message("stream_name", state_value)] ) - entrypoint.return_value.run.return_value = _to_entrypoint_output([_A_STATE_MESSAGE, _a_state_message("stream_name", state_value)]) output = read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE) - assert AirbyteStreamStateSerializer.dump(output.most_recent_state) == AirbyteStreamStateSerializer.dump(last_emitted_state) + assert AirbyteStreamStateSerializer.dump( + output.most_recent_state + ) == AirbyteStreamStateSerializer.dump(last_emitted_state) @patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint") def test_given_log_when_read_then_output_has_log(self, entrypoint): entrypoint.return_value.run.return_value = _to_entrypoint_output([_A_LOG]) output = read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE) - assert AirbyteMessageSerializer.dump(output.logs[0]) == AirbyteMessageSerializer.dump(_A_LOG) + assert AirbyteMessageSerializer.dump(output.logs[0]) == AirbyteMessageSerializer.dump( + _A_LOG + ) @patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint") def test_given_trace_message_when_read_then_output_has_trace_messages(self, entrypoint): entrypoint.return_value.run.return_value = _to_entrypoint_output([_AN_ANALYTIC_MESSAGE]) output = read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE) - assert AirbyteMessageSerializer.dump(output.analytics_messages[0]) == AirbyteMessageSerializer.dump(_AN_ANALYTIC_MESSAGE) + assert AirbyteMessageSerializer.dump( + output.analytics_messages[0] + ) == AirbyteMessageSerializer.dump(_AN_ANALYTIC_MESSAGE) @patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint") def test_given_stream_statuses_when_read_then_return_statuses(self, entrypoint): @@ -310,10 +350,15 @@ def test_given_stream_statuses_when_read_then_return_statuses(self, entrypoint): ] entrypoint.return_value.run.return_value = _to_entrypoint_output(status_messages) output = read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE) - assert output.get_stream_statuses(_A_STREAM_NAME) == [AirbyteStreamStatus.STARTED, AirbyteStreamStatus.COMPLETE] + assert output.get_stream_statuses(_A_STREAM_NAME) == [ + AirbyteStreamStatus.STARTED, + AirbyteStreamStatus.COMPLETE, + ] @patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint") - def test_given_stream_statuses_for_many_streams_when_read_then_filter_other_streams(self, entrypoint): + def test_given_stream_statuses_for_many_streams_when_read_then_filter_other_streams( + self, entrypoint + ): status_messages = [ _a_status_message(_A_STREAM_NAME, AirbyteStreamStatus.STARTED), _a_status_message("another stream name", AirbyteStreamStatus.INCOMPLETE), diff --git a/unit_tests/test_config_observation.py b/unit_tests/test_config_observation.py index a1828544..677e318f 100644 --- a/unit_tests/test_config_observation.py +++ b/unit_tests/test_config_observation.py @@ -6,7 +6,12 @@ import time import pytest -from airbyte_cdk.config_observation import ConfigObserver, ObservedDict, create_connector_config_control_message, observe_connector_config +from airbyte_cdk.config_observation import ( + ConfigObserver, + ObservedDict, + create_connector_config_control_message, + observe_connector_config, +) from airbyte_cdk.models import AirbyteControlConnectorConfigMessage, OrchestratorType, Type @@ -14,7 +19,12 @@ class TestObservedDict: def test_update_called_on_set_item(self, mocker): mock_observer = mocker.Mock() my_observed_dict = ObservedDict( - {"key": "value", "nested_dict": {"key": "value"}, "list_of_dict": [{"key": "value"}, {"key": "value"}]}, mock_observer + { + "key": "value", + "nested_dict": {"key": "value"}, + "list_of_dict": [{"key": "value"}, {"key": "value"}], + }, + mock_observer, ) assert mock_observer.update.call_count == 0 diff --git a/unit_tests/test_connector.py b/unit_tests/test_connector.py index ea7de2e4..bc7255b9 100644 --- a/unit_tests/test_connector.py +++ b/unit_tests/test_connector.py @@ -102,7 +102,10 @@ def use_invalid_json_spec(self): @pytest.fixture def use_yaml_spec(self): - spec = {"documentationUrl": "https://airbyte.com/#yaml", "connectionSpecification": self.CONNECTION_SPECIFICATION} + spec = { + "documentationUrl": "https://airbyte.com/#yaml", + "connectionSpecification": self.CONNECTION_SPECIFICATION, + } yaml_path = os.path.join(SPEC_ROOT, "spec.yaml") with open(yaml_path, "w") as f: diff --git a/unit_tests/test_counter.py b/unit_tests/test_counter.py index 7f17ff76..f6d2c22b 100644 --- a/unit_tests/test_counter.py +++ b/unit_tests/test_counter.py @@ -15,14 +15,18 @@ def test_counter_init(): def test_counter_start_event(): with create_timer("Counter") as timer: - with mock.patch("airbyte_cdk.utils.event_timing.EventTimer.start_event") as mock_start_event: + with mock.patch( + "airbyte_cdk.utils.event_timing.EventTimer.start_event" + ) as mock_start_event: timer.start_event("test_event") mock_start_event.assert_called_with("test_event") def test_counter_finish_event(): with create_timer("Counter") as timer: - with mock.patch("airbyte_cdk.utils.event_timing.EventTimer.finish_event") as mock_finish_event: + with mock.patch( + "airbyte_cdk.utils.event_timing.EventTimer.finish_event" + ) as mock_finish_event: timer.finish_event("test_event") mock_finish_event.assert_called_with("test_event") diff --git a/unit_tests/test_entrypoint.py b/unit_tests/test_entrypoint.py index 3d4ffefa..40781e89 100644 --- a/unit_tests/test_entrypoint.py +++ b/unit_tests/test_entrypoint.py @@ -84,7 +84,9 @@ def spec_mock(mocker): control=AirbyteControlMessage( type=OrchestratorType.CONNECTOR_CONFIG, emitted_at=10, - connectorConfig=AirbyteControlConnectorConfigMessage(config={"any config": "a config value"}), + connectorConfig=AirbyteControlConnectorConfigMessage( + config={"any config": "a config value"} + ), ), ) @@ -92,36 +94,74 @@ def spec_mock(mocker): @pytest.fixture def entrypoint(mocker) -> AirbyteEntrypoint: message_repository = MagicMock() - message_repository.consume_queue.side_effect = [[message for message in [MESSAGE_FROM_REPOSITORY]], [], []] - mocker.patch.object(MockSource, "message_repository", new_callable=mocker.PropertyMock, return_value=message_repository) + message_repository.consume_queue.side_effect = [ + [message for message in [MESSAGE_FROM_REPOSITORY]], + [], + [], + ] + mocker.patch.object( + MockSource, + "message_repository", + new_callable=mocker.PropertyMock, + return_value=message_repository, + ) return AirbyteEntrypoint(MockSource()) def test_airbyte_entrypoint_init(mocker): mocker.patch.object(entrypoint_module, "init_uncaught_exception_handler") AirbyteEntrypoint(MockSource()) - entrypoint_module.init_uncaught_exception_handler.assert_called_once_with(entrypoint_module.logger) + entrypoint_module.init_uncaught_exception_handler.assert_called_once_with( + entrypoint_module.logger + ) @pytest.mark.parametrize( ["cmd", "args", "expected_args"], [ ("spec", {"debug": ""}, {"command": "spec", "debug": True}), - ("check", {"config": "config_path"}, {"command": "check", "config": "config_path", "debug": False}), - ("discover", {"config": "config_path", "debug": ""}, {"command": "discover", "config": "config_path", "debug": True}), + ( + "check", + {"config": "config_path"}, + {"command": "check", "config": "config_path", "debug": False}, + ), + ( + "discover", + {"config": "config_path", "debug": ""}, + {"command": "discover", "config": "config_path", "debug": True}, + ), ( "read", {"config": "config_path", "catalog": "catalog_path", "state": "None"}, - {"command": "read", "config": "config_path", "catalog": "catalog_path", "state": "None", "debug": False}, + { + "command": "read", + "config": "config_path", + "catalog": "catalog_path", + "state": "None", + "debug": False, + }, ), ( "read", - {"config": "config_path", "catalog": "catalog_path", "state": "state_path", "debug": ""}, - {"command": "read", "config": "config_path", "catalog": "catalog_path", "state": "state_path", "debug": True}, + { + "config": "config_path", + "catalog": "catalog_path", + "state": "state_path", + "debug": "", + }, + { + "command": "read", + "config": "config_path", + "catalog": "catalog_path", + "state": "state_path", + "debug": True, + }, ), ], ) -def test_parse_valid_args(cmd: str, args: Mapping[str, Any], expected_args, entrypoint: AirbyteEntrypoint): +def test_parse_valid_args( + cmd: str, args: Mapping[str, Any], expected_args, entrypoint: AirbyteEntrypoint +): arglist = _as_arglist(cmd, args) parsed_args = entrypoint.parse_args(arglist) assert vars(parsed_args) == expected_args @@ -135,7 +175,9 @@ def test_parse_valid_args(cmd: str, args: Mapping[str, Any], expected_args, entr ("read", {"config": "config_path", "catalog": "catalog_path"}), ], ) -def test_parse_missing_required_args(cmd: str, args: MutableMapping[str, Any], entrypoint: AirbyteEntrypoint): +def test_parse_missing_required_args( + cmd: str, args: MutableMapping[str, Any], entrypoint: AirbyteEntrypoint +): required_args = {"check": ["config"], "discover": ["config"], "read": ["config", "catalog"]} for required_arg in required_args[cmd]: argcopy = deepcopy(args) @@ -144,7 +186,11 @@ def test_parse_missing_required_args(cmd: str, args: MutableMapping[str, Any], e entrypoint.parse_args(_as_arglist(cmd, argcopy)) -def _wrap_message(submessage: Union[AirbyteConnectionStatus, ConnectorSpecification, AirbyteRecordMessage, AirbyteCatalog]) -> str: +def _wrap_message( + submessage: Union[ + AirbyteConnectionStatus, ConnectorSpecification, AirbyteRecordMessage, AirbyteCatalog + ], +) -> str: if isinstance(submessage, AirbyteConnectionStatus): message = AirbyteMessage(type=Type.CONNECTION_STATUS, connectionStatus=submessage) elif isinstance(submessage, ConnectorSpecification): @@ -168,7 +214,10 @@ def test_run_spec(entrypoint: AirbyteEntrypoint, mocker): messages = list(entrypoint.run(parsed_args)) - assert [orjson.dumps(AirbyteMessageSerializer.dump(MESSAGE_FROM_REPOSITORY)).decode(), _wrap_message(expected)] == messages + assert [ + orjson.dumps(AirbyteMessageSerializer.dump(MESSAGE_FROM_REPOSITORY)).decode(), + _wrap_message(expected), + ] == messages @pytest.fixture @@ -182,13 +231,41 @@ def config_mock(mocker, request): @pytest.mark.parametrize( "config_mock, schema, config_valid", [ - ({"username": "fake"}, {"type": "object", "properties": {"name": {"type": "string"}}, "additionalProperties": False}, False), - ({"username": "fake"}, {"type": "object", "properties": {"username": {"type": "string"}}, "additionalProperties": False}, True), - ({"username": "fake"}, {"type": "object", "properties": {"user": {"type": "string"}}}, True), - ({"username": "fake"}, {"type": "object", "properties": {"user": {"type": "string", "airbyte_secret": True}}}, True), + ( + {"username": "fake"}, + { + "type": "object", + "properties": {"name": {"type": "string"}}, + "additionalProperties": False, + }, + False, + ), + ( + {"username": "fake"}, + { + "type": "object", + "properties": {"username": {"type": "string"}}, + "additionalProperties": False, + }, + True, + ), + ( + {"username": "fake"}, + {"type": "object", "properties": {"user": {"type": "string"}}}, + True, + ), + ( + {"username": "fake"}, + {"type": "object", "properties": {"user": {"type": "string", "airbyte_secret": True}}}, + True, + ), ( {"username": "fake", "_limit": 22}, - {"type": "object", "properties": {"username": {"type": "string"}}, "additionalProperties": False}, + { + "type": "object", + "properties": {"username": {"type": "string"}}, + "additionalProperties": False, + }, True, ), ], @@ -198,18 +275,28 @@ def test_config_validate(entrypoint: AirbyteEntrypoint, mocker, config_mock, sch parsed_args = Namespace(command="check", config="config_path") check_value = AirbyteConnectionStatus(status=Status.SUCCEEDED) mocker.patch.object(MockSource, "check", return_value=check_value) - mocker.patch.object(MockSource, "spec", return_value=ConnectorSpecification(connectionSpecification=schema)) + mocker.patch.object( + MockSource, "spec", return_value=ConnectorSpecification(connectionSpecification=schema) + ) messages = list(entrypoint.run(parsed_args)) if config_valid: - assert [orjson.dumps(AirbyteMessageSerializer.dump(MESSAGE_FROM_REPOSITORY)).decode(), _wrap_message(check_value)] == messages + assert [ + orjson.dumps(AirbyteMessageSerializer.dump(MESSAGE_FROM_REPOSITORY)).decode(), + _wrap_message(check_value), + ] == messages else: assert len(messages) == 2 - assert messages[0] == orjson.dumps(AirbyteMessageSerializer.dump(MESSAGE_FROM_REPOSITORY)).decode() + assert ( + messages[0] + == orjson.dumps(AirbyteMessageSerializer.dump(MESSAGE_FROM_REPOSITORY)).decode() + ) connection_status_message = AirbyteMessage(**orjson.loads(messages[1])) assert connection_status_message.type == Type.CONNECTION_STATUS.value assert connection_status_message.connectionStatus.get("status") == Status.FAILED.value - assert connection_status_message.connectionStatus.get("message").startswith("Config validation error:") + assert connection_status_message.connectionStatus.get("message").startswith( + "Config validation error:" + ) def test_run_check(entrypoint: AirbyteEntrypoint, mocker, spec_mock, config_mock): @@ -219,7 +306,10 @@ def test_run_check(entrypoint: AirbyteEntrypoint, mocker, spec_mock, config_mock messages = list(entrypoint.run(parsed_args)) - assert [orjson.dumps(AirbyteMessageSerializer.dump(MESSAGE_FROM_REPOSITORY)).decode(), _wrap_message(check_value)] == messages + assert [ + orjson.dumps(AirbyteMessageSerializer.dump(MESSAGE_FROM_REPOSITORY)).decode(), + _wrap_message(check_value), + ] == messages assert spec_mock.called @@ -234,7 +324,9 @@ def test_run_check_with_exception(entrypoint: AirbyteEntrypoint, mocker, spec_mo @freezegun.freeze_time("1970-01-01T00:00:00.001Z") -def test_run_check_with_traced_exception(entrypoint: AirbyteEntrypoint, mocker, spec_mock, config_mock): +def test_run_check_with_traced_exception( + entrypoint: AirbyteEntrypoint, mocker, spec_mock, config_mock +): exception = AirbyteTracedException.from_exception(ValueError("Any error")) parsed_args = Namespace(command="check", config="config_path") mocker.patch.object(MockSource, "check", side_effect=exception) @@ -257,13 +349,20 @@ def test_run_check_with_config_error(entrypoint: AirbyteEntrypoint, mocker, spec expected_messages = [ orjson.dumps(AirbyteMessageSerializer.dump(MESSAGE_FROM_REPOSITORY)).decode(), orjson.dumps(AirbyteMessageSerializer.dump(expected_trace)).decode(), - _wrap_message(AirbyteConnectionStatus(status=Status.FAILED, message=AirbyteTracedException.from_exception(exception).message)), + _wrap_message( + AirbyteConnectionStatus( + status=Status.FAILED, + message=AirbyteTracedException.from_exception(exception).message, + ) + ), ] assert messages == expected_messages @freezegun.freeze_time("1970-01-01T00:00:00.001Z") -def test_run_check_with_transient_error(entrypoint: AirbyteEntrypoint, mocker, spec_mock, config_mock): +def test_run_check_with_transient_error( + entrypoint: AirbyteEntrypoint, mocker, spec_mock, config_mock +): exception = AirbyteTracedException.from_exception(ValueError("Any error")) exception.failure_type = FailureType.transient_error parsed_args = Namespace(command="check", config="config_path") @@ -275,12 +374,21 @@ def test_run_check_with_transient_error(entrypoint: AirbyteEntrypoint, mocker, s def test_run_discover(entrypoint: AirbyteEntrypoint, mocker, spec_mock, config_mock): parsed_args = Namespace(command="discover", config="config_path") - expected = AirbyteCatalog(streams=[AirbyteStream(name="stream", json_schema={"k": "v"}, supported_sync_modes=[SyncMode.full_refresh])]) + expected = AirbyteCatalog( + streams=[ + AirbyteStream( + name="stream", json_schema={"k": "v"}, supported_sync_modes=[SyncMode.full_refresh] + ) + ] + ) mocker.patch.object(MockSource, "discover", return_value=expected) messages = list(entrypoint.run(parsed_args)) - assert [orjson.dumps(AirbyteMessageSerializer.dump(MESSAGE_FROM_REPOSITORY)).decode(), _wrap_message(expected)] == messages + assert [ + orjson.dumps(AirbyteMessageSerializer.dump(MESSAGE_FROM_REPOSITORY)).decode(), + _wrap_message(expected), + ] == messages assert spec_mock.called @@ -290,43 +398,61 @@ def test_run_discover_with_exception(entrypoint: AirbyteEntrypoint, mocker, spec with pytest.raises(ValueError): messages = list(entrypoint.run(parsed_args)) - assert [orjson.dumps(AirbyteMessageSerializer.dump(MESSAGE_FROM_REPOSITORY)).decode()] == messages + assert [ + orjson.dumps(AirbyteMessageSerializer.dump(MESSAGE_FROM_REPOSITORY)).decode() + ] == messages def test_run_read(entrypoint: AirbyteEntrypoint, mocker, spec_mock, config_mock): - parsed_args = Namespace(command="read", config="config_path", state="statepath", catalog="catalogpath") + parsed_args = Namespace( + command="read", config="config_path", state="statepath", catalog="catalogpath" + ) expected = AirbyteRecordMessage(stream="stream", data={"data": "stuff"}, emitted_at=1) mocker.patch.object(MockSource, "read_state", return_value={}) mocker.patch.object(MockSource, "read_catalog", return_value={}) - mocker.patch.object(MockSource, "read", return_value=[AirbyteMessage(record=expected, type=Type.RECORD)]) + mocker.patch.object( + MockSource, "read", return_value=[AirbyteMessage(record=expected, type=Type.RECORD)] + ) messages = list(entrypoint.run(parsed_args)) - assert [orjson.dumps(AirbyteMessageSerializer.dump(MESSAGE_FROM_REPOSITORY)).decode(), _wrap_message(expected)] == messages + assert [ + orjson.dumps(AirbyteMessageSerializer.dump(MESSAGE_FROM_REPOSITORY)).decode(), + _wrap_message(expected), + ] == messages assert spec_mock.called def test_given_message_emitted_during_config_when_read_then_emit_message_before_next_steps( entrypoint: AirbyteEntrypoint, mocker, spec_mock, config_mock ): - parsed_args = Namespace(command="read", config="config_path", state="statepath", catalog="catalogpath") + parsed_args = Namespace( + command="read", config="config_path", state="statepath", catalog="catalogpath" + ) mocker.patch.object(MockSource, "read_catalog", side_effect=ValueError) messages = entrypoint.run(parsed_args) - assert next(messages) == orjson.dumps(AirbyteMessageSerializer.dump(MESSAGE_FROM_REPOSITORY)).decode() + assert ( + next(messages) + == orjson.dumps(AirbyteMessageSerializer.dump(MESSAGE_FROM_REPOSITORY)).decode() + ) with pytest.raises(ValueError): next(messages) def test_run_read_with_exception(entrypoint: AirbyteEntrypoint, mocker, spec_mock, config_mock): - parsed_args = Namespace(command="read", config="config_path", state="statepath", catalog="catalogpath") + parsed_args = Namespace( + command="read", config="config_path", state="statepath", catalog="catalogpath" + ) mocker.patch.object(MockSource, "read_state", return_value={}) mocker.patch.object(MockSource, "read_catalog", return_value={}) mocker.patch.object(MockSource, "read", side_effect=ValueError("Any error")) with pytest.raises(ValueError): messages = list(entrypoint.run(parsed_args)) - assert [orjson.dumps(AirbyteMessageSerializer.dump(MESSAGE_FROM_REPOSITORY)).decode()] == messages + assert [ + orjson.dumps(AirbyteMessageSerializer.dump(MESSAGE_FROM_REPOSITORY)).decode() + ] == messages def test_invalid_command(entrypoint: AirbyteEntrypoint, config_mock): @@ -337,17 +463,63 @@ def test_invalid_command(entrypoint: AirbyteEntrypoint, config_mock): @pytest.mark.parametrize( "deployment_mode, url, expected_error", [ - pytest.param("CLOUD", "https://airbyte.com", None, id="test_cloud_public_endpoint_is_successful"), - pytest.param("CLOUD", "https://192.168.27.30", AirbyteTracedException, id="test_cloud_private_ip_address_is_rejected"), - pytest.param("CLOUD", "https://localhost:8080/api/v1/cast", AirbyteTracedException, id="test_cloud_private_endpoint_is_rejected"), - pytest.param("CLOUD", "http://past.lives.net/api/v1/inyun", ValueError, id="test_cloud_unsecured_endpoint_is_rejected"), - pytest.param("CLOUD", "https://not:very/cash:443.money", ValueError, id="test_cloud_invalid_url_format"), - pytest.param("CLOUD", "https://192.168.27.30 ", ValueError, id="test_cloud_incorrect_ip_format_is_rejected"), - pytest.param("cloud", "https://192.168.27.30", AirbyteTracedException, id="test_case_insensitive_cloud_environment_variable"), - pytest.param("OSS", "https://airbyte.com", None, id="test_oss_public_endpoint_is_successful"), - pytest.param("OSS", "https://192.168.27.30", None, id="test_oss_private_endpoint_is_successful"), - pytest.param("OSS", "https://localhost:8080/api/v1/cast", None, id="test_oss_private_endpoint_is_successful"), - pytest.param("OSS", "http://past.lives.net/api/v1/inyun", None, id="test_oss_unsecured_endpoint_is_successful"), + pytest.param( + "CLOUD", "https://airbyte.com", None, id="test_cloud_public_endpoint_is_successful" + ), + pytest.param( + "CLOUD", + "https://192.168.27.30", + AirbyteTracedException, + id="test_cloud_private_ip_address_is_rejected", + ), + pytest.param( + "CLOUD", + "https://localhost:8080/api/v1/cast", + AirbyteTracedException, + id="test_cloud_private_endpoint_is_rejected", + ), + pytest.param( + "CLOUD", + "http://past.lives.net/api/v1/inyun", + ValueError, + id="test_cloud_unsecured_endpoint_is_rejected", + ), + pytest.param( + "CLOUD", + "https://not:very/cash:443.money", + ValueError, + id="test_cloud_invalid_url_format", + ), + pytest.param( + "CLOUD", + "https://192.168.27.30 ", + ValueError, + id="test_cloud_incorrect_ip_format_is_rejected", + ), + pytest.param( + "cloud", + "https://192.168.27.30", + AirbyteTracedException, + id="test_case_insensitive_cloud_environment_variable", + ), + pytest.param( + "OSS", "https://airbyte.com", None, id="test_oss_public_endpoint_is_successful" + ), + pytest.param( + "OSS", "https://192.168.27.30", None, id="test_oss_private_endpoint_is_successful" + ), + pytest.param( + "OSS", + "https://localhost:8080/api/v1/cast", + None, + id="test_oss_private_endpoint_is_successful", + ), + pytest.param( + "OSS", + "http://past.lives.net/api/v1/inyun", + None, + id="test_oss_unsecured_endpoint_is_successful", + ), ], ) @patch.object(requests.Session, "send", lambda self, request, **kwargs: requests.Response()) @@ -374,9 +546,15 @@ def test_filter_internal_requests(deployment_mode, url, expected_error): "incoming_message, stream_message_count, expected_message, expected_records_by_stream", [ pytest.param( - AirbyteMessage(type=Type.RECORD, record=AirbyteRecordMessage(stream="customers", data={"id": "12345"}, emitted_at=1)), + AirbyteMessage( + type=Type.RECORD, + record=AirbyteRecordMessage(stream="customers", data={"id": "12345"}, emitted_at=1), + ), {HashableStreamDescriptor(name="customers"): 100.0}, - AirbyteMessage(type=Type.RECORD, record=AirbyteRecordMessage(stream="customers", data={"id": "12345"}, emitted_at=1)), + AirbyteMessage( + type=Type.RECORD, + record=AirbyteRecordMessage(stream="customers", data={"id": "12345"}, emitted_at=1), + ), {HashableStreamDescriptor(name="customers"): 101.0}, id="test_handle_record_message", ), @@ -386,7 +564,8 @@ def test_filter_internal_requests(deployment_mode, url, expected_error): state=AirbyteStateMessage( type=AirbyteStateType.STREAM, stream=AirbyteStreamState( - stream_descriptor=StreamDescriptor(name="customers"), stream_state=AirbyteStateBlob(updated_at="2024-02-02") + stream_descriptor=StreamDescriptor(name="customers"), + stream_state=AirbyteStateBlob(updated_at="2024-02-02"), ), ), ), @@ -396,7 +575,8 @@ def test_filter_internal_requests(deployment_mode, url, expected_error): state=AirbyteStateMessage( type=AirbyteStateType.STREAM, stream=AirbyteStreamState( - stream_descriptor=StreamDescriptor(name="customers"), stream_state=AirbyteStateBlob(updated_at="2024-02-02") + stream_descriptor=StreamDescriptor(name="customers"), + stream_state=AirbyteStateBlob(updated_at="2024-02-02"), ), sourceStats=AirbyteStateStats(recordCount=100.0), ), @@ -405,9 +585,15 @@ def test_filter_internal_requests(deployment_mode, url, expected_error): id="test_handle_state_message", ), pytest.param( - AirbyteMessage(type=Type.RECORD, record=AirbyteRecordMessage(stream="customers", data={"id": "12345"}, emitted_at=1)), + AirbyteMessage( + type=Type.RECORD, + record=AirbyteRecordMessage(stream="customers", data={"id": "12345"}, emitted_at=1), + ), defaultdict(float), - AirbyteMessage(type=Type.RECORD, record=AirbyteRecordMessage(stream="customers", data={"id": "12345"}, emitted_at=1)), + AirbyteMessage( + type=Type.RECORD, + record=AirbyteRecordMessage(stream="customers", data={"id": "12345"}, emitted_at=1), + ), {HashableStreamDescriptor(name="customers"): 1.0}, id="test_handle_first_record_message", ), @@ -417,7 +603,8 @@ def test_filter_internal_requests(deployment_mode, url, expected_error): trace=AirbyteTraceMessage( type=TraceType.STREAM_STATUS, stream_status=AirbyteStreamStatusTraceMessage( - stream_descriptor=StreamDescriptor(name="customers"), status=AirbyteStreamStatus.COMPLETE + stream_descriptor=StreamDescriptor(name="customers"), + status=AirbyteStreamStatus.COMPLETE, ), emitted_at=1, ), @@ -428,7 +615,8 @@ def test_filter_internal_requests(deployment_mode, url, expected_error): trace=AirbyteTraceMessage( type=TraceType.STREAM_STATUS, stream_status=AirbyteStreamStatusTraceMessage( - stream_descriptor=StreamDescriptor(name="customers"), status=AirbyteStreamStatus.COMPLETE + stream_descriptor=StreamDescriptor(name="customers"), + status=AirbyteStreamStatus.COMPLETE, ), emitted_at=1, ), @@ -437,10 +625,22 @@ def test_filter_internal_requests(deployment_mode, url, expected_error): id="test_handle_other_message_type", ), pytest.param( - AirbyteMessage(type=Type.RECORD, record=AirbyteRecordMessage(stream="others", data={"id": "12345"}, emitted_at=1)), - {HashableStreamDescriptor(name="customers"): 100.0, HashableStreamDescriptor(name="others"): 27.0}, - AirbyteMessage(type=Type.RECORD, record=AirbyteRecordMessage(stream="others", data={"id": "12345"}, emitted_at=1)), - {HashableStreamDescriptor(name="customers"): 100.0, HashableStreamDescriptor(name="others"): 28.0}, + AirbyteMessage( + type=Type.RECORD, + record=AirbyteRecordMessage(stream="others", data={"id": "12345"}, emitted_at=1), + ), + { + HashableStreamDescriptor(name="customers"): 100.0, + HashableStreamDescriptor(name="others"): 27.0, + }, + AirbyteMessage( + type=Type.RECORD, + record=AirbyteRecordMessage(stream="others", data={"id": "12345"}, emitted_at=1), + ), + { + HashableStreamDescriptor(name="customers"): 100.0, + HashableStreamDescriptor(name="others"): 28.0, + }, id="test_handle_record_message_for_other_stream", ), pytest.param( @@ -449,31 +649,45 @@ def test_filter_internal_requests(deployment_mode, url, expected_error): state=AirbyteStateMessage( type=AirbyteStateType.STREAM, stream=AirbyteStreamState( - stream_descriptor=StreamDescriptor(name="others"), stream_state=AirbyteStateBlob(updated_at="2024-02-02") + stream_descriptor=StreamDescriptor(name="others"), + stream_state=AirbyteStateBlob(updated_at="2024-02-02"), ), ), ), - {HashableStreamDescriptor(name="customers"): 100.0, HashableStreamDescriptor(name="others"): 27.0}, + { + HashableStreamDescriptor(name="customers"): 100.0, + HashableStreamDescriptor(name="others"): 27.0, + }, AirbyteMessage( type=Type.STATE, state=AirbyteStateMessage( type=AirbyteStateType.STREAM, stream=AirbyteStreamState( - stream_descriptor=StreamDescriptor(name="others"), stream_state=AirbyteStateBlob(updated_at="2024-02-02") + stream_descriptor=StreamDescriptor(name="others"), + stream_state=AirbyteStateBlob(updated_at="2024-02-02"), ), sourceStats=AirbyteStateStats(recordCount=27.0), ), ), - {HashableStreamDescriptor(name="customers"): 100.0, HashableStreamDescriptor(name="others"): 0.0}, + { + HashableStreamDescriptor(name="customers"): 100.0, + HashableStreamDescriptor(name="others"): 0.0, + }, id="test_handle_state_message_for_other_stream", ), pytest.param( AirbyteMessage( - type=Type.RECORD, record=AirbyteRecordMessage(stream="customers", namespace="public", data={"id": "12345"}, emitted_at=1) + type=Type.RECORD, + record=AirbyteRecordMessage( + stream="customers", namespace="public", data={"id": "12345"}, emitted_at=1 + ), ), {HashableStreamDescriptor(name="customers", namespace="public"): 100.0}, AirbyteMessage( - type=Type.RECORD, record=AirbyteRecordMessage(stream="customers", namespace="public", data={"id": "12345"}, emitted_at=1) + type=Type.RECORD, + record=AirbyteRecordMessage( + stream="customers", namespace="public", data={"id": "12345"}, emitted_at=1 + ), ), {HashableStreamDescriptor(name="customers", namespace="public"): 101.0}, id="test_handle_record_message_with_descriptor", @@ -535,9 +749,13 @@ def test_filter_internal_requests(deployment_mode, url, expected_error): ), ], ) -def test_handle_record_counts(incoming_message, stream_message_count, expected_message, expected_records_by_stream): +def test_handle_record_counts( + incoming_message, stream_message_count, expected_message, expected_records_by_stream +): entrypoint = AirbyteEntrypoint(source=MockSource()) - actual_message = entrypoint.handle_record_counts(message=incoming_message, stream_message_count=stream_message_count) + actual_message = entrypoint.handle_record_counts( + message=incoming_message, stream_message_count=stream_message_count + ) assert actual_message == expected_message for stream_descriptor, message_count in stream_message_count.items(): @@ -546,4 +764,6 @@ def test_handle_record_counts(incoming_message, stream_message_count, expected_m assert message_count == expected_records_by_stream[stream_descriptor] if actual_message.type == Type.STATE: - assert isinstance(actual_message.state.sourceStats.recordCount, float), "recordCount value should be expressed as a float" + assert isinstance( + actual_message.state.sourceStats.recordCount, float + ), "recordCount value should be expressed as a float" diff --git a/unit_tests/test_exception_handler.py b/unit_tests/test_exception_handler.py index f135c19f..ee4bfaa1 100644 --- a/unit_tests/test_exception_handler.py +++ b/unit_tests/test_exception_handler.py @@ -53,7 +53,8 @@ def test_uncaught_exception_handler(): ) expected_log_message = AirbyteMessage( - type=MessageType.LOG, log=AirbyteLogMessage(level=Level.FATAL, message=f"{exception_message}\n{exception_trace}") + type=MessageType.LOG, + log=AirbyteLogMessage(level=Level.FATAL, message=f"{exception_message}\n{exception_trace}"), ) expected_trace_message = AirbyteMessage( @@ -86,4 +87,6 @@ def test_uncaught_exception_handler(): out_trace_message = AirbyteMessageSerializer.load(json.loads(trace_output)) assert out_trace_message.trace.emitted_at > 0 out_trace_message.trace.emitted_at = 0.0 # set a specific emitted_at value for testing - assert out_trace_message == expected_trace_message, "Trace message should be emitted in expected form" + assert ( + out_trace_message == expected_trace_message + ), "Trace message should be emitted in expected form" diff --git a/unit_tests/test_secure_logger.py b/unit_tests/test_secure_logger.py index 44e0a3d9..4f2abf90 100644 --- a/unit_tests/test_secure_logger.py +++ b/unit_tests/test_secure_logger.py @@ -10,7 +10,13 @@ import pytest from airbyte_cdk import AirbyteEntrypoint from airbyte_cdk.logger import AirbyteLogFormatter -from airbyte_cdk.models import AirbyteMessage, AirbyteRecordMessage, ConfiguredAirbyteCatalog, ConnectorSpecification, Type +from airbyte_cdk.models import ( + AirbyteMessage, + AirbyteRecordMessage, + ConfiguredAirbyteCatalog, + ConnectorSpecification, + Type, +) from airbyte_cdk.sources import Source SECRET_PROPERTY = "api_token" @@ -35,7 +41,11 @@ def read( state: MutableMapping[str, Any] = None, ) -> Iterable[AirbyteMessage]: logger.info(I_AM_A_SECRET_VALUE) - logger.info(I_AM_A_SECRET_VALUE + " plus Some non secret Value in the same log record" + NOT_A_SECRET_VALUE) + logger.info( + I_AM_A_SECRET_VALUE + + " plus Some non secret Value in the same log record" + + NOT_A_SECRET_VALUE + ) logger.info(NOT_A_SECRET_VALUE) yield AirbyteMessage( record=AirbyteRecordMessage(stream="stream", data={"data": "stuff"}, emitted_at=1), @@ -139,8 +149,12 @@ def test_airbyte_secret_is_masked_on_logger_output(source_spec, mocker, config, mocker.patch.object(MockSource, "read_catalog", return_value={}) list(entrypoint.run(parsed_args)) log_result = caplog.text - expected_secret_values = [config[k] for k, v in source_spec["properties"].items() if v.get("airbyte_secret")] - expected_plain_text_values = [config[k] for k, v in source_spec["properties"].items() if not v.get("airbyte_secret")] + expected_secret_values = [ + config[k] for k, v in source_spec["properties"].items() if v.get("airbyte_secret") + ] + expected_plain_text_values = [ + config[k] for k, v in source_spec["properties"].items() if not v.get("airbyte_secret") + ] assert all([str(v) not in log_result for v in expected_secret_values]) assert all([str(v) in log_result for v in expected_plain_text_values]) @@ -188,8 +202,12 @@ def read( list(entrypoint.run(parsed_args)) except Exception: sys.excepthook(*sys.exc_info()) - assert I_AM_A_SECRET_VALUE not in capsys.readouterr().out, "Should have filtered non-secret value from exception trace message" - assert I_AM_A_SECRET_VALUE not in caplog.text, "Should have filtered secret value from exception log message" + assert ( + I_AM_A_SECRET_VALUE not in capsys.readouterr().out + ), "Should have filtered non-secret value from exception trace message" + assert ( + I_AM_A_SECRET_VALUE not in caplog.text + ), "Should have filtered secret value from exception log message" def test_non_airbyte_secrets_are_not_masked_on_uncaught_exceptions(mocker, caplog, capsys): @@ -230,11 +248,17 @@ def read( mocker.patch.object(MockSource, "read_config", return_value=None) mocker.patch.object(MockSource, "read_state", return_value={}) mocker.patch.object(MockSource, "read_catalog", return_value={}) - mocker.patch.object(MockSource, "read", side_effect=Exception("Exception:" + NOT_A_SECRET_VALUE)) + mocker.patch.object( + MockSource, "read", side_effect=Exception("Exception:" + NOT_A_SECRET_VALUE) + ) try: list(entrypoint.run(parsed_args)) except Exception: sys.excepthook(*sys.exc_info()) - assert NOT_A_SECRET_VALUE in capsys.readouterr().out, "Should not have filtered non-secret value from exception trace message" - assert NOT_A_SECRET_VALUE in caplog.text, "Should not have filtered non-secret value from exception log message" + assert ( + NOT_A_SECRET_VALUE in capsys.readouterr().out + ), "Should not have filtered non-secret value from exception trace message" + assert ( + NOT_A_SECRET_VALUE in caplog.text + ), "Should not have filtered non-secret value from exception log message" diff --git a/unit_tests/utils/test_datetime_format_inferrer.py b/unit_tests/utils/test_datetime_format_inferrer.py index 5e76b9cf..1e69f3d1 100644 --- a/unit_tests/utils/test_datetime_format_inferrer.py +++ b/unit_tests/utils/test_datetime_format_inferrer.py @@ -22,12 +22,29 @@ ("timestamp_ms_match_string", [{"d": "1686058051000"}], {"d": "%ms"}), ("timestamp_no_match_integer", [{"d": 99}], {}), ("timestamp_no_match_string", [{"d": "99999999999999999999"}], {}), - ("timestamp_overflow", [{"d": f"{10**100}_100"}], {}), # this case was previously causing OverflowError hence this test + ( + "timestamp_overflow", + [{"d": f"{10**100}_100"}], + {}, + ), # this case was previously causing OverflowError hence this test ("simple_no_match", [{"d": "20220203"}], {}), - ("multiple_match", [{"d": "2022-02-03", "e": "2022-02-03"}], {"d": "%Y-%m-%d", "e": "%Y-%m-%d"}), + ( + "multiple_match", + [{"d": "2022-02-03", "e": "2022-02-03"}], + {"d": "%Y-%m-%d", "e": "%Y-%m-%d"}, + ), ( "multiple_no_match", - [{"d": "20220203", "r": "ccc", "e": {"something-else": "2023-03-03"}, "s": ["2023-03-03"], "x": False, "y": 123}], + [ + { + "d": "20220203", + "r": "ccc", + "e": {"something-else": "2023-03-03"}, + "s": ["2023-03-03"], + "x": False, + "y": 123, + } + ], {}, ), ("format_1", [{"d": "2022-02-03"}], {"d": "%Y-%m-%d"}), @@ -36,20 +53,47 @@ ("format_4 1", [{"d": "2022-02-03T12:34:56.000Z"}], {"d": "%Y-%m-%dT%H:%M:%S.%fZ"}), ("format_4 2", [{"d": "2022-02-03T12:34:56.000000Z"}], {"d": "%Y-%m-%dT%H:%M:%S.%fZ"}), ("format_5", [{"d": "2022-02-03 12:34:56.123456+00:00"}], {"d": "%Y-%m-%d %H:%M:%S.%f%z"}), - ("format_5 2", [{"d": "2022-02-03 12:34:56.123456+02:00"}], {"d": "%Y-%m-%d %H:%M:%S.%f%z"}), + ( + "format_5 2", + [{"d": "2022-02-03 12:34:56.123456+02:00"}], + {"d": "%Y-%m-%d %H:%M:%S.%f%z"}, + ), ("format_6", [{"d": "2022-02-03T12:34:56.123456+0000"}], {"d": "%Y-%m-%dT%H:%M:%S.%f%z"}), - ("format_6 2", [{"d": "2022-02-03T12:34:56.123456+00:00"}], {"d": "%Y-%m-%dT%H:%M:%S.%f%z"}), - ("format_6 3", [{"d": "2022-02-03T12:34:56.123456-03:00"}], {"d": "%Y-%m-%dT%H:%M:%S.%f%z"}), + ( + "format_6 2", + [{"d": "2022-02-03T12:34:56.123456+00:00"}], + {"d": "%Y-%m-%dT%H:%M:%S.%f%z"}, + ), + ( + "format_6 3", + [{"d": "2022-02-03T12:34:56.123456-03:00"}], + {"d": "%Y-%m-%dT%H:%M:%S.%f%z"}, + ), ("format_7", [{"d": "03/02/2022 12:34"}], {"d": "%d/%m/%Y %H:%M"}), ("format_8", [{"d": "2022-02"}], {"d": "%Y-%m"}), ("format_9", [{"d": "03-02-2022"}], {"d": "%d-%m-%Y"}), - ("limit_down", [{"d": "2022-02-03", "x": "2022-02-03"}, {"d": "2022-02-03", "x": "another thing"}], {"d": "%Y-%m-%d"}), - ("limit_down all", [{"d": "2022-02-03", "x": "2022-02-03"}, {"d": "also another thing", "x": "another thing"}], {}), + ( + "limit_down", + [{"d": "2022-02-03", "x": "2022-02-03"}, {"d": "2022-02-03", "x": "another thing"}], + {"d": "%Y-%m-%d"}, + ), + ( + "limit_down all", + [ + {"d": "2022-02-03", "x": "2022-02-03"}, + {"d": "also another thing", "x": "another thing"}, + ], + {}, + ), ("limit_down empty", [{"d": "2022-02-03", "x": "2022-02-03"}, {}], {}), ("limit_down unsupported type", [{"d": "2022-02-03"}, {"d": False}], {}), ("limit_down complex type", [{"d": "2022-02-03"}, {"d": {"date": "2022-03-03"}}], {}), ("limit_down different format", [{"d": "2022-02-03"}, {"d": 1686058051}], {}), - ("limit_down different format", [{"d": "2022-02-03"}, {"d": "2022-02-03T12:34:56.000000Z"}], {}), + ( + "limit_down different format", + [{"d": "2022-02-03"}, {"d": "2022-02-03T12:34:56.000000Z"}], + {}, + ), ("no scope expand", [{}, {"d": "2022-02-03"}], {}), ], ) diff --git a/unit_tests/utils/test_message_utils.py b/unit_tests/utils/test_message_utils.py index 84fabf1a..e4567164 100644 --- a/unit_tests/utils/test_message_utils.py +++ b/unit_tests/utils/test_message_utils.py @@ -80,7 +80,9 @@ def test_get_other_message_stream_descriptor_fails(): control=AirbyteControlMessage( type=OrchestratorType.CONNECTOR_CONFIG, emitted_at=10, - connectorConfig=AirbyteControlConnectorConfigMessage(config={"any config": "a config value"}), + connectorConfig=AirbyteControlConnectorConfigMessage( + config={"any config": "a config value"} + ), ), ) with pytest.raises(NotImplementedError): diff --git a/unit_tests/utils/test_rate_limiting.py b/unit_tests/utils/test_rate_limiting.py index bc9d3ece..d4a78140 100644 --- a/unit_tests/utils/test_rate_limiting.py +++ b/unit_tests/utils/test_rate_limiting.py @@ -21,7 +21,11 @@ def helper_with_exceptions(exception_type): (3, 3, 1, exceptions.ChunkedEncodingError), ], ) -def test_default_backoff_handler(max_tries: int, max_time: int, factor: int, exception_to_raise: Exception): - backoff_handler = default_backoff_handler(max_tries=max_tries, max_time=max_time, factor=factor)(helper_with_exceptions) +def test_default_backoff_handler( + max_tries: int, max_time: int, factor: int, exception_to_raise: Exception +): + backoff_handler = default_backoff_handler( + max_tries=max_tries, max_time=max_time, factor=factor + )(helper_with_exceptions) with pytest.raises(exception_to_raise): backoff_handler(exception_to_raise) diff --git a/unit_tests/utils/test_schema_inferrer.py b/unit_tests/utils/test_schema_inferrer.py index 98d227c4..535ff41d 100644 --- a/unit_tests/utils/test_schema_inferrer.py +++ b/unit_tests/utils/test_schema_inferrer.py @@ -40,7 +40,10 @@ "obj": { "type": ["object", "null"], "properties": { - "data": {"type": ["array", "null"], "items": {"type": ["number", "null"]}}, + "data": { + "type": ["array", "null"], + "items": {"type": ["number", "null"]}, + }, "other_key": {"type": ["string", "null"]}, }, } @@ -76,7 +79,14 @@ {"stream": "my_stream", "data": {"field_A": None}}, {"stream": "my_stream", "data": {"field_A": {"nested": "abc"}}}, ], - {"my_stream": {"field_A": {"type": ["object", "null"], "properties": {"nested": {"type": ["string", "null"]}}}}}, + { + "my_stream": { + "field_A": { + "type": ["object", "null"], + "properties": {"nested": {"type": ["string", "null"]}}, + } + } + }, id="test_any_of", ), pytest.param( @@ -84,20 +94,33 @@ {"stream": "my_stream", "data": {"field_A": None}}, {"stream": "my_stream", "data": {"field_A": {"nested": "abc", "nully": None}}}, ], - {"my_stream": {"field_A": {"type": ["object", "null"], "properties": {"nested": {"type": ["string", "null"]}}}}}, + { + "my_stream": { + "field_A": { + "type": ["object", "null"], + "properties": {"nested": {"type": ["string", "null"]}}, + } + } + }, id="test_any_of_with_null", ), pytest.param( [ {"stream": "my_stream", "data": {"field_A": None}}, {"stream": "my_stream", "data": {"field_A": {"nested": "abc", "nully": None}}}, - {"stream": "my_stream", "data": {"field_A": {"nested": "abc", "nully": "a string"}}}, + { + "stream": "my_stream", + "data": {"field_A": {"nested": "abc", "nully": "a string"}}, + }, ], { "my_stream": { "field_A": { "type": ["object", "null"], - "properties": {"nested": {"type": ["string", "null"]}, "nully": {"type": ["string", "null"]}}, + "properties": { + "nested": {"type": ["string", "null"]}, + "nully": {"type": ["string", "null"]}, + }, } } }, @@ -105,7 +128,10 @@ ), pytest.param( [ - {"stream": "my_stream", "data": {"field_A": {"nested": "abc", "nully": "a string"}}}, + { + "stream": "my_stream", + "data": {"field_A": {"nested": "abc", "nully": "a string"}}, + }, {"stream": "my_stream", "data": {"field_A": None}}, {"stream": "my_stream", "data": {"field_A": {"nested": "abc", "nully": None}}}, ], @@ -113,7 +139,10 @@ "my_stream": { "field_A": { "type": ["object", "null"], - "properties": {"nested": {"type": ["string", "null"]}, "nully": {"type": ["string", "null"]}}, + "properties": { + "nested": {"type": ["string", "null"]}, + "nully": {"type": ["string", "null"]}, + }, } } }, @@ -123,19 +152,30 @@ [ {"stream": "my_stream", "data": {"field_A": "abc", "nested": {"field_B": None}}}, ], - {"my_stream": {"field_A": {"type": ["string", "null"]}, "nested": {"type": ["object", "null"], "properties": {}}}}, + { + "my_stream": { + "field_A": {"type": ["string", "null"]}, + "nested": {"type": ["object", "null"], "properties": {}}, + } + }, id="test_nested_null", ), pytest.param( [ - {"stream": "my_stream", "data": {"field_A": "abc", "nested": [{"field_B": None, "field_C": "abc"}]}}, + { + "stream": "my_stream", + "data": {"field_A": "abc", "nested": [{"field_B": None, "field_C": "abc"}]}, + }, ], { "my_stream": { "field_A": {"type": ["string", "null"]}, "nested": { "type": ["array", "null"], - "items": {"type": ["object", "null"], "properties": {"field_C": {"type": ["string", "null"]}}}, + "items": { + "type": ["object", "null"], + "properties": {"field_C": {"type": ["string", "null"]}}, + }, }, } }, @@ -144,14 +184,20 @@ pytest.param( [ {"stream": "my_stream", "data": {"field_A": "abc", "nested": None}}, - {"stream": "my_stream", "data": {"field_A": "abc", "nested": [{"field_B": None, "field_C": "abc"}]}}, + { + "stream": "my_stream", + "data": {"field_A": "abc", "nested": [{"field_B": None, "field_C": "abc"}]}, + }, ], { "my_stream": { "field_A": {"type": ["string", "null"]}, "nested": { "type": ["array", "null"], - "items": {"type": ["object", "null"], "properties": {"field_C": {"type": ["string", "null"]}}}, + "items": { + "type": ["object", "null"], + "properties": {"field_C": {"type": ["string", "null"]}}, + }, }, } }, @@ -176,7 +222,10 @@ { "title": "Nested_2", "type": "location", - "value": {"nested_key_1": "GB", "nested_key_2": "United Kingdom"}, + "value": { + "nested_key_1": "GB", + "nested_key_2": "United Kingdom", + }, }, ], } @@ -200,7 +249,10 @@ {"type": "array", "items": {"type": "string"}}, { "type": "object", - "properties": {"nested_key_1": {"type": "string"}, "nested_key_2": {"type": "string"}}, + "properties": { + "nested_key_1": {"type": "string"}, + "nested_key_2": {"type": "string"}, + }, }, ] }, @@ -218,7 +270,9 @@ def test_schema_derivation(input_records: List, expected_schemas: Mapping): inferrer = SchemaInferrer() for record in input_records: - inferrer.accumulate(AirbyteRecordMessage(stream=record["stream"], data=record["data"], emitted_at=NOW)) + inferrer.accumulate( + AirbyteRecordMessage(stream=record["stream"], data=record["data"], emitted_at=NOW) + ) for stream_name, expected_schema in expected_schemas.items(): assert inferrer.get_stream_schema(stream_name) == { @@ -250,7 +304,9 @@ def _create_inferrer_with_required_field(is_pk: bool, field: List[List[str]]) -> def test_field_is_on_root(is_pk: bool): inferrer = _create_inferrer_with_required_field(is_pk, [["property"]]) - inferrer.accumulate(AirbyteRecordMessage(stream=_STREAM_NAME, data={"property": _ANY_VALUE}, emitted_at=NOW)) + inferrer.accumulate( + AirbyteRecordMessage(stream=_STREAM_NAME, data={"property": _ANY_VALUE}, emitted_at=NOW) + ) assert inferrer.get_stream_schema(_STREAM_NAME)["required"] == ["property"] assert inferrer.get_stream_schema(_STREAM_NAME)["properties"]["property"]["type"] == "string" @@ -266,11 +322,17 @@ def test_field_is_on_root(is_pk: bool): def test_field_is_nested(is_pk: bool): inferrer = _create_inferrer_with_required_field(is_pk, [["property", "nested_property"]]) - inferrer.accumulate(AirbyteRecordMessage(stream=_STREAM_NAME, data={"property": {"nested_property": _ANY_VALUE}}, emitted_at=NOW)) + inferrer.accumulate( + AirbyteRecordMessage( + stream=_STREAM_NAME, data={"property": {"nested_property": _ANY_VALUE}}, emitted_at=NOW + ) + ) assert inferrer.get_stream_schema(_STREAM_NAME)["required"] == ["property"] assert inferrer.get_stream_schema(_STREAM_NAME)["properties"]["property"]["type"] == "object" - assert inferrer.get_stream_schema(_STREAM_NAME)["properties"]["property"]["required"] == ["nested_property"] + assert inferrer.get_stream_schema(_STREAM_NAME)["properties"]["property"]["required"] == [ + "nested_property" + ] @pytest.mark.parametrize( @@ -283,7 +345,11 @@ def test_field_is_nested(is_pk: bool): def test_field_is_composite(is_pk: bool): inferrer = _create_inferrer_with_required_field(is_pk, [["property 1"], ["property 2"]]) inferrer.accumulate( - AirbyteRecordMessage(stream=_STREAM_NAME, data={"property 1": _ANY_VALUE, "property 2": _ANY_VALUE}, emitted_at=NOW) + AirbyteRecordMessage( + stream=_STREAM_NAME, + data={"property 1": _ANY_VALUE, "property 2": _ANY_VALUE}, + emitted_at=NOW, + ) ) assert inferrer.get_stream_schema(_STREAM_NAME)["required"] == ["property 1", "property 2"] @@ -296,22 +362,37 @@ def test_field_is_composite(is_pk: bool): ], ) def test_field_is_composite_and_nested(is_pk: bool): - inferrer = _create_inferrer_with_required_field(is_pk, [["property 1", "nested"], ["property 2"]]) + inferrer = _create_inferrer_with_required_field( + is_pk, [["property 1", "nested"], ["property 2"]] + ) inferrer.accumulate( - AirbyteRecordMessage(stream=_STREAM_NAME, data={"property 1": {"nested": _ANY_VALUE}, "property 2": _ANY_VALUE}, emitted_at=NOW) + AirbyteRecordMessage( + stream=_STREAM_NAME, + data={"property 1": {"nested": _ANY_VALUE}, "property 2": _ANY_VALUE}, + emitted_at=NOW, + ) ) assert inferrer.get_stream_schema(_STREAM_NAME)["required"] == ["property 1", "property 2"] assert inferrer.get_stream_schema(_STREAM_NAME)["properties"]["property 1"]["type"] == "object" assert inferrer.get_stream_schema(_STREAM_NAME)["properties"]["property 2"]["type"] == "string" - assert inferrer.get_stream_schema(_STREAM_NAME)["properties"]["property 1"]["required"] == ["nested"] - assert inferrer.get_stream_schema(_STREAM_NAME)["properties"]["property 1"]["properties"]["nested"]["type"] == "string" + assert inferrer.get_stream_schema(_STREAM_NAME)["properties"]["property 1"]["required"] == [ + "nested" + ] + assert ( + inferrer.get_stream_schema(_STREAM_NAME)["properties"]["property 1"]["properties"][ + "nested" + ]["type"] + == "string" + ) def test_given_pk_does_not_exist_when_get_inferred_schemas_then_raise_error(): inferrer = SchemaInferrer([["pk does not exist"]]) - inferrer.accumulate(AirbyteRecordMessage(stream=_STREAM_NAME, data={"id": _ANY_VALUE}, emitted_at=NOW)) + inferrer.accumulate( + AirbyteRecordMessage(stream=_STREAM_NAME, data={"id": _ANY_VALUE}, emitted_at=NOW) + ) with pytest.raises(SchemaValidationException) as exception: inferrer.get_stream_schema(_STREAM_NAME) @@ -321,7 +402,9 @@ def test_given_pk_does_not_exist_when_get_inferred_schemas_then_raise_error(): def test_given_pk_path_is_partially_valid_when_get_inferred_schemas_then_validation_error_mentions_where_the_issue_is(): inferrer = SchemaInferrer([["id", "nested pk that does not exist"]]) - inferrer.accumulate(AirbyteRecordMessage(stream=_STREAM_NAME, data={"id": _ANY_VALUE}, emitted_at=NOW)) + inferrer.accumulate( + AirbyteRecordMessage(stream=_STREAM_NAME, data={"id": _ANY_VALUE}, emitted_at=NOW) + ) with pytest.raises(SchemaValidationException) as exception: inferrer.get_stream_schema(_STREAM_NAME) @@ -332,7 +415,9 @@ def test_given_pk_path_is_partially_valid_when_get_inferred_schemas_then_validat def test_given_composite_pk_but_only_one_path_valid_when_get_inferred_schemas_then_valid_path_is_required(): inferrer = SchemaInferrer([["id 1"], ["id 2"]]) - inferrer.accumulate(AirbyteRecordMessage(stream=_STREAM_NAME, data={"id 1": _ANY_VALUE}, emitted_at=NOW)) + inferrer.accumulate( + AirbyteRecordMessage(stream=_STREAM_NAME, data={"id 1": _ANY_VALUE}, emitted_at=NOW) + ) with pytest.raises(SchemaValidationException) as exception: inferrer.get_stream_schema(_STREAM_NAME) @@ -342,7 +427,9 @@ def test_given_composite_pk_but_only_one_path_valid_when_get_inferred_schemas_th def test_given_composite_pk_but_only_one_path_valid_when_get_inferred_schemas_then_validation_error_mentions_where_the_issue_is(): inferrer = SchemaInferrer([["id 1"], ["id 2"]]) - inferrer.accumulate(AirbyteRecordMessage(stream=_STREAM_NAME, data={"id 1": _ANY_VALUE}, emitted_at=NOW)) + inferrer.accumulate( + AirbyteRecordMessage(stream=_STREAM_NAME, data={"id 1": _ANY_VALUE}, emitted_at=NOW) + ) with pytest.raises(SchemaValidationException) as exception: inferrer.get_stream_schema(_STREAM_NAME) diff --git a/unit_tests/utils/test_secret_utils.py b/unit_tests/utils/test_secret_utils.py index 39c6ff73..846d0e12 100644 --- a/unit_tests/utils/test_secret_utils.py +++ b/unit_tests/utils/test_secret_utils.py @@ -3,7 +3,13 @@ # import pytest -from airbyte_cdk.utils.airbyte_secrets_utils import add_to_secrets, filter_secrets, get_secret_paths, get_secrets, update_secrets +from airbyte_cdk.utils.airbyte_secrets_utils import ( + add_to_secrets, + filter_secrets, + get_secret_paths, + get_secrets, + update_secrets, +) SECRET_STRING_KEY = "secret_key1" SECRET_STRING_VALUE = "secret_value" @@ -15,11 +21,19 @@ NOT_SECRET_VALUE = "unimportant value" -flat_spec_with_secret = {"properties": {SECRET_STRING_KEY: {"type": "string", "airbyte_secret": True}, NOT_SECRET_KEY: {"type": "string"}}} +flat_spec_with_secret = { + "properties": { + SECRET_STRING_KEY: {"type": "string", "airbyte_secret": True}, + NOT_SECRET_KEY: {"type": "string"}, + } +} flat_config_with_secret = {SECRET_STRING_KEY: SECRET_STRING_VALUE, NOT_SECRET_KEY: NOT_SECRET_VALUE} flat_spec_with_secret_int = { - "properties": {SECRET_INT_KEY: {"type": "integer", "airbyte_secret": True}, NOT_SECRET_KEY: {"type": "string"}} + "properties": { + SECRET_INT_KEY: {"type": "integer", "airbyte_secret": True}, + NOT_SECRET_KEY: {"type": "string"}, + } } flat_config_with_secret_int = {SECRET_INT_KEY: SECRET_INT_VALUE, NOT_SECRET_KEY: NOT_SECRET_VALUE} @@ -35,11 +49,17 @@ "oneOf": [ { "type": "object", - "properties": {SECRET_STRING_2_KEY: {"type": "string", "airbyte_secret": True}, NOT_SECRET_KEY: {"type": "string"}}, + "properties": { + SECRET_STRING_2_KEY: {"type": "string", "airbyte_secret": True}, + NOT_SECRET_KEY: {"type": "string"}, + }, }, { "type": "object", - "properties": {SECRET_INT_KEY: {"type": "integer", "airbyte_secret": True}, NOT_SECRET_KEY: {"type": "string"}}, + "properties": { + SECRET_INT_KEY: {"type": "integer", "airbyte_secret": True}, + NOT_SECRET_KEY: {"type": "string"}, + }, }, ], }, @@ -83,8 +103,22 @@ (flat_spec_with_secret, [[SECRET_STRING_KEY]]), (flat_spec_without_secrets, []), (flat_spec_with_secret_int, [[SECRET_INT_KEY]]), - (spec_with_oneof_secrets, [[SECRET_STRING_KEY], ["credentials", SECRET_STRING_2_KEY], ["credentials", SECRET_INT_KEY]]), - (spec_with_nested_secrets, [[SECRET_STRING_KEY], ["credentials", SECRET_STRING_2_KEY], ["credentials", SECRET_INT_KEY]]), + ( + spec_with_oneof_secrets, + [ + [SECRET_STRING_KEY], + ["credentials", SECRET_STRING_2_KEY], + ["credentials", SECRET_INT_KEY], + ], + ), + ( + spec_with_nested_secrets, + [ + [SECRET_STRING_KEY], + ["credentials", SECRET_STRING_2_KEY], + ["credentials", SECRET_INT_KEY], + ], + ), ], ) def test_get_secret_paths(spec, expected): @@ -97,17 +131,33 @@ def test_get_secret_paths(spec, expected): (flat_spec_with_secret, flat_config_with_secret, [SECRET_STRING_VALUE]), (flat_spec_without_secrets, flat_config_without_secrets, []), (flat_spec_with_secret_int, flat_config_with_secret_int, [SECRET_INT_VALUE]), - (spec_with_oneof_secrets, config_with_oneof_secrets_1, [SECRET_STRING_VALUE, SECRET_STRING_2_VALUE]), - (spec_with_oneof_secrets, config_with_oneof_secrets_2, [SECRET_STRING_VALUE, SECRET_INT_VALUE]), - (spec_with_nested_secrets, config_with_nested_secrets, [SECRET_STRING_VALUE, SECRET_STRING_2_VALUE, SECRET_INT_VALUE]), + ( + spec_with_oneof_secrets, + config_with_oneof_secrets_1, + [SECRET_STRING_VALUE, SECRET_STRING_2_VALUE], + ), + ( + spec_with_oneof_secrets, + config_with_oneof_secrets_2, + [SECRET_STRING_VALUE, SECRET_INT_VALUE], + ), + ( + spec_with_nested_secrets, + config_with_nested_secrets, + [SECRET_STRING_VALUE, SECRET_STRING_2_VALUE, SECRET_INT_VALUE], + ), ], ) def test_get_secrets(spec, config, expected): - assert get_secrets(spec, config) == expected, f"Expected the spec {spec} and config {config} to produce {expected}" + assert ( + get_secrets(spec, config) == expected + ), f"Expected the spec {spec} and config {config} to produce {expected}" def test_secret_filtering(): - sensitive_str = f"{SECRET_STRING_VALUE} {NOT_SECRET_VALUE} {SECRET_STRING_VALUE} {SECRET_STRING_2_VALUE}" + sensitive_str = ( + f"{SECRET_STRING_VALUE} {NOT_SECRET_VALUE} {SECRET_STRING_VALUE} {SECRET_STRING_2_VALUE}" + ) update_secrets([]) filtered = filter_secrets(sensitive_str) diff --git a/unit_tests/utils/test_stream_status_utils.py b/unit_tests/utils/test_stream_status_utils.py index 4862a1e0..494eb7ee 100644 --- a/unit_tests/utils/test_stream_status_utils.py +++ b/unit_tests/utils/test_stream_status_utils.py @@ -2,11 +2,21 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # -from airbyte_cdk.models import AirbyteMessage, AirbyteStream, AirbyteStreamStatus, SyncMode, TraceType +from airbyte_cdk.models import ( + AirbyteMessage, + AirbyteStream, + AirbyteStreamStatus, + SyncMode, + TraceType, +) from airbyte_cdk.models import Type as MessageType -from airbyte_cdk.utils.stream_status_utils import as_airbyte_message as stream_status_as_airbyte_message +from airbyte_cdk.utils.stream_status_utils import ( + as_airbyte_message as stream_status_as_airbyte_message, +) -stream = AirbyteStream(name="name", namespace="namespace", json_schema={}, supported_sync_modes=[SyncMode.full_refresh]) +stream = AirbyteStream( + name="name", namespace="namespace", json_schema={}, supported_sync_modes=[SyncMode.full_refresh] +) def test_started_as_message(): diff --git a/unit_tests/utils/test_traced_exception.py b/unit_tests/utils/test_traced_exception.py index 90c83165..2d2bcc81 100644 --- a/unit_tests/utils/test_traced_exception.py +++ b/unit_tests/utils/test_traced_exception.py @@ -32,7 +32,9 @@ def raised_exception(): def test_build_from_existing_exception(raised_exception): - traced_exc = AirbyteTracedException.from_exception(raised_exception, message="my user-friendly message") + traced_exc = AirbyteTracedException.from_exception( + raised_exception, message="my user-friendly message" + ) assert traced_exc.message == "my user-friendly message" assert traced_exc.internal_message == "an error has occurred" assert traced_exc.failure_type == FailureType.system_error @@ -48,9 +50,15 @@ def test_exception_as_airbyte_message(): assert airbyte_message.trace.type == TraceType.ERROR assert airbyte_message.trace.emitted_at > 0 assert airbyte_message.trace.error.failure_type == FailureType.system_error - assert airbyte_message.trace.error.message == "Something went wrong in the connector. See the logs for more details." + assert ( + airbyte_message.trace.error.message + == "Something went wrong in the connector. See the logs for more details." + ) assert airbyte_message.trace.error.internal_message == "an internal message" - assert airbyte_message.trace.error.stack_trace == "airbyte_cdk.utils.traced_exception.AirbyteTracedException: an internal message\n" + assert ( + airbyte_message.trace.error.stack_trace + == "airbyte_cdk.utils.traced_exception.AirbyteTracedException: an internal message\n" + ) def test_existing_exception_as_airbyte_message(raised_exception): @@ -60,7 +68,10 @@ def test_existing_exception_as_airbyte_message(raised_exception): assert isinstance(airbyte_message, AirbyteMessage) assert airbyte_message.type == MessageType.TRACE assert airbyte_message.trace.type == TraceType.ERROR - assert airbyte_message.trace.error.message == "Something went wrong in the connector. See the logs for more details." + assert ( + airbyte_message.trace.error.message + == "Something went wrong in the connector. See the logs for more details." + ) assert airbyte_message.trace.error.internal_message == "an error has occurred" assert airbyte_message.trace.error.stack_trace.startswith("Traceback (most recent call last):") assert airbyte_message.trace.error.stack_trace.endswith( @@ -69,7 +80,11 @@ def test_existing_exception_as_airbyte_message(raised_exception): def test_config_error_as_connection_status_message(): - traced_exc = AirbyteTracedException("an internal message", message="Config validation error", failure_type=FailureType.config_error) + traced_exc = AirbyteTracedException( + "an internal message", + message="Config validation error", + failure_type=FailureType.config_error, + ) airbyte_message = traced_exc.as_connection_status_message() assert isinstance(airbyte_message, AirbyteMessage) @@ -79,7 +94,9 @@ def test_config_error_as_connection_status_message(): def test_other_error_as_connection_status_message(): - traced_exc = AirbyteTracedException("an internal message", failure_type=FailureType.system_error) + traced_exc = AirbyteTracedException( + "an internal message", failure_type=FailureType.system_error + ) airbyte_message = traced_exc.as_connection_status_message() assert airbyte_message is None @@ -87,7 +104,9 @@ def test_other_error_as_connection_status_message(): def test_emit_message(capsys): traced_exc = AirbyteTracedException( - internal_message="internal message", message="user-friendly message", exception=RuntimeError("oh no") + internal_message="internal message", + message="user-friendly message", + exception=RuntimeError("oh no"), ) expected_message = AirbyteMessage( @@ -112,7 +131,9 @@ def test_emit_message(capsys): assert printed_message == expected_message -def test_given_both_init_and_as_message_with_stream_descriptor_when_as_airbyte_message_use_init_stream_descriptor() -> None: +def test_given_both_init_and_as_message_with_stream_descriptor_when_as_airbyte_message_use_init_stream_descriptor() -> ( + None +): traced_exc = AirbyteTracedException(stream_descriptor=_A_STREAM_DESCRIPTOR) message = traced_exc.as_airbyte_message(stream_descriptor=_ANOTHER_STREAM_DESCRIPTOR) assert message.trace.error.stream_descriptor == _A_STREAM_DESCRIPTOR @@ -126,8 +147,12 @@ def test_given_both_init_and_as_sanitized_airbyte_message_with_stream_descriptor assert message.trace.error.stream_descriptor == _A_STREAM_DESCRIPTOR -def test_given_both_from_exception_and_as_message_with_stream_descriptor_when_as_airbyte_message_use_init_stream_descriptor() -> None: - traced_exc = AirbyteTracedException.from_exception(_AN_EXCEPTION, stream_descriptor=_A_STREAM_DESCRIPTOR) +def test_given_both_from_exception_and_as_message_with_stream_descriptor_when_as_airbyte_message_use_init_stream_descriptor() -> ( + None +): + traced_exc = AirbyteTracedException.from_exception( + _AN_EXCEPTION, stream_descriptor=_A_STREAM_DESCRIPTOR + ) message = traced_exc.as_airbyte_message(stream_descriptor=_ANOTHER_STREAM_DESCRIPTOR) assert message.trace.error.stream_descriptor == _A_STREAM_DESCRIPTOR @@ -135,6 +160,8 @@ def test_given_both_from_exception_and_as_message_with_stream_descriptor_when_as def test_given_both_from_exception_and_as_sanitized_airbyte_message_with_stream_descriptor_when_as_airbyte_message_use_init_stream_descriptor() -> ( None ): - traced_exc = AirbyteTracedException.from_exception(_AN_EXCEPTION, stream_descriptor=_A_STREAM_DESCRIPTOR) + traced_exc = AirbyteTracedException.from_exception( + _AN_EXCEPTION, stream_descriptor=_A_STREAM_DESCRIPTOR + ) message = traced_exc.as_sanitized_airbyte_message(stream_descriptor=_ANOTHER_STREAM_DESCRIPTOR) assert message.trace.error.stream_descriptor == _A_STREAM_DESCRIPTOR