Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Core: Limit memory used by ParallelIterable #10691

Merged
merged 13 commits into from
Jul 22, 2024
158 changes: 118 additions & 40 deletions core/src/main/java/org/apache/iceberg/util/ParallelIterable.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,65 +20,69 @@

import java.io.Closeable;
import java.io.IOException;
import java.util.ArrayDeque;
import java.util.Deque;
import java.util.Iterator;
import java.util.NoSuchElementException;
import java.util.Optional;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.iceberg.exceptions.RuntimeIOException;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.iceberg.io.CloseableGroup;
import org.apache.iceberg.io.CloseableIterable;
import org.apache.iceberg.io.CloseableIterator;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
import org.apache.iceberg.relocated.com.google.common.collect.Iterables;
import org.apache.iceberg.relocated.com.google.common.io.Closer;

public class ParallelIterable<T> extends CloseableGroup implements CloseableIterable<T> {

private static final int DEFAULT_MAX_QUEUE_SIZE = 10_000;
findepi marked this conversation as resolved.
Show resolved Hide resolved

private final Iterable<? extends Iterable<T>> iterables;
private final ExecutorService workerPool;

// Bound for number of items in the queue to limit memory consumption
// even in the case when input iterables are large.
private final int approximateMaxQueueSize;

public ParallelIterable(Iterable<? extends Iterable<T>> iterables, ExecutorService workerPool) {
this.iterables = iterables;
this.workerPool = workerPool;
this(iterables, workerPool, DEFAULT_MAX_QUEUE_SIZE);
}

public ParallelIterable(
Iterable<? extends Iterable<T>> iterables,
ExecutorService workerPool,
int approximateMaxQueueSize) {
this.iterables = Preconditions.checkNotNull(iterables, "Input iterables cannot be null");
this.workerPool = Preconditions.checkNotNull(workerPool, "Worker pool cannot be null");
this.approximateMaxQueueSize = approximateMaxQueueSize;
}

@Override
public CloseableIterator<T> iterator() {
ParallelIterator<T> iter = new ParallelIterator<>(iterables, workerPool);
ParallelIterator<T> iter =
new ParallelIterator<>(iterables, workerPool, approximateMaxQueueSize);
addCloseable(iter);
return iter;
}

private static class ParallelIterator<T> implements CloseableIterator<T> {
private final Iterator<Runnable> tasks;
private final Iterator<Task<T>> tasks;
private final Deque<Task<T>> yieldedTasks = new ArrayDeque<>();
findepi marked this conversation as resolved.
Show resolved Hide resolved
private final ExecutorService workerPool;
private final Future<?>[] taskFutures;
private final Future<Optional<Task<T>>>[] taskFutures;
private final ConcurrentLinkedQueue<T> queue = new ConcurrentLinkedQueue<>();
private volatile boolean closed = false;
private final AtomicBoolean closed = new AtomicBoolean(false);

private ParallelIterator(
Iterable<? extends Iterable<T>> iterables, ExecutorService workerPool) {
Iterable<? extends Iterable<T>> iterables, ExecutorService workerPool, int maxQueueSize) {
this.tasks =
Iterables.transform(
iterables,
iterable ->
(Runnable)
() -> {
try (Closeable ignored =
(iterable instanceof Closeable) ? (Closeable) iterable : () -> {}) {
for (T item : iterable) {
// exit manually because `ConcurrentLinkedQueue` can't be
// interrupted
if (closed) {
return;
}

queue.add(item);
}
} catch (IOException e) {
throw new RuntimeIOException(e, "Failed to close iterable");
}
})
iterables, iterable -> new Task<>(iterable, queue, closed, maxQueueSize))
.iterator();
this.workerPool = workerPool;
// submit 2 tasks per worker at a time
Expand All @@ -88,16 +92,26 @@ private ParallelIterator(
@Override
public void close() {
// close first, avoid new task submit
this.closed = true;
this.closed.set(true);

try (Closer closer = Closer.create()) {
yieldedTasks.forEach(closer::register);
yieldedTasks.clear();

// cancel background tasks
for (Future<?> taskFuture : taskFutures) {
if (taskFuture != null && !taskFuture.isDone()) {
taskFuture.cancel(true);
// TODO close input iterables that were not started yet
findepi marked this conversation as resolved.
Show resolved Hide resolved

// cancel background tasks
for (Future<?> taskFuture : taskFutures) {
if (taskFuture != null && !taskFuture.isDone()) {
taskFuture.cancel(true);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mayInterruptIfRunning flag has no effect in CompletableFuture, so I guess it's possible the thread cannot be released here?

}
}

// clean queue
this.queue.clear();
} catch (IOException e) {
throw new RuntimeException("Close failed", e);
findepi marked this conversation as resolved.
Show resolved Hide resolved
}
// clean queue
this.queue.clear();
}

/**
Expand All @@ -113,9 +127,10 @@ private boolean checkTasks() {
for (int i = 0; i < taskFutures.length; i += 1) {
if (taskFutures[i] == null || taskFutures[i].isDone()) {
if (taskFutures[i] != null) {
Optional<Task<T>> continuation;
// check for task failure and re-throw any exception
findepi marked this conversation as resolved.
Show resolved Hide resolved
try {
taskFutures[i].get();
continuation = taskFutures[i].get();
} catch (ExecutionException e) {
if (e.getCause() instanceof RuntimeException) {
// rethrow a runtime exception
Expand All @@ -126,6 +141,7 @@ private boolean checkTasks() {
} catch (InterruptedException e) {
throw new RuntimeException("Interrupted while running parallel task", e);
}
continuation.ifPresent(yieldedTasks::addLast);
findepi marked this conversation as resolved.
Show resolved Hide resolved
}

taskFutures[i] = submitNextTask();
Expand All @@ -136,19 +152,20 @@ private boolean checkTasks() {
}
}

return !closed && (tasks.hasNext() || hasRunningTask);
return !closed.get() && (tasks.hasNext() || hasRunningTask);
}

private Future<?> submitNextTask() {
if (!closed && tasks.hasNext()) {
return workerPool.submit(tasks.next());
private Future<Optional<Task<T>>> submitNextTask() {
if (!closed.get() && (!yieldedTasks.isEmpty() || tasks.hasNext())) {
return workerPool.submit(
!yieldedTasks.isEmpty() ? yieldedTasks.removeFirst() : tasks.next());
findepi marked this conversation as resolved.
Show resolved Hide resolved
}
return null;
}

@Override
public synchronized boolean hasNext() {
Preconditions.checkState(!closed, "Already closed");
Preconditions.checkState(!closed.get(), "Already closed");

// if the consumer is processing records more slowly than the producers, then this check will
// prevent tasks from being submitted. while the producers are running, this will always
Expand Down Expand Up @@ -192,4 +209,65 @@ public synchronized T next() {
return queue.poll();
}
}

private static class Task<T> implements Callable<Optional<Task<T>>>, Closeable {
private final Iterable<T> input;
private final ConcurrentLinkedQueue<T> queue;
private final AtomicBoolean closed;
private final int approximateMaxQueueSize;

private Iterator<T> iterator;
findepi marked this conversation as resolved.
Show resolved Hide resolved

Task(
Iterable<T> input,
ConcurrentLinkedQueue<T> queue,
AtomicBoolean closed,
int approximateMaxQueueSize) {
this.input = Preconditions.checkNotNull(input, "input cannot be null");
this.queue = Preconditions.checkNotNull(queue, "queue cannot be null");
this.closed = Preconditions.checkNotNull(closed, "closed cannot be null");
this.approximateMaxQueueSize = approximateMaxQueueSize;
}

@Override
public Optional<Task<T>> call() throws Exception {
try {
if (iterator == null) {
findepi marked this conversation as resolved.
Show resolved Hide resolved
iterator = input.iterator();
}
while (iterator.hasNext()) {
findepi marked this conversation as resolved.
Show resolved Hide resolved
if (queue.size() >= approximateMaxQueueSize) {
// yield
findepi marked this conversation as resolved.
Show resolved Hide resolved
return Optional.of(this);
findepi marked this conversation as resolved.
Show resolved Hide resolved
}
T next = iterator.next();
findepi marked this conversation as resolved.
Show resolved Hide resolved
if (closed.get()) {
break;
}
queue.add(next);
}
} catch (Throwable e) {
try {
close();
} catch (IOException closeException) {
// self-suppression is not permitted
// (e and closeException to be the same is unlikely, but possible)
if (closeException != e) {
e.addSuppressed(closeException);
}
}
throw e;
findepi marked this conversation as resolved.
Show resolved Hide resolved
}
close();
findepi marked this conversation as resolved.
Show resolved Hide resolved
return Optional.empty();
findepi marked this conversation as resolved.
Show resolved Hide resolved
}

@Override
public void close() throws IOException {
iterator = null;
if (input instanceof Closeable) {
((Closeable) input).close();
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,22 @@
import java.lang.reflect.Field;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.apache.iceberg.io.CloseableIterable;
import org.apache.iceberg.io.CloseableIterator;
import org.apache.iceberg.relocated.com.google.common.collect.HashMultiset;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMultiset;
import org.apache.iceberg.relocated.com.google.common.collect.Iterables;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.relocated.com.google.common.collect.Multiset;
import org.awaitility.Awaitility;
import org.junit.jupiter.api.Test;

Expand Down Expand Up @@ -133,6 +140,47 @@ public CloseableIterator<Integer> iterator() {
.untilAsserted(() -> assertThat(queue).as("Queue is not empty after cleaning").isEmpty());
}

@Test
public void limitQueueSize() throws IOException, IllegalAccessException, NoSuchFieldException {
findepi marked this conversation as resolved.
Show resolved Hide resolved

List<Iterable<Integer>> iterables =
ImmutableList.of(
() -> IntStream.range(0, 100).iterator(),
() -> IntStream.range(0, 100).iterator(),
() -> IntStream.range(0, 100).iterator());

Multiset<Integer> expectedValues =
IntStream.range(0, 100)
.boxed()
.flatMap(i -> Stream.of(i, i, i))
.collect(ImmutableMultiset.toImmutableMultiset());

int maxQueueSize = 20;
ExecutorService executor = Executors.newCachedThreadPool();
findepi marked this conversation as resolved.
Show resolved Hide resolved
ParallelIterable<Integer> parallelIterable =
new ParallelIterable<>(iterables, executor, maxQueueSize);
CloseableIterator<Integer> iterator = parallelIterable.iterator();
findepi marked this conversation as resolved.
Show resolved Hide resolved
Field queueField = iterator.getClass().getDeclaredField("queue");
queueField.setAccessible(true);
ConcurrentLinkedQueue<?> queue = (ConcurrentLinkedQueue<?>) queueField.get(iterator);

Multiset<Integer> actualValues = HashMultiset.create();

while (iterator.hasNext()) {
findepi marked this conversation as resolved.
Show resolved Hide resolved
assertThat(queue)
.as("iterator internal queue")
.hasSizeLessThanOrEqualTo(maxQueueSize + iterables.size());
actualValues.add(iterator.next());
}

assertThat(actualValues)
.as("multiset of values returned by the iterator")
.isEqualTo(expectedValues);

iterator.close();
executor.shutdownNow();
}

private void queueHasElements(CloseableIterator<Integer> iterator, Queue queue) {
assertThat(iterator.hasNext()).isTrue();
assertThat(iterator.next()).isNotNull();
Expand Down