From 89aa066d67da8a6940c5013394e069767c2b40aa Mon Sep 17 00:00:00 2001 From: Flavio Brasil Date: Thu, 7 Dec 2023 22:27:06 -0800 Subject: [PATCH] wip --- .../main/scala/kyo/concurrent/fibers.scala | 32 +++++++------- .../shared/src/main/scala/kyo/joins.scala | 44 +++++++++++++++++++ .../src/test/scala/kyoTest/directTest.scala | 8 ++-- .../shared/src/main/scala/kyo/llm/ais.scala | 27 +----------- 4 files changed, 66 insertions(+), 45 deletions(-) diff --git a/kyo-core/shared/src/main/scala/kyo/concurrent/fibers.scala b/kyo-core/shared/src/main/scala/kyo/concurrent/fibers.scala index 76ac83508..2f094065e 100644 --- a/kyo-core/shared/src/main/scala/kyo/concurrent/fibers.scala +++ b/kyo-core/shared/src/main/scala/kyo/concurrent/fibers.scala @@ -205,7 +205,9 @@ object fibers { Locals.save.map(st => Fiber.promise(IOTask(IOs(v), st))) /*inline*/ - def init[T, S](j: Joins[S])( /*inline*/ v: => T > (S with Fibers))(implicit f: Flat[T > (S with Fibers)]): Fiber[j.M[T]] > (S with IOs) = + def init[T, S](j: Joins[S])( /*inline*/ v: => T > (S with Fibers))(implicit + f: Flat[T > (S with Fibers)] + ): Fiber[j.M[T]] > (S with IOs) = j.save.map { st => init(j.handle(st, v)) } @@ -268,13 +270,6 @@ object fibers { } } - def parallel[T, S](j: Joins[S])(l: Seq[T > (S with Fibers)])( - implicit f: Flat[T > Fibers] - ): Seq[T] > (S with Fibers) = - j.save.map { st => - j.handle(st, l).map(parallel(_)).map(j.resume) - } - def parallel[T](l: Seq[T > Fibers])(implicit f: Flat[T > Fibers]): Seq[T] > Fibers = l.size match { case 0 => Seq.empty @@ -283,6 +278,13 @@ object fibers { Fibers.get(parallelFiber[T](l)) } + def parallel[T, S](j: Joins[S])(l: Seq[T > (S with Fibers)])( + implicit f: Flat[T > Fibers] + ): Seq[T] > (S with Fibers) = + j.save.map { st => + j.handle(st, l).map(parallel(_)).map(j.resume) + } + def parallelFiber[T](l: Seq[T > Fibers])(implicit f: Flat[T > Fibers]): Fiber[Seq[T]] > IOs = l.size match { case 0 => Fiber.done(Seq.empty) @@ -317,13 +319,6 @@ object fibers { } } - def race[T, S](j: Joins[S])(l: Seq[T > (S with Fibers)])(implicit - f: Flat[T > (S with Fibers)] - ): T > (S with Fibers) = - j.save.map { st => - j.handle(st, l).map(race(_)).map(j.resume) - } - def race[T](l: Seq[T > Fibers])(implicit f: Flat[T > Fibers]): T > Fibers = l.size match { case 0 => IOs.fail("Can't race an empty list.") @@ -332,6 +327,13 @@ object fibers { Fibers.get(raceFiber[T](l)) } + def race[T, S](j: Joins[S])(l: Seq[T > (S with Fibers)])(implicit + f: Flat[T > (S with Fibers)] + ): T > (S with Fibers) = + j.save.map { st => + j.handle(st, l).map(race(_)).map(j.resume) + } + def raceFiber[T](l: Seq[T > Fibers])(implicit f: Flat[T > Fibers]): Fiber[T] > IOs = l.size match { case 0 => IOs.fail("Can't race an empty list.") diff --git a/kyo-core/shared/src/main/scala/kyo/joins.scala b/kyo-core/shared/src/main/scala/kyo/joins.scala index 38eb80bbb..085d54dc0 100644 --- a/kyo-core/shared/src/main/scala/kyo/joins.scala +++ b/kyo-core/shared/src/main/scala/kyo/joins.scala @@ -3,7 +3,11 @@ package kyo import kyo._ import kyo.core._ import kyo.ios._ +import kyo.aborts._ +import kyo.envs._ import kyo.lists.Lists +import izumi.reflect._ +import kyo.resources.Resources object joins { @@ -74,5 +78,45 @@ object joins { def apply[E1: Joins, E2: Joins, E3: Joins, E4: Joins, E5: Joins] : Joins[E1 with E2 with E3 with E4 with E5] = Joins[E1].andThen(Joins[E2]).andThen(Joins[E3]).andThen(Joins[E4]).andThen(Joins[E5]) + + implicit def aborts[E: Tag]: Joins[Aborts[E]] = + new Joins[Aborts[E]] { + type State = Unit + type M[T] = Abort[E]#Value[T] + val aborts = Aborts[E] + + def save = () + def handle[T, S](s: State, v: T > (Aborts[E] & S))(implicit f: Flat[T > (Aborts[E] & S)]) = + aborts.run(v) + def resume[T, S](v: M[T] > S) = + aborts.get(v) + } + + implicit def envs[E: Tag]: Joins[Envs[E]] = + new Joins[Envs[E]] { + type State = E + type M[T] = T + val envs = Envs[E] + + def save = envs.get + def handle[T, S](s: State, v: T > (Envs[E] & S))(implicit f: Flat[T > (Envs[E] & S)]) = + envs.run(s)(v) + def resume[T, S](v: M[T] > S) = + v + } + + implicit val lists: Joins[Lists] = + new Joins[Lists] { + type State = Unit + type M[T] = List[T] + + def save = () + def handle[T, S](s: Unit, v: T > (Lists & S))(implicit f: Flat[T > (Lists & S)]) = + Lists.run(v) + def resume[T, S](v: List[T] > S): T > (Lists & S) = + Lists.foreach(v) + } + + } } diff --git a/kyo-direct/src/test/scala/kyoTest/directTest.scala b/kyo-direct/src/test/scala/kyoTest/directTest.scala index 98fa79ec2..d6c0abab6 100644 --- a/kyo-direct/src/test/scala/kyoTest/directTest.scala +++ b/kyo-direct/src/test/scala/kyoTest/directTest.scala @@ -153,8 +153,8 @@ class directTest extends KyoTest { "lists" in { import kyo.lists._ - val x = Lists.foreach(List(1, -2, -3)) - val y = Lists.foreach(List("ab", "cde")) + val x = Lists.get(List(1, -2, -3)) + val y = Lists.get(List("ab", "cde")) val v: Int > Lists = defer { @@ -172,8 +172,8 @@ class directTest extends KyoTest { "lists + filter" in { import kyo.lists._ - val x = Lists.foreach(List(1, -2, -3)) - val y = Lists.foreach(List("ab", "cde")) + val x = Lists.get(List(1, -2, -3)) + val y = Lists.get(List("ab", "cde")) val v: Int > Lists = defer { diff --git a/kyo-llm/shared/src/main/scala/kyo/llm/ais.scala b/kyo-llm/shared/src/main/scala/kyo/llm/ais.scala index a26927700..fc80b695d 100644 --- a/kyo-llm/shared/src/main/scala/kyo/llm/ais.scala +++ b/kyo-llm/shared/src/main/scala/kyo/llm/ais.scala @@ -5,7 +5,6 @@ import kyo.llm.completions._ import kyo.llm.configs._ import kyo.llm.contexts._ import kyo.llm.tools._ -import kyo.concurrent.Joins import kyo.concurrent.atomics._ import kyo.concurrent.fibers._ import kyo.ios._ @@ -155,7 +154,7 @@ object ais { } yield r } - object AIs extends Joins[AIs] { + object AIs { type Effects = Sums[State] with Requests @@ -220,30 +219,6 @@ object ais { State.get.map { st => Tries.run[T, S](f).map(r => State.set(st).map(_ => r.get)) } - - def race[T](l: Seq[T > AIs])(implicit f: Flat[T > AIs]): T > AIs = - State.get.map { st => - Requests.race[(T, State)](l.map(State.run[T, Requests](st))) - .map { - case (v, st) => - State.set(st).map(_ => v) - } - } - - def parallel[T](l: Seq[T > AIs])(implicit f: Flat[T > AIs]): Seq[T] > AIs = - State.get.map { st => - Requests.parallel[(T, State)](l.map(State.run[T, Requests](st))) - .map { rl => - val r = rl.map(_._1) - val st = - rl.map(_._2) - .foldLeft(Map.empty: State) { - case (acc, st) => - summer.add(acc, st) - } - State.set(st).map(_ => r) - } - } } object internal {