Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(Low-Code Concurrent CDK): Add ConcurrentPerPartitionCursor #111

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions airbyte_cdk/sources/declarative/concurrent_declarative_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
ClientSideIncrementalRecordFilterDecorator,
)
from airbyte_cdk.sources.declarative.incremental.datetime_based_cursor import DatetimeBasedCursor
from airbyte_cdk.sources.declarative.incremental.per_partition_with_global import (
PerPartitionWithGlobalCursor,
)
from airbyte_cdk.sources.declarative.interpolation import InterpolatedString
from airbyte_cdk.sources.declarative.manifest_declarative_source import ManifestDeclarativeSource
from airbyte_cdk.sources.declarative.models.declarative_component_schema import (
Expand Down Expand Up @@ -306,6 +309,59 @@ def _group_streams(
cursor=final_state_cursor,
)
)
elif (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we filter only on list partition router until we support the global cursor part?

incremental_sync_component_definition
and incremental_sync_component_definition.get("type", "")
== DatetimeBasedCursorModel.__name__
and self._stream_supports_concurrent_partition_processing(
declarative_stream=declarative_stream
)
and hasattr(declarative_stream.retriever, "stream_slicer")
and isinstance(declarative_stream.retriever.stream_slicer, PerPartitionWithGlobalCursor)
):
stream_state = state_manager.get_stream_state(
stream_name=declarative_stream.name, namespace=declarative_stream.namespace
)
partition_router = declarative_stream.retriever.stream_slicer._partition_router

cursor = self._constructor.create_concurrent_cursor_from_perpartition_cursor(
state_manager=state_manager,
model_type=DatetimeBasedCursorModel,
component_definition=incremental_sync_component_definition,
stream_name=declarative_stream.name,
stream_namespace=declarative_stream.namespace,
config=config or {},
stream_state=stream_state,
partition_router=partition_router,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Handling potential None in stream_state

When retrieving stream_state, do we need to handle the case where it might be None? Ensuring that stream_state is properly initialized could prevent unexpected errors during cursor creation. Wdyt?

)


partition_generator = StreamSlicerPartitionGenerator(
DeclarativePartitionFactory(
declarative_stream.name,
declarative_stream.get_json_schema(),
self._retriever_factory(
name_to_stream_mapping[declarative_stream.name],
config,
stream_state,
),
self.message_repository,
),
cursor,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Error handling for partition generator creation

In constructing the partition_generator, are there potential edge cases where dependencies might not be properly initialized? Should we add error handling or input validation to make the code more robust against such scenarios? Wdyt?


concurrent_streams.append(
DefaultStream(
partition_generator=partition_generator,
name=declarative_stream.name,
json_schema=declarative_stream.get_json_schema(),
availability_strategy=AlwaysAvailableAvailabilityStrategy(),
primary_key=get_primary_key_from_stream(declarative_stream.primary_key),
cursor_field=cursor.cursor_field.cursor_field_key,
logger=self.logger,
cursor=cursor,
)
)
else:
synchronous_streams.append(declarative_stream)
else:
Expand Down
8 changes: 3 additions & 5 deletions airbyte_cdk/sources/declarative/extractors/record_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,11 @@ class ClientSideIncrementalRecordFilterDecorator(RecordFilter):

def __init__(
self,
date_time_based_cursor: DatetimeBasedCursor,
substream_cursor: Optional[Union[PerPartitionWithGlobalCursor, GlobalSubstreamCursor]],
cursor: Union[DatetimeBasedCursor, PerPartitionWithGlobalCursor, GlobalSubstreamCursor],
**kwargs: Any,
):
super().__init__(**kwargs)
self._date_time_based_cursor = date_time_based_cursor
self._substream_cursor = substream_cursor
self._cursor = cursor

def filter_records(
self,
Expand All @@ -77,7 +75,7 @@ def filter_records(
records = (
record
for record in records
if (self._substream_cursor or self._date_time_based_cursor).should_be_synced(
if self._cursor.should_be_synced(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is so beautiful that I want to cry ❤️

# Record is created on the fly to align with cursors interface; stream name is ignored as we don't need it here
# Record stream name is empty cause it is not used durig the filtering
Record(data=record, associated_slice=stream_slice, stream_name="")
Expand Down
3 changes: 3 additions & 0 deletions airbyte_cdk/sources/declarative/incremental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
#

from airbyte_cdk.sources.declarative.incremental.concurrent_partition_cursor import ConcurrentCursorFactory, ConcurrentPerPartitionCursor
from airbyte_cdk.sources.declarative.incremental.datetime_based_cursor import DatetimeBasedCursor
from airbyte_cdk.sources.declarative.incremental.declarative_cursor import DeclarativeCursor
from airbyte_cdk.sources.declarative.incremental.global_substream_cursor import GlobalSubstreamCursor
Expand All @@ -14,6 +15,8 @@

__all__ = [
"CursorFactory",
"ConcurrentCursorFactory"
"ConcurrentPerPartitionCursor",
Comment on lines +18 to +19
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Comma missing in __all__ list

In the __all__ list, there's a missing comma after "ConcurrentCursorFactory". This could lead to import errors. Should we add the comma to fix this issue? Wdyt?

"DatetimeBasedCursor",
"DeclarativeCursor",
"GlobalSubstreamCursor",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
import copy

#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
import logging
from collections import OrderedDict
from typing import Any, Callable, Iterable, Mapping, MutableMapping, Optional

from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager
from airbyte_cdk.sources.declarative.incremental.declarative_cursor import DeclarativeCursor
from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter
from airbyte_cdk.sources.message import MessageRepository
from airbyte_cdk.sources.streams.checkpoint.per_partition_key_serializer import (
PerPartitionKeySerializer,
)
from airbyte_cdk.sources.streams.concurrent.cursor import Cursor, CursorField
from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition
from airbyte_cdk.sources.types import Record, StreamSlice, StreamState

logger = logging.getLogger("airbyte")


class ConcurrentCursorFactory:
def __init__(self, create_function: Callable[..., Cursor]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the input type for Callable be StreamState?

self._create_function = create_function

def create(self, stream_state: Mapping[str, Any]) -> Cursor:
return self._create_function(stream_state=stream_state)


class ConcurrentPerPartitionCursor(Cursor):
"""
Manages state per partition when a stream has many partitions, to prevent data loss or duplication.

**Partition Limitation and Limit Reached Logic**

- **DEFAULT_MAX_PARTITIONS_NUMBER**: The maximum number of partitions to keep in memory (default is 10,000).
- **_cursor_per_partition**: An ordered dictionary that stores cursors for each partition.
- **_over_limit**: A counter that increments each time an oldest partition is removed when the limit is exceeded.

The class ensures that the number of partitions tracked does not exceed the `DEFAULT_MAX_PARTITIONS_NUMBER` to prevent excessive memory usage.

- When the number of partitions exceeds the limit, the oldest partitions are removed from `_cursor_per_partition`, and `_over_limit` is incremented accordingly.
- The `limit_reached` method returns `True` when `_over_limit` exceeds `DEFAULT_MAX_PARTITIONS_NUMBER`, indicating that the global cursor should be used instead of per-partition cursors.

This approach avoids unnecessary switching to a global cursor due to temporary spikes in partition counts, ensuring that switching is only done when a sustained high number of partitions is observed.
"""

DEFAULT_MAX_PARTITIONS_NUMBER = 10000
_NO_STATE: Mapping[str, Any] = {}
_NO_CURSOR_STATE: Mapping[str, Any] = {}
_KEY = 0
_VALUE = 1
_state_to_migrate_from: Mapping[str, Any] = {}

def __init__(
self,
cursor_factory: ConcurrentCursorFactory,
partition_router: PartitionRouter,
stream_name: str,
stream_namespace: Optional[str],
stream_state: Any,
message_repository: MessageRepository,
connector_state_manager: ConnectorStateManager,
cursor_field: CursorField,
) -> None:
self._stream_name = stream_name
self._stream_namespace = stream_namespace
self._message_repository = message_repository
self._connector_state_manager = connector_state_manager
self._cursor_field = cursor_field

self._cursor_factory = cursor_factory
self._partition_router = partition_router

# The dict is ordered to ensure that once the maximum number of partitions is reached,
# the oldest partitions can be efficiently removed, maintaining the most recent partitions.
self._cursor_per_partition: OrderedDict[str, Cursor] = OrderedDict()
self._over_limit = 0
self._partition_serializer = PerPartitionKeySerializer()

self._set_initial_state(stream_state)

@property
def cursor_field(self) -> CursorField:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add this in the Cursor interface? If we put it there, it seems like the return type will need to be Optional because of the FinalStateCursor

return self._cursor_field

@property
def state(self) -> MutableMapping[str, Any]:
states = []
for partition_tuple, cursor in self._cursor_per_partition.items():
cursor_state = cursor._connector_state_converter.convert_to_state_message(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have two confusions regarding this method:

Why is Cursor.state public?
I know it was already like that before you changes but I'm trying to understand where it is used and with it is public. Before your changes, it didn't seem like we needed it to be public.

Is there a way we can access the information we need without accessing private parameters?
It feel like what we need to make public here is the representation of the data that is being set as an AirbyteMessage of type state. In order words, the one in ConcurrentCursor._emit_state_message because when PerPartitionCursor will want to emit the state, it does not care about the implementation of the cursor but care about setting it in a state message

Based on the two points above, can we have the state method be more explicit (something like as_state_message) with a comment saying it is used for PerPartitionCursor and it is NOT used for other classes that are not cursors to emit states?

cursor._cursor_field, cursor.state
)
if cursor_state:
states.append(
{
"partition": self._to_dict(partition_tuple),
"cursor": copy.deepcopy(cursor_state),
}
)
state: dict[str, Any] = {"states": states}
return state

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Potential performance issue when generating state

The state property method iterates over self._cursor_per_partition.items() and performs deep copies of cursor states. This could become a performance bottleneck with a large number of partitions. Should we consider optimizing this by avoiding deep copies or processing states incrementally? Wdyt?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason we're doing a deepcopy here? If so, I think we should document it

def close_partition(self, partition: Partition) -> None:
self._cursor_per_partition[self._to_partition_key(partition._stream_slice.partition)].close_partition_without_emit(partition=partition)

def ensure_at_least_one_state_emitted(self) -> None:
"""
The platform expect to have at least one state message on successful syncs. Hence, whatever happens, we expect this method to be
called.
"""
self._emit_state_message()

def _emit_state_message(self) -> None:
self._connector_state_manager.update_state_for_stream(
self._stream_name,
self._stream_namespace,
self.state,
)
state_message = self._connector_state_manager.create_state_message(
self._stream_name, self._stream_namespace
)
self._message_repository.emit_message(state_message)


def stream_slices(self) -> Iterable[StreamSlice]:
slices = self._partition_router.stream_slices()
for partition in slices:
yield from self.generate_slices_from_partition(partition)

def generate_slices_from_partition(self, partition: StreamSlice) -> Iterable[StreamSlice]:
# Ensure the maximum number of partitions is not exceeded
self._ensure_partition_limit()

cursor = self._cursor_per_partition.get(self._to_partition_key(partition.partition))
if not cursor:
partition_state = (
self._state_to_migrate_from
if self._state_to_migrate_from
else self._NO_CURSOR_STATE
)
cursor = self._create_cursor(partition_state)
self._cursor_per_partition[self._to_partition_key(partition.partition)] = cursor

for cursor_slice in cursor.stream_slices():
yield StreamSlice(
partition=partition, cursor_slice=cursor_slice, extra_fields=partition.extra_fields
)

def _ensure_partition_limit(self) -> None:
"""
Ensure the maximum number of partitions is not exceeded. If so, the oldest added partition will be dropped.
"""
while len(self._cursor_per_partition) > self.DEFAULT_MAX_PARTITIONS_NUMBER - 1:
self._over_limit += 1
oldest_partition = self._cursor_per_partition.popitem(last=False)[
0
] # Remove the oldest partition
logger.warning(
f"The maximum number of partitions has been reached. Dropping the oldest partition: {oldest_partition}. Over limit: {self._over_limit}."
)

Comment on lines +216 to +228
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Clarification on partition limit logic

In _ensure_partition_limit, we increment _over_limit every time we remove a partition when the limit is reached. However, in limit_reached, we check if _over_limit > DEFAULT_MAX_PARTITIONS_NUMBER. Is this the intended behavior? Should the condition be adjusted to properly reflect when the limit is truly exceeded? Wdyt?

def limit_reached(self) -> bool:
return self._over_limit > self.DEFAULT_MAX_PARTITIONS_NUMBER

Comment on lines +229 to +231
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Possible off-by-one error in limit_reached method

The method limit_reached returns True when _over_limit > DEFAULT_MAX_PARTITIONS_NUMBER. Given how _over_limit is incremented, could this condition lead to an off-by-one error? Should we revisit this logic to ensure accurate limit checks? Wdyt?

def _set_initial_state(self, stream_state: StreamState) -> None:
"""
Set the initial state for the cursors.

This method initializes the state for each partition cursor using the provided stream state.
If a partition state is provided in the stream state, it will update the corresponding partition cursor with this state.

Additionally, it sets the parent state for partition routers that are based on parent streams. If a partition router
does not have parent streams, this step will be skipped due to the default PartitionRouter implementation.

Args:
stream_state (StreamState): The state of the streams to be set. The format of the stream state should be:
{
"states": [
{
"partition": {
"partition_key": "value"
},
"cursor": {
"last_updated": "2023-05-27T00:00:00Z"
}
}
],
"parent_state": {
"parent_stream_name": {
"last_updated": "2023-05-27T00:00:00Z"
}
}
}
"""
if not stream_state:
return

if "states" not in stream_state:
# We assume that `stream_state` is in a global format that can be applied to all partitions.
# Example: {"global_state_format_key": "global_state_format_value"}
self._state_to_migrate_from = stream_state

else:
for state in stream_state["states"]:
self._cursor_per_partition[self._to_partition_key(state["partition"])] = (
self._create_cursor(state["cursor"])
)

# set default state for missing partitions if it is per partition with fallback to global
if "state" in stream_state:
self._state_to_migrate_from = stream_state["state"]

# Set parent state for partition routers based on parent streams
self._partition_router.set_initial_state(stream_state)

def observe(self, record: Record) -> None:
self._cursor_per_partition[self._to_partition_key(record.associated_slice.partition)].observe(record)

def _to_partition_key(self, partition: Mapping[str, Any]) -> str:
return self._partition_serializer.to_partition_key(partition)

def _to_dict(self, partition_key: str) -> Mapping[str, Any]:
return self._partition_serializer.to_partition(partition_key)

def _create_cursor(self, cursor_state: Any) -> DeclarativeCursor:
cursor = self._cursor_factory.create(stream_state=cursor_state)
return cursor

def should_be_synced(self, record: Record) -> bool:
return self._get_cursor(record).should_be_synced(record)

def is_greater_than_or_equal(self, first: Record, second: Record) -> bool:
if not first.associated_slice or not second.associated_slice:
raise ValueError(
f"Both records should have an associated slice but got {first.associated_slice} and {second.associated_slice}"
)
if first.associated_slice.partition != second.associated_slice.partition:
raise ValueError(
f"To compare records, partition should be the same but got {first.associated_slice.partition} and {second.associated_slice.partition}"
)

return self._get_cursor(first).is_greater_than_or_equal(
self._convert_record_to_cursor_record(first),
self._convert_record_to_cursor_record(second),
)

@staticmethod
def _convert_record_to_cursor_record(record: Record) -> Record:
return Record(
record.data,
StreamSlice(partition={}, cursor_slice=record.associated_slice.cursor_slice)
if record.associated_slice
else None,
)

def _get_cursor(self, record: Record) -> Cursor:
if not record.associated_slice:
raise ValueError(
"Invalid state as stream slices that are emitted should refer to an existing cursor"
)
partition_key = self._to_partition_key(record.associated_slice.partition)
if partition_key not in self._cursor_per_partition:
raise ValueError(
"Invalid state as stream slices that are emitted should refer to an existing cursor"
)
cursor = self._cursor_per_partition[partition_key]
return cursor
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,15 @@ def get_request_body_json(
raise ValueError("A partition needs to be provided in order to get request body json")

def should_be_synced(self, record: Record) -> bool:
if self._to_partition_key(record.associated_slice.partition) not in self._cursor_per_partition:
partition_state = (
self._state_to_migrate_from
if self._state_to_migrate_from
else self._NO_CURSOR_STATE
)
cursor = self._create_cursor(partition_state)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Should we add a check for record.associated_slice being None in should_be_synced?

In the should_be_synced method, we access record.associated_slice.partition without verifying if record.associated_slice is not None. This could raise an AttributeError if record.associated_slice is None. Should we add a check to ensure record.associated_slice is not None before proceeding? Wdyt?

self._cursor_per_partition[self._to_partition_key(record.associated_slice.partition)] = cursor
return self._get_cursor(record).should_be_synced(
self._convert_record_to_cursor_record(record)
)
Expand Down
Loading
Loading