From ccf57e51907aa7609b8ddeb684d02345e0ccad69 Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Fri, 12 Jul 2024 14:34:52 +0200 Subject: [PATCH] Limit memory used by ParallelIterable ParallelIterable schedules 2 * WORKER_THREAD_POOL_SIZE tasks for processing input iterables. This defaults to 2 * # CPU cores. When one or some of the input iterables are considerable in size and the ParallelIterable consumer is not quick enough, this could result in unbounded allocation inside `ParallelIterator.queue`. This commit bounds the queue. When queue is full, the tasks yield and get removed from the executor. They are resumed when consumer catches up. --- .../apache/iceberg/util/ParallelIterable.java | 142 +++++++++++++----- .../iceberg/util/TestParallelIterable.java | 48 ++++++ 2 files changed, 156 insertions(+), 34 deletions(-) diff --git a/core/src/main/java/org/apache/iceberg/util/ParallelIterable.java b/core/src/main/java/org/apache/iceberg/util/ParallelIterable.java index d7221e7d4545..af961fe8d266 100644 --- a/core/src/main/java/org/apache/iceberg/util/ParallelIterable.java +++ b/core/src/main/java/org/apache/iceberg/util/ParallelIterable.java @@ -20,13 +20,17 @@ import java.io.Closeable; import java.io.IOException; +import java.util.ArrayDeque; +import java.util.Deque; import java.util.Iterator; import java.util.NoSuchElementException; +import java.util.Optional; +import java.util.concurrent.Callable; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; -import org.apache.iceberg.exceptions.RuntimeIOException; +import java.util.concurrent.atomic.AtomicBoolean; import org.apache.iceberg.io.CloseableGroup; import org.apache.iceberg.io.CloseableIterable; import org.apache.iceberg.io.CloseableIterator; @@ -34,51 +38,50 @@ import org.apache.iceberg.relocated.com.google.common.collect.Iterables; public class ParallelIterable extends CloseableGroup implements CloseableIterable { + + private static final int DEFAULT_MAX_QUEUE_SIZE = 10_000; + private final Iterable> iterables; private final ExecutorService workerPool; + // Bound for number of items in the queue to limit memory consumption + // even in the case when input iterables are + private final int approximateMaxQueueSize; + public ParallelIterable(Iterable> iterables, ExecutorService workerPool) { - this.iterables = iterables; - this.workerPool = workerPool; + this(iterables, workerPool, DEFAULT_MAX_QUEUE_SIZE); + } + + public ParallelIterable( + Iterable> iterables, + ExecutorService workerPool, + int approximateMaxQueueSize) { + this.iterables = Preconditions.checkNotNull(iterables, "Input iterables cannot be null"); + this.workerPool = Preconditions.checkNotNull(workerPool, "Worker pool cannot be null"); + this.approximateMaxQueueSize = approximateMaxQueueSize; } @Override public CloseableIterator iterator() { - ParallelIterator iter = new ParallelIterator<>(iterables, workerPool); + ParallelIterator iter = + new ParallelIterator<>(iterables, workerPool, approximateMaxQueueSize); addCloseable(iter); return iter; } private static class ParallelIterator implements CloseableIterator { - private final Iterator tasks; + private final Iterator> tasks; + private final Deque> yieldedTasks = new ArrayDeque<>(); private final ExecutorService workerPool; - private final Future[] taskFutures; + private final Future>>[] taskFutures; private final ConcurrentLinkedQueue queue = new ConcurrentLinkedQueue<>(); - private volatile boolean closed = false; + private final AtomicBoolean closed = new AtomicBoolean(false); private ParallelIterator( - Iterable> iterables, ExecutorService workerPool) { + Iterable> iterables, ExecutorService workerPool, int maxQueueSize) { this.tasks = Iterables.transform( - iterables, - iterable -> - (Runnable) - () -> { - try (Closeable ignored = - (iterable instanceof Closeable) ? (Closeable) iterable : () -> {}) { - for (T item : iterable) { - // exit manually because `ConcurrentLinkedQueue` can't be - // interrupted - if (closed) { - return; - } - - queue.add(item); - } - } catch (IOException e) { - throw new RuntimeIOException(e, "Failed to close iterable"); - } - }) + iterables, iterable -> new Task<>(iterable, queue, closed, maxQueueSize)) .iterator(); this.workerPool = workerPool; // submit 2 tasks per worker at a time @@ -88,7 +91,18 @@ private ParallelIterator( @Override public void close() { // close first, avoid new task submit - this.closed = true; + this.closed.set(true); + + for (Task task : yieldedTasks) { + try { + task.close(); + } catch (Exception e) { + throw new RuntimeException("Close failed", e); + } + } + yieldedTasks.clear(); + + // TODO close input iterables that were not started yet // cancel background tasks for (Future taskFuture : taskFutures) { @@ -113,9 +127,10 @@ private boolean checkTasks() { for (int i = 0; i < taskFutures.length; i += 1) { if (taskFutures[i] == null || taskFutures[i].isDone()) { if (taskFutures[i] != null) { + Optional> continuation; // check for task failure and re-throw any exception try { - taskFutures[i].get(); + continuation = taskFutures[i].get(); } catch (ExecutionException e) { if (e.getCause() instanceof RuntimeException) { // rethrow a runtime exception @@ -126,6 +141,7 @@ private boolean checkTasks() { } catch (InterruptedException e) { throw new RuntimeException("Interrupted while running parallel task", e); } + continuation.ifPresent(yieldedTasks::addLast); } taskFutures[i] = submitNextTask(); @@ -136,19 +152,20 @@ private boolean checkTasks() { } } - return !closed && (tasks.hasNext() || hasRunningTask); + return !closed.get() && (tasks.hasNext() || hasRunningTask); } - private Future submitNextTask() { - if (!closed && tasks.hasNext()) { - return workerPool.submit(tasks.next()); + private Future>> submitNextTask() { + if (!closed.get() && (!yieldedTasks.isEmpty() || tasks.hasNext())) { + return workerPool.submit( + !yieldedTasks.isEmpty() ? yieldedTasks.removeFirst() : tasks.next()); } return null; } @Override public synchronized boolean hasNext() { - Preconditions.checkState(!closed, "Already closed"); + Preconditions.checkState(!closed.get(), "Already closed"); // if the consumer is processing records more slowly than the producers, then this check will // prevent tasks from being submitted. while the producers are running, this will always @@ -192,4 +209,61 @@ public synchronized T next() { return queue.poll(); } } + + private static class Task implements Callable>>, AutoCloseable { + private final Iterable input; + private final ConcurrentLinkedQueue queue; + private final AtomicBoolean closed; + private final int approximateMaxQueueSize; + + private Iterator iterator; + + public Task( + Iterable input, + ConcurrentLinkedQueue queue, + AtomicBoolean closed, + int approximateMaxQueueSize) { + this.input = Preconditions.checkNotNull(input, "input cannot be null"); + this.queue = Preconditions.checkNotNull(queue, "queue cannot be null"); + this.closed = Preconditions.checkNotNull(closed, "closed cannot be null"); + this.approximateMaxQueueSize = approximateMaxQueueSize; + } + + @Override + public Optional> call() throws Exception { + try { + if (iterator == null) { + iterator = input.iterator(); + } + while (!closed.get() && iterator.hasNext()) { + if (queue.size() >= approximateMaxQueueSize) { + // yield + return Optional.of(this); + } + queue.add(iterator.next()); + } + } catch (Throwable e) { + try { + close(); + } catch (IOException closeException) { + // self-suppression is not permitted + // (e and closeException to be the same is unlikely, but possible) + if (closeException != e) { + e.addSuppressed(closeException); + } + } + throw e; + } + close(); + return Optional.empty(); + } + + @Override + public void close() throws Exception { + iterator = null; + if (input instanceof Closeable) { + ((Closeable) input).close(); + } + } + } } diff --git a/core/src/test/java/org/apache/iceberg/util/TestParallelIterable.java b/core/src/test/java/org/apache/iceberg/util/TestParallelIterable.java index af9c6ec5212c..d955962e2751 100644 --- a/core/src/test/java/org/apache/iceberg/util/TestParallelIterable.java +++ b/core/src/test/java/org/apache/iceberg/util/TestParallelIterable.java @@ -20,15 +20,22 @@ import static org.assertj.core.api.Assertions.assertThat; +import com.google.common.collect.HashMultiset; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMultiset; +import com.google.common.collect.Multiset; import java.io.IOException; import java.lang.reflect.Field; import java.util.Collections; import java.util.Iterator; +import java.util.List; import java.util.Queue; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; +import java.util.stream.IntStream; +import java.util.stream.Stream; import org.apache.iceberg.io.CloseableIterable; import org.apache.iceberg.io.CloseableIterator; import org.apache.iceberg.relocated.com.google.common.collect.Iterables; @@ -133,6 +140,47 @@ public CloseableIterator iterator() { .untilAsserted(() -> assertThat(queue).as("Queue is not empty after cleaning").isEmpty()); } + @Test + public void limitQueueSize() throws IOException, IllegalAccessException, NoSuchFieldException { + + List> iterables = + ImmutableList.of( + () -> IntStream.range(0, 100).iterator(), + () -> IntStream.range(0, 100).iterator(), + () -> IntStream.range(0, 100).iterator()); + + Multiset expectedValues = + IntStream.range(0, 100) + .boxed() + .flatMap(i -> Stream.of(i, i, i)) + .collect(ImmutableMultiset.toImmutableMultiset()); + + int maxQueueSize = 20; + ExecutorService executor = Executors.newCachedThreadPool(); + ParallelIterable parallelIterable = + new ParallelIterable<>(iterables, executor, maxQueueSize); + CloseableIterator iterator = parallelIterable.iterator(); + Field queueField = iterator.getClass().getDeclaredField("queue"); + queueField.setAccessible(true); + ConcurrentLinkedQueue queue = (ConcurrentLinkedQueue) queueField.get(iterator); + + Multiset actualValues = HashMultiset.create(); + + while (iterator.hasNext()) { + assertThat(queue) + .as("iterator internal queue") + .hasSizeLessThanOrEqualTo(maxQueueSize + iterables.size()); + actualValues.add(iterator.next()); + } + + assertThat(actualValues) + .as("multiset of values returned by the iterator") + .isEqualTo(expectedValues); + + iterator.close(); + executor.shutdownNow(); + } + private void queueHasElements(CloseableIterator iterator, Queue queue) { assertThat(iterator.hasNext()).isTrue(); assertThat(iterator.next()).isNotNull();