diff --git a/quixstreams/kafka/producer.py b/quixstreams/kafka/producer.py index 6aca67ec9..98d6ea5b2 100644 --- a/quixstreams/kafka/producer.py +++ b/quixstreams/kafka/producer.py @@ -71,6 +71,7 @@ def __init__( **{"logger": logger, "error_cb": error_callback}, } self._inner_producer: Optional[ConfluentProducer] = None + self._transactional: bool = False def produce( self, @@ -156,66 +157,23 @@ def flush(self, timeout: Optional[float] = None) -> int: def _producer(self) -> ConfluentProducer: if not self._inner_producer: self._inner_producer = ConfluentProducer(self._producer_config) + if self._transactional: + self._inner_producer.init_transactions() return self._inner_producer - def __len__(self): - return len(self._producer) - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - logger.debug("Flushing kafka producer") - self.flush() - logger.debug("Kafka producer flushed") - - -class TransactionalProducer(Producer): - """ - A separate producer class used only internally for transactions - (transactions are only needed when using a consumer). - """ - - def __init__( - self, - broker_address: Union[str, ConnectionConfig], - logger: logging.Logger = logger, - error_callback: Callable[[KafkaError], None] = _default_error_cb, - extra_config: Optional[dict] = None, - transactional_id: str = str(uuid.uuid4()), - ): - super().__init__( - broker_address=broker_address, - logger=logger, - error_callback=error_callback, - extra_config=extra_config, - ) - # remake config to avoid overriding anything in the Application's - # producer config, which is used in Application.get_producer(). + def __use_transactions__(self, transactional_id: str = str(uuid.uuid4())): self._producer_config = { **self._producer_config, "enable.idempotence": True, "transactional.id": transactional_id, } - self._active_transaction = False - - @property - def active_transaction(self): - return self._active_transaction - - @property - def _producer(self) -> ConfluentProducer: - if not self._inner_producer: - self._inner_producer = ConfluentProducer(self._producer_config) - self._inner_producer.init_transactions() - return self._inner_producer + self._transactional = True - def begin_transaction(self): + def __begin_transaction__(self): logger.debug("Starting Kafka transaction...") self._producer.begin_transaction() - self._active_transaction = True - def send_offsets_to_transaction( + def __send_offsets_to_transaction__( self, positions: List[TopicPartition], group_metadata: GroupMetadata, @@ -225,14 +183,23 @@ def send_offsets_to_transaction( positions, group_metadata, timeout if timeout is not None else -1 ) - def abort_transaction(self, timeout: Optional[float] = None): + def __abort_transaction__(self, timeout: Optional[float] = None): logger.debug("Aborting Kafka transaction...") self._producer.abort_transaction(timeout if timeout is not None else -1) - self._active_transaction = False logger.debug("Kafka transaction aborted successfully!") - def commit_transaction(self, timeout: Optional[float] = None): + def __commit_transaction__(self, timeout: Optional[float] = None): logger.debug("Committing Kafka transaction...") self._producer.commit_transaction(timeout if timeout is not None else -1) - self._active_transaction = False logger.debug("Kafka transaction committed successfully!") + + def __len__(self): + return len(self._producer) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + logger.debug("Flushing kafka producer") + self.flush() + logger.debug("Kafka producer flushed") diff --git a/quixstreams/rowproducer.py b/quixstreams/rowproducer.py index d6aa278f9..0bd399b02 100644 --- a/quixstreams/rowproducer.py +++ b/quixstreams/rowproducer.py @@ -1,6 +1,6 @@ import logging from time import sleep -from typing import Optional, Any, Union, Dict, Tuple, List +from typing import Optional, Any, Union, Dict, Tuple, List, Callable from confluent_kafka import TopicPartition, KafkaException, KafkaError, Message from confluent_kafka.admin import GroupMetadata @@ -9,7 +9,7 @@ from .error_callbacks import ProducerErrorCallback, default_on_producer_error from .kafka.configuration import ConnectionConfig from .kafka.exceptions import KafkaProducerDeliveryError -from .kafka.producer import Producer, TransactionalProducer +from .kafka.producer import Producer from .models import Topic, Row, Headers logger = logging.getLogger(__name__) @@ -47,23 +47,19 @@ def __init__( transactional: bool = False, ): + self._producer = Producer( + broker_address=broker_address, + extra_config=extra_config, + ) if transactional: - self._producer = TransactionalProducer( - broker_address=broker_address, - extra_config=extra_config, - ) - else: - self._producer = Producer( - broker_address=broker_address, - extra_config=extra_config, - ) + self._producer.__use_transactions__() self._on_error: Optional[ProducerErrorCallback] = ( on_error or default_on_producer_error ) self._tp_offsets: Dict[Tuple[str, int], int] = {} self._error: Optional[KafkaError] = None - self._transactional = transactional + self._active_transaction = False def produce_row( self, @@ -173,22 +169,41 @@ def offsets(self) -> Dict[Tuple[str, int], int]: return self._tp_offsets def begin_transaction(self): - self._producer.begin_transaction() + self._producer.__begin_transaction__() + self._active_transaction = True def abort_transaction(self, timeout: Optional[float] = None): - # Skip abort if no active transaction since it throws an exception if at least - # one transaction was successfully completed at some point. - # This avoids polluting the stack trace in the case where a transaction was - # not active as expected (because of some other exception already raised). - if self._producer.active_transaction: - self._producer.abort_transaction(timeout) + """ + Attempt an abort if an active transaction. + + Else, skip since it throws an exception if at least + one transaction was successfully completed at some point. + + This avoids polluting the stack trace in the case where a transaction was + not active as expected (because of some other exception already raised) + and a cleanup abort is attempted. + + NOTE: under normal circumstances a transaction will be open due to how + the Checkpoint inits another immediately after committing. + """ + if self._active_transaction: + self._producer.__abort_transaction__(timeout) + self._active_transaction = False else: logger.debug( "No Kafka transaction to abort, " "likely due to some other exception occurring" ) - def _retriable_commit_op(self, operation, args): + def _retriable_commit_op(self, operation: Callable, args: list): + """ + Some specific failure cases from sending offsets or committing a transaction + are retriable, which is worth re-attempting since the transaction is + almost complete (we flushed before attempting to commit). + + NOTE: During testing, most other operations (including producing) + did not generate "retriable" errors. + """ attempts_remaining = 3 backoff_seconds = 1 op_name = operation.__name__ @@ -226,10 +241,11 @@ def commit_transaction( timeout: Optional[float] = None, ): self._retriable_commit_op( - self._producer.send_offsets_to_transaction, + self._producer.__send_offsets_to_transaction__, [positions, group_metadata, timeout], ) - self._retriable_commit_op(self._producer.commit_transaction, [timeout]) + self._retriable_commit_op(self._producer.__commit_transaction__, [timeout]) + self._active_transaction = False def __enter__(self): return self diff --git a/tests/test_quixstreams/test_rowproducer.py b/tests/test_quixstreams/test_rowproducer.py index 112089ecc..857f3720b 100644 --- a/tests/test_quixstreams/test_rowproducer.py +++ b/tests/test_quixstreams/test_rowproducer.py @@ -11,7 +11,7 @@ from confluent_kafka import TopicPartition from quixstreams.kafka.exceptions import KafkaProducerDeliveryError -from quixstreams.kafka.producer import TransactionalProducer +from quixstreams.kafka.producer import Producer from quixstreams.models import ( JSONSerializer, SerializationError, @@ -188,6 +188,9 @@ def test_produce_and_commit( topic_manager_topic_factory, row_consumer_factory, ): + """ + Simplest transactional consume + produce pattern + """ topic_args = dict( create_topic=True, value_serializer="json", @@ -231,7 +234,7 @@ def test_produce_and_commit( == consumer_end_offset ) - # downstream consumer should only get the committed messages + # downstream consumer gets the committed messages rows = [] with row_consumer_factory(auto_offset_reset="earliest") as consumer: consumer.subscribe([topic_out]) @@ -242,28 +245,6 @@ def test_produce_and_commit( assert row.key == key assert row.value == value - def test_retriable_op_error(self): - class MockKafkaError(Exception): - def retriable(self): - return True - - call_args = [["my", "offsets"], "consumer_metadata", 1] - error = ConfluentKafkaException(MockKafkaError()) - - mock_producer = create_autospec(TransactionalProducer) - mock_producer.send_offsets_to_transaction.__name__ = "send_offsets" - mock_producer.send_offsets_to_transaction.side_effect = [error, None] - with patch( - "quixstreams.rowproducer.TransactionalProducer", return_value=mock_producer - ): - row_producer = RowProducer(broker_address="lol", transactional=True) - row_producer.commit_transaction(*call_args) - - mock_producer.send_offsets_to_transaction.assert_has_calls( - [call(*call_args)] * 2 - ) - mock_producer.commit_transaction.assert_called_once() - def test_produce_after_aborted_transaction( self, row_producer, @@ -271,6 +252,14 @@ def test_produce_after_aborted_transaction( topic_manager_topic_factory, row_consumer_factory, ): + """ + transactional consume + produce pattern, but we mimic a failed transaction by + aborting it directly (after producing + flushing to the downstream topic). + + Then, redo the consume + produce (and successfully commit the transaction). + + We confirm offset behavior from both failed and successful transactions. + """ topic_args = dict( create_topic=True, value_serializer="json", @@ -350,7 +339,7 @@ def consume_and_produce(consumer, producer): # as further proof the initial messages actually made it to the topic # (and thus were ignored) we can inspect our message offsets. - # Produced offsets 0-2 were aborted; all aborts (direct or timeout) are followed + # Produced offsets 0-2 were aborted; all direct aborts are followed # by an abort marker (offset 3). # The next valid offset (which is our first successful message) should be 4. # Note that the lowwater is still 0, meaning the messages were successfully added @@ -371,6 +360,13 @@ def test_produce_transaction_timeout_no_abort( topic_manager_topic_factory, row_consumer_factory, ): + """ + Validate the behavior around a transaction that times out via the producer + config transaction.timeout.ms + + A timeout should invalidate that producer from further transactions + (which also raises an exception to cause the Application to terminate). + """ topic_args = dict( create_topic=True, value_serializer="json", @@ -447,3 +443,29 @@ def test_produce_transaction_timeout_no_abort( producer.begin_transaction() kafka_error = e.value.args[0] assert kafka_error.code() == ConfluentKafkaError._FENCED + + def test_retriable_op_error(self): + """ + Some specific failure cases from sending offsets or committing a transaction + are retriable. + """ + + class MockKafkaError(Exception): + def retriable(self): + return True + + call_args = [["my", "offsets"], "consumer_metadata", 1] + error = ConfluentKafkaException(MockKafkaError()) + + mock_producer = create_autospec(Producer) + mock_producer.__send_offsets_to_transaction__.__name__ = "send_offsets" + mock_producer.__commit_transaction__.__name__ = "commit" + mock_producer.__send_offsets_to_transaction__.side_effect = [error, None] + with patch("quixstreams.rowproducer.Producer", return_value=mock_producer): + row_producer = RowProducer(broker_address="lol", transactional=True) + row_producer.commit_transaction(*call_args) + + mock_producer.__send_offsets_to_transaction__.assert_has_calls( + [call(*call_args)] * 2 + ) + mock_producer.__commit_transaction__.assert_called_once()