Skip to content

Commit

Permalink
[FLINK-35156][Runtime] Rework: Make operators of DataStream V2 integr…
Browse files Browse the repository at this point in the history
…ate with async state processing framework
  • Loading branch information
Zakelly committed Dec 3, 2024
1 parent 64af905 commit a7292e3
Show file tree
Hide file tree
Showing 22 changed files with 252 additions and 164 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,18 @@
package org.apache.flink.datastream.api.context;

import org.apache.flink.annotation.Experimental;
import org.apache.flink.api.common.state.AggregatingState;
import org.apache.flink.api.common.state.AggregatingStateDeclaration;
import org.apache.flink.api.common.state.BroadcastState;
import org.apache.flink.api.common.state.BroadcastStateDeclaration;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDeclaration;
import org.apache.flink.api.common.state.MapState;
import org.apache.flink.api.common.state.MapStateDeclaration;
import org.apache.flink.api.common.state.ReducingState;
import org.apache.flink.api.common.state.ReducingStateDeclaration;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDeclaration;
import org.apache.flink.api.common.state.v2.AggregatingState;
import org.apache.flink.api.common.state.v2.ListState;
import org.apache.flink.api.common.state.v2.MapState;
import org.apache.flink.api.common.state.v2.ReducingState;
import org.apache.flink.api.common.state.v2.ValueState;

import java.util.Optional;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,16 @@

package org.apache.flink.datastream.impl.context;

import org.apache.flink.api.common.state.OperatorStateStore;
import org.apache.flink.datastream.api.context.JobInfo;
import org.apache.flink.datastream.api.context.PartitionedContext;
import org.apache.flink.datastream.api.context.ProcessingTimeManager;
import org.apache.flink.datastream.api.context.RuntimeContext;
import org.apache.flink.datastream.api.context.TaskInfo;
import org.apache.flink.metrics.MetricGroup;
import org.apache.flink.runtime.state.v2.OperatorStateStore;
import org.apache.flink.streaming.api.operators.StreamingRuntimeContext;

import java.util.function.Consumer;
import java.util.function.BiConsumer;
import java.util.function.Supplier;

/** The default implementation of {@link PartitionedContext}. */
Expand All @@ -41,14 +41,14 @@ public class DefaultPartitionedContext implements PartitionedContext {
public DefaultPartitionedContext(
RuntimeContext context,
Supplier<Object> currentKeySupplier,
Consumer<Object> currentKeySetter,
BiConsumer<Runnable, Object> processorWithKey,
ProcessingTimeManager processingTimeManager,
StreamingRuntimeContext operatorContext,
OperatorStateStore operatorStateStore) {
this.context = context;
this.stateManager =
new DefaultStateManager(
currentKeySupplier, currentKeySetter, operatorContext, operatorStateStore);
currentKeySupplier, processorWithKey, operatorContext, operatorStateStore);
this.processingTimeManager = processingTimeManager;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,31 +18,32 @@

package org.apache.flink.datastream.impl.context;

import org.apache.flink.api.common.state.AggregatingState;
import org.apache.flink.api.common.state.AggregatingStateDeclaration;
import org.apache.flink.api.common.state.AggregatingStateDescriptor;
import org.apache.flink.api.common.state.BroadcastState;
import org.apache.flink.api.common.state.BroadcastStateDeclaration;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDeclaration;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.MapState;
import org.apache.flink.api.common.state.MapStateDeclaration;
import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.api.common.state.OperatorStateStore;
import org.apache.flink.api.common.state.ReducingState;
import org.apache.flink.api.common.state.ReducingStateDeclaration;
import org.apache.flink.api.common.state.ReducingStateDescriptor;
import org.apache.flink.api.common.state.StateDeclaration;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDeclaration;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.state.v2.AggregatingState;
import org.apache.flink.api.common.state.v2.ListState;
import org.apache.flink.api.common.state.v2.MapState;
import org.apache.flink.api.common.state.v2.ReducingState;
import org.apache.flink.api.common.state.v2.ValueState;
import org.apache.flink.api.java.typeutils.TypeExtractor;
import org.apache.flink.datastream.api.context.StateManager;
import org.apache.flink.runtime.state.v2.AggregatingStateDescriptor;
import org.apache.flink.runtime.state.v2.ListStateDescriptor;
import org.apache.flink.runtime.state.v2.MapStateDescriptor;
import org.apache.flink.runtime.state.v2.OperatorStateStore;
import org.apache.flink.runtime.state.v2.ReducingStateDescriptor;
import org.apache.flink.runtime.state.v2.ValueStateDescriptor;
import org.apache.flink.streaming.api.operators.StreamingRuntimeContext;
import org.apache.flink.util.Preconditions;

import java.util.Optional;
import java.util.function.Consumer;
import java.util.function.BiConsumer;
import java.util.function.Supplier;

/**
Expand All @@ -51,25 +52,22 @@
*/
public class DefaultStateManager implements StateManager {

/**
* Retrieve the current key. When {@link #currentKeySetter} receives a key, this must return
* that key until it is reset.
*/
/** Retrieve the current key. */
private final Supplier<Object> currentKeySupplier;

private final Consumer<Object> currentKeySetter;
private final BiConsumer<Runnable, Object> processorWithKey;

protected final StreamingRuntimeContext operatorContext;

protected final OperatorStateStore operatorStateStore;

public DefaultStateManager(
Supplier<Object> currentKeySupplier,
Consumer<Object> currentKeySetter,
BiConsumer<Runnable, Object> processorWithKey,
StreamingRuntimeContext operatorContext,
OperatorStateStore operatorStateStore) {
this.currentKeySupplier = currentKeySupplier;
this.currentKeySetter = currentKeySetter;
this.processorWithKey = processorWithKey;
this.operatorContext = Preconditions.checkNotNull(operatorContext);
this.operatorStateStore = Preconditions.checkNotNull(operatorStateStore);
}
Expand All @@ -86,8 +84,9 @@ public <T> Optional<ValueState<T>> getState(ValueStateDeclaration<T> stateDeclar
ValueStateDescriptor<T> valueStateDescriptor =
new ValueStateDescriptor<>(
stateDeclaration.getName(),
stateDeclaration.getTypeDescriptor().getTypeClass());
return Optional.ofNullable(operatorContext.getState(valueStateDescriptor));
TypeExtractor.createTypeInfo(
stateDeclaration.getTypeDescriptor().getTypeClass()));
return Optional.ofNullable(operatorContext.getValueState(valueStateDescriptor));
}

@Override
Expand All @@ -97,7 +96,8 @@ public <T> Optional<ListState<T>> getState(ListStateDeclaration<T> stateDeclarat
ListStateDescriptor<T> listStateDescriptor =
new ListStateDescriptor<>(
stateDeclaration.getName(),
stateDeclaration.getTypeDescriptor().getTypeClass());
TypeExtractor.createTypeInfo(
stateDeclaration.getTypeDescriptor().getTypeClass()));

if (stateDeclaration.getRedistributionMode()
== StateDeclaration.RedistributionMode.REDISTRIBUTABLE) {
Expand All @@ -119,8 +119,10 @@ public <K, V> Optional<MapState<K, V>> getState(MapStateDeclaration<K, V> stateD
MapStateDescriptor<K, V> mapStateDescriptor =
new MapStateDescriptor<>(
stateDeclaration.getName(),
stateDeclaration.getKeyTypeDescriptor().getTypeClass(),
stateDeclaration.getValueTypeDescriptor().getTypeClass());
TypeExtractor.createTypeInfo(
stateDeclaration.getKeyTypeDescriptor().getTypeClass()),
TypeExtractor.createTypeInfo(
stateDeclaration.getValueTypeDescriptor().getTypeClass()));
return Optional.ofNullable(operatorContext.getMapState(mapStateDescriptor));
}

Expand All @@ -131,7 +133,8 @@ public <T> Optional<ReducingState<T>> getState(ReducingStateDeclaration<T> state
new ReducingStateDescriptor<>(
stateDeclaration.getName(),
stateDeclaration.getReduceFunction(),
stateDeclaration.getTypeDescriptor().getTypeClass());
TypeExtractor.createTypeInfo(
stateDeclaration.getTypeDescriptor().getTypeClass()));
return Optional.ofNullable(operatorContext.getReducingState(reducingStateDescriptor));
}

Expand All @@ -142,7 +145,8 @@ public <IN, ACC, OUT> Optional<AggregatingState<IN, OUT>> getState(
new AggregatingStateDescriptor<>(
stateDeclaration.getName(),
stateDeclaration.getAggregateFunction(),
stateDeclaration.getTypeDescriptor().getTypeClass());
TypeExtractor.createTypeInfo(
stateDeclaration.getTypeDescriptor().getTypeClass()));
return Optional.ofNullable(operatorContext.getAggregatingState(aggregatingStateDescriptor));
}

Expand All @@ -152,8 +156,10 @@ public <K, V> Optional<BroadcastState<K, V>> getState(
MapStateDescriptor<K, V> mapStateDescriptor =
new MapStateDescriptor<>(
stateDeclaration.getName(),
stateDeclaration.getKeyTypeDescriptor().getTypeClass(),
stateDeclaration.getValueTypeDescriptor().getTypeClass());
TypeExtractor.createTypeInfo(
stateDeclaration.getKeyTypeDescriptor().getTypeClass()),
TypeExtractor.createTypeInfo(
stateDeclaration.getValueTypeDescriptor().getTypeClass()));
return Optional.ofNullable(operatorStateStore.getBroadcastState(mapStateDescriptor));
}

Expand All @@ -162,20 +168,6 @@ public <K, V> Optional<BroadcastState<K, V>> getState(
* key must be reset after the block is executed.
*/
public void executeInKeyContext(Runnable runnable, Object key) {
final Object oldKey = currentKeySupplier.get();
setCurrentKey(key);
try {
runnable.run();
} finally {
resetCurrentKey(oldKey);
}
}

private void setCurrentKey(Object key) {
currentKeySetter.accept(key);
}

private void resetCurrentKey(Object oldKey) {
currentKeySetter.accept(oldKey);
processorWithKey.accept(runnable, key);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@
import java.util.HashSet;
import java.util.Set;

import static org.apache.flink.util.Preconditions.checkNotNull;

/** Operator for {@link OneInputStreamProcessFunction} in {@link KeyedPartitionStream}. */
public class KeyedProcessOperator<KEY, IN, OUT> extends ProcessOperator<IN, OUT>
implements Triggerable<KEY, VoidNamespace> {
Expand Down Expand Up @@ -91,16 +89,8 @@ public void onEventTime(InternalTimer<KEY, VoidNamespace> timer) throws Exceptio

@Override
public void onProcessingTime(InternalTimer<KEY, VoidNamespace> timer) throws Exception {
// align the key context with the registered timer.
partitionedContext
.getStateManager()
.executeInKeyContext(
() ->
userFunction.onProcessingTimer(
timer.getTimestamp(),
getOutputCollector(),
partitionedContext),
timer.getKey());
userFunction.onProcessingTimer(
timer.getTimestamp(), getOutputCollector(), partitionedContext);
}

@Override
Expand All @@ -115,16 +105,14 @@ protected NonPartitionedContext<OUT> getNonPartitionedContext() {
}

@Override
@SuppressWarnings({"unchecked", "rawtypes"})
@SuppressWarnings({"rawtypes"})
public void setKeyContextElement1(StreamRecord record) throws Exception {
setKeyContextElement(record, getStateKeySelector1());
super.setKeyContextElement1(record);
keySet.add(getCurrentKey());
}

private <T> void setKeyContextElement(StreamRecord<T> record, KeySelector<T, ?> selector)
throws Exception {
checkNotNull(selector);
Object key = selector.getKey(record.getValue());
setCurrentKey(key);
keySet.add(key);
@Override
public boolean isAsyncStateProcessingEnabled() {
return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -95,16 +95,8 @@ public void onEventTime(InternalTimer<KEY, VoidNamespace> timer) throws Exceptio

@Override
public void onProcessingTime(InternalTimer<KEY, VoidNamespace> timer) throws Exception {
// align the key context with the registered timer.
partitionedContext
.getStateManager()
.executeInKeyContext(
() ->
userFunction.onProcessingTimer(
timer.getTimestamp(),
getOutputCollector(),
partitionedContext),
timer.getKey());
userFunction.onProcessingTimer(
timer.getTimestamp(), getOutputCollector(), partitionedContext);
}

@Override
Expand All @@ -114,19 +106,15 @@ protected NonPartitionedContext<OUT> getNonPartitionedContext() {
}

@Override
@SuppressWarnings({"unchecked", "rawtypes"})
@SuppressWarnings({"rawtypes"})
// Only element from input1 should be considered as the other side is broadcast input.
public void setKeyContextElement1(StreamRecord record) throws Exception {
setKeyContextElement(record, getStateKeySelector1());
super.setKeyContextElement1(record);
keySet.add(getCurrentKey());
}

private <T> void setKeyContextElement(StreamRecord<T> record, KeySelector<T, ?> selector)
throws Exception {
if (selector == null) {
return;
}
Object key = selector.getKey(record.getValue());
setCurrentKey(key);
keySet.add(key);
@Override
public boolean isAsyncStateProcessingEnabled() {
return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -97,16 +97,8 @@ public void onEventTime(InternalTimer<KEY, VoidNamespace> timer) throws Exceptio

@Override
public void onProcessingTime(InternalTimer<KEY, VoidNamespace> timer) throws Exception {
// align the key context with the registered timer.
partitionedContext
.getStateManager()
.executeInKeyContext(
() ->
userFunction.onProcessingTimer(
timer.getTimestamp(),
getOutputCollector(),
partitionedContext),
timer.getKey());
userFunction.onProcessingTimer(
timer.getTimestamp(), getOutputCollector(), partitionedContext);
}

@Override
Expand All @@ -116,24 +108,21 @@ protected NonPartitionedContext<OUT> getNonPartitionedContext() {
}

@Override
@SuppressWarnings({"unchecked", "rawtypes"})
@SuppressWarnings({"rawtypes"})
public void setKeyContextElement1(StreamRecord record) throws Exception {
setKeyContextElement(record, getStateKeySelector1());
super.setKeyContextElement1(record);
keySet.add(getCurrentKey());
}

@Override
@SuppressWarnings({"unchecked", "rawtypes"})
@SuppressWarnings({"rawtypes"})
public void setKeyContextElement2(StreamRecord record) throws Exception {
setKeyContextElement(record, getStateKeySelector2());
super.setKeyContextElement2(record);
keySet.add(getCurrentKey());
}

private <T> void setKeyContextElement(StreamRecord<T> record, KeySelector<T, ?> selector)
throws Exception {
if (selector == null) {
return;
}
Object key = selector.getKey(record.getValue());
setCurrentKey(key);
keySet.add(key);
@Override
public boolean isAsyncStateProcessingEnabled() {
return true;
}
}
Loading

0 comments on commit a7292e3

Please sign in to comment.