From 649f418ef596428689697fb3e573eea0c8ef2117 Mon Sep 17 00:00:00 2001 From: Daniil Gusev Date: Thu, 21 Nov 2024 16:57:36 +0100 Subject: [PATCH] Simplify the Kinesis Source implementation (#650) --- .../sources/community/kinesis/consumer.py | 159 +++++++++--------- .../sources/community/kinesis/kinesis.py | 48 ++---- 2 files changed, 94 insertions(+), 113 deletions(-) diff --git a/quixstreams/sources/community/kinesis/consumer.py b/quixstreams/sources/community/kinesis/consumer.py index 3dce7f330..ec1cb3458 100644 --- a/quixstreams/sources/community/kinesis/consumer.py +++ b/quixstreams/sources/community/kinesis/consumer.py @@ -1,9 +1,11 @@ import logging import time -from typing import Callable, Literal, Optional, Protocol +from typing import Callable, Generator, Literal, Optional, TypedDict from typing_extensions import Self +from quixstreams.sources import StatefulSource + try: import boto3 from botocore.exceptions import ClientError @@ -23,43 +25,39 @@ AutoOffsetResetType = Literal["earliest", "latest"] -class KinesisCheckpointer(Protocol): +class KinesisStreamShardsNotFound(Exception): + """Raised when the Kinesis Stream has no shards""" + + +class KinesisCheckpointer: + def __init__(self, stateful_source: StatefulSource, commit_interval: float = 5.0): + self._source = stateful_source + self._last_committed_at = time.monotonic() + self._commit_interval = commit_interval + @property def last_committed_at(self) -> float: - yield ... + return self._last_committed_at - def get(self, key: str) -> Optional[str]: ... + def get(self, shard_id: str) -> Optional[str]: + return self._source.state.get(shard_id) - def set(self, key: str, value: str): ... + def set(self, shard_id: str, sequence_number: str): + self._source.state.set(shard_id, sequence_number) - def commit(self, force: bool = False): ... + def commit(self, force: bool = False): + if ( + (now := time.monotonic()) - self._last_committed_at > self._commit_interval + ) or force: + self._source.flush() + self._last_committed_at = now -class Authentication: - def __init__( - self, - aws_region: Optional[str] = None, - aws_access_key_id: Optional[str] = None, - aws_secret_access_key: Optional[str] = None, - aws_endpoint_url: Optional[str] = None, - ): - """ - :param aws_region: The AWS region. - NOTE: can alternatively set the AWS_REGION environment variable - :param aws_access_key_id: the AWS access key ID. - NOTE: can alternatively set the AWS_ACCESS_KEY_ID environment variable - :param aws_secret_access_key: the AWS secret access key. - NOTE: can alternatively set the AWS_SECRET_ACCESS_KEY environment variable - :param aws_endpoint_url: the endpoint URL to use; only required for connecting - to a locally hosted Kinesis. - NOTE: can alternatively set the AWS_ENDPOINT_URL_KINESIS environment variable - """ - self.auth = { - "endpoint_url": aws_endpoint_url, - "region_name": aws_region, - "aws_access_key_id": aws_access_key_id, - "aws_secret_access_key": aws_secret_access_key, - } +class AWSCredentials(TypedDict): + endpoint_url: Optional[str] + region_name: Optional[str] + aws_access_key_id: Optional[str] + aws_secret_access_key: Optional[str] class KinesisConsumer: @@ -71,7 +69,7 @@ class KinesisConsumer: def __init__( self, stream_name: str, - auth: Authentication, + credentials: AWSCredentials, message_processor: Callable[[KinesisRecord], None], checkpointer: KinesisCheckpointer, auto_offset_reset: AutoOffsetResetType = "latest", @@ -79,7 +77,7 @@ def __init__( backoff_secs: float = 5.0, ): self._stream = stream_name - self._auth = auth + self._credentials = credentials self._message_processor = message_processor self._checkpointer = checkpointer self._shard_iterators: dict[str, str] = {} @@ -89,19 +87,36 @@ def __init__( self._auto_offset_reset = _OFFSET_RESET_DICT[auto_offset_reset] self._client: Optional[KinesisClient] = None + def process_shards(self): + """ + Process records from the Stream shards one by one and checkpoint their + sequence numbers. + """ + # Iterate over shards one by one + for shard_id in self._shard_iterators: + # Poll records from each shard + for record in self._poll_records(shard_id=shard_id): + # Process the record + self._message_processor(record) + # Save the sequence number of the processed record + self._checkpointer.set(shard_id, record["SequenceNumber"]) + + def commit(self, force: bool = False): + """ + Commit the checkpoint and save the progress of the + """ + self._checkpointer.commit(force=force) + def __enter__(self) -> Self: - self.start() + self._init_client() + self._init_shards() return self def __exit__(self, exc_type, exc_val, exc_tb): pass def _init_client(self): - self._client = boto3.client("kinesis", **self._auth.auth) - - def _process_record(self, shard_id: str, record: KinesisRecord): - self._message_processor(record) - self._checkpointer.set(shard_id, record["SequenceNumber"]) + self._client = boto3.client("kinesis", **self._credentials) def _list_shards(self) -> list[ShardTypeDef]: """List all shards in the stream.""" @@ -113,51 +128,54 @@ def _list_shards(self) -> list[ShardTypeDef]: shards.extend(response["Shards"]) return shards - def _get_shard_iterator(self, shard_id: str): + def _get_shard_iterator(self, shard_id: str) -> str: if sequence_number := self._checkpointer.get(shard_id): - additional_kwargs = { + kwargs = { "ShardIteratorType": "AFTER_SEQUENCE_NUMBER", "StartingSequenceNumber": sequence_number, } else: - additional_kwargs = { + kwargs = { "ShardIteratorType": self._auto_offset_reset, } response: GetShardIteratorOutputTypeDef = self._client.get_shard_iterator( - StreamName=self._stream, ShardId=shard_id, **additional_kwargs + StreamName=self._stream, ShardId=shard_id, **kwargs ) return response["ShardIterator"] def _init_shards(self): if not (shards := [shard["ShardId"] for shard in self._list_shards()]): - raise ValueError(f"No shards for stream {self._stream}") + raise KinesisStreamShardsNotFound( + f'Shards not found for stream "{self._stream}"' + ) self._shard_iterators = { shard: self._get_shard_iterator(shard) for shard in shards } - def _poll_and_process_shard(self, shard_id): - """Read records from a shard.""" + def _poll_records(self, shard_id: str) -> Generator[KinesisRecord, None, None]: + """ + Poll records from the Kinesis Stream shard. + + If the shared is backed off, no records will be returned. + + :param shard_id: shard id. + """ if ( - backoff_time := self._shard_backoff.get(shard_id) - ) and time.monotonic() < backoff_time: + backoff_time := self._shard_backoff.get(shard_id, 0.0) + ) and backoff_time > time.monotonic(): + # The shard is backed off, exit early return + try: response = self._client.get_records( ShardIterator=self._shard_iterators[shard_id], Limit=self._max_records_per_shard, ) - - for record in response.get("Records", []): - self._process_record(shard_id, record) - - # Update the shard iterator for the next batch - self._shard_iterators[shard_id] = response["NextShardIterator"] - self._shard_backoff[shard_id] = 0 - except ClientError as e: error_code = e.response["Error"]["Code"] logger.error(f"Error reading from shard {shard_id}: {error_code}") if error_code == "ProvisionedThroughputExceededException": + # The shard is backed off by Kinesis, update the backoff deadline self._shard_backoff[shard_id] = time.monotonic() + self._backoff_secs elif error_code == "ExpiredIteratorException": logger.error(f"Shard iterator expired for shard {shard_id}.") @@ -165,26 +183,11 @@ def _poll_and_process_shard(self, shard_id): else: logger.error(f"Unrecoverable error: {e}") raise + else: + # Yield records for the shard + for record in response.get("Records", []): + yield record - def start(self): - self._init_client() - self._init_shards() - - def poll_and_process_shards(self): - for shard in self._shard_iterators: - self._poll_and_process_shard(shard) - - def commit(self, force: bool = False): - self._checkpointer.commit(force=force) - - def run(self): - """For running _without_ using Quix Streams Source framework.""" - try: - self.start() - while True: - self.poll_and_process_shards() - self.commit() - except Exception as e: - logger.debug(f"KinesisConsumer encountered an error: {e}") - finally: - logger.debug("Stopping KinesisConsumer...") + # Update the shard iterator for the next batch + self._shard_iterators[shard_id] = response["NextShardIterator"] + self._shard_backoff[shard_id] = 0 diff --git a/quixstreams/sources/community/kinesis/kinesis.py b/quixstreams/sources/community/kinesis/kinesis.py index 1d05f8705..9ef3c8c85 100644 --- a/quixstreams/sources/community/kinesis/kinesis.py +++ b/quixstreams/sources/community/kinesis/kinesis.py @@ -1,12 +1,11 @@ -import time from typing import Optional from quixstreams.models.topics import Topic from quixstreams.sources.base import StatefulSource from .consumer import ( - Authentication, AutoOffsetResetType, + AWSCredentials, KinesisCheckpointer, KinesisConsumer, KinesisRecord, @@ -88,16 +87,19 @@ def __init__( shard when Kinesis consumer encounters handled/expected errors. """ self._stream_name = stream_name - self._auth = Authentication( - aws_endpoint_url=aws_endpoint_url, - aws_region=aws_region, - aws_access_key_id=aws_access_key_id, - aws_secret_access_key=aws_secret_access_key, - ) + self._credentials: AWSCredentials = { + "endpoint_url": aws_endpoint_url, + "region_name": aws_region, + "aws_access_key_id": aws_access_key_id, + "aws_secret_access_key": aws_secret_access_key, + } + self._auto_offset_reset = auto_offset_reset self._max_records_per_shard = max_records_per_shard self._retry_backoff_secs = retry_backoff_secs - self._checkpointer = SourceCheckpointer(self, commit_interval) + self._checkpointer = KinesisCheckpointer( + stateful_source=self, commit_interval=commit_interval + ) super().__init__( name=f"kinesis_{self._stream_name}", shutdown_timeout=shutdown_timeout ) @@ -126,7 +128,7 @@ def _handle_kinesis_message(self, message: KinesisRecord): def run(self): with KinesisConsumer( stream_name=self._stream_name, - auth=self._auth, + credentials=self._credentials, message_processor=self._handle_kinesis_message, auto_offset_reset=self._auto_offset_reset, checkpointer=self._checkpointer, @@ -134,29 +136,5 @@ def run(self): backoff_secs=self._retry_backoff_secs, ) as consumer: while self._running: - consumer.poll_and_process_shards() + consumer.process_shards() consumer.commit() - - -class SourceCheckpointer(KinesisCheckpointer): - def __init__(self, stateful_source: StatefulSource, commit_interval: float = 5.0): - self._source = stateful_source - self._last_committed_at = time.monotonic() - self._commit_interval = commit_interval - - @property - def last_committed_at(self) -> float: - return self._last_committed_at - - def get(self, key: str) -> Optional[str]: - return self._source.state.get(key) - - def set(self, key: str, value: str): - self._source.state.set(key, value) - - def commit(self, force: bool = False): - if ( - (now := time.monotonic()) - self._last_committed_at > self._commit_interval - ) or force: - self._source.flush() - self._last_committed_at = now