Skip to content

Commit

Permalink
quickfix for structured concurrency
Browse files Browse the repository at this point in the history
  • Loading branch information
maxcai314 committed Jul 26, 2024
1 parent 0ba1e62 commit db3691e
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 32 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package com.kuriosityrobotics.centerstage.concurrent;

public interface JavaLangAccess {
interface JavaLangAccess {

/**
* Returns the ThreadContainer for a thread, may be null.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
package com.kuriosityrobotics.centerstage.concurrent;

import java.util.Collections;
import java.util.Map;
import java.util.WeakHashMap;
import java.util.concurrent.ConcurrentHashMap;

public class SharedSecrets {
class SharedSecrets {
private static final JavaLangAccess JLA = new JavaLangAccess() {
private final Map<Thread, ThreadContainer> CONTAINERS = new ConcurrentHashMap<>();
private final Map<Thread, ThreadContainer> CONTAINERS = Collections.synchronizedMap(new WeakHashMap<>());
private final ThreadLocal<StackableScope> HEAD_STACKABLE_SCOPE = new ThreadLocal<>();

public ThreadContainer threadContainer(Thread thread) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
* popForcefully methods are used to pop the StackableScope from the current thread's
* scope stack.
*/
public class StackableScope {
class StackableScope {
private static final JavaLangAccess JLA = SharedSecrets.getJavaLangAccess();

private final Thread owner;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -513,15 +513,12 @@ protected <U extends T> Subtask<U> forkInner(Callable<? extends U> task) {

var subtask = new SubtaskImpl<U>(this, task);
if (s < SHUTDOWN) {
// create thread to run task
Thread thread = factory.newThread(subtask);
if (thread == null) {
throw new RejectedExecutionException("Rejected by thread factory");
}

// attempt to start the thread
try {
flock.start(thread);
Thread thread = flock.start(factory, subtask);
if (thread == null) {
throw new RejectedExecutionException("Rejected by thread factory");
}
} catch (IllegalStateException e) {
// shutdown by another thread, or underlying flock is shutdown due
// to unstructured use
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
/**
* A container of threads.
*/
public abstract class ThreadContainer extends StackableScope {
abstract class ThreadContainer extends StackableScope {

/**
* Creates a ThreadContainer.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
* This class consists exclusively of static methods to support debugging and
* monitoring of threads.
*/
public class ThreadContainers {
class ThreadContainers {
private static final JavaLangAccess JLA = SharedSecrets.getJavaLangAccess();

// the set of thread containers registered with this class
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
Expand All @@ -46,7 +47,7 @@
* in a flock remain in the flock until they terminate.
*
* <p> ThreadFlock defines the {@link #open(String) open} method to open a new flock,
* the {@link #start(Thread) start} method to start a thread in the flock, and the
* the {@link #start(ThreadFactory threadFactory, Runnable target) start} method to start a thread in the flock, and the
* {@link #close() close} method to close the flock. The {@code close} waits for all
* threads in the flock to finish. The {@link #awaitAll() awaitAll} method may be used
* to wait for all threads to finish without closing the flock. The {@link #wakeup()}
Expand Down Expand Up @@ -79,7 +80,7 @@
* <p> Unless otherwise specified, passing a {@code null} argument to a method
* in this class will cause a {@link NullPointerException} to be thrown.
*/
public class ThreadFlock implements AutoCloseable {
class ThreadFlock implements AutoCloseable {
private static final JavaLangAccess JLA = SharedSecrets.getJavaLangAccess();

private final Set<Thread> threads = ConcurrentHashMap.newKeySet();
Expand Down Expand Up @@ -222,17 +223,6 @@ public Thread owner() {
return container.owner();
}

private static void wrapThreadRunnable(Thread thread, Function<Runnable, Runnable> wrapper) {
try {
var targetField = Thread.class.getDeclaredField("target");
targetField.setAccessible(true);
var target = (Runnable) targetField.get(thread);
targetField.set(thread, wrapper.apply(target));
}catch (NoSuchFieldException | IllegalAccessException e) {
// "sneaky throw"
}
}

/**
* Starts the given unstarted thread in this flock.
*
Expand All @@ -242,7 +232,8 @@ private static void wrapThreadRunnable(Thread thread, Function<Runnable, Runnabl
* <p> This method may only be invoked by the flock owner or threads {@linkplain
* #containsThread(Thread) contained} in the flock.
*
* @param thread the unstarted thread
* @param threadFactory the thread factory
* @param target the task to be run on the thread
* @return the thread, started
* @throws IllegalStateException if this flock is shutdown or closed
* @throws IllegalThreadStateException if the given thread was already started
Expand All @@ -251,16 +242,16 @@ private static void wrapThreadRunnable(Thread thread, Function<Runnable, Runnabl
* @throws StructureViolationException if the current
* scoped value bindings are not the same as when the flock was created
*/
public Thread start(Thread thread) {
public Thread start(ThreadFactory threadFactory, Runnable target) {
ensureOwnerOrContainsThread();
// hook thread
var startLatch = new CountDownLatch(1);
var initException = new AtomicReference<Throwable>();

Function<Runnable, Runnable> targetWrapper = target -> () -> {
Runnable wrappedTarget = () -> {
boolean started = false;
try {
onStart(thread);
onStart(Thread.currentThread());

started = true;
startLatch.countDown();
Expand All @@ -273,12 +264,13 @@ public Thread start(Thread thread) {
throw e;
} finally {
if (started)
onExit(thread);
onExit(Thread.currentThread());

startLatch.countDown();
}
};
wrapThreadRunnable(thread, targetWrapper);

Thread thread = threadFactory.newThread(wrappedTarget);
JLA.start(thread, container);
try {
startLatch.await();
Expand Down

0 comments on commit db3691e

Please sign in to comment.