From b2f1181072ddaf9f2238cedb623ac6f6d2398075 Mon Sep 17 00:00:00 2001 From: Brandon Duffany Date: Wed, 28 Feb 2024 12:06:05 -0500 Subject: [PATCH] Update replay_action to allow overriding the command (#6011) --- .../tools/replay_action/replay_action.go | 63 ++++++++++++++----- 1 file changed, 46 insertions(+), 17 deletions(-) diff --git a/enterprise/tools/replay_action/replay_action.go b/enterprise/tools/replay_action/replay_action.go index ac67b51289a..8e5eaeaf2fa 100644 --- a/enterprise/tools/replay_action/replay_action.go +++ b/enterprise/tools/replay_action/replay_action.go @@ -37,9 +37,10 @@ var ( targetRemoteInstanceName = flag.String("target_remote_instance_name", "", "The remote instance name used in the source action") // Less common options below. - targetHeaders = flag.Slice("target_headers", []string{}, "A list of headers to set (format: 'key=val'") - n = flag.Int("n", 1, "Number of times to replay the action. By default they'll be replayed in serial. Set --jobs to 2 or higher to run concurrently.") - jobs = flag.Int("jobs", 1, "Max number of concurrent jobs that can execute actions at once.") + overrideCommand = flag.String("override_command", "", "If set, run this script (with 'sh -c') instead of the original action command line. All other properties such as environment variables and platform properties will be preserved from the original command.") + targetHeaders = flag.Slice("target_headers", []string{}, "A list of headers to set (format: 'key=val'") + n = flag.Int("n", 1, "Number of times to replay the action. By default they'll be replayed in serial. Set --jobs to 2 or higher to run concurrently.") + jobs = flag.Int("jobs", 1, "Max number of concurrent jobs that can execute actions at once.") ) // Example usage: @@ -179,14 +180,10 @@ func main() { targetCtx = metadata.AppendToOutgoingContext(targetCtx, headersToSet...) } - log.Infof("Connecting to %q", *sourceExecutor) - var sourceBSClient, destBSClient bspb.ByteStreamClient - var sourceCASClient, destCASClient repb.ContentAddressableStorageClient - var execClient repb.ExecutionClient - sourceBSClient, execClient, sourceCASClient = getClients(*sourceExecutor) - if inCopyMode() { - destBSClient, execClient, destCASClient = getClients(*targetExecutor) - } + log.Infof("Connecting to source %q", *sourceExecutor) + sourceBSClient, _, sourceCASClient := getClients(*sourceExecutor) + log.Infof("Connecting to target %q", *targetExecutor) + destBSClient, execClient, destCASClient := getClients(*targetExecutor) // For backwards compatibility, attempt to fixup old style digest // strings that don't start with a '/blobs/' prefix. @@ -202,7 +199,6 @@ func main() { // Fetch the action to ensure it exists. action := &repb.Action{} - d := actionInstanceDigest.GetDigest() if err := cachetools.GetBlobAsProto(srcCtx, sourceBSClient, actionInstanceDigest, action); err != nil { log.Fatalf("Error fetching action: %s", err.Error()) } @@ -211,7 +207,7 @@ func main() { fmb := NewFindMissingBatcher(targetCtx, *targetRemoteInstanceName, destCASClient, FindMissingBatcherOpts{}) eg, targetCtx := errgroup.WithContext(targetCtx) eg.Go(func() error { - if err := copyFile(srcCtx, targetCtx, fmb, destBSClient, sourceBSClient, d, actionInstanceDigest.GetDigestFunction()); err != nil { + if err := copyFile(srcCtx, targetCtx, fmb, destBSClient, sourceBSClient, actionInstanceDigest.GetDigest(), actionInstanceDigest.GetDigestFunction()); err != nil { return status.WrapError(err, "copy action") } return nil @@ -238,10 +234,39 @@ func main() { } log.Infof("Finished copying files.") } + + // If we're overriding the command, do that now. + if *overrideCommand != "" { + // Download the command and update arguments. + sourceCRN := digest.NewResourceName(action.GetCommandDigest(), *sourceRemoteInstanceName, rspb.CacheType_CAS, actionInstanceDigest.GetDigestFunction()) + cmd := &repb.Command{} + if err := cachetools.GetBlobAsProto(srcCtx, sourceBSClient, sourceCRN, cmd); err != nil { + log.Fatalf("Failed to get command: %s", err) + } + cmd.Arguments = []string{"sh", "-c", *overrideCommand} + + // Upload the new command and action. + cd, err := cachetools.UploadProto(targetCtx, destBSClient, *targetRemoteInstanceName, actionInstanceDigest.GetDigestFunction(), cmd) + if err != nil { + log.Fatalf("Failed to upload new command: %s", err) + } + action = action.CloneVT() + action.CommandDigest = cd + ad, err := cachetools.UploadProto(targetCtx, destBSClient, *targetRemoteInstanceName, actionInstanceDigest.GetDigestFunction(), action) + if err != nil { + log.Fatalf("Failed to upload new action: %s", err) + } + + actionInstanceDigest = digest.NewResourceName(ad, *targetRemoteInstanceName, rspb.CacheType_CAS, actionInstanceDigest.GetDigestFunction()) + } + + if str, err := actionInstanceDigest.DownloadString(); err == nil { + log.Infof("Action resource name: %s", str) + } execReq := &repb.ExecuteRequest{ InstanceName: *targetRemoteInstanceName, SkipCacheLookup: true, - ActionDigest: d, + ActionDigest: actionInstanceDigest.GetDigest(), DigestFunction: actionInstanceDigest.GetDigestFunction(), } eg := &errgroup.Group{} @@ -273,7 +298,7 @@ func execute(ctx context.Context, execClient repb.ExecutionClient, bsClient bspb for { op, err := stream.Recv() if err != nil { - log.Fatalf("Error on stream: %s", err.Error()) + log.Fatalf("Execute stream recv failed: %s", err.Error()) } if !printedExecutionID { log.Infof("Started task %q", op.GetName()) @@ -306,8 +331,12 @@ func execute(ctx context.Context, execClient repb.ExecutionClient, bsClient bspb } // Print stdout and stderr but only when running a single action. if *n == 1 { - printOutputFile(ctx, bsClient, result.GetStdoutDigest(), rn.GetDigestFunction(), "stdout") - printOutputFile(ctx, bsClient, result.GetStderrDigest(), rn.GetDigestFunction(), "stderr") + if err := printOutputFile(ctx, bsClient, result.GetStdoutDigest(), rn.GetDigestFunction(), "stdout"); err != nil { + log.Warningf("Failed to get stdout: %s", err) + } + if err := printOutputFile(ctx, bsClient, result.GetStderrDigest(), rn.GetDigestFunction(), "stderr"); err != nil { + log.Warningf("Failed to get stderr: %s", err) + } } logExecutionMetadata(i, response.GetResult().GetExecutionMetadata()) break