From fdfa00cd9b8bf3e937790dece47c47f47094784c Mon Sep 17 00:00:00 2001 From: Eren Yeager <92114074+wty-Bryant@users.noreply.github.com> Date: Fri, 18 Aug 2023 12:33:59 -0400 Subject: [PATCH] Modify and Merge protocol test request unit tests codegen logic (#447) * Modify and Merge protocol test request unit tests codegen logic * Modify and Merge protocol test generator syntax * Modify and Merge some unit test syntax * Modify and Merge protocol test codegen code --------- Co-authored-by: Tianyi Wang --- .../smithy/go/codegen/SmithyGoDependency.java | 1 + .../HttpProtocolUnitTestRequestGenerator.java | 39 +++--- .../protocol/middleware_capture_request.go | 47 +++++++ .../middleware_capture_request_test.go | 115 ++++++++++++++++++ 4 files changed, 179 insertions(+), 23 deletions(-) create mode 100644 private/protocol/middleware_capture_request.go create mode 100644 private/protocol/middleware_capture_request_test.go diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/SmithyGoDependency.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/SmithyGoDependency.java index f62ad9c1d..2d206e909 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/SmithyGoDependency.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/SmithyGoDependency.java @@ -49,6 +49,7 @@ public final class SmithyGoDependency { public static final GoDependency SMITHY_TRANSPORT = smithy("transport", "smithytransport"); public static final GoDependency SMITHY_HTTP_TRANSPORT = smithy("transport/http", "smithyhttp"); public static final GoDependency SMITHY_MIDDLEWARE = smithy("middleware"); + public static final GoDependency SMITHY_PRIVATE_PROTOCOL = smithy("private/protocol", "smithyprivateprotocol"); public static final GoDependency SMITHY_TIME = smithy("time", "smithytime"); public static final GoDependency SMITHY_HTTP_BINDING = smithy("encoding/httpbinding"); public static final GoDependency SMITHY_JSON = smithy("encoding/json", "smithyjson"); diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolUnitTestRequestGenerator.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolUnitTestRequestGenerator.java index 97bdcb1b4..15d4adf6a 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolUnitTestRequestGenerator.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolUnitTestRequestGenerator.java @@ -19,6 +19,7 @@ import java.util.function.Consumer; import java.util.logging.Logger; +import software.amazon.smithy.codegen.core.Symbol; import software.amazon.smithy.go.codegen.GoWriter; import software.amazon.smithy.go.codegen.SmithyGoDependency; import software.amazon.smithy.go.codegen.SymbolUtils; @@ -196,7 +197,7 @@ protected void generateTestCaseValues(GoWriter writer, HttpRequestTestCase testC * @param writer writer to write generated code with. */ protected void generateTestBodySetup(GoWriter writer) { - writer.write("var actualReq *http.Request"); + writer.write("actualReq := &http.Request{}"); } /** @@ -205,26 +206,6 @@ protected void generateTestBodySetup(GoWriter writer) { * @param writer writer to write generated code with. */ protected void generateTestServerHandler(GoWriter writer) { - writer.write("actualReq = r.Clone(r.Context())"); - // Go does not set RawPath on http server if nothing is escaped - writer.openBlock("if len(actualReq.URL.RawPath) == 0 {", "}", () -> { - writer.write("actualReq.URL.RawPath = actualReq.URL.Path"); - }); - // Go automatically removes Content-Length header setting it to the member. - writer.addUseImports(SmithyGoDependency.STRCONV); - writer.openBlock("if v := actualReq.ContentLength; v != 0 {", "}", () -> { - writer.write("actualReq.Header.Set(\"Content-Length\", strconv.FormatInt(v, 10))"); - }); - - writer.addUseImports(SmithyGoDependency.BYTES); - writer.write("var buf bytes.Buffer"); - writer.openBlock("if _, err := io.Copy(&buf, r.Body); err != nil {", "}", () -> { - writer.write("t.Errorf(\"failed to read request body, %v\", err)"); - }); - writer.addUseImports(SmithyGoDependency.IOUTIL); - writer.write("actualReq.Body = ioutil.NopCloser(&buf)"); - writer.write(""); - super.generateTestServerHandler(writer); } @@ -236,8 +217,18 @@ protected void generateTestServerHandler(GoWriter writer) { */ @Override protected void generateTestInvokeClientOperation(GoWriter writer, String clientName) { + Symbol stackSymbol = SymbolUtils.createPointableSymbolBuilder("Stack", + SmithyGoDependency.SMITHY_MIDDLEWARE).build(); writer.addUseImports(SmithyGoDependency.CONTEXT); - writer.write("result, err := $L.$T(context.Background(), c.Params)", clientName, opSymbol); + writer.openBlock("result, err := $L.$T(context.Background(), c.Params, func(options *Options) {", "})", + clientName, opSymbol, () -> { + writer.openBlock("options.APIOptions = append(options.APIOptions, func(stack $P) error {", "})", + stackSymbol, () -> { + writer.write("return $T(stack, actualReq)", + SymbolUtils.createValueSymbolBuilder("AddCaptureRequestMiddleware", + SmithyGoDependency.SMITHY_PRIVATE_PROTOCOL).build()); + }); + }); } /** @@ -254,6 +245,7 @@ protected void generateTestAssertions(GoWriter writer) { writeAssertScalarEqual(writer, "c.ExpectURIPath", "actualReq.URL.RawPath", "path"); writeQueryItemBreakout(writer, "actualReq.URL.RawQuery", "queryItems"); + writeAssertHasQuery(writer, "c.ExpectQuery", "queryItems"); writeAssertRequireQuery(writer, "c.RequireQuery", "queryItems"); writeAssertForbidQuery(writer, "c.ForbidQuery", "queryItems"); @@ -282,7 +274,8 @@ protected void generateTestServer( String name, Consumer handler ) { - super.generateTestServer(writer, name, handler); + // We aren't using a test server, but we do need a URL to set. + writer.write("serverURL := \"http://localhost:8888/\""); writer.pushState(); writer.putContext("parse", SymbolUtils.createValueSymbolBuilder("Parse", SmithyGoDependency.NET_URL) .build()); diff --git a/private/protocol/middleware_capture_request.go b/private/protocol/middleware_capture_request.go new file mode 100644 index 000000000..c812036f0 --- /dev/null +++ b/private/protocol/middleware_capture_request.go @@ -0,0 +1,47 @@ +package protocol + +import ( + "context" + "fmt" + "github.com/aws/smithy-go/middleware" + smithyhttp "github.com/aws/smithy-go/transport/http" + "net/http" + "strconv" +) + +const captureRequestID = "CaptureProtocolTestRequest" + +// AddCaptureRequestMiddleware captures serialized http request during protocol test for check +func AddCaptureRequestMiddleware(stack *middleware.Stack, req *http.Request) error { + return stack.Build.Add(&captureRequestMiddleware{ + req: req, + }, middleware.After) +} + +type captureRequestMiddleware struct { + req *http.Request +} + +func (*captureRequestMiddleware) ID() string { + return captureRequestID +} + +func (m *captureRequestMiddleware) HandleBuild(ctx context.Context, input middleware.BuildInput, next middleware.BuildHandler, +) ( + output middleware.BuildOutput, metadata middleware.Metadata, err error, +) { + request, ok := input.Request.(*smithyhttp.Request) + if !ok { + return output, metadata, fmt.Errorf("error while retrieving http request") + } + + *m.req = *request.Build(ctx) + if len(m.req.URL.RawPath) == 0 { + m.req.URL.RawPath = m.req.URL.Path + } + if v := m.req.ContentLength; v != 0 { + m.req.Header.Set("Content-Length", strconv.FormatInt(v, 10)) + } + + return next.HandleBuild(ctx, input) +} diff --git a/private/protocol/middleware_capture_request_test.go b/private/protocol/middleware_capture_request_test.go new file mode 100644 index 000000000..0579260a3 --- /dev/null +++ b/private/protocol/middleware_capture_request_test.go @@ -0,0 +1,115 @@ +package protocol + +import ( + "context" + "github.com/aws/smithy-go/middleware" + smithytesting "github.com/aws/smithy-go/testing" + smithyhttp "github.com/aws/smithy-go/transport/http" + "io" + "io/ioutil" + "net/http" + "net/url" + "strings" + "testing" +) + +// TestAddCaptureRequestMiddleware tests AddCaptureRequestMiddleware +func TestAddCaptureRequestMiddleware(t *testing.T) { + cases := map[string]struct { + Request *http.Request + ExpectRequest *http.Request + ExpectQuery []smithytesting.QueryItem + Stream io.Reader + }{ + "normal request": { + Request: &http.Request{ + Method: "PUT", + Header: map[string][]string{ + "Foo": {"bar", "too"}, + "Checksum": {"SHA256"}, + }, + URL: &url.URL{ + Path: "test/path", + RawQuery: "language=us®ion=us-west+east", + }, + ContentLength: 100, + }, + ExpectRequest: &http.Request{ + Method: "PUT", + Header: map[string][]string{ + "Foo": {"bar", "too"}, + "Checksum": {"SHA256"}, + "Content-Length": {"100"}, + }, + URL: &url.URL{ + Path: "test/path", + RawPath: "test/path", + }, + Body: ioutil.NopCloser(strings.NewReader("hello world.")), + }, + ExpectQuery: []smithytesting.QueryItem{ + { + Key: "language", + Value: "us", + }, + { + Key: "region", + Value: "us-west%20east", + }, + }, + Stream: strings.NewReader("hello world."), + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + var err error + req := &smithyhttp.Request{ + Request: c.Request, + } + if c.Stream != nil { + req, err = req.SetStream(c.Stream) + if err != nil { + t.Fatalf("Got error while retrieving case stream: %v", err) + } + } + capturedRequest := &http.Request{} + m := captureRequestMiddleware{ + req: capturedRequest, + } + _, _, err = m.HandleBuild(context.Background(), + middleware.BuildInput{Request: req}, + middleware.BuildHandlerFunc(func(ctx context.Context, input middleware.BuildInput) ( + out middleware.BuildOutput, metadata middleware.Metadata, err error) { + return out, metadata, nil + }), + ) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + if e, a := c.ExpectRequest.Method, capturedRequest.Method; e != a { + t.Errorf("expect request method %v found, got %v", e, a) + } + if e, a := c.ExpectRequest.URL.Path, capturedRequest.URL.RawPath; e != a { + t.Errorf("expect %v path, got %v", e, a) + } + if c.ExpectRequest.Body != nil { + expect, err := ioutil.ReadAll(c.ExpectRequest.Body) + if capturedRequest.Body == nil { + t.Errorf("Expect request stream %v captured, get nil", string(expect)) + } + actual, err := ioutil.ReadAll(capturedRequest.Body) + if err != nil { + t.Errorf("unable to read captured request body, %v", err) + } + if e, a := string(expect), string(actual); e != a { + t.Errorf("expect request body to be %s, got %s", e, a) + } + } + queryItems := smithytesting.ParseRawQuery(capturedRequest.URL.RawQuery) + smithytesting.AssertHasQuery(t, c.ExpectQuery, queryItems) + smithytesting.AssertHasHeader(t, c.ExpectRequest.Header, capturedRequest.Header) + }) + } +}