From 2bd70be79b3dd36fec6f6d7b48e81b0da5579972 Mon Sep 17 00:00:00 2001
From: Cr <631807682@qq.com>
Date: Mon, 17 Jul 2023 21:01:46 +0800
Subject: [PATCH 01/17] chore(http1): remove duplicated code (#817)
---
pkg/protocol/http1/server.go | 5 -----
1 file changed, 5 deletions(-)
diff --git a/pkg/protocol/http1/server.go b/pkg/protocol/http1/server.go
index ba8915a1d..8aaed6db7 100644
--- a/pkg/protocol/http1/server.go
+++ b/pkg/protocol/http1/server.go
@@ -350,11 +350,6 @@ func (s Server) Serve(c context.Context, conn network.Conn) (err error) {
}
if hijackHandler != nil {
- if zr != nil {
- zr.Release() //nolint:errcheck
- zr = nil
- }
-
// Hijacked conn process the timeout by itself
err = ctx.GetConn().SetReadTimeout(0)
if err != nil {
From 7d25d5a774a34d9dd7a6c9a58beedc01bc2513ed Mon Sep 17 00:00:00 2001
From: Cr <631807682@qq.com>
Date: Tue, 18 Jul 2023 17:27:35 +0800
Subject: [PATCH 02/17] test: add test for protocol http1 (#804)
Co-authored-by: kinggo
---
pkg/common/test/assert/assert.go | 2 +
pkg/common/test/mock/body_data.go | 28 ++++
pkg/common/test/mock/body_data_test.go | 12 ++
pkg/common/test/mock/network.go | 4 +
pkg/common/test/mock/network_test.go | 1 +
pkg/protocol/http1/client_test.go | 108 ++++++++++++++
pkg/protocol/http1/client_unix_test.go | 56 ++++++++
pkg/protocol/http1/ext/common_test.go | 94 +++++++++++++
pkg/protocol/http1/ext/headerscanner_test.go | 60 ++++++++
pkg/protocol/http1/ext/stream_test.go | 139 +++++++++++++++++++
pkg/protocol/http1/req/header_test.go | 9 ++
pkg/protocol/http1/req/request_test.go | 74 ++++++----
pkg/protocol/http1/resp/response_test.go | 6 +
pkg/protocol/http1/server_test.go | 115 ++++++++++++++-
14 files changed, 679 insertions(+), 29 deletions(-)
diff --git a/pkg/common/test/assert/assert.go b/pkg/common/test/assert/assert.go
index 9ed544f35..ed157835c 100644
--- a/pkg/common/test/assert/assert.go
+++ b/pkg/common/test/assert/assert.go
@@ -86,10 +86,12 @@ func NotEqual(t testing.TB, expected, actual interface{}) {
}
func True(t testing.TB, obj interface{}) {
+ t.Helper()
DeepEqual(t, true, obj)
}
func False(t testing.TB, obj interface{}) {
+ t.Helper()
DeepEqual(t, false, obj)
}
diff --git a/pkg/common/test/mock/body_data.go b/pkg/common/test/mock/body_data.go
index 808e25e99..807c585c2 100644
--- a/pkg/common/test/mock/body_data.go
+++ b/pkg/common/test/mock/body_data.go
@@ -16,6 +16,8 @@
package mock
+import "fmt"
+
func CreateFixedBody(bodySize int) []byte {
var b []byte
for i := 0; i < bodySize; i++ {
@@ -23,3 +25,29 @@ func CreateFixedBody(bodySize int) []byte {
}
return b
}
+
+func CreateChunkedBody(body []byte, trailer map[string]string, hasTrailer bool) []byte {
+ var b []byte
+ chunkSize := 1
+ for len(body) > 0 {
+ if chunkSize > len(body) {
+ chunkSize = len(body)
+ }
+ b = append(b, []byte(fmt.Sprintf("%x\r\n", chunkSize))...)
+ b = append(b, body[:chunkSize]...)
+ b = append(b, []byte("\r\n")...)
+ body = body[chunkSize:]
+ chunkSize++
+ }
+ if hasTrailer {
+ b = append(b, "0\r\n"...)
+ for k, v := range trailer {
+ b = append(b, k...)
+ b = append(b, ": "...)
+ b = append(b, v...)
+ b = append(b, "\r\n"...)
+ }
+ b = append(b, "\r\n"...)
+ }
+ return b
+}
diff --git a/pkg/common/test/mock/body_data_test.go b/pkg/common/test/mock/body_data_test.go
index a783a47c4..62ebf1bcd 100644
--- a/pkg/common/test/mock/body_data_test.go
+++ b/pkg/common/test/mock/body_data_test.go
@@ -18,6 +18,8 @@ package mock
import (
"testing"
+
+ "github.com/cloudwego/hertz/pkg/common/test/assert"
)
func TestGenerateCreateFixedBody(t *testing.T) {
@@ -33,3 +35,13 @@ func TestGenerateCreateFixedBody(t *testing.T) {
t.Fatalf("Unexpected %s. Expecting a nil", nilFixedBody)
}
}
+
+func TestGenerateCreateChunkedBody(t *testing.T) {
+ bodySize := 10
+ b := CreateFixedBody(bodySize)
+ trailer := map[string]string{"Foo": "chunked shit"}
+ expectCb := "1\r\n0\r\n2\r\n12\r\n3\r\n345\r\n4\r\n6789\r\n0\r\nFoo: chunked shit\r\n\r\n"
+
+ cb := CreateChunkedBody(b, trailer, true)
+ assert.DeepEqual(t, expectCb, string(cb))
+}
diff --git a/pkg/common/test/mock/network.go b/pkg/common/test/mock/network.go
index 3d4747563..7ad083647 100644
--- a/pkg/common/test/mock/network.go
+++ b/pkg/common/test/mock/network.go
@@ -125,6 +125,10 @@ func (m *Conn) WriterRecorder() Recorder {
return &recorder{c: m, Reader: m.zw}
}
+func (m *Conn) GetReadTimeout() time.Duration {
+ return m.readTimeout
+}
+
type recorder struct {
c *Conn
network.Reader
diff --git a/pkg/common/test/mock/network_test.go b/pkg/common/test/mock/network_test.go
index 84df1cdae..2d4ff973a 100644
--- a/pkg/common/test/mock/network_test.go
+++ b/pkg/common/test/mock/network_test.go
@@ -34,6 +34,7 @@ func TestConn(t *testing.T) {
assert.DeepEqual(t, nil, err)
err = conn1.SetReadTimeout(time.Millisecond * 100)
assert.DeepEqual(t, nil, err)
+ assert.DeepEqual(t, time.Millisecond*100, conn1.GetReadTimeout())
// Peek Skip Read
b, _ := conn1.Peek(1)
diff --git a/pkg/protocol/http1/client_test.go b/pkg/protocol/http1/client_test.go
index 1e4e0139f..19154453c 100644
--- a/pkg/protocol/http1/client_test.go
+++ b/pkg/protocol/http1/client_test.go
@@ -55,12 +55,15 @@ import (
"testing"
"time"
+ "github.com/cloudwego/hertz/pkg/app/client/retry"
"github.com/cloudwego/hertz/pkg/common/config"
errs "github.com/cloudwego/hertz/pkg/common/errors"
"github.com/cloudwego/hertz/pkg/common/test/assert"
"github.com/cloudwego/hertz/pkg/common/test/mock"
+ "github.com/cloudwego/hertz/pkg/common/utils"
"github.com/cloudwego/hertz/pkg/network"
"github.com/cloudwego/hertz/pkg/protocol"
+ "github.com/cloudwego/hertz/pkg/protocol/client"
"github.com/cloudwego/hertz/pkg/protocol/consts"
"github.com/cloudwego/hertz/pkg/protocol/http1/resp"
"github.com/cloudwego/netpoll"
@@ -363,6 +366,103 @@ func TestDialTimeoutPriority(t *testing.T) {
}
}
+func TestStateObserve(t *testing.T) {
+ syncState := struct {
+ mu sync.Mutex
+ state config.ConnPoolState
+ }{}
+ c := &HostClient{
+ ClientOptions: &ClientOptions{
+ Dialer: newSlowConnDialer(func(network, addr string) (network.Conn, error) {
+ return mock.SlowReadDialer(addr)
+ }),
+ StateObserve: func(hcs config.HostClientState) {
+ syncState.mu.Lock()
+ defer syncState.mu.Unlock()
+ syncState.state = hcs.ConnPoolState()
+ },
+ ObservationInterval: 50 * time.Millisecond,
+ },
+ Addr: "foobar",
+ closed: make(chan struct{}),
+ }
+
+ c.SetDynamicConfig(&client.DynamicConfig{
+ Addr: utils.AddMissingPort(c.Addr, true),
+ })
+
+ time.Sleep(500 * time.Millisecond)
+ assert.Nil(t, c.Close())
+ syncState.mu.Lock()
+ assert.DeepEqual(t, "foobar:443", syncState.state.Addr)
+ syncState.mu.Unlock()
+}
+
+func TestCachedTLSConfig(t *testing.T) {
+ c := &HostClient{
+ ClientOptions: &ClientOptions{
+ Dialer: newSlowConnDialer(func(network, addr string) (network.Conn, error) {
+ return mock.SlowReadDialer(addr)
+ }),
+ TLSConfig: &tls.Config{
+ InsecureSkipVerify: true,
+ },
+ },
+ Addr: "foobar",
+ IsTLS: true,
+ }
+
+ cfg1 := c.cachedTLSConfig("foobar")
+ cfg2 := c.cachedTLSConfig("baz")
+ assert.NotEqual(t, cfg1, cfg2)
+ cfg3 := c.cachedTLSConfig("foobar")
+ assert.DeepEqual(t, cfg1, cfg3)
+}
+
+func TestRetry(t *testing.T) {
+ var times int32
+ c := &HostClient{
+ ClientOptions: &ClientOptions{
+ Dialer: newSlowConnDialer(func(network, addr string) (network.Conn, error) {
+ times++
+ if times < 3 {
+ return &retryConn{
+ Conn: mock.NewConn(""),
+ }, nil
+ }
+ return mock.NewConn("HTTP/1.1 200 OK\r\nContent-Length: 10\r\nContent-Type: foo/bar\r\n\r\n0123456789"), nil
+ }),
+ RetryConfig: &retry.Config{
+ MaxAttemptTimes: 5,
+ Delay: time.Millisecond * 10,
+ },
+ RetryIfFunc: func(req *protocol.Request, resp *protocol.Response, err error) bool {
+ return true
+ },
+ },
+ Addr: "foobar",
+ }
+
+ req := protocol.AcquireRequest()
+ req.SetRequestURI("http://foobar/baz")
+ req.SetOptions(config.WithWriteTimeout(time.Millisecond * 100))
+ resp := protocol.AcquireResponse()
+
+ ch := make(chan error, 1)
+ go func() {
+ ch <- c.Do(context.Background(), req, resp)
+ }()
+ select {
+ case <-time.After(time.Second * 2):
+ t.Fatalf("should use writeTimeout in request options")
+ case err := <-ch:
+ assert.Nil(t, err)
+ assert.True(t, times == 3)
+ assert.DeepEqual(t, resp.StatusCode(), 200)
+ assert.DeepEqual(t, resp.Body(), []byte("0123456789"))
+ }
+}
+
// mockConn for getting error when write binary data.
type writeErrConn struct {
network.Conn
@@ -371,3 +471,11 @@ type writeErrConn struct {
func (w writeErrConn) WriteBinary(b []byte) (n int, err error) {
return 0, errs.ErrConnectionClosed
}
+
+type retryConn struct {
+ network.Conn
+}
+
+func (w retryConn) SetWriteTimeout(t time.Duration) error {
+ return errors.New("should retry")
+}
diff --git a/pkg/protocol/http1/client_unix_test.go b/pkg/protocol/http1/client_unix_test.go
index ba616ad54..6e417ba18 100644
--- a/pkg/protocol/http1/client_unix_test.go
+++ b/pkg/protocol/http1/client_unix_test.go
@@ -19,11 +19,15 @@ package http1
import (
"context"
+ "errors"
"net/http"
"runtime"
+ "sync"
+ "sync/atomic"
"testing"
"time"
+ errs "github.com/cloudwego/hertz/pkg/common/errors"
"github.com/cloudwego/hertz/pkg/common/test/assert"
"github.com/cloudwego/hertz/pkg/network/netpoll"
"github.com/cloudwego/hertz/pkg/protocol"
@@ -63,3 +67,55 @@ func TestGcBodyStream(t *testing.T) {
c.CloseIdleConnections()
assert.DeepEqual(t, 0, c.ConnPoolState().TotalConnNum)
}
+
+func TestMaxConn(t *testing.T) {
+ srv := &http.Server{Addr: ":11002", Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
+ w.Write([]byte("hello world\n"))
+ })}
+ go srv.ListenAndServe()
+ time.Sleep(100 * time.Millisecond)
+
+ c := &HostClient{
+ ClientOptions: &ClientOptions{
+ Dialer: netpoll.NewDialer(),
+ ResponseBodyStream: true,
+ MaxConnWaitTimeout: time.Millisecond * 100,
+ MaxConns: 5,
+ },
+ Addr: "127.0.0.1:11002",
+ }
+
+ var successCount int32
+ var noFreeCount int32
+ wg := sync.WaitGroup{}
+ for i := 0; i < 10; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ req, resp := protocol.AcquireRequest(), protocol.AcquireResponse()
+ req.SetRequestURI("http://127.0.0.1:11002")
+ req.SetMethod(consts.MethodPost)
+ err := c.Do(context.Background(), req, resp)
+ if err != nil {
+ if errors.Is(err, errs.ErrNoFreeConns) {
+ atomic.AddInt32(&noFreeCount, 1)
+ return
+ }
+ t.Errorf("client Do error=%v", err.Error())
+ }
+ atomic.AddInt32(&successCount, 1)
+ }()
+ }
+ wg.Wait()
+
+ assert.True(t, atomic.LoadInt32(&successCount) == 5)
+ assert.True(t, atomic.LoadInt32(&noFreeCount) == 5)
+ assert.DeepEqual(t, 0, c.ConnectionCount())
+ assert.DeepEqual(t, 5, c.WantConnectionCount())
+
+ runtime.GC()
+ // wait for gc
+ time.Sleep(100 * time.Millisecond)
+ c.CloseIdleConnections()
+ assert.DeepEqual(t, 0, c.WantConnectionCount())
+}
diff --git a/pkg/protocol/http1/ext/common_test.go b/pkg/protocol/http1/ext/common_test.go
index 5ad9c001a..824ccf4ca 100644
--- a/pkg/protocol/http1/ext/common_test.go
+++ b/pkg/protocol/http1/ext/common_test.go
@@ -18,12 +18,16 @@ package ext
import (
"bytes"
+ "errors"
+ "io"
"strings"
"testing"
+ errs "github.com/cloudwego/hertz/pkg/common/errors"
"github.com/cloudwego/hertz/pkg/common/test/assert"
"github.com/cloudwego/hertz/pkg/common/test/mock"
"github.com/cloudwego/hertz/pkg/protocol"
+ "github.com/cloudwego/netpoll"
)
func Test_stripSpace(t *testing.T) {
@@ -77,6 +81,12 @@ func TestReadTrailerError(t *testing.T) {
if err == nil {
t.Fatalf("expecting error.")
}
+
+ // eof
+ er := mock.EOFReader{}
+ trailer = protocol.Trailer{}
+ err = ReadTrailer(&trailer, &er)
+ assert.DeepEqual(t, io.EOF, err)
}
func TestReadTrailer1(t *testing.T) {
@@ -95,3 +105,87 @@ func TestReadTrailer1(t *testing.T) {
}
}
}
+
+func TestReadRawHeaders(t *testing.T) {
+ s := "HTTP/1.1 200 OK\r\n" +
+ "EmptyValue1:\r\n" +
+ "Content-Type: foo/bar;\r\n\tnewline;\r\n another/newline\r\n" +
+ "Foo: Bar\r\n" +
+ "Multi-Line: one;\r\n two\r\n" +
+ "Values: v1;\r\n v2; v3;\r\n v4;\tv5\r\n" +
+ "Content-Length: 5\r\n\r\n" +
+ "HELLOaaa"
+
+ var dst []byte
+ rawHeaders, index, err := ReadRawHeaders(dst, []byte(s))
+ assert.Nil(t, err)
+ assert.DeepEqual(t, s[:index], string(rawHeaders))
+}
+
+func TestBodyChunked(t *testing.T) {
+ body := "foobar baz aaa bbb ccc"
+ chunk := "16\r\nfoobar baz aaa bbb ccc\r\n0\r\n"
+ b := bytes.NewBufferString(body)
+
+ var w bytes.Buffer
+ zw := netpoll.NewWriter(&w)
+ WriteBodyChunked(zw, b)
+
+ assert.DeepEqual(t, chunk, w.String())
+
+ zr := mock.NewZeroCopyReader(chunk)
+ rb, err := ReadBody(zr, -1, 0, nil)
+ assert.Nil(t, err)
+ assert.DeepEqual(t, body, string(rb))
+}
+
+func TestBodyFixedSize(t *testing.T) {
+ body := mock.CreateFixedBody(10)
+ b := bytes.NewBuffer(body)
+
+ var w bytes.Buffer
+ zw := netpoll.NewWriter(&w)
+ WriteBodyFixedSize(zw, b, int64(len(body)))
+
+ assert.DeepEqual(t, body, w.Bytes())
+
+ zr := mock.NewZeroCopyReader(string(body))
+ rb, err := ReadBody(zr, len(body), 0, nil)
+ assert.Nil(t, err)
+ assert.DeepEqual(t, body, rb)
+}
+
+func TestBodyIdentity(t *testing.T) {
+ body := mock.CreateFixedBody(1024)
+ zr := mock.NewZeroCopyReader(string(body))
+ rb, err := ReadBody(zr, -2, 0, nil)
+ assert.Nil(t, err)
+ assert.DeepEqual(t, string(body), string(rb))
+}
+
+func TestBodySkipTrailer(t *testing.T) {
+ t.Run("TestBodySkipTrailer", func(t *testing.T) {
+ body := mock.CreateFixedBody(10)
+ trailer := map[string]string{"Foo": "chunked shit"}
+ chunkedBody := mock.CreateChunkedBody(body, trailer, true)
+ r := mock.NewSlowReadConn(string(chunkedBody))
+ err := SkipTrailer(r)
+ assert.Nil(t, err)
+ _, err = r.ReadByte()
+ assert.NotNil(t, err)
+ assert.True(t, errors.Is(err, netpoll.ErrEOF))
+ })
+
+ t.Run("TestBodySkipTrailerError", func(t *testing.T) {
+ // timeout error
+ sr := mock.NewSlowReadConn("")
+ err := SkipTrailer(sr)
+ assert.NotNil(t, err)
+ assert.True(t, errors.Is(err, errs.ErrTimeout))
+ // EOF error
+ er := &mock.EOFReader{}
+ err = SkipTrailer(er)
+ assert.NotNil(t, err)
+ assert.True(t, errors.Is(err, io.EOF))
+ })
+}
diff --git a/pkg/protocol/http1/ext/headerscanner_test.go b/pkg/protocol/http1/ext/headerscanner_test.go
index fc1f92126..0f8874d81 100644
--- a/pkg/protocol/http1/ext/headerscanner_test.go
+++ b/pkg/protocol/http1/ext/headerscanner_test.go
@@ -42,8 +42,13 @@
package ext
import (
+ "bufio"
+ "errors"
+ "net/http"
+ "strings"
"testing"
+ errs "github.com/cloudwego/hertz/pkg/common/errors"
"github.com/cloudwego/hertz/pkg/common/test/assert"
)
@@ -52,3 +57,58 @@ func TestHasHeaderValue(t *testing.T) {
assert.True(t, HasHeaderValue(s, []byte("Connection: Keep-Alive")))
assert.False(t, HasHeaderValue(s, []byte("Connection: Keep-Alive1")))
}
+
+func TestResponseHeaderMultiLineValue(t *testing.T) {
+ firstLine := "HTTP/1.1 200 OK\r\n"
+ rawHeaders := "EmptyValue1:\r\n" +
+ "Content-Type: foo/bar;\r\n\tnewline;\r\n another/newline\r\n" +
+ "Foo: Bar\r\n" +
+ "Multi-Line: one;\r\n two\r\n" +
+ "Values: v1;\r\n v2; v3;\r\n v4;\tv5\r\n" +
+ "\r\n"
+
+ // compared with http response
+ response, err := http.ReadResponse(bufio.NewReader(strings.NewReader(firstLine+rawHeaders)), nil)
+ assert.Nil(t, err)
+ defer func() { response.Body.Close() }()
+
+ hs := &HeaderScanner{}
+ hs.B = []byte(rawHeaders)
+ hs.DisableNormalizing = false
+ hmap := make(map[string]string, len(response.Header))
+ for hs.Next() {
+ if len(hs.Key) > 0 {
+ hmap[string(hs.Key)] = string(hs.Value)
+ }
+ }
+
+ for name, vals := range response.Header {
+ got := hmap[name]
+ want := vals[0]
+ assert.DeepEqual(t, want, got)
+ }
+}
+
+func TestHeaderScannerError(t *testing.T) {
+ t.Run("TestHeaderScannerErrorInvalidName", func(t *testing.T) {
+ rawHeaders := "Host: go.dev\r\nGopher-New-\r\n Line: This is a header on multiple lines\r\n\r\n"
+ testTestHeaderScannerError(t, rawHeaders, errInvalidName)
+ })
+ t.Run("TestHeaderScannerErrorNeedMore", func(t *testing.T) {
+ rawHeaders := "This is a header on multiple lines"
+ testTestHeaderScannerError(t, rawHeaders, errs.ErrNeedMore)
+
+ rawHeaders = "Gopher-New-\r\n Line"
+ testTestHeaderScannerError(t, rawHeaders, errs.ErrNeedMore)
+ })
+}
+
+func testTestHeaderScannerError(t *testing.T, rawHeaders string, expectError error) {
+ hs := &HeaderScanner{}
+ hs.B = []byte(rawHeaders)
+ hs.DisableNormalizing = false
+ for hs.Next() {
+ }
+ assert.NotNil(t, hs.Err)
+ assert.True(t, errors.Is(hs.Err, expectError))
+}
diff --git a/pkg/protocol/http1/ext/stream_test.go b/pkg/protocol/http1/ext/stream_test.go
index 188d234eb..ccec83dab 100644
--- a/pkg/protocol/http1/ext/stream_test.go
+++ b/pkg/protocol/http1/ext/stream_test.go
@@ -17,6 +17,7 @@ package ext
import (
"bytes"
+ "errors"
"fmt"
"io"
"testing"
@@ -112,3 +113,141 @@ func TestBodyStream_Reset(t *testing.T) {
assert.DeepEqual(t, 0, bs.chunkLeft)
assert.False(t, bs.chunkEOF)
}
+
+func TestReadBodyWithStreaming(t *testing.T) {
+ t.Run("TestBodyFixedSize", func(t *testing.T) {
+ bodySize := 1024
+ body := mock.CreateFixedBody(bodySize)
+ reader := mock.NewZeroCopyReader(string(body))
+ dst, err := ReadBodyWithStreaming(reader, bodySize, -1, nil)
+ assert.Nil(t, err)
+ assert.DeepEqual(t, body, dst)
+ })
+
+ t.Run("TestBodyFixedSizeMaxContentLength", func(t *testing.T) {
+ bodySize := 8 * 1024 * 2
+ body := mock.CreateFixedBody(bodySize)
+ reader := mock.NewZeroCopyReader(string(body))
+ dst, err := ReadBodyWithStreaming(reader, bodySize, 8*1024*10, nil)
+ assert.Nil(t, err)
+ assert.DeepEqual(t, body[:maxContentLengthInStream], dst)
+ })
+
+ t.Run("TestBodyIdentity", func(t *testing.T) {
+ bodySize := 1024
+ body := mock.CreateFixedBody(bodySize)
+ reader := mock.NewZeroCopyReader(string(body))
+ dst, err := ReadBodyWithStreaming(reader, -2, 512, nil)
+ assert.Nil(t, err)
+ assert.DeepEqual(t, body, dst)
+ })
+
+ t.Run("TestErrBodyTooLarge", func(t *testing.T) {
+ bodySize := 2048
+ body := mock.CreateFixedBody(bodySize)
+ reader := mock.NewZeroCopyReader(string(body))
+ dst, err := ReadBodyWithStreaming(reader, bodySize, 1024, nil)
+ assert.True(t, errors.Is(err, errBodyTooLarge))
+ assert.DeepEqual(t, body[:len(dst)], dst)
+ })
+
+ t.Run("TestErrChunkedStream", func(t *testing.T) {
+ bodySize := 1024
+ body := mock.CreateFixedBody(bodySize)
+ reader := mock.NewZeroCopyReader(string(body))
+ dst, err := ReadBodyWithStreaming(reader, -1, bodySize, nil)
+ assert.True(t, errors.Is(err, errChunkedStream))
+ assert.Nil(t, dst)
+ })
+}
+
+func TestBodyStream(t *testing.T) {
+ t.Run("TestBodyStreamPrereadBuffer", func(t *testing.T) {
+ bodySize := 1024
+ body := mock.CreateFixedBody(bodySize)
+ byteBuffer := &bytebufferpool.ByteBuffer{}
+ byteBuffer.Set(body)
+
+ bs := AcquireBodyStream(byteBuffer, mock.NewSlowReadConn(""), nil, len(body))
+ defer func() {
+ ReleaseBodyStream(bs)
+ }()
+
+ b := make([]byte, bodySize)
+ err := bodyStreamRead(bs, b)
+ assert.Nil(t, err)
+ assert.DeepEqual(t, len(body), len(b))
+ assert.DeepEqual(t, string(body), string(b))
+ })
+
+ t.Run("TestBodyStreamRelease", func(t *testing.T) {
+ bodySize := 1024
+ body := mock.CreateFixedBody(bodySize)
+ byteBuffer := &bytebufferpool.ByteBuffer{}
+ byteBuffer.Set(body)
+ bs := AcquireBodyStream(byteBuffer, mock.NewSlowReadConn(string(body)), nil, bodySize*2)
+ err := ReleaseBodyStream(bs)
+ assert.Nil(t, err)
+ })
+
+ t.Run("TestBodyStreamChunked", func(t *testing.T) {
+ bodySize := 5
+ body := mock.CreateFixedBody(bodySize)
+ expectedTrailer := map[string]string{"Foo": "chunked shit"}
+ chunkedBody := mock.CreateChunkedBody(body, expectedTrailer, true)
+
+ byteBuffer := &bytebufferpool.ByteBuffer{}
+ byteBuffer.Set(chunkedBody)
+
+ bs := AcquireBodyStream(byteBuffer, mock.NewSlowReadConn(string(chunkedBody)), &protocol.Trailer{}, -1)
+ defer func() {
+ ReleaseBodyStream(bs)
+ }()
+
+ b := make([]byte, bodySize)
+ err := bodyStreamRead(bs, b)
+ assert.Nil(t, err)
+ assert.DeepEqual(t, len(body), len(b))
+ assert.DeepEqual(t, string(body), string(b))
+ })
+
+ t.Run("TestBodyStreamReadFromWire", func(t *testing.T) {
+ bodySize := 1024
+ body := mock.CreateFixedBody(bodySize)
+ byteBuffer := &bytebufferpool.ByteBuffer{}
+ byteBuffer.Set(body)
+
+ rcBodySize := 128
+ rcBody := mock.CreateFixedBody(rcBodySize)
+ bs := AcquireBodyStream(byteBuffer, mock.NewSlowReadConn(string(rcBody)), nil, -2)
+ defer func() {
+ ReleaseBodyStream(bs)
+ }()
+
+ b := make([]byte, bodySize)
+ err := bodyStreamRead(bs, b)
+ assert.Nil(t, err)
+ assert.DeepEqual(t, len(body), len(b))
+ assert.DeepEqual(t, string(body), string(b))
+ })
+}
+
+func bodyStreamRead(bs io.Reader, b []byte) (err error) {
+ nb := 0
+ for {
+ p := make([]byte, 64)
+ n, rErr := bs.Read(p)
+ if n > 0 {
+ copy(b[nb:], p[:])
+ nb = nb + n
+ }
+
+ if rErr != nil {
+ if rErr != io.EOF {
+ err = rErr
+ }
+ break
+ }
+ }
+ return
+}
diff --git a/pkg/protocol/http1/req/header_test.go b/pkg/protocol/http1/req/header_test.go
index fcf5f4177..489e0aabf 100644
--- a/pkg/protocol/http1/req/header_test.go
+++ b/pkg/protocol/http1/req/header_test.go
@@ -44,11 +44,13 @@ package req
import (
"bufio"
"bytes"
+ "errors"
"fmt"
"net/http"
"strings"
"testing"
+ errs "github.com/cloudwego/hertz/pkg/common/errors"
"github.com/cloudwego/hertz/pkg/common/test/assert"
"github.com/cloudwego/hertz/pkg/common/test/mock"
"github.com/cloudwego/hertz/pkg/protocol"
@@ -412,3 +414,10 @@ func TestRequestHeader_PeekIfExists(t *testing.T) {
assert.DeepEqual(t, []byte{}, rh.Peek("exists"))
assert.DeepEqual(t, []byte(nil), rh.Peek("non-exists"))
}
+
+func TestRequestHeaderError(t *testing.T) {
+ er := mock.EOFReader{}
+ rh := protocol.RequestHeader{}
+ err := ReadHeader(&rh, &er)
+ assert.True(t, errors.Is(err, errs.ErrNothingRead))
+}
diff --git a/pkg/protocol/http1/req/request_test.go b/pkg/protocol/http1/req/request_test.go
index e123da406..0411187a5 100644
--- a/pkg/protocol/http1/req/request_test.go
+++ b/pkg/protocol/http1/req/request_test.go
@@ -44,6 +44,7 @@ package req
import (
"bufio"
"bytes"
+ "encoding/base64"
"errors"
"fmt"
"io"
@@ -54,6 +55,7 @@ import (
"testing"
"github.com/cloudwego/hertz/internal/bytesconv"
+ "github.com/cloudwego/hertz/internal/bytestr"
"github.com/cloudwego/hertz/pkg/common/bytebufferpool"
"github.com/cloudwego/hertz/pkg/common/compress"
errs "github.com/cloudwego/hertz/pkg/common/errors"
@@ -451,7 +453,7 @@ func TestRequestWriteRequestURINoHost(t *testing.T) {
t.Parallel()
var req protocol.Request
- req.Header.SetRequestURI("http://google.com/foo/bar?baz=aaa")
+ req.Header.SetRequestURI("http://user:pass@google.com/foo/bar?baz=aaa")
var w bytes.Buffer
zw := netpoll.NewWriter(&w)
if err := Write(&req, zw); err != nil {
@@ -474,6 +476,16 @@ func TestRequestWriteRequestURINoHost(t *testing.T) {
if string(req.Header.RequestURI()) != "/foo/bar?baz=aaa" {
t.Fatalf("unexpected requestURI: %q. Expecting %q", req.Header.RequestURI(), "/foo/bar?baz=aaa")
}
+ // authorization
+ authorization := req.Header.Get(string(bytestr.StrAuthorization))
+ author, err := base64.StdEncoding.DecodeString(authorization[len(bytestr.StrBasicSpace):])
+ if err != nil {
+ t.Fatalf("expecting error")
+ }
+
+ if string(author) != "user:pass" {
+ t.Fatalf("unexpected Authorization: %q. Expecting %q", authorization, "user:pass")
+ }
// verify that Write returns error on non-absolute RequestURI
req.Reset()
@@ -484,6 +496,38 @@ func TestRequestWriteRequestURINoHost(t *testing.T) {
}
}
+func TestRequestWriteMultipartFile(t *testing.T) {
+ t.Parallel()
+
+ var req protocol.Request
+ req.Header.SetHost("foobar.com")
+ req.Header.SetMethod(consts.MethodPost)
+ req.SetFileReader("filea", "filea.txt", bytes.NewReader([]byte("This is filea.")))
+ req.SetMultipartField("fileb", "fileb.txt", "text/plain", bytes.NewReader([]byte("This is fileb.")))
+
+ var w bytes.Buffer
+ zw := netpoll.NewWriter(&w)
+ if err := Write(&req, zw); err != nil {
+ t.Fatalf("unexpected error: %s", err)
+ }
+ if err := zw.Flush(); err != nil {
+ t.Fatalf("unexpected error: %s", err)
+ }
+
+ var req1 protocol.Request
+ zr := mock.NewZeroCopyReader(w.String())
+ if err := Read(&req1, zr); err != nil {
+ t.Fatalf("unexpected error: %s", err)
+ }
+
+ filea, err := req1.FormFile("filea")
+ assert.Nil(t, err)
+ assert.DeepEqual(t, "filea.txt", filea.Filename)
+ fileb, err := req1.FormFile("fileb")
+ assert.Nil(t, err)
+ assert.DeepEqual(t, "fileb.txt", fileb.Filename)
+}
+
func TestSetRequestBodyStreamChunked(t *testing.T) {
t.Parallel()
@@ -1037,7 +1081,7 @@ func testRequestMultipartFormNotPreParse(t *testing.T, boundary string, formData
func testReadBodyChunked(t *testing.T, bodySize int) {
body := mock.CreateFixedBody(bodySize)
expectedTrailer := map[string]string{"Foo": "chunked shit"}
- chunkedBody := createChunkedBody(body, expectedTrailer, true)
+ chunkedBody := mock.CreateChunkedBody(body, expectedTrailer, true)
zr := mock.NewZeroCopyReader(string(chunkedBody))
@@ -1052,32 +1096,6 @@ func testReadBodyChunked(t *testing.T, bodySize int) {
verifyTrailer(t, zr, expectedTrailer)
}
-func createChunkedBody(body []byte, trailer map[string]string, hasTrailer bool) []byte {
- var b []byte
- chunkSize := 1
- for len(body) > 0 {
- if chunkSize > len(body) {
- chunkSize = len(body)
- }
- b = append(b, []byte(fmt.Sprintf("%x\r\n", chunkSize))...)
- b = append(b, body[:chunkSize]...)
- b = append(b, []byte("\r\n")...)
- body = body[chunkSize:]
- chunkSize++
- }
- if hasTrailer {
- b = append(b, "0\r\n"...)
- for k, v := range trailer {
- b = append(b, k...)
- b = append(b, ": "...)
- b = append(b, v...)
- b = append(b, "\r\n"...)
- }
- b = append(b, "\r\n"...)
- }
- return b
-}
-
func testReadBodyFixedSize(t *testing.T, bodySize int) {
body := mock.CreateFixedBody(bodySize)
diff --git a/pkg/protocol/http1/resp/response_test.go b/pkg/protocol/http1/resp/response_test.go
index cff0a5783..0ff010fd7 100644
--- a/pkg/protocol/http1/resp/response_test.go
+++ b/pkg/protocol/http1/resp/response_test.go
@@ -820,3 +820,9 @@ func TestResponseReadBodyStreamBadTrailer(t *testing.T) {
testResponseReadBodyStreamBadTrailer(t, resp, "HTTP/1.1 300 OK\r\nTransfer-Encoding: chunked\r\nContent-Type: bar\r\n\r\n5\r\n56789\r\n0\r\ncontent-type: bar\r\n\r\n")
testResponseReadBodyStreamBadTrailer(t, resp, "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nTransfer-Encoding: chunked\r\n\r\n4\r\nqwer\r\n2\r\nty\r\n0\r\nproxy-connection: bar2\r\n\r\n")
}
+
+func TestResponseString(t *testing.T) {
+ resp := protocol.Response{}
+ resp.Header.Set("Location", "foo\r\nSet-Cookie: SESSIONID=MaliciousValue\r\n")
+ assert.True(t, strings.Contains(GetHTTP1Response(&resp).String(), "Location: foo\r\nSet-Cookie: SESSIONID=MaliciousValue\r\n"))
+}
diff --git a/pkg/protocol/http1/server_test.go b/pkg/protocol/http1/server_test.go
index dfa8b78d3..e34bb3f9a 100644
--- a/pkg/protocol/http1/server_test.go
+++ b/pkg/protocol/http1/server_test.go
@@ -20,8 +20,10 @@ import (
"bytes"
"context"
"errors"
+ "strings"
"sync"
"testing"
+ "time"
inStats "github.com/cloudwego/hertz/internal/stats"
"github.com/cloudwego/hertz/pkg/app"
@@ -33,6 +35,7 @@ import (
"github.com/cloudwego/hertz/pkg/common/tracer/traceinfo"
"github.com/cloudwego/hertz/pkg/network"
"github.com/cloudwego/hertz/pkg/protocol"
+ "github.com/cloudwego/hertz/pkg/protocol/consts"
"github.com/cloudwego/hertz/pkg/protocol/http1/resp"
)
@@ -247,14 +250,124 @@ func TestHijackResponseWriter(t *testing.T) {
assert.True(t, isFinal)
}
+func TestHijackHandler(t *testing.T) {
+ server := NewServer()
+ reqCtx := &app.RequestContext{}
+ originReadTimeout := time.Second
+ hijackReadTimeout := 200 * time.Millisecond
+ reqCtx.SetHijackHandler(func(c network.Conn) {
+ c.SetReadTimeout(hijackReadTimeout) // hijack read timeout
+ })
+
+ server.Core = &mockCore{
+ ctxPool: &sync.Pool{New: func() interface{} {
+ return reqCtx
+ }},
+ }
+
+ server.HijackConnHandle = func(c network.Conn, h app.HijackHandler) {
+ h(c)
+ }
+
+ defaultConn := mock.NewConn("GET / HTTP/1.1\nHost: foobar.com\n\n")
+ defaultConn.SetReadTimeout(originReadTimeout)
+ assert.DeepEqual(t, originReadTimeout, defaultConn.GetReadTimeout())
+ err := server.Serve(context.TODO(), defaultConn)
+ assert.True(t, errors.Is(err, errs.ErrHijacked))
+ assert.DeepEqual(t, hijackReadTimeout, defaultConn.GetReadTimeout())
+}
+
+func TestKeepAlive(t *testing.T) {
+ server := NewServer()
+ reqCtx := &app.RequestContext{}
+ times := 0
+ server.Core = &mockCore{
+ ctxPool: &sync.Pool{New: func() interface{} {
+ return reqCtx
+ }},
+ isRunning: true,
+ mockHandler: func(c context.Context, ctx *app.RequestContext) {
+ times++
+ if string(ctx.Path()) == "/close" {
+ ctx.SetConnectionClose()
+ }
+ },
+ }
+ server.IdleTimeout = time.Second
+
+ var s strings.Builder
+ s.WriteString("GET / HTTP/1.1\r\nHost: aaa\r\nConnection: keep-alive\r\n\r\n")
+ s.WriteString("GET /close HTTP/1.0\r\nHost: aaa\r\nConnection: keep-alive\r\n\r\n") // set connection close
+
+ defaultConn := mock.NewConn(s.String())
+ err := server.Serve(context.TODO(), defaultConn)
+ assert.True(t, errors.Is(err, errs.ErrShortConnection))
+ assert.DeepEqual(t, times, 2)
+}
+
+func TestExpect100Continue(t *testing.T) {
+ server := &Server{}
+ reqCtx := &app.RequestContext{}
+ server.Core = &mockCore{
+ ctxPool: &sync.Pool{New: func() interface{} {
+ return reqCtx
+ }},
+ mockHandler: func(c context.Context, ctx *app.RequestContext) {
+ data, err := ctx.Body()
+ if err == nil {
+ ctx.Write(data)
+ }
+ },
+ }
+
+ defaultConn := mock.NewConn("POST /foo HTTP/1.1\r\nHost: gle.com\r\nExpect: 100-continue\r\nContent-Length: 5\r\nContent-Type: a/b\r\n\r\n12345")
+ err := server.Serve(context.TODO(), defaultConn)
+ assert.True(t, errors.Is(err, errs.ErrShortConnection))
+ defaultResponseResult := defaultConn.WriterRecorder()
+ assert.DeepEqual(t, 0, defaultResponseResult.Len())
+ response := protocol.AcquireResponse()
+ resp.Read(response, defaultResponseResult)
+ assert.DeepEqual(t, "12345", string(response.Body()))
+}
+
+func TestExpect100ContinueHandler(t *testing.T) {
+ server := &Server{}
+ reqCtx := &app.RequestContext{}
+ server.Core = &mockCore{
+ ctxPool: &sync.Pool{New: func() interface{} {
+ return reqCtx
+ }},
+ mockHandler: func(c context.Context, ctx *app.RequestContext) {
+ data, err := ctx.Body()
+ if err == nil {
+ ctx.Write(data)
+ }
+ },
+ }
+ server.ContinueHandler = func(header *protocol.RequestHeader) bool {
+ return false
+ }
+
+ defaultConn := mock.NewConn("POST /foo HTTP/1.1\r\nHost: gle.com\r\nExpect: 100-continue\r\nContent-Length: 5\r\nContent-Type: a/b\r\n\r\n12345")
+ err := server.Serve(context.TODO(), defaultConn)
+ assert.True(t, errors.Is(err, errs.ErrShortConnection))
+ defaultResponseResult := defaultConn.WriterRecorder()
+ assert.DeepEqual(t, 0, defaultResponseResult.Len())
+ response := protocol.AcquireResponse()
+ resp.Read(response, defaultResponseResult)
+ assert.DeepEqual(t, consts.StatusExpectationFailed, response.StatusCode())
+ assert.DeepEqual(t, "", string(response.Body()))
+}
+
type mockCore struct {
ctxPool *sync.Pool
controller tracer.Controller
mockHandler func(c context.Context, ctx *app.RequestContext)
+ isRunning bool
}
func (m *mockCore) IsRunning() bool {
- return false
+ return m.isRunning
}
func (m *mockCore) GetCtxPool() *sync.Pool {
From e1d541bb909376012ec9db6ac59f952424549ce8 Mon Sep 17 00:00:00 2001
From: GuangyuFan <97507466+FGYFFFF@users.noreply.github.com>
Date: Tue, 18 Jul 2023 20:34:59 +0800
Subject: [PATCH 03/17] feat(hz): remove tag (#811)
---
_typos.toml | 3 ++-
cmd/hz/app/app.go | 5 +++++
cmd/hz/config/argument.go | 2 ++
cmd/hz/protobuf/plugin.go | 40 ++++++++++++++++++++++++++++++---------
cmd/hz/thrift/plugin.go | 13 ++++++++++---
5 files changed, 50 insertions(+), 13 deletions(-)
diff --git a/_typos.toml b/_typos.toml
index 9808229cc..3d3103e19 100644
--- a/_typos.toml
+++ b/_typos.toml
@@ -17,4 +17,5 @@ referer = "referer"
HeaderReferer = "HeaderReferer"
expectedReferer = "expectedReferer"
Referer = "Referer"
-O_WRONLY = "O_WRONLY"
\ No newline at end of file
+O_WRONLY = "O_WRONLY"
+WRONLY = "WRONLY"
\ No newline at end of file
diff --git a/cmd/hz/app/app.go b/cmd/hz/app/app.go
index 239b4df21..269e5bad9 100644
--- a/cmd/hz/app/app.go
+++ b/cmd/hz/app/app.go
@@ -184,6 +184,7 @@ func Init() *cli.App {
unsetOmitemptyFlag := cli.BoolFlag{Name: "unset_omitempty", Usage: "Remove 'omitempty' tag for generated struct.", Destination: &globalArgs.UnsetOmitempty}
protoCamelJSONTag := cli.BoolFlag{Name: "pb_camel_json_tag", Usage: "Convert Name style for json tag to camel(Only works protobuf).", Destination: &globalArgs.ProtobufCamelJSONTag}
snakeNameFlag := cli.BoolFlag{Name: "snake_tag", Usage: "Use snake_case style naming for tags. (Only works for 'form', 'query', 'json')", Destination: &globalArgs.SnakeName}
+ rmTagFlag := cli.StringSliceFlag{Name: "rm_tag", Usage: "Remove the specified tag"}
customLayout := cli.StringFlag{Name: "customize_layout", Usage: "Specify the path for layout template.", Destination: &globalArgs.CustomizeLayout}
customLayoutData := cli.StringFlag{Name: "customize_layout_data_path", Usage: "Specify the path for layout template render data.", Destination: &globalArgs.CustomizeLayoutData}
customPackage := cli.StringFlag{Name: "customize_package", Usage: "Specify the path for package template.", Destination: &globalArgs.CustomizePackage}
@@ -230,6 +231,7 @@ func Init() *cli.App {
&unsetOmitemptyFlag,
&protoCamelJSONTag,
&snakeNameFlag,
+ &rmTagFlag,
&excludeFilesFlag,
&customLayout,
&customLayoutData,
@@ -263,6 +265,7 @@ func Init() *cli.App {
&unsetOmitemptyFlag,
&protoCamelJSONTag,
&snakeNameFlag,
+ &rmTagFlag,
&excludeFilesFlag,
&customPackage,
&handlerByMethod,
@@ -289,6 +292,7 @@ func Init() *cli.App {
&unsetOmitemptyFlag,
&protoCamelJSONTag,
&snakeNameFlag,
+ &rmTagFlag,
&excludeFilesFlag,
},
Action: Model,
@@ -315,6 +319,7 @@ func Init() *cli.App {
&unsetOmitemptyFlag,
&protoCamelJSONTag,
&snakeNameFlag,
+ &rmTagFlag,
&excludeFilesFlag,
&customLayout,
&customLayoutData,
diff --git a/cmd/hz/config/argument.go b/cmd/hz/config/argument.go
index 8eef041d2..b8bf41cac 100644
--- a/cmd/hz/config/argument.go
+++ b/cmd/hz/config/argument.go
@@ -64,6 +64,7 @@ type Argument struct {
ProtobufPlugins []string
ThriftPlugins []string
SnakeName bool
+ RmTags []string
Excludes []string
NoRecurse bool
HandlerByMethod bool
@@ -121,6 +122,7 @@ func (arg *Argument) parseStringSlice(c *cli.Context) {
arg.ProtocOptions = c.StringSlice("protoc")
arg.ThriftPlugins = c.StringSlice("thrift-plugins")
arg.ProtobufPlugins = c.StringSlice("protoc-plugins")
+ arg.RmTags = c.StringSlice("rm_tag")
}
func (arg *Argument) UpdateByManifest(m *meta.Manifest) {
diff --git a/cmd/hz/protobuf/plugin.go b/cmd/hz/protobuf/plugin.go
index 82b01cd6a..e25803860 100644
--- a/cmd/hz/protobuf/plugin.go
+++ b/cmd/hz/protobuf/plugin.go
@@ -76,10 +76,22 @@ type Plugin struct {
ModelDir string
UseDir string
IdlClientDir string
+ RmTags RemoveTags
PkgMap map[string]string
logger *logs.StdLogger
}
+type RemoveTags []string
+
+func (rm *RemoveTags) Exist(tag string) bool {
+ for _, rmTag := range *rm {
+ if rmTag == tag {
+ return true
+ }
+ }
+ return false
+}
+
func (plugin *Plugin) Run() int {
plugin.setLogger()
args := &config.Argument{}
@@ -192,6 +204,7 @@ func (plugin *Plugin) Handle(req *pluginpb.CodeGeneratorRequest, args *config.Ar
opts := protogen.Options{}
gen, err := opts.New(req)
plugin.Plugin = gen
+ plugin.RmTags = args.RmTags
if err != nil {
return fmt.Errorf("new protoc plugin failed: %s", err.Error())
}
@@ -322,7 +335,6 @@ func (plugin *Plugin) fixModelPathAndPackage(pkg string) (impt, path string) {
func (plugin *Plugin) GenerateFiles(pluginPb *protogen.Plugin) error {
idl := pluginPb.Request.FileToGenerate[len(pluginPb.Request.FileToGenerate)-1]
pluginPb.SupportedFeatures = gengo.SupportedFeatures
-
for _, f := range pluginPb.Files {
if f.Proto.GetName() == idl {
err := plugin.GenerateFile(pluginPb, f)
@@ -358,7 +370,7 @@ func (plugin *Plugin) GenerateFile(gen *protogen.Plugin, f *protogen.File) error
if len(plugin.UseDir) != 0 {
return nil
}
- file, err := generateFile(gen, f)
+ file, err := generateFile(gen, f, plugin.RmTags)
if err != nil || file == nil {
return fmt.Errorf("generate file %s failed: %s", f.Proto.GetName(), err.Error())
}
@@ -366,7 +378,7 @@ func (plugin *Plugin) GenerateFile(gen *protogen.Plugin, f *protogen.File) error
}
// generateFile generates the contents of a .pb.go file.
-func generateFile(gen *protogen.Plugin, file *protogen.File) (*protogen.GeneratedFile, error) {
+func generateFile(gen *protogen.Plugin, file *protogen.File, rmTags RemoveTags) (*protogen.GeneratedFile, error) {
filename := file.GeneratedFilenamePrefix + ".pb.go"
g := gen.NewGeneratedFile(filename, file.GoImportPath)
f := newFileInfo(file)
@@ -398,7 +410,7 @@ func generateFile(gen *protogen.Plugin, file *protogen.File) (*protogen.Generate
}
var err error
for _, message := range f.allMessages {
- err = genMessage(g, f, message)
+ err = genMessage(g, f, message, rmTags)
if err != nil {
return nil, err
}
@@ -410,7 +422,7 @@ func generateFile(gen *protogen.Plugin, file *protogen.File) (*protogen.Generate
return g, nil
}
-func genMessage(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo) error {
+func genMessage(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo, rmTags RemoveTags) error {
if m.Desc.IsMapEntry() {
return nil
}
@@ -421,7 +433,7 @@ func genMessage(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo) error {
m.Desc.Options().(*descriptorpb.MessageOptions).GetDeprecated())
g.P(leadingComments,
"type ", m.GoIdent, " struct {")
- err := genMessageFields(g, f, m)
+ err := genMessageFields(g, f, m, rmTags)
if err != nil {
return err
}
@@ -435,12 +447,12 @@ func genMessage(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo) error {
return nil
}
-func genMessageFields(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo) error {
+func genMessageFields(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo, rmTags RemoveTags) error {
sf := f.allMessageFieldsByPtr[m]
genMessageInternalFields(g, f, m, sf)
var err error
for _, field := range m.Fields {
- err = genMessageField(g, f, m, field, sf)
+ err = genMessageField(g, f, m, field, sf, rmTags)
if err != nil {
return err
}
@@ -448,7 +460,7 @@ func genMessageFields(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo) er
return nil
}
-func genMessageField(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo, field *protogen.Field, sf *structFields) error {
+func genMessageField(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo, field *protogen.Field, sf *structFields, rmTags RemoveTags) error {
if oneof := field.Oneof; oneof != nil && !oneof.Desc.IsSynthetic() {
// It would be a bit simpler to iterate over the oneofs below,
// but generating the field here keeps the contents of the Go
@@ -506,6 +518,16 @@ func genMessageField(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo, fie
tags = append(tags, gotrackTags...)
}
+ if len(rmTags) > 0 {
+ tmp := structTags{}
+ for _, tag := range tags {
+ if !rmTags.Exist(tag[0]) {
+ tmp = append(tmp, tag)
+ }
+ }
+ tags = tmp
+ }
+
name := field.GoName
if field.Desc.IsWeak() {
name = WeakFieldPrefix_goname + name
diff --git a/cmd/hz/thrift/plugin.go b/cmd/hz/thrift/plugin.go
index 23c5c7597..ff4fdf65f 100644
--- a/cmd/hz/thrift/plugin.go
+++ b/cmd/hz/thrift/plugin.go
@@ -41,6 +41,7 @@ type Plugin struct {
req *thriftgo_plugin.Request
args *config.Argument
logger *logs.StdLogger
+ rmTags []string
}
func (plugin *Plugin) Run() int {
@@ -74,6 +75,7 @@ func (plugin *Plugin) Run() int {
logs.Errorf("parse args failed: %s", err.Error())
return meta.PluginError
}
+ plugin.rmTags = args.RmTags
if args.CmdType == meta.CmdModel {
res, err := plugin.GetResponse(nil, args.OutDir)
if err != nil {
@@ -334,7 +336,7 @@ func (plugin *Plugin) InsertTag() ([]*thriftgo_plugin.Generated, error) {
stName := st.GetName()
for _, f := range st.Fields {
fieldName := f.GetName()
- tagString, err := getTagString(f)
+ tagString, err := getTagString(f, plugin.rmTags)
if err != nil {
return nil, err
}
@@ -360,7 +362,7 @@ func (plugin *Plugin) InsertTag() ([]*thriftgo_plugin.Generated, error) {
stName := st.GetName()
for _, f := range st.Fields {
fieldName := f.GetName()
- tagString, err := getTagString(f)
+ tagString, err := getTagString(f, plugin.rmTags)
if err != nil {
return nil, err
}
@@ -400,7 +402,7 @@ func (plugin *Plugin) GetResponse(files []generator.File, outputDir string) (*th
}, nil
}
-func getTagString(f *parser.Field) (string, error) {
+func getTagString(f *parser.Field, rmTags []string) (string, error) {
field := model.Field{}
err := injectTags(f, &field, true, false)
if err != nil {
@@ -412,6 +414,11 @@ func getTagString(f *parser.Field) (string, error) {
disableTag = true
}
}
+
+ for _, rmTag := range rmTags {
+ field.Tags.Remove(rmTag)
+ }
+
var tagString string
tags := field.Tags
for idx, tag := range tags {
From 1018b5dba033144c784f034a2a9a89c64750fe1c Mon Sep 17 00:00:00 2001
From: GuangyuFan <97507466+FGYFFFF@users.noreply.github.com>
Date: Tue, 18 Jul 2023 20:35:30 +0800
Subject: [PATCH 04/17] feat(hz): add validator anno (#799)
---
cmd/hz/protobuf/api/api.proto | 10 ++++++++++
1 file changed, 10 insertions(+)
diff --git a/cmd/hz/protobuf/api/api.proto b/cmd/hz/protobuf/api/api.proto
index a41e2f8d6..72edf41ff 100644
--- a/cmd/hz/protobuf/api/api.proto
+++ b/cmd/hz/protobuf/api/api.proto
@@ -24,6 +24,8 @@ extend google.protobuf.FieldOptions {
optional string js_conv_compatible = 50132;
optional string file_name_compatible = 50133;
optional string none_compatible = 50134;
+ // 50135 is reserved to vt_compatible
+ // optional FieldRules vt_compatible = 50135;
optional string go_tag = 51001;
}
@@ -62,4 +64,12 @@ extend google.protobuf.ServiceOptions {
// 50731~50760 used to extend service option by hz
optional string base_domain_compatible = 50731;
+}
+
+extend google.protobuf.MessageOptions {
+ // optional FieldRules msg_vt = 50111;
+
+ optional string reserve = 50830;
+ // 550831 is reserved to msg_vt_compatible
+ // optional FieldRules msg_vt_compatible = 50831;
}
\ No newline at end of file
From 3ecd9868bf0c03a43ae304295b9660fc83767b82 Mon Sep 17 00:00:00 2001
From: GuangyuFan <97507466+FGYFFFF@users.noreply.github.com>
Date: Tue, 18 Jul 2023 20:36:44 +0800
Subject: [PATCH 05/17] feat(hz): add more template parameters (#795)
---
cmd/hz/generator/custom_files.go | 49 ++++++++++++++++++--------------
1 file changed, 27 insertions(+), 22 deletions(-)
diff --git a/cmd/hz/generator/custom_files.go b/cmd/hz/generator/custom_files.go
index 390a40d60..1efaae07a 100644
--- a/cmd/hz/generator/custom_files.go
+++ b/cmd/hz/generator/custom_files.go
@@ -48,16 +48,17 @@ type IDLPackageRenderInfo struct {
type CustomizedFileForMethod struct {
*HttpMethod
- FilePath string
- FilePackage string
- ServiceInfo *Service // service info for that method
+ FilePath string
+ FilePackage string
+ ServiceInfo *Service // service info for this method
+ IDLPackageInfo *IDLPackageRenderInfo // IDL info for this service
}
type CustomizedFileForService struct {
*Service
FilePath string
FilePackage string
- IDLPackageInfo *IDLPackageRenderInfo // IDL info for that service
+ IDLPackageInfo *IDLPackageRenderInfo // IDL info for this service
}
type CustomizedFileForIDL struct {
@@ -110,7 +111,7 @@ func (pkgGen *HttpPackageGenerator) genCustomizedFile(pkg *HttpPackage) error {
filePathRenderInfo.ServiceName = service.Name
filePathRenderInfo.MethodName = method.Name
filePathRenderInfo.HandlerGenPath = method.OutputDir
- err := pkgGen.genLoopMethod(tplInfo, filePathRenderInfo, method, service)
+ err := pkgGen.genLoopMethod(tplInfo, filePathRenderInfo, method, service, &idlPackageRenderInfo)
if err != nil {
return err
}
@@ -323,10 +324,11 @@ func (pkgGen *HttpPackageGenerator) genLoopService(tplInfo *Template, filePathRe
if tplInfo.UpdateBehavior.AppendKey == "method" {
for _, method := range service.Methods {
data := CustomizedFileForMethod{
- HttpMethod: method,
- FilePath: filePath,
- FilePackage: util.SplitPackage(filepath.Dir(filePath), ""),
- ServiceInfo: service,
+ HttpMethod: method,
+ FilePath: filePath,
+ FilePackage: util.SplitPackage(filepath.Dir(filePath), ""),
+ ServiceInfo: service,
+ IDLPackageInfo: idlPackageRenderInfo,
}
insertKey, err := renderInsertKey(tplInfo, data)
if err != nil {
@@ -395,7 +397,7 @@ func (pkgGen *HttpPackageGenerator) genLoopService(tplInfo *Template, filePathRe
}
// genLoopMethod used to generate files by 'method'
-func (pkgGen *HttpPackageGenerator) genLoopMethod(tplInfo *Template, filePathRenderInfo FilePathRenderInfo, method *HttpMethod, service *Service) error {
+func (pkgGen *HttpPackageGenerator) genLoopMethod(tplInfo *Template, filePathRenderInfo FilePathRenderInfo, method *HttpMethod, service *Service, idlPackageRenderInfo *IDLPackageRenderInfo) error {
filePath, err := renderFilePath(tplInfo, filePathRenderInfo)
if err != nil {
return err
@@ -408,10 +410,11 @@ func (pkgGen *HttpPackageGenerator) genLoopMethod(tplInfo *Template, filePathRen
if !exist { // create file
data := CustomizedFileForMethod{
- HttpMethod: method,
- FilePath: filePath,
- FilePackage: util.SplitPackage(filepath.Dir(filePath), ""),
- ServiceInfo: service,
+ HttpMethod: method,
+ FilePath: filePath,
+ FilePackage: util.SplitPackage(filepath.Dir(filePath), ""),
+ ServiceInfo: service,
+ IDLPackageInfo: idlPackageRenderInfo,
}
err := pkgGen.TemplateGenerator.Generate(data, tplInfo.Path, filePath, false)
if err != nil {
@@ -426,10 +429,11 @@ func (pkgGen *HttpPackageGenerator) genLoopMethod(tplInfo *Template, filePathRen
// re-generate
logs.Infof("re-generate file '%s', because the update behavior is 'Regenerate'", filePath)
data := CustomizedFileForMethod{
- HttpMethod: method,
- FilePath: filePath,
- FilePackage: util.SplitPackage(filepath.Dir(filePath), ""),
- ServiceInfo: service,
+ HttpMethod: method,
+ FilePath: filePath,
+ FilePackage: util.SplitPackage(filepath.Dir(filePath), ""),
+ ServiceInfo: service,
+ IDLPackageInfo: idlPackageRenderInfo,
}
err := pkgGen.TemplateGenerator.Generate(data, tplInfo.Path, filePath, false)
if err != nil {
@@ -497,10 +501,11 @@ func (pkgGen *HttpPackageGenerator) genSingleCustomizedFile(tplInfo *Template, f
for _, service := range idlPackageRenderInfo.ServiceInfos.Services {
for _, method := range service.Methods {
data := CustomizedFileForMethod{
- HttpMethod: method,
- FilePath: filePath,
- FilePackage: util.SplitPackage(filepath.Dir(filePath), ""),
- ServiceInfo: service,
+ HttpMethod: method,
+ FilePath: filePath,
+ FilePackage: util.SplitPackage(filepath.Dir(filePath), ""),
+ ServiceInfo: service,
+ IDLPackageInfo: &idlPackageRenderInfo,
}
insertKey, err := renderInsertKey(tplInfo, data)
if err != nil {
From 90e1ff059ae52b8051fd029fcab649088ef9d782 Mon Sep 17 00:00:00 2001
From: GuangyuFan <97507466+FGYFFFF@users.noreply.github.com>
Date: Tue, 18 Jul 2023 20:38:01 +0800
Subject: [PATCH 06/17] fix(hz): client snake tag name (#851)
---
cmd/hz/protobuf/ast.go | 12 ++++++------
cmd/hz/thrift/ast.go | 12 ++++++------
2 files changed, 12 insertions(+), 12 deletions(-)
diff --git a/cmd/hz/protobuf/ast.go b/cmd/hz/protobuf/ast.go
index 2abe83372..292c0a461 100644
--- a/cmd/hz/protobuf/ast.go
+++ b/cmd/hz/protobuf/ast.go
@@ -269,14 +269,14 @@ func parseAnnotationToClient(clientMethod *generator.ClientMethod, gen *protogen
if proto.HasExtension(f.Desc.Options(), api.E_Query) {
hasAnnotation = true
queryAnnos := proto.GetExtension(f.Desc.Options(), api.E_Query)
- val := queryAnnos.(string)
+ val := checkSnakeName(queryAnnos.(string))
clientMethod.QueryParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", val, f.GoName)
}
if proto.HasExtension(f.Desc.Options(), api.E_Path) {
hasAnnotation = true
pathAnnos := proto.GetExtension(f.Desc.Options(), api.E_Path)
- val := pathAnnos.(string)
+ val := checkSnakeName(pathAnnos.(string))
if isStringFieldType {
clientMethod.PathParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", val, f.GoName)
} else {
@@ -287,7 +287,7 @@ func parseAnnotationToClient(clientMethod *generator.ClientMethod, gen *protogen
if proto.HasExtension(f.Desc.Options(), api.E_Header) {
hasAnnotation = true
headerAnnos := proto.GetExtension(f.Desc.Options(), api.E_Header)
- val := headerAnnos.(string)
+ val := checkSnakeName(headerAnnos.(string))
if isStringFieldType {
clientMethod.HeaderParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", val, f.GoName)
} else {
@@ -298,7 +298,7 @@ func parseAnnotationToClient(clientMethod *generator.ClientMethod, gen *protogen
if formAnnos := getCompatibleAnnotation(f.Desc.Options(), api.E_Form, api.E_FormCompatible); formAnnos != nil {
hasAnnotation = true
hasFormAnnotation = true
- val := formAnnos.(string)
+ val := checkSnakeName(formAnnos.(string))
if isStringFieldType {
clientMethod.FormValueCode += fmt.Sprintf("%q: req.Get%s(),\n", val, f.GoName)
} else {
@@ -314,11 +314,11 @@ func parseAnnotationToClient(clientMethod *generator.ClientMethod, gen *protogen
if fileAnnos := getCompatibleAnnotation(f.Desc.Options(), api.E_FileName, api.E_FileNameCompatible); fileAnnos != nil {
hasAnnotation = true
hasFormAnnotation = true
- val := fileAnnos.(string)
+ val := checkSnakeName(fileAnnos.(string))
clientMethod.FormFileCode += fmt.Sprintf("%q: req.Get%s(),\n", val, f.GoName)
}
if !hasAnnotation && strings.EqualFold(clientMethod.HTTPMethod, "get") {
- clientMethod.QueryParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", f.GoName, f.GoName)
+ clientMethod.QueryParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", checkSnakeName(f.GoName), f.GoName)
}
}
clientMethod.BodyParamsCode = meta.SetBodyParam
diff --git a/cmd/hz/thrift/ast.go b/cmd/hz/thrift/ast.go
index 298bad006..af7b05401 100644
--- a/cmd/hz/thrift/ast.go
+++ b/cmd/hz/thrift/ast.go
@@ -217,13 +217,13 @@ func parseAnnotationToClient(clientMethod *generator.ClientMethod, p *parser.Typ
}
if anno := getAnnotation(field.Annotations, AnnotationQuery); len(anno) > 0 {
hasAnnotation = true
- query := anno[0]
+ query := checkSnakeName(anno[0])
clientMethod.QueryParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", query, field.GoName().String())
}
if anno := getAnnotation(field.Annotations, AnnotationPath); len(anno) > 0 {
hasAnnotation = true
- path := anno[0]
+ path := checkSnakeName(anno[0])
if isStringFieldType {
clientMethod.PathParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", path, field.GoName().String())
} else {
@@ -233,7 +233,7 @@ func parseAnnotationToClient(clientMethod *generator.ClientMethod, p *parser.Typ
if anno := getAnnotation(field.Annotations, AnnotationHeader); len(anno) > 0 {
hasAnnotation = true
- header := anno[0]
+ header := checkSnakeName(anno[0])
if isStringFieldType {
clientMethod.HeaderParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", header, field.GoName().String())
} else {
@@ -243,7 +243,7 @@ func parseAnnotationToClient(clientMethod *generator.ClientMethod, p *parser.Typ
if anno := getAnnotation(field.Annotations, AnnotationForm); len(anno) > 0 {
hasAnnotation = true
- form := anno[0]
+ form := checkSnakeName(anno[0])
hasFormAnnotation = true
if isStringFieldType {
clientMethod.FormValueCode += fmt.Sprintf("%q: req.Get%s(),\n", form, field.GoName().String())
@@ -259,12 +259,12 @@ func parseAnnotationToClient(clientMethod *generator.ClientMethod, p *parser.Typ
if anno := getAnnotation(field.Annotations, AnnotationFileName); len(anno) > 0 {
hasAnnotation = true
- fileName := anno[0]
+ fileName := checkSnakeName(anno[0])
hasFormAnnotation = true
clientMethod.FormFileCode += fmt.Sprintf("%q: req.Get%s(),\n", fileName, field.GoName().String())
}
if !hasAnnotation && strings.EqualFold(clientMethod.HTTPMethod, "get") {
- clientMethod.QueryParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", field.GoName().String(), field.GoName().String())
+ clientMethod.QueryParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", checkSnakeName(field.GoName().String()), field.GoName().String())
}
}
clientMethod.BodyParamsCode = meta.SetBodyParam
From 05d9ab4a7a3ae47b5389084be56045299ecfb301 Mon Sep 17 00:00:00 2001
From: GuangyuFan <97507466+FGYFFFF@users.noreply.github.com>
Date: Tue, 18 Jul 2023 20:38:52 +0800
Subject: [PATCH 07/17] feat(hz): add default service name (#852)
---
cmd/hz/generator/layout.go | 7 ++++++-
cmd/hz/meta/const.go | 2 ++
2 files changed, 8 insertions(+), 1 deletion(-)
diff --git a/cmd/hz/generator/layout.go b/cmd/hz/generator/layout.go
index 3791955b5..d007fa82e 100644
--- a/cmd/hz/generator/layout.go
+++ b/cmd/hz/generator/layout.go
@@ -25,6 +25,7 @@ import (
"reflect"
"strings"
+ "github.com/cloudwego/hertz/cmd/hz/meta"
"github.com/cloudwego/hertz/cmd/hz/util"
"gopkg.in/yaml.v2"
)
@@ -169,10 +170,14 @@ func serviceToLayoutData(service Layout) (map[string]interface{}, error) {
if len(service.RouterDir) != 0 {
routerPkg = filepath.Base(service.RouterDir)
}
+ serviceName := service.ServiceName
+ if len(serviceName) == 0 {
+ serviceName = meta.DefaultServiceName
+ }
return map[string]interface{}{
"GoModule": goMod,
- "ServiceName": service.ServiceName,
+ "ServiceName": serviceName,
"UseApacheThrift": service.UseApacheThrift,
"HandlerPkg": handlerPkg,
"RouterPkg": routerPkg,
diff --git a/cmd/hz/meta/const.go b/cmd/hz/meta/const.go
index 6ec98d6e2..66f0a1fb4 100644
--- a/cmd/hz/meta/const.go
+++ b/cmd/hz/meta/const.go
@@ -21,6 +21,8 @@ import "runtime"
// Version hz version
const Version = "v0.6.5"
+const DefaultServiceName = "hertz_service"
+
// Mode hz run modes
type Mode int
From 1c49d8b8fe7e8d95a6998c5022a45ff22671b94f Mon Sep 17 00:00:00 2001
From: kinggo
Date: Tue, 18 Jul 2023 20:52:21 +0800
Subject: [PATCH 08/17] optimize: modify judgment conditions when use retry and
err = nil (#860)
---
pkg/app/client/client_test.go | 51 +++++++++++++++++++++++++++++++++++
pkg/protocol/http1/client.go | 3 ++-
2 files changed, 53 insertions(+), 1 deletion(-)
diff --git a/pkg/app/client/client_test.go b/pkg/app/client/client_test.go
index 84c6a918e..ba00f78a8 100644
--- a/pkg/app/client/client_test.go
+++ b/pkg/app/client/client_test.go
@@ -2296,3 +2296,54 @@ func TestClientState(t *testing.T) {
client.Get(context.Background(), nil, "http://127.0.0.1:11000")
time.Sleep(time.Second * 22)
}
+
+func TestClientRetryErr(t *testing.T) {
+ t.Run("200", func(t *testing.T) {
+ opt := config.NewOptions([]config.Option{})
+ opt.Addr = "127.0.0.1:10136"
+ engine := route.NewEngine(opt)
+ var l sync.Mutex
+ retryNum := 0
+ engine.GET("/ping", func(c context.Context, ctx *app.RequestContext) {
+ l.Lock()
+ defer l.Unlock()
+ retryNum += 1
+ ctx.SetStatusCode(200)
+ })
+ go engine.Run()
+ time.Sleep(100 * time.Millisecond)
+ c, _ := NewClient(WithRetryConfig(retry.WithMaxAttemptTimes(3)))
+ _, _, err := c.Get(context.Background(), nil, "http://127.0.0.1:10136/ping")
+ assert.Nil(t, err)
+ l.Lock()
+ assert.DeepEqual(t, 1, retryNum)
+ l.Unlock()
+ engine.Close()
+ })
+
+ t.Run("502", func(t *testing.T) {
+ opt := config.NewOptions([]config.Option{})
+ opt.Addr = "127.0.0.1:10137"
+ engine := route.NewEngine(opt)
+ var l sync.Mutex
+ retryNum := 0
+ engine.GET("/ping", func(c context.Context, ctx *app.RequestContext) {
+ l.Lock()
+ defer l.Unlock()
+ retryNum += 1
+ ctx.SetStatusCode(502)
+ })
+ go engine.Run()
+ time.Sleep(100 * time.Millisecond)
+ c, _ := NewClient(WithRetryConfig(retry.WithMaxAttemptTimes(3)))
+ c.SetRetryIfFunc(func(req *protocol.Request, resp *protocol.Response, err error) bool {
+ return resp.StatusCode() == 502
+ })
+ _, _, err := c.Get(context.Background(), nil, "http://127.0.0.1:10137/ping")
+ assert.Nil(t, err)
+ l.Lock()
+ assert.DeepEqual(t, 3, retryNum)
+ l.Unlock()
+ engine.Close()
+ })
+}
diff --git a/pkg/protocol/http1/client.go b/pkg/protocol/http1/client.go
index 13f07a149..34bb5f014 100644
--- a/pkg/protocol/http1/client.go
+++ b/pkg/protocol/http1/client.go
@@ -381,7 +381,8 @@ func (c *HostClient) Do(ctx context.Context, req *protocol.Request, resp *protoc
req.Options().StartRequest()
for {
canIdempotentRetry, err = c.do(req, resp)
- if err == nil {
+ // If there is no custom retry and err is equal to nil, the loop simply exits.
+ if err == nil && isDefaultRetryFunc {
break
}
From 9014502d33ba3f1d46937173953ce6741d95f8d1 Mon Sep 17 00:00:00 2001
From: GuangyuFan <97507466+FGYFFFF@users.noreply.github.com>
Date: Wed, 19 Jul 2023 10:45:14 +0800
Subject: [PATCH 09/17] feat(hz): add more template parameter (#854)
---
cmd/hz/generator/handler.go | 26 +++++++++++++----------
cmd/hz/protobuf/ast.go | 12 +++++++++--
cmd/hz/thrift/ast.go | 42 +++++++++++++++++++++++++++++--------
3 files changed, 58 insertions(+), 22 deletions(-)
diff --git a/cmd/hz/generator/handler.go b/cmd/hz/generator/handler.go
index cecb0d038..4d797d26e 100644
--- a/cmd/hz/generator/handler.go
+++ b/cmd/hz/generator/handler.go
@@ -29,17 +29,21 @@ import (
)
type HttpMethod struct {
- Name string
- HTTPMethod string
- Comment string
- RequestTypeName string
- ReturnTypeName string
- Path string
- Serializer string
- OutputDir string
- RefPackage string
- RefPackageAlias string
- ModelPackage map[string]string
+ Name string
+ HTTPMethod string
+ Comment string
+ RequestTypeName string
+ RequestTypePackage string
+ RequestTypeRawName string
+ ReturnTypeName string
+ ReturnTypePackage string
+ ReturnTypeRawName string
+ Path string
+ Serializer string
+ OutputDir string
+ RefPackage string // handler import dir
+ RefPackageAlias string // handler import alias
+ ModelPackage map[string]string
// Annotations map[string]string
Models map[string]*model.Model
}
diff --git a/cmd/hz/protobuf/ast.go b/cmd/hz/protobuf/ast.go
index 292c0a461..fcb8197eb 100644
--- a/cmd/hz/protobuf/ast.go
+++ b/cmd/hz/protobuf/ast.go
@@ -158,16 +158,20 @@ func astToService(ast *descriptorpb.FileDescriptorProto, resolver *Resolver, cmd
reqName := m.GetInputType()
sb, err := resolver.ResolveIdentifier(reqName)
- reqName = util.BaseName(sb.Scope.GetOptions().GetGoPackage(), "") + "." + inputGoType.GoIdent.GoName
if err != nil {
return nil, err
}
+ reqName = util.BaseName(sb.Scope.GetOptions().GetGoPackage(), "") + "." + inputGoType.GoIdent.GoName
+ reqRawName := inputGoType.GoIdent.GoName
+ reqPackage := util.BaseName(sb.Scope.GetOptions().GetGoPackage(), "")
respName := m.GetOutputType()
st, err := resolver.ResolveIdentifier(respName)
- respName = util.BaseName(st.Scope.GetOptions().GetGoPackage(), "") + "." + outputGoType.GoIdent.GoName
if err != nil {
return nil, err
}
+ respName = util.BaseName(st.Scope.GetOptions().GetGoPackage(), "") + "." + outputGoType.GoIdent.GoName
+ respRawName := outputGoType.GoIdent.GoName
+ respPackage := util.BaseName(sb.Scope.GetOptions().GetGoPackage(), "")
var serializer string
sl, sv := checkFirstOptions(SerializerOptions, m.GetOptions())
@@ -212,7 +216,11 @@ func astToService(ast *descriptorpb.FileDescriptorProto, resolver *Resolver, cmd
respName = goOptMapAlias[st.Scope.GetOptions().GetGoPackage()] + "." + outputGoType.GoIdent.GoName
}
method.RequestTypeName = reqName
+ method.RequestTypeRawName = reqRawName
+ method.RequestTypePackage = reqPackage
method.ReturnTypeName = respName
+ method.ReturnTypeRawName = respRawName
+ method.ReturnTypePackage = respPackage
methods = append(methods, method)
diff --git a/cmd/hz/thrift/ast.go b/cmd/hz/thrift/ast.go
index af7b05401..557b50a41 100644
--- a/cmd/hz/thrift/ast.go
+++ b/cmd/hz/thrift/ast.go
@@ -122,7 +122,7 @@ func astToService(ast *parser.Thrift, resolver *Resolver, args *config.Argument)
return nil, fmt.Errorf("invalid api.%s for %s.%s: %s", hmethod, s.Name, m.Name, path)
}
- var reqName string
+ var reqName, reqRawName, reqPackage string
if len(m.Arguments) >= 1 {
if len(m.Arguments) > 1 {
logs.Warnf("function '%s' has more than one argument, but only the first can be used in hertz now", m.GetName())
@@ -132,25 +132,49 @@ func astToService(ast *parser.Thrift, resolver *Resolver, args *config.Argument)
if err != nil {
return nil, err
}
+ if strings.Contains(reqName, ".") && !m.Arguments[0].GetType().Category.IsContainerType() {
+ // If reqName contains "." , then it must be of the form "pkg.name".
+ // so reqRawName='name', reqPackage='pkg'
+ names := strings.Split(reqName, ".")
+ if len(names) != 2 {
+ return nil, fmt.Errorf("request name: %s is wrong", reqName)
+ }
+ reqRawName = names[1]
+ reqPackage = names[0]
+ }
}
- var respName string
+ var respName, respRawName, respPackage string
if !m.Oneway {
var err error
respName, err = resolver.ResolveTypeName(m.GetFunctionType())
if err != nil {
return nil, err
}
+ if strings.Contains(respName, ".") && !m.GetFunctionType().Category.IsContainerType() {
+ names := strings.Split(respName, ".")
+ if len(names) != 2 {
+ return nil, fmt.Errorf("response name: %s is wrong", respName)
+ }
+ // If respName contains "." , then it must be of the form "pkg.name".
+ // so respRawName='name', respPackage='pkg'
+ respRawName = names[1]
+ respPackage = names[0]
+ }
}
sr, _ := util.GetFirstKV(getAnnotations(m.Annotations, SerializerTags))
method := &generator.HttpMethod{
- Name: util.CamelString(m.GetName()),
- HTTPMethod: hmethod,
- RequestTypeName: reqName,
- ReturnTypeName: respName,
- Path: path[0],
- Serializer: sr,
- OutputDir: handlerOutDir,
+ Name: util.CamelString(m.GetName()),
+ HTTPMethod: hmethod,
+ RequestTypeName: reqName,
+ RequestTypeRawName: reqRawName,
+ RequestTypePackage: reqPackage,
+ ReturnTypeName: respName,
+ ReturnTypeRawName: respRawName,
+ ReturnTypePackage: respPackage,
+ Path: path[0],
+ Serializer: sr,
+ OutputDir: handlerOutDir,
// Annotations: m.Annotations,
}
refs := resolver.ExportReferred(false, true)
From eca144b85291fe0f0976ddf7f92d15c559fd6257 Mon Sep 17 00:00:00 2001
From: Guangming Luo
Date: Thu, 20 Jul 2023 20:38:27 +0800
Subject: [PATCH 10/17] chore: remove wechat group in readme (#864)
---
README.md | 2 --
README_cn.md | 4 +---
2 files changed, 1 insertion(+), 5 deletions(-)
diff --git a/README.md b/README.md
index ea347fdfa..ae09350ea 100644
--- a/README.md
+++ b/README.md
@@ -112,9 +112,7 @@ Hertz is distributed under the [Apache License, version 2.0](https://github.com/
- Lark: Scan the QR code below with [Lark](https://www.larksuite.com/zh_cn/download) to join our CloudWeGo/hertz user group.
![LarkGroup](images/lark_group.png)
-- WeChat: CloudWeGo community WeChat group.
-![WechatGroup](images/wechat_group_cn.png)
## Contributors
Thank you for your contribution to Hertz!
diff --git a/README_cn.md b/README_cn.md
index a1c13fbd6..8ff215d04 100644
--- a/README_cn.md
+++ b/README_cn.md
@@ -113,9 +113,7 @@ Hertz 基于[Apache License 2.0](https://github.com/cloudwego/hertz/blob/main/LI
- 飞书用户群([注册飞书](https://www.larksuite.com/zh_cn/download)进群)
![LarkGroup](images/lark_group_cn.png)
-- 微信: CloudWeGo community
- ![WechatGroup](images/wechat_group_cn.png)
## 贡献者
感谢您对 Hertz 作出的贡献!
@@ -126,4 +124,4 @@ Hertz 基于[Apache License 2.0](https://github.com/cloudwego/hertz/blob/main/LI
CloudWeGo 丰富了 CNCF 云原生生态。
-
\ No newline at end of file
+
From 3161f043a97d062f887fa2d8a9af9ca9bc96fc95 Mon Sep 17 00:00:00 2001
From: Cr <631807682@qq.com>
Date: Tue, 25 Jul 2023 20:10:32 +0800
Subject: [PATCH 11/17] test(http1): fix retry test (#866)
---
pkg/protocol/http1/client_test.go | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/pkg/protocol/http1/client_test.go b/pkg/protocol/http1/client_test.go
index 19154453c..4dd6fcbf2 100644
--- a/pkg/protocol/http1/client_test.go
+++ b/pkg/protocol/http1/client_test.go
@@ -437,7 +437,7 @@ func TestRetry(t *testing.T) {
Delay: time.Millisecond * 10,
},
RetryIfFunc: func(req *protocol.Request, resp *protocol.Response, err error) bool {
- return true
+ return resp.Header.ContentLength() != 10
},
},
Addr: "foobar",
From 4a68f8488361bc796a047af4133b515ac309468e Mon Sep 17 00:00:00 2001
From: Wenju Gao
Date: Mon, 31 Jul 2023 16:06:28 +0800
Subject: [PATCH 12/17] fix: keep isTLS field in keepalive when reset the
request (#876)
---
pkg/app/context.go | 3 ++-
pkg/app/context_test.go | 4 +++-
pkg/protocol/request.go | 19 ++++++++++++++++++-
3 files changed, 23 insertions(+), 3 deletions(-)
diff --git a/pkg/app/context.go b/pkg/app/context.go
index e4fcd0af9..559b84d03 100644
--- a/pkg/app/context.go
+++ b/pkg/app/context.go
@@ -754,7 +754,8 @@ func (ctx *RequestContext) ResetWithoutConn() {
close(ctx.finished)
ctx.finished = nil
}
- ctx.Request.Reset()
+
+ ctx.Request.ResetWithoutConn()
ctx.Response.Reset()
if ctx.IsEnableTrace() {
ctx.traceInfo.Reset()
diff --git a/pkg/app/context_test.go b/pkg/app/context_test.go
index b9e9d2f5f..8bf52f437 100644
--- a/pkg/app/context_test.go
+++ b/pkg/app/context_test.go
@@ -654,8 +654,10 @@ func TestContextReset(t *testing.T) {
c.Params = param.Params{param.Param{}}
c.Error(errors.New("test")) // nolint: errcheck
c.Set("foo", "bar")
+ c.Request.SetIsTLS(true)
c.ResetWithoutConn()
-
+ c.Request.URI()
+ assert.DeepEqual(t, "https", string(c.Request.Scheme()))
assert.False(t, c.IsAborted())
assert.DeepEqual(t, 0, len(c.Errors))
assert.Nil(t, c.Errors.Errors())
diff --git a/pkg/protocol/request.go b/pkg/protocol/request.go
index af29db1ba..ef70f92ca 100644
--- a/pkg/protocol/request.go
+++ b/pkg/protocol/request.go
@@ -176,16 +176,25 @@ func (req *Request) MayContinue() bool {
return bytes.Equal(req.Header.peek(bytestr.StrExpect), bytestr.Str100Continue)
}
+// Scheme returns the scheme of the request.
+// uri will be parsed in ServeHTTP(before user's process), so that there is no need for uri nil-check.
func (req *Request) Scheme() []byte {
return req.uri.Scheme()
}
-func (req *Request) ResetSkipHeader() {
+// For keepalive connection reuse.
+// It is roughly the same as ResetSkipHeader, except that the connection related fields are removed:
+// - req.isTLS
+func (req *Request) resetSkipHeaderAndConn() {
req.ResetBody()
req.uri.Reset()
req.parsedURI = false
req.parsedPostArgs = false
req.postArgs.Reset()
+}
+
+func (req *Request) ResetSkipHeader() {
+ req.resetSkipHeaderAndConn()
req.isTLS = false
}
@@ -814,6 +823,14 @@ func (req *Request) SetConnectionClose() {
req.Header.SetConnectionClose(true)
}
+func (req *Request) ResetWithoutConn() {
+ req.Header.Reset()
+ req.resetSkipHeaderAndConn()
+ req.CloseBodyStream()
+
+ req.options = nil
+}
+
// AcquireRequest returns an empty Request instance from request pool.
//
// The returned Request instance may be passed to ReleaseRequest when it is
From f4cde92c4781a355113b92e7a9e444b69e640fbb Mon Sep 17 00:00:00 2001
From: Wenju Gao
Date: Mon, 31 Jul 2023 16:14:47 +0800
Subject: [PATCH 13/17] optimize(http1): retry if the error is caused by broken
pooled conn (#871)
---
pkg/app/client/client_test.go | 6 ++-
pkg/app/server/hertz_test.go | 7 +--
pkg/common/errors/errors.go | 6 +--
pkg/common/test/mock/network.go | 65 +++++++++++++++++++++++-
pkg/common/test/mock/network_test.go | 58 ++++++++++++++++++++-
pkg/protocol/client/client.go | 11 ----
pkg/protocol/http1/client.go | 70 +++++++++++++++++++++-----
pkg/protocol/http1/client_test.go | 67 ++++++++++++++++++++++--
pkg/protocol/http1/client_unix_test.go | 4 +-
pkg/protocol/http1/server.go | 2 +-
10 files changed, 251 insertions(+), 45 deletions(-)
diff --git a/pkg/app/client/client_test.go b/pkg/app/client/client_test.go
index ba00f78a8..8449cd4d8 100644
--- a/pkg/app/client/client_test.go
+++ b/pkg/app/client/client_test.go
@@ -453,7 +453,9 @@ func TestClientReadTimeout(t *testing.T) {
req.SetConnectionClose()
if err := c.Do(context.Background(), req, res); !errors.Is(err, errs.ErrTimeout) {
- t.Errorf("expected ErrTimeout got %#v", err)
+ if !strings.Contains(err.Error(), "timeout") {
+ t.Errorf("expected ErrTimeout got %#v", err)
+ }
}
protocol.ReleaseRequest(req)
@@ -2267,7 +2269,7 @@ func TestClientDoWithDialFunc(t *testing.T) {
func TestClientState(t *testing.T) {
opt := config.NewOptions([]config.Option{})
- opt.Addr = ":11000"
+ opt.Addr = "127.0.0.1:11000"
engine := route.NewEngine(opt)
go engine.Run()
diff --git a/pkg/app/server/hertz_test.go b/pkg/app/server/hertz_test.go
index b9a07fbce..bc3e1a0e4 100644
--- a/pkg/app/server/hertz_test.go
+++ b/pkg/app/server/hertz_test.go
@@ -311,16 +311,17 @@ func TestNotAbsolutePathWithRawPath(t *testing.T) {
}
func TestWithBasePath(t *testing.T) {
- engine := New(WithBasePath("/hertz"), WithHostPorts("127.0.0.1:9898"))
+ engine := New(WithBasePath("/hertz"), WithHostPorts("127.0.0.1:19898"))
engine.POST("/test", func(c context.Context, ctx *app.RequestContext) {
})
go engine.Run()
- time.Sleep(200 * time.Microsecond)
+ time.Sleep(500 * time.Microsecond)
var r http.Request
r.ParseForm()
r.Form.Add("xxxxxx", "xxx")
body := strings.NewReader(r.Form.Encode())
- resp, _ := http.Post("http://127.0.0.1:9898/hertz/test", "application/x-www-form-urlencoded", body)
+ resp, err := http.Post("http://127.0.0.1:19898/hertz/test", "application/x-www-form-urlencoded", body)
+ assert.Nil(t, err)
assert.DeepEqual(t, consts.StatusOK, resp.StatusCode)
}
diff --git a/pkg/common/errors/errors.go b/pkg/common/errors/errors.go
index fdaf2d56b..918c7fa7a 100644
--- a/pkg/common/errors/errors.go
+++ b/pkg/common/errors/errors.go
@@ -53,17 +53,15 @@ var (
ErrChunkedStream = errors.New("chunked stream")
ErrBodyTooLarge = errors.New("body size exceeds the given limit")
ErrHijacked = errors.New("connection has been hijacked")
- ErrIdleTimeout = errors.New("idle timeout")
ErrTimeout = errors.New("timeout")
- ErrReadTimeout = errors.New("read timeout")
- ErrWriteTimeout = errors.New("write timeout")
- ErrDialTimeout = errors.New("dial timeout")
+ ErrIdleTimeout = errors.New("idle timeout")
ErrNothingRead = errors.New("nothing read")
ErrShortConnection = errors.New("short connection")
ErrNoFreeConns = errors.New("no free connections available to host")
ErrConnectionClosed = errors.New("connection closed")
ErrNotSupportProtocol = errors.New("not support protocol")
ErrNoMultipartForm = errors.New("request has no multipart/form-data Content-Type")
+ ErrBadPoolConn = errors.New("connection is closed by peer while being in the connection pool")
)
// ErrorType is an unsigned 64-bit error code as defined in the hertz spec.
diff --git a/pkg/common/test/mock/network.go b/pkg/common/test/mock/network.go
index 7ad083647..57293cb2b 100644
--- a/pkg/common/test/mock/network.go
+++ b/pkg/common/test/mock/network.go
@@ -18,6 +18,7 @@ package mock
import (
"bytes"
+ "io"
"net"
"strings"
"time"
@@ -27,6 +28,11 @@ import (
"github.com/cloudwego/netpoll"
)
+var (
+ ErrReadTimeout = errs.New(errs.ErrTimeout, errs.ErrorTypePublic, "read timeout")
+ ErrWriteTimeout = errs.New(errs.ErrTimeout, errs.ErrorTypePublic, "write timeout")
+)
+
type Conn struct {
readTimeout time.Duration
zr network.Reader
@@ -146,7 +152,7 @@ func (m *SlowReadConn) Peek(i int) ([]byte, error) {
time.Sleep(100 * time.Millisecond)
}
if err != nil || len(b) != i {
- return nil, errs.ErrReadTimeout
+ return nil, ErrReadTimeout
}
return b, err
}
@@ -161,6 +167,61 @@ func NewConn(source string) *Conn {
}
}
+type BrokenConn struct {
+ *Conn
+}
+
+func (o *BrokenConn) Peek(i int) ([]byte, error) {
+ return nil, io.EOF
+}
+
+func (o *BrokenConn) Flush() error {
+ return errs.ErrConnectionClosed
+}
+
+func NewBrokenConn(source string) *BrokenConn {
+ return &BrokenConn{Conn: NewConn(source)}
+}
+
+type OneTimeConn struct {
+ isRead bool
+ isFlushed bool
+ contentLength int
+ *Conn
+}
+
+func (o *OneTimeConn) Peek(n int) ([]byte, error) {
+ if o.isRead {
+ return nil, io.EOF
+ }
+ return o.Conn.Peek(n)
+}
+
+func (o *OneTimeConn) Skip(n int) error {
+ if o.isRead {
+ return io.EOF
+ }
+ o.contentLength -= n
+
+ if o.contentLength == 0 {
+ o.isRead = true
+ }
+
+ return o.Conn.Skip(n)
+}
+
+func (o *OneTimeConn) Flush() error {
+ if o.isFlushed {
+ return errs.ErrConnectionClosed
+ }
+ o.isFlushed = true
+ return o.Conn.Flush()
+}
+
+func NewOneTimeConn(source string) *OneTimeConn {
+ return &OneTimeConn{isRead: false, isFlushed: false, Conn: NewConn(source), contentLength: len(source)}
+}
+
func NewSlowReadConn(source string) *SlowReadConn {
return &SlowReadConn{Conn: NewConn(source)}
}
@@ -200,7 +261,7 @@ func (m *SlowWriteConn) Flush() error {
time.Sleep(100 * time.Millisecond)
if err == nil {
time.Sleep(m.writeTimeout)
- return errs.ErrWriteTimeout
+ return ErrWriteTimeout
}
return err
}
diff --git a/pkg/common/test/mock/network_test.go b/pkg/common/test/mock/network_test.go
index 2d4ff973a..bd6ad5eb9 100644
--- a/pkg/common/test/mock/network_test.go
+++ b/pkg/common/test/mock/network_test.go
@@ -18,6 +18,7 @@ package mock
import (
"context"
+ "io"
"testing"
"time"
@@ -30,6 +31,7 @@ func TestConn(t *testing.T) {
t.Run("TestReader", func(t *testing.T) {
s1 := "abcdef4343"
conn1 := NewConn(s1)
+ assert.Nil(t, conn1.SetWriteTimeout(1))
err := conn1.SetReadDeadline(time.Now().Add(time.Millisecond * 100))
assert.DeepEqual(t, nil, err)
err = conn1.SetReadTimeout(time.Millisecond * 100)
@@ -135,12 +137,16 @@ func TestSlowConn(t *testing.T) {
t.Run("TestSlowReadConn", func(t *testing.T) {
s1 := "abcdefg"
conn := NewSlowReadConn(s1)
+ assert.Nil(t, conn.SetWriteTimeout(1))
+ assert.Nil(t, conn.SetReadTimeout(1))
+ assert.DeepEqual(t, time.Duration(1), conn.readTimeout)
+
b, err := conn.Peek(4)
assert.DeepEqual(t, nil, err)
assert.DeepEqual(t, s1[:4], string(b))
conn.Skip(len(s1))
_, err = conn.Peek(1)
- assert.DeepEqual(t, errs.ErrReadTimeout, err)
+ assert.DeepEqual(t, ErrReadTimeout, err)
_, err = SlowReadDialer("")
assert.DeepEqual(t, nil, err)
})
@@ -150,7 +156,7 @@ func TestSlowConn(t *testing.T) {
assert.DeepEqual(t, nil, err)
conn.SetWriteTimeout(time.Millisecond * 100)
err = conn.Flush()
- assert.DeepEqual(t, errs.ErrWriteTimeout, err)
+ assert.DeepEqual(t, ErrWriteTimeout, err)
})
}
@@ -183,3 +189,51 @@ func TestStreamConn(t *testing.T) {
})
})
}
+
+func TestBrokenConn_Flush(t *testing.T) {
+ conn := NewBrokenConn("")
+ n, err := conn.Writer().WriteBinary([]byte("Foo"))
+ assert.DeepEqual(t, 3, n)
+ assert.Nil(t, err)
+ assert.DeepEqual(t, errs.ErrConnectionClosed, conn.Flush())
+}
+
+func TestBrokenConn_Peek(t *testing.T) {
+ conn := NewBrokenConn("Foo")
+ buf, err := conn.Peek(3)
+ assert.Nil(t, buf)
+ assert.DeepEqual(t, io.EOF, err)
+}
+
+func TestOneTimeConn_Flush(t *testing.T) {
+ conn := NewOneTimeConn("")
+ n, err := conn.Writer().WriteBinary([]byte("Foo"))
+ assert.DeepEqual(t, 3, n)
+ assert.Nil(t, err)
+ assert.Nil(t, conn.Flush())
+ n, err = conn.Writer().WriteBinary([]byte("Bar"))
+ assert.DeepEqual(t, 3, n)
+ assert.Nil(t, err)
+ assert.DeepEqual(t, errs.ErrConnectionClosed, conn.Flush())
+}
+
+func TestOneTimeConn_Skip(t *testing.T) {
+ conn := NewOneTimeConn("FooBar")
+ buf, err := conn.Peek(3)
+ assert.DeepEqual(t, "Foo", string(buf))
+ assert.Nil(t, err)
+ assert.Nil(t, conn.Skip(3))
+ assert.DeepEqual(t, 3, conn.contentLength)
+
+ buf, err = conn.Peek(3)
+ assert.DeepEqual(t, "Bar", string(buf))
+ assert.Nil(t, err)
+ assert.Nil(t, conn.Skip(3))
+ assert.DeepEqual(t, 0, conn.contentLength)
+
+ buf, err = conn.Peek(3)
+ assert.DeepEqual(t, 0, len(buf))
+ assert.DeepEqual(t, io.EOF, err)
+ assert.DeepEqual(t, io.EOF, conn.Skip(3))
+ assert.DeepEqual(t, 0, conn.contentLength)
+}
diff --git a/pkg/protocol/client/client.go b/pkg/protocol/client/client.go
index 3715255e4..777f55cdd 100644
--- a/pkg/protocol/client/client.go
+++ b/pkg/protocol/client/client.go
@@ -43,7 +43,6 @@ package client
import (
"context"
- "io"
"sync"
"time"
@@ -88,16 +87,6 @@ func DefaultRetryIf(req *protocol.Request, resp *protocol.Response, err error) b
if isIdempotent(req, resp, err) {
return true
}
- // Retry non-idempotent requests if the server closes
- // the connection before sending the response.
- //
- // This case is possible if the server closes the idle
- // keep-alive connection on timeout.
- //
- // Apache and nginx usually do this.
- if err == io.EOF {
- return true
- }
return false
}
diff --git a/pkg/protocol/http1/client.go b/pkg/protocol/http1/client.go
index 34bb5f014..c306f071f 100644
--- a/pkg/protocol/http1/client.go
+++ b/pkg/protocol/http1/client.go
@@ -51,6 +51,7 @@ import (
"strings"
"sync"
"sync/atomic"
+ "syscall"
"time"
"github.com/cloudwego/hertz/internal/bytesconv"
@@ -362,6 +363,7 @@ func (c *HostClient) Do(ctx context.Context, req *protocol.Request, resp *protoc
canIdempotentRetry bool
isDefaultRetryFunc = true
attempts uint = 0
+ connAttempts uint = 0
maxAttempts uint = 1
isRequestRetryable client.RetryIfFunc = client.DefaultRetryIf
)
@@ -380,17 +382,38 @@ func (c *HostClient) Do(ctx context.Context, req *protocol.Request, resp *protoc
atomic.AddInt32(&c.pendingRequests, 1)
req.Options().StartRequest()
for {
+ select {
+ case <-ctx.Done():
+ req.CloseBodyStream() //nolint:errcheck
+ return ctx.Err()
+ default:
+ }
+
canIdempotentRetry, err = c.do(req, resp)
// If there is no custom retry and err is equal to nil, the loop simply exits.
if err == nil && isDefaultRetryFunc {
+ if connAttempts != 0 {
+ hlog.SystemLogger().Warnf("Client connection attempt times: %d, url: %s. "+
+ "This is mainly because the connection in pool is closed by peer in advance. "+
+ "If this number is too high which indicates that long-connection are basically unavailable, "+
+ "try to change the request to short-connection.\n", connAttempts, req.URI().FullURI())
+ }
break
}
+ // This connection is closed by the peer when it is in the connection pool.
+ //
+ // This case is possible if the server closes the idle
+ // keep-alive connection on timeout.
+ //
+ // Apache and nginx usually do this.
+ if canIdempotentRetry && client.DefaultRetryIf(req, resp, err) && errors.Is(err, errs.ErrBadPoolConn) {
+ connAttempts++
+ continue
+ }
+
if isDefaultRetryFunc {
- // canIdempotentRetry only makes sense if the user hasn't provided a custom retry function.
- if !canIdempotentRetry {
- break
- }
+ break
}
attempts++
@@ -516,7 +539,7 @@ func (c *HostClient) doNonNilReqResp(req *protocol.Request, resp *protocol.Respo
if reqTimeout < dialTimeout || dialTimeout == 0 {
dialTimeout = reqTimeout
}
- cc, err := c.acquireConn(dialTimeout)
+ cc, inPool, err := c.acquireConn(dialTimeout)
// if getting connection error, fast fail
if err != nil {
return false, err
@@ -588,13 +611,16 @@ func (c *HostClient) doNonNilReqResp(req *protocol.Request, resp *protocol.Respo
// Only if the connection is closed while writing the request. Try to parse the response and return.
// In this case, the request/response is considered as successful.
// Otherwise, return the former error.
-
zr := c.acquireReader(conn)
defer zr.Release()
if respI.ReadHeaderAndLimitBody(resp, zr, c.MaxResponseBodySize) == nil {
return false, nil
}
+ if inPool {
+ err = errs.ErrBadPoolConn
+ }
+
return true, err
}
@@ -620,6 +646,24 @@ func (c *HostClient) doNonNilReqResp(req *protocol.Request, resp *protocol.Respo
}
zr := c.acquireReader(conn)
+ // errs.ErrBadPoolConn error are returned when the
+ // 1 byte peek read fails, and we're actually anticipating a response.
+ // Usually this is just due to the inherent keep-alive shut down race,
+ // where the server closed the connection at the same time the client
+ // wrote. The underlying err field is usually io.EOF or some
+ // ECONNRESET sort of thing which varies by platform.
+ _, err = zr.Peek(1)
+ if err != nil {
+ zr.Release() //nolint:errcheck
+ c.closeConn(cc)
+ if inPool && (err == io.EOF || err == syscall.ECONNRESET) {
+ return true, errs.ErrBadPoolConn
+ }
+ // if this is not a pooled connection,
+ // we should not retry to avoid getting stuck in an endless retry loop.
+ return false, err
+ }
+
// init here for passing in ReadBodyStream's closure
// and this value will be assigned after reading Response's Header
//
@@ -676,7 +720,7 @@ func (c *HostClient) SetMaxConns(newMaxConns int) {
c.connsLock.Unlock()
}
-func (c *HostClient) acquireConn(dialTimeout time.Duration) (cc *clientConn, err error) {
+func (c *HostClient) acquireConn(dialTimeout time.Duration) (cc *clientConn, inPool bool, err error) {
createConn := false
startCleaner := false
@@ -705,11 +749,11 @@ func (c *HostClient) acquireConn(dialTimeout time.Duration) (cc *clientConn, err
c.connsLock.Unlock()
if cc != nil {
- return cc, nil
+ return cc, true, nil
}
if !createConn {
if c.MaxConnWaitTimeout <= 0 {
- return nil, errs.ErrNoFreeConns
+ return nil, true, errs.ErrNoFreeConns
}
timeout := c.MaxConnWaitTimeout
@@ -736,9 +780,9 @@ func (c *HostClient) acquireConn(dialTimeout time.Duration) (cc *clientConn, err
select {
case <-w.ready:
- return w.conn, w.err
+ return w.conn, true, w.err
case <-tc.C:
- return nil, errs.ErrNoFreeConns
+ return nil, true, errs.ErrNoFreeConns
}
}
@@ -749,11 +793,11 @@ func (c *HostClient) acquireConn(dialTimeout time.Duration) (cc *clientConn, err
conn, err := c.dialHostHard(dialTimeout)
if err != nil {
c.decConnsCount()
- return nil, err
+ return nil, false, err
}
cc = acquireClientConn(conn)
- return cc, nil
+ return cc, false, nil
}
func (c *HostClient) queueForIdle(w *wantConn) {
diff --git a/pkg/protocol/http1/client_test.go b/pkg/protocol/http1/client_test.go
index 4dd6fcbf2..7bc389bf9 100644
--- a/pkg/protocol/http1/client_test.go
+++ b/pkg/protocol/http1/client_test.go
@@ -58,6 +58,7 @@ import (
"github.com/cloudwego/hertz/pkg/app/client/retry"
"github.com/cloudwego/hertz/pkg/common/config"
errs "github.com/cloudwego/hertz/pkg/common/errors"
+ "github.com/cloudwego/hertz/pkg/common/hlog"
"github.com/cloudwego/hertz/pkg/common/test/assert"
"github.com/cloudwego/hertz/pkg/common/test/mock"
"github.com/cloudwego/hertz/pkg/common/utils"
@@ -69,6 +70,8 @@ import (
"github.com/cloudwego/netpoll"
)
+var errDialTimeout = errs.New(errs.ErrTimeout, errs.ErrorTypePublic, "dial timeout")
+
func TestHostClientMaxConnWaitTimeoutWithEarlierDeadline(t *testing.T) {
var (
emptyBodyCount uint8
@@ -235,7 +238,7 @@ type slowDialer struct {
func (s *slowDialer) DialConnection(network, address string, timeout time.Duration, tlsConfig *tls.Config) (conn network.Conn, err error) {
time.Sleep(timeout)
- return nil, errs.ErrDialTimeout
+ return nil, errDialTimeout
}
func TestReadTimeoutPriority(t *testing.T) {
@@ -264,7 +267,7 @@ func TestReadTimeoutPriority(t *testing.T) {
case <-time.After(time.Second * 2):
t.Fatalf("should use readTimeout in request options")
case err := <-ch:
- assert.DeepEqual(t, errs.ErrTimeout.Error(), err.Error())
+ assert.DeepEqual(t, mock.ErrReadTimeout, err)
}
}
@@ -334,7 +337,7 @@ func TestWriteTimeoutPriority(t *testing.T) {
case <-time.After(time.Second * 2):
t.Fatalf("should use writeTimeout in request options")
case err := <-ch:
- assert.DeepEqual(t, errs.ErrWriteTimeout.Error(), err.Error())
+ assert.DeepEqual(t, mock.ErrWriteTimeout, err)
}
}
@@ -359,10 +362,10 @@ func TestDialTimeoutPriority(t *testing.T) {
ch <- c.Do(context.Background(), req, resp)
}()
select {
- case <-time.After(time.Second * 2000):
+ case <-time.After(time.Second * 2):
t.Fatalf("should use dialTimeout in request options")
case err := <-ch:
- assert.DeepEqual(t, errs.ErrDialTimeout.Error(), err.Error())
+ assert.DeepEqual(t, errDialTimeout, err)
}
}
@@ -479,3 +482,57 @@ type retryConn struct {
func (w retryConn) SetWriteTimeout(t time.Duration) error {
return errors.New("should retry")
}
+
+func TestConnInPoolRetry(t *testing.T) {
+ c := &HostClient{
+ ClientOptions: &ClientOptions{
+ Dialer: newSlowConnDialer(func(network, addr string) (network.Conn, error) {
+ return mock.NewOneTimeConn("HTTP/1.1 200 OK\r\nContent-Length: 10\r\nContent-Type: foo/bar\r\n\r\n0123456789"), nil
+ }),
+ },
+ Addr: "foobar",
+ }
+
+ req := protocol.AcquireRequest()
+ req.SetRequestURI("http://foobar/baz")
+ req.SetOptions(config.WithWriteTimeout(time.Millisecond * 100))
+ resp := protocol.AcquireResponse()
+
+ logbuf := &bytes.Buffer{}
+ hlog.SetOutput(logbuf)
+
+ err := c.Do(context.Background(), req, resp)
+ assert.Nil(t, err)
+ assert.DeepEqual(t, resp.StatusCode(), 200)
+ assert.DeepEqual(t, string(resp.Body()), "0123456789")
+ assert.True(t, logbuf.String() == "")
+ protocol.ReleaseResponse(resp)
+ resp = protocol.AcquireResponse()
+ err = c.Do(context.Background(), req, resp)
+ assert.Nil(t, err)
+ assert.DeepEqual(t, resp.StatusCode(), 200)
+ assert.DeepEqual(t, string(resp.Body()), "0123456789")
+ assert.True(t, strings.Contains(logbuf.String(), "Client connection attempt times: 1"))
+}
+
+func TestConnNotRetry(t *testing.T) {
+ c := &HostClient{
+ ClientOptions: &ClientOptions{
+ Dialer: newSlowConnDialer(func(network, addr string) (network.Conn, error) {
+ return mock.NewBrokenConn(""), nil
+ }),
+ },
+ Addr: "foobar",
+ }
+
+ req := protocol.AcquireRequest()
+ req.SetRequestURI("http://foobar/baz")
+ req.SetOptions(config.WithWriteTimeout(time.Millisecond * 100))
+ resp := protocol.AcquireResponse()
+ logbuf := &bytes.Buffer{}
+ hlog.SetOutput(logbuf)
+ err := c.Do(context.Background(), req, resp)
+ assert.DeepEqual(t, errs.ErrConnectionClosed, err)
+ assert.True(t, logbuf.String() == "")
+ protocol.ReleaseResponse(resp)
+}
diff --git a/pkg/protocol/http1/client_unix_test.go b/pkg/protocol/http1/client_unix_test.go
index 6e417ba18..1e88502d3 100644
--- a/pkg/protocol/http1/client_unix_test.go
+++ b/pkg/protocol/http1/client_unix_test.go
@@ -35,7 +35,7 @@ import (
)
func TestGcBodyStream(t *testing.T) {
- srv := &http.Server{Addr: ":11001", Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
+ srv := &http.Server{Addr: "127.0.0.1:11001", Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
for range [1024]int{} {
w.Write([]byte("hello world\n"))
}
@@ -69,7 +69,7 @@ func TestGcBodyStream(t *testing.T) {
}
func TestMaxConn(t *testing.T) {
- srv := &http.Server{Addr: ":11002", Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
+ srv := &http.Server{Addr: "127.0.0.1:11002", Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Write([]byte("hello world\n"))
})}
go srv.ListenAndServe()
diff --git a/pkg/protocol/http1/server.go b/pkg/protocol/http1/server.go
index 8aaed6db7..77cc74e18 100644
--- a/pkg/protocol/http1/server.go
+++ b/pkg/protocol/http1/server.go
@@ -48,7 +48,7 @@ const NextProtoTLS = suite.HTTP1
var (
errHijacked = errs.New(errs.ErrHijacked, errs.ErrorTypePublic, nil)
- errIdleTimeout = errs.New(errs.ErrIdleTimeout, errs.ErrorTypePublic, nil)
+ errIdleTimeout = errs.New(errs.ErrIdleTimeout, errs.ErrorTypePrivate, nil)
errShortConnection = errs.New(errs.ErrShortConnection, errs.ErrorTypePublic, "server is going to close the connection")
errUnexpectedEOF = errs.NewPublic(io.ErrUnexpectedEOF.Error() + " when reading request")
)
From 41c781eb0755258d8b0d1b6b87d7495654a292ee Mon Sep 17 00:00:00 2001
From: Wenju Gao
Date: Mon, 31 Jul 2023 16:40:54 +0800
Subject: [PATCH 14/17] optimize(client): do not write body in stream mode if
content-length is 0 (#875)
---
pkg/common/test/mock/network.go | 6 +++++-
pkg/common/test/mock/network_test.go | 2 +-
pkg/protocol/http1/ext/common.go | 3 +++
pkg/protocol/http1/ext/common_test.go | 6 ++++++
4 files changed, 15 insertions(+), 2 deletions(-)
diff --git a/pkg/common/test/mock/network.go b/pkg/common/test/mock/network.go
index 57293cb2b..fe671ae56 100644
--- a/pkg/common/test/mock/network.go
+++ b/pkg/common/test/mock/network.go
@@ -172,7 +172,11 @@ type BrokenConn struct {
}
func (o *BrokenConn) Peek(i int) ([]byte, error) {
- return nil, io.EOF
+ return nil, io.ErrUnexpectedEOF
+}
+
+func (o *BrokenConn) Read(b []byte) (n int, err error) {
+ return 0, io.ErrUnexpectedEOF
}
func (o *BrokenConn) Flush() error {
diff --git a/pkg/common/test/mock/network_test.go b/pkg/common/test/mock/network_test.go
index bd6ad5eb9..4c9c4cf5b 100644
--- a/pkg/common/test/mock/network_test.go
+++ b/pkg/common/test/mock/network_test.go
@@ -202,7 +202,7 @@ func TestBrokenConn_Peek(t *testing.T) {
conn := NewBrokenConn("Foo")
buf, err := conn.Peek(3)
assert.Nil(t, buf)
- assert.DeepEqual(t, io.EOF, err)
+ assert.DeepEqual(t, io.ErrUnexpectedEOF, err)
}
func TestOneTimeConn_Flush(t *testing.T) {
diff --git a/pkg/protocol/http1/ext/common.go b/pkg/protocol/http1/ext/common.go
index b31b40cdd..864988317 100644
--- a/pkg/protocol/http1/ext/common.go
+++ b/pkg/protocol/http1/ext/common.go
@@ -135,6 +135,9 @@ func WriteBodyChunked(w network.Writer, r io.Reader) error {
}
func WriteBodyFixedSize(w network.Writer, r io.Reader, size int64) error {
+ if size == 0 {
+ return nil
+ }
if size > consts.MaxSmallFileSize {
if err := w.Flush(); err != nil {
return err
diff --git a/pkg/protocol/http1/ext/common_test.go b/pkg/protocol/http1/ext/common_test.go
index 824ccf4ca..78a9fede0 100644
--- a/pkg/protocol/http1/ext/common_test.go
+++ b/pkg/protocol/http1/ext/common_test.go
@@ -155,6 +155,12 @@ func TestBodyFixedSize(t *testing.T) {
assert.DeepEqual(t, body, rb)
}
+func TestBodyFixedSizeQuickPath(t *testing.T) {
+ conn := mock.NewBrokenConn("")
+ err := WriteBodyFixedSize(conn.Writer(), conn, 0)
+ assert.Nil(t, err)
+}
+
func TestBodyIdentity(t *testing.T) {
body := mock.CreateFixedBody(1024)
zr := mock.NewZeroCopyReader(string(body))
From fc36f6f69e7f29edd640aeac51f6eae8ab11e9d3 Mon Sep 17 00:00:00 2001
From: Wenju Gao
Date: Thu, 3 Aug 2023 16:29:49 +0800
Subject: [PATCH 15/17] optimize: unify timeout error (#884)
---
pkg/network/netpoll/connection.go | 5 +++++
pkg/network/standard/connection.go | 4 ++++
pkg/protocol/http1/client.go | 4 ++++
3 files changed, 13 insertions(+)
diff --git a/pkg/network/netpoll/connection.go b/pkg/network/netpoll/connection.go
index 7c26ef914..e6df5030c 100644
--- a/pkg/network/netpoll/connection.go
+++ b/pkg/network/netpoll/connection.go
@@ -35,6 +35,11 @@ func (c *Conn) ToHertzError(err error) error {
if errors.Is(err, netpoll.ErrConnClosed) || errors.Is(err, syscall.EPIPE) {
return errs.ErrConnectionClosed
}
+
+ // only unify read timeout for now
+ if errors.Is(err, netpoll.ErrReadTimeout) {
+ return errs.ErrTimeout
+ }
return err
}
diff --git a/pkg/network/standard/connection.go b/pkg/network/standard/connection.go
index e38ee9e7b..13c10b479 100644
--- a/pkg/network/standard/connection.go
+++ b/pkg/network/standard/connection.go
@@ -54,6 +54,10 @@ func (c *Conn) ToHertzError(err error) error {
if errors.Is(err, syscall.EPIPE) || errors.Is(err, syscall.ENOTCONN) {
return errs.ErrConnectionClosed
}
+ if netErr, ok := err.(*net.OpError); ok && netErr.Timeout() {
+ return errs.ErrTimeout
+ }
+
return err
}
diff --git a/pkg/protocol/http1/client.go b/pkg/protocol/http1/client.go
index c306f071f..c2b7cf5a4 100644
--- a/pkg/protocol/http1/client.go
+++ b/pkg/protocol/http1/client.go
@@ -661,6 +661,10 @@ func (c *HostClient) doNonNilReqResp(req *protocol.Request, resp *protocol.Respo
}
// if this is not a pooled connection,
// we should not retry to avoid getting stuck in an endless retry loop.
+ errNorm, ok := conn.(network.ErrorNormalization)
+ if ok {
+ err = errNorm.ToHertzError(err)
+ }
return false, err
}
From 2b0a165b60f9361c20ffc36b551250310c293107 Mon Sep 17 00:00:00 2001
From: GuangyuFan <97507466+FGYFFFF@users.noreply.github.com>
Date: Fri, 4 Aug 2023 15:01:45 +0800
Subject: [PATCH 16/17] chore(hz): release v0.6.6 (#886)
---
cmd/hz/meta/const.go | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/cmd/hz/meta/const.go b/cmd/hz/meta/const.go
index 66f0a1fb4..8e7490dfe 100644
--- a/cmd/hz/meta/const.go
+++ b/cmd/hz/meta/const.go
@@ -19,7 +19,7 @@ package meta
import "runtime"
// Version hz version
-const Version = "v0.6.5"
+const Version = "v0.6.6"
const DefaultServiceName = "hertz_service"
From 69c28889a12c680f4493548b48f444ee6fa96aa0 Mon Sep 17 00:00:00 2001
From: alice <90381261+alice-yyds@users.noreply.github.com>
Date: Fri, 4 Aug 2023 15:02:30 +0800
Subject: [PATCH 17/17] chore: update version v0.6.7
---
version.go | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/version.go b/version.go
index f56cbc31a..2f7a82fb5 100644
--- a/version.go
+++ b/version.go
@@ -19,5 +19,5 @@ package hertz
// Name and Version info of this framework, used for statistics and debug
const (
Name = "Hertz"
- Version = "v0.6.6"
+ Version = "v0.6.7"
)