Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(Low-Code Concurrent CDK): Refactor the low-code AsyncRetriever to use an underlying StreamSlicer #170

Merged
merged 4 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,9 @@
SinglePartitionRouter,
SubstreamPartitionRouter,
)
from airbyte_cdk.sources.declarative.partition_routers.async_job_partition_router import (
AsyncJobPartitionRouter,
)
from airbyte_cdk.sources.declarative.partition_routers.substream_partition_router import (
ParentStreamConfig,
)
Expand Down Expand Up @@ -2228,22 +2231,28 @@ def create_async_retriever(
urls_extractor=urls_extractor,
)

return AsyncRetriever(
async_job_partition_router = AsyncJobPartitionRouter(
job_orchestrator_factory=lambda stream_slices: AsyncJobOrchestrator(
job_repository,
stream_slices,
JobTracker(
1
), # FIXME eventually make the number of concurrent jobs in the API configurable. Until then, we limit to 1
JobTracker(1),
# FIXME eventually make the number of concurrent jobs in the API configurable. Until then, we limit to 1
brianjlai marked this conversation as resolved.
Show resolved Hide resolved
self._message_repository,
has_bulk_parent=False, # FIXME work would need to be done here in order to detect if a stream as a parent stream that is bulk
has_bulk_parent=False,
# FIXME work would need to be done here in order to detect if a stream as a parent stream that is bulk
),
record_selector=record_selector,
stream_slicer=stream_slicer,
config=config,
parameters=model.parameters or {},
)

return AsyncRetriever(
record_selector=record_selector,
stream_slicer=async_job_partition_router,
config=config,
parameters=model.parameters or {},
)

@staticmethod
def create_spec(model: SpecModel, config: Config, **kwargs: Any) -> Spec:
return Spec(
Expand Down
10 changes: 9 additions & 1 deletion airbyte_cdk/sources/declarative/partition_routers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,18 @@
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
#

from airbyte_cdk.sources.declarative.partition_routers.async_job_partition_router import AsyncJobPartitionRouter
from airbyte_cdk.sources.declarative.partition_routers.cartesian_product_stream_slicer import CartesianProductStreamSlicer
from airbyte_cdk.sources.declarative.partition_routers.list_partition_router import ListPartitionRouter
from airbyte_cdk.sources.declarative.partition_routers.single_partition_router import SinglePartitionRouter
from airbyte_cdk.sources.declarative.partition_routers.substream_partition_router import SubstreamPartitionRouter
from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter

__all__ = ["CartesianProductStreamSlicer", "ListPartitionRouter", "SinglePartitionRouter", "SubstreamPartitionRouter", "PartitionRouter"]
__all__ = [
"AsyncJobPartitionRouter",
"CartesianProductStreamSlicer",
"ListPartitionRouter",
"SinglePartitionRouter",
"SubstreamPartitionRouter",
"PartitionRouter"
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.

from dataclasses import InitVar, dataclass, field
from typing import Any, Callable, Iterable, Mapping, Optional

from airbyte_cdk.models import FailureType
from airbyte_cdk.sources.declarative.async_job.job_orchestrator import (
AsyncJobOrchestrator,
AsyncPartition,
)
from airbyte_cdk.sources.declarative.partition_routers.single_partition_router import (
SinglePartitionRouter,
)
from airbyte_cdk.sources.streams.concurrent.partitions.stream_slicer import StreamSlicer
brianjlai marked this conversation as resolved.
Show resolved Hide resolved
from airbyte_cdk.sources.types import Config, StreamSlice
from airbyte_cdk.utils.traced_exception import AirbyteTracedException


@dataclass
class AsyncJobPartitionRouter(StreamSlicer):
"""
Partition router that creates async jobs in a source API, periodically polls for job
completion, and supplies the completed job URL locations as stream slices so that
records can be extracted.
"""

config: Config
parameters: InitVar[Mapping[str, Any]]
job_orchestrator_factory: Callable[[Iterable[StreamSlice]], AsyncJobOrchestrator]
stream_slicer: StreamSlicer = field(
default_factory=lambda: SinglePartitionRouter(parameters={})
)

def __post_init__(self, parameters: Mapping[str, Any]) -> None:
self._job_orchestrator_factory = self.job_orchestrator_factory
self._job_orchestrator: Optional[AsyncJobOrchestrator] = None
self._parameters = parameters

def stream_slices(self) -> Iterable[StreamSlice]:
slices = self.stream_slicer.stream_slices()
self._job_orchestrator = self._job_orchestrator_factory(slices)

for completed_partition in self._job_orchestrator.create_and_get_completed_partitions():
yield StreamSlice(
partition=dict(completed_partition.stream_slice.partition)
| {"partition": completed_partition},
cursor_slice=completed_partition.stream_slice.cursor_slice,
)

def fetch_records(self, partition: AsyncPartition) -> Iterable[Mapping[str, Any]]:
"""
This method of fetching records extends beyond what a PartitionRouter/StreamSlicer should
be responsible for. However, this was added in because the JobOrchestrator is required to
retrieve records. And without defining fetch_records() on this class, we're stuck with either
passing the JobOrchestrator to the AsyncRetriever or storing it on multiple classes.
"""

if not self._job_orchestrator:
raise AirbyteTracedException(
message="Invalid state within AsyncJobRetriever. Please contact Airbyte Support",
internal_message="AsyncPartitionRepository is expected to be accessed only after `stream_slices`",
failure_type=FailureType.system_error,
)

return self._job_orchestrator.fetch_records(partition=partition)
39 changes: 8 additions & 31 deletions airbyte_cdk/sources/declarative/retrievers/async_retriever.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.


from dataclasses import InitVar, dataclass, field
from typing import Any, Callable, Iterable, Mapping, Optional
from dataclasses import InitVar, dataclass
from typing import Any, Iterable, Mapping, Optional

from typing_extensions import deprecated

Expand All @@ -12,9 +12,10 @@
AsyncPartition,
)
from airbyte_cdk.sources.declarative.extractors.record_selector import RecordSelector
from airbyte_cdk.sources.declarative.partition_routers import SinglePartitionRouter
from airbyte_cdk.sources.declarative.partition_routers.async_job_partition_router import (
AsyncJobPartitionRouter,
)
from airbyte_cdk.sources.declarative.retrievers import Retriever
from airbyte_cdk.sources.declarative.stream_slicers import StreamSlicer
from airbyte_cdk.sources.source import ExperimentalClassWarning
from airbyte_cdk.sources.streams.core import StreamData
from airbyte_cdk.sources.types import Config, StreamSlice, StreamState
Expand All @@ -29,15 +30,10 @@
class AsyncRetriever(Retriever):
config: Config
parameters: InitVar[Mapping[str, Any]]
job_orchestrator_factory: Callable[[Iterable[StreamSlice]], AsyncJobOrchestrator]
record_selector: RecordSelector
stream_slicer: StreamSlicer = field(
default_factory=lambda: SinglePartitionRouter(parameters={})
)
stream_slicer: AsyncJobPartitionRouter

def __post_init__(self, parameters: Mapping[str, Any]) -> None:
self._job_orchestrator_factory = self.job_orchestrator_factory
self.__job_orchestrator: Optional[AsyncJobOrchestrator] = None
self._parameters = parameters

@property
Expand All @@ -54,17 +50,6 @@ def state(self, value: StreamState) -> None:
"""
pass

@property
def _job_orchestrator(self) -> AsyncJobOrchestrator:
if not self.__job_orchestrator:
raise AirbyteTracedException(
message="Invalid state within AsyncJobRetriever. Please contact Airbyte Support",
internal_message="AsyncPartitionRepository is expected to be accessed only after `stream_slices`",
failure_type=FailureType.system_error,
)

return self.__job_orchestrator

def _get_stream_state(self) -> StreamState:
"""
Gets the current state of the stream.
Expand Down Expand Up @@ -99,15 +84,7 @@ def _validate_and_get_stream_slice_partition(
return stream_slice["partition"] # type: ignore # stream_slice["partition"] has been added as an AsyncPartition as part of stream_slices

def stream_slices(self) -> Iterable[Optional[StreamSlice]]:
slices = self.stream_slicer.stream_slices()
self.__job_orchestrator = self._job_orchestrator_factory(slices)

for completed_partition in self._job_orchestrator.create_and_get_completed_partitions():
yield StreamSlice(
partition=dict(completed_partition.stream_slice.partition)
| {"partition": completed_partition},
cursor_slice=completed_partition.stream_slice.cursor_slice,
)
return self.stream_slicer.stream_slices()

def read_records(
self,
Expand All @@ -116,7 +93,7 @@ def read_records(
) -> Iterable[StreamData]:
stream_state: StreamState = self._get_stream_state()
partition: AsyncPartition = self._validate_and_get_stream_slice_partition(stream_slice)
records: Iterable[Mapping[str, Any]] = self._job_orchestrator.fetch_records(partition)
records: Iterable[Mapping[str, Any]] = self.stream_slicer.fetch_records(partition)
brianjlai marked this conversation as resolved.
Show resolved Hide resolved

yield from self.record_selector.filter_and_transform(
all_data=records,
Expand Down
19 changes: 13 additions & 6 deletions unit_tests/sources/declarative/async_job/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from airbyte_cdk.sources.declarative.async_job.status import AsyncJobStatus
from airbyte_cdk.sources.declarative.extractors.record_extractor import RecordExtractor
from airbyte_cdk.sources.declarative.extractors.record_selector import RecordSelector
from airbyte_cdk.sources.declarative.partition_routers.async_job_partition_router import (
AsyncJobPartitionRouter,
)
from airbyte_cdk.sources.declarative.retrievers.async_retriever import AsyncRetriever
from airbyte_cdk.sources.declarative.schema import InlineSchemaLoader
from airbyte_cdk.sources.declarative.stream_slicers import StreamSlicer
Expand Down Expand Up @@ -79,12 +82,16 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]:
config={},
parameters={},
record_selector=noop_record_selector,
stream_slicer=self._stream_slicer,
job_orchestrator_factory=lambda stream_slices: AsyncJobOrchestrator(
MockAsyncJobRepository(),
stream_slices,
JobTracker(_NO_LIMIT),
self._message_repository,
stream_slicer=AsyncJobPartitionRouter(
stream_slicer=self._stream_slicer,
job_orchestrator_factory=lambda stream_slices: AsyncJobOrchestrator(
MockAsyncJobRepository(),
stream_slices,
JobTracker(_NO_LIMIT),
self._message_repository,
),
config={},
parameters={},
),
),
config={},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from airbyte_cdk import AirbyteTracedException
from airbyte_cdk.models import FailureType, Level
from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager
from airbyte_cdk.sources.declarative.async_job.job_orchestrator import AsyncJobOrchestrator
from airbyte_cdk.sources.declarative.auth import DeclarativeOauth2Authenticator, JwtAuthenticator
from airbyte_cdk.sources.declarative.auth.token import (
ApiKeyAuthenticator,
Expand All @@ -40,6 +41,7 @@
ResumableFullRefreshCursor,
)
from airbyte_cdk.sources.declarative.interpolation import InterpolatedString
from airbyte_cdk.sources.declarative.models import AsyncRetriever as AsyncRetrieverModel
from airbyte_cdk.sources.declarative.models import CheckStream as CheckStreamModel
from airbyte_cdk.sources.declarative.models import (
CompositeErrorHandler as CompositeErrorHandlerModel,
Expand Down Expand Up @@ -85,6 +87,7 @@
ModelToComponentFactory,
)
from airbyte_cdk.sources.declarative.partition_routers import (
AsyncJobPartitionRouter,
CartesianProductStreamSlicer,
ListPartitionRouter,
SinglePartitionRouter,
Expand All @@ -102,6 +105,7 @@
WaitTimeFromHeaderBackoffStrategy,
WaitUntilTimeFromHeaderBackoffStrategy,
)
from airbyte_cdk.sources.declarative.requesters.http_job_repository import AsyncHttpJobRepository
from airbyte_cdk.sources.declarative.requesters.paginators import DefaultPaginator
from airbyte_cdk.sources.declarative.requesters.paginators.strategies import (
CursorPaginationStrategy,
Expand All @@ -121,6 +125,7 @@
from airbyte_cdk.sources.declarative.requesters.request_path import RequestPath
from airbyte_cdk.sources.declarative.requesters.requester import HttpMethod
from airbyte_cdk.sources.declarative.retrievers import (
AsyncRetriever,
SimpleRetriever,
SimpleRetrieverTestReadDecorator,
)
Expand All @@ -138,6 +143,7 @@
from airbyte_cdk.sources.streams.http.requests_native_auth.oauth import (
SingleUseRefreshTokenOauth2Authenticator,
)
from airbyte_cdk.sources.types import StreamSlice
from unit_tests.sources.declarative.parsers.testing_components import (
TestingCustomSubstreamPartitionRouter,
TestingSomeComponent,
Expand Down Expand Up @@ -3294,3 +3300,97 @@ def test_create_custom_record_extractor():
}
component = factory.create_component(CustomRecordExtractorModel, definition, {})
assert isinstance(component, CustomRecordExtractor)


def test_create_async_retriever():
config = {"api_key": "123"}

definition = {
"type": "AsyncRetriever",
"status_mapping": {
"failed": ["failed"],
"running": ["pending"],
"timeout": ["timeout"],
"completed": ["ready"],
},
"urls_extractor": {"type": "DpathExtractor", "field_path": ["urls"]},
"record_selector": {
"type": "RecordSelector",
"extractor": {"type": "DpathExtractor", "field_path": ["data"]},
},
"status_extractor": {"type": "DpathExtractor", "field_path": ["status"]},
"polling_requester": {
"type": "HttpRequester",
"path": "/v3/marketing/contacts/exports/{{stream_slice['create_job_response'].json()['id'] }}",
"url_base": "https://api.sendgrid.com",
"http_method": "GET",
"authenticator": {
"type": "BearerAuthenticator",
"api_token": "{{ config['api_key'] }}",
},
},
"creation_requester": {
"type": "HttpRequester",
"path": "/v3/marketing/contacts/exports",
"url_base": "https://api.sendgrid.com",
"http_method": "POST",
"authenticator": {
"type": "BearerAuthenticator",
"api_token": "{{ config['api_key'] }}",
},
},
"download_requester": {
"type": "HttpRequester",
"path": "{{stream_slice['url']}}",
"url_base": "",
"http_method": "GET",
},
"abort_requester": {
"type": "HttpRequester",
"path": "{{stream_slice['url']}}/abort",
"url_base": "",
"http_method": "POST",
},
"delete_requester": {
"type": "HttpRequester",
"path": "{{stream_slice['url']}}",
"url_base": "",
"http_method": "POST",
},
}

component = factory.create_component(
model_type=AsyncRetrieverModel,
component_definition=definition,
name="test_stream",
primary_key="id",
stream_slicer=None,
transformations=[],
config=config,
)

assert isinstance(component, AsyncRetriever)

async_job_partition_router = component.stream_slicer
assert isinstance(async_job_partition_router, AsyncJobPartitionRouter)
assert isinstance(async_job_partition_router.stream_slicer, SinglePartitionRouter)
job_orchestrator = async_job_partition_router.job_orchestrator_factory(
[StreamSlice(partition={}, cursor_slice={})]
)
assert isinstance(job_orchestrator, AsyncJobOrchestrator)

job_repository = job_orchestrator._job_repository
assert isinstance(job_repository, AsyncHttpJobRepository)
assert job_repository.creation_requester
assert job_repository.polling_requester
assert job_repository.download_retriever
assert job_repository.abort_requester
assert job_repository.delete_requester
assert job_repository.status_extractor
assert job_repository.urls_extractor

selector = component.record_selector
extractor = selector.extractor
assert isinstance(selector, RecordSelector)
assert isinstance(extractor, DpathExtractor)
assert extractor.field_path == ["data"]
Loading
Loading