From ef516432d0f831068a6ca4a648dde7fcfde064da Mon Sep 17 00:00:00 2001 From: brianjlai Date: Fri, 20 Dec 2024 17:46:38 -0800 Subject: [PATCH 1/5] Refactor retrievers, paginators, and pagination strategy to be stateless --- .../paginators/default_paginator.py | 84 ++++++----- .../requesters/paginators/no_pagination.py | 17 ++- .../requesters/paginators/paginator.py | 13 +- .../strategies/cursor_pagination_strategy.py | 17 ++- .../paginators/strategies/offset_increment.py | 21 ++- .../paginators/strategies/page_increment.py | 27 ++-- .../strategies/pagination_strategy.py | 15 +- .../paginators/strategies/stop_condition.py | 17 ++- .../retrievers/simple_retriever.py | 142 ++++++++++++------ .../test_cursor_pagination_strategy.py | 18 +-- .../paginators/test_default_paginator.py | 63 ++------ .../paginators/test_no_paginator.py | 2 +- .../paginators/test_offset_increment.py | 58 +++---- .../paginators/test_page_increment.py | 72 ++++----- .../paginators/test_stop_condition.py | 14 +- .../retrievers/test_simple_retriever.py | 52 ++++--- .../test_manifest_declarative_source.py | 13 +- 17 files changed, 345 insertions(+), 300 deletions(-) diff --git a/airbyte_cdk/sources/declarative/requesters/paginators/default_paginator.py b/airbyte_cdk/sources/declarative/requesters/paginators/default_paginator.py index e9476447a..50fbe636b 100644 --- a/airbyte_cdk/sources/declarative/requesters/paginators/default_paginator.py +++ b/airbyte_cdk/sources/declarative/requesters/paginators/default_paginator.py @@ -112,27 +112,40 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: ) if isinstance(self.url_base, str): self.url_base = InterpolatedString(string=self.url_base, parameters=parameters) - self._token: Optional[Any] = self.pagination_strategy.initial_token + # self._token: Optional[Any] = self.pagination_strategy.initial_token + + def get_initial_token(self) -> Optional[Any]: + """ + Return the page token that should be used for the first request of a stream + + WARNING: get_initial_token() should not be used by streams that use RFR that perform checkpointing + of state using page numbers. Because paginators are stateless + """ + return self.pagination_strategy.initial_token def next_page_token( - self, response: requests.Response, last_page_size: int, last_record: Optional[Record] + self, + response: requests.Response, + last_page_size: int, + last_record: Optional[Record], + last_page_token_value: Optional[Any] = None, ) -> Optional[Mapping[str, Any]]: - self._token = self.pagination_strategy.next_page_token( - response, last_page_size, last_record + next_page_token = self.pagination_strategy.next_page_token( + response=response, + last_page_size=last_page_size, + last_record=last_record, + last_page_token_value=last_page_token_value, ) - if self._token: - return {"next_page_token": self._token} + if next_page_token: + return {"next_page_token": next_page_token} else: return None - def path(self) -> Optional[str]: - if ( - self._token - and self.page_token_option - and isinstance(self.page_token_option, RequestPath) - ): + def path(self, next_page_token: Mapping[str, Any]) -> Optional[str]: + token = next_page_token.get("next_page_token") if next_page_token else None + if token and self.page_token_option and isinstance(self.page_token_option, RequestPath): # Replace url base to only return the path - return str(self._token).replace(self.url_base.eval(self.config), "") # type: ignore # url_base is casted to a InterpolatedString in __post_init__ + return str(token).replace(self.url_base.eval(self.config), "") # type: ignore # url_base is casted to a InterpolatedString in __post_init__ else: return None @@ -143,7 +156,7 @@ def get_request_params( stream_slice: Optional[StreamSlice] = None, next_page_token: Optional[Mapping[str, Any]] = None, ) -> MutableMapping[str, Any]: - return self._get_request_options(RequestOptionType.request_parameter) + return self._get_request_options(RequestOptionType.request_parameter, next_page_token) def get_request_headers( self, @@ -152,7 +165,7 @@ def get_request_headers( stream_slice: Optional[StreamSlice] = None, next_page_token: Optional[Mapping[str, Any]] = None, ) -> Mapping[str, str]: - return self._get_request_options(RequestOptionType.header) + return self._get_request_options(RequestOptionType.header, next_page_token) def get_request_body_data( self, @@ -161,7 +174,7 @@ def get_request_body_data( stream_slice: Optional[StreamSlice] = None, next_page_token: Optional[Mapping[str, Any]] = None, ) -> Mapping[str, Any]: - return self._get_request_options(RequestOptionType.body_data) + return self._get_request_options(RequestOptionType.body_data, next_page_token) def get_request_body_json( self, @@ -170,25 +183,21 @@ def get_request_body_json( stream_slice: Optional[StreamSlice] = None, next_page_token: Optional[Mapping[str, Any]] = None, ) -> Mapping[str, Any]: - return self._get_request_options(RequestOptionType.body_json) - - def reset(self, reset_value: Optional[Any] = None) -> None: - if reset_value: - self.pagination_strategy.reset(reset_value=reset_value) - else: - self.pagination_strategy.reset() - self._token = self.pagination_strategy.initial_token + return self._get_request_options(RequestOptionType.body_json, next_page_token) - def _get_request_options(self, option_type: RequestOptionType) -> MutableMapping[str, Any]: + def _get_request_options( + self, option_type: RequestOptionType, next_page_token: Optional[Mapping[str, Any]] + ) -> MutableMapping[str, Any]: options = {} + token = next_page_token.get("next_page_token") if next_page_token else None if ( self.page_token_option - and self._token is not None + and token is not None and isinstance(self.page_token_option, RequestOption) and self.page_token_option.inject_into == option_type ): - options[self.page_token_option.field_name.eval(config=self.config)] = self._token # type: ignore # field_name is always cast to an interpolated string + options[self.page_token_option.field_name.eval(config=self.config)] = token # type: ignore # field_name is always cast to an interpolated string if ( self.page_size_option and self.pagination_strategy.get_page_size() @@ -217,17 +226,26 @@ def __init__(self, decorated: Paginator, maximum_number_of_pages: int = 5) -> No self._decorated = decorated self._page_count = self._PAGE_COUNT_BEFORE_FIRST_NEXT_CALL + def get_initial_token(self) -> Optional[Any]: + return self._decorated.get_initial_token() + def next_page_token( - self, response: requests.Response, last_page_size: int, last_record: Optional[Record] + self, + response: requests.Response, + last_page_size: int, + last_record: Optional[Record], + last_page_token_value: Optional[Any] = None, ) -> Optional[Mapping[str, Any]]: if self._page_count >= self._maximum_number_of_pages: return None self._page_count += 1 - return self._decorated.next_page_token(response, last_page_size, last_record) + return self._decorated.next_page_token( + response, last_page_size, last_record, last_page_token_value + ) - def path(self) -> Optional[str]: - return self._decorated.path() + def path(self, next_page_token: Mapping[str, Any]) -> Optional[str]: + return self._decorated.path(next_page_token) def get_request_params( self, @@ -272,7 +290,3 @@ def get_request_body_json( return self._decorated.get_request_body_json( stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token ) - - def reset(self, reset_value: Optional[Any] = None) -> None: - self._decorated.reset() - self._page_count = self._PAGE_COUNT_BEFORE_FIRST_NEXT_CALL diff --git a/airbyte_cdk/sources/declarative/requesters/paginators/no_pagination.py b/airbyte_cdk/sources/declarative/requesters/paginators/no_pagination.py index cb0592793..230899cab 100644 --- a/airbyte_cdk/sources/declarative/requesters/paginators/no_pagination.py +++ b/airbyte_cdk/sources/declarative/requesters/paginators/no_pagination.py @@ -19,7 +19,7 @@ class NoPagination(Paginator): parameters: InitVar[Mapping[str, Any]] - def path(self) -> Optional[str]: + def path(self, next_page_token: Mapping[str, Any]) -> Optional[str]: return None def get_request_params( @@ -58,11 +58,14 @@ def get_request_body_json( ) -> Mapping[str, Any]: return {} + def get_initial_token(self) -> Optional[Any]: + return None + def next_page_token( - self, response: requests.Response, last_page_size: int, last_record: Optional[Record] - ) -> Mapping[str, Any]: + self, + response: requests.Response, + last_page_size: int, + last_record: Optional[Record], + last_page_token_value: Optional[Any], + ) -> Optional[Mapping[str, Any]]: return {} - - def reset(self, reset_value: Optional[Any] = None) -> None: - # No state to reset - pass diff --git a/airbyte_cdk/sources/declarative/requesters/paginators/paginator.py b/airbyte_cdk/sources/declarative/requesters/paginators/paginator.py index d47124628..2def49e3a 100644 --- a/airbyte_cdk/sources/declarative/requesters/paginators/paginator.py +++ b/airbyte_cdk/sources/declarative/requesters/paginators/paginator.py @@ -24,14 +24,18 @@ class Paginator(ABC, RequestOptionsProvider): """ @abstractmethod - def reset(self, reset_value: Optional[Any] = None) -> None: + def get_initial_token(self) -> Optional[Any]: """ - Reset the pagination's inner state + Get the page token that should be included in the request to get the first page of records """ @abstractmethod def next_page_token( - self, response: requests.Response, last_page_size: int, last_record: Optional[Record] + self, + response: requests.Response, + last_page_size: int, + last_record: Optional[Record], + last_page_token_value: Optional[Any], ) -> Optional[Mapping[str, Any]]: """ Returns the next_page_token to use to fetch the next page of records. @@ -39,12 +43,13 @@ def next_page_token( :param response: the response to process :param last_page_size: the number of records read from the response :param last_record: the last record extracted from the response + :param last_page_token_value: The current value of the page token made on the last request :return: A mapping {"next_page_token": } for the next page from the input response object. Returning None means there are no more pages to read in this response. """ pass @abstractmethod - def path(self) -> Optional[str]: + def path(self, next_page_token: Mapping[str, Any]) -> Optional[str]: """ Returns the URL path to hit to fetch the next page of records diff --git a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/cursor_pagination_strategy.py b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/cursor_pagination_strategy.py index beebf9e83..e35c84c7c 100644 --- a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/cursor_pagination_strategy.py +++ b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/cursor_pagination_strategy.py @@ -43,7 +43,6 @@ class CursorPaginationStrategy(PaginationStrategy): ) def __post_init__(self, parameters: Mapping[str, Any]) -> None: - self._initial_cursor = None if isinstance(self.cursor_value, str): self._cursor_value = InterpolatedString.create(self.cursor_value, parameters=parameters) else: @@ -57,10 +56,19 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: @property def initial_token(self) -> Optional[Any]: - return self._initial_cursor + """ + CursorPaginationStrategy does not have an initial value because the next cursor is typically included + in the response of the first request. For Resumable Full Refresh streams that checkpoint the page + cursor, the next cursor should be read from the state or stream slice object. + """ + return None def next_page_token( - self, response: requests.Response, last_page_size: int, last_record: Optional[Record] + self, + response: requests.Response, + last_page_size: int, + last_record: Optional[Record], + last_page_token_value: Optional[Any] = None, ) -> Optional[Any]: decoded_response = next(self.decoder.decode(response)) @@ -87,8 +95,5 @@ def next_page_token( ) return token if token else None - def reset(self, reset_value: Optional[Any] = None) -> None: - self._initial_cursor = reset_value - def get_page_size(self) -> Optional[int]: return self.page_size diff --git a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/offset_increment.py b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/offset_increment.py index 37ba3bbfa..2e09592f1 100644 --- a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/offset_increment.py +++ b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/offset_increment.py @@ -52,7 +52,6 @@ class OffsetIncrement(PaginationStrategy): inject_on_first_request: bool = False def __post_init__(self, parameters: Mapping[str, Any]) -> None: - self._offset = 0 page_size = str(self.page_size) if isinstance(self.page_size, int) else self.page_size if page_size: self._page_size: Optional[InterpolatedString] = InterpolatedString( @@ -64,11 +63,15 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: @property def initial_token(self) -> Optional[Any]: if self.inject_on_first_request: - return self._offset + return 0 return None def next_page_token( - self, response: requests.Response, last_page_size: int, last_record: Optional[Record] + self, + response: requests.Response, + last_page_size: int, + last_record: Optional[Record], + last_page_token_value: Optional[Any] = None, ) -> Optional[Any]: decoded_response = next(self.decoder.decode(response)) @@ -78,9 +81,17 @@ def next_page_token( and last_page_size < self._page_size.eval(self.config, response=decoded_response) ) or last_page_size == 0: return None + elif last_page_token_value is None: + # If the OffsetIncrement strategy does not inject on the first request, the incoming last_page_token_value + # will be None. For this case, we assume that None was the first page and progress to the next offset + return 0 + last_page_size + elif not isinstance(last_page_token_value, int): + raise ValueError( + "The page token for a OffsetIncrement pagination strategy must be an integer" + ) else: - self._offset += last_page_size - return self._offset + next_page_token_value = last_page_token_value + last_page_size + return next_page_token_value def reset(self, reset_value: Optional[Any] = 0) -> None: if not isinstance(reset_value, int): diff --git a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/page_increment.py b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/page_increment.py index 2227fffec..a482c0443 100644 --- a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/page_increment.py +++ b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/page_increment.py @@ -31,7 +31,6 @@ class PageIncrement(PaginationStrategy): inject_on_first_request: bool = False def __post_init__(self, parameters: Mapping[str, Any]) -> None: - self._page = self.start_from_page if isinstance(self.page_size, int) or (self.page_size is None): self._page_size = self.page_size else: @@ -43,28 +42,30 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: @property def initial_token(self) -> Optional[Any]: if self.inject_on_first_request: - return self._page + return self.start_from_page return None def next_page_token( - self, response: requests.Response, last_page_size: int, last_record: Optional[Record] + self, + response: requests.Response, + last_page_size: int, + last_record: Optional[Record], + last_page_token_value: Optional[Any], ) -> Optional[Any]: # Stop paginating when there are fewer records than the page size or the current page has no records if (self._page_size and last_page_size < self._page_size) or last_page_size == 0: return None - else: - self._page += 1 - return self._page - - def reset(self, reset_value: Optional[Any] = None) -> None: - if reset_value is None: - self._page = self.start_from_page - elif not isinstance(reset_value, int): + elif last_page_token_value is None: + # If the PageIncrement strategy does not inject on the first request, the incoming last_page_token_value + # may be None. When this is the case, we assume we've already requested the first page specified by + # start_from_page and must now get the next page + return self.start_from_page + 1 + elif not isinstance(last_page_token_value, int): raise ValueError( - f"Reset value {reset_value} for PageIncrement pagination strategy was not an integer" + "The page token for a PageIncrement pagination strategy must be an integer" ) else: - self._page = reset_value + return last_page_token_value + 1 def get_page_size(self) -> Optional[int]: return self._page_size diff --git a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/pagination_strategy.py b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/pagination_strategy.py index a55dcb131..255fa70c4 100644 --- a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/pagination_strategy.py +++ b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/pagination_strategy.py @@ -4,7 +4,7 @@ from abc import abstractmethod from dataclasses import dataclass -from typing import Any, Optional +from typing import Any, Mapping, Optional import requests @@ -26,22 +26,21 @@ def initial_token(self) -> Optional[Any]: @abstractmethod def next_page_token( - self, response: requests.Response, last_page_size: int, last_record: Optional[Record] + self, + response: requests.Response, + last_page_size: int, + last_record: Optional[Record], + last_page_token_value: Optional[Any], ) -> Optional[Any]: """ :param response: response to process :param last_page_size: the number of records read from the response :param last_record: the last record extracted from the response + :param last_page_token_value: The current value of the page token made on the last request :return: next page token. Returns None if there are no more pages to fetch """ pass - @abstractmethod - def reset(self, reset_value: Optional[Any] = None) -> None: - """ - Reset the pagination's inner state - """ - @abstractmethod def get_page_size(self) -> Optional[int]: """ diff --git a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/stop_condition.py b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/stop_condition.py index 7722c5e73..7c89ba552 100644 --- a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/stop_condition.py +++ b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/stop_condition.py @@ -44,16 +44,19 @@ def __init__(self, _delegate: PaginationStrategy, stop_condition: PaginationStop self._stop_condition = stop_condition def next_page_token( - self, response: requests.Response, last_page_size: int, last_record: Optional[Record] + self, + response: requests.Response, + last_page_size: int, + last_record: Optional[Record], + last_page_token_value: Optional[Any] = None, ) -> Optional[Any]: - # We evaluate in reverse order because the assumption is that most of the APIs using data feed structure will return records in - # descending order. In terms of performance/memory, we return the records lazily + # We evaluate in reverse order because the assumption is that most of the APIs using data feed structure + # will return records in descending order. In terms of performance/memory, we return the records lazily if last_record and self._stop_condition.is_met(last_record): return None - return self._delegate.next_page_token(response, last_page_size, last_record) - - def reset(self, reset_value: Optional[Any] = None) -> None: - self._delegate.reset(reset_value) + return self._delegate.next_page_token( + response, last_page_size, last_record, last_page_token_value + ) def get_page_size(self) -> Optional[int]: return self._delegate.get_page_size() diff --git a/airbyte_cdk/sources/declarative/retrievers/simple_retriever.py b/airbyte_cdk/sources/declarative/retrievers/simple_retriever.py index cc7040595..5560bd384 100644 --- a/airbyte_cdk/sources/declarative/retrievers/simple_retriever.py +++ b/airbyte_cdk/sources/declarative/retrievers/simple_retriever.py @@ -46,6 +46,13 @@ FULL_REFRESH_SYNC_COMPLETE_KEY = "__ab_full_refresh_sync_complete" +@dataclass +class LastResponseValue: + last_response: Optional[requests.Response] = None + last_page_size: int = 0 + last_record: Optional[Record] = None + + @dataclass class SimpleRetriever(Retriever): """ @@ -90,9 +97,6 @@ class SimpleRetriever(Retriever): def __post_init__(self, parameters: Mapping[str, Any]) -> None: self._paginator = self.paginator or NoPagination(parameters=parameters) - self._last_response: Optional[requests.Response] = None - self._last_page_size: int = 0 - self._last_record: Optional[Record] = None self._parameters = parameters self._name = ( InterpolatedString(self._name, parameters=parameters) @@ -100,10 +104,6 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: else self._name ) - # This mapping is used during a resumable full refresh syncs to indicate whether a partition has started syncing - # records. Partitions serve as the key and map to True if they already began processing records - self._partition_started: MutableMapping[Any, bool] = dict() - @property # type: ignore def name(self) -> str: """ @@ -251,17 +251,13 @@ def _request_body_json( raise ValueError("Request body json cannot be a string") return body_json - def _paginator_path( - self, - ) -> Optional[str]: + def _paginator_path(self, next_page_token: Optional[Mapping[str, Any]] = None) -> Optional[str]: """ If the paginator points to a path, follow it, else return nothing so the requester is used. - :param stream_state: - :param stream_slice: :param next_page_token: :return: """ - return self._paginator.path() + return self._paginator.path(next_page_token=next_page_token) def _parse_response( self, @@ -270,10 +266,10 @@ def _parse_response( records_schema: Mapping[str, Any], stream_slice: Optional[StreamSlice] = None, next_page_token: Optional[Mapping[str, Any]] = None, - ) -> Iterable[Record]: + ) -> Iterable[Union[Record, LastResponseValue]]: if not response: - self._last_response = None yield from [] + return LastResponseValue(last_response=None, last_page_size=0, last_record=None) else: self._last_response = response record_generator = self.record_selector.select_records( @@ -283,11 +279,16 @@ def _parse_response( stream_slice=stream_slice, next_page_token=next_page_token, ) - self._last_page_size = 0 + + last_page_size = 0 + last_record = None for record in record_generator: - self._last_page_size += 1 - self._last_record = record + last_page_size += 1 + last_record = record yield record + return LastResponseValue( + last_response=response, last_page_size=last_page_size, last_record=last_record + ) @property # type: ignore def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]: @@ -299,7 +300,13 @@ def primary_key(self, value: str) -> None: if not isinstance(value, property): self._primary_key = value - def _next_page_token(self, response: requests.Response) -> Optional[Mapping[str, Any]]: + def _next_page_token( + self, + response: requests.Response, + last_page_size: int, + last_record: Optional[Record], + last_page_token_value: Optional[Any], + ) -> Optional[Mapping[str, Any]]: """ Specifies a pagination strategy. @@ -307,7 +314,12 @@ def _next_page_token(self, response: requests.Response) -> Optional[Mapping[str, :return: The token for the next page from the input response object. Returning None means there are no more pages to read in this response. """ - return self._paginator.next_page_token(response, self._last_page_size, self._last_record) + return self._paginator.next_page_token( + response=response, + last_page_size=last_page_size, + last_record=last_record, + last_page_token_value=last_page_token_value, + ) def _fetch_next_page( self, @@ -316,7 +328,7 @@ def _fetch_next_page( next_page_token: Optional[Mapping[str, Any]] = None, ) -> Optional[requests.Response]: return self.requester.send_request( - path=self._paginator_path(), + path=self._paginator_path(next_page_token=next_page_token), stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token, @@ -350,15 +362,35 @@ def _read_pages( stream_slice: StreamSlice, ) -> Iterable[StreamData]: pagination_complete = False - next_page_token = None + initial_token = self._paginator.get_initial_token() + next_page_token = {"next_page_token": initial_token} if initial_token else None while not pagination_complete: response = self._fetch_next_page(stream_state, stream_slice, next_page_token) - yield from records_generator_fn(response) + + last_page_size = 0 + last_record = None + + # todo: There has to be a better way of yielding records and still emitting a final return value + try: + yield from records_generator_fn(response) + except StopIteration as e: + last_response_value = e.value + if isinstance(last_response_value, LastResponseValue): + last_page_size = last_response_value.last_page_size + last_record = last_response_value.last_record if not response: pagination_complete = True else: - next_page_token = self._next_page_token(response) + last_page_token_value = ( + next_page_token.get("next_page_token") if next_page_token else None + ) + next_page_token = self._next_page_token( + response=response, + last_page_size=last_page_size, + last_record=last_record, + last_page_token_value=last_page_token_value, + ) if not next_page_token: pagination_complete = True @@ -371,15 +403,39 @@ def _read_single_page( stream_state: Mapping[str, Any], stream_slice: StreamSlice, ) -> Iterable[StreamData]: - response = self._fetch_next_page(stream_state, stream_slice) - yield from records_generator_fn(response) + initial_token = stream_state.get("next_page_token") + if initial_token is None: + initial_token = self._paginator.get_initial_token() + next_page_token = {"next_page_token": initial_token} if initial_token else None + + response = self._fetch_next_page(stream_state, stream_slice, next_page_token) + + last_page_size = 0 + last_record = None + + # todo: There has to be a better way of yielding records and still emitting a final return value + try: + record_generator = records_generator_fn(response) + while True: + yield next(record_generator) + except StopIteration as e: + last_response_value = e.value + if isinstance(last_response_value, LastResponseValue): + last_page_size = last_response_value.last_page_size + last_record = last_response_value.last_record if not response: next_page_token: Mapping[str, Any] = {FULL_REFRESH_SYNC_COMPLETE_KEY: True} else: - next_page_token = self._next_page_token(response) or { - FULL_REFRESH_SYNC_COMPLETE_KEY: True - } + last_page_token_value = ( + next_page_token.get("next_page_token") if next_page_token else None + ) + next_page_token = self._next_page_token( + response=response, + last_page_size=last_page_size, + last_record=last_record, + last_page_token_value=last_page_token_value, + ) or {FULL_REFRESH_SYNC_COMPLETE_KEY: True} if self.cursor: self.cursor.close_slice( @@ -414,25 +470,14 @@ def read_records( if self.cursor and isinstance(self.cursor, ResumableFullRefreshCursor): stream_state = self.state - # Before syncing the RFR stream, we check if the job's prior attempt was successful and don't need to fetch more records - # The platform deletes stream state for full refresh streams before starting a new job, so we don't need to worry about - # this value existing for the initial attempt + # Before syncing the RFR stream, we check if the job's prior attempt was successful and don't need to + # fetch more records. The platform deletes stream state for full refresh streams before starting a + # new job, so we don't need to worry about this value existing for the initial attempt if stream_state.get(FULL_REFRESH_SYNC_COMPLETE_KEY): return - cursor_value = stream_state.get("next_page_token") - - # The first attempt to read a page for the current partition should reset the paginator to the current - # cursor state which is initially assigned to the incoming state from the platform - partition_key = self._to_partition_key(_slice.partition) - if partition_key not in self._partition_started: - self._partition_started[partition_key] = True - self._paginator.reset(reset_value=cursor_value) yield from self._read_single_page(record_generator, stream_state, _slice) else: - # Fixing paginator types has a long tail of dependencies - self._paginator.reset() - for stream_data in self._read_pages(record_generator, self.state, _slice): current_record = self._extract_record(stream_data, _slice) if self.cursor and current_record: @@ -518,13 +563,18 @@ def _parse_records( stream_state: Mapping[str, Any], records_schema: Mapping[str, Any], stream_slice: Optional[StreamSlice], - ) -> Iterable[StreamData]: - yield from self._parse_response( + ) -> Iterable[Union[StreamData, LastResponseValue]]: + record_generator = self._parse_response( response, stream_slice=stream_slice, stream_state=stream_state, records_schema=records_schema, ) + try: + while True: + yield next(record_generator) + except StopIteration as e: + return e.value def must_deduplicate_query_params(self) -> bool: return True @@ -562,7 +612,7 @@ def _fetch_next_page( next_page_token: Optional[Mapping[str, Any]] = None, ) -> Optional[requests.Response]: return self.requester.send_request( - path=self._paginator_path(), + path=self._paginator_path(next_page_token=next_page_token), stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token, diff --git a/unit_tests/sources/declarative/requesters/paginators/test_cursor_pagination_strategy.py b/unit_tests/sources/declarative/requesters/paginators/test_cursor_pagination_strategy.py index 83d46918d..997920687 100644 --- a/unit_tests/sources/declarative/requesters/paginators/test_cursor_pagination_strategy.py +++ b/unit_tests/sources/declarative/requesters/paginators/test_cursor_pagination_strategy.py @@ -12,6 +12,7 @@ from airbyte_cdk.sources.declarative.requesters.paginators.strategies.cursor_pagination_strategy import ( CursorPaginationStrategy, ) +from airbyte_cdk.sources.types import Record @pytest.mark.parametrize( @@ -79,7 +80,7 @@ def test_cursor_pagination_strategy(template_string, stop_condition, expected_to "characters": {}, } response._content = json.dumps(response_body).encode("utf-8") - last_record = {"id": 1, "more_records": True} + last_record = Record(data={"id": 1, "more_records": True}, stream_name="stream_name") token = strategy.next_page_token(response, 1, last_record) assert expected_token == token @@ -111,18 +112,3 @@ def test_last_record_is_node_if_no_records(): response = requests.Response() next_page_token = strategy.next_page_token(response, 0, None) assert next_page_token is None - - -def test_reset_with_initial_token(): - strategy = CursorPaginationStrategy( - page_size=10, - cursor_value="{{ response.next_page }}", - config={}, - parameters={}, - ) - - assert strategy.initial_token is None - - strategy.reset("https://for-all-mankind.nasa.com/api/v1/astronauts") - - assert strategy.initial_token == "https://for-all-mankind.nasa.com/api/v1/astronauts" diff --git a/unit_tests/sources/declarative/requesters/paginators/test_default_paginator.py b/unit_tests/sources/declarative/requesters/paginators/test_default_paginator.py index 1cd34c42f..406449d06 100644 --- a/unit_tests/sources/declarative/requesters/paginators/test_default_paginator.py +++ b/unit_tests/sources/declarative/requesters/paginators/test_default_paginator.py @@ -208,12 +208,12 @@ def test_default_paginator_with_cursor( json.dumps(response_body).encode("utf-8") if decoder == JsonDecoder else response_body ) - actual_next_page_token = paginator.next_page_token(response, 2, last_record) - actual_next_path = paginator.path() - actual_request_params = paginator.get_request_params() - actual_headers = paginator.get_request_headers() - actual_body_data = paginator.get_request_body_data() - actual_body_json = paginator.get_request_body_json() + actual_next_page_token = paginator.next_page_token(response, 2, last_record, None) + actual_next_path = paginator.path(actual_next_page_token) + actual_request_params = paginator.get_request_params(next_page_token=actual_next_page_token) + actual_headers = paginator.get_request_headers(next_page_token=actual_next_page_token) + actual_body_data = paginator.get_request_body_data(next_page_token=actual_next_page_token) + actual_body_json = paginator.get_request_body_json(next_page_token=actual_next_page_token) assert actual_next_page_token == expected_next_page_token assert actual_next_path == expected_updated_path assert actual_request_params == expected_request_params @@ -281,8 +281,8 @@ def test_paginator_request_param_interpolation( response_body = {"next": "https://airbyte.io/next_url"} response._content = json.dumps(response_body).encode("utf-8") last_record = {"id": 1} - paginator.next_page_token(response, 2, last_record) - actual_request_params = paginator.get_request_params() + next_page_token = paginator.next_page_token(response, 2, last_record, None) + actual_request_params = paginator.get_request_params(next_page_token=next_page_token) assert actual_request_params == expected_request_params @@ -314,48 +314,6 @@ def test_page_size_option_cannot_be_set_if_strategy_has_no_limit(): pass -@pytest.mark.parametrize( - "inject_on_first_request", - [ - (True), - (False), - ], - ids=[ - "test_reset_inject_on_first_request", - "test_reset_no_inject_on_first_request", - ], -) -def test_reset(inject_on_first_request): - page_size_request_option = RequestOption( - inject_into=RequestOptionType.request_parameter, field_name="limit", parameters={} - ) - page_token_request_option = RequestOption( - inject_into=RequestOptionType.request_parameter, field_name="offset", parameters={} - ) - url_base = "https://airbyte.io" - config = {} - strategy = OffsetIncrement( - config={}, page_size=2, inject_on_first_request=inject_on_first_request, parameters={} - ) - paginator = DefaultPaginator( - strategy, - config, - url_base, - parameters={}, - page_size_option=page_size_request_option, - page_token_option=page_token_request_option, - ) - initial_request_parameters = paginator.get_request_params() - response = requests.Response() - response._content = json.dumps({}).encode("utf-8") - paginator.next_page_token(response, 2, {"a key": "a value"}) - request_parameters_for_second_request = paginator.get_request_params() - paginator.reset() - request_parameters_after_reset = paginator.get_request_params() - assert initial_request_parameters == request_parameters_after_reset - assert request_parameters_for_second_request != request_parameters_after_reset - - def test_initial_token_with_offset_pagination(): page_size_request_option = RequestOption( inject_into=RequestOptionType.request_parameter, field_name="limit", parameters={} @@ -374,7 +332,10 @@ def test_initial_token_with_offset_pagination(): page_size_option=page_size_request_option, page_token_option=page_token_request_option, ) - initial_request_parameters = paginator.get_request_params() + initial_token = paginator.get_initial_token() + next_page_token = {"next_page_token": initial_token} + + initial_request_parameters = paginator.get_request_params(next_page_token=next_page_token) assert initial_request_parameters == {"limit": 2, "offset": 0} diff --git a/unit_tests/sources/declarative/requesters/paginators/test_no_paginator.py b/unit_tests/sources/declarative/requesters/paginators/test_no_paginator.py index 92bcc55a2..21beed576 100644 --- a/unit_tests/sources/declarative/requesters/paginators/test_no_paginator.py +++ b/unit_tests/sources/declarative/requesters/paginators/test_no_paginator.py @@ -9,5 +9,5 @@ def test(): paginator = NoPagination(parameters={}) - next_page_token = paginator.next_page_token(requests.Response(), 0, []) + next_page_token = paginator.next_page_token(requests.Response(), 0, [], None) assert next_page_token == {} diff --git a/unit_tests/sources/declarative/requesters/paginators/test_offset_increment.py b/unit_tests/sources/declarative/requesters/paginators/test_offset_increment.py index d443132ed..4cd827e88 100644 --- a/unit_tests/sources/declarative/requesters/paginators/test_offset_increment.py +++ b/unit_tests/sources/declarative/requesters/paginators/test_offset_increment.py @@ -14,36 +14,46 @@ @pytest.mark.parametrize( - "page_size, parameters, last_page_size, last_record, expected_next_page_token, expected_offset", + "page_size, parameters, last_page_size, last_record, last_page_token_value, expected_next_page_token, expected_offset", [ - pytest.param("2", {}, 2, {"id": 1}, 2, 2, id="test_same_page_size"), - pytest.param(2, {}, 2, {"id": 1}, 2, 2, id="test_same_page_size"), + pytest.param("2", {}, 2, {"id": 1}, 4, 6, 2, id="test_same_page_size"), + pytest.param(2, {}, 2, {"id": 1}, 4, 6, 2, id="test_same_page_size"), pytest.param( "{{ parameters['page_size'] }}", {"page_size": 3}, 2, {"id": 1}, + 3, None, 0, id="test_larger_page_size", ), - pytest.param(None, {}, 0, [], None, 0, id="test_stop_if_no_records"), + pytest.param(None, {}, 0, [], 3, None, 0, id="test_stop_if_no_records"), pytest.param( "{{ response['page_metadata']['limit'] }}", {}, 2, {"id": 1}, + 3, None, 0, id="test_page_size_from_response", ), + pytest.param( + 2, {}, 2, {"id": 1}, None, 2, 2, id="test_get_second_page_with_first_page_not_injected" + ), ], ) def test_offset_increment_paginator_strategy( - page_size, parameters, last_page_size, last_record, expected_next_page_token, expected_offset + page_size, + parameters, + last_page_size, + last_record, + last_page_token_value, + expected_next_page_token, + expected_offset, ): paginator_strategy = OffsetIncrement(page_size=page_size, parameters=parameters, config={}) - assert paginator_strategy._offset == 0 response = requests.Response() @@ -51,12 +61,16 @@ def test_offset_increment_paginator_strategy( response_body = {"next": "https://airbyte.io/next_url", "page_metadata": {"limit": 5}} response._content = json.dumps(response_body).encode("utf-8") - next_page_token = paginator_strategy.next_page_token(response, last_page_size, last_record) + next_page_token = paginator_strategy.next_page_token( + response, last_page_size, last_record, last_page_token_value + ) assert expected_next_page_token == next_page_token - assert expected_offset == paginator_strategy._offset - paginator_strategy.reset() - assert 0 == paginator_strategy._offset + # Validate that the PaginationStrategy is stateless and calling next_page_token() again returns the same result + next_page_token = paginator_strategy.next_page_token( + response, last_page_size, last_record, last_page_token_value + ) + assert expected_next_page_token == next_page_token def test_offset_increment_paginator_strategy_rises(): @@ -85,27 +99,3 @@ def test_offset_increment_paginator_strategy_initial_token( ) assert paginator_strategy.initial_token == expected_initial_token - - -@pytest.mark.parametrize( - "reset_value, expected_initial_token, expected_error", - [ - pytest.param(25, 25, None, id="test_reset_with_offset_value"), - pytest.param(None, 0, None, id="test_reset_with_default"), - pytest.param("Nope", None, ValueError, id="test_reset_with_invalid_value"), - ], -) -def test_offset_increment_reset(reset_value, expected_initial_token, expected_error): - paginator_strategy = OffsetIncrement( - page_size=20, parameters={}, config={}, inject_on_first_request=True - ) - - if expected_error: - with pytest.raises(expected_error): - paginator_strategy.reset(reset_value=reset_value) - else: - if reset_value is None: - paginator_strategy.reset() - else: - paginator_strategy.reset(reset_value=reset_value) - assert paginator_strategy.initial_token == expected_initial_token diff --git a/unit_tests/sources/declarative/requesters/paginators/test_page_increment.py b/unit_tests/sources/declarative/requesters/paginators/test_page_increment.py index 56564f925..32af20b50 100644 --- a/unit_tests/sources/declarative/requesters/paginators/test_page_increment.py +++ b/unit_tests/sources/declarative/requesters/paginators/test_page_increment.py @@ -14,26 +14,45 @@ @pytest.mark.parametrize( - "page_size, start_from, last_page_size, last_record, expected_next_page_token, expected_offset", + "page_size, start_from, last_page_size, last_record, last_page_token_value, expected_next_page_token, expected_offset", [ - pytest.param(2, 1, 2, {"id": 1}, 2, 2, id="test_same_page_size_start_from_0"), - pytest.param(3, 1, 2, {"id": 1}, None, 1, id="test_larger_page_size_start_from_0"), - pytest.param(2, 0, 2, {"id": 1}, 1, 1, id="test_same_page_size_start_from_1"), - pytest.param(3, 0, 2, {"id": 1}, None, 0, id="test_larger_page_size_start_from_0"), - pytest.param(None, 0, 0, None, None, 0, id="test_no_page_size"), - pytest.param("2", 0, 2, {"id": 1}, 1, 1, id="test_page_size_from_string"), + pytest.param(2, 1, 2, {"id": 1}, 3, 4, 2, id="test_same_page_size_start_from_1"), + pytest.param(3, 1, 2, {"id": 1}, 3, None, 1, id="test_larger_page_size_start_from_1"), + pytest.param(2, 0, 2, {"id": 1}, 3, 4, 1, id="test_same_page_size_start_from_0"), + pytest.param(3, 0, 2, {"id": 1}, 3, None, 0, id="test_larger_page_size_start_from_0"), + pytest.param(None, 0, 0, None, 2, None, 0, id="test_no_page_size"), + pytest.param("2", 0, 2, {"id": 1}, 3, 4, 1, id="test_page_size_from_string"), pytest.param( - "{{ config['value'] }}", 0, 2, {"id": 1}, 1, 1, id="test_page_size_from_config" + "{{ config['value'] }}", 0, 2, {"id": 1}, 3, 4, 1, id="test_page_size_from_config" + ), + pytest.param( + 2, 0, 2, {"id": 1}, None, 1, 2, id="test_start_from_not_injected_returns_second_page" + ), + pytest.param( + 2, + 10, + 2, + {"id": 1}, + None, + 11, + 2, + id="test_non_default_start_from_not_injected_returns_next_page", ), ], ) def test_page_increment_paginator_strategy( - page_size, start_from, last_page_size, last_record, expected_next_page_token, expected_offset + page_size, + start_from, + last_page_size, + last_record, + last_page_token_value, + expected_next_page_token, + expected_offset, ): paginator_strategy = PageIncrement( page_size=page_size, parameters={}, start_from_page=start_from, config={"value": 2} ) - assert paginator_strategy._page == start_from + assert paginator_strategy.start_from_page == start_from response = requests.Response() @@ -41,12 +60,16 @@ def test_page_increment_paginator_strategy( response_body = {"next": "https://airbyte.io/next_url"} response._content = json.dumps(response_body).encode("utf-8") - next_page_token = paginator_strategy.next_page_token(response, last_page_size, last_record) + next_page_token = paginator_strategy.next_page_token( + response, last_page_size, last_record, last_page_token_value + ) assert expected_next_page_token == next_page_token - assert expected_offset == paginator_strategy._page - paginator_strategy.reset() - assert start_from == paginator_strategy._page + # Validate that the PaginationStrategy is stateless and calling next_page_token() again returns the same result + next_page_token = paginator_strategy.next_page_token( + response, last_page_size, last_record, last_page_token_value + ) + assert expected_next_page_token == next_page_token @pytest.mark.parametrize( @@ -82,24 +105,3 @@ def test_page_increment_paginator_strategy_initial_token( ) assert paginator_strategy.initial_token == expected_initial_token - - -@pytest.mark.parametrize( - "reset_value, expected_initial_token, expected_error", - [ - pytest.param(25, 25, None, id="test_reset_with_offset_value"), - pytest.param(None, 0, None, id="test_reset_with_default"), - pytest.param("Nope", None, ValueError, id="test_reset_with_invalid_value"), - ], -) -def test_offset_increment_reset(reset_value, expected_initial_token, expected_error): - paginator_strategy = PageIncrement( - page_size=100, parameters={}, config={}, inject_on_first_request=True - ) - - if expected_error: - with pytest.raises(expected_error): - paginator_strategy.reset(reset_value=reset_value) - else: - paginator_strategy.reset(reset_value=reset_value) - assert paginator_strategy.initial_token == expected_initial_token diff --git a/unit_tests/sources/declarative/requesters/paginators/test_stop_condition.py b/unit_tests/sources/declarative/requesters/paginators/test_stop_condition.py index ea1d38e24..5561f92ab 100644 --- a/unit_tests/sources/declarative/requesters/paginators/test_stop_condition.py +++ b/unit_tests/sources/declarative/requesters/paginators/test_stop_condition.py @@ -86,7 +86,9 @@ def test_given_stop_condition_is_not_met_when_next_page_token_then_delegate( next_page_token = decorator.next_page_token(ANY_RESPONSE, 2, last_record) assert next_page_token == mocked_pagination_strategy.next_page_token.return_value - mocked_pagination_strategy.next_page_token.assert_called_once_with(ANY_RESPONSE, 2, last_record) + mocked_pagination_strategy.next_page_token.assert_called_once_with( + ANY_RESPONSE, 2, last_record, None + ) mocked_stop_condition.is_met.assert_has_calls([call(last_record)]) @@ -100,15 +102,9 @@ def test_given_no_records_when_next_page_token_then_delegate( next_page_token = decorator.next_page_token(ANY_RESPONSE, 0, NO_RECORD) assert next_page_token == mocked_pagination_strategy.next_page_token.return_value - mocked_pagination_strategy.next_page_token.assert_called_once_with(ANY_RESPONSE, 0, NO_RECORD) - - -def test_when_reset_then_delegate(mocked_pagination_strategy, mocked_stop_condition): - decorator = StopConditionPaginationStrategyDecorator( - mocked_pagination_strategy, mocked_stop_condition + mocked_pagination_strategy.next_page_token.assert_called_once_with( + ANY_RESPONSE, 0, NO_RECORD, None ) - decorator.reset() - mocked_pagination_strategy.reset.assert_called_once_with(None) def test_when_get_page_size_then_delegate(mocked_pagination_strategy, mocked_stop_condition): diff --git a/unit_tests/sources/declarative/retrievers/test_simple_retriever.py b/unit_tests/sources/declarative/retrievers/test_simple_retriever.py index b54527c13..79ffc9bed 100644 --- a/unit_tests/sources/declarative/retrievers/test_simple_retriever.py +++ b/unit_tests/sources/declarative/retrievers/test_simple_retriever.py @@ -3,6 +3,7 @@ # import json +from typing import Iterable, Union from unittest.mock import MagicMock, Mock, patch import pytest @@ -26,6 +27,7 @@ from airbyte_cdk.sources.declarative.requesters.request_option import RequestOptionType from airbyte_cdk.sources.declarative.requesters.requester import HttpMethod from airbyte_cdk.sources.declarative.retrievers.simple_retriever import ( + LastResponseValue, SimpleRetriever, SimpleRetrieverTestReadDecorator, ) @@ -51,10 +53,11 @@ def test_simple_retriever_full(mock_http_stream): requester.get_request_params.return_value = request_params paginator = MagicMock() + paginator.get_initial_token.return_value = None next_page_token = {"cursor": "cursor_value"} paginator.path.return_value = None paginator.next_page_token.return_value = next_page_token - paginator.get_requesyyt_headers.return_value = {} + paginator.get_request_headers.return_value = {} record_selector = MagicMock() record_selector.select_records.return_value = records @@ -66,6 +69,10 @@ def test_simple_retriever_full(mock_http_stream): response = requests.Response() response.status_code = 200 + last_page_size = 2 + last_record = Record(data={"id": "1a"}, stream_name="stream_name") + last_page_token_value = 0 + underlying_state = {"date": "2021-01-01"} cursor.get_stream_state.return_value = underlying_state @@ -102,18 +109,31 @@ def test_simple_retriever_full(mock_http_stream): assert retriever.primary_key == primary_key assert retriever.state == underlying_state - assert retriever._next_page_token(response) == next_page_token + assert ( + retriever._next_page_token(response, last_page_size, last_record, last_page_token_value) + == next_page_token + ) assert retriever._request_params(None, None, None) == {} assert retriever.stream_slices() == stream_slices - assert retriever._last_response is None - assert retriever._last_record is None - assert list(retriever._parse_response(response, stream_state={}, records_schema={})) == records - assert retriever._last_response == response - assert retriever._last_page_size == 2 + # assert retriever._last_response is None + # assert retriever._last_record is None + # assert list(retriever._parse_response(response, stream_state={}, records_schema={})) == records + # assert retriever._last_response == response + # assert retriever._last_page_size == 2 + + try: + assert ( + list(retriever._parse_response(response, stream_state={}, records_schema={})) == records + ) + except StopIteration as e: + last_response_values = e.value + assert isinstance(last_response_values, LastResponseValue) + assert last_response_values.last_response == response + assert last_response_values.last_record == last_record + assert last_response_values.last_page_size == 2 [r for r in retriever.read_records(SyncMode.full_refresh)] - paginator.reset.assert_called() @patch.object(SimpleRetriever, "_read_pages", return_value=iter([*request_response_logs, *records])) @@ -144,7 +164,6 @@ def test_simple_retriever_with_request_response_logs(mock_http_stream): ) actual_messages = [r for r in retriever.read_records(SyncMode.full_refresh)] - paginator.reset.assert_called() assert isinstance(actual_messages[0], AirbyteLogMessage) assert isinstance(actual_messages[1], AirbyteLogMessage) @@ -209,7 +228,7 @@ def test_simple_retriever_resumable_full_refresh_cursor_page_increment( url_base="https://airbyte.io", parameters={}, ) - paginator.reset = Mock(wraps=paginator.reset) + # paginator.reset = Mock(wraps=paginator.reset) stream_slicer = ResumableFullRefreshCursor(parameters={}) if initial_state: @@ -243,8 +262,6 @@ def test_simple_retriever_resumable_full_refresh_cursor_page_increment( assert actual_records == expected_records[5:] assert retriever.state == {"__ab_full_refresh_sync_complete": True} - paginator.reset.assert_called_once_with(reset_value=expected_reset_value) - @pytest.mark.parametrize( "initial_state, expected_reset_value, expected_next_page", @@ -331,7 +348,6 @@ def test_simple_retriever_resumable_full_refresh_cursor_reset_cursor_pagination( "https://for-all-mankind.nasa.com/api/v1/astronauts?next_page=gordo_stevens", json=response_body_2, ) - stream.retriever.paginator.reset = Mock(wraps=stream.retriever.paginator.reset) stream_slicer = ResumableFullRefreshCursor(parameters={}) if initial_state: stream_slicer.set_initial_state(initial_state) @@ -360,8 +376,6 @@ def test_simple_retriever_resumable_full_refresh_cursor_reset_cursor_pagination( assert actual_records == expected_records[5:] assert stream.retriever.state == {"__ab_full_refresh_sync_complete": True} - stream.retriever.paginator.reset.assert_called_once_with(reset_value=expected_reset_value) - def test_simple_retriever_resumable_full_refresh_cursor_reset_skip_completed_stream(): expected_records = [ @@ -391,7 +405,7 @@ def test_simple_retriever_resumable_full_refresh_cursor_reset_skip_completed_str url_base="https://airbyte.io", parameters={}, ) - paginator.reset = Mock(wraps=paginator.reset) + paginator.get_initial_token = Mock(wraps=paginator.get_initial_token) stream_slicer = ResumableFullRefreshCursor(parameters={}) stream_slicer.set_initial_state({"__ab_full_refresh_sync_complete": True}) @@ -416,7 +430,7 @@ def test_simple_retriever_resumable_full_refresh_cursor_reset_skip_completed_str assert len(actual_records) == 0 assert retriever.state == {"__ab_full_refresh_sync_complete": True} - paginator.reset.assert_not_called() + paginator.get_initial_token.assert_not_called() @pytest.mark.parametrize( @@ -614,8 +628,6 @@ def test_request_body_data( paginator.get_request_body_data.return_value = paginator_body_data requester = MagicMock(use_cache=False) - # stream_slicer = MagicMock() - # stream_slicer.get_request_body_data.return_value = request_options_provider_body_data request_option_provider = MagicMock() request_option_provider.get_request_body_data.return_value = request_options_provider_body_data @@ -667,7 +679,7 @@ def test_path(test_name, requester_path, paginator_path, expected_path): config={}, ) - actual_path = retriever._paginator_path() + actual_path = retriever._paginator_path(next_page_token=None) assert actual_path == expected_path diff --git a/unit_tests/sources/declarative/test_manifest_declarative_source.py b/unit_tests/sources/declarative/test_manifest_declarative_source.py index a132757a6..e672c59b9 100644 --- a/unit_tests/sources/declarative/test_manifest_declarative_source.py +++ b/unit_tests/sources/declarative/test_manifest_declarative_source.py @@ -1278,7 +1278,7 @@ def _create_page(response_body): ) * 10, [{"ABC": 0}, {"AED": 1}], - [call({}, {})], + [call({}, {}, None)], ), ( "test_read_manifest_with_added_fields", @@ -1365,7 +1365,7 @@ def _create_page(response_body): {"ABC": 0, "added_field_key": "added_field_value"}, {"AED": 1, "added_field_key": "added_field_value"}, ], - [call({}, {})], + [call({}, {}, None)], ), ( "test_read_manifest_with_flatten_fields", @@ -1535,7 +1535,14 @@ def _create_page(response_body): ) * 10, [{"ABC": 0}, {"AED": 1}, {"USD": 2}], - [call({}, {}), call({"next_page_token": "next"}, {"next_page_token": "next"})], + [ + call({}, {}, None), + call( + {"next_page_token": "next"}, + {"next_page_token": "next"}, + {"next_page_token": "next"}, + ), + ], ), ( "test_no_pagination_with_partition_router", From 96e6cb17f5574cfef6c132bd3ca6113f6dedeb3c Mon Sep 17 00:00:00 2001 From: brianjlai Date: Mon, 23 Dec 2024 17:10:14 -0800 Subject: [PATCH 2/5] solve problem with connector builder server test reads, rework record yielding to be simpler, fix tests, formatting, mypy errors --- .../concurrent_declarative_source.py | 31 ++++++-- .../paginators/default_paginator.py | 11 ++- .../requesters/paginators/no_pagination.py | 2 +- .../requesters/paginators/paginator.py | 2 +- .../retrievers/simple_retriever.py | 74 +++++++------------ .../test_concurrent_declarative_source.py | 18 +++-- .../test_manifest_declarative_source.py | 2 +- 7 files changed, 70 insertions(+), 70 deletions(-) diff --git a/airbyte_cdk/sources/declarative/concurrent_declarative_source.py b/airbyte_cdk/sources/declarative/concurrent_declarative_source.py index aa3cea705..7cf8eb833 100644 --- a/airbyte_cdk/sources/declarative/concurrent_declarative_source.py +++ b/airbyte_cdk/sources/declarative/concurrent_declarative_source.py @@ -194,10 +194,11 @@ def _group_streams( # Some low-code sources use a combination of DeclarativeStream and regular Python streams. We can't inspect # these legacy Python streams the way we do low-code streams to determine if they are concurrent compatible, # so we need to treat them as synchronous - if ( - isinstance(declarative_stream, DeclarativeStream) - and name_to_stream_mapping[declarative_stream.name]["retriever"]["type"] + if isinstance(declarative_stream, DeclarativeStream) and ( + name_to_stream_mapping[declarative_stream.name]["retriever"]["type"] == "SimpleRetriever" + or name_to_stream_mapping[declarative_stream.name]["retriever"]["type"] + == "AsyncRetriever" ): incremental_sync_component_definition = name_to_stream_mapping[ declarative_stream.name @@ -217,6 +218,11 @@ def _group_streams( and not incremental_sync_component_definition ) + is_async_job_stream = ( + name_to_stream_mapping[declarative_stream.name].get("retriever", {}).get("type") + == "AsyncRetriever" + ) + if self._is_datetime_incremental_without_partition_routing( declarative_stream, incremental_sync_component_definition ): @@ -268,15 +274,24 @@ def _group_streams( elif ( is_substream_without_incremental or is_without_partition_router_or_cursor ) and hasattr(declarative_stream.retriever, "stream_slicer"): + if is_async_job_stream: + async_retriever = declarative_stream.retriever + + def async_retriever_factory_method() -> Retriever: + return async_retriever + + retriever_factory = async_retriever_factory_method + else: + retriever_factory = self._retriever_factory( + name_to_stream_mapping[declarative_stream.name], + config, + {}, + ) partition_generator = StreamSlicerPartitionGenerator( DeclarativePartitionFactory( declarative_stream.name, declarative_stream.get_json_schema(), - self._retriever_factory( - name_to_stream_mapping[declarative_stream.name], - config, - {}, - ), + retriever_factory, self.message_repository, ), declarative_stream.retriever.stream_slicer, diff --git a/airbyte_cdk/sources/declarative/requesters/paginators/default_paginator.py b/airbyte_cdk/sources/declarative/requesters/paginators/default_paginator.py index 50fbe636b..e876e4577 100644 --- a/airbyte_cdk/sources/declarative/requesters/paginators/default_paginator.py +++ b/airbyte_cdk/sources/declarative/requesters/paginators/default_paginator.py @@ -130,6 +130,7 @@ def next_page_token( last_record: Optional[Record], last_page_token_value: Optional[Any] = None, ) -> Optional[Mapping[str, Any]]: + print("At the DefaultPaginator") next_page_token = self.pagination_strategy.next_page_token( response=response, last_page_size=last_page_size, @@ -141,7 +142,7 @@ def next_page_token( else: return None - def path(self, next_page_token: Mapping[str, Any]) -> Optional[str]: + def path(self, next_page_token: Optional[Mapping[str, Any]]) -> Optional[str]: token = next_page_token.get("next_page_token") if next_page_token else None if token and self.page_token_option and isinstance(self.page_token_option, RequestPath): # Replace url base to only return the path @@ -213,6 +214,9 @@ class PaginatorTestReadDecorator(Paginator): """ In some cases, we want to limit the number of requests that are made to the backend source. This class allows for limiting the number of pages that are queried throughout a read command. + + WARNING: This decorator is not currently thread-safe like the rest of the low-code framework because it has + an internal state to track the current number of pages counted so that it can exit early during a test read """ _PAGE_COUNT_BEFORE_FIRST_NEXT_CALL = 1 @@ -227,6 +231,7 @@ def __init__(self, decorated: Paginator, maximum_number_of_pages: int = 5) -> No self._page_count = self._PAGE_COUNT_BEFORE_FIRST_NEXT_CALL def get_initial_token(self) -> Optional[Any]: + self._page_count = self._PAGE_COUNT_BEFORE_FIRST_NEXT_CALL return self._decorated.get_initial_token() def next_page_token( @@ -236,6 +241,8 @@ def next_page_token( last_record: Optional[Record], last_page_token_value: Optional[Any] = None, ) -> Optional[Mapping[str, Any]]: + print("At the PaginatorTestReadDecorator") + print(f"page count = {self._page_count} and max pages = {self._maximum_number_of_pages}") if self._page_count >= self._maximum_number_of_pages: return None @@ -244,7 +251,7 @@ def next_page_token( response, last_page_size, last_record, last_page_token_value ) - def path(self, next_page_token: Mapping[str, Any]) -> Optional[str]: + def path(self, next_page_token: Optional[Mapping[str, Any]]) -> Optional[str]: return self._decorated.path(next_page_token) def get_request_params( diff --git a/airbyte_cdk/sources/declarative/requesters/paginators/no_pagination.py b/airbyte_cdk/sources/declarative/requesters/paginators/no_pagination.py index 230899cab..7de91f5e9 100644 --- a/airbyte_cdk/sources/declarative/requesters/paginators/no_pagination.py +++ b/airbyte_cdk/sources/declarative/requesters/paginators/no_pagination.py @@ -19,7 +19,7 @@ class NoPagination(Paginator): parameters: InitVar[Mapping[str, Any]] - def path(self, next_page_token: Mapping[str, Any]) -> Optional[str]: + def path(self, next_page_token: Optional[Mapping[str, Any]]) -> Optional[str]: return None def get_request_params( diff --git a/airbyte_cdk/sources/declarative/requesters/paginators/paginator.py b/airbyte_cdk/sources/declarative/requesters/paginators/paginator.py index 2def49e3a..8b1fea69b 100644 --- a/airbyte_cdk/sources/declarative/requesters/paginators/paginator.py +++ b/airbyte_cdk/sources/declarative/requesters/paginators/paginator.py @@ -49,7 +49,7 @@ def next_page_token( pass @abstractmethod - def path(self, next_page_token: Mapping[str, Any]) -> Optional[str]: + def path(self, next_page_token: Optional[Mapping[str, Any]]) -> Optional[str]: """ Returns the URL path to hit to fetch the next page of records diff --git a/airbyte_cdk/sources/declarative/retrievers/simple_retriever.py b/airbyte_cdk/sources/declarative/retrievers/simple_retriever.py index 5560bd384..a8d9035b7 100644 --- a/airbyte_cdk/sources/declarative/retrievers/simple_retriever.py +++ b/airbyte_cdk/sources/declarative/retrievers/simple_retriever.py @@ -9,6 +9,7 @@ from typing import ( Any, Callable, + Generator, Iterable, List, Mapping, @@ -266,13 +267,11 @@ def _parse_response( records_schema: Mapping[str, Any], stream_slice: Optional[StreamSlice] = None, next_page_token: Optional[Mapping[str, Any]] = None, - ) -> Iterable[Union[Record, LastResponseValue]]: + ) -> Iterable[Record]: if not response: yield from [] - return LastResponseValue(last_response=None, last_page_size=0, last_record=None) else: - self._last_response = response - record_generator = self.record_selector.select_records( + yield from self.record_selector.select_records( response=response, stream_state=stream_state, records_schema=records_schema, @@ -280,16 +279,6 @@ def _parse_response( next_page_token=next_page_token, ) - last_page_size = 0 - last_record = None - for record in record_generator: - last_page_size += 1 - last_record = record - yield record - return LastResponseValue( - last_response=response, last_page_size=last_page_size, last_record=last_record - ) - @property # type: ignore def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]: """The stream's primary key""" @@ -357,27 +346,24 @@ def _fetch_next_page( # This logic is similar to _read_pages in the HttpStream class. When making changes here, consider making changes there as well. def _read_pages( self, - records_generator_fn: Callable[[Optional[requests.Response]], Iterable[StreamData]], + records_generator_fn: Callable[[Optional[requests.Response]], Iterable[Record]], stream_state: Mapping[str, Any], stream_slice: StreamSlice, - ) -> Iterable[StreamData]: + ) -> Iterable[Record]: pagination_complete = False initial_token = self._paginator.get_initial_token() - next_page_token = {"next_page_token": initial_token} if initial_token else None + next_page_token: Optional[Mapping[str, Any]] = ( + {"next_page_token": initial_token} if initial_token else None + ) while not pagination_complete: response = self._fetch_next_page(stream_state, stream_slice, next_page_token) last_page_size = 0 - last_record = None - - # todo: There has to be a better way of yielding records and still emitting a final return value - try: - yield from records_generator_fn(response) - except StopIteration as e: - last_response_value = e.value - if isinstance(last_response_value, LastResponseValue): - last_page_size = last_response_value.last_page_size - last_record = last_response_value.last_record + last_record: Optional[Record] = None + for record in records_generator_fn(response): + last_page_size += 1 + last_record = record + yield record if not response: pagination_complete = True @@ -399,33 +385,28 @@ def _read_pages( def _read_single_page( self, - records_generator_fn: Callable[[Optional[requests.Response]], Iterable[StreamData]], + records_generator_fn: Callable[[Optional[requests.Response]], Iterable[Record]], stream_state: Mapping[str, Any], stream_slice: StreamSlice, ) -> Iterable[StreamData]: initial_token = stream_state.get("next_page_token") if initial_token is None: initial_token = self._paginator.get_initial_token() - next_page_token = {"next_page_token": initial_token} if initial_token else None + next_page_token: Optional[Mapping[str, Any]] = ( + {"next_page_token": initial_token} if initial_token else None + ) response = self._fetch_next_page(stream_state, stream_slice, next_page_token) last_page_size = 0 - last_record = None - - # todo: There has to be a better way of yielding records and still emitting a final return value - try: - record_generator = records_generator_fn(response) - while True: - yield next(record_generator) - except StopIteration as e: - last_response_value = e.value - if isinstance(last_response_value, LastResponseValue): - last_page_size = last_response_value.last_page_size - last_record = last_response_value.last_record + last_record: Optional[Record] = None + for record in records_generator_fn(response): + last_page_size += 1 + last_record = record + yield record if not response: - next_page_token: Mapping[str, Any] = {FULL_REFRESH_SYNC_COMPLETE_KEY: True} + next_page_token = {FULL_REFRESH_SYNC_COMPLETE_KEY: True} else: last_page_token_value = ( next_page_token.get("next_page_token") if next_page_token else None @@ -563,18 +544,13 @@ def _parse_records( stream_state: Mapping[str, Any], records_schema: Mapping[str, Any], stream_slice: Optional[StreamSlice], - ) -> Iterable[Union[StreamData, LastResponseValue]]: - record_generator = self._parse_response( + ) -> Iterable[Record]: + yield from self._parse_response( response, stream_slice=stream_slice, stream_state=stream_state, records_schema=records_schema, ) - try: - while True: - yield next(record_generator) - except StopIteration as e: - return e.value def must_deduplicate_query_params(self) -> bool: return True diff --git a/unit_tests/sources/declarative/test_concurrent_declarative_source.py b/unit_tests/sources/declarative/test_concurrent_declarative_source.py index 18f5a97f8..3b5dd50c9 100644 --- a/unit_tests/sources/declarative/test_concurrent_declarative_source.py +++ b/unit_tests/sources/declarative/test_concurrent_declarative_source.py @@ -651,13 +651,15 @@ def test_group_streams(): concurrent_streams, synchronous_streams = source._group_streams(config=_CONFIG) # 1 full refresh stream, 2 incremental streams, 1 substream w/o incremental, 1 list based substream w/o incremental - assert len(concurrent_streams) == 5 + # 1 async job stream + assert len(concurrent_streams) == 6 ( concurrent_stream_0, concurrent_stream_1, concurrent_stream_2, concurrent_stream_3, concurrent_stream_4, + concurrent_stream_5, ) = concurrent_streams assert isinstance(concurrent_stream_0, DefaultStream) assert concurrent_stream_0.name == "party_members" @@ -669,13 +671,13 @@ def test_group_streams(): assert concurrent_stream_3.name == "party_members_skills" assert isinstance(concurrent_stream_4, DefaultStream) assert concurrent_stream_4.name == "arcana_personas" + assert isinstance(concurrent_stream_5, DefaultStream) + assert concurrent_stream_5.name == "async_job_stream" # 1 substream w/ incremental, 1 stream with async retriever - assert len(synchronous_streams) == 2 + assert len(synchronous_streams) == 1 assert isinstance(synchronous_streams[0], DeclarativeStream) assert synchronous_streams[0].name == "palace_enemies" - assert isinstance(synchronous_streams[1], DeclarativeStream) - assert synchronous_streams[1].name == "async_job_stream" @freezegun.freeze_time(time_to_freeze=datetime(2024, 9, 1, 0, 0, 0, 0, tzinfo=timezone.utc)) @@ -1456,10 +1458,10 @@ def test_streams_with_stream_state_interpolation_should_be_synchronous(): ) concurrent_streams, synchronous_streams = source._group_streams(config=_CONFIG) - # 1 full refresh stream, 2 with parent stream without incremental dependency - assert len(concurrent_streams) == 3 - # 2 incremental stream with interpolation on state (locations and party_members), 1 incremental with parent stream (palace_enemies), 1 stream with async retriever - assert len(synchronous_streams) == 4 + # 1 full refresh stream, 2 with parent stream without incremental dependency, 1 stream with async retriever + assert len(concurrent_streams) == 4 + # 2 incremental stream with interpolation on state (locations and party_members), 1 incremental with parent stream (palace_enemies) + assert len(synchronous_streams) == 3 def test_given_partition_routing_and_incremental_sync_then_stream_is_not_concurrent(): diff --git a/unit_tests/sources/declarative/test_manifest_declarative_source.py b/unit_tests/sources/declarative/test_manifest_declarative_source.py index e672c59b9..b3c9ab4bb 100644 --- a/unit_tests/sources/declarative/test_manifest_declarative_source.py +++ b/unit_tests/sources/declarative/test_manifest_declarative_source.py @@ -1449,7 +1449,7 @@ def _create_page(response_body): {"ABC": 0, "id": 1}, {"AED": 1, "id": 2}, ], - [call({}, {})], + [call({}, {}, None)], ), ( "test_read_with_pagination_no_partitions", From 74e4b5e8dc1d8175fd71ff21450a40d1e9d2803f Mon Sep 17 00:00:00 2001 From: brianjlai Date: Thu, 26 Dec 2024 14:55:08 -0800 Subject: [PATCH 3/5] more test edge cases and code cleanup --- .../paginators/default_paginator.py | 4 - .../retrievers/test_simple_retriever.py | 242 ++++++++++++++++-- 2 files changed, 219 insertions(+), 27 deletions(-) diff --git a/airbyte_cdk/sources/declarative/requesters/paginators/default_paginator.py b/airbyte_cdk/sources/declarative/requesters/paginators/default_paginator.py index e876e4577..59255c75b 100644 --- a/airbyte_cdk/sources/declarative/requesters/paginators/default_paginator.py +++ b/airbyte_cdk/sources/declarative/requesters/paginators/default_paginator.py @@ -112,7 +112,6 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: ) if isinstance(self.url_base, str): self.url_base = InterpolatedString(string=self.url_base, parameters=parameters) - # self._token: Optional[Any] = self.pagination_strategy.initial_token def get_initial_token(self) -> Optional[Any]: """ @@ -130,7 +129,6 @@ def next_page_token( last_record: Optional[Record], last_page_token_value: Optional[Any] = None, ) -> Optional[Mapping[str, Any]]: - print("At the DefaultPaginator") next_page_token = self.pagination_strategy.next_page_token( response=response, last_page_size=last_page_size, @@ -241,8 +239,6 @@ def next_page_token( last_record: Optional[Record], last_page_token_value: Optional[Any] = None, ) -> Optional[Mapping[str, Any]]: - print("At the PaginatorTestReadDecorator") - print(f"page count = {self._page_count} and max pages = {self._maximum_number_of_pages}") if self._page_count >= self._maximum_number_of_pages: return None diff --git a/unit_tests/sources/declarative/retrievers/test_simple_retriever.py b/unit_tests/sources/declarative/retrievers/test_simple_retriever.py index 79ffc9bed..5878c758f 100644 --- a/unit_tests/sources/declarative/retrievers/test_simple_retriever.py +++ b/unit_tests/sources/declarative/retrievers/test_simple_retriever.py @@ -3,7 +3,8 @@ # import json -from typing import Iterable, Union +from functools import partial +from typing import Any, Iterable, Mapping, Optional from unittest.mock import MagicMock, Mock, patch import pytest @@ -12,6 +13,8 @@ from airbyte_cdk import YamlDeclarativeSource from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, Level, SyncMode, Type from airbyte_cdk.sources.declarative.auth.declarative_authenticator import NoAuth +from airbyte_cdk.sources.declarative.decoders import JsonDecoder +from airbyte_cdk.sources.declarative.extractors import DpathExtractor, RecordSelector from airbyte_cdk.sources.declarative.incremental import ( DatetimeBasedCursor, DeclarativeCursor, @@ -23,15 +26,18 @@ ) from airbyte_cdk.sources.declarative.partition_routers import SinglePartitionRouter from airbyte_cdk.sources.declarative.requesters.paginators import DefaultPaginator -from airbyte_cdk.sources.declarative.requesters.paginators.strategies import PageIncrement +from airbyte_cdk.sources.declarative.requesters.paginators.strategies import ( + CursorPaginationStrategy, + PageIncrement, +) from airbyte_cdk.sources.declarative.requesters.request_option import RequestOptionType from airbyte_cdk.sources.declarative.requesters.requester import HttpMethod from airbyte_cdk.sources.declarative.retrievers.simple_retriever import ( - LastResponseValue, SimpleRetriever, SimpleRetrieverTestReadDecorator, ) from airbyte_cdk.sources.types import Record, StreamSlice +from airbyte_cdk.sources.utils.transform import TransformConfig, TypeTransformer A_SLICE_STATE = {"slice_state": "slice state value"} A_STREAM_SLICE = StreamSlice(cursor_slice={"stream slice": "slice value"}, partition={}) @@ -116,25 +122,6 @@ def test_simple_retriever_full(mock_http_stream): assert retriever._request_params(None, None, None) == {} assert retriever.stream_slices() == stream_slices - # assert retriever._last_response is None - # assert retriever._last_record is None - # assert list(retriever._parse_response(response, stream_state={}, records_schema={})) == records - # assert retriever._last_response == response - # assert retriever._last_page_size == 2 - - try: - assert ( - list(retriever._parse_response(response, stream_state={}, records_schema={})) == records - ) - except StopIteration as e: - last_response_values = e.value - assert isinstance(last_response_values, LastResponseValue) - assert last_response_values.last_response == response - assert last_response_values.last_record == last_record - assert last_response_values.last_page_size == 2 - - [r for r in retriever.read_records(SyncMode.full_refresh)] - @patch.object(SimpleRetriever, "_read_pages", return_value=iter([*request_response_logs, *records])) def test_simple_retriever_with_request_response_logs(mock_http_stream): @@ -228,7 +215,6 @@ def test_simple_retriever_resumable_full_refresh_cursor_page_increment( url_base="https://airbyte.io", parameters={}, ) - # paginator.reset = Mock(wraps=paginator.reset) stream_slicer = ResumableFullRefreshCursor(parameters={}) if initial_state: @@ -859,3 +845,213 @@ def test_emit_log_request_response_messages(mocker): requester.send_request.call_args_list[0][1]["log_formatter"](response) == format_http_message_mock.return_value ) + + +def test_retriever_last_page_size_for_page_increment(): + requester = MagicMock() + requester.send_request.return_value = MagicMock() + + paginator = DefaultPaginator( + config={}, + pagination_strategy=PageIncrement(config={}, page_size=5, parameters={}), + url_base="https://airbyte.io", + parameters={}, + ) + + retriever = SimpleRetriever( + name="employees", + primary_key=primary_key, + requester=requester, + paginator=paginator, + record_selector=MagicMock(), + stream_slicer=SinglePartitionRouter(parameters={}), + parameters={}, + config={}, + ) + + expected_records = [ + Record(data={"id": "1a", "name": "Cross Product Sales"}, stream_name="departments"), + Record(data={"id": "2b", "name": "Foreign Exchange"}, stream_name="departments"), + Record(data={"id": "3c", "name": "Wealth Management"}, stream_name="departments"), + Record(data={"id": "4d", "name": "Investment Banking Division"}, stream_name="departments"), + ] + + def mock_parse_records(response: Optional[requests.Response]) -> Iterable[Record]: + yield from expected_records + + actual_records = list( + retriever._read_pages( + records_generator_fn=mock_parse_records, + stream_state={}, + stream_slice=StreamSlice(cursor_slice={}, partition={}), + ) + ) + assert actual_records == expected_records + + +def test_retriever_last_record_for_page_increment(): + requester = MagicMock() + requester.send_request.return_value = MagicMock() + + paginator = DefaultPaginator( + config={}, + pagination_strategy=CursorPaginationStrategy( + cursor_value="{{ last_record['id'] }}", + stop_condition="{{ last_record['last_record'] }}", + config={}, + parameters={}, + ), + url_base="https://airbyte.io", + parameters={}, + ) + + retriever = SimpleRetriever( + name="employees", + primary_key=primary_key, + requester=requester, + paginator=paginator, + record_selector=MagicMock(), + stream_slicer=SinglePartitionRouter(parameters={}), + parameters={}, + config={}, + ) + + expected_records = [ + Record(data={"id": "a", "name": "Cross Product Sales"}, stream_name="departments"), + Record(data={"id": "b", "name": "Foreign Exchange"}, stream_name="departments"), + Record(data={"id": "c", "name": "Wealth Management"}, stream_name="departments"), + Record( + data={"id": "d", "name": "Investment Banking Division", "last_record": True}, + stream_name="departments", + ), + ] + + def mock_parse_records(response: Optional[requests.Response]) -> Iterable[Record]: + yield from expected_records + + actual_records = list( + retriever._read_pages( + records_generator_fn=mock_parse_records, + stream_state={}, + stream_slice=StreamSlice(cursor_slice={}, partition={}), + ) + ) + assert actual_records == expected_records + + +def test_retriever_is_stateless(): + """ + Special test case to verify that retrieving the pages for a given slice does not affect an internal + state of the component. Specifically, because this test don't call any type of reset so invoking the + _read_pages() method twice will fail if there is an internal state (and is therefore not stateless) + because the page count will not be reset. + """ + + page_response_1 = requests.Response() + page_response_1.status_code = 200 + page_response_1._content = json.dumps( + { + "employees": [ + {"id": "0", "first_name": "eric", "last_name": "tao"}, + {"id": "1", "first_name": "rishi", "last_name": "ramdani"}, + {"id": "2", "first_name": "harper", "last_name": "stern"}, + {"id": "3", "first_name": "erobertric", "last_name": "spearing"}, + {"id": "4", "first_name": "yasmin", "last_name": "kara-hanani"}, + ] + } + ).encode("utf-8") + + page_response_2 = requests.Response() + page_response_2.status_code = 200 + page_response_2._content = json.dumps( + { + "employees": [ + {"id": "5", "first_name": "daria", "last_name": "greenock"}, + {"id": "6", "first_name": "venetia", "last_name": "berens"}, + {"id": "7", "first_name": "kenny", "last_name": "killbane"}, + ] + } + ).encode("utf-8") + + def mock_send_request( + next_page_token: Optional[Mapping[str, Any]] = None, **kwargs + ) -> Optional[requests.Response]: + page_number = next_page_token.get("next_page_token") if next_page_token else None + if page_number is None: + return page_response_1 + elif page_number == 1: + return page_response_2 + else: + raise ValueError(f"Requested an invalid page number {page_number}") + + requester = MagicMock() + requester.send_request.side_effect = mock_send_request + + decoder = JsonDecoder(parameters={}) + extractor = DpathExtractor( + field_path=["employees"], decoder=decoder, config=config, parameters={} + ) + record_selector = RecordSelector( + name="employees", + extractor=extractor, + record_filter=None, + transformations=[], + config=config, + parameters={}, + schema_normalization=TypeTransformer(TransformConfig.DefaultSchemaNormalization), + ) + + paginator = DefaultPaginator( + config={}, + pagination_strategy=PageIncrement(config={}, page_size=5, parameters={}), + url_base="https://airbyte.io", + parameters={}, + ) + + retriever = SimpleRetriever( + name="employees", + primary_key=primary_key, + requester=requester, + paginator=paginator, + record_selector=record_selector, + stream_slicer=SinglePartitionRouter(parameters={}), + parameters={}, + config={}, + ) + + _slice = StreamSlice(cursor_slice={}, partition={}) + + record_generator = partial( + retriever._parse_records, + stream_state=retriever.state or {}, + stream_slice=_slice, + records_schema={}, + ) + + # We call _read_pages() because the existing read_records() used to modify and reset state whereas + # _read_pages() did not invoke any methods to reset state + actual_records = list( + retriever._read_pages( + records_generator_fn=record_generator, stream_state={}, stream_slice=_slice + ) + ) + assert len(actual_records) == 8 + assert actual_records[0] == Record( + data={"id": "0", "first_name": "eric", "last_name": "tao"}, stream_name="employees" + ) + assert actual_records[7] == Record( + data={"id": "7", "first_name": "kenny", "last_name": "killbane"}, stream_name="employees" + ) + + actual_records = list( + retriever._read_pages( + records_generator_fn=record_generator, stream_state={}, stream_slice=_slice + ) + ) + assert len(actual_records) == 8 + assert actual_records[2] == Record( + data={"id": "2", "first_name": "harper", "last_name": "stern"}, stream_name="employees" + ) + assert actual_records[5] == Record( + data={"id": "5", "first_name": "daria", "last_name": "greenock"}, stream_name="employees" + ) From 9ac92bf24c0f2dfca12d4e8d94c9a5ada6affe45 Mon Sep 17 00:00:00 2001 From: brianjlai Date: Thu, 26 Dec 2024 15:58:04 -0800 Subject: [PATCH 4/5] some comments --- .../sources/declarative/concurrent_declarative_source.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/airbyte_cdk/sources/declarative/concurrent_declarative_source.py b/airbyte_cdk/sources/declarative/concurrent_declarative_source.py index 7cf8eb833..24d925a8a 100644 --- a/airbyte_cdk/sources/declarative/concurrent_declarative_source.py +++ b/airbyte_cdk/sources/declarative/concurrent_declarative_source.py @@ -275,6 +275,8 @@ def _group_streams( is_substream_without_incremental or is_without_partition_router_or_cursor ) and hasattr(declarative_stream.retriever, "stream_slicer"): if is_async_job_stream: + # A stream's AsyncRetriever must be shared across all partitions because it uses a + # shared JobRepository to manage the state of jobs requests and when they are ready async_retriever = declarative_stream.retriever def async_retriever_factory_method() -> Retriever: From c704a22b37b8fc1a27430aa50ae917c89b0e0437 Mon Sep 17 00:00:00 2001 From: brianjlai Date: Thu, 26 Dec 2024 16:19:02 -0800 Subject: [PATCH 5/5] remove unneeded class --- .../retrievers/simple_retriever.py | 21 +------------------ 1 file changed, 1 insertion(+), 20 deletions(-) diff --git a/airbyte_cdk/sources/declarative/retrievers/simple_retriever.py b/airbyte_cdk/sources/declarative/retrievers/simple_retriever.py index a8d9035b7..d167a84bc 100644 --- a/airbyte_cdk/sources/declarative/retrievers/simple_retriever.py +++ b/airbyte_cdk/sources/declarative/retrievers/simple_retriever.py @@ -6,19 +6,7 @@ from dataclasses import InitVar, dataclass, field from functools import partial from itertools import islice -from typing import ( - Any, - Callable, - Generator, - Iterable, - List, - Mapping, - MutableMapping, - Optional, - Set, - Tuple, - Union, -) +from typing import Any, Callable, Iterable, List, Mapping, Optional, Set, Tuple, Union import requests @@ -47,13 +35,6 @@ FULL_REFRESH_SYNC_COMPLETE_KEY = "__ab_full_refresh_sync_complete" -@dataclass -class LastResponseValue: - last_response: Optional[requests.Response] = None - last_page_size: int = 0 - last_record: Optional[Record] = None - - @dataclass class SimpleRetriever(Retriever): """