Skip to content

Commit

Permalink
Show file tree
Hide file tree
Showing 10 changed files with 172 additions and 73 deletions.
2 changes: 1 addition & 1 deletion kyo-bench/src/main/scala/kyo/bench/TRefMultiBench.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class TRefMultiBench(parallelism: Int) extends Bench.ForkOnly(parallelism):

STM.runtime[IO].flatMap { stm =>
for
refs <- stm.commit(Seq.fill(parallelism)(stm.TVar.of(0)).sequence)
refs <- Seq.fill(parallelism)(stm.commit(stm.TVar.of(0))).sequence
_ <- refs.map(ref => stm.commit(ref.modify(_ + 1))).parSequence_
result <- stm.commit(refs.traverse(_.get).map(_.sum))
yield result
Expand Down
11 changes: 5 additions & 6 deletions kyo-core/shared/src/main/scala/kyo/Retry.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,17 @@ object Retry:
SafeClassTag[E],
Frame
): A < (Async & Abort[E] & S) =
Loop(schedule) { schedule =>
Abort.run[E](v).map(_.fold { r =>
Abort.run[E](v).map {
case Result.Success(r) => r
case error: Result.Error[?] =>
Clock.now.map { now =>
schedule.next(now).map { (delay, nextSchedule) =>
Async.delay(delay)(Loop.continue(nextSchedule))
Async.delay(delay)(Retry[E](nextSchedule)(v))
}.getOrElse {
Abort.get(r)
Abort.get(error)
}
}
}(Loop.done(_)))
}
end apply
end RetryOps

/** Creates a RetryOps instance for the specified error type.
Expand Down
2 changes: 1 addition & 1 deletion kyo-prelude/shared/src/main/scala/kyo/Local.scala
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ object Local:
ContextEffect.suspendAndMap(tag, Map.empty)(map => f(map.getOrElse(this, default).asInstanceOf[A]))

def let[B, S](value: A)(v: B < S)(using Frame) =
ContextEffect.handle(tag, Map(this -> value), _.updated(this, value.asInstanceOf[AnyRef]))(v)
ContextEffect.handle(tag, Map.empty[Local[?], AnyRef].updated(this, value), _.updated(this, value.asInstanceOf[AnyRef]))(v)

def update[B, S](f: A => A)(v: B < S)(using Frame) =
ContextEffect.handle(
Expand Down
12 changes: 6 additions & 6 deletions kyo-prelude/shared/src/main/scala/kyo/Var.scala
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ object Var:
runWith(state)(v)((state, result) => (state, result))

object isolate:
abstract private[kyo] class Base[V: Tag] extends Isolate[Var[V]]:
abstract private[kyo] class Base[V](using Tag[Var[V]]) extends Isolate[Var[V]]:
type State = V
def use[A, S2](f: V => A < S2)(using Frame) = Var.use(f)
def resume[A: Flat, S2](state: State, v: A < (Var[V] & S2))(using Frame) =
Expand All @@ -183,10 +183,10 @@ object Var:
* @return
* An isolate that updates the Var with its isolated value
*/
def update[V: Tag]: Isolate[Var[V]] =
def update[V](using Tag[Var[V]]): Isolate[Var[V]] =
new Base[V]:
def restore[A: Flat, S2](state: V, v: A < S2)(using Frame) =
Var.set(state).andThen(v)
Var.setAndThen(state)(v)

/** Creates an isolate that merges Var values using a combination function.
*
Expand All @@ -200,10 +200,10 @@ object Var:
* @return
* An isolate that merges Var values
*/
def merge[V: Tag](f: (V, V) => V): Isolate[Var[V]] =
def merge[V](f: (V, V) => V)(using Tag[Var[V]]): Isolate[Var[V]] =
new Base[V]:
def restore[A: Flat, S2](state: V, v: A < S2)(using Frame) =
Var.use[V](prev => Var.set(f(prev, state)).andThen(v))
Var.use[V](prev => Var.setAndThen(f(prev, state))(v))

/** Creates an isolate that keeps Var modifications local.
*
Expand All @@ -215,7 +215,7 @@ object Var:
* @return
* An isolate that discards Var modifications
*/
def discard[V: Tag]: Isolate[Var[V]] =
def discard[V](using Tag[Var[V]]): Isolate[Var[V]] =
new Base[V]:
def restore[A: Flat, S2](state: V, v: A < S2)(using Frame) =
v
Expand Down
144 changes: 123 additions & 21 deletions kyo-stm/shared/src/main/scala/kyo/STM.scala
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
package kyo

import java.util.Arrays
import kyo.Result.Fail
import scala.annotation.tailrec
import scala.util.control.NoStackTrace

/** A FailedTransaction exception that is thrown when a transaction fails to commit. Contains the frame where the failure occurred.
*/
case class FailedTransaction(frame: Frame) extends Exception(frame.position.show)
case class FailedTransaction(frame: Frame) extends Exception(frame.position.show) with NoStackTrace

/** Software Transactional Memory (STM) provides concurrent access to shared state using optimistic locking. Rather than acquiring locks
* upfront, transactions execute speculatively and automatically retry if conflicts are detected during commit. While this enables better
Expand Down Expand Up @@ -108,34 +111,97 @@ object STM:
// New transaction without a parent, use regular commit flow
Retry[FailedTransaction](retrySchedule) {
TID.useNew { tid =>
TRefLog.runWith(v) { (log, result) =>
IO.Unsafe {
// Attempt to acquire locks and commit the transaction
val (locked, unlocked) =
// Sort references by identity to prevent deadlocks
log.toSeq.sortBy((ref, _) => ref.hashCode)
.span((ref, entry) => ref.lock(entry))

if unlocked.nonEmpty then
// Failed to acquire some locks - rollback and retry
locked.foreach((ref, entry) => ref.unlock(entry))
Abort.fail(FailedTransaction(frame))
else
// Successfully locked all references - commit changes
locked.foreach((ref, entry) => ref.commit(tid, entry))
// Release all locks
locked.foreach((ref, entry) => ref.unlock(entry))
Var.runWith(TRefLog.empty)(v) { (log, result) =>
val logMap = log.toMap
logMap.size match
case 0 =>
// Nothing to commit
result
end if
}
case 1 =>
// Fast-path for a single ref
IO.Unsafe {
val (ref, entry) = logMap.head
// No need to pre-validate since `lock` validates and
// there's a single ref
if ref.lock(entry) then
ref.commit(tid, entry)
ref.unlock(entry)
result
else
Abort.fail(FailedTransaction(frame))
end if
}
case size =>
// Commit multiple refs
IO.Unsafe {
// Flattened representation of the log
val array = new Array[Any](size * 2)

try
def fail = throw new FailedTransaction(frame)

var i = 0
// Pre-validate and dump the log to the flat array
logMap.foreachEntry { (ref, entry) =>
// This code uses exception throwing because
// foreachEntry is the only way to traverse the
// map without allocating tuples, so throwing
// is the workaround to short circuit
if !ref.validate(entry) then fail
array(i) = ref
array(i + 1) = entry
i += 2
}

// Sort references by identity to prevent deadlocks
quickSort(array, size)

// Convenience accessors to the flat log
inline def ref(idx: Int) = array(idx * 2).asInstanceOf[TRef[Any]]
inline def entry(idx: Int) = array(idx * 2 + 1).asInstanceOf[TRefLog.Entry[Any]]

@tailrec def lock(idx: Int): Int =
if idx == size then size
else if !ref(idx).lock(entry(idx)) then idx
else lock(idx + 1)

@tailrec def unlock(idx: Int, upTo: Int): Unit =
if idx < upTo then
ref(idx).unlock(entry(idx))
unlock(idx + 1, upTo)

@tailrec def commit(idx: Int): Unit =
if idx < size then
ref(idx).commit(tid, entry(idx))
commit(idx + 1)

val acquired = lock(0)
if acquired != size then
// Failed to acquire some locks - rollback and retry
unlock(0, acquired)
fail
end if

// Successfully locked all references - commit changes
commit(0)

// Release all locks
unlock(0, size)
result
catch
case ex: FailedTransaction =>
Abort.fail(ex)
end try
}
end match
}
}
}
case parent =>
// Nested transaction inherits parent's transaction context but isolates RefLog.
// On success: changes propagate to parent. On failure: changes are rolled back
// without affecting parent's state.
val result = TRefLog.isolate(v)
val result = TRefLog.isolate.run(v)

// Can't return `result` directly since it has a pending STM effect
// but it's safe to cast because, if there's a parent transaction,
Expand All @@ -145,4 +211,40 @@ object STM:
}

end run

private def quickSort(array: Array[Any], size: Int): Unit =
def swap(i: Int, j: Int): Unit =
val temp = array(i)
array(i) = array(j)
array(j) = temp
val temp2 = array(i + 1)
array(i + 1) = array(j + 1)
array(j + 1) = temp2
end swap

def getHash(idx: Int): Int =
array(idx * 2).hashCode()

@tailrec def partitionLoop(low: Int, hi: Int, pivot: Int, i: Int, j: Int): Int =
if j >= hi then
swap(i * 2, pivot * 2)
i
else if getHash(j) < getHash(pivot) then
swap(i * 2, j * 2)
partitionLoop(low, hi, pivot, i + 1, j + 1)
else
partitionLoop(low, hi, pivot, i, j + 1)

def partition(low: Int, hi: Int): Int =
partitionLoop(low, hi, hi, low, low)

def loop(low: Int, hi: Int): Unit =
if low < hi then
val p = partition(low, hi)
loop(low, p - 1)
loop(p + 1, hi)

if size > 0 then
loop(0, size - 1)
end quickSort
end STM
10 changes: 5 additions & 5 deletions kyo-stm/shared/src/main/scala/kyo/TID.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,20 @@ private[kyo] object TID:
// Unique transaction ID generation
private val nextTid = AtomicLong.Unsafe.init(0)(using AllowUnsafe.embrace.danger)

private val tidLocal = Local.initIsolated(-1L)
private val tidLocal = Local.initIsolated[java.lang.Long](-1L)

def next(using AllowUnsafe): Long = nextTid.incrementAndGet()

def useNew[A, S](f: Long => A < S)(using Frame): A < (S & IO) =
inline def useNew[A, S](inline f: Long => A < S)(using inline frame: Frame): A < (S & IO) =
IO.Unsafe {
val tid = nextTid.incrementAndGet()
tidLocal.let(tid)(f(tid))
}

def use[A, S](f: Long => A < S)(using Frame): A < S =
tidLocal.use(f)
inline def use[A, S](inline f: Long => A < S)(using inline frame: Frame): A < S =
tidLocal.use(f(_))

def useRequired[A, S](f: Long => A < S)(using Frame): A < S =
inline def useRequired[A, S](inline f: Long => A < S)(using inline frame: Frame): A < S =
tidLocal.use {
case -1L => bug("STM operation attempted outside of STM.run - this should be impossible due to effect typing")
case tid => f(tid)
Expand Down
24 changes: 14 additions & 10 deletions kyo-stm/shared/src/main/scala/kyo/TRef.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ sealed trait TRef[A]:
final def update[S](f: A => A < S)(using Frame): Unit < (STM & S) = use(f(_).map(set))

private[kyo] def state(using AllowUnsafe): Write[A]
private[kyo] def validate(entry: Entry[A])(using AllowUnsafe): Boolean
private[kyo] def lock(entry: Entry[A])(using AllowUnsafe): Boolean
private[kyo] def commit(tid: Long, entry: Entry[A])(using AllowUnsafe): Unit
private[kyo] def unlock(entry: Entry[A])(using AllowUnsafe): Unit
Expand All @@ -62,7 +63,7 @@ final private class TRefImpl[A] private[kyo] (initialState: Write[A])
private[kyo] def state(using AllowUnsafe): Write[A] = currentState

def use[B, S](f: A => B < S)(using Frame): B < (STM & S) =
TRefLog.use { log =>
Var.use[TRefLog] { log =>
log.get(this) match
case Present(entry) =>
f(entry.value)
Expand All @@ -76,19 +77,19 @@ final private class TRefImpl[A] private[kyo] (initialState: Write[A])
else
// Append Read to the log and return value
val entry = Read(state.tid, state.value)
TRefLog.setAndThen(log.put(this, entry))(f(state.value))
Var.setAndThen(log.put(this, entry))(f(state.value))
end if
}
}
end match
}

def set(v: A)(using Frame): Unit < STM =
TRefLog.use { log =>
Var.use[TRefLog] { log =>
log.get(this) match
case Present(prev) =>
val entry = Write(prev.tid, v)
TRefLog.setDiscard(log.put(this, entry))
Var.setDiscard(log.put(this, entry))
case Absent =>
TID.useRequired { tid =>
IO {
Expand All @@ -99,15 +100,18 @@ final private class TRefImpl[A] private[kyo] (initialState: Write[A])
else
// Append Write to the log
val entry = Write(state.tid, v)
TRefLog.setDiscard(log.put(this, entry))
Var.setDiscard(log.put(this, entry))
end if
}
}
}

private[kyo] def validate(entry: Entry[A])(using AllowUnsafe): Boolean =
currentState.tid == entry.tid

private[kyo] def lock(entry: Entry[A])(using AllowUnsafe): Boolean =
@tailrec def loop(): Boolean =
currentState.tid == entry.tid && {
validate(entry) && {
val lockState = super.get()
entry match
case Read(tid, value) =>
Expand All @@ -119,9 +123,9 @@ final private class TRefImpl[A] private[kyo] (initialState: Write[A])
end match
}
val locked = loop()
if locked && currentState.tid != entry.tid then
if locked && !validate(entry) then
// This branch handles the race condition where another fiber commits
// after the initial `currentState.tid == entry.tid` check but before the
// after the initial `validate(entry)` check but before the
// lock is acquired. If that's the case, roll back the lock.
unlock(entry)
false
Expand Down Expand Up @@ -160,10 +164,10 @@ object TRef:
*/
def init[A](value: A)(using Frame): TRef[A] < STM =
TID.useRequired { tid =>
TRefLog.use { log =>
Var.use[TRefLog] { log =>
IO.Unsafe {
val ref = TRef.Unsafe.init(tid, value)
TRefLog.setAndThen(log.put(ref, ref.state))(ref)
Var.setAndThen(log.put(ref, ref.state))(ref)
}
}
}
Expand Down
Loading

0 comments on commit d53aa85

Please sign in to comment.