Skip to content

Commit

Permalink
most tests done, need a few more
Browse files Browse the repository at this point in the history
  • Loading branch information
tim-quix committed Jun 13, 2024
1 parent 35d360b commit 05420cd
Show file tree
Hide file tree
Showing 7 changed files with 425 additions and 31 deletions.
5 changes: 2 additions & 3 deletions quixstreams/checkpointing/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,9 @@ def commit(self):
"""

if not self._tp_offsets:
logger.debug("Nothing to commit")
if self._exactly_once:
self._producer.abort_transaction()
logger.debug("Nothing to commit")
return

# Step 1. Produce the changelogs
Expand Down Expand Up @@ -174,8 +174,7 @@ def commit(self):
produced_offsets.get(changelog_tp) if changelog_tp is not None else None
)
if changelog_offset is not None:
# Increment the changelog offset by one to match the high watermark
# in Kafka
# Increment the changelog offset to match the high watermark in Kafka
changelog_offset += self._changelog_offset_update
transaction.flush(
processed_offset=offset, changelog_offset=changelog_offset
Expand Down
10 changes: 7 additions & 3 deletions quixstreams/kafka/producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,13 @@ def __init__(
error_callback=error_callback,
extra_config=extra_config,
)
self._producer_config.update(
{"enable.idempotence": True, "transactional.id": transactional_id}
)
# remake config to avoid overriding anything in the Application's
# producer config, which is used in Application.get_producer().
self._producer_config = {
**self._producer_config,
"enable.idempotence": True,
"transactional.id": transactional_id,
}
self._active_transaction = False

@property
Expand Down
38 changes: 22 additions & 16 deletions quixstreams/rowproducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,29 +188,21 @@ def abort_transaction(self, timeout: Optional[float] = None):
"likely due to some other exception occurring"
)

def commit_transaction(
self,
positions: List[TopicPartition],
group_metadata: GroupMetadata,
timeout: Optional[float] = None,
):
attempts_remaining = 3
def _retriable_commit_op(self, operation, args):
attempts_remaining = 5
backoff_seconds = 1
op_name = operation.__name__
while attempts_remaining:
try:
self._producer.send_offsets_to_transaction(
positions, group_metadata, timeout=timeout
)
self._producer.commit_transaction(timeout=timeout)
return
return operation(*args)
# Errors do not manifest from these calls via producer error_cb.
# NOTE: Manual flushing earlier keeps error handling here to a minimum.
except KafkaException as e:
error: KafkaError = e.args[0]
error = e.args[0]
if error.retriable():
attempts_remaining -= 1
logger.debug(
f"Kafka Transaction commit attempt failed, but is retriable; "
f"Kafka transaction operation {op_name} failed, but can retry; "
f"attempts remaining: {attempts_remaining}. "
)
if attempts_remaining:
Expand All @@ -220,13 +212,27 @@ def commit_transaction(
sleep(backoff_seconds)
else:
# Just treat all errors besides retriable as fatal.
logger.error("Error while attempting to commit Kafka transaction.")
logger.error(
f"Error occurred during Kafka transaction operation {op_name}"
)
raise
raise KafkaProducerTransactionCommitFailed(
"All Kafka transaction commit attempts failed; "
f"All Kafka {op_name} attempts failed; "
"aborting transaction and shutting down Application..."
)

def commit_transaction(
self,
positions: List[TopicPartition],
group_metadata: GroupMetadata,
timeout: Optional[float] = None,
):
self._retriable_commit_op(
self._producer.send_offsets_to_transaction,
[positions, group_metadata, timeout],
)
self._retriable_commit_op(self._producer.commit_transaction, [timeout])

def __enter__(self):
return self

Expand Down
3 changes: 3 additions & 0 deletions tests/containerhelper.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ def create_kafka_container() -> Tuple[DockerContainer, str, int]:
.with_env("KAFKA_OFFSETS_TOPIC_REPLICATION_FACTOR", "1")
.with_env("KAFKA_GROUP_INITIAL_REBALANCE_DELAY_MS", "10")
.with_env("CLUSTER_ID", kraft_cluster_id)
.with_env("KAFKA_TRANSACTION_STATE_LOG_NUM_PARTITIONS", "1")
.with_env("KAFKA_TRANSACTION_STATE_LOG_REPLICATION_FACTOR", "1")
.with_env("KAFKA_TRANSACTION_STATE_LOG_MIN_ISR", "1")
.with_bind_ports(kafka_port, kafka_port)
)
return kafka_container, broker_list, kafka_port
Expand Down
15 changes: 13 additions & 2 deletions tests/test_quixstreams/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,13 @@ def factory(
broker_address: str = kafka_container.broker_address,
extra_config: dict = None,
on_error: Optional[ProducerErrorCallback] = None,
transactional: bool = False,
) -> RowProducer:
return RowProducer(
broker_address=broker_address,
extra_config=extra_config,
on_error=on_error,
transactional=transactional,
)

return factory
Expand All @@ -244,6 +246,11 @@ def row_producer(row_producer_factory):
return row_producer_factory()


@pytest.fixture()
def transactional_row_producer(row_producer_factory):
return row_producer_factory(transactional=True)


@pytest.fixture()
def row_factory():
"""
Expand Down Expand Up @@ -291,6 +298,7 @@ def factory(
auto_create_topics: bool = True,
use_changelog_topics: bool = True,
topic_manager: Optional[TopicManager] = None,
exactly_once_guarantees: bool = False,
) -> Application:
state_dir = state_dir or (tmp_path / "state").absolute()
return Application(
Expand All @@ -308,6 +316,7 @@ def factory(
auto_create_topics=auto_create_topics,
use_changelog_topics=use_changelog_topics,
topic_manager=topic_manager,
exactly_once_guarantees=exactly_once_guarantees,
)

return factory
Expand Down Expand Up @@ -574,15 +583,17 @@ def topic_manager_topic_factory(topic_manager_factory):
"""

def factory(
name: Optional[str] = str(uuid.uuid4()),
name: Optional[str] = None,
partitions: int = 1,
create_topic: bool = False,
key_serializer: Optional[Union[Serializer, str]] = None,
value_serializer: Optional[Union[Serializer, str]] = None,
key_deserializer: Optional[Union[Deserializer, str]] = None,
value_deserializer: Optional[Union[Deserializer, str]] = None,
timestamp_extractor: Optional[TimestampExtractor] = None,
):
) -> Topic:
if not name:
name = str(uuid.uuid4())
topic_manager = topic_manager_factory()
topic_args = {
"key_serializer": key_serializer,
Expand Down
105 changes: 99 additions & 6 deletions tests/test_quixstreams/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,12 +527,109 @@ def test_consumer_group_default(self):
assert app._consumer_group == "quixstreams-default"


class TestAppExactlyOnce:

def test_exactly_once(
self,
app_factory,
topic_manager_factory,
row_consumer_factory,
executor,
row_factory,
):
"""
An Application that forwards messages to a new topic crashes after producing 2
messages, then restarts (will reprocess all 3 messages again).
The second run succeeds in processing all 3 messages and commits transaction.
The 2 non-committed produces should be ignored by a downstream consumer.
"""
processed_count = 0
total_messages = 3
fail_idx = 1
done = Future()

def on_message_processed(*_):
# Set the callback to track total messages processed
# The callback is not triggered if processing fails
nonlocal processed_count
processed_count += 1
# Stop processing after consuming all the messages
# if (processed_count % total_messages) == 0:
if processed_count == total_messages:
done.set_result(True)

class ForceFail(Exception): ...

def fail_once(value):
# sleep here to ensure produced messages make it to topic
time.sleep(2)
if processed_count == fail_idx:
raise ForceFail
return value

consumer_group = str(uuid.uuid4())
topic_in_name = str(uuid.uuid4())
topic_out_name = str(uuid.uuid4())

def get_app(fail: bool):
app = app_factory(
commit_interval=30,
auto_offset_reset="earliest",
on_message_processed=on_message_processed,
consumer_group=consumer_group,
exactly_once_guarantees=True,
)
topic_in = app.topic(topic_in_name, value_deserializer="json")
topic_out = app.topic(topic_out_name, value_serializer="json")
sdf = app.dataframe(topic_in)
sdf = sdf.to_topic(topic_out)
if fail:
sdf = sdf.apply(fail_once)
return app, sdf, topic_in, topic_out

# first run of app that encounters an error during processing
app, sdf, topic_in, topic_out = get_app(fail=True)

# produce initial messages to consume
with app.get_producer() as producer:
for i in range(total_messages):
msg = topic_in.serialize(key=str(i), value={"my_val": str(i)})
producer.produce(topic=topic_in.name, key=msg.key, value=msg.value)

with pytest.raises(ForceFail):
app.run(sdf)
assert processed_count == fail_idx

# re-init the app, only this time it won't fail
processed_count = 0
app, sdf, topic_in, topic_out = get_app(fail=False)
executor.submit(_stop_app_on_future, app, done, 10.0)
app.run(sdf)

# only committed messages are read by a downstream consumer
with row_consumer_factory(auto_offset_reset="earliest") as row_consumer:
row_consumer.subscribe([topic_out])
rows = []
while (row := row_consumer.poll_row(timeout=5)) is not None:
rows.append(row)
lowwater, highwater = row_consumer.get_watermark_offsets(
TopicPartition(topic_out.name, 0), 3
)
assert len(rows) == total_messages

# Sanity check that non-committed messages actually made it to topic
assert lowwater == 0
assert rows[0].offset == fail_idx + 2 == 3
assert highwater == rows[-1].offset + 2 == 7


class TestAppGroupBy:

def test_group_by(
self,
app_factory,
topic_manager_factory,
row_consumer_factory,
executor,
row_factory,
Expand All @@ -552,7 +649,6 @@ def on_message_processed(*_):
if processed_count == total_messages:
done.set_result(True)

topic_manager = topic_manager_factory() # just to make topic_config objects
processed_count = 0

timestamp = 1000
Expand All @@ -562,8 +658,8 @@ def on_message_processed(*_):
total_messages = expected_message_count * 2 # groupby reproduces each message
app = app_factory(
auto_offset_reset="earliest",
topic_manager=topic_manager,
on_message_processed=on_message_processed,
exactly_once_guarantees=True,
)

app_topic_in = app.topic(
Expand Down Expand Up @@ -620,7 +716,6 @@ def on_message_processed(*_):
def test_group_by_with_window(
self,
app_factory,
topic_manager_factory,
row_consumer_factory,
executor,
row_factory,
Expand All @@ -640,7 +735,6 @@ def on_message_processed(*_):
if processed_count == total_messages:
done.set_result(True)

topic_manager = topic_manager_factory() # just to make topic_config objects
processed_count = 0

timestamp = 1000
Expand All @@ -650,7 +744,6 @@ def on_message_processed(*_):
total_messages = expected_message_count * 2 # groupby reproduces each message
app = app_factory(
auto_offset_reset="earliest",
topic_manager=topic_manager,
on_message_processed=on_message_processed,
)

Expand Down
Loading

0 comments on commit 05420cd

Please sign in to comment.