From c7b6c9172b6d467dbde1115a01631a29458bec37 Mon Sep 17 00:00:00 2001 From: Flavio Brasil Date: Sat, 25 Nov 2023 21:49:33 -0800 Subject: [PATCH] queue/channel.close --- .../main/scala/kyo/concurrent/channels.scala | 107 ++++++++++++------ .../main/scala/kyo/concurrent/queues.scala | 97 +++++++++++----- .../kyoTest/concurrent/channelsTest.scala | 57 ++++++++++ .../scala/kyoTest/concurrent/queuesTest.scala | 36 ++++++ 4 files changed, 232 insertions(+), 65 deletions(-) diff --git a/kyo-core/shared/src/main/scala/kyo/concurrent/channels.scala b/kyo-core/shared/src/main/scala/kyo/concurrent/channels.scala index a30ab4320..8cd20db01 100644 --- a/kyo-core/shared/src/main/scala/kyo/concurrent/channels.scala +++ b/kyo-core/shared/src/main/scala/kyo/concurrent/channels.scala @@ -16,9 +16,9 @@ object channels { def size: Int > IOs - def offer[S](v: T > S): Boolean > (IOs with S) + def offer(v: T): Boolean > IOs - def offerUnit[S](v: T > S): Unit > (IOs with S) + def offerUnit(v: T): Unit > IOs def poll: Option[T] > IOs @@ -26,20 +26,25 @@ object channels { def isFull: Boolean > IOs - def putFiber[S](v: T > S): Fiber[Unit] > (IOs with S) + def putFiber(v: T): Fiber[Unit] > IOs def takeFiber: Fiber[T] > IOs - def put[S](v: T > S): Unit > (S with Fibers) = + def put(v: T): Unit > Fibers = putFiber(v).map(_.get) def take: T > Fibers = takeFiber.map(_.get) + + def isClosed: Boolean > IOs + + def close: Option[Seq[T]] > IOs } object Channels { private val placeholder = Fibers.unsafeInitPromise[Unit] + private val closed = IOs.fail("Channel closed!") def init[T]( capacity: Int, @@ -53,51 +58,45 @@ object channels { val takes = new MpmcUnboundedXaddArrayQueue[Promise[T]](8) val puts = new MpmcUnboundedXaddArrayQueue[(T, Promise[Unit])](8) - def size = queue.size - def isEmpty = queue.isEmpty - def isFull = queue.isFull + def size = op(u.size()) + def isEmpty = op(u.isEmpty()) + def isFull = op(u.isFull()) - def offer[S](v: T > S) = - v.map { v => - IOs[Boolean, S] { - try u.offer(v) - finally flush() - } + def offer(v: T) = + op { + try u.offer(v) + finally flush() } - def offerUnit[S](v: T > S) = - v.map { v => - IOs[Unit, S] { - try { - u.offer(v) - () - } finally flush() - } + def offerUnit(v: T) = + op { + try { + u.offer(v) + () + } finally flush() } val poll = - IOs { + op { try u.poll() finally flush() } - def putFiber[S](v: T > S) = - v.map { v => - IOs[Fiber[Unit], S] { - try { - if (u.offer(v)) { - Fibers.value(()) - } else { - val p = Fibers.unsafeInitPromise[Unit] - puts.add((v, p)) - p - } - } finally { - flush() + def putFiber(v: T) = + op { + try { + if (u.offer(v)) { + Fibers.value(()) + } else { + val p = Fibers.unsafeInitPromise[Unit] + puts.add((v, p)) + p } + } finally { + flush() } } val takeFiber = - IOs { + op { try { u.poll() match { case Some(v) => @@ -112,6 +111,42 @@ object channels { } } + /*inline*/ + def op[T]( /*inline*/ v: => T): T > IOs = + IOs[T, Any] { + if (u.isClosed()) { + closed + } else { + v + } + } + + def isClosed = queue.isClosed + + def close = + IOs[Option[Seq[T]], Any] { + u.close() match { + case None => + None + case r: Some[Seq[T]] => + def dropTakes(): Unit > IOs = + takes.poll() match { + case null => () + case p => + p.interrupt.map(_ => dropTakes()) + } + def dropPuts(): Unit > IOs = + puts.poll() match { + case null => () + case (_, p) => + p.interrupt.map(_ => dropPuts()) + } + dropTakes() + .andThen(dropPuts()) + .andThen(r) + } + } + @tailrec private def flush(): Unit = { // This method ensures that all values are processed // and handles interrupted fibers by discarding them. diff --git a/kyo-core/shared/src/main/scala/kyo/concurrent/queues.scala b/kyo-core/shared/src/main/scala/kyo/concurrent/queues.scala index f6504fcbc..5ea44f1d6 100644 --- a/kyo-core/shared/src/main/scala/kyo/concurrent/queues.scala +++ b/kyo-core/shared/src/main/scala/kyo/concurrent/queues.scala @@ -8,22 +8,41 @@ import org.jctools.queues._ import java.util.ArrayDeque import java.util.concurrent.atomic.AtomicReference import scala.annotation.tailrec +import java.util.concurrent.atomic.AtomicBoolean object queues { + private val closed = IOs.fail("Queue closed!") + class Queue[T] private[queues] (private[kyo] val unsafe: Queues.Unsafe[T]) { - def capacity: Int = unsafe.capacity - def size: Int > IOs = IOs(unsafe.size()) - def isEmpty: Boolean > IOs = IOs(unsafe.isEmpty()) - def isFull: Boolean > IOs = IOs(unsafe.isFull()) - def offer[S](v: T > S): Boolean > (IOs with S) = v.map(v => IOs(unsafe.offer(v))) - def poll: Option[T] > IOs = IOs(unsafe.poll()) - def peek: Option[T] > IOs = IOs(unsafe.peek()) + + def capacity: Int = unsafe.capacity + def size: Int > IOs = op(unsafe.size()) + def isEmpty: Boolean > IOs = op(unsafe.isEmpty()) + def isFull: Boolean > IOs = op(unsafe.isFull()) + def offer(v: T): Boolean > IOs = op(unsafe.offer(v)) + def poll: Option[T] > IOs = op(unsafe.poll()) + def peek: Option[T] > IOs = op(unsafe.peek()) + def drain: Seq[T] > IOs = op(unsafe.drain()) + def isClosed: Boolean > IOs = IOs(unsafe.isClosed()) + def close: Option[Seq[T]] > IOs = IOs(unsafe.close()) + + /*inline*/ + private def op[T]( /*inline*/ v: => T): T > IOs = + IOs { + if (unsafe.isClosed()) { + closed + } else { + v + } + } } object Queues { - private[kyo] trait Unsafe[T] { + private[kyo] abstract class Unsafe[T] + extends AtomicBoolean(false) { + def capacity: Int def size(): Int def isEmpty(): Boolean @@ -31,6 +50,28 @@ object queues { def offer(v: T): Boolean def poll(): Option[T] def peek(): Option[T] + + def drain(): Seq[T] = { + def loop(acc: List[T]): List[T] = + poll() match { + case None => + acc.reverse + case Some(v) => + loop(v :: acc) + } + loop(Nil) + } + + def isClosed(): Boolean = + super.get() + + def close(): Option[Seq[T]] = + super.compareAndSet(false, true) match { + case false => + None + case true => + Some(drain()) + } } class Unbounded[T] private[queues] (unsafe: Queues.Unsafe[T]) extends Queue[T](unsafe) { @@ -38,34 +79,32 @@ object queues { def add[S](v: T > S): Unit > (IOs with S) = v.map(offer(_)).unit } - private val zeroCapacity = - new Queue( - new Unsafe[Any] { - def capacity = 0 - def size() = 0 - def isEmpty() = true - def isFull() = true - def offer(v: Any) = false - def poll() = None - def peek() = None - } - ) - def init[T](capacity: Int, access: Access = Access.Mpmc): Queue[T] > IOs = IOs { capacity match { case c if (c <= 0) => - zeroCapacity.asInstanceOf[Queue[T]] + new Queue( + new Unsafe[T] { + def capacity = 0 + def size() = 0 + def isEmpty() = true + def isFull() = true + def offer(v: T) = false + def poll() = None + def peek() = None + } + ) case 1 => new Queue( - new AtomicReference[T] with Unsafe[T] { + new Unsafe[T] { + val state = new AtomicReference[T] def capacity = 1 - def size() = if (get == null) 0 else 1 - def isEmpty() = get == null - def isFull() = get != null - def offer(v: T) = compareAndSet(null.asInstanceOf[T], v) - def poll() = Option(getAndSet(null.asInstanceOf[T])) - def peek() = Option(get) + def size() = if (state.get() == null) 0 else 1 + def isEmpty() = state.get() == null + def isFull() = state.get() != null + def offer(v: T) = state.compareAndSet(null.asInstanceOf[T], v) + def poll() = Option(state.getAndSet(null.asInstanceOf[T])) + def peek() = Option(state.get()) } ) case Int.MaxValue => diff --git a/kyo-core/shared/src/test/scala/kyoTest/concurrent/channelsTest.scala b/kyo-core/shared/src/test/scala/kyoTest/concurrent/channelsTest.scala index eecc63149..7ed0ab887 100644 --- a/kyo-core/shared/src/test/scala/kyoTest/concurrent/channelsTest.scala +++ b/kyo-core/shared/src/test/scala/kyoTest/concurrent/channelsTest.scala @@ -6,6 +6,7 @@ import kyo.concurrent.queues._ import kyo.concurrent.timers._ import kyo._ import kyo.ios._ +import kyo.tries._ import kyoTest.KyoTest import scala.concurrent.duration._ @@ -84,6 +85,62 @@ class channelsTest extends KyoTest { v <- f.get } yield assert(!d1 && d2 && v == 1) } + "close" - { + "empty" in runJVM { + for { + c <- Channels.init[Int](2) + r <- c.close + t <- Tries.run(c.offer(1)) + } yield assert(r == Some(Seq()) && t.isFailure) + } + "non-empty" in runJVM { + for { + c <- Channels.init[Int](2) + _ <- c.put(1) + _ <- c.put(2) + r <- c.close + t <- Tries.run(c.isEmpty) + } yield assert(r == Some(Seq(1, 2)) && t.isFailure) + } + "pending take" in runJVM { + for { + c <- Channels.init[Int](2) + f <- c.takeFiber + r <- c.close + d <- f.getTry + t <- Tries.run(c.isFull) + } yield assert(r == Some(Seq()) && d.isFailure && t.isFailure) + } + "pending put" in runJVM { + for { + c <- Channels.init[Int](2) + _ <- c.put(1) + _ <- c.put(2) + f <- c.putFiber(3) + r <- c.close + d <- f.getTry + t <- Tries.run(c.offerUnit(1)) + } yield assert(r == Some(Seq(1, 2)) && d.isFailure && t.isFailure) + } + "no buffer w/ pending put" in runJVM { + for { + c <- Channels.init[Int](0) + f <- c.putFiber(1) + r <- c.close + d <- f.getTry + t <- Tries.run(c.poll) + } yield assert(r == Some(Seq()) && d.isFailure && t.isFailure) + } + "no buffer w/ pending take" in runJVM { + for { + c <- Channels.init[Int](0) + f <- c.takeFiber + r <- c.close + d <- f.getTry + t <- Tries.run(c.put(1)) + } yield assert(r == Some(Seq()) && d.isFailure && t.isFailure) + } + } "no buffer" in runJVM { for { c <- Channels.init[Int](0) diff --git a/kyo-core/shared/src/test/scala/kyoTest/concurrent/queuesTest.scala b/kyo-core/shared/src/test/scala/kyoTest/concurrent/queuesTest.scala index 902e6c554..950324709 100644 --- a/kyo-core/shared/src/test/scala/kyoTest/concurrent/queuesTest.scala +++ b/kyo-core/shared/src/test/scala/kyoTest/concurrent/queuesTest.scala @@ -3,6 +3,7 @@ package kyoTest.concurrent import kyo.concurrent.queues._ import kyo._ import kyo.ios._ +import kyo.tries._ import kyoTest.KyoTest import kyo.concurrent.Access @@ -62,6 +63,41 @@ class queuesTest extends KyoTest { } } + "close" in run { + for { + q <- Queues.init[Int](2) + b <- q.offer(1) + c1 <- q.close + v1 <- Tries.run(q.size) + v2 <- Tries.run(q.isEmpty) + v3 <- Tries.run(q.isFull) + v4 <- Tries.run(q.offer(2)) + v5 <- Tries.run(q.poll) + v6 <- Tries.run(q.peek) + v7 <- Tries.run(q.drain) + c2 <- q.close + } yield assert( + b && c1 == Some(Seq(1)) && + v1.isFailure && + v2.isFailure && + v3.isFailure && + v4.isFailure && + v5.isFailure && + v6.isFailure && + v7.isFailure && + c2.isEmpty + ) + } + + "drain" in run { + for { + q <- Queues.init[Int](2) + _ <- q.offer(1) + _ <- q.offer(2) + v <- q.drain + } yield assert(v == Seq(1, 2)) + } + "unbounded" - { access.foreach { access => access.toString() - {