Skip to content

Commit

Permalink
queue/channel.close
Browse files Browse the repository at this point in the history
  • Loading branch information
fwbrasil committed Nov 26, 2023
1 parent 96d1c93 commit c7b6c91
Show file tree
Hide file tree
Showing 4 changed files with 232 additions and 65 deletions.
107 changes: 71 additions & 36 deletions kyo-core/shared/src/main/scala/kyo/concurrent/channels.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,35 @@ object channels {

def size: Int > IOs

def offer[S](v: T > S): Boolean > (IOs with S)
def offer(v: T): Boolean > IOs

def offerUnit[S](v: T > S): Unit > (IOs with S)
def offerUnit(v: T): Unit > IOs

def poll: Option[T] > IOs

def isEmpty: Boolean > IOs

def isFull: Boolean > IOs

def putFiber[S](v: T > S): Fiber[Unit] > (IOs with S)
def putFiber(v: T): Fiber[Unit] > IOs

def takeFiber: Fiber[T] > IOs

def put[S](v: T > S): Unit > (S with Fibers) =
def put(v: T): Unit > Fibers =
putFiber(v).map(_.get)

def take: T > Fibers =
takeFiber.map(_.get)

def isClosed: Boolean > IOs

def close: Option[Seq[T]] > IOs
}

object Channels {

private val placeholder = Fibers.unsafeInitPromise[Unit]
private val closed = IOs.fail("Channel closed!")

def init[T](
capacity: Int,
Expand All @@ -53,51 +58,45 @@ object channels {
val takes = new MpmcUnboundedXaddArrayQueue[Promise[T]](8)
val puts = new MpmcUnboundedXaddArrayQueue[(T, Promise[Unit])](8)

def size = queue.size
def isEmpty = queue.isEmpty
def isFull = queue.isFull
def size = op(u.size())
def isEmpty = op(u.isEmpty())
def isFull = op(u.isFull())

def offer[S](v: T > S) =
v.map { v =>
IOs[Boolean, S] {
try u.offer(v)
finally flush()
}
def offer(v: T) =
op {
try u.offer(v)
finally flush()
}
def offerUnit[S](v: T > S) =
v.map { v =>
IOs[Unit, S] {
try {
u.offer(v)
()
} finally flush()
}
def offerUnit(v: T) =
op {
try {
u.offer(v)
()
} finally flush()
}
val poll =
IOs {
op {
try u.poll()
finally flush()
}

def putFiber[S](v: T > S) =
v.map { v =>
IOs[Fiber[Unit], S] {
try {
if (u.offer(v)) {
Fibers.value(())
} else {
val p = Fibers.unsafeInitPromise[Unit]
puts.add((v, p))
p
}
} finally {
flush()
def putFiber(v: T) =
op {
try {
if (u.offer(v)) {
Fibers.value(())
} else {
val p = Fibers.unsafeInitPromise[Unit]
puts.add((v, p))
p
}
} finally {
flush()
}
}

val takeFiber =
IOs {
op {
try {
u.poll() match {
case Some(v) =>
Expand All @@ -112,6 +111,42 @@ object channels {
}
}

/*inline*/
def op[T]( /*inline*/ v: => T): T > IOs =
IOs[T, Any] {
if (u.isClosed()) {
closed
} else {
v
}
}

def isClosed = queue.isClosed

def close =
IOs[Option[Seq[T]], Any] {
u.close() match {
case None =>
None
case r: Some[Seq[T]] =>
def dropTakes(): Unit > IOs =
takes.poll() match {
case null => ()
case p =>
p.interrupt.map(_ => dropTakes())
}
def dropPuts(): Unit > IOs =
puts.poll() match {
case null => ()
case (_, p) =>
p.interrupt.map(_ => dropPuts())
}
dropTakes()
.andThen(dropPuts())
.andThen(r)
}
}

@tailrec private def flush(): Unit = {
// This method ensures that all values are processed
// and handles interrupted fibers by discarding them.
Expand Down
97 changes: 68 additions & 29 deletions kyo-core/shared/src/main/scala/kyo/concurrent/queues.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,64 +8,103 @@ import org.jctools.queues._
import java.util.ArrayDeque
import java.util.concurrent.atomic.AtomicReference
import scala.annotation.tailrec
import java.util.concurrent.atomic.AtomicBoolean

object queues {

private val closed = IOs.fail("Queue closed!")

class Queue[T] private[queues] (private[kyo] val unsafe: Queues.Unsafe[T]) {
def capacity: Int = unsafe.capacity
def size: Int > IOs = IOs(unsafe.size())
def isEmpty: Boolean > IOs = IOs(unsafe.isEmpty())
def isFull: Boolean > IOs = IOs(unsafe.isFull())
def offer[S](v: T > S): Boolean > (IOs with S) = v.map(v => IOs(unsafe.offer(v)))
def poll: Option[T] > IOs = IOs(unsafe.poll())
def peek: Option[T] > IOs = IOs(unsafe.peek())

def capacity: Int = unsafe.capacity
def size: Int > IOs = op(unsafe.size())
def isEmpty: Boolean > IOs = op(unsafe.isEmpty())
def isFull: Boolean > IOs = op(unsafe.isFull())
def offer(v: T): Boolean > IOs = op(unsafe.offer(v))
def poll: Option[T] > IOs = op(unsafe.poll())
def peek: Option[T] > IOs = op(unsafe.peek())
def drain: Seq[T] > IOs = op(unsafe.drain())
def isClosed: Boolean > IOs = IOs(unsafe.isClosed())
def close: Option[Seq[T]] > IOs = IOs(unsafe.close())

/*inline*/
private def op[T]( /*inline*/ v: => T): T > IOs =
IOs {
if (unsafe.isClosed()) {
closed
} else {
v
}
}
}

object Queues {

private[kyo] trait Unsafe[T] {
private[kyo] abstract class Unsafe[T]
extends AtomicBoolean(false) {

def capacity: Int
def size(): Int
def isEmpty(): Boolean
def isFull(): Boolean
def offer(v: T): Boolean
def poll(): Option[T]
def peek(): Option[T]

def drain(): Seq[T] = {
def loop(acc: List[T]): List[T] =
poll() match {
case None =>
acc.reverse
case Some(v) =>
loop(v :: acc)
}
loop(Nil)
}

def isClosed(): Boolean =
super.get()

def close(): Option[Seq[T]] =
super.compareAndSet(false, true) match {
case false =>
None
case true =>
Some(drain())
}
}

class Unbounded[T] private[queues] (unsafe: Queues.Unsafe[T]) extends Queue[T](unsafe) {

def add[S](v: T > S): Unit > (IOs with S) = v.map(offer(_)).unit
}

private val zeroCapacity =
new Queue(
new Unsafe[Any] {
def capacity = 0
def size() = 0
def isEmpty() = true
def isFull() = true
def offer(v: Any) = false
def poll() = None
def peek() = None
}
)

def init[T](capacity: Int, access: Access = Access.Mpmc): Queue[T] > IOs =
IOs {
capacity match {
case c if (c <= 0) =>
zeroCapacity.asInstanceOf[Queue[T]]
new Queue(
new Unsafe[T] {
def capacity = 0
def size() = 0
def isEmpty() = true
def isFull() = true
def offer(v: T) = false
def poll() = None
def peek() = None
}
)
case 1 =>
new Queue(
new AtomicReference[T] with Unsafe[T] {
new Unsafe[T] {
val state = new AtomicReference[T]
def capacity = 1
def size() = if (get == null) 0 else 1
def isEmpty() = get == null
def isFull() = get != null
def offer(v: T) = compareAndSet(null.asInstanceOf[T], v)
def poll() = Option(getAndSet(null.asInstanceOf[T]))
def peek() = Option(get)
def size() = if (state.get() == null) 0 else 1
def isEmpty() = state.get() == null
def isFull() = state.get() != null
def offer(v: T) = state.compareAndSet(null.asInstanceOf[T], v)
def poll() = Option(state.getAndSet(null.asInstanceOf[T]))
def peek() = Option(state.get())
}
)
case Int.MaxValue =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import kyo.concurrent.queues._
import kyo.concurrent.timers._
import kyo._
import kyo.ios._
import kyo.tries._
import kyoTest.KyoTest

import scala.concurrent.duration._
Expand Down Expand Up @@ -84,6 +85,62 @@ class channelsTest extends KyoTest {
v <- f.get
} yield assert(!d1 && d2 && v == 1)
}
"close" - {
"empty" in runJVM {
for {
c <- Channels.init[Int](2)
r <- c.close
t <- Tries.run(c.offer(1))
} yield assert(r == Some(Seq()) && t.isFailure)
}
"non-empty" in runJVM {
for {
c <- Channels.init[Int](2)
_ <- c.put(1)
_ <- c.put(2)
r <- c.close
t <- Tries.run(c.isEmpty)
} yield assert(r == Some(Seq(1, 2)) && t.isFailure)
}
"pending take" in runJVM {
for {
c <- Channels.init[Int](2)
f <- c.takeFiber
r <- c.close
d <- f.getTry
t <- Tries.run(c.isFull)
} yield assert(r == Some(Seq()) && d.isFailure && t.isFailure)
}
"pending put" in runJVM {
for {
c <- Channels.init[Int](2)
_ <- c.put(1)
_ <- c.put(2)
f <- c.putFiber(3)
r <- c.close
d <- f.getTry
t <- Tries.run(c.offerUnit(1))
} yield assert(r == Some(Seq(1, 2)) && d.isFailure && t.isFailure)
}
"no buffer w/ pending put" in runJVM {
for {
c <- Channels.init[Int](0)
f <- c.putFiber(1)
r <- c.close
d <- f.getTry
t <- Tries.run(c.poll)
} yield assert(r == Some(Seq()) && d.isFailure && t.isFailure)
}
"no buffer w/ pending take" in runJVM {
for {
c <- Channels.init[Int](0)
f <- c.takeFiber
r <- c.close
d <- f.getTry
t <- Tries.run(c.put(1))
} yield assert(r == Some(Seq()) && d.isFailure && t.isFailure)
}
}
"no buffer" in runJVM {
for {
c <- Channels.init[Int](0)
Expand Down
Loading

0 comments on commit c7b6c91

Please sign in to comment.