Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
aaronsteers committed Nov 10, 2024
1 parent 879f78b commit 78ca8d9
Show file tree
Hide file tree
Showing 7 changed files with 25 additions and 21 deletions.
5 changes: 4 additions & 1 deletion airbyte_cdk/models/airbyte_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,13 @@ class AirbyteGlobalState:
@dataclass
class AirbyteStateMessage:
type: Optional[models.AirbyteStateType] = None
stream: Optional[models.AirbyteStreamState] = None

# These two use custom classes defined above
stream: Optional[AirbyteStreamState] = None
global_: Annotated[Optional[AirbyteGlobalState], Alias("global")] = (
None # "global" is a reserved keyword in python ⇒ Alias is used for (de-)serialization
)

data: Optional[dict[str, Any]] = None
sourceStats: Optional[models.AirbyteStateStats] = None
destinationStats: Optional[models.AirbyteStateStats] = None
Expand Down
2 changes: 1 addition & 1 deletion airbyte_cdk/models/airbyte_protocol_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
class AirbyteStateBlobType(CustomType[AirbyteStateBlob, dict[str, Any]]):
def serialize(self, value: AirbyteStateBlob) -> dict[str, Any]:
# cant use orjson.dumps() directly because private attributes are excluded, e.g. "__ab_full_refresh_sync_complete"
return dict(value.__dict__.items())
return {k: v for k, v in value.__dict__.items()}

def deserialize(self, value: dict[str, Any]) -> AirbyteStateBlob:
return AirbyteStateBlob(value)
Expand Down
3 changes: 2 additions & 1 deletion airbyte_cdk/sources/connector_state_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def get_stream_state(self, stream_name: str, namespace: str | None) -> MutableMa
HashableStreamDescriptor(name=stream_name, namespace=namespace)
)
if stream_state:
return copy.deepcopy(dict(stream_state.__dict__.items()))
return copy.deepcopy({k: v for k, v in stream_state.__dict__.items()})
return {}

def update_state_for_stream(
Expand Down Expand Up @@ -125,6 +125,7 @@ def _extract_from_state_message(
for per_stream_state in global_state.stream_states # type: ignore[union-attr] # global_state has shared_state
}
return shared_state, streams

streams = {
HashableStreamDescriptor(
name=per_stream_state.stream.stream_descriptor.name,
Expand Down
22 changes: 10 additions & 12 deletions airbyte_cdk/sources/file_based/config/file_based_stream_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,24 @@
#
from __future__ import annotations

from collections.abc import Mapping

# ruff: noqa: TCH001, TCH002, TCH003 # Don't move imports to TYPE_CHECKING blocks
from enum import Enum
from typing import TYPE_CHECKING, Any
from typing import Any

from pydantic.v1 import BaseModel, Field, validator

from airbyte_cdk.sources.file_based.config.avro_format import AvroFormat
from airbyte_cdk.sources.file_based.config.csv_format import CsvFormat
from airbyte_cdk.sources.file_based.config.excel_format import ExcelFormat
from airbyte_cdk.sources.file_based.config.jsonl_format import JsonlFormat
from airbyte_cdk.sources.file_based.config.parquet_format import ParquetFormat
from airbyte_cdk.sources.file_based.config.unstructured_format import UnstructuredFormat
from airbyte_cdk.sources.file_based.exceptions import ConfigValidationError, FileBasedSourceError
from airbyte_cdk.sources.file_based.schema_helpers import type_mapping_to_jsonschema


if TYPE_CHECKING:
from collections.abc import Mapping

from airbyte_cdk.sources.file_based.config.avro_format import AvroFormat
from airbyte_cdk.sources.file_based.config.csv_format import CsvFormat
from airbyte_cdk.sources.file_based.config.excel_format import ExcelFormat
from airbyte_cdk.sources.file_based.config.jsonl_format import JsonlFormat
from airbyte_cdk.sources.file_based.config.parquet_format import ParquetFormat
from airbyte_cdk.sources.file_based.config.unstructured_format import UnstructuredFormat


PrimaryKeyType = str | list[str] | None


Expand Down
7 changes: 2 additions & 5 deletions airbyte_cdk/sources/file_based/remote_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,12 @@
#
from __future__ import annotations

from typing import TYPE_CHECKING
# ruff: noqa: TCH003 # Don'e move types to TYPE_CHECKING blocks. Pydantic needs them at runtime.
from datetime import datetime

from pydantic.v1 import BaseModel


if TYPE_CHECKING:
from datetime import datetime


class RemoteFile(BaseModel):
"""A file in a file-based stream."""

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ ignore = [
# Consider re-enabling these when we have time to address them:
"A003", # Class attribute 'type' is shadowing a Python builtin
"BLE001", # Do not catch blind exception: Exception
"C416", # Allow unnecessary-comprehensions. Auto-fix sometimes unsafe if operating over a mapping.
"D", # pydocstyle (Docstring conventions)
"D102", # Missing docstring in public method
"D103", # Missing docstring in public function
Expand Down
6 changes: 5 additions & 1 deletion unit_tests/sources/test_connector_state_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,11 @@
),
),
)
def test_initialize_state_manager(input_stream_state, expected_stream_state, expected_error):
def test_initialize_state_manager(
input_stream_state,
expected_stream_state,
expected_error,
) -> None:
if isinstance(input_stream_state, list):
input_stream_state = [
AirbyteStateMessageSerializer.load(state_obj) for state_obj in list(input_stream_state)
Expand Down

0 comments on commit 78ca8d9

Please sign in to comment.