Skip to content

Commit

Permalink
Merge pull request #277 from profunktor/feature/pipelining-using-hlists
Browse files Browse the repository at this point in the history
Pipelining using HLists
  • Loading branch information
gvolpe authored May 9, 2020
2 parents 466437b + dbaa990 commit 99e03f4
Show file tree
Hide file tree
Showing 10 changed files with 209 additions and 99 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,30 @@
package dev.profunktor.redis4cats

import cats.effect._
import cats.effect.implicits._
import cats.implicits._
import dev.profunktor.redis4cats.effect.Log
import dev.profunktor.redis4cats.hlist._
import scala.util.control.NoStackTrace

object pipeline {

case class RedisPipeline[F[_]: Bracket[*[_], Throwable]: Log, K, V](
case object PipelineError extends NoStackTrace

case class RedisPipeline[F[_]: Concurrent: Log: Timer, K, V](
cmd: RedisCommands[F, K, V]
) {
def run[A](fa: F[A]): F[A] =
F.info("Pipeline started") *>
cmd.disableAutoFlush
.bracketCase(_ => fa) {
case (_, ExitCase.Completed) => cmd.flushCommands *> F.info("Pipeline completed")
case (_, ExitCase.Error(e)) => F.error(s"Pipeline failed: ${e.getMessage}")
case (_, ExitCase.Canceled) => F.error("Pipeline canceled")
}
.guarantee(cmd.enableAutoFlush)

def exec[T <: HList, R <: HList](commands: T)(implicit w: Witness.Aux[T, R]): F[R] =
Runner[F].exec(
Runner.Ops(
name = "Pipeline",
mainCmd = cmd.disableAutoFlush,
onComplete = (_: Runner.CancelFibers[F]) => cmd.flushCommands,
onError = F.unit,
afterCompletion = cmd.enableAutoFlush,
mkError = () => PipelineError
)
)(commands)

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -259,13 +259,13 @@ private[redis4cats] class BaseRedis[F[_]: Concurrent: ContextShift, K, V](

/******************************* AutoFlush API **********************************/
override def enableAutoFlush: F[Unit] =
async.flatMap(c => F.delay(c.setAutoFlushCommands(true)))
blocker.blockOn(async.flatMap(c => blocker.delay(c.setAutoFlushCommands(true))))

override def disableAutoFlush: F[Unit] =
async.flatMap(c => F.delay(c.setAutoFlushCommands(false)))
blocker.blockOn(async.flatMap(c => blocker.delay(c.setAutoFlushCommands(false))))

override def flushCommands: F[Unit] =
async.flatMap(c => F.delay(c.flushCommands()))
blocker.blockOn(async.flatMap(c => blocker.delay(c.flushCommands())))

/******************************* Strings API **********************************/
override def append(key: K, value: V): F[Unit] =
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* Copyright 2018-2020 ProfunKtor
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package dev.profunktor.redis4cats

import cats.effect._
import cats.effect.concurrent.Deferred
import cats.effect.implicits._
import cats.implicits._
import dev.profunktor.redis4cats.effect.Log
import dev.profunktor.redis4cats.hlist._
import scala.concurrent.duration._

object Runner {
type CancelFibers[F[_]] = Throwable => F[Unit]

case class Ops[F[_]](
name: String,
mainCmd: F[Unit],
onComplete: CancelFibers[F] => F[Unit],
onError: F[Unit],
afterCompletion: F[Unit],
mkError: () => Throwable
)

def apply[F[_]: Concurrent: Log: Timer]: RunnerPartiallyApplied[F] =
new RunnerPartiallyApplied[F]
}

private[redis4cats] class RunnerPartiallyApplied[F[_]: Concurrent: Log: Timer] {

def exec[T <: HList, R <: HList](ops: Runner.Ops[F])(commands: T)(implicit w: Witness.Aux[T, R]): F[R] =
Deferred[F, Either[Throwable, w.R]].flatMap { promise =>
def cancelFibers[A](fibs: HList)(err: Throwable): F[Unit] =
joinOrCancel(fibs, HNil)(false).void >> promise.complete(err.asLeft)

F.info(s"${ops.name} started") >>
Resource
.makeCase(ops.mainCmd >> runner(commands, HNil)) {
case ((fibs: HList), ExitCase.Completed) =>
for {
_ <- F.info(s"${ops.name} completed")
_ <- ops.onComplete(cancelFibers(fibs))
tr <- joinOrCancel(fibs, HNil)(true)
// Casting here is fine since we have a `Witness` that proves this true
_ <- promise.complete(tr.asInstanceOf[w.R].asRight)
} yield ()
case ((fibs: HList), ExitCase.Error(e)) =>
F.error(s"${ops.name} failed: ${e.getMessage}") >>
ops.onError.guarantee(cancelFibers(fibs)(ops.mkError()))
case ((fibs: HList), ExitCase.Canceled) =>
F.error(s"${ops.name} canceled") >>
ops.onError.guarantee(cancelFibers(fibs)(ops.mkError()))
case _ =>
F.error("Kernel panic: the impossible happened!")
}
.use(_ => F.unit)
.guarantee(ops.afterCompletion) >> promise.get.rethrow.timeout(3.seconds)
}

// Forks every command in order
private def runner[H <: HList, G <: HList](ys: H, res: G): F[Any] =
ys match {
case HNil => F.pure(res)
case HCons((h: F[_] @unchecked), t) => h.start.flatMap(fb => runner(t, fb :: res))
}

// Joins or cancel fibers correspondent to previous executed commands
private def joinOrCancel[H <: HList, G <: HList](ys: H, res: G)(isJoin: Boolean): F[Any] =
ys match {
case HNil => F.pure(res)
case HCons((h: Fiber[F, Any] @unchecked), t) if isJoin =>
h.join.flatMap(x => joinOrCancel(t, x :: res)(isJoin))
case HCons((h: Fiber[F, Any] @unchecked), t) =>
h.cancel.flatMap(x => joinOrCancel(t, x :: res)(isJoin))
case HCons(h, t) =>
F.error(s"Unexpected result: ${h.toString}") >> joinOrCancel(t, res)(isJoin)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,9 @@
package dev.profunktor.redis4cats

import cats.effect._
import cats.effect.concurrent._
import cats.effect.implicits._
import cats.implicits._
import dev.profunktor.redis4cats.effect.Log
import dev.profunktor.redis4cats.hlist._
import scala.concurrent.duration._
import scala.util.control.NoStackTrace

object transactions {
Expand All @@ -46,52 +43,16 @@ object transactions {
* may end in unexpected results such as a dead lock.
*/
def exec[T <: HList, R <: HList](commands: T)(implicit w: Witness.Aux[T, R]): F[R] =
Deferred[F, Either[Throwable, w.R]].flatMap { promise =>
// Forks every command in order
def runner[H <: HList, G <: HList](ys: H, res: G): F[Any] =
ys match {
case HNil => F.pure(res)
case HCons((h: F[_] @unchecked), t) => h.start.flatMap(fb => runner(t, fb :: res))
}

// Joins or cancel fibers correspondent to previous executed commands
def joinOrCancel[H <: HList, G <: HList](ys: H, res: G)(isJoin: Boolean): F[Any] =
ys match {
case HNil => F.pure(res)
case HCons((h: Fiber[F, Any] @unchecked), t) if isJoin =>
h.join.flatMap(x => joinOrCancel(t, x :: res)(isJoin))
case HCons((h: Fiber[F, Any] @unchecked), t) =>
h.cancel.flatMap(x => joinOrCancel(t, x :: res)(isJoin))
case HCons(h, t) =>
F.error(s"Unexpected result: ${h.toString}") >> joinOrCancel(t, res)(isJoin)
}

def cancelFibers(fibs: HList, err: Throwable = TransactionAborted): F[Unit] =
joinOrCancel(fibs, HNil)(false).void >> promise.complete(err.asLeft)

val tx =
Resource.makeCase(cmd.multi >> runner(commands, HNil)) {
case ((fibs: HList), ExitCase.Completed) =>
for {
_ <- F.info("Transaction completed")
_ <- cmd.exec.handleErrorWith(e => cancelFibers(fibs, e) >> F.raiseError(e))
tr <- joinOrCancel(fibs, HNil)(true)
// Casting here is fine since we have a `Witness` that proves this true
_ <- promise.complete(tr.asInstanceOf[w.R].asRight)
} yield ()
case ((fibs: HList), ExitCase.Error(e)) =>
F.error(s"Transaction failed: ${e.getMessage}") >>
cmd.discard.guarantee(cancelFibers(fibs))
case ((fibs: HList), ExitCase.Canceled) =>
F.error("Transaction canceled") >>
cmd.discard.guarantee(cancelFibers(fibs))
case _ =>
F.error("Kernel panic: the impossible happened!")
}

F.info("Transaction started") >>
(tx.use(_ => F.unit) >> promise.get.rethrow).timeout(3.seconds)
}
Runner[F].exec(
Runner.Ops(
name = "Transaction",
mainCmd = cmd.multi,
onComplete = (f: Runner.CancelFibers[F]) => cmd.exec.handleErrorWith(e => f(e) >> F.raiseError(e)),
onError = cmd.discard,
afterCompletion = F.unit,
mkError = () => TransactionAborted
)
)(commands)

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,20 @@ import cats.effect._
import cats.implicits._
import dev.profunktor.redis4cats.connection._
import dev.profunktor.redis4cats.effect.Log
import dev.profunktor.redis4cats.hlist._
import dev.profunktor.redis4cats.pipeline._
import scala.concurrent.duration._
import java.util.concurrent.TimeoutException

object RedisPipelineDemo extends LoggerIOApp {

import Demo._

def program(implicit log: Log[IO]): IO[Unit] = {
val key = "testp"
val key1 = "testp1"
val key2 = "testp2"

val showResult: Int => Option[String] => IO[Unit] = n =>
_.fold(putStrLn(s"Not found key $key-$n"))(s => putStrLn(s))
val showResult: String => Option[String] => IO[Unit] = key =>
_.fold(putStrLn(s"Not found key: $key"))(s => putStrLn(s"$key: $s"))

val commandsApi: Resource[IO, RedisCommands[IO, String, String]] =
for {
Expand All @@ -42,16 +44,29 @@ object RedisPipelineDemo extends LoggerIOApp {

commandsApi
.use { cmd =>
def traversal(f: Int => IO[Unit]): IO[Unit] =
List.range(0, 50).traverse(f).void
val getters =
cmd.get(key1).flatTap(showResult(key1)) *>
cmd.get(key2).flatTap(showResult(key2))

val setters: IO[Unit] =
traversal(n => cmd.set(s"$key-$n", (n * 2).toString).start.void)
val operations =
cmd.set(key1, "noop") :: cmd.set(key2, "windows") :: cmd.get(key1) ::
cmd.set(key1, "nix") :: cmd.set(key2, "linux") :: cmd.get(key1) :: HNil

val getters: IO[Unit] =
traversal(n => cmd.get(s"$key-$n").flatMap(showResult(n)))
val prog =
RedisPipeline(cmd)
.exec(operations)
.flatMap {
case _ ~: _ ~: res1 ~: _ ~: _ ~: res2 ~: HNil =>
putStrLn(s"res1: $res1, res2: $res2")
}
.onError {
case PipelineError =>
putStrLn("[Error] - Pipeline failed")
case _: TimeoutException =>
putStrLn("[Error] - Timeout")
}

RedisPipeline(cmd).run(setters) *> IO.sleep(2.seconds) *> getters
getters >> prog >> getters >> putStrLn("keep doing stuff...")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ object RedisTransactionsDemo extends LoggerIOApp {

commandsApi
.use { cmd =>
val tx = RedisTransaction(cmd)

val getters =
cmd.get(key1).flatTap(showResult(key1)) *>
cmd.get(key2).flatTap(showResult(key2))
Expand All @@ -59,7 +57,8 @@ object RedisTransactionsDemo extends LoggerIOApp {

//type Res = Unit :: Unit :: Option[String] :: Unit :: Unit :: Option[String] :: HNil
val prog =
tx.exec(operations)
RedisTransaction(cmd)
.exec(operations)
.flatMap {
case _ ~: _ ~: res1 ~: _ ~: _ ~: res2 ~: HNil =>
putStrLn(s"res1: $res1, res2: $res2")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class RedisClusterSpec extends Redis4CatsFunSuite(true) with TestScenarios {

test("cluster: scripts")(withRedis(scriptsScenario))

test("cluster: pipelining")(withRedisCluster(pipelineScenario))

// FIXME: The Cluster impl cannot connect to a single node just yet
// test("cluster: transactions")(withRedisCluster(transactionScenario))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class RedisSpec extends Redis4CatsFunSuite(false) with TestScenarios {

test("connection api")(withRedis(connectionScenario))

test("pipelining")(withRedis(pipelineScenario))

test("transactions: successful")(withRedis(transactionScenario))

test("transactions: canceled")(withRedis(canceledTransactionScenario))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ import cats.implicits._
import dev.profunktor.redis4cats.effect.Log
import dev.profunktor.redis4cats.effects._
import dev.profunktor.redis4cats.hlist._
import dev.profunktor.redis4cats.transactions._
import dev.profunktor.redis4cats.pipeline.RedisPipeline
import dev.profunktor.redis4cats.transactions.RedisTransaction
import io.lettuce.core.GeoArgs
import scala.concurrent.duration._

Expand Down Expand Up @@ -249,17 +250,34 @@ trait TestScenarios {
_ <- IO(assert(slowLogLen.isValidLong))
} yield ()

def pipelineScenario(cmd: RedisCommands[IO, String, String]): IO[Unit] = {
val key1 = "testp1"
val key2 = "testp2"

val operations =
cmd.set(key1, "osx") :: cmd.set(key2, "windows") :: cmd.get(key1) :: cmd.sIsMember("foo", "bar") ::
cmd.set(key1, "nix") :: cmd.set(key2, "linux") :: cmd.get(key1) :: HNil

RedisPipeline(cmd).exec(operations).map {
case _ ~: _ ~: res1 ~: res2 ~: _ ~: _ ~: res3 ~: HNil =>
assert(res1.contains("osx"))
assert(res2 === false)
assert(res3.contains("nix"))
case tr =>
assert(false, s"Unexpected result: $tr")
}

}

def transactionScenario(cmd: RedisCommands[IO, String, String]): IO[Unit] = {
val key1 = "test1"
val key2 = "test2"

val tx = RedisTransaction(cmd)

val operations =
cmd.set(key1, "osx") :: cmd.set(key2, "windows") :: cmd.get(key1) :: cmd.sIsMember("foo", "bar") ::
cmd.set(key1, "nix") :: cmd.set(key2, "linux") :: cmd.get(key1) :: HNil

tx.exec(operations).map {
RedisTransaction(cmd).exec(operations).map {
case _ ~: _ ~: res1 ~: res2 ~: _ ~: _ ~: res3 ~: HNil =>
assert(res1.contains("osx"))
assert(res2 === false)
Expand Down
Loading

0 comments on commit 99e03f4

Please sign in to comment.