Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SAMZA-2796: Introduce config knob for framework thread sub DAG execution #1691

Merged
merged 2 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions docs/learn/documentation/versioned/jobs/configuration-table.html
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,16 @@ <h1>Samza Configuration Reference</h1>
</td>
</tr>

<tr>
<td class="property" id="job.operator.framework.executor.enabled">job.operator.framework.executor.enabled</td>
<td class="default">false</td>
<td class="description">
If enabled, framework thread pool will be used for message hand off and sub DAG execution. Otherwise, the
execution will fall back to using caller thread or java fork join pool depending on the type of work
chained as part of message hand off.
</td>
</tr>

<tr>
<!-- change link to StandAlone design/tutorial doc. SAMZA-1299 -->
<th colspan="3" class="section" id="ZkBasedJobCoordination"><a href="../index.html">Zookeeper-based job configuration</a></th>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,10 @@ public class JobConfig extends MapConfig {
public static final String JOB_ELASTICITY_FACTOR = "job.elasticity.factor";
public static final int DEFAULT_JOB_ELASTICITY_FACTOR = 1;

public static final String JOB_OPERATOR_FRAMEWORK_EXECUTOR_ENABLED = "job.operator.framework.executor.enabled";

public static final boolean DEFAULT_JOB_OPERATOR_FRAMEWORK_EXECUTOR_ENABLED = false;

public JobConfig(Config config) {
super(config);
}
Expand Down Expand Up @@ -527,4 +531,8 @@ public int getElasticityFactor() {
public String getCoordinatorExecuteCommand() {
return get(COORDINATOR_EXECUTE_COMMAND, DEFAULT_COORDINATOR_EXECUTE_COMMAND);
}

public boolean getOperatorFrameworkExecutorEnabled() {
return getBoolean(JOB_OPERATOR_FRAMEWORK_EXECUTOR_ENABLED, DEFAULT_JOB_OPERATOR_FRAMEWORK_EXECUTOR_ENABLED);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.ExecutorService;
import java.util.function.Consumer;
import java.util.function.Function;
import org.apache.samza.SamzaException;
import org.apache.samza.config.Config;
import org.apache.samza.config.JobConfig;
Expand Down Expand Up @@ -95,6 +97,7 @@ public abstract class OperatorImpl<M, RM> {
private ControlMessageSender controlMessageSender;
private int elasticityFactor;
private ExecutorService operatorExecutor;
private boolean operatorExecutorEnabled;

/**
* Initialize this {@link OperatorImpl} and its user-defined functions.
Expand Down Expand Up @@ -136,7 +139,9 @@ public final void init(InternalTaskContext internalTaskContext) {
this.taskModel = taskContext.getTaskModel();
this.callbackScheduler = taskContext.getCallbackScheduler();
handleInit(context);
this.elasticityFactor = new JobConfig(config).getElasticityFactor();
JobConfig jobConfig = new JobConfig(config);
this.elasticityFactor = jobConfig.getElasticityFactor();
this.operatorExecutorEnabled = jobConfig.getOperatorFrameworkExecutorEnabled();
this.operatorExecutor = context.getTaskContext().getOperatorExecutor();

initialized = true;
Expand Down Expand Up @@ -192,21 +197,20 @@ public final CompletionStage<Void> onMessageAsync(M message, MessageCollector co
getOpImplId(), getOperatorSpec().getSourceLocation(), expectedType, actualType), e);
}

CompletionStage<Void> result = completableResultsFuture.thenComposeAsync(results -> {
CompletionStage<Void> result = composeFutureWithExecutor(completableResultsFuture, results -> {
long endNs = this.highResClock.nanoTime();
this.handleMessageNs.update(endNs - startNs);

return CompletableFuture.allOf(results.stream()
.flatMap(r -> this.registeredOperators.stream()
.map(op -> op.onMessageAsync(r, collector, coordinator)))
.flatMap(r -> this.registeredOperators.stream().map(op -> op.onMessageAsync(r, collector, coordinator)))
.toArray(CompletableFuture[]::new));
}, operatorExecutor);
});

WatermarkFunction watermarkFn = getOperatorSpec().getWatermarkFn();
if (watermarkFn != null) {
// check whether there is new watermark emitted from the user function
Long outputWm = watermarkFn.getOutputWatermark();
return result.thenComposeAsync(ignored -> propagateWatermark(outputWm, collector, coordinator), operatorExecutor);
return composeFutureWithExecutor(result, ignored -> propagateWatermark(outputWm, collector, coordinator));
}

return result;
Expand Down Expand Up @@ -245,11 +249,9 @@ public final CompletionStage<Void> onTimer(MessageCollector collector, TaskCoord
.map(op -> op.onMessageAsync(r, collector, coordinator)))
.toArray(CompletableFuture[]::new));

return resultFuture.thenComposeAsync(x ->
CompletableFuture.allOf(this.registeredOperators
.stream()
.map(op -> op.onTimer(collector, coordinator))
.toArray(CompletableFuture[]::new)), operatorExecutor);
return composeFutureWithExecutor(resultFuture, x -> CompletableFuture.allOf(this.registeredOperators.stream()
.map(op -> op.onTimer(collector, coordinator))
.toArray(CompletableFuture[]::new)));
}

/**
Expand Down Expand Up @@ -315,15 +317,14 @@ public final CompletionStage<Void> aggregateEndOfStream(EndOfStreamMessage eos,
}

// populate the end-of-stream through the dag
endOfStreamFuture = onEndOfStream(collector, coordinator)
.thenAcceptAsync(result -> {
if (eosStates.allEndOfStream()) {
// all inputs have been end-of-stream, shut down the task
LOG.info("All input streams have reached the end for task {}", taskName.getTaskName());
coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
}
}, operatorExecutor);
endOfStreamFuture = acceptFutureWithExecutor(onEndOfStream(collector, coordinator), result -> {
if (eosStates.allEndOfStream()) {
// all inputs have been end-of-stream, shut down the task
LOG.info("All input streams have reached the end for task {}", taskName.getTaskName());
coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
}
});
}

return endOfStreamFuture;
Expand All @@ -347,10 +348,10 @@ private CompletionStage<Void> onEndOfStream(MessageCollector collector, TaskCoor
.map(op -> op.onMessageAsync(r, collector, coordinator)))
.toArray(CompletableFuture[]::new));

endOfStreamFuture = resultFuture.thenComposeAsync(x ->
CompletableFuture.allOf(this.registeredOperators.stream()
endOfStreamFuture = composeFutureWithExecutor(resultFuture, x -> CompletableFuture.allOf(
this.registeredOperators.stream()
.map(op -> op.onEndOfStream(collector, coordinator))
.toArray(CompletableFuture[]::new)), operatorExecutor);
.toArray(CompletableFuture[]::new)));
}

return endOfStreamFuture;
Expand Down Expand Up @@ -406,15 +407,14 @@ public final CompletionStage<Void> aggregateDrainMessages(DrainMessage drainMess
controlMessageSender.broadcastToOtherPartitions(new DrainMessage(drainMessage.getRunId()), ssp, collector);
}

drainFuture = onDrainOfStream(collector, coordinator)
.thenAcceptAsync(result -> {
if (drainStates.areAllStreamsDrained()) {
// All input streams have been drained, shut down the task
LOG.info("All input streams have been drained for task {}. Requesting shutdown.", taskName.getTaskName());
coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
}
}, operatorExecutor);
drainFuture = acceptFutureWithExecutor(onDrainOfStream(collector, coordinator), result -> {
if (drainStates.areAllStreamsDrained()) {
// All input streams have been drained, shut down the task
LOG.info("All input streams have been drained for task {}. Requesting shutdown.", taskName.getTaskName());
coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
}
});
}

return drainFuture;
Expand All @@ -439,10 +439,10 @@ private CompletionStage<Void> onDrainOfStream(MessageCollector collector, TaskCo
.toArray(CompletableFuture[]::new));

// propagate DrainMessage to downstream operators
drainFuture = resultFuture.thenComposeAsync(x ->
CompletableFuture.allOf(this.registeredOperators.stream()
drainFuture = composeFutureWithExecutor(resultFuture, x -> CompletableFuture.allOf(
this.registeredOperators.stream()
.map(op -> op.onDrainOfStream(collector, coordinator))
.toArray(CompletableFuture[]::new)), operatorExecutor);
.toArray(CompletableFuture[]::new)));
}

return drainFuture;
Expand Down Expand Up @@ -474,8 +474,8 @@ public final CompletionStage<Void> aggregateWatermark(WatermarkMessage watermark
controlMessageSender.broadcastToOtherPartitions(new WatermarkMessage(watermark), ssp, collector);
}
// populate the watermark through the dag
watermarkFuture = onWatermark(watermark, collector, coordinator)
.thenAcceptAsync(ignored -> watermarkStates.updateAggregateMetric(ssp, watermark), operatorExecutor);
watermarkFuture = acceptFutureWithExecutor(onWatermark(watermark, collector, coordinator),
ignored -> watermarkStates.updateAggregateMetric(ssp, watermark));
}

return watermarkFuture;
Expand Down Expand Up @@ -530,8 +530,8 @@ private CompletionStage<Void> onWatermark(long watermark, MessageCollector colle
.toArray(CompletableFuture[]::new));
}

watermarkFuture = watermarkFuture.thenComposeAsync(res -> propagateWatermark(outputWm, collector, coordinator),
operatorExecutor);
watermarkFuture =
composeFutureWithExecutor(watermarkFuture, res -> propagateWatermark(outputWm, collector, coordinator));
}

return watermarkFuture;
Expand Down Expand Up @@ -679,6 +679,20 @@ final Collection<RM> handleMessage(M message, MessageCollector collector, TaskCo
.toCompletableFuture().join();
}

@VisibleForTesting
final <T, U> CompletionStage<U> composeFutureWithExecutor(CompletionStage<T> futureToChain,
Function<? super T, ? extends CompletionStage<U>> fn) {
return operatorExecutorEnabled ? futureToChain.thenComposeAsync(fn, operatorExecutor)
: futureToChain.thenCompose(fn);
}

@VisibleForTesting
final <T> CompletionStage<Void> acceptFutureWithExecutor(CompletionStage<T> futureToChain,
Consumer<? super T> consumer) {
return operatorExecutorEnabled ? futureToChain.thenAcceptAsync(consumer, operatorExecutor)
: futureToChain.thenAccept(consumer);
}

private HighResolutionClock createHighResClock(Config config) {
MetricsConfig metricsConfig = new MetricsConfig(config);
// The timer metrics calculation here is only enabled for debugging
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,13 @@ class TaskInstance(
val jobConfig = new JobConfig(jobContext.getConfig)
val taskExecutorFactory = ReflectionUtil.getObj(jobConfig.getTaskExecutorFactory, classOf[TaskExecutorFactory])

var operatorExecutor = Option.empty[java.util.concurrent.ExecutorService].orNull
if (jobConfig.getOperatorFrameworkExecutorEnabled) {
operatorExecutor = taskExecutorFactory.getOperatorExecutor(taskName, jobContext.getConfig)
}
new TaskContextImpl(taskModel, metrics.registry, kvStoreSupplier, tableManager,
new CallbackSchedulerImpl(epochTimeScheduler), offsetManager, jobModel, streamMetadataCache,
systemStreamPartitions, taskExecutorFactory.getOperatorExecutor(taskName, jobContext.getConfig))
systemStreamPartitions, operatorExecutor)
}
// need separate field for this instead of using it through Context, since Context throws an exception if it is null
private val applicationTaskContextOption = applicationTaskContextFactoryOption
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,24 @@
*/
package org.apache.samza.operators.impl;

import com.google.common.collect.ImmutableMap;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.function.Consumer;
import java.util.function.Function;
import org.apache.samza.config.Config;
import org.apache.samza.config.MapConfig;
import org.apache.samza.context.ContainerContext;
import org.apache.samza.context.Context;
import org.apache.samza.context.InternalTaskContext;
import org.apache.samza.context.MockContext;
import org.apache.samza.context.JobContext;
import org.apache.samza.context.TaskContext;
import org.apache.samza.job.model.TaskModel;
import org.apache.samza.metrics.Counter;
import org.apache.samza.metrics.MetricsRegistryMap;
Expand All @@ -44,33 +52,111 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.anyLong;
import static org.mockito.Matchers.anyObject;
import static org.mockito.Matchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.mockito.Matchers.*;
import static org.mockito.Mockito.*;


public class TestOperatorImpl {
private Context context;
private InternalTaskContext internalTaskContext;

private JobContext jobContext;

private TaskContext taskContext;

private ContainerContext containerContext;

@Before
public void setup() {
this.context = new MockContext();
this.context = mock(Context.class);
this.internalTaskContext = mock(InternalTaskContext.class);
this.jobContext = mock(JobContext.class);
this.taskContext = mock(TaskContext.class);
this.containerContext = mock(ContainerContext.class);
when(this.internalTaskContext.getContext()).thenReturn(this.context);
// might be necessary in the future
when(this.internalTaskContext.fetchObject(EndOfStreamStates.class.getName())).thenReturn(mock(EndOfStreamStates.class));
when(this.internalTaskContext.fetchObject(WatermarkStates.class.getName())).thenReturn(mock(WatermarkStates.class));
when(this.context.getTaskContext().getTaskMetricsRegistry()).thenReturn(new MetricsRegistryMap());
when(this.context.getTaskContext().getTaskModel()).thenReturn(mock(TaskModel.class));
when(this.context.getTaskContext().getOperatorExecutor()).thenReturn(Executors.newSingleThreadExecutor());
when(this.context.getContainerContext().getContainerMetricsRegistry()).thenReturn(new MetricsRegistryMap());
when(this.context.getJobContext()).thenReturn(jobContext);
when(this.context.getTaskContext()).thenReturn(taskContext);
when(this.taskContext.getTaskMetricsRegistry()).thenReturn(new MetricsRegistryMap());
when(this.taskContext.getTaskModel()).thenReturn(mock(TaskModel.class));
when(this.taskContext.getOperatorExecutor()).thenReturn(Executors.newSingleThreadExecutor());
when(this.context.getContainerContext()).thenReturn(containerContext);
when(containerContext.getContainerMetricsRegistry()).thenReturn(new MetricsRegistryMap());
}

@Test
public void testComposeFutureWithExecutorWithFrameworkExecutorEnabled() {
OperatorImpl<Object, Object> opImpl = new TestOpImpl(mock(Object.class));
ExecutorService mockExecutor = mock(ExecutorService.class);
CompletionStage<Object> mockFuture = mock(CompletionStage.class);
Function<Object, CompletionStage<Object>> mockFunction = mock(Function.class);

Config config = new MapConfig(ImmutableMap.of("job.operator.framework.executor.enabled", "true"));

when(this.taskContext.getOperatorExecutor()).thenReturn(mockExecutor);
when(this.jobContext.getConfig()).thenReturn(config);

opImpl.init(this.internalTaskContext);
opImpl.composeFutureWithExecutor(mockFuture, mockFunction);

verify(mockFuture).thenComposeAsync(eq(mockFunction), eq(mockExecutor));
}

@Test
public void testComposeFutureWithExecutorWithFrameworkExecutorDisabled() {
OperatorImpl<Object, Object> opImpl = new TestOpImpl(mock(Object.class));
ExecutorService mockExecutor = mock(ExecutorService.class);
CompletionStage<Object> mockFuture = mock(CompletionStage.class);
Function<Object, CompletionStage<Object>> mockFunction = mock(Function.class);

Config config = new MapConfig(ImmutableMap.of("job.operator.framework.executor.enabled", "false"));

when(this.taskContext.getOperatorExecutor()).thenReturn(mockExecutor);
when(this.jobContext.getConfig()).thenReturn(config);

opImpl.init(this.internalTaskContext);
opImpl.composeFutureWithExecutor(mockFuture, mockFunction);

verify(mockFuture).thenCompose(eq(mockFunction));
}

@Test
public void testAcceptFutureWithExecutorWithFrameworkExecutorDisabled() {
OperatorImpl<Object, Object> opImpl = new TestOpImpl(mock(Object.class));
ExecutorService mockExecutor = mock(ExecutorService.class);
CompletionStage<Object> mockFuture = mock(CompletionStage.class);
Consumer<Object> mockConsumer = mock(Consumer.class);

Config config = new MapConfig(ImmutableMap.of("job.operator.framework.executor.enabled", "false"));

when(this.taskContext.getOperatorExecutor()).thenReturn(mockExecutor);
when(this.jobContext.getConfig()).thenReturn(config);

opImpl.init(this.internalTaskContext);
opImpl.acceptFutureWithExecutor(mockFuture, mockConsumer);

verify(mockFuture).thenAccept(eq(mockConsumer));
}

@Test
public void testAcceptFutureWithExecutorWithFrameworkExecutorEnabled() {
OperatorImpl<Object, Object> opImpl = new TestOpImpl(mock(Object.class));
ExecutorService mockExecutor = mock(ExecutorService.class);
CompletionStage<Object> mockFuture = mock(CompletionStage.class);
Consumer<Object> mockConsumer = mock(Consumer.class);

Config config = new MapConfig(ImmutableMap.of("job.operator.framework.executor.enabled", "true"));

when(this.taskContext.getOperatorExecutor()).thenReturn(mockExecutor);
when(this.jobContext.getConfig()).thenReturn(config);

opImpl.init(this.internalTaskContext);
opImpl.acceptFutureWithExecutor(mockFuture, mockConsumer);

verify(mockFuture).thenAcceptAsync(eq(mockConsumer), eq(mockExecutor));
}
@Test(expected = IllegalStateException.class)
public void testMultipleInitShouldThrow() {
OperatorImpl<Object, Object> opImpl = new TestOpImpl(mock(Object.class));
Expand Down
Loading