diff --git a/enterprise/server/remote_execution/action_merger/action_merger.go b/enterprise/server/remote_execution/action_merger/action_merger.go index 5c3e1baad5a..dba1871e26a 100644 --- a/enterprise/server/remote_execution/action_merger/action_merger.go +++ b/enterprise/server/remote_execution/action_merger/action_merger.go @@ -22,12 +22,27 @@ const ( // set to twice the length of `remote_execution.lease_duration` to give a // short grace period in the event of missed leases. DefaultClaimedExecutionTTL = 20 * time.Second + + // Redis Hash keys for storing the canonical execution ID and the pending + // execution count. + executionIDKey = "execution-id" + executionCountKey = "execution-count" + + // The redis keys storing action merging data are versioned to support + // making backwards-incompatible changes to the storage representation. + // Increment this version to cycle to new keys (and discard all old + // action-merging data) during the next rollout. + keyVersion = 1 ) var ( enableActionMerging = flag.Bool("remote_execution.enable_action_merging", true, "If enabled, identical actions being executed concurrently are merged into a single execution.") + hedgedActionCount = flag.Int("remote_execution.action_merging_hedge_count", 0, "When action merging is enabled, this flag controls how many additional, 'hedged' attempts an action is run in the background. Note that even hedged actions are run at most once per execution request.") ) +// Returns the redis key pointing to the hash storing action merging state. The +// value stored here is a hash containing the canonical (first-submitted) +// execution ID and the count of executions run for this action (for hedging). func redisKeyForPendingExecutionID(ctx context.Context, adResource *digest.ResourceName) (string, error) { userPrefix, err := prefix.UserPrefixFromContext(ctx) if err != nil { @@ -37,11 +52,11 @@ func redisKeyForPendingExecutionID(ctx context.Context, adResource *digest.Resou if err != nil { return "", err } - return fmt.Sprintf("pendingExecution/%s%s", userPrefix, downloadString), nil + return fmt.Sprintf("pendingExecution/%d/%s%s", keyVersion, userPrefix, downloadString), nil } func redisKeyForPendingExecutionDigest(executionID string) string { - return fmt.Sprintf("pendingExecutionDigest/%s", executionID) + return fmt.Sprintf("pendingExecutionDigest/%d/%s", keyVersion, executionID) } // Action merging is an optimization that detects when an execution is @@ -69,7 +84,9 @@ func RecordQueuedExecution(ctx context.Context, rdb redis.UniversalClient, execu } reverseKey := redisKeyForPendingExecutionDigest(executionID) pipe := rdb.TxPipeline() - pipe.Set(ctx, forwardKey, executionID, queuedExecutionTTL) + pipe.HSet(ctx, forwardKey, executionIDKey, executionID) + pipe.HIncrBy(ctx, forwardKey, executionCountKey, 1) + pipe.Expire(ctx, forwardKey, queuedExecutionTTL) pipe.Set(ctx, reverseKey, forwardKey, queuedExecutionTTL) _, err = pipe.Exec(ctx) return err @@ -93,30 +110,47 @@ func RecordClaimedExecution(ctx context.Context, rdb redis.UniversalClient, exec } pipe := rdb.TxPipeline() - pipe.Set(ctx, forwardKey, executionID, ttl) + pipe.HSet(ctx, forwardKey, executionIDKey, executionID, ttl) pipe.Set(ctx, reverseKey, forwardKey, ttl) _, err = pipe.Exec(ctx) return err } +// Records a hedged execution in Redis. +func RecordHedgedExecution(ctx context.Context, rdb redis.UniversalClient, adResource *digest.ResourceName) error { + key, err := redisKeyForPendingExecutionID(ctx, adResource) + if err != nil { + return err + } + return rdb.HIncrBy(ctx, key, executionCountKey, 1).Err() +} + // Returns the execution ID of a pending execution working on the action with -// the provided action digest, or an empty string and possibly an error if no -// pending execution was found. -func FindPendingExecution(ctx context.Context, rdb redis.UniversalClient, schedulerService interfaces.SchedulerService, adResource *digest.ResourceName) (string, error) { +// the provided action digest, or an empty string, as well as a boolean if the +// provided action should be run additionally in the background ("hedged"), or +// an error if no pending execution was found. +func FindPendingExecution(ctx context.Context, rdb redis.UniversalClient, schedulerService interfaces.SchedulerService, adResource *digest.ResourceName) (string, bool, error) { if !*enableActionMerging { - return "", nil + return "", false, nil } - executionIDKey, err := redisKeyForPendingExecutionID(ctx, adResource) + forwardKey, err := redisKeyForPendingExecutionID(ctx, adResource) if err != nil { - return "", err + return "", false, err } - executionID, err := rdb.Get(ctx, executionIDKey).Result() + executionID, err := rdb.HGet(ctx, forwardKey, executionIDKey).Result() if err == redis.Nil { - return "", nil + return "", false, nil } if err != nil { - return "", err + return "", false, err + } + count, err := rdb.HGet(ctx, forwardKey, executionCountKey).Int() + if err == redis.Nil { + return "", false, nil + } + if err != nil { + return "", false, err } // Validate that the reverse mapping exists as well. The reverse mapping is @@ -124,23 +158,23 @@ func FindPendingExecution(ctx context.Context, rdb redis.UniversalClient, schedu // Bail out if it doesn't exist. err = rdb.Get(ctx, redisKeyForPendingExecutionDigest(executionID)).Err() if err == redis.Nil { - return "", nil + return "", false, nil } if err != nil { - return "", err + return "", false, err } // Finally, confirm this execution exists in the scheduler and hasn't been // lost somehow. ok, err := schedulerService.ExistsTask(ctx, executionID) if err != nil { - return "", err + return "", false, err } if !ok { log.CtxWarningf(ctx, "Pending execution %q does not exist in the scheduler", executionID) - return "", nil + return "", false, nil } - return executionID, nil + return executionID, count <= *hedgedActionCount, nil } // Deletes the pending execution with the provided execution ID. diff --git a/enterprise/server/remote_execution/execution_server/execution_server.go b/enterprise/server/remote_execution/execution_server/execution_server.go index b42728fe35c..20a8c3c30a9 100644 --- a/enterprise/server/remote_execution/execution_server/execution_server.go +++ b/enterprise/server/remote_execution/execution_server/execution_server.go @@ -405,6 +405,18 @@ type streamLike interface { Send(*longrunning.Operation) error } +// A streamLike that returns a background context and ignores all packets sent +// to it, used for waiting on pending executions in the background. +type dummyStream struct{} + +func (s dummyStream) Context() context.Context { + return context.Background() +} + +func (s dummyStream) Send(*longrunning.Operation) error { + return nil +} + func (s *ExecutionServer) Execute(req *repb.ExecuteRequest, stream repb.Execution_ExecuteServer) error { return s.execute(req, stream) } @@ -577,6 +589,7 @@ func (s *ExecutionServer) execute(req *repb.ExecuteRequest, stream streamLike) e } invocationID := bazel_request.GetInvocationID(stream.Context()) + hedge := false executionID := "" if !req.GetSkipCacheLookup() { if actionResult, err := s.getActionResultFromCache(ctx, adInstanceDigest); err == nil { @@ -597,7 +610,8 @@ func (s *ExecutionServer) execute(req *repb.ExecuteRequest, stream streamLike) e // Check if there's already an identical action pending execution. If // so, wait on the result of that execution instead of starting a new // one. - ee, err := action_merger.FindPendingExecution(ctx, s.rdb, s.env.GetSchedulerService(), adInstanceDigest) + ee, h, err := action_merger.FindPendingExecution(ctx, s.rdb, s.env.GetSchedulerService(), adInstanceDigest) + hedge = h if err != nil { log.CtxWarningf(ctx, "could not check for existing execution: %s", err) } @@ -629,6 +643,17 @@ func (s *ExecutionServer) execute(req *repb.ExecuteRequest, stream streamLike) e tracing.AddStringAttributeToCurrentSpan(ctx, "execution_result", "merged") tracing.AddStringAttributeToCurrentSpan(ctx, "execution_id", executionID) } + // If the action_merger said to hedge this action, run another execution + // in the background. + if hedge { + action_merger.RecordHedgedExecution(ctx, s.rdb, adInstanceDigest) + hedgedExecutionID, err := s.Dispatch(ctx, req) + if err != nil { + log.CtxWarningf(ctx, "Error dispatching execution for action %q and invocation %q: %s", downloadString, invocationID, err) + return err + } + log.CtxInfof(ctx, "Dispatched new hedged execution %q for action %q and invocation %q", hedgedExecutionID, downloadString, invocationID) + } waitReq := repb.WaitExecutionRequest{ Name: executionID, diff --git a/enterprise/server/test/integration/remote_execution/remote_execution_test.go b/enterprise/server/test/integration/remote_execution/remote_execution_test.go index c47a814f1ee..014e2c1bed8 100644 --- a/enterprise/server/test/integration/remote_execution/remote_execution_test.go +++ b/enterprise/server/test/integration/remote_execution/remote_execution_test.go @@ -1696,7 +1696,7 @@ func WaitForPendingExecution(rdb redis.UniversalClient, opID string) error { return status.NotFoundError("No forward key for pending execution") } -func TestActionMerging(t *testing.T) { +func TestActionMerging_Success(t *testing.T) { rbe := rbetest.NewRBETestEnv(t) rbe.AddBuildBuddyServer() @@ -1802,6 +1802,81 @@ func TestActionMerging_ClaimingAppDies(t *testing.T) { require.NotEqual(t, op1, op2, "unexpected action merge: dead app shouldn't block future actions") } +func TestActionMerging_FirstRunStuck(t *testing.T) { + flags.Set(t, "remote_execution.action_merging_hedge_count", 1) + rbe := rbetest.NewRBETestEnv(t) + rbe.AddBuildBuddyServer() + rbe.AddExecutor(t) + + // This script takes ~10min the first time it's run, but <1s subsequently. + fname := fmt.Sprintf("/tmp/%s", uuid.New().String()) + flakyScript := fmt.Sprintf(`FILE="%s" +if [ -f $FILE ]; then + echo "FAST" +else + echo "SLOW" > $FILE + sleep 600 +fi`, fname) + flakyCmd := &repb.Command{ + Arguments: []string{"sh", "-c", flakyScript}, + Platform: &repb.Platform{ + Properties: []*repb.Platform_Property{ + {Name: "OSFamily", Value: runtime.GOOS}, + {Name: "Arch", Value: runtime.GOARCH}, + }, + }, + } + + // This script can be used to check the status of the remote worker to see + // if the /tmp/uuid file exists, meaning flakyCmd will be fast. + existsScript := fmt.Sprintf(`FILE="%s" +if [ -f $FILE ]; then + echo "found" +fi`, fname) + existsCmd := &repb.Command{ + Arguments: []string{"sh", "-c", existsScript}, + Platform: &repb.Platform{ + Properties: []*repb.Platform_Property{ + {Name: "OSFamily", Value: runtime.GOOS}, + {Name: "Arch", Value: runtime.GOARCH}, + }, + }, + } + + // Run the command the first time, this will take about a minute. + cmd1 := rbe.Execute(flakyCmd, &rbetest.ExecuteOpts{CheckCache: true, InvocationID: "invocation1"}) + op1 := cmd1.WaitAccepted() + WaitForPendingExecution(rbe.GetRedisClient(), op1) + + // Wait for the /tmp/uuid file to exist before re-submitting flakyCmd. + found := false + for i := 0; i < 10; i++ { + probeCmd := rbe.Execute(existsCmd, &rbetest.ExecuteOpts{CheckCache: false, InvocationID: fmt.Sprintf("existence-probe-%d", i)}) + probeOp := probeCmd.Wait() + if strings.Trim(probeOp.Stdout, "\n") == "found" { + found = true + break + } + time.Sleep(100 * time.Millisecond) + } + require.True(t, found, "Expected remote execution system to begin cmd1") + + // The next action will merge against the slow op1 above, but also run a + // hedged action which should finish quickly. + cmd2 := rbe.Execute(flakyCmd, &rbetest.ExecuteOpts{CheckCache: true, InvocationID: "invocation2"}) + op2 := cmd2.WaitAccepted() + require.Equal(t, op1, op2, "expected actions to be merged") + + // Let the hedged execution finish. + time.Sleep(100 * time.Millisecond) + + // The action result should be cached, even though op1 is still running. + cmd3 := rbe.Execute(flakyCmd, &rbetest.ExecuteOpts{CheckCache: true, InvocationID: "invocation3"}) + op3 := cmd3.WaitAccepted() + require.NotEqual(t, op1, op3, "expected action to be hedged") + require.Equal(t, "FAST", strings.Trim(cmd3.Wait().Stdout, "\n")) +} + func TestAppShutdownDuringExecution_PublishOperationRetried(t *testing.T) { // Set a short progress publish interval since we want to test killing an // app while an update stream is in progress, and want to catch the error