From 78ca8d9d4f11a1ce638ec7c49fe0d273e5f8b7bf Mon Sep 17 00:00:00 2001 From: Aaron Steers Date: Sun, 10 Nov 2024 12:27:20 -0800 Subject: [PATCH] fix tests --- airbyte_cdk/models/airbyte_protocol.py | 5 ++++- .../models/airbyte_protocol_serializers.py | 2 +- .../sources/connector_state_manager.py | 3 ++- .../config/file_based_stream_config.py | 22 +++++++++---------- airbyte_cdk/sources/file_based/remote_file.py | 7 ++---- pyproject.toml | 1 + .../sources/test_connector_state_manager.py | 6 ++++- 7 files changed, 25 insertions(+), 21 deletions(-) diff --git a/airbyte_cdk/models/airbyte_protocol.py b/airbyte_cdk/models/airbyte_protocol.py index e74e9195..057ff835 100644 --- a/airbyte_cdk/models/airbyte_protocol.py +++ b/airbyte_cdk/models/airbyte_protocol.py @@ -81,10 +81,13 @@ class AirbyteGlobalState: @dataclass class AirbyteStateMessage: type: Optional[models.AirbyteStateType] = None - stream: Optional[models.AirbyteStreamState] = None + + # These two use custom classes defined above + stream: Optional[AirbyteStreamState] = None global_: Annotated[Optional[AirbyteGlobalState], Alias("global")] = ( None # "global" is a reserved keyword in python ⇒ Alias is used for (de-)serialization ) + data: Optional[dict[str, Any]] = None sourceStats: Optional[models.AirbyteStateStats] = None destinationStats: Optional[models.AirbyteStateStats] = None diff --git a/airbyte_cdk/models/airbyte_protocol_serializers.py b/airbyte_cdk/models/airbyte_protocol_serializers.py index 5d6e48b5..8ae8f2cc 100644 --- a/airbyte_cdk/models/airbyte_protocol_serializers.py +++ b/airbyte_cdk/models/airbyte_protocol_serializers.py @@ -19,7 +19,7 @@ class AirbyteStateBlobType(CustomType[AirbyteStateBlob, dict[str, Any]]): def serialize(self, value: AirbyteStateBlob) -> dict[str, Any]: # cant use orjson.dumps() directly because private attributes are excluded, e.g. "__ab_full_refresh_sync_complete" - return dict(value.__dict__.items()) + return {k: v for k, v in value.__dict__.items()} def deserialize(self, value: dict[str, Any]) -> AirbyteStateBlob: return AirbyteStateBlob(value) diff --git a/airbyte_cdk/sources/connector_state_manager.py b/airbyte_cdk/sources/connector_state_manager.py index 43811832..f1f58d86 100644 --- a/airbyte_cdk/sources/connector_state_manager.py +++ b/airbyte_cdk/sources/connector_state_manager.py @@ -62,7 +62,7 @@ def get_stream_state(self, stream_name: str, namespace: str | None) -> MutableMa HashableStreamDescriptor(name=stream_name, namespace=namespace) ) if stream_state: - return copy.deepcopy(dict(stream_state.__dict__.items())) + return copy.deepcopy({k: v for k, v in stream_state.__dict__.items()}) return {} def update_state_for_stream( @@ -125,6 +125,7 @@ def _extract_from_state_message( for per_stream_state in global_state.stream_states # type: ignore[union-attr] # global_state has shared_state } return shared_state, streams + streams = { HashableStreamDescriptor( name=per_stream_state.stream.stream_descriptor.name, diff --git a/airbyte_cdk/sources/file_based/config/file_based_stream_config.py b/airbyte_cdk/sources/file_based/config/file_based_stream_config.py index 3ec5a042..9a47e0a0 100644 --- a/airbyte_cdk/sources/file_based/config/file_based_stream_config.py +++ b/airbyte_cdk/sources/file_based/config/file_based_stream_config.py @@ -3,26 +3,24 @@ # from __future__ import annotations +from collections.abc import Mapping + +# ruff: noqa: TCH001, TCH002, TCH003 # Don't move imports to TYPE_CHECKING blocks from enum import Enum -from typing import TYPE_CHECKING, Any +from typing import Any from pydantic.v1 import BaseModel, Field, validator +from airbyte_cdk.sources.file_based.config.avro_format import AvroFormat +from airbyte_cdk.sources.file_based.config.csv_format import CsvFormat +from airbyte_cdk.sources.file_based.config.excel_format import ExcelFormat +from airbyte_cdk.sources.file_based.config.jsonl_format import JsonlFormat +from airbyte_cdk.sources.file_based.config.parquet_format import ParquetFormat +from airbyte_cdk.sources.file_based.config.unstructured_format import UnstructuredFormat from airbyte_cdk.sources.file_based.exceptions import ConfigValidationError, FileBasedSourceError from airbyte_cdk.sources.file_based.schema_helpers import type_mapping_to_jsonschema -if TYPE_CHECKING: - from collections.abc import Mapping - - from airbyte_cdk.sources.file_based.config.avro_format import AvroFormat - from airbyte_cdk.sources.file_based.config.csv_format import CsvFormat - from airbyte_cdk.sources.file_based.config.excel_format import ExcelFormat - from airbyte_cdk.sources.file_based.config.jsonl_format import JsonlFormat - from airbyte_cdk.sources.file_based.config.parquet_format import ParquetFormat - from airbyte_cdk.sources.file_based.config.unstructured_format import UnstructuredFormat - - PrimaryKeyType = str | list[str] | None diff --git a/airbyte_cdk/sources/file_based/remote_file.py b/airbyte_cdk/sources/file_based/remote_file.py index 4b46f02e..a801468a 100644 --- a/airbyte_cdk/sources/file_based/remote_file.py +++ b/airbyte_cdk/sources/file_based/remote_file.py @@ -3,15 +3,12 @@ # from __future__ import annotations -from typing import TYPE_CHECKING +# ruff: noqa: TCH003 # Don'e move types to TYPE_CHECKING blocks. Pydantic needs them at runtime. +from datetime import datetime from pydantic.v1 import BaseModel -if TYPE_CHECKING: - from datetime import datetime - - class RemoteFile(BaseModel): """A file in a file-based stream.""" diff --git a/pyproject.toml b/pyproject.toml index 54cea412..11837d3b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -170,6 +170,7 @@ ignore = [ # Consider re-enabling these when we have time to address them: "A003", # Class attribute 'type' is shadowing a Python builtin "BLE001", # Do not catch blind exception: Exception + "C416", # Allow unnecessary-comprehensions. Auto-fix sometimes unsafe if operating over a mapping. "D", # pydocstyle (Docstring conventions) "D102", # Missing docstring in public method "D103", # Missing docstring in public function diff --git a/unit_tests/sources/test_connector_state_manager.py b/unit_tests/sources/test_connector_state_manager.py index 585d7ffc..7c0594a2 100644 --- a/unit_tests/sources/test_connector_state_manager.py +++ b/unit_tests/sources/test_connector_state_manager.py @@ -158,7 +158,11 @@ ), ), ) -def test_initialize_state_manager(input_stream_state, expected_stream_state, expected_error): +def test_initialize_state_manager( + input_stream_state, + expected_stream_state, + expected_error, +) -> None: if isinstance(input_stream_state, list): input_stream_state = [ AirbyteStateMessageSerializer.load(state_obj) for state_obj in list(input_stream_state)