Skip to content

Commit

Permalink
Core: Limit ParallelIterable memory consumption by yielding in tasks (a…
Browse files Browse the repository at this point in the history
…pache#10691)

ParallelIterable schedules 2 * WORKER_THREAD_POOL_SIZE tasks for
processing input iterables. This defaults to 2 * # CPU cores.  When one
or some of the input iterables are considerable in size and the
ParallelIterable consumer is not quick enough, this could result in
unbounded allocation inside `ParallelIterator.queue`. This commit bounds
the queue. When queue is full, the tasks yield and get removed from the
executor. They are resumed when consumer catches up.
  • Loading branch information
findepi authored and zachdisc committed Dec 12, 2024
1 parent 3746091 commit c1b0d52
Show file tree
Hide file tree
Showing 2 changed files with 215 additions and 55 deletions.
222 changes: 167 additions & 55 deletions core/src/main/java/org/apache/iceberg/util/ParallelIterable.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,84 +20,117 @@

import java.io.Closeable;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.ArrayDeque;
import java.util.Deque;
import java.util.Iterator;
import java.util.NoSuchElementException;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.iceberg.exceptions.RuntimeIOException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Supplier;
import org.apache.iceberg.io.CloseableGroup;
import org.apache.iceberg.io.CloseableIterable;
import org.apache.iceberg.io.CloseableIterator;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
import org.apache.iceberg.relocated.com.google.common.collect.Iterables;
import org.apache.iceberg.relocated.com.google.common.io.Closer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ParallelIterable<T> extends CloseableGroup implements CloseableIterable<T> {

private static final Logger LOG = LoggerFactory.getLogger(ParallelIterable.class);

// Logic behind default value: ParallelIterable is often used for file planning.
// Assuming that a DataFile or DeleteFile is about 500 bytes, a 30k limit uses 14.3 MB of memory.
private static final int DEFAULT_MAX_QUEUE_SIZE = 30_000;

private final Iterable<? extends Iterable<T>> iterables;
private final ExecutorService workerPool;

// Bound for number of items in the queue to limit memory consumption
// even in the case when input iterables are large.
private final int approximateMaxQueueSize;

public ParallelIterable(Iterable<? extends Iterable<T>> iterables, ExecutorService workerPool) {
this.iterables = iterables;
this.workerPool = workerPool;
this(iterables, workerPool, DEFAULT_MAX_QUEUE_SIZE);
}

public ParallelIterable(
Iterable<? extends Iterable<T>> iterables,
ExecutorService workerPool,
int approximateMaxQueueSize) {
this.iterables = Preconditions.checkNotNull(iterables, "Input iterables cannot be null");
this.workerPool = Preconditions.checkNotNull(workerPool, "Worker pool cannot be null");
this.approximateMaxQueueSize = approximateMaxQueueSize;
}

@Override
public CloseableIterator<T> iterator() {
ParallelIterator<T> iter = new ParallelIterator<>(iterables, workerPool);
ParallelIterator<T> iter =
new ParallelIterator<>(iterables, workerPool, approximateMaxQueueSize);
addCloseable(iter);
return iter;
}

private static class ParallelIterator<T> implements CloseableIterator<T> {
private final Iterator<Runnable> tasks;
private final Iterator<Task<T>> tasks;
private final Deque<Task<T>> yieldedTasks = new ArrayDeque<>();
private final ExecutorService workerPool;
private final Future<?>[] taskFutures;
private final CompletableFuture<Optional<Task<T>>>[] taskFutures;
private final ConcurrentLinkedQueue<T> queue = new ConcurrentLinkedQueue<>();
private volatile boolean closed = false;
private final int maxQueueSize;
private final AtomicBoolean closed = new AtomicBoolean(false);

private ParallelIterator(
Iterable<? extends Iterable<T>> iterables, ExecutorService workerPool) {
Iterable<? extends Iterable<T>> iterables, ExecutorService workerPool, int maxQueueSize) {
this.tasks =
Iterables.transform(
iterables,
iterable ->
(Runnable)
() -> {
try (Closeable ignored =
(iterable instanceof Closeable) ? (Closeable) iterable : () -> {}) {
for (T item : iterable) {
// exit manually because `ConcurrentLinkedQueue` can't be
// interrupted
if (closed) {
return;
}

queue.add(item);
}
} catch (IOException e) {
throw new RuntimeIOException(e, "Failed to close iterable");
}
})
iterables, iterable -> new Task<>(iterable, queue, closed, maxQueueSize))
.iterator();
this.workerPool = workerPool;
this.maxQueueSize = maxQueueSize;
// submit 2 tasks per worker at a time
this.taskFutures = new Future[2 * ThreadPools.WORKER_THREAD_POOL_SIZE];
this.taskFutures = new CompletableFuture[2 * ThreadPools.WORKER_THREAD_POOL_SIZE];
}

@Override
public void close() {
// close first, avoid new task submit
this.closed = true;
this.closed.set(true);

// cancel background tasks
for (Future<?> taskFuture : taskFutures) {
if (taskFuture != null && !taskFuture.isDone()) {
taskFuture.cancel(true);
try (Closer closer = Closer.create()) {
synchronized (this) {
yieldedTasks.forEach(closer::register);
yieldedTasks.clear();
}

// cancel background tasks and close continuations if any
for (CompletableFuture<Optional<Task<T>>> taskFuture : taskFutures) {
if (taskFuture != null) {
taskFuture.cancel(true);
taskFuture.thenAccept(
continuation -> {
if (continuation.isPresent()) {
try {
continuation.get().close();
} catch (IOException e) {
LOG.error("Task close failed", e);
}
}
});
}
}

// clean queue
this.queue.clear();
} catch (IOException e) {
throw new UncheckedIOException("Close failed", e);
}
// clean queue
this.queue.clear();
}

/**
Expand All @@ -107,15 +140,17 @@ public void close() {
*
* @return true if there are pending tasks, false otherwise
*/
private boolean checkTasks() {
private synchronized boolean checkTasks() {
Preconditions.checkState(!closed.get(), "Already closed");
boolean hasRunningTask = false;

for (int i = 0; i < taskFutures.length; i += 1) {
if (taskFutures[i] == null || taskFutures[i].isDone()) {
if (taskFutures[i] != null) {
// check for task failure and re-throw any exception
// check for task failure and re-throw any exception. Enqueue continuation if any.
try {
taskFutures[i].get();
Optional<Task<T>> continuation = taskFutures[i].get();
continuation.ifPresent(yieldedTasks::addLast);
} catch (ExecutionException e) {
if (e.getCause() instanceof RuntimeException) {
// rethrow a runtime exception
Expand All @@ -136,30 +171,33 @@ private boolean checkTasks() {
}
}

return !closed && (tasks.hasNext() || hasRunningTask);
return !closed.get() && (tasks.hasNext() || hasRunningTask);
}

private Future<?> submitNextTask() {
if (!closed && tasks.hasNext()) {
return workerPool.submit(tasks.next());
private CompletableFuture<Optional<Task<T>>> submitNextTask() {
if (!closed.get()) {
if (!yieldedTasks.isEmpty()) {
return CompletableFuture.supplyAsync(yieldedTasks.removeFirst(), workerPool);
} else if (tasks.hasNext()) {
return CompletableFuture.supplyAsync(tasks.next(), workerPool);
}
}
return null;
}

@Override
public synchronized boolean hasNext() {
Preconditions.checkState(!closed, "Already closed");

// if the consumer is processing records more slowly than the producers, then this check will
// prevent tasks from being submitted. while the producers are running, this will always
// return here before running checkTasks. when enough of the tasks are finished that the
// consumer catches up, then lots of new tasks will be submitted at once. this behavior is
// okay because it ensures that records are not stacking up waiting to be consumed and taking
// up memory.
//
// consumers that process results quickly will periodically exhaust the queue and submit new
// tasks when checkTasks runs. fast consumers should not be delayed.
if (!queue.isEmpty()) {
Preconditions.checkState(!closed.get(), "Already closed");

// If the consumer is processing records more slowly than the producers, the producers will
// eventually fill the queue and yield, returning continuations. Continuations and new tasks
// are started by checkTasks(). The check here prevents us from restarting continuations or
// starting new tasks too early (when queue is almost full) or too late (when queue is already
// emptied). Restarting too early would lead to tasks yielding very quickly (CPU waste on
// scheduling). Restarting too late would mean the consumer may need to wait for the tasks
// to produce new items. A consumer slower than producers shouldn't need to wait.
int queueLowWaterMark = maxQueueSize / 2;
if (queue.size() > queueLowWaterMark) {
return true;
}

Expand Down Expand Up @@ -192,4 +230,78 @@ public synchronized T next() {
return queue.poll();
}
}

private static class Task<T> implements Supplier<Optional<Task<T>>>, Closeable {
private final Iterable<T> input;
private final ConcurrentLinkedQueue<T> queue;
private final AtomicBoolean closed;
private final int approximateMaxQueueSize;

private Iterator<T> iterator = null;

Task(
Iterable<T> input,
ConcurrentLinkedQueue<T> queue,
AtomicBoolean closed,
int approximateMaxQueueSize) {
this.input = Preconditions.checkNotNull(input, "input cannot be null");
this.queue = Preconditions.checkNotNull(queue, "queue cannot be null");
this.closed = Preconditions.checkNotNull(closed, "closed cannot be null");
this.approximateMaxQueueSize = approximateMaxQueueSize;
}

@Override
public Optional<Task<T>> get() {
try {
if (iterator == null) {
iterator = input.iterator();
}

while (iterator.hasNext()) {
if (queue.size() >= approximateMaxQueueSize) {
// Yield when queue is over the size limit. Task will be resubmitted later and continue
// the work.
return Optional.of(this);
}

T next = iterator.next();
if (closed.get()) {
break;
}

queue.add(next);
}
} catch (Throwable e) {
try {
close();
} catch (IOException closeException) {
// self-suppression is not permitted
// (e and closeException to be the same is unlikely, but possible)
if (closeException != e) {
e.addSuppressed(closeException);
}
}

throw e;
}

try {
close();
} catch (IOException e) {
throw new UncheckedIOException("Close failed", e);
}

// The task is complete. Returning empty means there is no continuation that should be
// executed.
return Optional.empty();
}

@Override
public void close() throws IOException {
iterator = null;
if (input instanceof Closeable) {
((Closeable) input).close();
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,22 @@
import java.lang.reflect.Field;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.apache.iceberg.io.CloseableIterable;
import org.apache.iceberg.io.CloseableIterator;
import org.apache.iceberg.relocated.com.google.common.collect.HashMultiset;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMultiset;
import org.apache.iceberg.relocated.com.google.common.collect.Iterables;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.relocated.com.google.common.collect.Multiset;
import org.awaitility.Awaitility;
import org.junit.jupiter.api.Test;

Expand Down Expand Up @@ -133,6 +140,47 @@ public CloseableIterator<Integer> iterator() {
.untilAsserted(() -> assertThat(queue).as("Queue is not empty after cleaning").isEmpty());
}

@Test
public void limitQueueSize() throws IOException, IllegalAccessException, NoSuchFieldException {

List<Iterable<Integer>> iterables =
ImmutableList.of(
() -> IntStream.range(0, 100).iterator(),
() -> IntStream.range(0, 100).iterator(),
() -> IntStream.range(0, 100).iterator());

Multiset<Integer> expectedValues =
IntStream.range(0, 100)
.boxed()
.flatMap(i -> Stream.of(i, i, i))
.collect(ImmutableMultiset.toImmutableMultiset());

int maxQueueSize = 20;
ExecutorService executor = Executors.newCachedThreadPool();
ParallelIterable<Integer> parallelIterable =
new ParallelIterable<>(iterables, executor, maxQueueSize);
CloseableIterator<Integer> iterator = parallelIterable.iterator();
Field queueField = iterator.getClass().getDeclaredField("queue");
queueField.setAccessible(true);
ConcurrentLinkedQueue<?> queue = (ConcurrentLinkedQueue<?>) queueField.get(iterator);

Multiset<Integer> actualValues = HashMultiset.create();

while (iterator.hasNext()) {
assertThat(queue)
.as("iterator internal queue")
.hasSizeLessThanOrEqualTo(maxQueueSize + iterables.size());
actualValues.add(iterator.next());
}

assertThat(actualValues)
.as("multiset of values returned by the iterator")
.isEqualTo(expectedValues);

iterator.close();
executor.shutdownNow();
}

private void queueHasElements(CloseableIterator<Integer> iterator, Queue queue) {
assertThat(iterator.hasNext()).isTrue();
assertThat(iterator.next()).isNotNull();
Expand Down

0 comments on commit c1b0d52

Please sign in to comment.