diff --git a/changelog/fragments/1733936340-add-retries-for-download-upgrade-verifiers.yaml b/changelog/fragments/1733936340-add-retries-for-download-upgrade-verifiers.yaml new file mode 100644 index 00000000000..29ece746651 --- /dev/null +++ b/changelog/fragments/1733936340-add-retries-for-download-upgrade-verifiers.yaml @@ -0,0 +1,30 @@ +# Kind can be one of: +# - breaking-change: a change to previously-documented behavior +# - deprecation: functionality that is being removed in a later release +# - bug-fix: fixes a problem in a previous version +# - enhancement: extends functionality but does not break or fix existing behavior +# - feature: new functionality +# - known-issue: problems that we are aware of in a given version +# - security: impacts on the security of a product or a user’s deployment. +# - upgrade: important information for someone upgrading from a prior version +# - other: does not fit into any of the other categories +kind: bug-fix + +# Change summary; a 80ish characters long description of the change. +summary: added retries for requesting download verifiers when upgrading the agent + +# Long description; in case the summary is not enough to describe the change +# this field accommodate a description without length limits. +# NOTE: This field will be rendered only for breaking-change and known-issue kinds at the moment. +#description: + +# Affected component; usually one of "elastic-agent", "fleet-server", "filebeat", "metricbeat", "auditbeat", "all", etc. +component: "elastic-agent" +# PR URL; optional; the PR number that added the changeset. +# If not present is automatically filled by the tooling finding the PR where this changelog fragment has been added. +# NOTE: the tooling supports backports, so it's able to fill the original PR number instead of the backport PR number. +# Please provide it if you are adding a fragment for a different PR. +pr: https://github.com/elastic/elastic-agent/pull/6276 +# Issue URL; optional; the GitHub issue related to this changeset (either closes or is part of). +# If not present is automatically filled by the tooling with the issue linked to the PR number. +#issue: https://github.com/owner/repo/1234 diff --git a/internal/pkg/agent/application/upgrade/artifact/download/composed/verifier.go b/internal/pkg/agent/application/upgrade/artifact/download/composed/verifier.go index 7a3f0509a8a..6ca3ea015bb 100644 --- a/internal/pkg/agent/application/upgrade/artifact/download/composed/verifier.go +++ b/internal/pkg/agent/application/upgrade/artifact/download/composed/verifier.go @@ -5,6 +5,7 @@ package composed import ( + "context" goerrors "errors" "github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/artifact" @@ -39,11 +40,11 @@ func NewVerifier(log *logger.Logger, verifiers ...download.Verifier) *Verifier { } // Verify checks the package from configured source. -func (v *Verifier) Verify(a artifact.Artifact, version agtversion.ParsedSemVer, skipDefaultPgp bool, pgpBytes ...string) error { +func (v *Verifier) Verify(ctx context.Context, a artifact.Artifact, version agtversion.ParsedSemVer, skipDefaultPgp bool, pgpBytes ...string) error { var errs []error for _, verifier := range v.vv { - e := verifier.Verify(a, version, skipDefaultPgp, pgpBytes...) + e := verifier.Verify(ctx, a, version, skipDefaultPgp, pgpBytes...) if e == nil { // Success return nil diff --git a/internal/pkg/agent/application/upgrade/artifact/download/composed/verifier_test.go b/internal/pkg/agent/application/upgrade/artifact/download/composed/verifier_test.go index dcad62b7cef..ad3e6ffe749 100644 --- a/internal/pkg/agent/application/upgrade/artifact/download/composed/verifier_test.go +++ b/internal/pkg/agent/application/upgrade/artifact/download/composed/verifier_test.go @@ -5,6 +5,7 @@ package composed import ( + "context" "errors" "testing" @@ -24,7 +25,7 @@ func (d *ErrorVerifier) Name() string { return "error" } -func (d *ErrorVerifier) Verify(artifact.Artifact, agtversion.ParsedSemVer, bool, ...string) error { +func (d *ErrorVerifier) Verify(context.Context, artifact.Artifact, agtversion.ParsedSemVer, bool, ...string) error { d.called = true return errors.New("failing") } @@ -39,7 +40,7 @@ func (d *FailVerifier) Name() string { return "fail" } -func (d *FailVerifier) Verify(artifact.Artifact, agtversion.ParsedSemVer, bool, ...string) error { +func (d *FailVerifier) Verify(context.Context, artifact.Artifact, agtversion.ParsedSemVer, bool, ...string) error { d.called = true return &download.InvalidSignatureError{File: "", Err: errors.New("invalid signature")} } @@ -54,7 +55,7 @@ func (d *SuccVerifier) Name() string { return "succ" } -func (d *SuccVerifier) Verify(artifact.Artifact, agtversion.ParsedSemVer, bool, ...string) error { +func (d *SuccVerifier) Verify(context.Context, artifact.Artifact, agtversion.ParsedSemVer, bool, ...string) error { d.called = true return nil } @@ -90,7 +91,7 @@ func TestVerifier(t *testing.T) { testVersion := agtversion.NewParsedSemVer(1, 2, 3, "", "") for _, tc := range testCases { d := NewVerifier(log, tc.verifiers[0], tc.verifiers[1], tc.verifiers[2]) - err := d.Verify(artifact.Artifact{Name: "a", Cmd: "a", Artifact: "a/a"}, *testVersion, false) + err := d.Verify(context.Background(), artifact.Artifact{Name: "a", Cmd: "a", Artifact: "a/a"}, *testVersion, false) assert.Equal(t, tc.expectedResult, err == nil) diff --git a/internal/pkg/agent/application/upgrade/artifact/download/fs/verifier.go b/internal/pkg/agent/application/upgrade/artifact/download/fs/verifier.go index 210905f2047..4d52e61d48e 100644 --- a/internal/pkg/agent/application/upgrade/artifact/download/fs/verifier.go +++ b/internal/pkg/agent/application/upgrade/artifact/download/fs/verifier.go @@ -5,6 +5,7 @@ package fs import ( + "context" "fmt" "net/http" "os" @@ -65,7 +66,7 @@ func NewVerifier(log *logger.Logger, config *artifact.Config, pgp []byte) (*Veri // Verify checks downloaded package on preconfigured // location against a key stored on elastic.co website. -func (v *Verifier) Verify(a artifact.Artifact, version agtversion.ParsedSemVer, skipDefaultPgp bool, pgpBytes ...string) error { +func (v *Verifier) Verify(ctx context.Context, a artifact.Artifact, version agtversion.ParsedSemVer, skipDefaultPgp bool, pgpBytes ...string) error { filename, err := artifact.GetArtifactName(a, version, v.config.OS(), v.config.Arch()) if err != nil { return fmt.Errorf("could not get artifact name: %w", err) diff --git a/internal/pkg/agent/application/upgrade/artifact/download/fs/verifier_test.go b/internal/pkg/agent/application/upgrade/artifact/download/fs/verifier_test.go index f27cd899a84..18db67f7b5d 100644 --- a/internal/pkg/agent/application/upgrade/artifact/download/fs/verifier_test.go +++ b/internal/pkg/agent/application/upgrade/artifact/download/fs/verifier_test.go @@ -29,12 +29,11 @@ import ( var testVersion = agtversion.NewParsedSemVer(7, 5, 1, "", "") -var ( - agentSpec = artifact.Artifact{ - Name: "Elastic Agent", - Cmd: "elastic-agent", - Artifact: "beat/elastic-agent"} -) +var agentSpec = artifact.Artifact{ + Name: "Elastic Agent", + Cmd: "elastic-agent", + Artifact: "beat/elastic-agent", +} func TestFetchVerify(t *testing.T) { // See docs/pgp-sign-verify-artifact.md for how to generate a key, export @@ -47,7 +46,8 @@ func TestFetchVerify(t *testing.T) { targetPath := filepath.Join("testdata", "download") ctx := context.Background() a := artifact.Artifact{ - Name: "elastic-agent", Cmd: "elastic-agent", Artifact: "beats/elastic-agent"} + Name: "elastic-agent", Cmd: "elastic-agent", Artifact: "beats/elastic-agent", + } version := agtversion.NewParsedSemVer(8, 0, 0, "", "") filename := "elastic-agent-8.0.0-darwin-x86_64.tar.gz" @@ -80,7 +80,7 @@ func TestFetchVerify(t *testing.T) { // first download verify should fail: // download skipped, as invalid package is prepared upfront // verify fails and cleans download - err = verifier.Verify(a, *version, false) + err = verifier.Verify(ctx, a, *version, false) var checksumErr *download.ChecksumMismatchError require.ErrorAs(t, err, &checksumErr) @@ -109,7 +109,7 @@ func TestFetchVerify(t *testing.T) { _, err = os.Stat(ascTargetFilePath) require.NoError(t, err) - err = verifier.Verify(a, *version, false) + err = verifier.Verify(ctx, a, *version, false) require.NoError(t, err) // Bad GPG public key. @@ -126,7 +126,7 @@ func TestFetchVerify(t *testing.T) { // Missing .asc file. { - err = verifier.Verify(a, *version, false) + err = verifier.Verify(ctx, a, *version, false) require.Error(t, err) // Don't delete these files when GPG validation failure. @@ -139,7 +139,7 @@ func TestFetchVerify(t *testing.T) { err = os.WriteFile(targetFilePath+".asc", []byte("bad sig"), 0o600) require.NoError(t, err) - err = verifier.Verify(a, *version, false) + err = verifier.Verify(ctx, a, *version, false) var invalidSigErr *download.InvalidSignatureError assert.ErrorAs(t, err, &invalidSigErr) @@ -157,7 +157,8 @@ func prepareFetchVerifyTests( targetDir, filename, targetFilePath, - hashTargetFilePath string) error { + hashTargetFilePath string, +) error { sourceFilePath := filepath.Join(dropPath, filename) hashSourceFilePath := filepath.Join(dropPath, filename+".sha512") @@ -202,6 +203,7 @@ func TestVerify(t *testing.T) { for _, tc := range tt { t.Run(tc.Name, func(t *testing.T) { + ctx := context.Background() log, obs := loggertest.New("TestVerify") targetDir := t.TempDir() @@ -220,7 +222,7 @@ func TestVerify(t *testing.T) { pgpKey := prepareTestCase(t, agentSpec, testVersion, config) testClient := NewDownloader(config) - artifactPath, err := testClient.Download(context.Background(), agentSpec, testVersion) + artifactPath, err := testClient.Download(ctx, agentSpec, testVersion) require.NoError(t, err, "fs.Downloader could not download artifacts") _, err = testClient.DownloadAsc(context.Background(), agentSpec, *testVersion) require.NoError(t, err, "fs.Downloader could not download artifacts .asc file") @@ -231,7 +233,7 @@ func TestVerify(t *testing.T) { testVerifier, err := NewVerifier(log, config, pgpKey) require.NoError(t, err) - err = testVerifier.Verify(agentSpec, *testVersion, false, tc.RemotePGPUris...) + err = testVerifier.Verify(ctx, agentSpec, *testVersion, false, tc.RemotePGPUris...) require.NoError(t, err) // log message informing remote PGP was skipped @@ -246,7 +248,6 @@ func TestVerify(t *testing.T) { // It creates the necessary key to sing the artifact and returns the public key // to verify the signature. func prepareTestCase(t *testing.T, a artifact.Artifact, version *agtversion.ParsedSemVer, cfg *artifact.Config) []byte { - filename, err := artifact.GetArtifactName(a, *version, cfg.OperatingSystem, cfg.Architecture) require.NoErrorf(t, err, "could not get artifact name") diff --git a/internal/pkg/agent/application/upgrade/artifact/download/http/common_test.go b/internal/pkg/agent/application/upgrade/artifact/download/http/common_test.go index 0a3decbab87..84eb82dec84 100644 --- a/internal/pkg/agent/application/upgrade/artifact/download/http/common_test.go +++ b/internal/pkg/agent/application/upgrade/artifact/download/http/common_test.go @@ -55,7 +55,28 @@ func getTestCases() []testCase { } } -func getElasticCoServer(t *testing.T) (*httptest.Server, []byte) { +type extResCode map[string]struct { + resCode int + count int +} + +type testDials struct { + extResCode +} + +func (td *testDials) withExtResCode(k string, statusCode int, count int) { + td.extResCode[k] = struct { + resCode int + count int + }{statusCode, count} +} + +func (td *testDials) reset() { + *td = testDials{extResCode: make(extResCode)} +} + +func getElasticCoServer(t *testing.T) (*httptest.Server, []byte, *testDials) { + td := testDials{extResCode: make(extResCode)} correctValues := map[string]struct{}{ fmt.Sprintf("%s-%s-%s", beatSpec.Cmd, version, "i386.deb"): {}, fmt.Sprintf("%s-%s-%s", beatSpec.Cmd, version, "amd64.deb"): {}, @@ -81,7 +102,6 @@ func getElasticCoServer(t *testing.T) (*httptest.Server, []byte) { ext = ".tar.gz" } packageName = strings.TrimSuffix(packageName, ext) - switch ext { case ".sha512": resp = []byte(fmt.Sprintf("%x %s", hash, packageName)) @@ -103,11 +123,17 @@ func getElasticCoServer(t *testing.T) (*httptest.Server, []byte) { return } + if v, ok := td.extResCode[ext]; ok && v.count != 0 { + w.WriteHeader(v.resCode) + v.count-- + td.extResCode[ext] = v + } + _, err := w.Write(resp) assert.NoErrorf(t, err, "mock elastic.co server: failes writing response") }) - return httptest.NewServer(handler), pub + return httptest.NewServer(handler), pub, &td } func getElasticCoClient(server *httptest.Server) http.Client { diff --git a/internal/pkg/agent/application/upgrade/artifact/download/http/downloader_test.go b/internal/pkg/agent/application/upgrade/artifact/download/http/downloader_test.go index 1d59da4e977..8cf21a86818 100644 --- a/internal/pkg/agent/application/upgrade/artifact/download/http/downloader_test.go +++ b/internal/pkg/agent/application/upgrade/artifact/download/http/downloader_test.go @@ -44,7 +44,7 @@ func TestDownload(t *testing.T) { log, _ := logger.New("", false) timeout := 30 * time.Second testCases := getTestCases() - server, _ := getElasticCoServer(t) + server, _, _ := getElasticCoServer(t) elasticClient := getElasticCoClient(server) config := &artifact.Config{ @@ -359,7 +359,6 @@ type downloadHttpResponse struct { } func TestDownloadVersion(t *testing.T) { - type fields struct { config *artifact.Config } @@ -485,7 +484,6 @@ func TestDownloadVersion(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - targetDirPath := t.TempDir() handleDownload := func(rw http.ResponseWriter, req *http.Request) { @@ -527,5 +525,4 @@ func TestDownloadVersion(t *testing.T) { assert.Equalf(t, filepath.Join(targetDirPath, tt.want), got, "Download(%v, %v)", tt.args.a, tt.args.version) }) } - } diff --git a/internal/pkg/agent/application/upgrade/artifact/download/http/verifier.go b/internal/pkg/agent/application/upgrade/artifact/download/http/verifier.go index 4657f92659a..e0abbcc97c6 100644 --- a/internal/pkg/agent/application/upgrade/artifact/download/http/verifier.go +++ b/internal/pkg/agent/application/upgrade/artifact/download/http/verifier.go @@ -53,6 +53,9 @@ func NewVerifier(log *logger.Logger, config *artifact.Config, pgp []byte) (*Veri httpcommon.WithModRoundtripper(func(rt http.RoundTripper) http.RoundTripper { return download.WithHeaders(rt, download.Headers) }), + httpcommon.WithModRoundtripper(func(rt http.RoundTripper) http.RoundTripper { + return WithBackoff(rt, log) + }), ) if err != nil { return nil, err @@ -88,7 +91,7 @@ func (v *Verifier) Reload(c *artifact.Config) error { // Verify checks downloaded package on preconfigured // location against a key stored on elastic.co website. -func (v *Verifier) Verify(a artifact.Artifact, version agtversion.ParsedSemVer, skipDefaultPgp bool, pgpBytes ...string) error { +func (v *Verifier) Verify(ctx context.Context, a artifact.Artifact, version agtversion.ParsedSemVer, skipDefaultPgp bool, pgpBytes ...string) error { artifactPath, err := artifact.GetArtifactPath(a, version, v.config.OS(), v.config.Arch(), v.config.TargetDirectory) if err != nil { return errors.New(err, "retrieving package path") @@ -98,7 +101,7 @@ func (v *Verifier) Verify(a artifact.Artifact, version agtversion.ParsedSemVer, return fmt.Errorf("failed to verify SHA512 hash: %w", err) } - if err = v.verifyAsc(a, version, skipDefaultPgp, pgpBytes...); err != nil { + if err = v.verifyAsc(ctx, a, version, skipDefaultPgp, pgpBytes...); err != nil { var invalidSignatureErr *download.InvalidSignatureError if errors.As(err, &invalidSignatureErr) { if err := os.Remove(artifactPath); err != nil { @@ -116,7 +119,7 @@ func (v *Verifier) Verify(a artifact.Artifact, version agtversion.ParsedSemVer, return nil } -func (v *Verifier) verifyAsc(a artifact.Artifact, version agtversion.ParsedSemVer, skipDefaultKey bool, pgpSources ...string) error { +func (v *Verifier) verifyAsc(ctx context.Context, a artifact.Artifact, version agtversion.ParsedSemVer, skipDefaultKey bool, pgpSources ...string) error { filename, err := artifact.GetArtifactName(a, version, v.config.OS(), v.config.Arch()) if err != nil { return errors.New(err, "retrieving package name") @@ -132,7 +135,7 @@ func (v *Verifier) verifyAsc(a artifact.Artifact, version agtversion.ParsedSemVe return errors.New(err, "composing URI for fetching asc file", errors.TypeNetwork) } - ascBytes, err := v.getPublicAsc(ascURI) + ascBytes, err := v.getPublicAsc(ctx, ascURI) if err != nil { return errors.New(err, fmt.Sprintf("fetching asc file from %s", ascURI), errors.TypeNetwork, errors.M(errors.MetaKeyURI, ascURI)) } @@ -163,8 +166,8 @@ func (v *Verifier) composeURI(filename, artifactName string) (string, error) { return uri.String(), nil } -func (v *Verifier) getPublicAsc(sourceURI string) ([]byte, error) { - ctx, cancelFn := context.WithTimeout(context.Background(), 30*time.Second) +func (v *Verifier) getPublicAsc(ctx context.Context, sourceURI string) ([]byte, error) { + ctx, cancelFn := context.WithTimeout(ctx, 30*time.Second) defer cancelFn() req, err := http.NewRequestWithContext(ctx, http.MethodGet, sourceURI, nil) if err != nil { diff --git a/internal/pkg/agent/application/upgrade/artifact/download/http/verifier_test.go b/internal/pkg/agent/application/upgrade/artifact/download/http/verifier_test.go index e477db3e227..2923a4d3845 100644 --- a/internal/pkg/agent/application/upgrade/artifact/download/http/verifier_test.go +++ b/internal/pkg/agent/application/upgrade/artifact/download/http/verifier_test.go @@ -9,6 +9,7 @@ import ( "fmt" "math/rand/v2" "net/http" + "net/http/httptest" "net/url" "os" @@ -30,7 +31,7 @@ func TestVerify(t *testing.T) { log, _ := logger.New("", false) timeout := 30 * time.Second testCases := getRandomTestCases()[0:1] - server, pub := getElasticCoServer(t) + server, pub, td := getElasticCoServer(t) config := &artifact.Config{ SourceURI: server.URL + "/downloads", @@ -41,7 +42,7 @@ func TestVerify(t *testing.T) { } t.Run("without proxy", func(t *testing.T) { - runTests(t, testCases, config, log, pub) + runTests(t, testCases, td, config, log, pub) }) t.Run("with proxy", func(t *testing.T) { @@ -72,14 +73,21 @@ func TestVerify(t *testing.T) { URL: (*httpcommon.ProxyURI)(proxyURL), } - runTests(t, testCases, &config, log, pub) + runTests(t, testCases, td, &config, log, pub) }) } -func runTests(t *testing.T, testCases []testCase, config *artifact.Config, log *logger.Logger, pub []byte) { +func runTests(t *testing.T, testCases []testCase, td *testDials, config *artifact.Config, log *logger.Logger, pub []byte) { for _, tc := range testCases { testName := fmt.Sprintf("%s-binary-%s", tc.system, tc.arch) t.Run(testName, func(t *testing.T) { + td.withExtResCode(".asc", 500, 2) + defer td.reset() + + cancelDeadline := time.Now().Add(config.Timeout) + cancelCtx, cancel := context.WithDeadline(context.Background(), cancelDeadline) + defer cancel() + config.OperatingSystem = tc.system config.Architecture = tc.arch @@ -88,7 +96,7 @@ func runTests(t *testing.T, testCases []testCase, config *artifact.Config, log * downloader, err := NewDownloader(log, config, upgradeDetails) require.NoError(t, err, "could not create new downloader") - pkgPath, err := downloader.Download(context.Background(), beatSpec, version) + pkgPath, err := downloader.Download(cancelCtx, beatSpec, version) require.NoErrorf(t, err, "failed downloading %s v%s", beatSpec.Artifact, version) @@ -102,7 +110,7 @@ func runTests(t *testing.T, testCases []testCase, config *artifact.Config, log * t.Fatal(err) } - err = testVerifier.Verify(beatSpec, *version, false) + err = testVerifier.Verify(cancelCtx, beatSpec, *version, false) require.NoError(t, err) }) } diff --git a/internal/pkg/agent/application/upgrade/artifact/download/http/verify_backoff_rtt.go b/internal/pkg/agent/application/upgrade/artifact/download/http/verify_backoff_rtt.go new file mode 100644 index 00000000000..0f90e7ae657 --- /dev/null +++ b/internal/pkg/agent/application/upgrade/artifact/download/http/verify_backoff_rtt.go @@ -0,0 +1,89 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License 2.0; +// you may not use this file except in compliance with the Elastic License 2.0. + +package http + +import ( + "bytes" + "fmt" + "io" + "net/http" + "time" + + "github.com/cenkalti/backoff/v4" + + "github.com/elastic/elastic-agent/internal/pkg/agent/errors" + "github.com/elastic/elastic-agent/pkg/core/logger" +) + +func WithBackoff(rtt http.RoundTripper, logger *logger.Logger) http.RoundTripper { + if rtt == nil { + rtt = http.DefaultTransport + } + + return &BackoffRoundTripper{next: rtt, logger: logger} +} + +type BackoffRoundTripper struct { + next http.RoundTripper + logger *logger.Logger +} + +func (btr *BackoffRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + exp := backoff.NewExponentialBackOff() + boCtx := backoff.WithContext(exp, req.Context()) + + opNotify := func(err error, retryAfter time.Duration) { + btr.logger.Warnf("request failed: %s, retrying in %s", err, retryAfter) + } + + var resp *http.Response + var err error + var resettableBody *bytes.Reader + + if req.Body != nil { + data, err := io.ReadAll(req.Body) + if err != nil { + return nil, fmt.Errorf("failed to read request body: %w", err) + } + req.Body.Close() + + resettableBody = bytes.NewReader(data) + req.Body = io.NopCloser(resettableBody) + } + // opFunc implements the retry logic for the backoff mechanism. + // + // - For each attempt, the request body is reset (if non-nil) to allow reuse. + // - Requests with errors or responses with status >= 400 trigger retries. + // - The response body is closed for failed requests to free resources. + // - A successful request (status < 400) stops the retries and returns the response. + attempt := 1 + opFunc := func() error { + if resettableBody != nil { + _, err = resettableBody.Seek(0, io.SeekStart) + if err != nil { + btr.logger.Errorf("error while resetting request body: %w", err) + } + } + + attempt++ + resp, err = btr.next.RoundTrip(req) //nolint:bodyclose // the response body is closed when status code >= 400 or it is closed by the caller + if err != nil { + btr.logger.Errorf("attempt %d: error round-trip: %w", err) + return err + } + + if resp.StatusCode >= 400 { + if err := resp.Body.Close(); err != nil { + btr.logger.Errorf("attempt %d: error closing the response body: %w", attempt, err) + } + btr.logger.Errorf("attempt %d: received response status: %d", attempt, resp.StatusCode) + return errors.New(fmt.Sprintf("received response status: %d", resp.StatusCode)) + } + + return nil + } + + return resp, backoff.RetryNotify(opFunc, boCtx, opNotify) +} diff --git a/internal/pkg/agent/application/upgrade/artifact/download/http/verify_backoff_rtt_test.go b/internal/pkg/agent/application/upgrade/artifact/download/http/verify_backoff_rtt_test.go new file mode 100644 index 00000000000..436911ceceb --- /dev/null +++ b/internal/pkg/agent/application/upgrade/artifact/download/http/verify_backoff_rtt_test.go @@ -0,0 +1,83 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License 2.0; +// you may not use this file except in compliance with the Elastic License 2.0. + +package http + +import ( + "bytes" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/elastic/elastic-agent-libs/logp" +) + +func TestVerifyBackoffRoundtripper(t *testing.T) { + t.Run("test get request retry", func(t *testing.T) { + failedResCounter := 2 + handler := func(rw http.ResponseWriter, req *http.Request) { + if failedResCounter > 0 { + rw.WriteHeader(http.StatusInternalServerError) + failedResCounter-- + } + _, err := rw.Write([]byte("hello")) + require.NoError(t, err) + } + server := httptest.NewServer(http.HandlerFunc(handler)) + client := http.Client{ + Transport: WithBackoff(&http.Transport{}, logp.NewLogger("testing")), + Timeout: 10 * time.Second, + } + + res, err := client.Get(server.URL) //nolint:noctx // test code + require.NoError(t, err) + defer res.Body.Close() + + body, err := io.ReadAll(res.Body) + require.NoError(t, err) + + require.Equal(t, string(body), "hello") + require.Equal(t, res.StatusCode, 200) + require.Equal(t, failedResCounter, 0) + }) + + t.Run("test post request with body", func(t *testing.T) { + failedResCounter := 2 + handler := func(rw http.ResponseWriter, req *http.Request) { + if failedResCounter > 0 { + rw.WriteHeader(http.StatusInternalServerError) + failedResCounter-- + } + + body, err := io.ReadAll(req.Body) + require.NoError(t, err) + defer req.Body.Close() + + _, err = rw.Write(body) + require.NoError(t, err) + } + server := httptest.NewServer(http.HandlerFunc(handler)) + client := http.Client{ + Transport: WithBackoff(&http.Transport{}, logp.NewLogger("testing")), + Timeout: 10 * time.Second, + } + + reqReader := bytes.NewReader([]byte("hello")) + + resp, err := client.Post(server.URL, "text/html", reqReader) //nolint:noctx // test code + require.NoError(t, err) + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + require.Equal(t, string(body), "hello") + require.Equal(t, resp.StatusCode, 200) + require.Equal(t, failedResCounter, 0) + }) +} diff --git a/internal/pkg/agent/application/upgrade/artifact/download/snapshot/verifier.go b/internal/pkg/agent/application/upgrade/artifact/download/snapshot/verifier.go index 9910e0e06da..7ce05055d60 100644 --- a/internal/pkg/agent/application/upgrade/artifact/download/snapshot/verifier.go +++ b/internal/pkg/agent/application/upgrade/artifact/download/snapshot/verifier.go @@ -30,7 +30,6 @@ func (v *Verifier) Name() string { // NewVerifier creates a downloader which first checks local directory // and then fallbacks to remote if configured. func NewVerifier(log *logger.Logger, config *artifact.Config, pgp []byte, versionOverride *agtversion.ParsedSemVer) (download.Verifier, error) { - client, err := config.HTTPTransportSettings.Client(httpcommon.WithAPMHTTPInstrumentation()) if err != nil { return nil, err @@ -54,9 +53,9 @@ func NewVerifier(log *logger.Logger, config *artifact.Config, pgp []byte, versio } // Verify checks the package from configured source. -func (v *Verifier) Verify(a artifact.Artifact, version agtversion.ParsedSemVer, skipDefaultPgp bool, pgpBytes ...string) error { +func (v *Verifier) Verify(ctx context.Context, a artifact.Artifact, version agtversion.ParsedSemVer, skipDefaultPgp bool, pgpBytes ...string) error { strippedVersion := agtversion.NewParsedSemVer(version.Major(), version.Minor(), version.Patch(), version.Prerelease(), "") - return v.verifier.Verify(a, *strippedVersion, skipDefaultPgp, pgpBytes...) + return v.verifier.Verify(ctx, a, *strippedVersion, skipDefaultPgp, pgpBytes...) } func (v *Verifier) Reload(c *artifact.Config) error { diff --git a/internal/pkg/agent/application/upgrade/artifact/download/verifier.go b/internal/pkg/agent/application/upgrade/artifact/download/verifier.go index 3c2cf06715c..67d16076f4e 100644 --- a/internal/pkg/agent/application/upgrade/artifact/download/verifier.go +++ b/internal/pkg/agent/application/upgrade/artifact/download/verifier.go @@ -2,6 +2,8 @@ // or more contributor license agreements. Licensed under the Elastic License 2.0; // you may not use this file except in compliance with the Elastic License 2.0. +// you may not use this file except in compliance with the Elastic License 2.0. + package download import ( @@ -84,7 +86,7 @@ type Verifier interface { // If the checksum does no match Verify returns a *download.ChecksumMismatchError. // If the PGP signature check fails then Verify returns a // *download.InvalidSignatureError. - Verify(a artifact.Artifact, version agtversion.ParsedSemVer, skipDefaultPgp bool, pgpBytes ...string) error + Verify(ctx context.Context, a artifact.Artifact, version agtversion.ParsedSemVer, skipDefaultPgp bool, pgpBytes ...string) error } // VerifySHA512HashWithCleanup calls VerifySHA512Hash and, in case of a @@ -210,7 +212,8 @@ func readChecksumFile(checksumFile, filename string) (string, error) { } func VerifyPGPSignatureWithKeys( - log infoWarnLogger, file string, asciiArmorSignature []byte, publicKeys [][]byte) error { + log infoWarnLogger, file string, asciiArmorSignature []byte, publicKeys [][]byte, +) error { var err error for i, key := range publicKeys { err = VerifyPGPSignature(file, asciiArmorSignature, key) diff --git a/internal/pkg/agent/application/upgrade/step_download.go b/internal/pkg/agent/application/upgrade/step_download.go index fb38e93972c..58d56c81f52 100644 --- a/internal/pkg/agent/application/upgrade/step_download.go +++ b/internal/pkg/agent/application/upgrade/step_download.go @@ -116,7 +116,7 @@ func (u *Upgrader) downloadArtifact(ctx context.Context, parsedVersion *agtversi } } - if err := verifier.Verify(agentArtifact, *parsedVersion, skipDefaultPgp, pgpBytes...); err != nil { + if err := verifier.Verify(ctx, agentArtifact, *parsedVersion, skipDefaultPgp, pgpBytes...); err != nil { return "", errors.New(err, "failed verification of agent binary") } return path, nil