Skip to content

Commit

Permalink
Simplify the Kinesis Source implementation (#650)
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-quix authored and tim-quix committed Nov 21, 2024
1 parent e52fb80 commit 649f418
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 113 deletions.
159 changes: 81 additions & 78 deletions quixstreams/sources/community/kinesis/consumer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import logging
import time
from typing import Callable, Literal, Optional, Protocol
from typing import Callable, Generator, Literal, Optional, TypedDict

from typing_extensions import Self

from quixstreams.sources import StatefulSource

try:
import boto3
from botocore.exceptions import ClientError
Expand All @@ -23,43 +25,39 @@
AutoOffsetResetType = Literal["earliest", "latest"]


class KinesisCheckpointer(Protocol):
class KinesisStreamShardsNotFound(Exception):
"""Raised when the Kinesis Stream has no shards"""


class KinesisCheckpointer:
def __init__(self, stateful_source: StatefulSource, commit_interval: float = 5.0):
self._source = stateful_source
self._last_committed_at = time.monotonic()
self._commit_interval = commit_interval

@property
def last_committed_at(self) -> float:
yield ...
return self._last_committed_at

def get(self, key: str) -> Optional[str]: ...
def get(self, shard_id: str) -> Optional[str]:
return self._source.state.get(shard_id)

def set(self, key: str, value: str): ...
def set(self, shard_id: str, sequence_number: str):
self._source.state.set(shard_id, sequence_number)

def commit(self, force: bool = False): ...
def commit(self, force: bool = False):
if (
(now := time.monotonic()) - self._last_committed_at > self._commit_interval
) or force:
self._source.flush()
self._last_committed_at = now


class Authentication:
def __init__(
self,
aws_region: Optional[str] = None,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_endpoint_url: Optional[str] = None,
):
"""
:param aws_region: The AWS region.
NOTE: can alternatively set the AWS_REGION environment variable
:param aws_access_key_id: the AWS access key ID.
NOTE: can alternatively set the AWS_ACCESS_KEY_ID environment variable
:param aws_secret_access_key: the AWS secret access key.
NOTE: can alternatively set the AWS_SECRET_ACCESS_KEY environment variable
:param aws_endpoint_url: the endpoint URL to use; only required for connecting
to a locally hosted Kinesis.
NOTE: can alternatively set the AWS_ENDPOINT_URL_KINESIS environment variable
"""
self.auth = {
"endpoint_url": aws_endpoint_url,
"region_name": aws_region,
"aws_access_key_id": aws_access_key_id,
"aws_secret_access_key": aws_secret_access_key,
}
class AWSCredentials(TypedDict):
endpoint_url: Optional[str]
region_name: Optional[str]
aws_access_key_id: Optional[str]
aws_secret_access_key: Optional[str]


class KinesisConsumer:
Expand All @@ -71,15 +69,15 @@ class KinesisConsumer:
def __init__(
self,
stream_name: str,
auth: Authentication,
credentials: AWSCredentials,
message_processor: Callable[[KinesisRecord], None],
checkpointer: KinesisCheckpointer,
auto_offset_reset: AutoOffsetResetType = "latest",
max_records_per_shard: int = 1000,
backoff_secs: float = 5.0,
):
self._stream = stream_name
self._auth = auth
self._credentials = credentials
self._message_processor = message_processor
self._checkpointer = checkpointer
self._shard_iterators: dict[str, str] = {}
Expand All @@ -89,19 +87,36 @@ def __init__(
self._auto_offset_reset = _OFFSET_RESET_DICT[auto_offset_reset]
self._client: Optional[KinesisClient] = None

def process_shards(self):
"""
Process records from the Stream shards one by one and checkpoint their
sequence numbers.
"""
# Iterate over shards one by one
for shard_id in self._shard_iterators:
# Poll records from each shard
for record in self._poll_records(shard_id=shard_id):
# Process the record
self._message_processor(record)
# Save the sequence number of the processed record
self._checkpointer.set(shard_id, record["SequenceNumber"])

def commit(self, force: bool = False):
"""
Commit the checkpoint and save the progress of the
"""
self._checkpointer.commit(force=force)

def __enter__(self) -> Self:
self.start()
self._init_client()
self._init_shards()
return self

def __exit__(self, exc_type, exc_val, exc_tb):
pass

def _init_client(self):
self._client = boto3.client("kinesis", **self._auth.auth)

def _process_record(self, shard_id: str, record: KinesisRecord):
self._message_processor(record)
self._checkpointer.set(shard_id, record["SequenceNumber"])
self._client = boto3.client("kinesis", **self._credentials)

def _list_shards(self) -> list[ShardTypeDef]:
"""List all shards in the stream."""
Expand All @@ -113,78 +128,66 @@ def _list_shards(self) -> list[ShardTypeDef]:
shards.extend(response["Shards"])
return shards

def _get_shard_iterator(self, shard_id: str):
def _get_shard_iterator(self, shard_id: str) -> str:
if sequence_number := self._checkpointer.get(shard_id):
additional_kwargs = {
kwargs = {
"ShardIteratorType": "AFTER_SEQUENCE_NUMBER",
"StartingSequenceNumber": sequence_number,
}
else:
additional_kwargs = {
kwargs = {
"ShardIteratorType": self._auto_offset_reset,
}
response: GetShardIteratorOutputTypeDef = self._client.get_shard_iterator(
StreamName=self._stream, ShardId=shard_id, **additional_kwargs
StreamName=self._stream, ShardId=shard_id, **kwargs
)
return response["ShardIterator"]

def _init_shards(self):
if not (shards := [shard["ShardId"] for shard in self._list_shards()]):
raise ValueError(f"No shards for stream {self._stream}")
raise KinesisStreamShardsNotFound(
f'Shards not found for stream "{self._stream}"'
)
self._shard_iterators = {
shard: self._get_shard_iterator(shard) for shard in shards
}

def _poll_and_process_shard(self, shard_id):
"""Read records from a shard."""
def _poll_records(self, shard_id: str) -> Generator[KinesisRecord, None, None]:
"""
Poll records from the Kinesis Stream shard.
If the shared is backed off, no records will be returned.
:param shard_id: shard id.
"""
if (
backoff_time := self._shard_backoff.get(shard_id)
) and time.monotonic() < backoff_time:
backoff_time := self._shard_backoff.get(shard_id, 0.0)
) and backoff_time > time.monotonic():
# The shard is backed off, exit early
return

try:
response = self._client.get_records(
ShardIterator=self._shard_iterators[shard_id],
Limit=self._max_records_per_shard,
)

for record in response.get("Records", []):
self._process_record(shard_id, record)

# Update the shard iterator for the next batch
self._shard_iterators[shard_id] = response["NextShardIterator"]
self._shard_backoff[shard_id] = 0

except ClientError as e:
error_code = e.response["Error"]["Code"]
logger.error(f"Error reading from shard {shard_id}: {error_code}")
if error_code == "ProvisionedThroughputExceededException":
# The shard is backed off by Kinesis, update the backoff deadline
self._shard_backoff[shard_id] = time.monotonic() + self._backoff_secs
elif error_code == "ExpiredIteratorException":
logger.error(f"Shard iterator expired for shard {shard_id}.")
raise
else:
logger.error(f"Unrecoverable error: {e}")
raise
else:
# Yield records for the shard
for record in response.get("Records", []):
yield record

def start(self):
self._init_client()
self._init_shards()

def poll_and_process_shards(self):
for shard in self._shard_iterators:
self._poll_and_process_shard(shard)

def commit(self, force: bool = False):
self._checkpointer.commit(force=force)

def run(self):
"""For running _without_ using Quix Streams Source framework."""
try:
self.start()
while True:
self.poll_and_process_shards()
self.commit()
except Exception as e:
logger.debug(f"KinesisConsumer encountered an error: {e}")
finally:
logger.debug("Stopping KinesisConsumer...")
# Update the shard iterator for the next batch
self._shard_iterators[shard_id] = response["NextShardIterator"]
self._shard_backoff[shard_id] = 0
48 changes: 13 additions & 35 deletions quixstreams/sources/community/kinesis/kinesis.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import time
from typing import Optional

from quixstreams.models.topics import Topic
from quixstreams.sources.base import StatefulSource

from .consumer import (
Authentication,
AutoOffsetResetType,
AWSCredentials,
KinesisCheckpointer,
KinesisConsumer,
KinesisRecord,
Expand Down Expand Up @@ -88,16 +87,19 @@ def __init__(
shard when Kinesis consumer encounters handled/expected errors.
"""
self._stream_name = stream_name
self._auth = Authentication(
aws_endpoint_url=aws_endpoint_url,
aws_region=aws_region,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
)
self._credentials: AWSCredentials = {
"endpoint_url": aws_endpoint_url,
"region_name": aws_region,
"aws_access_key_id": aws_access_key_id,
"aws_secret_access_key": aws_secret_access_key,
}

self._auto_offset_reset = auto_offset_reset
self._max_records_per_shard = max_records_per_shard
self._retry_backoff_secs = retry_backoff_secs
self._checkpointer = SourceCheckpointer(self, commit_interval)
self._checkpointer = KinesisCheckpointer(
stateful_source=self, commit_interval=commit_interval
)
super().__init__(
name=f"kinesis_{self._stream_name}", shutdown_timeout=shutdown_timeout
)
Expand Down Expand Up @@ -126,37 +128,13 @@ def _handle_kinesis_message(self, message: KinesisRecord):
def run(self):
with KinesisConsumer(
stream_name=self._stream_name,
auth=self._auth,
credentials=self._credentials,
message_processor=self._handle_kinesis_message,
auto_offset_reset=self._auto_offset_reset,
checkpointer=self._checkpointer,
max_records_per_shard=self._max_records_per_shard,
backoff_secs=self._retry_backoff_secs,
) as consumer:
while self._running:
consumer.poll_and_process_shards()
consumer.process_shards()
consumer.commit()


class SourceCheckpointer(KinesisCheckpointer):
def __init__(self, stateful_source: StatefulSource, commit_interval: float = 5.0):
self._source = stateful_source
self._last_committed_at = time.monotonic()
self._commit_interval = commit_interval

@property
def last_committed_at(self) -> float:
return self._last_committed_at

def get(self, key: str) -> Optional[str]:
return self._source.state.get(key)

def set(self, key: str, value: str):
self._source.state.set(key, value)

def commit(self, force: bool = False):
if (
(now := time.monotonic()) - self._last_committed_at > self._commit_interval
) or force:
self._source.flush()
self._last_committed_at = now

0 comments on commit 649f418

Please sign in to comment.