Skip to content

Commit

Permalink
initial
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonstreator committed Nov 25, 2022
0 parents commit 8afc3f0
Show file tree
Hide file tree
Showing 9 changed files with 376 additions and 0 deletions.
27 changes: 27 additions & 0 deletions .github/workflows/coverage.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
name: Go

on:
push:
branches: [main]
pull_request:
branches: [main]

jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2

- name: Set up Go
uses: actions/setup-go@v2
with:
go-version: 1.19

- name: Build
run: go build -v ./...

- name: Test
run: go test -v -coverprofile=coverage.out -covermode=atomic ./...

- name: Upload coverage to Codecov
run: bash <(curl -s https://codecov.io/bash)
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
coverage.out
71 changes: 71 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# roundtriphook

[![codecov](https://codecov.io/gh/dillonstreator/roundtriphook/branch/main/graph/badge.svg?token=T4KLDSR6FH)](https://codecov.io/gh/dillonstreator/roundtriphook)

utility package which provides a simple before & after hook interface for an `http.RoundTripper`

## Install

```sh
go get github.com/dillonstreator/roundtriphook
```

## Usage

```go
var wrappedTransport = roundtriphook.NewTransport(
// This call to roundtriphook.WithBaseRoundTripper is unnecessary
// since the default behavior is to set the base round tripper to http.DefaultTransport if none is provided
roundtriphook.WithBaseRoundTripper(http.DefaultTransport),
roundtriphook.WithBefore(func(req *http.Request) *http.Request {
fmt.Println("before request")
// mutate request or add context here
return req
}),
roundtriphook.WithAfter(func(req *http.Request, res *http.Response, err error) {
fmt.Println("after request")
}),
)

var httpClient = &http.Client{
Transport: wrappedTransport,
}
```

### logging transport

```go
var loggingTransport = roundtriphook.NewTransport(
roundtriphook.WithBefore(func(req *http.Request) *http.Request {
startTime := time.Now()
id := startTime.UnixNano()

fmt.Printf("[%d] -> %s %s\n", id, req.Method, req.URL)

ctx := req.Context()
ctx = context.WithValue(ctx, timeStartKey, startTime)
ctx = context.WithValue(ctx, idKey, id)

return req.WithContext(ctx)
}),
roundtriphook.WithAfter(func(req *http.Request, res *http.Response, err error) {
startTime := req.Context().Value(timeStartKey).(time.Time)
id := req.Context().Value(idKey).(int64)

sb := strings.Builder{}
sb.WriteString(fmt.Sprintf("[%d] <- %s %s %s", id, req.Method, req.URL, time.Since(startTime)))

if res != nil {
sb.WriteString(" " + res.Status)
}

if err != nil {
sb.WriteString(fmt.Sprintf(" %s", err.Error()))
}

fmt.Printf("%s\n", sb.String())
}),
)
```

[full logging transport example](./examples/loggingtransport/main.go)
61 changes: 61 additions & 0 deletions examples/loggingtransport/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package main

import (
"context"
"fmt"
"net/http"
"strings"
"time"

"github.com/dillonstreator/roundtriphook"
)

type rthCtxKey string

var (
timeStartKey = rthCtxKey("startTime")
idKey = rthCtxKey("id")
)

var loggingTransport = roundtriphook.NewTransport(
// This call to roundtriphook.WithBaseRoundTripper is unnecessary
// since the default behavior is to set the base round tripper to http.DefaultTransport if none is provided
roundtriphook.WithBaseRoundTripper(http.DefaultTransport),
roundtriphook.WithBefore(func(req *http.Request) *http.Request {
startTime := time.Now()
id := startTime.UnixNano()

fmt.Printf("[%d] -> %s %s\n", id, req.Method, req.URL)

ctx := req.Context()
ctx = context.WithValue(ctx, timeStartKey, startTime)
ctx = context.WithValue(ctx, idKey, id)

return req.WithContext(ctx)
}),
roundtriphook.WithAfter(func(req *http.Request, res *http.Response, err error) {
startTime := req.Context().Value(timeStartKey).(time.Time)
id := req.Context().Value(idKey).(int64)

sb := strings.Builder{}
sb.WriteString(fmt.Sprintf("[%d] <- %s %s %s", id, req.Method, req.URL, time.Since(startTime)))

if res != nil {
sb.WriteString(" " + res.Status)
}

if err != nil {
sb.WriteString(fmt.Sprintf(" %s", err.Error()))
}

fmt.Printf("%s\n", sb.String())
}),
)

func main() {
httpClient := &http.Client{
Transport: loggingTransport,
}

httpClient.Get("https://www.google.com")
}
11 changes: 11 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
module github.com/dillonstreator/roundtriphook

go 1.19

require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/stretchr/objx v0.5.0 // indirect
github.com/stretchr/testify v1.8.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
17 changes: 17 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
25 changes: 25 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package roundtriphook

import (
"net/http"
)

type option func(t *transport)

func WithBaseRoundTripper(base http.RoundTripper) option {
return func(t *transport) {
t.base = base
}
}

func WithBefore(fn ...BeforeFn) option {
return func(t *transport) {
t.befores = append(t.befores, fn...)
}
}

func WithAfter(fn ...AfterFn) option {
return func(t *transport) {
t.afters = append(t.afters, fn...)
}
}
43 changes: 43 additions & 0 deletions transport.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package roundtriphook

import (
"net/http"
)

type BeforeFn func(req *http.Request) *http.Request
type AfterFn func(req *http.Request, res *http.Response, err error)

type transport struct {
base http.RoundTripper
befores []BeforeFn
afters []AfterFn
}

var _ http.RoundTripper = (*transport)(nil)

func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) {
for _, fn := range t.befores {
req = fn(req)
}

res, err := t.base.RoundTrip(req)

for _, fn := range t.afters {
fn(req, res, err)
}

return res, err
}

func NewTransport(opts ...option) *transport {
t := &transport{}
for _, opt := range opts {
opt(t)
}

if t.base == nil {
t.base = http.DefaultTransport
}

return t
}
120 changes: 120 additions & 0 deletions transport_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
package roundtriphook

import (
"net/http"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)

type mockTransport struct {
mock.Mock
}

var _ http.RoundTripper = (*mockTransport)(nil)

func (m *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) {
args := m.Called(req)
return args.Get(0).(*http.Response), args.Error(1)
}

func TestNewTransport(t *testing.T) {
assert := assert.New(t)

tpt := NewTransport()

assert.Equal(http.DefaultTransport, tpt.base)
}

func TestNewTransport_Options(t *testing.T) {
assert := assert.New(t)

baseTpt := &http.Transport{}

var before1 BeforeFn = func(req *http.Request) *http.Request { return req }
var before2 BeforeFn = func(req *http.Request) *http.Request { return req }
var after1 AfterFn = func(req *http.Request, res *http.Response, err error) {}
var after2 AfterFn = func(req *http.Request, res *http.Response, err error) {}

tpt := NewTransport(
WithBaseRoundTripper(baseTpt),
WithBefore(before1),
WithBefore(before2),
WithAfter(after1),
WithAfter(after2),
)

assert.Equal(baseTpt, tpt.base)
assert.Len(tpt.befores, 2)
assert.Len(tpt.afters, 2)
}

func TestRoundTrip(t *testing.T) {
assert := assert.New(t)

mockTpt := &mockTransport{}
req1 := &http.Request{
Header: http.Header{
"header1": []string{"value1"},
},
}
req2 := &http.Request{
Header: http.Header{
"header1": []string{"value1"},
"header2": []string{"value2"},
},
}
expectedRes := &http.Response{}

mockTpt.On("RoundTrip", req2).Return(expectedRes, nil).Once()

before1CalledAt := time.Time{}
before2CalledAt := time.Time{}
after1CalledAt := time.Time{}
after2CalledAt := time.Time{}
before1 := func(req *http.Request) *http.Request {
time.Sleep(time.Millisecond)
before1CalledAt = time.Now()
assert.Equal(req1, req)
return req2
}
before2 := func(req *http.Request) *http.Request {
time.Sleep(time.Millisecond)
before2CalledAt = time.Now()
assert.Equal(req2, req)
return req2
}
after1 := func(req *http.Request, res *http.Response, err error) {
time.Sleep(time.Millisecond)
after1CalledAt = time.Now()
assert.Equal(req2, req)
assert.Equal(expectedRes, res)
}
after2 := func(req *http.Request, res *http.Response, err error) {
time.Sleep(time.Millisecond)
after2CalledAt = time.Now()
assert.Equal(req2, req)
assert.Equal(expectedRes, res)
}
tpt := NewTransport(
WithBaseRoundTripper(mockTpt),
WithBefore(before1),
WithBefore(before2),
WithAfter(after1),
WithAfter(after2),
)

res, err := tpt.RoundTrip(req1)
assert.NoError(err)

assert.Equal(expectedRes, res)
assert.NotZero(before1CalledAt)
assert.NotZero(before2CalledAt)
assert.NotZero(after1CalledAt)
assert.NotZero(after2CalledAt)
assert.True(before1CalledAt.Before(before2CalledAt))
assert.True(before2CalledAt.Before(after1CalledAt))
assert.True(after1CalledAt.Before(after2CalledAt))
}

0 comments on commit 8afc3f0

Please sign in to comment.