Skip to content

Commit

Permalink
[prelude][core][stm] fix issue with preemption in nested async comput…
Browse files Browse the repository at this point in the history
…ations
  • Loading branch information
fwbrasil committed Dec 13, 2024
1 parent d53aa85 commit 0b58f86
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 17 deletions.
15 changes: 8 additions & 7 deletions kyo-core/shared/src/main/scala/kyo/scheduler/IOPromise.scala
Original file line number Diff line number Diff line change
Expand Up @@ -216,14 +216,15 @@ private[kyo] class IOPromise[+E, +A](init: State[E, A]) extends Safepoint.Interc
blockLoop(this)
end block

protected def stateString(): String =
state.get() match
case p: Pending[?, ?] => s"Pending(waiters = ${p.waiters})"
case l: Linked[?, ?] => s"Linked(promise = ${l.p})"
case r => s"Done(result = ${r.asInstanceOf[Result[Any, Any]].show})"

override def toString =
val stateString =
state.get() match
case p: Pending[?, ?] => s"Pending(waiters = ${p.waiters})"
case l: Linked[?, ?] => s"Linked(promise = ${l.p})"
case r => s"Done(result = ${r.asInstanceOf[Result[Any, Any]].show})"
s"IOPromise(state = ${stateString})"
end toString
s"IOPromise(state = ${stateString()})"

end IOPromise

private[kyo] object IOPromise:
Expand Down
3 changes: 3 additions & 0 deletions kyo-core/shared/src/main/scala/kyo/scheduler/IOTask.scala
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ sealed private[kyo] class IOTask[Ctx, E, A] private (

private inline def nullResult = null.asInstanceOf[A < Ctx & Async & Abort[E]]

override def toString =
s"IOTask(state = ${stateString()}, preempt = ${{ shouldPreempt() }}, finalizers = ${finalizers.size()}, curr = ${curr})"

end IOTask

object IOTask:
Expand Down
55 changes: 55 additions & 0 deletions kyo-core/shared/src/test/scala/kyo/AsyncTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1014,4 +1014,59 @@ class AsyncTest extends Test:
}
}

"preemption is properly handled in nested Async computations" - {
"simple" in run {
Async.run(Async.run(Async.delay(100.millis)(42))).map(_.get).map(_.get).map { result =>
assert(result == 42)
}
}
"with nested eval" in run {
import AllowUnsafe.embrace.danger
val task = IO.Unsafe.evalOrThrow(Async.run(Async.delay(100.millis)(42)))
Async.run(task).map(_.get).map(_.get).map { result =>
assert(result == 42)
}
}
"with multiple nested evals" in run {
import AllowUnsafe.embrace.danger
val innerTask = IO.Unsafe.evalOrThrow(Async.run(Async.delay(100.millis)(42)))
val middleTask = IO.Unsafe.evalOrThrow(Async.run(innerTask))
val outerTask = IO.Unsafe.evalOrThrow(Async.run(middleTask))
Async.run(outerTask).map(_.get).map(_.get).map(_.get).map(_.get).map { result =>
assert(result == 42)
}
}
"with eval inside async computation" in run {
import AllowUnsafe.embrace.danger
Async.run {
Async.delay(100.millis) {
IO.Unsafe.evalOrThrow(Async.run(42)).get
}
}.map(_.get).map { result =>
assert(result == 42)
}
}
"with interleaved evals and delays" in run {
import AllowUnsafe.embrace.danger
val task1 = IO.Unsafe.evalOrThrow(Async.run(Async.delay(100.millis)(1)))
val task2 = Async.delay(100.millis) {
IO.Unsafe.evalOrThrow(Async.run(task1)).get
}
val task3 = IO.Unsafe.evalOrThrow(Async.run(task2))
Async.run(task3).map(_.get).map(_.get).map(_.get).map { result =>
assert(result == 1)
}
}
"with race" in run {
Async.run {
Async.race(
Async.run(Async.delay(100.millis)(1)).map(_.get),
Async.run(Async.delay(200.millis)(2)).map(_.get)
)
}.map(_.get).map { result =>
assert(result == 1)
}
}
}

end AsyncTest
28 changes: 20 additions & 8 deletions kyo-prelude/jvm/src/test/scala/kyo/kernel/SafepointTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,22 @@ class SafepointTest extends Test:
executed = true
true

Safepoint.immediate(interceptor)((1: Int < Any).map(_ + 1).eval)
Safepoint.immediate(interceptor)((1: Int < Any).map(_ + 1)).eval
assert(executed)
}

"eval removes the interceptor" in {
var executed = false
val interceptor = new TestInterceptor:
def ensure(f: () => Unit): Unit = ()
def enter(frame: Frame, value: Any): Boolean =
executed = true
true

Safepoint.immediate(interceptor)((1: Int < Any).map(_ + 1).eval)
assert(!executed)
}

"restore previous interceptor" in {
var count = 0
val interceptor1 = new TestInterceptor:
Expand All @@ -212,7 +224,7 @@ class SafepointTest extends Test:
true

Safepoint.immediate(interceptor1) {
Safepoint.immediate(interceptor2)((1: Int < Any).map(_ + 1).eval)
Safepoint.immediate(interceptor2)((1: Int < Any).map(_ + 1))
}.eval

assert(count == 11)
Expand Down Expand Up @@ -508,8 +520,8 @@ class SafepointTest extends Test:
assert(interceptor.ensuresAdded.size == 1)
assert(interceptor.ensuresRemoved.isEmpty)
42
}.eval
}
}
}.eval

assert(interceptor.ensuresAdded.size == 1)
assert(interceptor.ensuresRemoved.size == 1)
Expand All @@ -525,8 +537,8 @@ class SafepointTest extends Test:
Safepoint.immediate(interceptor) {
testEnsure {
42
}.eval
}
}
}.eval

assert(interceptor.ensuresRemoved.size == 1)
}
Expand All @@ -544,8 +556,8 @@ class SafepointTest extends Test:
42
}
}
}.eval
}
}
}.eval

assert(interceptor.ensuresAdded.size == 3)
assert(interceptor.ensuresRemoved.size == 3)
Expand Down
12 changes: 10 additions & 2 deletions kyo-prelude/shared/src/main/scala/kyo/kernel/Safepoint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ final class Safepoint private () extends Trace.Owner:
interceptor = newInterceptor
state = state.withInterceptor(newInterceptor != null)

override def toString(): String =
val currentState = state
s"Safepoint(depth=${currentState.depth}, threadId=${currentState.threadId}, interceptor=${interceptor})"

end Safepoint

object Safepoint:
Expand Down Expand Up @@ -166,8 +170,12 @@ object Safepoint:
private[kernel] inline def eval[A](
inline f: Safepoint ?=> A
)(using inline frame: Frame): A =
val self = Safepoint.get
self.withNewTrace(f(using self))
val self = Safepoint.get
val prevInterceptor = self.interceptor
self.setInterceptor(null)
try self.withNewTrace(f(using self))
finally
self.interceptor = prevInterceptor
end eval

private[kernel] inline def handle[V, A, S](value: V)(
Expand Down
24 changes: 24 additions & 0 deletions kyo-stm/shared/src/test/scala/kyo/STMTest.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package kyo

import scala.concurrent.Future

class STMTest extends Test:

"Transaction isolation" - {
Expand Down Expand Up @@ -715,4 +717,26 @@ class STMTest extends Test:
}
}

"bug #925" in run {
def unsafeToFuture[A: Flat](a: => A < (Async & Abort[Throwable])): Future[A] =
import kyo.AllowUnsafe.embrace.danger
IO.Unsafe.evalOrThrow(
Async.run(a).map(_.toFuture)
)
end unsafeToFuture

val ex = new Exception

val faultyTransaction: Int < STM = TRef.init(42).map { r =>
throw ex
r.get
}

val task = Async.runAndBlock(Duration.Infinity)(Async.fromFuture(unsafeToFuture(STM.run(faultyTransaction))))

Abort.run(task).map { result =>
assert(result == Result.fail(ex))
}
}

end STMTest

0 comments on commit 0b58f86

Please sign in to comment.