diff --git a/formula-coroutines/src/main/java/com/instacart/formula/coroutines/FlowRuntime.kt b/formula-coroutines/src/main/java/com/instacart/formula/coroutines/FlowRuntime.kt index ab4a8d77..4bc6c16e 100644 --- a/formula-coroutines/src/main/java/com/instacart/formula/coroutines/FlowRuntime.kt +++ b/formula-coroutines/src/main/java/com/instacart/formula/coroutines/FlowRuntime.kt @@ -3,7 +3,6 @@ package com.instacart.formula.coroutines import com.instacart.formula.FormulaRuntime import com.instacart.formula.IFormula import com.instacart.formula.Inspector -import com.instacart.formula.internal.ThreadChecker import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.channels.awaitClose import kotlinx.coroutines.channels.trySendBlocking @@ -22,12 +21,8 @@ object FlowRuntime { inspector: Inspector? = null, isValidationEnabled: Boolean = false, ): Flow { - val threadChecker = ThreadChecker(formula) return callbackFlow { - threadChecker.check("Need to subscribe on main thread.") - val runtime = FormulaRuntime( - threadChecker = threadChecker, formula = formula, onOutput = this::trySendBlocking, onError = this::close, @@ -35,7 +30,7 @@ object FlowRuntime { isValidationEnabled = isValidationEnabled, ) - input.onEach { input -> runtime.onInput(input) }.launchIn(this) + input.onEach(runtime::onInput).launchIn(this) awaitClose { runtime.terminate() diff --git a/formula-rxjava3/src/main/java/com/instacart/formula/rxjava3/RxJavaRuntime.kt b/formula-rxjava3/src/main/java/com/instacart/formula/rxjava3/RxJavaRuntime.kt index ded6770d..0fee03b5 100644 --- a/formula-rxjava3/src/main/java/com/instacart/formula/rxjava3/RxJavaRuntime.kt +++ b/formula-rxjava3/src/main/java/com/instacart/formula/rxjava3/RxJavaRuntime.kt @@ -1,10 +1,8 @@ package com.instacart.formula.rxjava3 -import com.instacart.formula.FormulaPlugins import com.instacart.formula.FormulaRuntime import com.instacart.formula.IFormula import com.instacart.formula.Inspector -import com.instacart.formula.internal.ThreadChecker import io.reactivex.rxjava3.core.Observable import io.reactivex.rxjava3.disposables.CompositeDisposable import io.reactivex.rxjava3.disposables.FormulaDisposableHelper @@ -16,12 +14,8 @@ object RxJavaRuntime { inspector: Inspector? = null, isValidationEnabled: Boolean = false, ): Observable { - val threadChecker = ThreadChecker(formula) - return Observable.create { emitter -> - threadChecker.check("Need to subscribe on main thread.") - + return Observable.create { emitter -> val runtime = FormulaRuntime( - threadChecker = threadChecker, formula = formula, onOutput = emitter::onNext, onError = emitter::onError, @@ -30,11 +24,9 @@ object RxJavaRuntime { ) val disposables = CompositeDisposable() - disposables.add(input.subscribe({ input -> - runtime.onInput(input) - }, emitter::onError)) - + disposables.add(input.subscribe(runtime::onInput, emitter::onError)) disposables.add(FormulaDisposableHelper.fromRunnable(runtime::terminate)) + emitter.setDisposable(disposables) }.distinctUntilChanged() } diff --git a/formula/src/main/java/com/instacart/formula/FormulaRuntime.kt b/formula/src/main/java/com/instacart/formula/FormulaRuntime.kt index 5fafcd1a..5375c6d4 100644 --- a/formula/src/main/java/com/instacart/formula/FormulaRuntime.kt +++ b/formula/src/main/java/com/instacart/formula/FormulaRuntime.kt @@ -3,30 +3,27 @@ package com.instacart.formula import com.instacart.formula.internal.FormulaManager import com.instacart.formula.internal.FormulaManagerImpl import com.instacart.formula.internal.ManagerDelegate -import com.instacart.formula.internal.ThreadChecker +import com.instacart.formula.internal.SynchronizedUpdateQueue import java.util.LinkedList /** * Takes a [Formula] and creates an Observable from it. */ class FormulaRuntime( - private val threadChecker: ThreadChecker, private val formula: IFormula, private val onOutput: (Output) -> Unit, private val onError: (Throwable) -> Unit, private val isValidationEnabled: Boolean = false, inspector: Inspector? = null, ) : ManagerDelegate { + private val synchronizedUpdateQueue = SynchronizedUpdateQueue() + private val inspector = FormulaPlugins.inspector(type = formula.type(), local = inspector) private val implementation = formula.implementation() + private var manager: FormulaManagerImpl? = null - private val inspector = FormulaPlugins.inspector( - type = formula.type(), - local = inspector, - ) private var emitOutput = false private var lastOutput: Output? = null - private var input: Input? = null private var key: Any? = null @@ -43,8 +40,7 @@ class FormulaRuntime( private var inputId: Int = 0 /** - * Global transition effect queue which executes side-effects - * after all formulas are idle. + * Global transition effect queue which executes side-effects after all formulas are idle. */ private var globalEffectQueue = LinkedList() @@ -66,8 +62,10 @@ class FormulaRuntime( } fun onInput(input: Input) { - threadChecker.check("Input arrived on a wrong thread.") + synchronizedUpdateQueue.postUpdate { onInputInternal(input) } + } + private fun onInputInternal(input: Input) { if (isRuntimeTerminated) return val isKeyValid = isKeyValid(input) @@ -105,8 +103,10 @@ class FormulaRuntime( } fun terminate() { - threadChecker.check("Need to unsubscribe on the main thread.") + synchronizedUpdateQueue.postUpdate(this::terminateInternal) + } + private fun terminateInternal() { if (isRuntimeTerminated) return isRuntimeTerminated = true @@ -127,8 +127,6 @@ class FormulaRuntime( } override fun onPostTransition(effects: Effects?, evaluate: Boolean) { - threadChecker.check("Only thread that created it can post transition result") - effects?.let { globalEffectQueue.addLast(effects) } @@ -271,6 +269,7 @@ class FormulaRuntime( private fun initManager(initialInput: Input): FormulaManagerImpl { return FormulaManagerImpl( + queue = synchronizedUpdateQueue, delegate = this, formula = implementation, initialInput = initialInput, diff --git a/formula/src/main/java/com/instacart/formula/internal/ChildrenManager.kt b/formula/src/main/java/com/instacart/formula/internal/ChildrenManager.kt index 3fcd9376..1148893a 100644 --- a/formula/src/main/java/com/instacart/formula/internal/ChildrenManager.kt +++ b/formula/src/main/java/com/instacart/formula/internal/ChildrenManager.kt @@ -110,6 +110,7 @@ internal class ChildrenManager( val childFormulaHolder = children.findOrInit(key) { val implementation = formula.implementation() FormulaManagerImpl( + queue = delegate.queue, delegate = delegate, formula = implementation, initialInput = input, diff --git a/formula/src/main/java/com/instacart/formula/internal/FormulaManagerImpl.kt b/formula/src/main/java/com/instacart/formula/internal/FormulaManagerImpl.kt index f1c2a998..12a64b89 100644 --- a/formula/src/main/java/com/instacart/formula/internal/FormulaManagerImpl.kt +++ b/formula/src/main/java/com/instacart/formula/internal/FormulaManagerImpl.kt @@ -18,6 +18,7 @@ import kotlin.reflect.KClass * a state change, it will rerun [Formula.evaluate]. */ internal class FormulaManagerImpl( + val queue: SynchronizedUpdateQueue, private val delegate: ManagerDelegate, private val formula: Formula, initialInput: Input, diff --git a/formula/src/main/java/com/instacart/formula/internal/ListenerImpl.kt b/formula/src/main/java/com/instacart/formula/internal/ListenerImpl.kt index 3cf42ca6..886709b9 100644 --- a/formula/src/main/java/com/instacart/formula/internal/ListenerImpl.kt +++ b/formula/src/main/java/com/instacart/formula/internal/ListenerImpl.kt @@ -18,8 +18,10 @@ internal class ListenerImpl(internal var key: Any) : Liste // TODO: log if null listener (it might be due to formula removal or due to callback removal) val manager = manager ?: return - val deferredTransition = DeferredTransition(this, transition, event) - manager.onPendingTransition(deferredTransition) + manager.queue.postUpdate { + val deferredTransition = DeferredTransition(this, transition, event) + manager.onPendingTransition(deferredTransition) + } } fun disable() { diff --git a/formula/src/main/java/com/instacart/formula/internal/SynchronizedUpdateQueue.kt b/formula/src/main/java/com/instacart/formula/internal/SynchronizedUpdateQueue.kt new file mode 100644 index 00000000..45a59cd8 --- /dev/null +++ b/formula/src/main/java/com/instacart/formula/internal/SynchronizedUpdateQueue.kt @@ -0,0 +1,116 @@ +package com.instacart.formula.internal + +import java.util.concurrent.ConcurrentLinkedQueue +import java.util.concurrent.atomic.AtomicReference + +/** + * We can only process one formula update at a time. To enable thread-safety we use a + * non-blocking event queue with a serial confinement strategy for queue processing. All external + * formula events use [postUpdate] function which adds the update to [updateQueue], and then, + * tries to start processing the queue via atomic [threadRunning] variable. If another thread + * was first to take over [threadRunning], we let that thread continue and we exit out. Given + * all formula state access is gated via [threadRunning] atomic reference, we are able to ensure + * that there is happens-before relationship between each thread and memory changes are visible + * between them. + */ +class SynchronizedUpdateQueue { + /** + * Defines a thread currently executing formula update. Null value indicates idle queue. + * + * To ensure that memory changes within formula internals are synchronized between threads, + * we piggyback on the internal synchronization of this variable. Modification to this + * variable wraps around every formula update: + * - threadRunning = MyThread + * - formulaUpdate() + * - threadRunning = null + * + * This creates happens-before relationship between multiple threads and makes sure that + * all modifications within formulaUpdate() block are visible to the next thread. + */ + private val threadRunning = AtomicReference() + + /** + * A non-blocking thread-safe FIFO queue that tracks pending updates. + */ + private val updateQueue = ConcurrentLinkedQueue<() -> Unit>() + + /** + * To ensure that we execute one update at a time, all external formula events use this + * function to post updates. We add the update to a queue and then try to start processing. + * Failure to start processing indicates that another thread was first and we allow that + * thread to continue. + */ + fun postUpdate(update: () -> Unit) { + val currentThread = Thread.currentThread() + val owner = threadRunning.get() + if (owner == currentThread) { + // This indicates a nested update where an update triggers another update. Given we + // are already thread gated, we can execute this update immediately without a need + // for any extra synchronization. + update() + return + } + + val success = if (updateQueue.peek() == null) { + // No pending update, let's try to run our update immediately + takeOver(currentThread) { update() } + } else { + false + } + + if (!success) { + updateQueue.add(update) + } + tryToDrainQueue(currentThread) + } + + /** + * Tries to drain the update queue. It will process one update at a time until + * queue is empty or another thread takes over processing. + */ + private fun tryToDrainQueue(currentThread: Thread) { + while (true) { + // First, we peek to see if there is a value to process. + val peekUpdate = updateQueue.peek() + if (peekUpdate != null) { + // Since there is a pending update, we try to process it. + val success = takeOver(currentThread) { + // We successfully set ourselves as the running thread + // We poll the queue to get the latest value (it could have changed). It + // also removes the value from the queue. + val actualUpdate = updateQueue.poll() + actualUpdate?.invoke() + } + + if (!success) { + return + } + } else { + return + } + } + } + + /** + * Tries to take over the processing and execute an [update]. + * + * Returns true if it was able to successfully claim the ownership and execute the + * update. Otherwise, returns false (this indicates another thread claimed the right first). + */ + private inline fun takeOver(currentThread: Thread, crossinline update: () -> Unit): Boolean { + return if (threadRunning.compareAndSet(null, currentThread)) { + // We took over the processing, let's execute the [update] + try { + update() + } finally { + // We reset the running thread. To ensure happens-before relationship, this must + // always happen after the [update]. + threadRunning.set(null) + } + true + } else { + // Another thread is running, so we return false. + false + } + } +} \ No newline at end of file diff --git a/formula/src/main/java/com/instacart/formula/internal/ThreadChecker.kt b/formula/src/main/java/com/instacart/formula/internal/ThreadChecker.kt deleted file mode 100644 index bab7e8f4..00000000 --- a/formula/src/main/java/com/instacart/formula/internal/ThreadChecker.kt +++ /dev/null @@ -1,19 +0,0 @@ -package com.instacart.formula.internal - -import com.instacart.formula.IFormula - -/** - * A poor man's thread checker. - */ -class ThreadChecker(private val formula: IFormula<*, *>) { - private val formulaType = formula::class.qualifiedName - private val threadName = Thread.currentThread().name - private val id = Thread.currentThread().id - - fun check(errorMessage: String) { - val thread = Thread.currentThread() - if (thread.id != id) { - throw IllegalStateException("$formulaType - $errorMessage Expected: $threadName, Was: ${thread.name}") - } - } -} diff --git a/formula/src/test/java/com/instacart/formula/FormulaRuntimeTest.kt b/formula/src/test/java/com/instacart/formula/FormulaRuntimeTest.kt index 16b8ea9b..f8c6bfee 100644 --- a/formula/src/test/java/com/instacart/formula/FormulaRuntimeTest.kt +++ b/formula/src/test/java/com/instacart/formula/FormulaRuntimeTest.kt @@ -35,6 +35,7 @@ import com.instacart.formula.subjects.KeyFormula import com.instacart.formula.subjects.KeyUsingListFormula import com.instacart.formula.subjects.MessageFormula import com.instacart.formula.subjects.MixingCallbackUseWithKeyUse +import com.instacart.formula.subjects.MultiThreadRobot import com.instacart.formula.subjects.MultipleChildEvents import com.instacart.formula.subjects.NestedCallbackCallRobot import com.instacart.formula.subjects.NestedChildTransitionAfterNoEvaluationPass @@ -62,6 +63,7 @@ import com.instacart.formula.subjects.TestKey import com.instacart.formula.subjects.TransitionAfterNoEvaluationPass import com.instacart.formula.subjects.UseInputFormula import com.instacart.formula.subjects.ReusableFunctionCreatesUniqueListeners +import com.instacart.formula.subjects.SleepFormula import com.instacart.formula.subjects.UniqueListenersWithinLoop import com.instacart.formula.subjects.UsingKeyToScopeCallbacksWithinAnotherFunction import com.instacart.formula.subjects.UsingKeyToScopeChildFormula @@ -83,6 +85,9 @@ import org.junit.rules.RuleChain import org.junit.rules.TestName import org.junit.runner.RunWith import org.junit.runners.Parameterized +import java.util.concurrent.CountDownLatch +import java.util.concurrent.Executors +import java.util.concurrent.ThreadFactory import java.util.concurrent.TimeUnit import kotlin.reflect.KClass @@ -633,6 +638,7 @@ class FormulaRuntimeTest(val runtime: TestableRuntime, val name: String) { assertThat(eventCallback.values()).containsExactly("a", "b").inOrder() } + @Ignore("Not valid anymore since we enabled thread-safety") @Test fun `when action returns value on background thread, we emit an error`() { val bgAction = EventOnBgThreadAction() @@ -646,7 +652,9 @@ class FormulaRuntimeTest(val runtime: TestableRuntime, val name: String) { } val observer = runtime.test(formula, Unit) - bgAction.latch.await(50, TimeUnit.MILLISECONDS) + if (!bgAction.latch.await(50, TimeUnit.MILLISECONDS)) { + throw IllegalStateException("Timeout") + } assertThat(bgAction.errors.values().firstOrNull()?.message).contains( "com.instacart.formula.subjects.OnlyUpdateFormula - Only thread that created it can post transition result Expected:" ) @@ -1294,6 +1302,24 @@ class FormulaRuntimeTest(val runtime: TestableRuntime, val name: String) { .assertValue(0) } + @Test + fun `formula multi thread handoff`() { + with(MultiThreadRobot(runtime)) { + threadA(50) + threadB(10) + awaitCompletion() + threadB(10) + + awaitEvents( + SleepFormula.SleepEvent(50, "thread-a"), + // First thread-b event is handed-off to thread-a + SleepFormula.SleepEvent(10, "thread-a"), + // Second thread-b event is handled by thread-b + SleepFormula.SleepEvent(10, "thread-b") + ) + } + } + @Test fun `inspector events`() { val globalInspector = TestInspector() diff --git a/formula/src/test/java/com/instacart/formula/internal/ThreadCheckerTest.kt b/formula/src/test/java/com/instacart/formula/internal/ThreadCheckerTest.kt deleted file mode 100644 index 2800eb21..00000000 --- a/formula/src/test/java/com/instacart/formula/internal/ThreadCheckerTest.kt +++ /dev/null @@ -1,29 +0,0 @@ -package com.instacart.formula.internal - -import com.google.common.truth.Truth -import com.instacart.formula.subjects.DynamicStreamSubject -import com.instacart.formula.subjects.KeyFormula -import org.junit.Test -import java.util.concurrent.CountDownLatch -import java.util.concurrent.TimeUnit - -class ThreadCheckerTest { - - @Test - fun `detects incorrect thread`() { - val checker = ThreadChecker(KeyFormula()) - - val latch = CountDownLatch(1) - Thread { - try { - checker.check("error message") - error("thread checker should fail") - } catch (e: Exception) { - Truth.assertThat(e.message).startsWith("com.instacart.formula.subjects.KeyFormula - error message") - latch.countDown() - } - }.start() - - latch.await(1, TimeUnit.SECONDS) - } -} \ No newline at end of file diff --git a/formula/src/test/java/com/instacart/formula/subjects/MultiThreadRobot.kt b/formula/src/test/java/com/instacart/formula/subjects/MultiThreadRobot.kt new file mode 100644 index 00000000..f98f8f49 --- /dev/null +++ b/formula/src/test/java/com/instacart/formula/subjects/MultiThreadRobot.kt @@ -0,0 +1,88 @@ +package com.instacart.formula.subjects + +import com.google.common.truth.Truth +import com.instacart.formula.test.TestableRuntime +import java.util.concurrent.ConcurrentLinkedQueue +import java.util.concurrent.CountDownLatch +import java.util.concurrent.Executor +import java.util.concurrent.Executors +import java.util.concurrent.ThreadFactory +import java.util.concurrent.TimeUnit + +class MultiThreadRobot(val runtime: TestableRuntime) { + class NamedThreadFactory(private val name: String): ThreadFactory { + override fun newThread(r: Runnable): Thread { + return Thread(r, name) + } + } + + private val executorA = Executors.newSingleThreadExecutor(NamedThreadFactory("thread-a")) + private val executorB = Executors.newSingleThreadExecutor(NamedThreadFactory("thread-b")) + + private val threadFormula = SleepFormula() + private val observer = runtime.test(threadFormula, Unit) + + @Volatile private var nextEventStartedLatch: CountDownLatch? = null + + private val eventCompletionLatches = ConcurrentLinkedQueue() + + private fun execute(executor: Executor, sleepDuration: Long) { + // Creating a latch and adding it to a list to make sure we are able to + // wait for all event completion. + val completionLatch = CountDownLatch(1) + eventCompletionLatches.add(completionLatch) + + // To ensure predictable order, we wait for previous executor to + // start, before we start ourselves. + val previousLatch = nextEventStartedLatch + if (previousLatch != null) { + await(previousLatch, 100, TimeUnit.MILLISECONDS) + } + + val localStartLatch = CountDownLatch(1) + nextEventStartedLatch = localStartLatch + + executor.execute { + if (previousLatch != null) { + // We give a little extra time for the other executor to pick up the event + Thread.sleep(10) + } + + localStartLatch.countDown() + + observer.output { + this.onSleep(sleepDuration) + + completionLatch.countDown() + } + } + } + + fun threadA(sleepDuration: Long) = apply { + execute(executorA, sleepDuration) + } + + fun threadB(sleepDuration: Long) = apply { + execute(executorB, sleepDuration) + } + + fun awaitCompletion() = apply { + for (latch in eventCompletionLatches) { + await(latch, 1, TimeUnit.SECONDS) + } + } + + fun awaitEvents(vararg sleepEvents: SleepFormula.SleepEvent) = apply { + awaitCompletion() + + observer.output { + Truth.assertThat(this.sleepEvents).containsExactly(*sleepEvents).inOrder() + } + } + + private fun await(latch: CountDownLatch, timeout: Long, unit: TimeUnit) { + if (!latch.await(timeout, unit)) { + throw IllegalStateException("Timeout") + } + } +} \ No newline at end of file diff --git a/formula/src/test/java/com/instacart/formula/subjects/SleepFormula.kt b/formula/src/test/java/com/instacart/formula/subjects/SleepFormula.kt new file mode 100644 index 00000000..08ddecef --- /dev/null +++ b/formula/src/test/java/com/instacart/formula/subjects/SleepFormula.kt @@ -0,0 +1,55 @@ +package com.instacart.formula.subjects + +import com.instacart.formula.Action +import com.instacart.formula.Evaluation +import com.instacart.formula.Formula +import com.instacart.formula.Snapshot + +class SleepFormula : Formula() { + + data class SleepEvent( + val duration: Long, + val threadName: String, + ) + + data class State( + val sleepEvents: List = emptyList(), + val pendingEvent: SleepEvent? = null, + ) + + data class Output( + val sleepEvents: List, + val onSleep: (Long) -> Unit, + ) + + override fun initialState(input: Unit): State { + return State() + } + + override fun Snapshot.evaluate(): Evaluation { + return Evaluation( + output = Output( + sleepEvents = state.sleepEvents, + onSleep = context.onEvent { + val newEvent = SleepEvent( + duration = it, + threadName = Thread.currentThread().name, + ) + transition(state.copy(pendingEvent = newEvent)) + } + ), + actions = context.actions { + state.pendingEvent?.let { + Action.onData(it).onEvent { event -> + // Using sleep to control multi-threaded events + Thread.sleep(event.duration) + val newState = state.copy( + sleepEvents = state.sleepEvents + event, + ) + transition(newState) + } + } + } + ) + } +} \ No newline at end of file