diff --git a/flink-core/src/main/java/org/apache/flink/core/state/StateFutureImpl.java b/flink-core/src/main/java/org/apache/flink/core/state/StateFutureImpl.java index 63de46cc91e638..d7db0be5311d21 100644 --- a/flink-core/src/main/java/org/apache/flink/core/state/StateFutureImpl.java +++ b/flink-core/src/main/java/org/apache/flink/core/state/StateFutureImpl.java @@ -166,7 +166,8 @@ public StateFuture thenCombine( } /** - * Make a new future based on context of this future. + * Make a new future based on context of this future. Subclasses need to overload this method to + * generate their own instances (if needed). * * @return the new created future. */ diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/ContextStateFutureImpl.java b/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/ContextStateFutureImpl.java index 42d9dd96a6ac26..8d3338e957946c 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/ContextStateFutureImpl.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/ContextStateFutureImpl.java @@ -32,6 +32,8 @@ *
  • 2. -1 when future completed. *
  • 3. +1 when callback registered. *
  • 4. -1 when callback finished. + *
  • Please refer to {@code ContextStateFutureImplTest} where the reference counting is carefully + * tested. */ public class ContextStateFutureImpl extends StateFutureImpl { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/asyncprocessing/AsyncExecutionControllerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/asyncprocessing/AsyncExecutionControllerTest.java index 07565458a23f5e..e7a0cf83275fb3 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/asyncprocessing/AsyncExecutionControllerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/asyncprocessing/AsyncExecutionControllerTest.java @@ -158,7 +158,7 @@ void testBasicRun() { * An AsyncExecutionController for testing purpose, which integrates with basic buffer * mechanism. */ - class TestAsyncExecutionController extends AsyncExecutionController { + static class TestAsyncExecutionController extends AsyncExecutionController { LinkedList> activeBuffer; @@ -203,7 +203,7 @@ void migrateBlockingToActive() { } /** Simulate the underlying state that is actually used to execute the request. */ - class TestUnderlyingState { + static class TestUnderlyingState { private HashMap hashMap; @@ -220,7 +220,7 @@ public void update(String key, Integer val) { } } - class TestValueState implements ValueState { + static class TestValueState implements ValueState { private AsyncExecutionController asyncExecutionController; @@ -258,7 +258,7 @@ public StateFuture asyncUpdate(Integer value) { * A brief implementation of {@link StateExecutor}, to illustrate the interaction between AEC * and StateExecutor. */ - class TestStateExecutor implements StateExecutor { + static class TestStateExecutor implements StateExecutor { public TestStateExecutor() {} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/asyncprocessing/ContextStateFutureImplTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/asyncprocessing/ContextStateFutureImplTest.java new file mode 100644 index 00000000000000..ca18cb73c00efc --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/asyncprocessing/ContextStateFutureImplTest.java @@ -0,0 +1,234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.asyncprocessing; + +import org.apache.flink.core.state.StateFutureUtils; + +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.LinkedList; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Tests for {@link ContextStateFutureImpl}. */ +public class ContextStateFutureImplTest { + + @Test + public void testThenApply() { + SingleStepRunner runner = new SingleStepRunner(); + KeyAccountingUnit keyAccountingUnit = new KeyAccountingUnit<>(); + RecordContext recordContext = + new RecordContext<>(keyAccountingUnit, "a", "b"); + + // validate + ContextStateFutureImpl future = + new ContextStateFutureImpl<>(runner::submit, recordContext); + assertThat(recordContext.getReferenceCount()).isEqualTo(1); + future.thenApply((v) -> 1L); + future.complete(null); + assertThat(runner.runThrough()).isTrue(); + assertThat(recordContext.getReferenceCount()).isEqualTo(0); + + // validate completion before callback + future = new ContextStateFutureImpl<>(runner::submit, recordContext); + assertThat(recordContext.getReferenceCount()).isEqualTo(1); + future.complete(null); + assertThat(recordContext.getReferenceCount()).isEqualTo(0); + future.thenApply((v) -> 1L); + assertThat(recordContext.getReferenceCount()).isEqualTo(0); + assertThat(runner.runThrough()).isFalse(); + } + + @Test + public void testThenAccept() { + SingleStepRunner runner = new SingleStepRunner(); + KeyAccountingUnit keyAccountingUnit = new KeyAccountingUnit<>(); + RecordContext recordContext = + new RecordContext<>(keyAccountingUnit, "a", "b"); + + // validate + ContextStateFutureImpl future = + new ContextStateFutureImpl<>(runner::submit, recordContext); + assertThat(recordContext.getReferenceCount()).isEqualTo(1); + future.thenAccept((v) -> {}); + future.complete(null); + assertThat(runner.runThrough()).isTrue(); + assertThat(recordContext.getReferenceCount()).isEqualTo(0); + + // validate completion before callback + future = new ContextStateFutureImpl<>(runner::submit, recordContext); + assertThat(recordContext.getReferenceCount()).isEqualTo(1); + future.complete(null); + assertThat(recordContext.getReferenceCount()).isEqualTo(0); + future.thenAccept((v) -> {}); + assertThat(recordContext.getReferenceCount()).isEqualTo(0); + assertThat(runner.runThrough()).isFalse(); + } + + @Test + public void testThenCompose() { + SingleStepRunner runner = new SingleStepRunner(); + KeyAccountingUnit keyAccountingUnit = new KeyAccountingUnit<>(); + RecordContext recordContext = + new RecordContext<>(keyAccountingUnit, "a", "b"); + + // validate + ContextStateFutureImpl future = + new ContextStateFutureImpl<>(runner::submit, recordContext); + assertThat(recordContext.getReferenceCount()).isEqualTo(1); + future.thenCompose((v) -> StateFutureUtils.completedFuture(1L)); + future.complete(null); + assertThat(runner.runThrough()).isTrue(); + assertThat(recordContext.getReferenceCount()).isEqualTo(0); + + // validate completion before callback + future = new ContextStateFutureImpl<>(runner::submit, recordContext); + assertThat(recordContext.getReferenceCount()).isEqualTo(1); + future.complete(null); + assertThat(recordContext.getReferenceCount()).isEqualTo(0); + future.thenCompose((v) -> StateFutureUtils.completedFuture(1L)); + assertThat(recordContext.getReferenceCount()).isEqualTo(0); + assertThat(runner.runThrough()).isFalse(); + } + + @Test + public void testThenCombine() { + SingleStepRunner runner = new SingleStepRunner(); + KeyAccountingUnit keyAccountingUnit = new KeyAccountingUnit<>(); + RecordContext recordContext = + new RecordContext<>(keyAccountingUnit, "a", "b"); + + // validate + ContextStateFutureImpl future1 = + new ContextStateFutureImpl<>(runner::submit, recordContext); + ContextStateFutureImpl future2 = + new ContextStateFutureImpl<>(runner::submit, recordContext); + assertThat(recordContext.getReferenceCount()).isEqualTo(2); + future1.thenCombine(future2, (v1, v2) -> 1L); + future1.complete(null); + future2.complete(null); + assertThat(runner.runThrough()).isTrue(); + assertThat(recordContext.getReferenceCount()).isEqualTo(0); + + // validate future1 completion before callback + future1 = new ContextStateFutureImpl<>(runner::submit, recordContext); + future2 = new ContextStateFutureImpl<>(runner::submit, recordContext); + assertThat(recordContext.getReferenceCount()).isEqualTo(2); + future1.complete(null); + future1.thenCombine(future2, (v1, v2) -> 1L); + assertThat(recordContext.getReferenceCount()).isGreaterThan(1); + future2.complete(null); + assertThat(runner.runThrough()).isTrue(); + assertThat(recordContext.getReferenceCount()).isEqualTo(0); + + // validate future2 completion before callback + future1 = new ContextStateFutureImpl<>(runner::submit, recordContext); + future2 = new ContextStateFutureImpl<>(runner::submit, recordContext); + assertThat(recordContext.getReferenceCount()).isEqualTo(2); + future2.complete(null); + future1.thenCombine(future2, (v1, v2) -> 1L); + assertThat(recordContext.getReferenceCount()).isGreaterThan(1); + future1.complete(null); + assertThat(runner.runThrough()).isTrue(); + assertThat(recordContext.getReferenceCount()).isEqualTo(0); + } + + @Test + public void testComplex() { + SingleStepRunner runner = new SingleStepRunner(); + KeyAccountingUnit keyAccountingUnit = new KeyAccountingUnit<>(); + RecordContext recordContext = + new RecordContext<>(keyAccountingUnit, "a", "b"); + + for (int i = 0; i < 32; i++) { // 2^5 for completion status combination + ArrayList> futures = new ArrayList<>(6); + for (int j = 0; j < 5; j++) { + ContextStateFutureImpl future = + new ContextStateFutureImpl<>(runner::submit, recordContext); + futures.add(future); + if (((i >>> j) & 1) == 1) { + future.complete(null); + } + } + + StateFutureUtils.combineAll( + Arrays.asList(futures.get(0), futures.get(1), futures.get(2))) + .thenCombine( + futures.get(3), + (a, b) -> { + return 1L; + }) + .thenCompose( + (a) -> { + return futures.get(4); + }) + .thenApply( + (e) -> { + return 2L; + }) + .thenAccept((b) -> {}); + + for (int j = 0; j < 5; j++) { + if (((i >>> j) & 1) == 0) { + futures.get(j).complete(null); + } + } + + if (i == 31) { + // all completed + assertThat(recordContext.getReferenceCount()) + .withFailMessage("The reference counted tests fail for profile id %d", i) + .isEqualTo(0); + assertThat(runner.runThrough()) + .withFailMessage("The reference counted tests fail for profile id %d", i) + .isFalse(); + } else { + assertThat(recordContext.getReferenceCount()) + .withFailMessage("The reference counted tests fail for profile id %d", i) + .isGreaterThan(0); + assertThat(runner.runThrough()) + .withFailMessage("The reference counted tests fail for profile id %d", i) + .isTrue(); + assertThat(recordContext.getReferenceCount()) + .withFailMessage("The reference counted tests fail for profile id %d", i) + .isEqualTo(0); + } + } + } + + /** A runner that performs single-step debugging. */ + public static class SingleStepRunner { + private final LinkedList runnables = new LinkedList<>(); + + public void submit(Runnable runnable) { + runnables.add(runnable); + } + + public boolean runThrough() { + boolean run = false; + while (!runnables.isEmpty()) { + runnables.poll().run(); + run = true; + } + return run; + } + } +}