Skip to content

Commit

Permalink
Modify and Merge protocol test request unit tests codegen logic (#447)
Browse files Browse the repository at this point in the history
* Modify and Merge protocol test request unit tests codegen logic

* Modify and Merge protocol test generator syntax

* Modify and Merge some unit test syntax

* Modify and Merge protocol test codegen code

---------

Co-authored-by: Tianyi Wang <[email protected]>
  • Loading branch information
wty-Bryant and Tianyi Wang authored Aug 18, 2023
1 parent 664582d commit fdfa00c
Show file tree
Hide file tree
Showing 4 changed files with 179 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ public final class SmithyGoDependency {
public static final GoDependency SMITHY_TRANSPORT = smithy("transport", "smithytransport");
public static final GoDependency SMITHY_HTTP_TRANSPORT = smithy("transport/http", "smithyhttp");
public static final GoDependency SMITHY_MIDDLEWARE = smithy("middleware");
public static final GoDependency SMITHY_PRIVATE_PROTOCOL = smithy("private/protocol", "smithyprivateprotocol");
public static final GoDependency SMITHY_TIME = smithy("time", "smithytime");
public static final GoDependency SMITHY_HTTP_BINDING = smithy("encoding/httpbinding");
public static final GoDependency SMITHY_JSON = smithy("encoding/json", "smithyjson");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import java.util.function.Consumer;
import java.util.logging.Logger;
import software.amazon.smithy.codegen.core.Symbol;
import software.amazon.smithy.go.codegen.GoWriter;
import software.amazon.smithy.go.codegen.SmithyGoDependency;
import software.amazon.smithy.go.codegen.SymbolUtils;
Expand Down Expand Up @@ -196,7 +197,7 @@ protected void generateTestCaseValues(GoWriter writer, HttpRequestTestCase testC
* @param writer writer to write generated code with.
*/
protected void generateTestBodySetup(GoWriter writer) {
writer.write("var actualReq *http.Request");
writer.write("actualReq := &http.Request{}");
}

/**
Expand All @@ -205,26 +206,6 @@ protected void generateTestBodySetup(GoWriter writer) {
* @param writer writer to write generated code with.
*/
protected void generateTestServerHandler(GoWriter writer) {
writer.write("actualReq = r.Clone(r.Context())");
// Go does not set RawPath on http server if nothing is escaped
writer.openBlock("if len(actualReq.URL.RawPath) == 0 {", "}", () -> {
writer.write("actualReq.URL.RawPath = actualReq.URL.Path");
});
// Go automatically removes Content-Length header setting it to the member.
writer.addUseImports(SmithyGoDependency.STRCONV);
writer.openBlock("if v := actualReq.ContentLength; v != 0 {", "}", () -> {
writer.write("actualReq.Header.Set(\"Content-Length\", strconv.FormatInt(v, 10))");
});

writer.addUseImports(SmithyGoDependency.BYTES);
writer.write("var buf bytes.Buffer");
writer.openBlock("if _, err := io.Copy(&buf, r.Body); err != nil {", "}", () -> {
writer.write("t.Errorf(\"failed to read request body, %v\", err)");
});
writer.addUseImports(SmithyGoDependency.IOUTIL);
writer.write("actualReq.Body = ioutil.NopCloser(&buf)");
writer.write("");

super.generateTestServerHandler(writer);
}

Expand All @@ -236,8 +217,18 @@ protected void generateTestServerHandler(GoWriter writer) {
*/
@Override
protected void generateTestInvokeClientOperation(GoWriter writer, String clientName) {
Symbol stackSymbol = SymbolUtils.createPointableSymbolBuilder("Stack",
SmithyGoDependency.SMITHY_MIDDLEWARE).build();
writer.addUseImports(SmithyGoDependency.CONTEXT);
writer.write("result, err := $L.$T(context.Background(), c.Params)", clientName, opSymbol);
writer.openBlock("result, err := $L.$T(context.Background(), c.Params, func(options *Options) {", "})",
clientName, opSymbol, () -> {
writer.openBlock("options.APIOptions = append(options.APIOptions, func(stack $P) error {", "})",
stackSymbol, () -> {
writer.write("return $T(stack, actualReq)",
SymbolUtils.createValueSymbolBuilder("AddCaptureRequestMiddleware",
SmithyGoDependency.SMITHY_PRIVATE_PROTOCOL).build());
});
});
}

/**
Expand All @@ -254,6 +245,7 @@ protected void generateTestAssertions(GoWriter writer) {
writeAssertScalarEqual(writer, "c.ExpectURIPath", "actualReq.URL.RawPath", "path");

writeQueryItemBreakout(writer, "actualReq.URL.RawQuery", "queryItems");

writeAssertHasQuery(writer, "c.ExpectQuery", "queryItems");
writeAssertRequireQuery(writer, "c.RequireQuery", "queryItems");
writeAssertForbidQuery(writer, "c.ForbidQuery", "queryItems");
Expand Down Expand Up @@ -282,7 +274,8 @@ protected void generateTestServer(
String name,
Consumer<GoWriter> handler
) {
super.generateTestServer(writer, name, handler);
// We aren't using a test server, but we do need a URL to set.
writer.write("serverURL := \"http://localhost:8888/\"");
writer.pushState();
writer.putContext("parse", SymbolUtils.createValueSymbolBuilder("Parse", SmithyGoDependency.NET_URL)
.build());
Expand Down
47 changes: 47 additions & 0 deletions private/protocol/middleware_capture_request.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package protocol

import (
"context"
"fmt"
"github.com/aws/smithy-go/middleware"
smithyhttp "github.com/aws/smithy-go/transport/http"
"net/http"
"strconv"
)

const captureRequestID = "CaptureProtocolTestRequest"

// AddCaptureRequestMiddleware captures serialized http request during protocol test for check
func AddCaptureRequestMiddleware(stack *middleware.Stack, req *http.Request) error {
return stack.Build.Add(&captureRequestMiddleware{
req: req,
}, middleware.After)
}

type captureRequestMiddleware struct {
req *http.Request
}

func (*captureRequestMiddleware) ID() string {
return captureRequestID
}

func (m *captureRequestMiddleware) HandleBuild(ctx context.Context, input middleware.BuildInput, next middleware.BuildHandler,
) (
output middleware.BuildOutput, metadata middleware.Metadata, err error,
) {
request, ok := input.Request.(*smithyhttp.Request)
if !ok {
return output, metadata, fmt.Errorf("error while retrieving http request")
}

*m.req = *request.Build(ctx)
if len(m.req.URL.RawPath) == 0 {
m.req.URL.RawPath = m.req.URL.Path
}
if v := m.req.ContentLength; v != 0 {
m.req.Header.Set("Content-Length", strconv.FormatInt(v, 10))
}

return next.HandleBuild(ctx, input)
}
115 changes: 115 additions & 0 deletions private/protocol/middleware_capture_request_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
package protocol

import (
"context"
"github.com/aws/smithy-go/middleware"
smithytesting "github.com/aws/smithy-go/testing"
smithyhttp "github.com/aws/smithy-go/transport/http"
"io"
"io/ioutil"
"net/http"
"net/url"
"strings"
"testing"
)

// TestAddCaptureRequestMiddleware tests AddCaptureRequestMiddleware
func TestAddCaptureRequestMiddleware(t *testing.T) {
cases := map[string]struct {
Request *http.Request
ExpectRequest *http.Request
ExpectQuery []smithytesting.QueryItem
Stream io.Reader
}{
"normal request": {
Request: &http.Request{
Method: "PUT",
Header: map[string][]string{
"Foo": {"bar", "too"},
"Checksum": {"SHA256"},
},
URL: &url.URL{
Path: "test/path",
RawQuery: "language=us&region=us-west+east",
},
ContentLength: 100,
},
ExpectRequest: &http.Request{
Method: "PUT",
Header: map[string][]string{
"Foo": {"bar", "too"},
"Checksum": {"SHA256"},
"Content-Length": {"100"},
},
URL: &url.URL{
Path: "test/path",
RawPath: "test/path",
},
Body: ioutil.NopCloser(strings.NewReader("hello world.")),
},
ExpectQuery: []smithytesting.QueryItem{
{
Key: "language",
Value: "us",
},
{
Key: "region",
Value: "us-west%20east",
},
},
Stream: strings.NewReader("hello world."),
},
}

for name, c := range cases {
t.Run(name, func(t *testing.T) {
var err error
req := &smithyhttp.Request{
Request: c.Request,
}
if c.Stream != nil {
req, err = req.SetStream(c.Stream)
if err != nil {
t.Fatalf("Got error while retrieving case stream: %v", err)
}
}
capturedRequest := &http.Request{}
m := captureRequestMiddleware{
req: capturedRequest,
}
_, _, err = m.HandleBuild(context.Background(),
middleware.BuildInput{Request: req},
middleware.BuildHandlerFunc(func(ctx context.Context, input middleware.BuildInput) (
out middleware.BuildOutput, metadata middleware.Metadata, err error) {
return out, metadata, nil
}),
)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}

if e, a := c.ExpectRequest.Method, capturedRequest.Method; e != a {
t.Errorf("expect request method %v found, got %v", e, a)
}
if e, a := c.ExpectRequest.URL.Path, capturedRequest.URL.RawPath; e != a {
t.Errorf("expect %v path, got %v", e, a)
}
if c.ExpectRequest.Body != nil {
expect, err := ioutil.ReadAll(c.ExpectRequest.Body)
if capturedRequest.Body == nil {
t.Errorf("Expect request stream %v captured, get nil", string(expect))
}
actual, err := ioutil.ReadAll(capturedRequest.Body)
if err != nil {
t.Errorf("unable to read captured request body, %v", err)
}
if e, a := string(expect), string(actual); e != a {
t.Errorf("expect request body to be %s, got %s", e, a)
}
}
queryItems := smithytesting.ParseRawQuery(capturedRequest.URL.RawQuery)
smithytesting.AssertHasQuery(t, c.ExpectQuery, queryItems)
smithytesting.AssertHasHeader(t, c.ExpectRequest.Header, capturedRequest.Header)
})
}
}

0 comments on commit fdfa00c

Please sign in to comment.