From 8afc3f0a565708cf232f3de20b5186e16b232ea5 Mon Sep 17 00:00:00 2001 From: dillonstreator Date: Fri, 11 Nov 2022 18:31:37 -0600 Subject: [PATCH] initial --- .github/workflows/coverage.yaml | 27 +++++++ .gitignore | 1 + README.md | 71 ++++++++++++++++++ examples/loggingtransport/main.go | 61 +++++++++++++++ go.mod | 11 +++ go.sum | 17 +++++ options.go | 25 +++++++ transport.go | 43 +++++++++++ transport_test.go | 120 ++++++++++++++++++++++++++++++ 9 files changed, 376 insertions(+) create mode 100644 .github/workflows/coverage.yaml create mode 100644 .gitignore create mode 100644 README.md create mode 100644 examples/loggingtransport/main.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 options.go create mode 100644 transport.go create mode 100644 transport_test.go diff --git a/.github/workflows/coverage.yaml b/.github/workflows/coverage.yaml new file mode 100644 index 0000000..5abeefd --- /dev/null +++ b/.github/workflows/coverage.yaml @@ -0,0 +1,27 @@ +name: Go + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + + - name: Set up Go + uses: actions/setup-go@v2 + with: + go-version: 1.19 + + - name: Build + run: go build -v ./... + + - name: Test + run: go test -v -coverprofile=coverage.out -covermode=atomic ./... + + - name: Upload coverage to Codecov + run: bash <(curl -s https://codecov.io/bash) diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..1b3ac10 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +coverage.out \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..f720870 --- /dev/null +++ b/README.md @@ -0,0 +1,71 @@ +# roundtriphook + +[![codecov](https://codecov.io/gh/dillonstreator/roundtriphook/branch/main/graph/badge.svg?token=T4KLDSR6FH)](https://codecov.io/gh/dillonstreator/roundtriphook) + +utility package which provides a simple before & after hook interface for an `http.RoundTripper` + +## Install + +```sh +go get github.com/dillonstreator/roundtriphook +``` + +## Usage + +```go +var wrappedTransport = roundtriphook.NewTransport( + // This call to roundtriphook.WithBaseRoundTripper is unnecessary + // since the default behavior is to set the base round tripper to http.DefaultTransport if none is provided + roundtriphook.WithBaseRoundTripper(http.DefaultTransport), + roundtriphook.WithBefore(func(req *http.Request) *http.Request { + fmt.Println("before request") + // mutate request or add context here + return req + }), + roundtriphook.WithAfter(func(req *http.Request, res *http.Response, err error) { + fmt.Println("after request") + }), +) + +var httpClient = &http.Client{ + Transport: wrappedTransport, +} +``` + +### logging transport + +```go +var loggingTransport = roundtriphook.NewTransport( + roundtriphook.WithBefore(func(req *http.Request) *http.Request { + startTime := time.Now() + id := startTime.UnixNano() + + fmt.Printf("[%d] -> %s %s\n", id, req.Method, req.URL) + + ctx := req.Context() + ctx = context.WithValue(ctx, timeStartKey, startTime) + ctx = context.WithValue(ctx, idKey, id) + + return req.WithContext(ctx) + }), + roundtriphook.WithAfter(func(req *http.Request, res *http.Response, err error) { + startTime := req.Context().Value(timeStartKey).(time.Time) + id := req.Context().Value(idKey).(int64) + + sb := strings.Builder{} + sb.WriteString(fmt.Sprintf("[%d] <- %s %s %s", id, req.Method, req.URL, time.Since(startTime))) + + if res != nil { + sb.WriteString(" " + res.Status) + } + + if err != nil { + sb.WriteString(fmt.Sprintf(" %s", err.Error())) + } + + fmt.Printf("%s\n", sb.String()) + }), +) +``` + +[full logging transport example](./examples/loggingtransport/main.go) diff --git a/examples/loggingtransport/main.go b/examples/loggingtransport/main.go new file mode 100644 index 0000000..a906439 --- /dev/null +++ b/examples/loggingtransport/main.go @@ -0,0 +1,61 @@ +package main + +import ( + "context" + "fmt" + "net/http" + "strings" + "time" + + "github.com/dillonstreator/roundtriphook" +) + +type rthCtxKey string + +var ( + timeStartKey = rthCtxKey("startTime") + idKey = rthCtxKey("id") +) + +var loggingTransport = roundtriphook.NewTransport( + // This call to roundtriphook.WithBaseRoundTripper is unnecessary + // since the default behavior is to set the base round tripper to http.DefaultTransport if none is provided + roundtriphook.WithBaseRoundTripper(http.DefaultTransport), + roundtriphook.WithBefore(func(req *http.Request) *http.Request { + startTime := time.Now() + id := startTime.UnixNano() + + fmt.Printf("[%d] -> %s %s\n", id, req.Method, req.URL) + + ctx := req.Context() + ctx = context.WithValue(ctx, timeStartKey, startTime) + ctx = context.WithValue(ctx, idKey, id) + + return req.WithContext(ctx) + }), + roundtriphook.WithAfter(func(req *http.Request, res *http.Response, err error) { + startTime := req.Context().Value(timeStartKey).(time.Time) + id := req.Context().Value(idKey).(int64) + + sb := strings.Builder{} + sb.WriteString(fmt.Sprintf("[%d] <- %s %s %s", id, req.Method, req.URL, time.Since(startTime))) + + if res != nil { + sb.WriteString(" " + res.Status) + } + + if err != nil { + sb.WriteString(fmt.Sprintf(" %s", err.Error())) + } + + fmt.Printf("%s\n", sb.String()) + }), +) + +func main() { + httpClient := &http.Client{ + Transport: loggingTransport, + } + + httpClient.Get("https://www.google.com") +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..6776b8e --- /dev/null +++ b/go.mod @@ -0,0 +1,11 @@ +module github.com/dillonstreator/roundtriphook + +go 1.19 + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/stretchr/objx v0.5.0 // indirect + github.com/stretchr/testify v1.8.1 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..d9320db --- /dev/null +++ b/go.sum @@ -0,0 +1,17 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/options.go b/options.go new file mode 100644 index 0000000..7d7a2df --- /dev/null +++ b/options.go @@ -0,0 +1,25 @@ +package roundtriphook + +import ( + "net/http" +) + +type option func(t *transport) + +func WithBaseRoundTripper(base http.RoundTripper) option { + return func(t *transport) { + t.base = base + } +} + +func WithBefore(fn ...BeforeFn) option { + return func(t *transport) { + t.befores = append(t.befores, fn...) + } +} + +func WithAfter(fn ...AfterFn) option { + return func(t *transport) { + t.afters = append(t.afters, fn...) + } +} diff --git a/transport.go b/transport.go new file mode 100644 index 0000000..af33b68 --- /dev/null +++ b/transport.go @@ -0,0 +1,43 @@ +package roundtriphook + +import ( + "net/http" +) + +type BeforeFn func(req *http.Request) *http.Request +type AfterFn func(req *http.Request, res *http.Response, err error) + +type transport struct { + base http.RoundTripper + befores []BeforeFn + afters []AfterFn +} + +var _ http.RoundTripper = (*transport)(nil) + +func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) { + for _, fn := range t.befores { + req = fn(req) + } + + res, err := t.base.RoundTrip(req) + + for _, fn := range t.afters { + fn(req, res, err) + } + + return res, err +} + +func NewTransport(opts ...option) *transport { + t := &transport{} + for _, opt := range opts { + opt(t) + } + + if t.base == nil { + t.base = http.DefaultTransport + } + + return t +} diff --git a/transport_test.go b/transport_test.go new file mode 100644 index 0000000..ca9eded --- /dev/null +++ b/transport_test.go @@ -0,0 +1,120 @@ +package roundtriphook + +import ( + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +type mockTransport struct { + mock.Mock +} + +var _ http.RoundTripper = (*mockTransport)(nil) + +func (m *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) { + args := m.Called(req) + return args.Get(0).(*http.Response), args.Error(1) +} + +func TestNewTransport(t *testing.T) { + assert := assert.New(t) + + tpt := NewTransport() + + assert.Equal(http.DefaultTransport, tpt.base) +} + +func TestNewTransport_Options(t *testing.T) { + assert := assert.New(t) + + baseTpt := &http.Transport{} + + var before1 BeforeFn = func(req *http.Request) *http.Request { return req } + var before2 BeforeFn = func(req *http.Request) *http.Request { return req } + var after1 AfterFn = func(req *http.Request, res *http.Response, err error) {} + var after2 AfterFn = func(req *http.Request, res *http.Response, err error) {} + + tpt := NewTransport( + WithBaseRoundTripper(baseTpt), + WithBefore(before1), + WithBefore(before2), + WithAfter(after1), + WithAfter(after2), + ) + + assert.Equal(baseTpt, tpt.base) + assert.Len(tpt.befores, 2) + assert.Len(tpt.afters, 2) +} + +func TestRoundTrip(t *testing.T) { + assert := assert.New(t) + + mockTpt := &mockTransport{} + req1 := &http.Request{ + Header: http.Header{ + "header1": []string{"value1"}, + }, + } + req2 := &http.Request{ + Header: http.Header{ + "header1": []string{"value1"}, + "header2": []string{"value2"}, + }, + } + expectedRes := &http.Response{} + + mockTpt.On("RoundTrip", req2).Return(expectedRes, nil).Once() + + before1CalledAt := time.Time{} + before2CalledAt := time.Time{} + after1CalledAt := time.Time{} + after2CalledAt := time.Time{} + before1 := func(req *http.Request) *http.Request { + time.Sleep(time.Millisecond) + before1CalledAt = time.Now() + assert.Equal(req1, req) + return req2 + } + before2 := func(req *http.Request) *http.Request { + time.Sleep(time.Millisecond) + before2CalledAt = time.Now() + assert.Equal(req2, req) + return req2 + } + after1 := func(req *http.Request, res *http.Response, err error) { + time.Sleep(time.Millisecond) + after1CalledAt = time.Now() + assert.Equal(req2, req) + assert.Equal(expectedRes, res) + } + after2 := func(req *http.Request, res *http.Response, err error) { + time.Sleep(time.Millisecond) + after2CalledAt = time.Now() + assert.Equal(req2, req) + assert.Equal(expectedRes, res) + } + tpt := NewTransport( + WithBaseRoundTripper(mockTpt), + WithBefore(before1), + WithBefore(before2), + WithAfter(after1), + WithAfter(after2), + ) + + res, err := tpt.RoundTrip(req1) + assert.NoError(err) + + assert.Equal(expectedRes, res) + assert.NotZero(before1CalledAt) + assert.NotZero(before2CalledAt) + assert.NotZero(after1CalledAt) + assert.NotZero(after2CalledAt) + assert.True(before1CalledAt.Before(before2CalledAt)) + assert.True(before2CalledAt.Before(after1CalledAt)) + assert.True(after1CalledAt.Before(after2CalledAt)) +}