From 690fcaaa951dc7d2bafa717b23d9d15e2f8424ec Mon Sep 17 00:00:00 2001 From: Eren Yeager <92114074+wty-Bryant@users.noreply.github.com> Date: Wed, 6 Dec 2023 15:12:51 -0500 Subject: [PATCH] Add Gzip request compression feature (#467) * Add and Merge request compression feature * Modify and Merge request compression codegen * Add and Merge changelog for last commit * Modify logic of request compression middleware * Add request compression algorithm codegen part * resolve METAINFO conflict * Change dependency format * Revert dependency format * Change request compression middleware to operation level * Change codegen comment * Change static middleware import * Change go dependency codegen * Add body compare fn to request compress op unit test * Solve rebase conflict --------- Co-authored-by: Tianyi Wang --- .../80ed28327bcd4301a264f318efaf8216.json | 8 + .../smithy/go/codegen/SmithyGoDependency.java | 2 + .../smithy/go/codegen/SmithyGoTypes.java | 7 + .../HttpProtocolUnitTestRequestGenerator.java | 52 +++++- .../RequestCompression.java | 160 ++++++++++++++++++ ...mithy.go.codegen.integration.GoIntegration | 2 + private/requestcompression/gzip.go | 30 ++++ .../middleware_capture_request_compression.go | 52 ++++++ .../requestcompression/request_compression.go | 103 +++++++++++ .../request_compression_test.go | 125 ++++++++++++++ testing/bytes.go | 35 ++++ testing/gzip.go | 27 +++ 12 files changed, 602 insertions(+), 1 deletion(-) create mode 100644 .changelog/80ed28327bcd4301a264f318efaf8216.json create mode 100644 codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/requestcompression/RequestCompression.java create mode 100644 private/requestcompression/gzip.go create mode 100644 private/requestcompression/middleware_capture_request_compression.go create mode 100644 private/requestcompression/request_compression.go create mode 100644 private/requestcompression/request_compression_test.go create mode 100644 testing/gzip.go diff --git a/.changelog/80ed28327bcd4301a264f318efaf8216.json b/.changelog/80ed28327bcd4301a264f318efaf8216.json new file mode 100644 index 000000000..74bf98c30 --- /dev/null +++ b/.changelog/80ed28327bcd4301a264f318efaf8216.json @@ -0,0 +1,8 @@ +{ + "id": "80ed2832-7bcd-4301-a264-f318efaf8216", + "type": "feature", + "description": "Support modeled request compression.", + "modules": [ + "." + ] +} \ No newline at end of file diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/SmithyGoDependency.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/SmithyGoDependency.java index ec310fd45..4f613e1ad 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/SmithyGoDependency.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/SmithyGoDependency.java @@ -53,6 +53,8 @@ public final class SmithyGoDependency { public static final GoDependency SMITHY_HTTP_TRANSPORT = smithy("transport/http", "smithyhttp"); public static final GoDependency SMITHY_MIDDLEWARE = smithy("middleware"); public static final GoDependency SMITHY_PRIVATE_PROTOCOL = smithy("private/protocol", "smithyprivateprotocol"); + public static final GoDependency SMITHY_REQUEST_COMPRESSION = + smithy("private/requestcompression", "smithyrequestcompression"); public static final GoDependency SMITHY_TIME = smithy("time", "smithytime"); public static final GoDependency SMITHY_HTTP_BINDING = smithy("encoding/httpbinding"); public static final GoDependency SMITHY_JSON = smithy("encoding/json", "smithyjson"); diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/SmithyGoTypes.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/SmithyGoTypes.java index c4e2b75c6..50a34ea42 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/SmithyGoTypes.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/SmithyGoTypes.java @@ -105,4 +105,11 @@ public static final class Bearer { public static final Symbol NewSignHTTPSMessage = SmithyGoDependency.SMITHY_AUTH_BEARER.valueSymbol("NewSignHTTPSMessage"); } } + + public static final class Private { + public static final class RequestCompression { + public static final Symbol AddRequestCompression = SmithyGoDependency.SMITHY_REQUEST_COMPRESSION.valueSymbol("AddRequestCompression"); + public static final Symbol AddCaptureUncompressedRequest = SmithyGoDependency.SMITHY_REQUEST_COMPRESSION.valueSymbol("AddCaptureUncompressedRequestMiddleware"); + } + } } diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolUnitTestRequestGenerator.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolUnitTestRequestGenerator.java index 15d4adf6a..3c4f3cea8 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolUnitTestRequestGenerator.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolUnitTestRequestGenerator.java @@ -17,13 +17,22 @@ package software.amazon.smithy.go.codegen.integration; +import static software.amazon.smithy.go.codegen.GoWriter.goTemplate; +import static software.amazon.smithy.go.codegen.SmithyGoTypes.Private.RequestCompression.AddCaptureUncompressedRequest; + +import java.util.Arrays; +import java.util.HashSet; +import java.util.Set; import java.util.function.Consumer; import java.util.logging.Logger; import software.amazon.smithy.codegen.core.Symbol; import software.amazon.smithy.go.codegen.GoWriter; import software.amazon.smithy.go.codegen.SmithyGoDependency; +import software.amazon.smithy.go.codegen.SmithyGoTypes; import software.amazon.smithy.go.codegen.SymbolUtils; +import software.amazon.smithy.model.traits.RequestCompressionTrait; import software.amazon.smithy.protocoltests.traits.HttpRequestTestCase; +import software.amazon.smithy.utils.MapUtils; /** * Generates HTTP protocol unit tests for HTTP request test cases. @@ -31,6 +40,8 @@ public class HttpProtocolUnitTestRequestGenerator extends HttpProtocolUnitTestGenerator { private static final Logger LOGGER = Logger.getLogger(HttpProtocolUnitTestRequestGenerator.class.getName()); + private static final Set ALLOWED_ALGORITHMS = new HashSet<>(Arrays.asList("gzip")); + /** * Initializes the protocol test generator. * @@ -198,6 +209,10 @@ protected void generateTestCaseValues(GoWriter writer, HttpRequestTestCase testC */ protected void generateTestBodySetup(GoWriter writer) { writer.write("actualReq := &http.Request{}"); + if (operation.hasTrait(RequestCompressionTrait.class)) { + writer.addUseImports(SmithyGoDependency.BYTES); + writer.write("rawBodyBuf := &bytes.Buffer{}"); + } } /** @@ -227,8 +242,29 @@ protected void generateTestInvokeClientOperation(GoWriter writer, String clientN writer.write("return $T(stack, actualReq)", SymbolUtils.createValueSymbolBuilder("AddCaptureRequestMiddleware", SmithyGoDependency.SMITHY_PRIVATE_PROTOCOL).build()); - }); + }); + if (operation.hasTrait(RequestCompressionTrait.class)) { + writer.write(goTemplate(""" + options.APIOptions = append(options.APIOptions, func(stack $stack:P) error { + return $captureRequest:T(stack, rawBodyBuf) + }) + """, + MapUtils.of( + "stack", SmithyGoTypes.Middleware.Stack, + "captureRequest", AddCaptureUncompressedRequest + ))); + } }); + + if (operation.hasTrait(RequestCompressionTrait.class)) { + writer.write(goTemplate(""" + disable := $client:L.Options().DisableRequestCompression + min := $client:L.Options().RequestMinCompressSizeBytes + """, + MapUtils.of( + "client", clientName + ))); + } } /** @@ -259,6 +295,20 @@ protected void generateTestAssertions(GoWriter writer) { writer.write("t.Errorf(\"expect body equal, got %v\", err)"); }); }); + + if (operation.hasTrait(RequestCompressionTrait.class)) { + String algorithm = operation.expectTrait(RequestCompressionTrait.class).getEncodings() + .stream().filter(it -> ALLOWED_ALGORITHMS.contains(it)).findFirst().get(); + writer.write(goTemplate(""" + if err := smithytesting.CompareCompressedBytes(rawBodyBuf, actualReq.Body, + disable, min, $algorithm:S); err != nil { + t.Errorf("unzipped request body not match: %q", err) + } + """, + MapUtils.of( + "algorithm", algorithm + ))); + } } public static class Builder extends HttpProtocolUnitTestGenerator.Builder { diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/requestcompression/RequestCompression.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/requestcompression/RequestCompression.java new file mode 100644 index 000000000..b6f7779c8 --- /dev/null +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/requestcompression/RequestCompression.java @@ -0,0 +1,160 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.smithy.go.codegen.requestcompression; + +import static software.amazon.smithy.go.codegen.GoWriter.goTemplate; + +import java.util.ArrayList; +import java.util.List; +import software.amazon.smithy.codegen.core.SymbolProvider; +import software.amazon.smithy.go.codegen.GoCodegenPlugin; +import software.amazon.smithy.go.codegen.GoDelegator; +import software.amazon.smithy.go.codegen.GoSettings; +import software.amazon.smithy.go.codegen.GoUniverseTypes; +import software.amazon.smithy.go.codegen.GoWriter; +import software.amazon.smithy.go.codegen.SmithyGoTypes; +import software.amazon.smithy.go.codegen.SymbolUtils; +import software.amazon.smithy.go.codegen.integration.ConfigField; +import software.amazon.smithy.go.codegen.integration.GoIntegration; +import software.amazon.smithy.go.codegen.integration.MiddlewareRegistrar; +import software.amazon.smithy.go.codegen.integration.RuntimeClientPlugin; +import software.amazon.smithy.model.Model; +import software.amazon.smithy.model.knowledge.TopDownIndex; +import software.amazon.smithy.model.shapes.OperationShape; +import software.amazon.smithy.model.shapes.ServiceShape; +import software.amazon.smithy.model.shapes.ShapeId; +import software.amazon.smithy.model.traits.RequestCompressionTrait; +import software.amazon.smithy.utils.ListUtils; +import software.amazon.smithy.utils.MapUtils; + + +public final class RequestCompression implements GoIntegration { + private static final String DISABLE_REQUEST_COMPRESSION = "DisableRequestCompression"; + + private static final String REQUEST_MIN_COMPRESSION_SIZE_BYTES = "RequestMinCompressSizeBytes"; + + private final List runtimeClientPlugins = new ArrayList<>(); + + // Write operation plugin for request compression middleware + @Override + public void processFinalizedModel(GoSettings settings, Model model) { + ServiceShape service = settings.getService(model); + TopDownIndex.of(model) + .getContainedOperations(service).forEach(operation -> { + if (!operation.hasTrait(RequestCompressionTrait.class)) { + return; + } + SymbolProvider symbolProvider = GoCodegenPlugin.createSymbolProvider(model, settings); + String funcName = getAddRequestCompressionMiddlewareFuncName( + symbolProvider.toSymbol(operation).getName() + ); + runtimeClientPlugins.add(RuntimeClientPlugin.builder().operationPredicate((m, s, o) -> { + if (!o.hasTrait(RequestCompressionTrait.class)) { + return false; + } + return o.equals(operation); + }).registerMiddleware(MiddlewareRegistrar.builder() + .resolvedFunction(SymbolUtils.buildPackageSymbol(funcName)) + .useClientOptions().build()) + .build()); + }); + } + + @Override + public void writeAdditionalFiles( + GoSettings settings, + Model model, + SymbolProvider symbolProvider, + GoDelegator goDelegator + ) { + ServiceShape service = settings.getService(model); + for (ShapeId operationID : service.getAllOperations()) { + OperationShape operation = model.expectShape(operationID, OperationShape.class); + if (!operation.hasTrait(RequestCompressionTrait.class)) { + continue; + } + goDelegator.useShapeWriter(operation, writeMiddlewareHelper(symbolProvider, operation)); + } + } + + + public static boolean isRequestCompressionService(Model model, ServiceShape service) { + return TopDownIndex.of(model) + .getContainedOperations(service).stream() + .anyMatch(it -> it.hasTrait(RequestCompressionTrait.class)); + } + + @Override + public List getClientPlugins() { + runtimeClientPlugins.add( + RuntimeClientPlugin.builder() + .servicePredicate(RequestCompression::isRequestCompressionService) + .configFields(ListUtils.of( + ConfigField.builder() + .name(DISABLE_REQUEST_COMPRESSION) + .type(GoUniverseTypes.Bool) + .documentation( + "Whether to disable automatic request compression for supported operations.") + .build(), + ConfigField.builder() + .name(REQUEST_MIN_COMPRESSION_SIZE_BYTES) + .type(GoUniverseTypes.Int64) + .documentation("The minimum request body size, in bytes, at which compression " + + "should occur. The default value is 10 KiB. Values must fall within " + + "[0, 1MiB].") + .build() + )) + .build() + ); + + return runtimeClientPlugins; + } + + private GoWriter.Writable generateAlgorithmList(List algorithms) { + return goTemplate(""" + []string{ + $W + } + """, + GoWriter.ChainWritable.of( + algorithms.stream() + .map(it -> goTemplate("$S,", it)) + .toList() + ).compose(false)); + } + + private static String getAddRequestCompressionMiddlewareFuncName(String operationName) { + return String.format("addOperation%sRequestCompressionMiddleware", operationName); + } + + private GoWriter.Writable writeMiddlewareHelper(SymbolProvider symbolProvider, OperationShape operation) { + String operationName = symbolProvider.toSymbol(operation).getName(); + RequestCompressionTrait trait = operation.expectTrait(RequestCompressionTrait.class); + + return goTemplate(""" + func $add:L(stack $stack:P, options Options) error { + return $addInternal:T(stack, options.DisableRequestCompression, options.RequestMinCompressSizeBytes, + $algorithms:W) + } + """, + MapUtils.of( + "add", getAddRequestCompressionMiddlewareFuncName(operationName), + "stack", SmithyGoTypes.Middleware.Stack, + "addInternal", SmithyGoTypes.Private.RequestCompression.AddRequestCompression, + "algorithms", generateAlgorithmList(trait.getEncodings()) + )); + } +} diff --git a/codegen/smithy-go-codegen/src/main/resources/META-INF/services/software.amazon.smithy.go.codegen.integration.GoIntegration b/codegen/smithy-go-codegen/src/main/resources/META-INF/services/software.amazon.smithy.go.codegen.integration.GoIntegration index af72bfc17..09af27932 100644 --- a/codegen/smithy-go-codegen/src/main/resources/META-INF/services/software.amazon.smithy.go.codegen.integration.GoIntegration +++ b/codegen/smithy-go-codegen/src/main/resources/META-INF/services/software.amazon.smithy.go.codegen.integration.GoIntegration @@ -12,3 +12,5 @@ software.amazon.smithy.go.codegen.endpoints.EndpointClientPluginsGenerator # modeled auth schemes software.amazon.smithy.go.codegen.integration.auth.SigV4AuthScheme software.amazon.smithy.go.codegen.integration.auth.AnonymousAuthScheme + +software.amazon.smithy.go.codegen.requestcompression.RequestCompression diff --git a/private/requestcompression/gzip.go b/private/requestcompression/gzip.go new file mode 100644 index 000000000..004d78f21 --- /dev/null +++ b/private/requestcompression/gzip.go @@ -0,0 +1,30 @@ +package requestcompression + +import ( + "bytes" + "compress/gzip" + "fmt" + "io" +) + +func gzipCompress(input io.Reader) ([]byte, error) { + var b bytes.Buffer + w, err := gzip.NewWriterLevel(&b, gzip.DefaultCompression) + if err != nil { + return nil, fmt.Errorf("failed to create gzip writer, %v", err) + } + + inBytes, err := io.ReadAll(input) + if err != nil { + return nil, fmt.Errorf("failed read payload to compress, %v", err) + } + + if _, err = w.Write(inBytes); err != nil { + return nil, fmt.Errorf("failed to write payload to be compressed, %v", err) + } + if err = w.Close(); err != nil { + return nil, fmt.Errorf("failed to flush payload being compressed, %v", err) + } + + return b.Bytes(), nil +} diff --git a/private/requestcompression/middleware_capture_request_compression.go b/private/requestcompression/middleware_capture_request_compression.go new file mode 100644 index 000000000..06c16afc1 --- /dev/null +++ b/private/requestcompression/middleware_capture_request_compression.go @@ -0,0 +1,52 @@ +package requestcompression + +import ( + "bytes" + "context" + "fmt" + "github.com/aws/smithy-go/middleware" + smithyhttp "github.com/aws/smithy-go/transport/http" + "io" + "net/http" +) + +const captureUncompressedRequestID = "CaptureUncompressedRequest" + +// AddCaptureUncompressedRequestMiddleware captures http request before compress encoding for check +func AddCaptureUncompressedRequestMiddleware(stack *middleware.Stack, buf *bytes.Buffer) error { + return stack.Serialize.Insert(&captureUncompressedRequestMiddleware{ + buf: buf, + }, "RequestCompression", middleware.Before) +} + +type captureUncompressedRequestMiddleware struct { + req *http.Request + buf *bytes.Buffer + bytes []byte +} + +// ID returns id of the captureUncompressedRequestMiddleware +func (*captureUncompressedRequestMiddleware) ID() string { + return captureUncompressedRequestID +} + +// HandleSerialize captures request payload before it is compressed by request compression middleware +func (m *captureUncompressedRequestMiddleware) HandleSerialize(ctx context.Context, input middleware.SerializeInput, next middleware.SerializeHandler, +) ( + output middleware.SerializeOutput, metadata middleware.Metadata, err error, +) { + request, ok := input.Request.(*smithyhttp.Request) + if !ok { + return output, metadata, fmt.Errorf("error when retrieving http request") + } + + _, err = io.Copy(m.buf, request.GetStream()) + if err != nil { + return output, metadata, fmt.Errorf("error when copying http request stream: %q", err) + } + if err = request.RewindStream(); err != nil { + return output, metadata, fmt.Errorf("error when rewinding request stream: %q", err) + } + + return next.HandleSerialize(ctx, input) +} diff --git a/private/requestcompression/request_compression.go b/private/requestcompression/request_compression.go new file mode 100644 index 000000000..cc1a7fc13 --- /dev/null +++ b/private/requestcompression/request_compression.go @@ -0,0 +1,103 @@ +// Package requestcompression implements runtime support for smithy-modeled +// request compression. +// +// This package is designated as private and is intended for use only by the +// smithy client runtime. The exported API therein is not considered stable and +// is subject to breaking changes without notice. +package requestcompression + +import ( + "bytes" + "context" + "fmt" + "github.com/aws/smithy-go/middleware" + "github.com/aws/smithy-go/transport/http" + "io" +) + +const maxRequestMinCompressSizeBytes = 10485760 + +// Enumeration values for supported compress Algorithms. +const ( + GZIP = "gzip" +) + +type compressFunc func(io.Reader) ([]byte, error) + +var allowedAlgorithms = map[string]compressFunc{ + GZIP: gzipCompress, +} + +// AddRequestCompression add requestCompression middleware to op stack +func AddRequestCompression(stack *middleware.Stack, disabled bool, minBytes int64, algorithms []string) error { + return stack.Serialize.Add(&requestCompression{ + disableRequestCompression: disabled, + requestMinCompressSizeBytes: minBytes, + compressAlgorithms: algorithms, + }, middleware.After) +} + +type requestCompression struct { + disableRequestCompression bool + requestMinCompressSizeBytes int64 + compressAlgorithms []string +} + +// ID returns the ID of the middleware +func (m requestCompression) ID() string { + return "RequestCompression" +} + +// HandleSerialize gzip compress the request's stream/body if enabled by config fields +func (m requestCompression) HandleSerialize( + ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler, +) ( + out middleware.SerializeOutput, metadata middleware.Metadata, err error, +) { + if m.disableRequestCompression { + return next.HandleSerialize(ctx, in) + } + // still need to check requestMinCompressSizeBytes in case it is out of range after service client config + if m.requestMinCompressSizeBytes < 0 || m.requestMinCompressSizeBytes > maxRequestMinCompressSizeBytes { + return out, metadata, fmt.Errorf("invalid range for min request compression size bytes %d, must be within 0 and 10485760 inclusively", m.requestMinCompressSizeBytes) + } + + req, ok := in.Request.(*http.Request) + if !ok { + return out, metadata, fmt.Errorf("unknown request type %T", req) + } + + for _, algorithm := range m.compressAlgorithms { + compressFunc := allowedAlgorithms[algorithm] + if compressFunc != nil { + if stream := req.GetStream(); stream != nil { + size, found, err := req.StreamLength() + if err != nil { + return out, metadata, fmt.Errorf("error while finding request stream length, %v", err) + } else if !found || size < m.requestMinCompressSizeBytes { + return next.HandleSerialize(ctx, in) + } + + compressedBytes, err := compressFunc(stream) + if err != nil { + return out, metadata, fmt.Errorf("failed to compress request stream, %v", err) + } + + var newReq *http.Request + if newReq, err = req.SetStream(bytes.NewReader(compressedBytes)); err != nil { + return out, metadata, fmt.Errorf("failed to set request stream, %v", err) + } + *req = *newReq + + if val := req.Header.Get("Content-Encoding"); val != "" { + req.Header.Set("Content-Encoding", fmt.Sprintf("%s, %s", val, algorithm)) + } else { + req.Header.Set("Content-Encoding", algorithm) + } + } + break + } + } + + return next.HandleSerialize(ctx, in) +} diff --git a/private/requestcompression/request_compression_test.go b/private/requestcompression/request_compression_test.go new file mode 100644 index 000000000..b29947dfd --- /dev/null +++ b/private/requestcompression/request_compression_test.go @@ -0,0 +1,125 @@ +package requestcompression + +import ( + "bytes" + "compress/gzip" + "context" + "fmt" + "github.com/aws/smithy-go/middleware" + "github.com/aws/smithy-go/transport/http" + "io" + "reflect" + "strings" + "testing" +) + +func TestRequestCompression(t *testing.T) { + cases := map[string]struct { + DisableRequestCompression bool + RequestMinCompressSizeBytes int64 + ContentLength int64 + Header map[string][]string + Stream io.Reader + ExpectedStream []byte + ExpectedHeader map[string][]string + }{ + "GZip request stream": { + Stream: strings.NewReader("Hi, world!"), + ExpectedStream: []byte("Hi, world!"), + ExpectedHeader: map[string][]string{ + "Content-Encoding": {"gzip"}, + }, + }, + "GZip request stream with existing encoding header": { + Stream: strings.NewReader("Hi, world!"), + ExpectedStream: []byte("Hi, world!"), + Header: map[string][]string{ + "Content-Encoding": {"custom"}, + }, + ExpectedHeader: map[string][]string{ + "Content-Encoding": {"custom, gzip"}, + }, + }, + "GZip request stream smaller than min compress request size": { + RequestMinCompressSizeBytes: 100, + Stream: strings.NewReader("Hi, world!"), + ExpectedStream: []byte("Hi, world!"), + ExpectedHeader: map[string][]string{}, + }, + "Disable GZip request stream": { + DisableRequestCompression: true, + Stream: strings.NewReader("Hi, world!"), + ExpectedStream: []byte("Hi, world!"), + ExpectedHeader: map[string][]string{}, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + var err error + req := http.NewStackRequest().(*http.Request) + req.ContentLength = c.ContentLength + req, _ = req.SetStream(c.Stream) + if c.Header != nil { + req.Header = c.Header + } + var updatedRequest *http.Request + + m := requestCompression{ + disableRequestCompression: c.DisableRequestCompression, + requestMinCompressSizeBytes: c.RequestMinCompressSizeBytes, + compressAlgorithms: []string{GZIP}, + } + _, _, err = m.HandleSerialize(context.Background(), + middleware.SerializeInput{Request: req}, + middleware.SerializeHandlerFunc(func(ctx context.Context, input middleware.SerializeInput) ( + out middleware.SerializeOutput, metadata middleware.Metadata, err error) { + updatedRequest = input.Request.(*http.Request) + return out, metadata, nil + }), + ) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + if stream := updatedRequest.GetStream(); stream != nil { + if err := testUnzipContent(stream, c.ExpectedStream, c.DisableRequestCompression, c.RequestMinCompressSizeBytes); err != nil { + t.Errorf("error while checking request stream: %q", err) + } + } + + if e, a := c.ExpectedHeader, map[string][]string(updatedRequest.Header); !reflect.DeepEqual(e, a) { + t.Errorf("expect request header to be %q, got %q", e, a) + } + }) + } +} + +func testUnzipContent(content io.Reader, expect []byte, disableRequestCompression bool, requestMinCompressionSizeBytes int64) error { + if disableRequestCompression || int64(len(expect)) < requestMinCompressionSizeBytes { + b, err := io.ReadAll(content) + if err != nil { + return fmt.Errorf("error while reading request") + } + if e, a := expect, b; !bytes.Equal(e, a) { + return fmt.Errorf("expect content to be %s, got %s", e, a) + } + } else { + r, err := gzip.NewReader(content) + if err != nil { + return fmt.Errorf("error while reading request") + } + + var actualBytes bytes.Buffer + _, err = actualBytes.ReadFrom(r) + if err != nil { + return fmt.Errorf("error while unzipping request payload") + } + + if e, a := expect, actualBytes.Bytes(); !bytes.Equal(e, a) { + return fmt.Errorf("expect unzipped content to be %s, got %s", e, a) + } + } + + return nil +} diff --git a/testing/bytes.go b/testing/bytes.go index 8f966846d..955612552 100644 --- a/testing/bytes.go +++ b/testing/bytes.go @@ -8,6 +8,17 @@ import ( "io/ioutil" ) +// Enumeration values for supported compress Algorithms. +const ( + GZIP = "gzip" +) + +type compareCompressFunc func([]byte, io.Reader) error + +var allowedAlgorithms = map[string]compareCompressFunc{ + GZIP: GzipCompareCompressBytes, +} + // CompareReaderEmpty checks if the reader is nil, or contains no bytes. // Returns an error if not empty. func CompareReaderEmpty(r io.Reader) error { @@ -94,3 +105,27 @@ func CompareURLFormReaderBytes(r io.Reader, expect []byte) error { } return nil } + +// CompareCompressedBytes compares the request stream before and after possible request compression +func CompareCompressedBytes(expect *bytes.Buffer, actual io.Reader, disable bool, min int64, algorithm string) error { + expectBytes := expect.Bytes() + if disable || int64(len(expectBytes)) < min { + actualBytes, err := io.ReadAll(actual) + if err != nil { + return fmt.Errorf("error while reading request: %q", err) + } + if e, a := expectBytes, actualBytes; !bytes.Equal(e, a) { + return fmt.Errorf("expect content to be %s, got %s", e, a) + } + } else { + compareFn := allowedAlgorithms[algorithm] + if compareFn == nil { + return fmt.Errorf("compress algorithm %s is not allowed", algorithm) + } + if err := compareFn(expectBytes, actual); err != nil { + return fmt.Errorf("error while comparing unzipped content: %q", err) + } + } + + return nil +} diff --git a/testing/gzip.go b/testing/gzip.go new file mode 100644 index 000000000..e6f4d6524 --- /dev/null +++ b/testing/gzip.go @@ -0,0 +1,27 @@ +package testing + +import ( + "bytes" + "compress/gzip" + "fmt" + "io" +) + +func GzipCompareCompressBytes(expect []byte, actual io.Reader) error { + content, err := gzip.NewReader(actual) + if err != nil { + return fmt.Errorf("error while reading request") + } + + var actualBytes bytes.Buffer + _, err = actualBytes.ReadFrom(content) + if err != nil { + return fmt.Errorf("error while unzipping request payload") + } + + if e, a := expect, actualBytes.Bytes(); !bytes.Equal(e, a) { + return fmt.Errorf("expect unzipped content to be %s, got %s", e, a) + } + + return nil +}