Skip to content

Commit

Permalink
update Producer class and remove TransactionalProducer
Browse files Browse the repository at this point in the history
  • Loading branch information
tim-quix committed Jun 14, 2024
1 parent 3b3a70f commit 198c81a
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 100 deletions.
73 changes: 20 additions & 53 deletions quixstreams/kafka/producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(
**{"logger": logger, "error_cb": error_callback},
}
self._inner_producer: Optional[ConfluentProducer] = None
self._transactional: bool = False

def produce(
self,
Expand Down Expand Up @@ -156,66 +157,23 @@ def flush(self, timeout: Optional[float] = None) -> int:
def _producer(self) -> ConfluentProducer:
if not self._inner_producer:
self._inner_producer = ConfluentProducer(self._producer_config)
if self._transactional:
self._inner_producer.init_transactions()
return self._inner_producer

def __len__(self):
return len(self._producer)

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
logger.debug("Flushing kafka producer")
self.flush()
logger.debug("Kafka producer flushed")


class TransactionalProducer(Producer):
"""
A separate producer class used only internally for transactions
(transactions are only needed when using a consumer).
"""

def __init__(
self,
broker_address: Union[str, ConnectionConfig],
logger: logging.Logger = logger,
error_callback: Callable[[KafkaError], None] = _default_error_cb,
extra_config: Optional[dict] = None,
transactional_id: str = str(uuid.uuid4()),
):
super().__init__(
broker_address=broker_address,
logger=logger,
error_callback=error_callback,
extra_config=extra_config,
)
# remake config to avoid overriding anything in the Application's
# producer config, which is used in Application.get_producer().
def __use_transactions__(self, transactional_id: str = str(uuid.uuid4())):
self._producer_config = {
**self._producer_config,
"enable.idempotence": True,
"transactional.id": transactional_id,
}
self._active_transaction = False

@property
def active_transaction(self):
return self._active_transaction

@property
def _producer(self) -> ConfluentProducer:
if not self._inner_producer:
self._inner_producer = ConfluentProducer(self._producer_config)
self._inner_producer.init_transactions()
return self._inner_producer
self._transactional = True

def begin_transaction(self):
def __begin_transaction__(self):
logger.debug("Starting Kafka transaction...")
self._producer.begin_transaction()
self._active_transaction = True

def send_offsets_to_transaction(
def __send_offsets_to_transaction__(
self,
positions: List[TopicPartition],
group_metadata: GroupMetadata,
Expand All @@ -225,14 +183,23 @@ def send_offsets_to_transaction(
positions, group_metadata, timeout if timeout is not None else -1
)

def abort_transaction(self, timeout: Optional[float] = None):
def __abort_transaction__(self, timeout: Optional[float] = None):
logger.debug("Aborting Kafka transaction...")
self._producer.abort_transaction(timeout if timeout is not None else -1)
self._active_transaction = False
logger.debug("Kafka transaction aborted successfully!")

def commit_transaction(self, timeout: Optional[float] = None):
def __commit_transaction__(self, timeout: Optional[float] = None):
logger.debug("Committing Kafka transaction...")
self._producer.commit_transaction(timeout if timeout is not None else -1)
self._active_transaction = False
logger.debug("Kafka transaction committed successfully!")

def __len__(self):
return len(self._producer)

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
logger.debug("Flushing kafka producer")
self.flush()
logger.debug("Kafka producer flushed")
60 changes: 38 additions & 22 deletions quixstreams/rowproducer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from time import sleep
from typing import Optional, Any, Union, Dict, Tuple, List
from typing import Optional, Any, Union, Dict, Tuple, List, Callable

from confluent_kafka import TopicPartition, KafkaException, KafkaError, Message
from confluent_kafka.admin import GroupMetadata
Expand All @@ -9,7 +9,7 @@
from .error_callbacks import ProducerErrorCallback, default_on_producer_error
from .kafka.configuration import ConnectionConfig
from .kafka.exceptions import KafkaProducerDeliveryError
from .kafka.producer import Producer, TransactionalProducer
from .kafka.producer import Producer
from .models import Topic, Row, Headers

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -47,23 +47,19 @@ def __init__(
transactional: bool = False,
):

self._producer = Producer(
broker_address=broker_address,
extra_config=extra_config,
)
if transactional:
self._producer = TransactionalProducer(
broker_address=broker_address,
extra_config=extra_config,
)
else:
self._producer = Producer(
broker_address=broker_address,
extra_config=extra_config,
)
self._producer.__use_transactions__()

self._on_error: Optional[ProducerErrorCallback] = (
on_error or default_on_producer_error
)
self._tp_offsets: Dict[Tuple[str, int], int] = {}
self._error: Optional[KafkaError] = None
self._transactional = transactional
self._active_transaction = False

def produce_row(
self,
Expand Down Expand Up @@ -173,22 +169,41 @@ def offsets(self) -> Dict[Tuple[str, int], int]:
return self._tp_offsets

def begin_transaction(self):
self._producer.begin_transaction()
self._producer.__begin_transaction__()
self._active_transaction = True

def abort_transaction(self, timeout: Optional[float] = None):
# Skip abort if no active transaction since it throws an exception if at least
# one transaction was successfully completed at some point.
# This avoids polluting the stack trace in the case where a transaction was
# not active as expected (because of some other exception already raised).
if self._producer.active_transaction:
self._producer.abort_transaction(timeout)
"""
Attempt an abort if an active transaction.
Else, skip since it throws an exception if at least
one transaction was successfully completed at some point.
This avoids polluting the stack trace in the case where a transaction was
not active as expected (because of some other exception already raised)
and a cleanup abort is attempted.
NOTE: under normal circumstances a transaction will be open due to how
the Checkpoint inits another immediately after committing.
"""
if self._active_transaction:
self._producer.__abort_transaction__(timeout)
self._active_transaction = False
else:
logger.debug(
"No Kafka transaction to abort, "
"likely due to some other exception occurring"
)

def _retriable_commit_op(self, operation, args):
def _retriable_commit_op(self, operation: Callable, args: list):
"""
Some specific failure cases from sending offsets or committing a transaction
are retriable, which is worth re-attempting since the transaction is
almost complete (we flushed before attempting to commit).
NOTE: During testing, most other operations (including producing)
did not generate "retriable" errors.
"""
attempts_remaining = 3
backoff_seconds = 1
op_name = operation.__name__
Expand Down Expand Up @@ -226,10 +241,11 @@ def commit_transaction(
timeout: Optional[float] = None,
):
self._retriable_commit_op(
self._producer.send_offsets_to_transaction,
self._producer.__send_offsets_to_transaction__,
[positions, group_metadata, timeout],
)
self._retriable_commit_op(self._producer.commit_transaction, [timeout])
self._retriable_commit_op(self._producer.__commit_transaction__, [timeout])
self._active_transaction = False

def __enter__(self):
return self
Expand Down
72 changes: 47 additions & 25 deletions tests/test_quixstreams/test_rowproducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from confluent_kafka import TopicPartition

from quixstreams.kafka.exceptions import KafkaProducerDeliveryError
from quixstreams.kafka.producer import TransactionalProducer
from quixstreams.kafka.producer import Producer
from quixstreams.models import (
JSONSerializer,
SerializationError,
Expand Down Expand Up @@ -188,6 +188,9 @@ def test_produce_and_commit(
topic_manager_topic_factory,
row_consumer_factory,
):
"""
Simplest transactional consume + produce pattern
"""
topic_args = dict(
create_topic=True,
value_serializer="json",
Expand Down Expand Up @@ -231,7 +234,7 @@ def test_produce_and_commit(
== consumer_end_offset
)

# downstream consumer should only get the committed messages
# downstream consumer gets the committed messages
rows = []
with row_consumer_factory(auto_offset_reset="earliest") as consumer:
consumer.subscribe([topic_out])
Expand All @@ -242,35 +245,21 @@ def test_produce_and_commit(
assert row.key == key
assert row.value == value

def test_retriable_op_error(self):
class MockKafkaError(Exception):
def retriable(self):
return True

call_args = [["my", "offsets"], "consumer_metadata", 1]
error = ConfluentKafkaException(MockKafkaError())

mock_producer = create_autospec(TransactionalProducer)
mock_producer.send_offsets_to_transaction.__name__ = "send_offsets"
mock_producer.send_offsets_to_transaction.side_effect = [error, None]
with patch(
"quixstreams.rowproducer.TransactionalProducer", return_value=mock_producer
):
row_producer = RowProducer(broker_address="lol", transactional=True)
row_producer.commit_transaction(*call_args)

mock_producer.send_offsets_to_transaction.assert_has_calls(
[call(*call_args)] * 2
)
mock_producer.commit_transaction.assert_called_once()

def test_produce_after_aborted_transaction(
self,
row_producer,
transactional_row_producer,
topic_manager_topic_factory,
row_consumer_factory,
):
"""
transactional consume + produce pattern, but we mimic a failed transaction by
aborting it directly (after producing + flushing to the downstream topic).
Then, redo the consume + produce (and successfully commit the transaction).
We confirm offset behavior from both failed and successful transactions.
"""
topic_args = dict(
create_topic=True,
value_serializer="json",
Expand Down Expand Up @@ -350,7 +339,7 @@ def consume_and_produce(consumer, producer):
# as further proof the initial messages actually made it to the topic
# (and thus were ignored) we can inspect our message offsets.

# Produced offsets 0-2 were aborted; all aborts (direct or timeout) are followed
# Produced offsets 0-2 were aborted; all direct aborts are followed
# by an abort marker (offset 3).
# The next valid offset (which is our first successful message) should be 4.
# Note that the lowwater is still 0, meaning the messages were successfully added
Expand All @@ -371,6 +360,13 @@ def test_produce_transaction_timeout_no_abort(
topic_manager_topic_factory,
row_consumer_factory,
):
"""
Validate the behavior around a transaction that times out via the producer
config transaction.timeout.ms
A timeout should invalidate that producer from further transactions
(which also raises an exception to cause the Application to terminate).
"""
topic_args = dict(
create_topic=True,
value_serializer="json",
Expand Down Expand Up @@ -447,3 +443,29 @@ def test_produce_transaction_timeout_no_abort(
producer.begin_transaction()
kafka_error = e.value.args[0]
assert kafka_error.code() == ConfluentKafkaError._FENCED

def test_retriable_op_error(self):
"""
Some specific failure cases from sending offsets or committing a transaction
are retriable.
"""

class MockKafkaError(Exception):
def retriable(self):
return True

call_args = [["my", "offsets"], "consumer_metadata", 1]
error = ConfluentKafkaException(MockKafkaError())

mock_producer = create_autospec(Producer)
mock_producer.__send_offsets_to_transaction__.__name__ = "send_offsets"
mock_producer.__commit_transaction__.__name__ = "commit"
mock_producer.__send_offsets_to_transaction__.side_effect = [error, None]
with patch("quixstreams.rowproducer.Producer", return_value=mock_producer):
row_producer = RowProducer(broker_address="lol", transactional=True)
row_producer.commit_transaction(*call_args)

mock_producer.__send_offsets_to_transaction__.assert_has_calls(
[call(*call_args)] * 2
)
mock_producer.__commit_transaction__.assert_called_once()

0 comments on commit 198c81a

Please sign in to comment.