Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
yehaolan committed Sep 27, 2024
1 parent a52ef9d commit 2bcaac6
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ public class TaskConfig extends MapConfig {
public static final String COMMIT_TIMEOUT_MS = "task.commit.timeout.ms";
static final long DEFAULT_COMMIT_TIMEOUT_MS = Duration.ofMinutes(30).toMillis();

public static final String SKIP_COMMIT_DURING_FAILURES_ENABLED = "task.commit.skip.commit.during.failures.enabled";
private static final boolean DEFAULT_SKIP_COMMIT_DURING_FAILURES_ENABLED = false;

// how long to wait for a clean shutdown
public static final String TASK_SHUTDOWN_MS = "task.shutdown.ms";
static final long DEFAULT_TASK_SHUTDOWN_MS = 30000L;
Expand Down Expand Up @@ -418,4 +421,8 @@ public long getWatermarkIdleTimeoutMs() {
public double getWatermarkQuorumSizePercentage() {
return getDouble(WATERMARK_QUORUM_SIZE_PERCENTAGE, DEFAULT_WATERMARK_QUORUM_SIZE_PERCENTAGE);
}

public boolean getSkipCommitDuringFailuresEnabled() {
return getBoolean(SKIP_COMMIT_DURING_FAILURES_ENABLED, DEFAULT_SKIP_COMMIT_DURING_FAILURES_ENABLED);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ class TaskInstance(
@volatile var lastCommitStartTimeMs = System.currentTimeMillis()
val commitMaxDelayMs = taskConfig.getCommitMaxDelayMs
val commitTimeoutMs = taskConfig.getCommitTimeoutMs
val skipCommitDuringFailureEnabled = taskConfig.getSkipCommitDuringFailuresEnabled
val commitInProgress = new Semaphore(1)
val commitException = new AtomicReference[Exception]()

Expand Down Expand Up @@ -312,10 +313,18 @@ class TaskInstance(

val commitStartNs = System.nanoTime()
// first check if there were any unrecoverable errors during the async stage of the pending commit
// and if so, shut down the container.
// If there is unrecoverable error and skipCommitDuringFailureEnabled is enabled, ignore the error.
// Otherwise, shut down the container.
if (commitException.get() != null) {
throw new SamzaException("Unrecoverable error during pending commit for taskName: %s." format taskName,
commitException.get())
if (skipCommitDuringFailureEnabled) {
warn("Ignored the commit failure for taskName %s: %s" format (taskName, commitException.get().getMessage))
metrics.commitExceptionIgnored.set(metrics.commitExceptionIgnored.getValue + 1)
commitException.set(null)
commitInProgress.release()
} else {
throw new SamzaException("Unrecoverable error during pending commit for taskName: %s." format taskName,
commitException.get())
}
}

// if no commit is in progress for this task, continue with this commit.
Expand All @@ -339,10 +348,18 @@ class TaskInstance(
if (!commitInProgress.tryAcquire(commitTimeoutMs, TimeUnit.MILLISECONDS)) {
val timeSinceLastCommit = System.currentTimeMillis() - lastCommitStartTimeMs
metrics.commitsTimedOut.set(metrics.commitsTimedOut.getValue + 1)
throw new SamzaException("Timeout waiting for pending commit for taskName: %s to finish. " +
"%s ms have elapsed since the pending commit started. Max allowed commit delay is %s ms " +
"and commit timeout beyond that is %s ms" format (taskName, timeSinceLastCommit,
commitMaxDelayMs, commitTimeoutMs))
if (skipCommitDuringFailureEnabled) {
warn("Ignoring commit timeout for taskName: %s. %s ms have elapsed since another commit started. " +
"Max allowed commit delay is %s ms and commit timeout beyond that is %s ms."
format (taskName, timeSinceLastCommit, commitMaxDelayMs, commitTimeoutMs))
commitInProgress.release()
return
} else {
throw new SamzaException("Timeout waiting for pending commit for taskName: %s to finish. " +
"%s ms have elapsed since the pending commit started. Max allowed commit delay is %s ms " +
"and commit timeout beyond that is %s ms" format (taskName, timeSinceLastCommit,
commitMaxDelayMs, commitTimeoutMs))
}
}
}
}
Expand Down Expand Up @@ -426,7 +443,7 @@ class TaskInstance(
}
})

metrics.lastCommitNs.set(System.nanoTime() - commitStartNs)
metrics.lastCommitNs.set(System.nanoTime())
metrics.commitSyncNs.update(System.nanoTime() - commitStartNs)
debug("Finishing sync stage of commit for taskName: %s checkpointId: %s" format (taskName, checkpointId))
}
Expand Down Expand Up @@ -533,6 +550,7 @@ class TaskInstance(
} else {
metrics.commitAsyncNs.update(System.nanoTime() - asyncStageStartNs)
metrics.commitNs.update(System.nanoTime() - commitStartNs)
metrics.lastAsyncCommitNs.set(System.nanoTime())
}
} finally {
// release the permit indicating that previous commit is complete.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,10 @@ class TaskInstanceMetrics(
val asyncCallbackCompleted = newCounter("async-callback-complete-calls")
val commitsTimedOut = newGauge("commits-timed-out", 0)
val commitsSkipped = newGauge("commits-skipped", 0)
val commitExceptionIgnored = newGauge("commit-exceptions-ignored", 0)
val commitNs = newTimer("commit-ns")
val lastCommitNs = newGauge("last-commit-ns", 0L)
val lastAsyncCommitNs = newGauge("last-async-commit-ns", 0L)
val commitSyncNs = newTimer("commit-sync-ns")
val commitAsyncNs = newTimer("commit-async-ns")
val snapshotNs = newTimer("snapshot-ns")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1004,6 +1004,131 @@ class TestTaskInstance extends AssertionsForJUnit with MockitoSugar {
verify(snapshotTimer, times(2)).update(anyLong())
}

@Test
def testSkipExceptionFromFirstCommitAndContinueSecondCommit(): Unit = {
val commitsCounter = mock[Counter]
when(this.metrics.commits).thenReturn(commitsCounter)
val snapshotTimer = mock[Timer]
when(this.metrics.snapshotNs).thenReturn(snapshotTimer)
val uploadTimer = mock[Timer]
when(this.metrics.asyncUploadNs).thenReturn(uploadTimer)
val commitTimer = mock[Timer]
when(this.metrics.commitNs).thenReturn(commitTimer)
val commitSyncTimer = mock[Timer]
when(this.metrics.commitSyncNs).thenReturn(commitSyncTimer)
val commitAsyncTimer = mock[Timer]
when(this.metrics.commitAsyncNs).thenReturn(commitAsyncTimer)
val cleanUpTimer = mock[Timer]
when(this.metrics.asyncCleanupNs).thenReturn(cleanUpTimer)
val skippedCounter = mock[Gauge[Int]]
when(this.metrics.commitsSkipped).thenReturn(skippedCounter)
val lastCommitGauge = mock[Gauge[Long]]
when(this.metrics.lastCommitNs).thenReturn(lastCommitGauge)
val commitExceptionIgnoredCounter = mock[Gauge[Int]]
when(this.metrics.commitExceptionIgnored).thenReturn(commitExceptionIgnoredCounter)

val taskConfigsMap = new util.HashMap[String, String]()
taskConfigsMap.put("task.commit.ms", "-1")
taskConfigsMap.put("task.commit.max.delay.ms", "-1")
taskConfigsMap.put("task.commit.timeout.ms", "2000000")
// skip commit if exception occurs during the commit
taskConfigsMap.put("task.commit.skip.commit.during.failures.enabled", "true")
when(this.jobContext.getConfig).thenReturn(new MapConfig(taskConfigsMap))
setupTaskInstance(None, ForkJoinPool.commonPool())

val inputOffsets = new util.HashMap[SystemStreamPartition, String]()
inputOffsets.put(SYSTEM_STREAM_PARTITION, "4")
val stateCheckpointMarkers: util.Map[String, String] = new util.HashMap[String, String]()
when(this.offsetManager.getLastProcessedOffsets(TASK_NAME)).thenReturn(inputOffsets)
// Ensure the second commit proceeds without exceptions
when(this.taskCommitManager.upload(any(), any()))
.thenReturn(CompletableFuture.completedFuture(
Collections.singletonMap(KafkaStateCheckpointMarker.KAFKA_STATE_BACKEND_FACTORY_NAME, stateCheckpointMarkers)))
// exception during the first commit
when(this.taskCommitManager.upload(any(), any()))
.thenReturn(FutureUtil.failedFuture[util.Map[String, util.Map[String, String]]](new RuntimeException))

taskInstance.commit
verify(commitsCounter).inc()
verify(snapshotTimer).update(anyLong())
verifyZeroInteractions(uploadTimer)
verifyZeroInteractions(commitTimer)
verifyZeroInteractions(skippedCounter)
Thread.sleep(1000) // ensure the commitException is updated by the previous commit
taskInstance.commit
verify(commitsCounter, times(2)).inc() // should only have been incremented twice - once for each commit
verify(commitExceptionIgnoredCounter).set(1)
}

@Test
def testIgnoreTimeoutAndContinueCommitIfPreviousAsyncCommitInProgressAfterMaxCommitDelayAndBlockTime(): Unit = {
val commitsCounter = mock[Counter]
when(this.metrics.commits).thenReturn(commitsCounter)
val snapshotTimer = mock[Timer]
when(this.metrics.snapshotNs).thenReturn(snapshotTimer)
val commitTimer = mock[Timer]
when(this.metrics.commitNs).thenReturn(commitTimer)
val commitSyncTimer = mock[Timer]
when(this.metrics.commitSyncNs).thenReturn(commitSyncTimer)
val commitAsyncTimer = mock[Timer]
when(this.metrics.commitAsyncNs).thenReturn(commitAsyncTimer)
val uploadTimer = mock[Timer]
when(this.metrics.asyncUploadNs).thenReturn(uploadTimer)
val cleanUpTimer = mock[Timer]
when(this.metrics.asyncCleanupNs).thenReturn(cleanUpTimer)
val skippedCounter = mock[Gauge[Int]]
when(this.metrics.commitsSkipped).thenReturn(skippedCounter)
val commitsTimedOutCounter = mock[Gauge[Int]]
when(this.metrics.commitsTimedOut).thenReturn(commitsTimedOutCounter)
val lastCommitGauge = mock[Gauge[Long]]
when(this.metrics.lastCommitNs).thenReturn(lastCommitGauge)
val commitExceptionIgnoredCounter = mock[Gauge[Int]]
when(this.metrics.commitExceptionIgnored).thenReturn(commitExceptionIgnoredCounter)

val inputOffsets = new util.HashMap[SystemStreamPartition, String]()
inputOffsets.put(SYSTEM_STREAM_PARTITION,"4")
val changelogSSP = new SystemStreamPartition(new SystemStream(SYSTEM_NAME, "test-changelog-stream"), new Partition(0))

val stateCheckpointMarkers: util.Map[String, String] = new util.HashMap[String, String]()
val stateCheckpointMarker = KafkaStateCheckpointMarker.serialize(new KafkaStateCheckpointMarker(changelogSSP, "5"))
stateCheckpointMarkers.put("storeName", stateCheckpointMarker)
when(this.offsetManager.getLastProcessedOffsets(TASK_NAME)).thenReturn(inputOffsets)

val snapshotSCMs = ImmutableMap.of(KafkaStateCheckpointMarker.KAFKA_STATE_BACKEND_FACTORY_NAME, stateCheckpointMarkers)
when(this.taskCommitManager.snapshot(any())).thenReturn(snapshotSCMs)
val snapshotSCMFuture: CompletableFuture[util.Map[String, util.Map[String, String]]] =
CompletableFuture.completedFuture(snapshotSCMs)

when(this.taskCommitManager.upload(any(), Matchers.eq(snapshotSCMs))).thenReturn(snapshotSCMFuture) // kafka is no-op

val cleanUpFuture = new CompletableFuture[Void]()
when(this.taskCommitManager.cleanUp(any(), any())).thenReturn(cleanUpFuture)

// use a separate executor to perform async operations on to test caller thread blocking behavior
val taskConfigsMap = new util.HashMap[String, String]()
taskConfigsMap.put("task.commit.ms", "-1")
// "block" immediately if previous commit async stage not complete
taskConfigsMap.put("task.commit.max.delay.ms", "-1")
taskConfigsMap.put("task.commit.timeout.ms", "0") // throw exception immediately if blocked
taskConfigsMap.put("task.commit.skip.commit.during.failures.enabled", "true")
when(this.jobContext.getConfig).thenReturn(new MapConfig(taskConfigsMap)) // override default behavior

setupTaskInstance(None, ForkJoinPool.commonPool())

taskInstance.commit // async stage will not complete until cleanUpFuture is completed
taskInstance.commit // second commit found commit timeout and release the semaphore
cleanUpFuture.complete(null) // just to unblock shared executor

verifyZeroInteractions(commitExceptionIgnoredCounter)
verifyZeroInteractions(skippedCounter)
verify(commitsTimedOutCounter).set(1)
verify(commitsCounter, times(1)).inc() // should only have been incremented once now - second commit was skipped

taskInstance.commit // third commit should proceed without any issues

verify(commitsCounter, times(2)).inc() // should only have been incremented twice - second commit was skipped
}


/**
* Given that no application task context factory is provided, then no lifecycle calls should be made.
Expand Down

0 comments on commit 2bcaac6

Please sign in to comment.