diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/AsyncExecutionController.java b/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/AsyncExecutionController.java index fc2d23c176d889..eb8d6731d380ce 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/AsyncExecutionController.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/AsyncExecutionController.java @@ -111,6 +111,9 @@ public class AsyncExecutionController implements StateRequestHandler { /** The reference of epoch manager. */ final EpochManager epochManager; + /** The listener of context switch. */ + final SwitchContextListener switchContextListener; + /** * The parallel mode of epoch execution. Keep this field internal for now, until we could see * the concrete need for {@link ParallelMode#PARALLEL_BETWEEN_EPOCH} from average users. @@ -124,7 +127,8 @@ public AsyncExecutionController( int maxParallelism, int batchSize, long bufferTimeout, - int maxInFlightRecords) { + int maxInFlightRecords, + SwitchContextListener switchContextListener) { this.keyAccountingUnit = new KeyAccountingUnit<>(maxInFlightRecords); this.mailboxExecutor = mailboxExecutor; this.exceptionHandler = exceptionHandler; @@ -148,6 +152,7 @@ public AsyncExecutionController( "AEC-buffer-timeout")); this.epochManager = new EpochManager(this); + this.switchContextListener = switchContextListener; LOG.info( "Create AsyncExecutionController: batchSize {}, bufferTimeout {}, maxInFlightRecordNum {}, epochParallelMode {}", this.batchSize, @@ -189,6 +194,9 @@ public RecordContext buildContext(Object record, K key) { */ public void setCurrentContext(RecordContext switchingContext) { currentContext = switchingContext; + if (switchContextListener != null) { + switchContextListener.switchContext(switchingContext); + } } /** @@ -374,4 +382,9 @@ public StateExecutor getStateExecutor() { public int getInFlightRecordNum() { return inFlightRecordNum.get(); } + + /** A listener listens the key context switch. */ + public interface SwitchContextListener { + void switchContext(RecordContext context); + } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/operators/AbstractAsyncStateStreamOperator.java b/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/operators/AbstractAsyncStateStreamOperator.java index 6ea16456c5b475..c74befba770df8 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/operators/AbstractAsyncStateStreamOperator.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/operators/AbstractAsyncStateStreamOperator.java @@ -95,7 +95,8 @@ public void initializeState(StreamTaskStateInitializer streamTaskStateManager) maxParallelism, asyncBufferSize, asyncBufferTimeout, - inFlightRecordsLimit); + inFlightRecordsLimit, + asyncKeyedStateBackend); asyncKeyedStateBackend.setup(asyncExecutionController); } else if (stateHandler.getKeyedStateBackend() != null) { throw new UnsupportedOperationException( diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/operators/AbstractAsyncStateStreamOperatorV2.java b/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/operators/AbstractAsyncStateStreamOperatorV2.java index fbf5acd30d735e..e686efab5ae72d 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/operators/AbstractAsyncStateStreamOperatorV2.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/operators/AbstractAsyncStateStreamOperatorV2.java @@ -92,7 +92,8 @@ public final void initializeState(StreamTaskStateInitializer streamTaskStateMana maxParallelism, asyncBufferSize, asyncBufferTimeout, - inFlightRecordsLimit); + inFlightRecordsLimit, + asyncKeyedStateBackend); asyncKeyedStateBackend.setup(asyncExecutionController); } else if (stateHandler.getKeyedStateBackend() != null) { throw new UnsupportedOperationException( diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java index 6225c1dac12591..248771fe26599d 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java @@ -250,6 +250,13 @@ public void setCurrentKey(K newKey) { KeyGroupRangeAssignment.assignToKeyGroup(newKey, numberOfKeyGroups)); } + /** Only used in {@code AsyncKeyedStateBackendAdaptor}. */ + public void setCurrentKeyAndKeyGroupIndex(K newKey, int newKeyGroupIndex) { + notifyKeySelected(newKey); + this.keyContext.setCurrentKey(newKey); + this.keyContext.setCurrentKeyGroupIndex(newKeyGroupIndex); + } + private void notifyKeySelected(K newKey) { // we prefer a for-loop over other iteration schemes for performance reasons here. for (int i = 0; i < keySelectionListeners.size(); ++i) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AsyncKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AsyncKeyedStateBackend.java index 096fbb25327465..a819703cab49bb 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AsyncKeyedStateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AsyncKeyedStateBackend.java @@ -22,6 +22,8 @@ import org.apache.flink.api.common.state.InternalCheckpointListener; import org.apache.flink.api.common.state.v2.State; import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.runtime.asyncprocessing.AsyncExecutionController; +import org.apache.flink.runtime.asyncprocessing.RecordContext; import org.apache.flink.runtime.asyncprocessing.StateExecutor; import org.apache.flink.runtime.asyncprocessing.StateRequestHandler; import org.apache.flink.runtime.state.v2.StateDescriptor; @@ -36,11 +38,12 @@ * in batch. */ @Internal -public interface AsyncKeyedStateBackend +public interface AsyncKeyedStateBackend extends Snapshotable>, InternalCheckpointListener, Disposable, - Closeable { + Closeable, + AsyncExecutionController.SwitchContextListener { /** * Initializes with some contexts. @@ -80,6 +83,10 @@ S createState( @Nonnull StateExecutor createStateExecutor(); + /** By default, a state backend does nothing when a key is switched in async processing. */ + @Override + default void switchContext(RecordContext context) {} + @Override void dispose(); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateBackend.java index c0c20e8193b92b..544704adc2cb11 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateBackend.java @@ -117,7 +117,7 @@ CheckpointableKeyedStateBackend createKeyedStateBackend( * backend. */ @Experimental - default AsyncKeyedStateBackend createAsyncKeyedStateBackend( + default AsyncKeyedStateBackend createAsyncKeyedStateBackend( KeyedStateBackendParameters parameters) throws Exception { throw new UnsupportedOperationException( "Don't support createAsyncKeyedStateBackend by default"); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/DefaultKeyedStateStoreV2.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/DefaultKeyedStateStoreV2.java index e8b6ee8661ff8c..fd019905d05acf 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/DefaultKeyedStateStoreV2.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/DefaultKeyedStateStoreV2.java @@ -33,7 +33,7 @@ /** Default implementation of KeyedStateStoreV2. */ public class DefaultKeyedStateStoreV2 implements KeyedStateStoreV2 { - private final AsyncKeyedStateBackend asyncKeyedStateBackend; + private final AsyncKeyedStateBackend asyncKeyedStateBackend; public DefaultKeyedStateStoreV2(@Nonnull AsyncKeyedStateBackend asyncKeyedStateBackend) { this.asyncKeyedStateBackend = Preconditions.checkNotNull(asyncKeyedStateBackend); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/adaptor/AsyncKeyedStateBackendAdaptor.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/adaptor/AsyncKeyedStateBackendAdaptor.java index 30dea2acd14859..80826e3f7dcdf5 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/adaptor/AsyncKeyedStateBackendAdaptor.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/adaptor/AsyncKeyedStateBackendAdaptor.java @@ -22,9 +22,11 @@ import org.apache.flink.api.common.state.InternalCheckpointListener; import org.apache.flink.api.common.state.v2.State; import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.runtime.asyncprocessing.RecordContext; import org.apache.flink.runtime.asyncprocessing.StateExecutor; import org.apache.flink.runtime.asyncprocessing.StateRequestHandler; import org.apache.flink.runtime.checkpoint.CheckpointOptions; +import org.apache.flink.runtime.state.AbstractKeyedStateBackend; import org.apache.flink.runtime.state.AsyncKeyedStateBackend; import org.apache.flink.runtime.state.CheckpointStreamFactory; import org.apache.flink.runtime.state.CheckpointableKeyedStateBackend; @@ -49,7 +51,7 @@ * * @param The key by which state is keyed. */ -public class AsyncKeyedStateBackendAdaptor implements AsyncKeyedStateBackend { +public class AsyncKeyedStateBackendAdaptor implements AsyncKeyedStateBackend { private final CheckpointableKeyedStateBackend keyedStateBackend; public AsyncKeyedStateBackendAdaptor(CheckpointableKeyedStateBackend keyedStateBackend) { @@ -95,6 +97,16 @@ public StateExecutor createStateExecutor() { return null; } + @Override + public void switchContext(RecordContext context) { + if (keyedStateBackend instanceof AbstractKeyedStateBackend) { + ((AbstractKeyedStateBackend) keyedStateBackend) + .setCurrentKeyAndKeyGroupIndex(context.getKey(), context.getKeyGroup()); + } else { + keyedStateBackend.setCurrentKey(context.getKey()); + } + } + @Override public void dispose() {} diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/api/operators/StreamOperatorStateHandler.java b/flink-runtime/src/main/java/org/apache/flink/streaming/api/operators/StreamOperatorStateHandler.java index e530febe4bb1c7..d1f9b3950fed04 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/api/operators/StreamOperatorStateHandler.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/api/operators/StreamOperatorStateHandler.java @@ -83,7 +83,7 @@ public class StreamOperatorStateHandler { protected static final Logger LOG = LoggerFactory.getLogger(StreamOperatorStateHandler.class); - @Nullable private final AsyncKeyedStateBackend asyncKeyedStateBackend; + @Nullable private final AsyncKeyedStateBackend asyncKeyedStateBackend; @Nullable private final KeyedStateStoreV2 keyedStateStoreV2; diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/asyncprocessing/AbstractStateIteratorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/asyncprocessing/AbstractStateIteratorTest.java index 5acbf9a8100411..6438cc1e7d00eb 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/asyncprocessing/AbstractStateIteratorTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/asyncprocessing/AbstractStateIteratorTest.java @@ -46,7 +46,14 @@ public void testPartialLoading() { TestIteratorStateExecutor stateExecutor = new TestIteratorStateExecutor(100, 3); AsyncExecutionController aec = new AsyncExecutionController( - new SyncMailboxExecutor(), (a, b) -> {}, stateExecutor, 1, 100, 1000, 1); + new SyncMailboxExecutor(), + (a, b) -> {}, + stateExecutor, + 1, + 100, + 1000, + 1, + null); stateExecutor.bindAec(aec); RecordContext recordContext = aec.buildContext("1", "key1"); aec.setCurrentContext(recordContext); @@ -77,7 +84,14 @@ public void testPartialLoadingWithReturnValue() { TestIteratorStateExecutor stateExecutor = new TestIteratorStateExecutor(100, 3); AsyncExecutionController aec = new AsyncExecutionController( - new SyncMailboxExecutor(), (a, b) -> {}, stateExecutor, 1, 100, 1000, 1); + new SyncMailboxExecutor(), + (a, b) -> {}, + stateExecutor, + 1, + 100, + 1000, + 1, + null); stateExecutor.bindAec(aec); RecordContext recordContext = aec.buildContext("1", "key1"); aec.setCurrentContext(recordContext); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/asyncprocessing/AsyncExecutionControllerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/asyncprocessing/AsyncExecutionControllerTest.java index b5f722e63e0ec2..86e4b6b035c16a 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/asyncprocessing/AsyncExecutionControllerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/asyncprocessing/AsyncExecutionControllerTest.java @@ -55,7 +55,7 @@ /** Test for {@link AsyncExecutionController}. */ class AsyncExecutionControllerTest { - AsyncExecutionController aec; + AsyncExecutionController aec; AtomicInteger output; TestValueState valueState; @@ -90,7 +90,7 @@ void setup( StateBackend testAsyncStateBackend = StateBackendTestUtils.buildAsyncStateBackend(stateSupplier, stateExecutor); assertThat(testAsyncStateBackend.supportsAsyncKeyedStateBackend()).isTrue(); - AsyncKeyedStateBackend asyncKeyedStateBackend; + AsyncKeyedStateBackend asyncKeyedStateBackend; try { asyncKeyedStateBackend = testAsyncStateBackend.createAsyncKeyedStateBackend(null); } catch (Exception e) { @@ -106,7 +106,8 @@ void setup( 128, batchSize, timeout, - maxInFlight); + maxInFlight, + null); asyncKeyedStateBackend.setup(aec); try { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestUtils.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestUtils.java index 67faa449e4a82f..2b7732433a2912 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestUtils.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestUtils.java @@ -110,7 +110,7 @@ public OperatorStateBackend createOperatorStateBackend( } } - private static class TestAsyncKeyedStateBackend implements AsyncKeyedStateBackend { + private static class TestAsyncKeyedStateBackend implements AsyncKeyedStateBackend { private final Supplier innerStateSupplier; private final StateExecutor stateExecutor; diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/AbstractKeyedStateTestBase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/AbstractKeyedStateTestBase.java index a6a62991da5b32..44792146fe4642 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/AbstractKeyedStateTestBase.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/AbstractKeyedStateTestBase.java @@ -75,7 +75,8 @@ void setup() { 1, 1, 1000, - 1); + 1, + null); exception = new AtomicReference<>(null); } @@ -124,9 +125,9 @@ public boolean supportsAsyncKeyedStateBackend() { } @Override - public AsyncKeyedStateBackend createAsyncKeyedStateBackend( + public AsyncKeyedStateBackend createAsyncKeyedStateBackend( KeyedStateBackendParameters parameters) { - return new AsyncKeyedStateBackend() { + return new AsyncKeyedStateBackend() { @Nonnull @Override public RunnableFuture> snapshot( diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/AbstractReducingStateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/AbstractReducingStateTest.java index 8ceb44d0899f5c..22096bdc36233b 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/AbstractReducingStateTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/AbstractReducingStateTest.java @@ -79,15 +79,16 @@ public void testMergeNamespace() throws Exception { ReduceFunction reducer = Integer::sum; ReducingStateDescriptor descriptor = new ReducingStateDescriptor<>("testState", reducer, BasicTypeInfo.INT_TYPE_INFO); - AsyncExecutionController aec = - new AsyncExecutionController( + AsyncExecutionController aec = + new AsyncExecutionController<>( new SyncMailboxExecutor(), (a, b) -> {}, new ReducingStateExecutor(), 1, 100, 10000, - 1); + 1, + null); AbstractReducingState reducingState = new AbstractReducingState<>(aec, descriptor); aec.setCurrentContext(aec.buildContext("test", "test")); diff --git a/flink-runtime/src/test/java/org/apache/flink/streaming/api/operators/InternalTimerServiceAsyncImplTest.java b/flink-runtime/src/test/java/org/apache/flink/streaming/api/operators/InternalTimerServiceAsyncImplTest.java index 4f78cd2035966f..1bf396bad81dea 100644 --- a/flink-runtime/src/test/java/org/apache/flink/streaming/api/operators/InternalTimerServiceAsyncImplTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/streaming/api/operators/InternalTimerServiceAsyncImplTest.java @@ -43,7 +43,7 @@ /** Test for {@link InternalTimerServiceAsyncImpl}. */ class InternalTimerServiceAsyncImplTest { - private AsyncExecutionController asyncExecutionController; + private AsyncExecutionController asyncExecutionController; private TestKeyContext keyContext; private TestProcessingTimeService processingTimeService; private InternalTimerServiceAsyncImpl service; @@ -59,14 +59,15 @@ public void handleException(String message, Throwable exception) { @BeforeEach void setup() throws Exception { asyncExecutionController = - new AsyncExecutionController( + new AsyncExecutionController<>( new SyncMailboxExecutor(), exceptionHandler, new MockStateExecutor(), 128, 2, 1000L, - 10); + 10, + null); // ensure arbitrary key is in the key group int totalKeyGroups = 128; KeyGroupRange testKeyGroupList = new KeyGroupRange(0, totalKeyGroups - 1); diff --git a/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStKeyedStateBackend.java b/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStKeyedStateBackend.java index 84f09edd5fe38a..9ad360358c14a4 100644 --- a/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStKeyedStateBackend.java +++ b/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStKeyedStateBackend.java @@ -69,7 +69,7 @@ * A KeyedStateBackend that stores its state in {@code ForSt}. This state backend can store very * large state that exceeds memory even disk to remote storage. */ -public class ForStKeyedStateBackend implements AsyncKeyedStateBackend { +public class ForStKeyedStateBackend implements AsyncKeyedStateBackend { private static final Logger LOG = LoggerFactory.getLogger(ForStKeyedStateBackend.class);