Skip to content

Commit

Permalink
fix(Low-Code Concurrent CDK): Refactor the low-code AsyncRetriever to…
Browse files Browse the repository at this point in the history
… use an underlying StreamSlicer (#170)
  • Loading branch information
brianjlai authored Dec 18, 2024
1 parent 9563c33 commit 57e1b52
Show file tree
Hide file tree
Showing 7 changed files with 295 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,9 @@
SinglePartitionRouter,
SubstreamPartitionRouter,
)
from airbyte_cdk.sources.declarative.partition_routers.async_job_partition_router import (
AsyncJobPartitionRouter,
)
from airbyte_cdk.sources.declarative.partition_routers.substream_partition_router import (
ParentStreamConfig,
)
Expand Down Expand Up @@ -2260,22 +2263,28 @@ def create_async_retriever(
urls_extractor=urls_extractor,
)

return AsyncRetriever(
async_job_partition_router = AsyncJobPartitionRouter(
job_orchestrator_factory=lambda stream_slices: AsyncJobOrchestrator(
job_repository,
stream_slices,
JobTracker(
1
), # FIXME eventually make the number of concurrent jobs in the API configurable. Until then, we limit to 1
JobTracker(1),
# FIXME eventually make the number of concurrent jobs in the API configurable. Until then, we limit to 1
self._message_repository,
has_bulk_parent=False, # FIXME work would need to be done here in order to detect if a stream as a parent stream that is bulk
has_bulk_parent=False,
# FIXME work would need to be done here in order to detect if a stream as a parent stream that is bulk
),
record_selector=record_selector,
stream_slicer=stream_slicer,
config=config,
parameters=model.parameters or {},
)

return AsyncRetriever(
record_selector=record_selector,
stream_slicer=async_job_partition_router,
config=config,
parameters=model.parameters or {},
)

@staticmethod
def create_spec(model: SpecModel, config: Config, **kwargs: Any) -> Spec:
return Spec(
Expand Down
10 changes: 9 additions & 1 deletion airbyte_cdk/sources/declarative/partition_routers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,18 @@
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
#

from airbyte_cdk.sources.declarative.partition_routers.async_job_partition_router import AsyncJobPartitionRouter
from airbyte_cdk.sources.declarative.partition_routers.cartesian_product_stream_slicer import CartesianProductStreamSlicer
from airbyte_cdk.sources.declarative.partition_routers.list_partition_router import ListPartitionRouter
from airbyte_cdk.sources.declarative.partition_routers.single_partition_router import SinglePartitionRouter
from airbyte_cdk.sources.declarative.partition_routers.substream_partition_router import SubstreamPartitionRouter
from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter

__all__ = ["CartesianProductStreamSlicer", "ListPartitionRouter", "SinglePartitionRouter", "SubstreamPartitionRouter", "PartitionRouter"]
__all__ = [
"AsyncJobPartitionRouter",
"CartesianProductStreamSlicer",
"ListPartitionRouter",
"SinglePartitionRouter",
"SubstreamPartitionRouter",
"PartitionRouter"
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.

from dataclasses import InitVar, dataclass, field
from typing import Any, Callable, Iterable, Mapping, Optional

from airbyte_cdk.models import FailureType
from airbyte_cdk.sources.declarative.async_job.job_orchestrator import (
AsyncJobOrchestrator,
AsyncPartition,
)
from airbyte_cdk.sources.declarative.partition_routers.single_partition_router import (
SinglePartitionRouter,
)
from airbyte_cdk.sources.streams.concurrent.partitions.stream_slicer import StreamSlicer
from airbyte_cdk.sources.types import Config, StreamSlice
from airbyte_cdk.utils.traced_exception import AirbyteTracedException


@dataclass
class AsyncJobPartitionRouter(StreamSlicer):
"""
Partition router that creates async jobs in a source API, periodically polls for job
completion, and supplies the completed job URL locations as stream slices so that
records can be extracted.
"""

config: Config
parameters: InitVar[Mapping[str, Any]]
job_orchestrator_factory: Callable[[Iterable[StreamSlice]], AsyncJobOrchestrator]
stream_slicer: StreamSlicer = field(
default_factory=lambda: SinglePartitionRouter(parameters={})
)

def __post_init__(self, parameters: Mapping[str, Any]) -> None:
self._job_orchestrator_factory = self.job_orchestrator_factory
self._job_orchestrator: Optional[AsyncJobOrchestrator] = None
self._parameters = parameters

def stream_slices(self) -> Iterable[StreamSlice]:
slices = self.stream_slicer.stream_slices()
self._job_orchestrator = self._job_orchestrator_factory(slices)

for completed_partition in self._job_orchestrator.create_and_get_completed_partitions():
yield StreamSlice(
partition=dict(completed_partition.stream_slice.partition)
| {"partition": completed_partition},
cursor_slice=completed_partition.stream_slice.cursor_slice,
)

def fetch_records(self, partition: AsyncPartition) -> Iterable[Mapping[str, Any]]:
"""
This method of fetching records extends beyond what a PartitionRouter/StreamSlicer should
be responsible for. However, this was added in because the JobOrchestrator is required to
retrieve records. And without defining fetch_records() on this class, we're stuck with either
passing the JobOrchestrator to the AsyncRetriever or storing it on multiple classes.
"""

if not self._job_orchestrator:
raise AirbyteTracedException(
message="Invalid state within AsyncJobRetriever. Please contact Airbyte Support",
internal_message="AsyncPartitionRepository is expected to be accessed only after `stream_slices`",
failure_type=FailureType.system_error,
)

return self._job_orchestrator.fetch_records(partition=partition)
39 changes: 8 additions & 31 deletions airbyte_cdk/sources/declarative/retrievers/async_retriever.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.


from dataclasses import InitVar, dataclass, field
from typing import Any, Callable, Iterable, Mapping, Optional
from dataclasses import InitVar, dataclass
from typing import Any, Iterable, Mapping, Optional

from typing_extensions import deprecated

Expand All @@ -12,9 +12,10 @@
AsyncPartition,
)
from airbyte_cdk.sources.declarative.extractors.record_selector import RecordSelector
from airbyte_cdk.sources.declarative.partition_routers import SinglePartitionRouter
from airbyte_cdk.sources.declarative.partition_routers.async_job_partition_router import (
AsyncJobPartitionRouter,
)
from airbyte_cdk.sources.declarative.retrievers import Retriever
from airbyte_cdk.sources.declarative.stream_slicers import StreamSlicer
from airbyte_cdk.sources.source import ExperimentalClassWarning
from airbyte_cdk.sources.streams.core import StreamData
from airbyte_cdk.sources.types import Config, StreamSlice, StreamState
Expand All @@ -29,15 +30,10 @@
class AsyncRetriever(Retriever):
config: Config
parameters: InitVar[Mapping[str, Any]]
job_orchestrator_factory: Callable[[Iterable[StreamSlice]], AsyncJobOrchestrator]
record_selector: RecordSelector
stream_slicer: StreamSlicer = field(
default_factory=lambda: SinglePartitionRouter(parameters={})
)
stream_slicer: AsyncJobPartitionRouter

def __post_init__(self, parameters: Mapping[str, Any]) -> None:
self._job_orchestrator_factory = self.job_orchestrator_factory
self.__job_orchestrator: Optional[AsyncJobOrchestrator] = None
self._parameters = parameters

@property
Expand All @@ -54,17 +50,6 @@ def state(self, value: StreamState) -> None:
"""
pass

@property
def _job_orchestrator(self) -> AsyncJobOrchestrator:
if not self.__job_orchestrator:
raise AirbyteTracedException(
message="Invalid state within AsyncJobRetriever. Please contact Airbyte Support",
internal_message="AsyncPartitionRepository is expected to be accessed only after `stream_slices`",
failure_type=FailureType.system_error,
)

return self.__job_orchestrator

def _get_stream_state(self) -> StreamState:
"""
Gets the current state of the stream.
Expand Down Expand Up @@ -99,15 +84,7 @@ def _validate_and_get_stream_slice_partition(
return stream_slice["partition"] # type: ignore # stream_slice["partition"] has been added as an AsyncPartition as part of stream_slices

def stream_slices(self) -> Iterable[Optional[StreamSlice]]:
slices = self.stream_slicer.stream_slices()
self.__job_orchestrator = self._job_orchestrator_factory(slices)

for completed_partition in self._job_orchestrator.create_and_get_completed_partitions():
yield StreamSlice(
partition=dict(completed_partition.stream_slice.partition)
| {"partition": completed_partition},
cursor_slice=completed_partition.stream_slice.cursor_slice,
)
return self.stream_slicer.stream_slices()

def read_records(
self,
Expand All @@ -116,7 +93,7 @@ def read_records(
) -> Iterable[StreamData]:
stream_state: StreamState = self._get_stream_state()
partition: AsyncPartition = self._validate_and_get_stream_slice_partition(stream_slice)
records: Iterable[Mapping[str, Any]] = self._job_orchestrator.fetch_records(partition)
records: Iterable[Mapping[str, Any]] = self.stream_slicer.fetch_records(partition)

yield from self.record_selector.filter_and_transform(
all_data=records,
Expand Down
21 changes: 14 additions & 7 deletions unit_tests/sources/declarative/async_job/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from airbyte_cdk.sources.declarative.async_job.status import AsyncJobStatus
from airbyte_cdk.sources.declarative.extractors.record_extractor import RecordExtractor
from airbyte_cdk.sources.declarative.extractors.record_selector import RecordSelector
from airbyte_cdk.sources.declarative.partition_routers.async_job_partition_router import (
AsyncJobPartitionRouter,
)
from airbyte_cdk.sources.declarative.retrievers.async_retriever import AsyncRetriever
from airbyte_cdk.sources.declarative.schema import InlineSchemaLoader
from airbyte_cdk.sources.declarative.stream_slicers import StreamSlicer
Expand All @@ -35,7 +38,7 @@

class MockAsyncJobRepository(AsyncJobRepository):
def start(self, stream_slice: StreamSlice) -> AsyncJob:
return AsyncJob("a_job_id", StreamSlice(partition={}, cursor_slice={}))
return AsyncJob("a_job_id", stream_slice)

def update_jobs_status(self, jobs: Set[AsyncJob]) -> None:
for job in jobs:
Expand Down Expand Up @@ -79,12 +82,16 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]:
config={},
parameters={},
record_selector=noop_record_selector,
stream_slicer=self._stream_slicer,
job_orchestrator_factory=lambda stream_slices: AsyncJobOrchestrator(
MockAsyncJobRepository(),
stream_slices,
JobTracker(_NO_LIMIT),
self._message_repository,
stream_slicer=AsyncJobPartitionRouter(
stream_slicer=self._stream_slicer,
job_orchestrator_factory=lambda stream_slices: AsyncJobOrchestrator(
MockAsyncJobRepository(),
stream_slices,
JobTracker(_NO_LIMIT),
self._message_repository,
),
config={},
parameters={},
),
),
config={},
Expand Down
Loading

0 comments on commit 57e1b52

Please sign in to comment.