diff --git a/modules/effects/src/main/scala/dev/profunktor/redis4cats/runner.scala b/modules/effects/src/main/scala/dev/profunktor/redis4cats/runner.scala index 59f7f70f..92cc9bfb 100644 --- a/modules/effects/src/main/scala/dev/profunktor/redis4cats/runner.scala +++ b/modules/effects/src/main/scala/dev/profunktor/redis4cats/runner.scala @@ -78,26 +78,29 @@ private[redis4cats] class RunnerPartiallyApplied[F[_]: Concurrent: Log: Timer] { def exec[T <: HList, R <: HList](ops: Runner.Ops[F])(commands: T)(implicit w: Witness.Aux[T, R]): F[R] = (Deferred[F, Either[Throwable, w.R]], F.delay(UUID.randomUUID), getTxDelay).tupled.flatMap { case (promise, uuid, txDelay) => - def cancelFibers[A](fibs: HList)(err: Throwable): F[Unit] = - joinOrCancel(fibs, HNil)(false).void >> promise.complete(err.asLeft) + def cancelFibers[A](fibs: HList)(after: F[Unit])(err: Throwable): F[Unit] = + joinOrCancel(fibs, HNil)(false).void.guarantee(after) >> promise.complete(err.asLeft) + + def onErrorOrCancelation(fibs: HList): F[Unit] = + cancelFibers(fibs)(ops.onError)(ops.mkError()) F.debug(s"${ops.name} started - ID: $uuid") >> Resource .makeCase(ops.mainCmd >> runner(commands, HNil)) { - case ((fibs: HList), ExitCase.Completed) => + case (fibs, ExitCase.Completed) => for { _ <- F.debug(s"${ops.name} completed - ID: $uuid") - _ <- ops.onComplete(cancelFibers(fibs)) + _ <- ops.onComplete(cancelFibers(fibs)(F.unit)) tr <- joinOrCancel(fibs, HNil)(true) // Casting here is fine since we have a `Witness` that proves this true _ <- promise.complete(tr.asInstanceOf[w.R].asRight) } yield () - case ((fibs: HList), ExitCase.Error(e)) => + case (fibs, ExitCase.Error(e)) => F.error(s"${ops.name} failed: ${e.getMessage} - ID: $uuid") >> - ops.onError.guarantee(cancelFibers(fibs)(ops.mkError())) - case ((fibs: HList), ExitCase.Canceled) => + onErrorOrCancelation(fibs) + case (fibs, ExitCase.Canceled) => F.error(s"${ops.name} canceled - ID: $uuid") >> - ops.onError.guarantee(cancelFibers(fibs)(ops.mkError())) + onErrorOrCancelation(fibs) } .use(_ => F.sleep(txDelay).void) .guarantee(ops.afterCompletion) >> promise.get.rethrow.timeout(3.seconds) diff --git a/modules/tests/src/test/scala/dev/profunktor/redis4cats/TestScenarios.scala b/modules/tests/src/test/scala/dev/profunktor/redis4cats/TestScenarios.scala index fcc28dc9..5e980bb4 100644 --- a/modules/tests/src/test/scala/dev/profunktor/redis4cats/TestScenarios.scala +++ b/modules/tests/src/test/scala/dev/profunktor/redis4cats/TestScenarios.scala @@ -449,14 +449,23 @@ trait TestScenarios { self: FunSuite => } def canceledTransactionScenario(cmd: RedisCommands[IO, String, String]): IO[Unit] = { - val tx = RedisTransaction(cmd) + val key1 = "tx-1" + val key2 = "tx-2" + val tx = RedisTransaction(cmd) - val commands = - cmd.set("tx-1", "v1") :: cmd.set("tx-2", "v2") :: cmd.set("tx-3", "v3") :: HNil + val commands = cmd.set(key1, "v1") :: cmd.set(key2, "v2") :: cmd.set("tx-3", "v3") :: HNil - // Transaction should be canceled - IO.race(tx.exec(commands).attempt.void, IO.unit) >> - cmd.get("tx-1").map(assertEquals(_, None)) // no keys written + // We race it with a plain `IO.unit` so the transaction may or may not start at all but the result should be the same + val verifyKey1 = + IO.race(tx.exec(commands).attempt.void, IO.unit) >> + cmd.get(key1).map(assertEquals(_, None)) // no keys written + + // We race it with a sleep to make sure the transaction gets time to start running + val verifyKey2 = + IO.race(tx.exec(commands).attempt.void, IO.sleep(20.millis).void) >> + cmd.get(key2).map(assertEquals(_, None)) // no keys written + + verifyKey1 >> verifyKey2 } def scriptsScenario(cmd: RedisCommands[IO, String, String]): IO[Unit] = {