Skip to content

Commit

Permalink
fix: Fixing implementation of revoke listener to not wait for streams…
Browse files Browse the repository at this point in the history
… termination while signaling streams to finish
  • Loading branch information
wookievx committed Dec 21, 2024
1 parent 4807bcb commit 32f5a53
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 33 deletions.
14 changes: 7 additions & 7 deletions modules/core/src/main/scala/fs2/kafka/KafkaConsumer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import scala.collection.immutable.SortedSet
import scala.concurrent.duration.FiniteDuration
import scala.util.matching.Regex
import cats.{Applicative, Foldable, Functor, Reducible}
import cats.data.{NonEmptySet, OptionT}
import cats.data.{Chain, NonEmptySet, OptionT}
import cats.effect.*
import cats.effect.implicits.*
import cats.effect.std.*
Expand Down Expand Up @@ -268,17 +268,15 @@ object KafkaConsumer {
): OnRebalance[F] =
OnRebalance(
onRevoked = revoked => {
val finishSignals = for {
for {
finishers <- assignmentRef.modify(_.partition(entry => !revoked.contains(entry._1)))
revokeFinishers <- finishers
.toVector
revokeFinishers <- Chain
.fromIterableOnce(finishers)
.traverse {
case (_, assignmentSignals) =>
assignmentSignals.signalStreamToTerminate.as(assignmentSignals.awaitStreamFinishedSignal)
}
} yield revokeFinishers

finishSignals.flatMap(revokes => revokes.sequence_)
},
onAssigned = assignedPartitions => {
for {
Expand Down Expand Up @@ -447,7 +445,9 @@ object KafkaConsumer {
assignmentRef.updateAndGet(_ ++ assigned).flatMap(updateQueue.offer),
onRevoked = revoked =>
initialAssignmentDone >>
assignmentRef.updateAndGet(_ -- revoked).flatMap(updateQueue.offer)
assignmentRef.updateAndGet(_ -- revoked)
.flatMap(updateQueue.offer)
.as(Chain.empty)
)

Stream
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,14 +223,24 @@ final private[kafka] class KafkaConsumerActor[F[_], K, V](
)).run(withRebalancing)
}
.flatMap { res =>
val onRevoked =
res.onRebalances.foldLeft(F.unit)(_ >> _.onRevoked(revoked))
val onRevokedStarted =
res.onRebalances.foldLeft(F.pure(Chain.empty[F[Unit]])) { (acc, next) =>
for {
acc <- acc
next <- next.onRevoked(revoked)
} yield acc ++ next
}

res.logRevoked >>
res.completeWithRecords >>
res.completeWithoutRecords >>
res.removeRevokedRecords >>
onRevoked.timeout(settings.sessionTimeout) //just to be extra-safe timeout this revoke
onRevokedStarted //first we need to trigger interruption for all the consuming streams
.flatMap(_.sequence_) //second we await for all the streams to finish processing (Eager mode returns immediately)
.timeoutTo(
settings.sessionTimeout,
ref.get.flatMap(state => log(LogEntry.RevokeTimeoutOccurred(revoked, state)))
)
}
}

Expand Down Expand Up @@ -630,7 +640,7 @@ private[kafka] object KafkaConsumerActor {

final case class OnRebalance[F[_]](
onAssigned: SortedSet[TopicPartition] => F[Unit],
onRevoked: SortedSet[TopicPartition] => F[Unit],
onRevoked: SortedSet[TopicPartition] => F[Chain[F[Unit]]],
) {

override def toString: String =
Expand Down
10 changes: 10 additions & 0 deletions modules/core/src/main/scala/fs2/kafka/internal/LogEntry.scala
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,16 @@ private[kafka] object LogEntry {

}

final case class RevokeTimeoutOccurred[F[_]](
revoked: Set[TopicPartition],
state: State[F, ?, ?]
) extends LogEntry {
override def level: LogLevel = Info

override def message: String =
s"Consuming streams did not signal processing completion of [$revoked]. Current state [$state]."
}

def recordsString[F[_]](
records: Records[F]
): String =
Expand Down
75 changes: 53 additions & 22 deletions modules/core/src/test/scala/fs2/kafka/KafkaConsumerSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,16 @@ package fs2.kafka

import scala.collection.immutable.SortedSet
import scala.concurrent.duration.*

import cats.data.NonEmptySet
import cats.effect.{Clock, Fiber, IO, Ref}
import cats.effect.std.Queue
import cats.effect.std.{Queue, Semaphore}
import cats.effect.unsafe.implicits.global
import cats.syntax.all.*
import fs2.concurrent.SignallingRef
import fs2.kafka.consumer.KafkaConsumeChunk.CommitNow
import fs2.kafka.internal.converters.collection.*
import fs2.Stream

import org.apache.kafka.clients.consumer.{
ConsumerConfig,
CooperativeStickyAssignor,
NoOffsetForPartitionException
}
import org.apache.kafka.clients.consumer.{ConsumerConfig, CooperativeStickyAssignor, NoOffsetForPartitionException}
import org.apache.kafka.common.errors.TimeoutException
import org.apache.kafka.common.TopicPartition
import org.scalatest.Assertion
Expand Down Expand Up @@ -71,15 +65,17 @@ final class KafkaConsumerSpec extends BaseKafkaSpec {
}
}

it("should consume all records at least once with subscribing for several consumers") {
def testMultipleConsumersCorrectConsumption(
customizeSettings: ConsumerSettings[IO, String, String] => ConsumerSettings[IO, String, String]
) = {
withTopic { topic =>
createCustomTopic(topic, partitions = 3)
val produced = (0 until 5).map(n => s"key-$n" -> s"value->$n")
publishToKafka(topic, produced)

val consumed =
KafkaConsumer
.stream(consumerSettings[IO].withGroupId("test"))
.stream(customizeSettings(consumerSettings[IO].withGroupId("test")))
.subscribeTo(topic)
.evalMap(IO.sleep(3.seconds).as(_)) // sleep a bit to trigger potential race condition with _.stream
.records
Expand All @@ -101,6 +97,14 @@ final class KafkaConsumerSpec extends BaseKafkaSpec {
}
}

it("should consume all records at least once with subscribing for several consumers") {
testMultipleConsumersCorrectConsumption(identity)
}

it("should consume all records at least once with subscribing for several consumers in graceful mode") {
testMultipleConsumersCorrectConsumption(_.withRebalanceRevokeMode(RebalanceRevokeMode.Graceful))
}

it("should consume records with assign by partitions") {
withTopic { topic =>
createCustomTopic(topic, partitions = 3)
Expand Down Expand Up @@ -1212,7 +1216,7 @@ final class KafkaConsumerSpec extends BaseKafkaSpec {
}

describe("KafkaConsumer#stream") {
it("should wait for previous generation of streams to start consuming messages with RebalanceRevokeMode#Graceful") {
it("should wait for previous generation of streams to finish before starting consuming messages with RebalanceRevokeMode#Graceful") {
withTopic { topic =>
createCustomTopic(topic, partitions = 2) //minimal amount of partitions for two consumers
def recordRange(from: Int, _until: Int) = (from until _until).map(n => s"key-$n" -> s"value-$n")
Expand All @@ -1222,50 +1226,77 @@ final class KafkaConsumerSpec extends BaseKafkaSpec {
publishToKafka(topic, produced)
}

val settings = consumerSettings[IO].withGroupId("rebalance-test-group").withRebalanceRevokeMode(RebalanceRevokeMode.Graceful).withAutoOffsetReset(AutoOffsetReset.EarliestOffsetReset)

// tracking consumption for being unique by explicitly commiting after each message
// the expected timeline looks like this:
// ----stream1---|locked auquired|--|producing second batch|--|rebalance-triggered|--|lock-released|---...
//idea is that we check if all the messages are consumed by one processor (exactly once processing implies that)
val consumed = for {
lock <- Semaphore[IO](1)
ref <- Ref.of[IO, Vector[(String, String)]](Vector.empty)
_ <- produceRange(0, 10)
_ <-
KafkaConsumer
.stream(consumerSettings[IO].withRebalanceRevokeMode(RebalanceRevokeMode.Graceful))
.stream(settings)
.evalTap(_.subscribeTo(topic))
.flatMap(
_.stream
.evalMap { record =>
ref.update(_ :+ (record.record.key -> record.record.value)).as(record.offset)
lock.permit.use { _ =>
ref.update(_ :+ (record.record.key -> record.record.value)).as(record)
}
}
.evalMap { r =>
//if key is last return none and terminate stream
if (r.record.key == "key-29") {
r.offset.commit.as(None)
} else {
r.offset.commit.as(Some(r))
}
}
.evalTap(_.commit)
.unNoneTerminate
)
.interruptAfter(3.seconds)
.compile
.drain
.race {
Clock[IO].sleep(1.second) *>
Clock[IO].sleep(100.millis) *>
lock.acquire *>
produceRange(10, 20) *>
KafkaConsumer
.stream(consumerSettings[IO].withRebalanceRevokeMode(RebalanceRevokeMode.Graceful))
.stream(settings)
.evalTap(_.subscribeTo(topic))
.evalTap(_ => lock.release)
.flatMap( c =>
fs2.Stream.exec(produceRange(20, 30)) ++
c.stream
.evalMap { record =>
ref.update(_ :+ (record.record.key -> record.record.value)).as(record.offset)
ref.update(_ :+ (record.record.key -> record.record.value)).as(record)
}
.evalMap { r =>
//if key is last return none and terminate stream
if (r.record.key == "key-29") {
r.offset.commit.as(None)
} else {
r.offset.commit.as(Some(r))
}
}
.evalTap(_.commit)
.unNoneTerminate
)
.interruptAfter(3.seconds)
.compile
.drain
}
.timeout(5.seconds)
res <- ref.get
} yield res

val res = consumed.unsafeRunSync()

//expected behavior is that no duplicate consumption is performed
res.toSet should have size res.length.toLong
(res should contain).theSameElementsAs(recordRange(0, 10) ++ recordRange(10, 20) ++ recordRange(20, 30))
val resultSet = res.toSet
resultSet should have size res.length.toLong
val expectedSet = (recordRange(0, 10) ++ recordRange(10, 20) ++ recordRange(20, 30)).toSet
resultSet shouldEqual expectedSet
}
}
}
Expand Down

0 comments on commit 32f5a53

Please sign in to comment.