Skip to content

Commit

Permalink
[FLINK-36338] Properly handle KeyContext when using AsyncKeyedStateBa…
Browse files Browse the repository at this point in the history
…ckendAdaptor
  • Loading branch information
Zakelly committed Sep 20, 2024
1 parent 1ba8900 commit 90aa949
Show file tree
Hide file tree
Showing 16 changed files with 84 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ public class AsyncExecutionController<K> implements StateRequestHandler {
/** The reference of epoch manager. */
final EpochManager epochManager;

/** The listener of context switch. */
final SwitchContextListener<K> switchContextListener;

/**
* The parallel mode of epoch execution. Keep this field internal for now, until we could see
* the concrete need for {@link ParallelMode#PARALLEL_BETWEEN_EPOCH} from average users.
Expand All @@ -124,7 +127,8 @@ public AsyncExecutionController(
int maxParallelism,
int batchSize,
long bufferTimeout,
int maxInFlightRecords) {
int maxInFlightRecords,
SwitchContextListener<K> switchContextListener) {
this.keyAccountingUnit = new KeyAccountingUnit<>(maxInFlightRecords);
this.mailboxExecutor = mailboxExecutor;
this.exceptionHandler = exceptionHandler;
Expand All @@ -148,6 +152,7 @@ public AsyncExecutionController(
"AEC-buffer-timeout"));

this.epochManager = new EpochManager(this);
this.switchContextListener = switchContextListener;
LOG.info(
"Create AsyncExecutionController: batchSize {}, bufferTimeout {}, maxInFlightRecordNum {}, epochParallelMode {}",
this.batchSize,
Expand Down Expand Up @@ -189,6 +194,9 @@ public RecordContext<K> buildContext(Object record, K key) {
*/
public void setCurrentContext(RecordContext<K> switchingContext) {
currentContext = switchingContext;
if (switchContextListener != null) {
switchContextListener.switchContext(switchingContext);
}
}

/**
Expand Down Expand Up @@ -374,4 +382,9 @@ public StateExecutor getStateExecutor() {
public int getInFlightRecordNum() {
return inFlightRecordNum.get();
}

/** A listener listens the key context switch. */
public interface SwitchContextListener<K> {
void switchContext(RecordContext<K> context);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ public void initializeState(StreamTaskStateInitializer streamTaskStateManager)
maxParallelism,
asyncBufferSize,
asyncBufferTimeout,
inFlightRecordsLimit);
inFlightRecordsLimit,
asyncKeyedStateBackend);
asyncKeyedStateBackend.setup(asyncExecutionController);
} else if (stateHandler.getKeyedStateBackend() != null) {
throw new UnsupportedOperationException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ public final void initializeState(StreamTaskStateInitializer streamTaskStateMana
maxParallelism,
asyncBufferSize,
asyncBufferTimeout,
inFlightRecordsLimit);
inFlightRecordsLimit,
asyncKeyedStateBackend);
asyncKeyedStateBackend.setup(asyncExecutionController);
} else if (stateHandler.getKeyedStateBackend() != null) {
throw new UnsupportedOperationException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,13 @@ public void setCurrentKey(K newKey) {
KeyGroupRangeAssignment.assignToKeyGroup(newKey, numberOfKeyGroups));
}

/** Only used in {@code AsyncKeyedStateBackendAdaptor}. */
public void setCurrentKeyAndKeyGroupIndex(K newKey, int newKeyGroupIndex) {
notifyKeySelected(newKey);
this.keyContext.setCurrentKey(newKey);
this.keyContext.setCurrentKeyGroupIndex(newKeyGroupIndex);
}

private void notifyKeySelected(K newKey) {
// we prefer a for-loop over other iteration schemes for performance reasons here.
for (int i = 0; i < keySelectionListeners.size(); ++i) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import org.apache.flink.api.common.state.InternalCheckpointListener;
import org.apache.flink.api.common.state.v2.State;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.runtime.asyncprocessing.AsyncExecutionController;
import org.apache.flink.runtime.asyncprocessing.RecordContext;
import org.apache.flink.runtime.asyncprocessing.StateExecutor;
import org.apache.flink.runtime.asyncprocessing.StateRequestHandler;
import org.apache.flink.runtime.state.v2.StateDescriptor;
Expand All @@ -36,11 +38,12 @@
* in batch.
*/
@Internal
public interface AsyncKeyedStateBackend
public interface AsyncKeyedStateBackend<K>
extends Snapshotable<SnapshotResult<KeyedStateHandle>>,
InternalCheckpointListener,
Disposable,
Closeable {
Closeable,
AsyncExecutionController.SwitchContextListener<K> {

/**
* Initializes with some contexts.
Expand Down Expand Up @@ -80,6 +83,10 @@ <N, S extends State, SV> S createState(
@Nonnull
StateExecutor createStateExecutor();

/** By default, a state backend does nothing when a key is switched in async processing. */
@Override
default void switchContext(RecordContext<K> context) {}

@Override
void dispose();
}
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ <K> CheckpointableKeyedStateBackend<K> createKeyedStateBackend(
* backend.
*/
@Experimental
default <K> AsyncKeyedStateBackend createAsyncKeyedStateBackend(
default <K> AsyncKeyedStateBackend<K> createAsyncKeyedStateBackend(
KeyedStateBackendParameters<K> parameters) throws Exception {
throw new UnsupportedOperationException(
"Don't support createAsyncKeyedStateBackend by default");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
/** Default implementation of KeyedStateStoreV2. */
public class DefaultKeyedStateStoreV2 implements KeyedStateStoreV2 {

private final AsyncKeyedStateBackend asyncKeyedStateBackend;
private final AsyncKeyedStateBackend<?> asyncKeyedStateBackend;

public DefaultKeyedStateStoreV2(@Nonnull AsyncKeyedStateBackend asyncKeyedStateBackend) {
this.asyncKeyedStateBackend = Preconditions.checkNotNull(asyncKeyedStateBackend);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@
import org.apache.flink.api.common.state.InternalCheckpointListener;
import org.apache.flink.api.common.state.v2.State;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.runtime.asyncprocessing.RecordContext;
import org.apache.flink.runtime.asyncprocessing.StateExecutor;
import org.apache.flink.runtime.asyncprocessing.StateRequestHandler;
import org.apache.flink.runtime.checkpoint.CheckpointOptions;
import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
import org.apache.flink.runtime.state.AsyncKeyedStateBackend;
import org.apache.flink.runtime.state.CheckpointStreamFactory;
import org.apache.flink.runtime.state.CheckpointableKeyedStateBackend;
Expand All @@ -49,7 +51,7 @@
*
* @param <K> The key by which state is keyed.
*/
public class AsyncKeyedStateBackendAdaptor<K> implements AsyncKeyedStateBackend {
public class AsyncKeyedStateBackendAdaptor<K> implements AsyncKeyedStateBackend<K> {
private final CheckpointableKeyedStateBackend<K> keyedStateBackend;

public AsyncKeyedStateBackendAdaptor(CheckpointableKeyedStateBackend<K> keyedStateBackend) {
Expand Down Expand Up @@ -95,6 +97,16 @@ public StateExecutor createStateExecutor() {
return null;
}

@Override
public void switchContext(RecordContext<K> context) {
if (keyedStateBackend instanceof AbstractKeyedStateBackend) {
((AbstractKeyedStateBackend<K>) keyedStateBackend)
.setCurrentKeyAndKeyGroupIndex(context.getKey(), context.getKeyGroup());
} else {
keyedStateBackend.setCurrentKey(context.getKey());
}
}

@Override
public void dispose() {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ public class StreamOperatorStateHandler {

protected static final Logger LOG = LoggerFactory.getLogger(StreamOperatorStateHandler.class);

@Nullable private final AsyncKeyedStateBackend asyncKeyedStateBackend;
@Nullable private final AsyncKeyedStateBackend<?> asyncKeyedStateBackend;

@Nullable private final KeyedStateStoreV2 keyedStateStoreV2;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,14 @@ public void testPartialLoading() {
TestIteratorStateExecutor stateExecutor = new TestIteratorStateExecutor(100, 3);
AsyncExecutionController aec =
new AsyncExecutionController(
new SyncMailboxExecutor(), (a, b) -> {}, stateExecutor, 1, 100, 1000, 1);
new SyncMailboxExecutor(),
(a, b) -> {},
stateExecutor,
1,
100,
1000,
1,
null);
stateExecutor.bindAec(aec);
RecordContext<String> recordContext = aec.buildContext("1", "key1");
aec.setCurrentContext(recordContext);
Expand Down Expand Up @@ -77,7 +84,14 @@ public void testPartialLoadingWithReturnValue() {
TestIteratorStateExecutor stateExecutor = new TestIteratorStateExecutor(100, 3);
AsyncExecutionController aec =
new AsyncExecutionController(
new SyncMailboxExecutor(), (a, b) -> {}, stateExecutor, 1, 100, 1000, 1);
new SyncMailboxExecutor(),
(a, b) -> {},
stateExecutor,
1,
100,
1000,
1,
null);
stateExecutor.bindAec(aec);
RecordContext<String> recordContext = aec.buildContext("1", "key1");
aec.setCurrentContext(recordContext);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@

/** Test for {@link AsyncExecutionController}. */
class AsyncExecutionControllerTest {
AsyncExecutionController aec;
AsyncExecutionController<String> aec;
AtomicInteger output;
TestValueState valueState;

Expand Down Expand Up @@ -90,7 +90,7 @@ void setup(
StateBackend testAsyncStateBackend =
StateBackendTestUtils.buildAsyncStateBackend(stateSupplier, stateExecutor);
assertThat(testAsyncStateBackend.supportsAsyncKeyedStateBackend()).isTrue();
AsyncKeyedStateBackend asyncKeyedStateBackend;
AsyncKeyedStateBackend<String> asyncKeyedStateBackend;
try {
asyncKeyedStateBackend = testAsyncStateBackend.createAsyncKeyedStateBackend(null);
} catch (Exception e) {
Expand All @@ -106,7 +106,8 @@ void setup(
128,
batchSize,
timeout,
maxInFlight);
maxInFlight,
null);
asyncKeyedStateBackend.setup(aec);

try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ public OperatorStateBackend createOperatorStateBackend(
}
}

private static class TestAsyncKeyedStateBackend implements AsyncKeyedStateBackend {
private static class TestAsyncKeyedStateBackend<K> implements AsyncKeyedStateBackend<K> {

private final Supplier<org.apache.flink.api.common.state.v2.State> innerStateSupplier;
private final StateExecutor stateExecutor;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ void setup() {
1,
1,
1000,
1);
1,
null);
exception = new AtomicReference<>(null);
}

Expand Down Expand Up @@ -124,9 +125,9 @@ public boolean supportsAsyncKeyedStateBackend() {
}

@Override
public <K> AsyncKeyedStateBackend createAsyncKeyedStateBackend(
public <K> AsyncKeyedStateBackend<K> createAsyncKeyedStateBackend(
KeyedStateBackendParameters<K> parameters) {
return new AsyncKeyedStateBackend() {
return new AsyncKeyedStateBackend<K>() {
@Nonnull
@Override
public RunnableFuture<SnapshotResult<KeyedStateHandle>> snapshot(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,16 @@ public void testMergeNamespace() throws Exception {
ReduceFunction<Integer> reducer = Integer::sum;
ReducingStateDescriptor<Integer> descriptor =
new ReducingStateDescriptor<>("testState", reducer, BasicTypeInfo.INT_TYPE_INFO);
AsyncExecutionController aec =
new AsyncExecutionController(
AsyncExecutionController<String> aec =
new AsyncExecutionController<>(
new SyncMailboxExecutor(),
(a, b) -> {},
new ReducingStateExecutor(),
1,
100,
10000,
1);
1,
null);
AbstractReducingState<String, String, Integer> reducingState =
new AbstractReducingState<>(aec, descriptor);
aec.setCurrentContext(aec.buildContext("test", "test"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@

/** Test for {@link InternalTimerServiceAsyncImpl}. */
class InternalTimerServiceAsyncImplTest {
private AsyncExecutionController asyncExecutionController;
private AsyncExecutionController<String> asyncExecutionController;
private TestKeyContext keyContext;
private TestProcessingTimeService processingTimeService;
private InternalTimerServiceAsyncImpl<Integer, String> service;
Expand All @@ -59,14 +59,15 @@ public void handleException(String message, Throwable exception) {
@BeforeEach
void setup() throws Exception {
asyncExecutionController =
new AsyncExecutionController(
new AsyncExecutionController<>(
new SyncMailboxExecutor(),
exceptionHandler,
new MockStateExecutor(),
128,
2,
1000L,
10);
10,
null);
// ensure arbitrary key is in the key group
int totalKeyGroups = 128;
KeyGroupRange testKeyGroupList = new KeyGroupRange(0, totalKeyGroups - 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
* A KeyedStateBackend that stores its state in {@code ForSt}. This state backend can store very
* large state that exceeds memory even disk to remote storage.
*/
public class ForStKeyedStateBackend<K> implements AsyncKeyedStateBackend {
public class ForStKeyedStateBackend<K> implements AsyncKeyedStateBackend<K> {

private static final Logger LOG = LoggerFactory.getLogger(ForStKeyedStateBackend.class);

Expand Down

0 comments on commit 90aa949

Please sign in to comment.