diff --git a/modules/core/src/main/scala/fs2/kafka/KafkaConsumer.scala b/modules/core/src/main/scala/fs2/kafka/KafkaConsumer.scala index c52fe843b..5564f8a11 100644 --- a/modules/core/src/main/scala/fs2/kafka/KafkaConsumer.scala +++ b/modules/core/src/main/scala/fs2/kafka/KafkaConsumer.scala @@ -12,7 +12,7 @@ import scala.collection.immutable.SortedSet import scala.concurrent.duration.FiniteDuration import scala.util.matching.Regex import cats.{Applicative, Foldable, Functor, Reducible} -import cats.data.{NonEmptySet, OptionT} +import cats.data.{Chain, NonEmptySet, OptionT} import cats.effect.* import cats.effect.implicits.* import cats.effect.std.* @@ -268,17 +268,15 @@ object KafkaConsumer { ): OnRebalance[F] = OnRebalance( onRevoked = revoked => { - val finishSignals = for { + for { finishers <- assignmentRef.modify(_.partition(entry => !revoked.contains(entry._1))) - revokeFinishers <- finishers - .toVector + revokeFinishers <- Chain + .fromIterableOnce(finishers) .traverse { case (_, assignmentSignals) => assignmentSignals.signalStreamToTerminate.as(assignmentSignals.awaitStreamFinishedSignal) } } yield revokeFinishers - - finishSignals.flatMap(revokes => revokes.sequence_) }, onAssigned = assignedPartitions => { for { @@ -447,7 +445,9 @@ object KafkaConsumer { assignmentRef.updateAndGet(_ ++ assigned).flatMap(updateQueue.offer), onRevoked = revoked => initialAssignmentDone >> - assignmentRef.updateAndGet(_ -- revoked).flatMap(updateQueue.offer) + assignmentRef.updateAndGet(_ -- revoked) + .flatMap(updateQueue.offer) + .as(Chain.empty) ) Stream diff --git a/modules/core/src/main/scala/fs2/kafka/internal/KafkaConsumerActor.scala b/modules/core/src/main/scala/fs2/kafka/internal/KafkaConsumerActor.scala index e2c049625..a468dc786 100644 --- a/modules/core/src/main/scala/fs2/kafka/internal/KafkaConsumerActor.scala +++ b/modules/core/src/main/scala/fs2/kafka/internal/KafkaConsumerActor.scala @@ -223,14 +223,24 @@ final private[kafka] class KafkaConsumerActor[F[_], K, V]( )).run(withRebalancing) } .flatMap { res => - val onRevoked = - res.onRebalances.foldLeft(F.unit)(_ >> _.onRevoked(revoked)) + val onRevokedStarted = + res.onRebalances.foldLeft(F.pure(Chain.empty[F[Unit]])) { (acc, next) => + for { + acc <- acc + next <- next.onRevoked(revoked) + } yield acc ++ next + } res.logRevoked >> res.completeWithRecords >> res.completeWithoutRecords >> res.removeRevokedRecords >> - onRevoked.timeout(settings.sessionTimeout) //just to be extra-safe timeout this revoke + onRevokedStarted //first we need to trigger interruption for all the consuming streams + .flatMap(_.sequence_) //second we await for all the streams to finish processing (Eager mode returns immediately) + .timeoutTo( + settings.sessionTimeout, + ref.get.flatMap(state => log(LogEntry.RevokeTimeoutOccurred(revoked, state))) + ) } } @@ -630,7 +640,7 @@ private[kafka] object KafkaConsumerActor { final case class OnRebalance[F[_]]( onAssigned: SortedSet[TopicPartition] => F[Unit], - onRevoked: SortedSet[TopicPartition] => F[Unit], + onRevoked: SortedSet[TopicPartition] => F[Chain[F[Unit]]], ) { override def toString: String = diff --git a/modules/core/src/main/scala/fs2/kafka/internal/LogEntry.scala b/modules/core/src/main/scala/fs2/kafka/internal/LogEntry.scala index 782945a1d..3e8a396b7 100644 --- a/modules/core/src/main/scala/fs2/kafka/internal/LogEntry.scala +++ b/modules/core/src/main/scala/fs2/kafka/internal/LogEntry.scala @@ -223,6 +223,16 @@ private[kafka] object LogEntry { } + final case class RevokeTimeoutOccurred[F[_]]( + revoked: Set[TopicPartition], + state: State[F, ?, ?] + ) extends LogEntry { + override def level: LogLevel = Info + + override def message: String = + s"Consuming streams did not signal processing completion of [$revoked]. Current state [$state]." + } + def recordsString[F[_]]( records: Records[F] ): String = diff --git a/modules/core/src/test/scala/fs2/kafka/KafkaConsumerSpec.scala b/modules/core/src/test/scala/fs2/kafka/KafkaConsumerSpec.scala index 1fc2a800e..1dc3f26d6 100644 --- a/modules/core/src/test/scala/fs2/kafka/KafkaConsumerSpec.scala +++ b/modules/core/src/test/scala/fs2/kafka/KafkaConsumerSpec.scala @@ -8,22 +8,16 @@ package fs2.kafka import scala.collection.immutable.SortedSet import scala.concurrent.duration.* - import cats.data.NonEmptySet import cats.effect.{Clock, Fiber, IO, Ref} -import cats.effect.std.Queue +import cats.effect.std.{Queue, Semaphore} import cats.effect.unsafe.implicits.global import cats.syntax.all.* import fs2.concurrent.SignallingRef import fs2.kafka.consumer.KafkaConsumeChunk.CommitNow import fs2.kafka.internal.converters.collection.* import fs2.Stream - -import org.apache.kafka.clients.consumer.{ - ConsumerConfig, - CooperativeStickyAssignor, - NoOffsetForPartitionException -} +import org.apache.kafka.clients.consumer.{ConsumerConfig, CooperativeStickyAssignor, NoOffsetForPartitionException} import org.apache.kafka.common.errors.TimeoutException import org.apache.kafka.common.TopicPartition import org.scalatest.Assertion @@ -71,7 +65,9 @@ final class KafkaConsumerSpec extends BaseKafkaSpec { } } - it("should consume all records at least once with subscribing for several consumers") { + def testMultipleConsumersCorrectConsumption( + customizeSettings: ConsumerSettings[IO, String, String] => ConsumerSettings[IO, String, String] + ) = { withTopic { topic => createCustomTopic(topic, partitions = 3) val produced = (0 until 5).map(n => s"key-$n" -> s"value->$n") @@ -79,7 +75,7 @@ final class KafkaConsumerSpec extends BaseKafkaSpec { val consumed = KafkaConsumer - .stream(consumerSettings[IO].withGroupId("test")) + .stream(customizeSettings(consumerSettings[IO].withGroupId("test"))) .subscribeTo(topic) .evalMap(IO.sleep(3.seconds).as(_)) // sleep a bit to trigger potential race condition with _.stream .records @@ -101,6 +97,14 @@ final class KafkaConsumerSpec extends BaseKafkaSpec { } } + it("should consume all records at least once with subscribing for several consumers") { + testMultipleConsumersCorrectConsumption(identity) + } + + it("should consume all records at least once with subscribing for several consumers in graceful mode") { + testMultipleConsumersCorrectConsumption(_.withRebalanceRevokeMode(RebalanceRevokeMode.Graceful)) + } + it("should consume records with assign by partitions") { withTopic { topic => createCustomTopic(topic, partitions = 3) @@ -1212,7 +1216,7 @@ final class KafkaConsumerSpec extends BaseKafkaSpec { } describe("KafkaConsumer#stream") { - it("should wait for previous generation of streams to start consuming messages with RebalanceRevokeMode#Graceful") { + it("should wait for previous generation of streams to finish before starting consuming messages with RebalanceRevokeMode#Graceful") { withTopic { topic => createCustomTopic(topic, partitions = 2) //minimal amount of partitions for two consumers def recordRange(from: Int, _until: Int) = (from until _until).map(n => s"key-$n" -> s"value-$n") @@ -1222,50 +1226,77 @@ final class KafkaConsumerSpec extends BaseKafkaSpec { publishToKafka(topic, produced) } + val settings = consumerSettings[IO].withGroupId("rebalance-test-group").withRebalanceRevokeMode(RebalanceRevokeMode.Graceful).withAutoOffsetReset(AutoOffsetReset.EarliestOffsetReset) + // tracking consumption for being unique by explicitly commiting after each message + // the expected timeline looks like this: + // ----stream1---|locked auquired|--|producing second batch|--|rebalance-triggered|--|lock-released|---... + //idea is that we check if all the messages are consumed by one processor (exactly once processing implies that) val consumed = for { + lock <- Semaphore[IO](1) ref <- Ref.of[IO, Vector[(String, String)]](Vector.empty) _ <- produceRange(0, 10) _ <- KafkaConsumer - .stream(consumerSettings[IO].withRebalanceRevokeMode(RebalanceRevokeMode.Graceful)) + .stream(settings) .evalTap(_.subscribeTo(topic)) .flatMap( _.stream .evalMap { record => - ref.update(_ :+ (record.record.key -> record.record.value)).as(record.offset) + lock.permit.use { _ => + ref.update(_ :+ (record.record.key -> record.record.value)).as(record) + } + } + .evalMap { r => + //if key is last return none and terminate stream + if (r.record.key == "key-29") { + r.offset.commit.as(None) + } else { + r.offset.commit.as(Some(r)) + } } - .evalTap(_.commit) + .unNoneTerminate ) - .interruptAfter(3.seconds) .compile .drain .race { - Clock[IO].sleep(1.second) *> + Clock[IO].sleep(100.millis) *> + lock.acquire *> produceRange(10, 20) *> KafkaConsumer - .stream(consumerSettings[IO].withRebalanceRevokeMode(RebalanceRevokeMode.Graceful)) + .stream(settings) .evalTap(_.subscribeTo(topic)) + .evalTap(_ => lock.release) .flatMap( c => fs2.Stream.exec(produceRange(20, 30)) ++ c.stream .evalMap { record => - ref.update(_ :+ (record.record.key -> record.record.value)).as(record.offset) + ref.update(_ :+ (record.record.key -> record.record.value)).as(record) + } + .evalMap { r => + //if key is last return none and terminate stream + if (r.record.key == "key-29") { + r.offset.commit.as(None) + } else { + r.offset.commit.as(Some(r)) + } } - .evalTap(_.commit) + .unNoneTerminate ) - .interruptAfter(3.seconds) .compile .drain } + .timeout(5.seconds) res <- ref.get } yield res val res = consumed.unsafeRunSync() //expected behavior is that no duplicate consumption is performed - res.toSet should have size res.length.toLong - (res should contain).theSameElementsAs(recordRange(0, 10) ++ recordRange(10, 20) ++ recordRange(20, 30)) + val resultSet = res.toSet + resultSet should have size res.length.toLong + val expectedSet = (recordRange(0, 10) ++ recordRange(10, 20) ++ recordRange(20, 30)).toSet + resultSet shouldEqual expectedSet } } }