From 2bd70be79b3dd36fec6f6d7b48e81b0da5579972 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Mon, 17 Jul 2023 21:01:46 +0800 Subject: [PATCH 01/17] chore(http1): remove duplicated code (#817) --- pkg/protocol/http1/server.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/pkg/protocol/http1/server.go b/pkg/protocol/http1/server.go index ba8915a1d..8aaed6db7 100644 --- a/pkg/protocol/http1/server.go +++ b/pkg/protocol/http1/server.go @@ -350,11 +350,6 @@ func (s Server) Serve(c context.Context, conn network.Conn) (err error) { } if hijackHandler != nil { - if zr != nil { - zr.Release() //nolint:errcheck - zr = nil - } - // Hijacked conn process the timeout by itself err = ctx.GetConn().SetReadTimeout(0) if err != nil { From 7d25d5a774a34d9dd7a6c9a58beedc01bc2513ed Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Tue, 18 Jul 2023 17:27:35 +0800 Subject: [PATCH 02/17] test: add test for protocol http1 (#804) Co-authored-by: kinggo --- pkg/common/test/assert/assert.go | 2 + pkg/common/test/mock/body_data.go | 28 ++++ pkg/common/test/mock/body_data_test.go | 12 ++ pkg/common/test/mock/network.go | 4 + pkg/common/test/mock/network_test.go | 1 + pkg/protocol/http1/client_test.go | 108 ++++++++++++++ pkg/protocol/http1/client_unix_test.go | 56 ++++++++ pkg/protocol/http1/ext/common_test.go | 94 +++++++++++++ pkg/protocol/http1/ext/headerscanner_test.go | 60 ++++++++ pkg/protocol/http1/ext/stream_test.go | 139 +++++++++++++++++++ pkg/protocol/http1/req/header_test.go | 9 ++ pkg/protocol/http1/req/request_test.go | 74 ++++++---- pkg/protocol/http1/resp/response_test.go | 6 + pkg/protocol/http1/server_test.go | 115 ++++++++++++++- 14 files changed, 679 insertions(+), 29 deletions(-) diff --git a/pkg/common/test/assert/assert.go b/pkg/common/test/assert/assert.go index 9ed544f35..ed157835c 100644 --- a/pkg/common/test/assert/assert.go +++ b/pkg/common/test/assert/assert.go @@ -86,10 +86,12 @@ func NotEqual(t testing.TB, expected, actual interface{}) { } func True(t testing.TB, obj interface{}) { + t.Helper() DeepEqual(t, true, obj) } func False(t testing.TB, obj interface{}) { + t.Helper() DeepEqual(t, false, obj) } diff --git a/pkg/common/test/mock/body_data.go b/pkg/common/test/mock/body_data.go index 808e25e99..807c585c2 100644 --- a/pkg/common/test/mock/body_data.go +++ b/pkg/common/test/mock/body_data.go @@ -16,6 +16,8 @@ package mock +import "fmt" + func CreateFixedBody(bodySize int) []byte { var b []byte for i := 0; i < bodySize; i++ { @@ -23,3 +25,29 @@ func CreateFixedBody(bodySize int) []byte { } return b } + +func CreateChunkedBody(body []byte, trailer map[string]string, hasTrailer bool) []byte { + var b []byte + chunkSize := 1 + for len(body) > 0 { + if chunkSize > len(body) { + chunkSize = len(body) + } + b = append(b, []byte(fmt.Sprintf("%x\r\n", chunkSize))...) + b = append(b, body[:chunkSize]...) + b = append(b, []byte("\r\n")...) + body = body[chunkSize:] + chunkSize++ + } + if hasTrailer { + b = append(b, "0\r\n"...) + for k, v := range trailer { + b = append(b, k...) + b = append(b, ": "...) + b = append(b, v...) + b = append(b, "\r\n"...) + } + b = append(b, "\r\n"...) + } + return b +} diff --git a/pkg/common/test/mock/body_data_test.go b/pkg/common/test/mock/body_data_test.go index a783a47c4..62ebf1bcd 100644 --- a/pkg/common/test/mock/body_data_test.go +++ b/pkg/common/test/mock/body_data_test.go @@ -18,6 +18,8 @@ package mock import ( "testing" + + "github.com/cloudwego/hertz/pkg/common/test/assert" ) func TestGenerateCreateFixedBody(t *testing.T) { @@ -33,3 +35,13 @@ func TestGenerateCreateFixedBody(t *testing.T) { t.Fatalf("Unexpected %s. Expecting a nil", nilFixedBody) } } + +func TestGenerateCreateChunkedBody(t *testing.T) { + bodySize := 10 + b := CreateFixedBody(bodySize) + trailer := map[string]string{"Foo": "chunked shit"} + expectCb := "1\r\n0\r\n2\r\n12\r\n3\r\n345\r\n4\r\n6789\r\n0\r\nFoo: chunked shit\r\n\r\n" + + cb := CreateChunkedBody(b, trailer, true) + assert.DeepEqual(t, expectCb, string(cb)) +} diff --git a/pkg/common/test/mock/network.go b/pkg/common/test/mock/network.go index 3d4747563..7ad083647 100644 --- a/pkg/common/test/mock/network.go +++ b/pkg/common/test/mock/network.go @@ -125,6 +125,10 @@ func (m *Conn) WriterRecorder() Recorder { return &recorder{c: m, Reader: m.zw} } +func (m *Conn) GetReadTimeout() time.Duration { + return m.readTimeout +} + type recorder struct { c *Conn network.Reader diff --git a/pkg/common/test/mock/network_test.go b/pkg/common/test/mock/network_test.go index 84df1cdae..2d4ff973a 100644 --- a/pkg/common/test/mock/network_test.go +++ b/pkg/common/test/mock/network_test.go @@ -34,6 +34,7 @@ func TestConn(t *testing.T) { assert.DeepEqual(t, nil, err) err = conn1.SetReadTimeout(time.Millisecond * 100) assert.DeepEqual(t, nil, err) + assert.DeepEqual(t, time.Millisecond*100, conn1.GetReadTimeout()) // Peek Skip Read b, _ := conn1.Peek(1) diff --git a/pkg/protocol/http1/client_test.go b/pkg/protocol/http1/client_test.go index 1e4e0139f..19154453c 100644 --- a/pkg/protocol/http1/client_test.go +++ b/pkg/protocol/http1/client_test.go @@ -55,12 +55,15 @@ import ( "testing" "time" + "github.com/cloudwego/hertz/pkg/app/client/retry" "github.com/cloudwego/hertz/pkg/common/config" errs "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/common/test/mock" + "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/network" "github.com/cloudwego/hertz/pkg/protocol" + "github.com/cloudwego/hertz/pkg/protocol/client" "github.com/cloudwego/hertz/pkg/protocol/consts" "github.com/cloudwego/hertz/pkg/protocol/http1/resp" "github.com/cloudwego/netpoll" @@ -363,6 +366,103 @@ func TestDialTimeoutPriority(t *testing.T) { } } +func TestStateObserve(t *testing.T) { + syncState := struct { + mu sync.Mutex + state config.ConnPoolState + }{} + c := &HostClient{ + ClientOptions: &ClientOptions{ + Dialer: newSlowConnDialer(func(network, addr string) (network.Conn, error) { + return mock.SlowReadDialer(addr) + }), + StateObserve: func(hcs config.HostClientState) { + syncState.mu.Lock() + defer syncState.mu.Unlock() + syncState.state = hcs.ConnPoolState() + }, + ObservationInterval: 50 * time.Millisecond, + }, + Addr: "foobar", + closed: make(chan struct{}), + } + + c.SetDynamicConfig(&client.DynamicConfig{ + Addr: utils.AddMissingPort(c.Addr, true), + }) + + time.Sleep(500 * time.Millisecond) + assert.Nil(t, c.Close()) + syncState.mu.Lock() + assert.DeepEqual(t, "foobar:443", syncState.state.Addr) + syncState.mu.Unlock() +} + +func TestCachedTLSConfig(t *testing.T) { + c := &HostClient{ + ClientOptions: &ClientOptions{ + Dialer: newSlowConnDialer(func(network, addr string) (network.Conn, error) { + return mock.SlowReadDialer(addr) + }), + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + }, + Addr: "foobar", + IsTLS: true, + } + + cfg1 := c.cachedTLSConfig("foobar") + cfg2 := c.cachedTLSConfig("baz") + assert.NotEqual(t, cfg1, cfg2) + cfg3 := c.cachedTLSConfig("foobar") + assert.DeepEqual(t, cfg1, cfg3) +} + +func TestRetry(t *testing.T) { + var times int32 + c := &HostClient{ + ClientOptions: &ClientOptions{ + Dialer: newSlowConnDialer(func(network, addr string) (network.Conn, error) { + times++ + if times < 3 { + return &retryConn{ + Conn: mock.NewConn(""), + }, nil + } + return mock.NewConn("HTTP/1.1 200 OK\r\nContent-Length: 10\r\nContent-Type: foo/bar\r\n\r\n0123456789"), nil + }), + RetryConfig: &retry.Config{ + MaxAttemptTimes: 5, + Delay: time.Millisecond * 10, + }, + RetryIfFunc: func(req *protocol.Request, resp *protocol.Response, err error) bool { + return true + }, + }, + Addr: "foobar", + } + + req := protocol.AcquireRequest() + req.SetRequestURI("http://foobar/baz") + req.SetOptions(config.WithWriteTimeout(time.Millisecond * 100)) + resp := protocol.AcquireResponse() + + ch := make(chan error, 1) + go func() { + ch <- c.Do(context.Background(), req, resp) + }() + select { + case <-time.After(time.Second * 2): + t.Fatalf("should use writeTimeout in request options") + case err := <-ch: + assert.Nil(t, err) + assert.True(t, times == 3) + assert.DeepEqual(t, resp.StatusCode(), 200) + assert.DeepEqual(t, resp.Body(), []byte("0123456789")) + } +} + // mockConn for getting error when write binary data. type writeErrConn struct { network.Conn @@ -371,3 +471,11 @@ type writeErrConn struct { func (w writeErrConn) WriteBinary(b []byte) (n int, err error) { return 0, errs.ErrConnectionClosed } + +type retryConn struct { + network.Conn +} + +func (w retryConn) SetWriteTimeout(t time.Duration) error { + return errors.New("should retry") +} diff --git a/pkg/protocol/http1/client_unix_test.go b/pkg/protocol/http1/client_unix_test.go index ba616ad54..6e417ba18 100644 --- a/pkg/protocol/http1/client_unix_test.go +++ b/pkg/protocol/http1/client_unix_test.go @@ -19,11 +19,15 @@ package http1 import ( "context" + "errors" "net/http" "runtime" + "sync" + "sync/atomic" "testing" "time" + errs "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/network/netpoll" "github.com/cloudwego/hertz/pkg/protocol" @@ -63,3 +67,55 @@ func TestGcBodyStream(t *testing.T) { c.CloseIdleConnections() assert.DeepEqual(t, 0, c.ConnPoolState().TotalConnNum) } + +func TestMaxConn(t *testing.T) { + srv := &http.Server{Addr: ":11002", Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.Write([]byte("hello world\n")) + })} + go srv.ListenAndServe() + time.Sleep(100 * time.Millisecond) + + c := &HostClient{ + ClientOptions: &ClientOptions{ + Dialer: netpoll.NewDialer(), + ResponseBodyStream: true, + MaxConnWaitTimeout: time.Millisecond * 100, + MaxConns: 5, + }, + Addr: "127.0.0.1:11002", + } + + var successCount int32 + var noFreeCount int32 + wg := sync.WaitGroup{} + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + req, resp := protocol.AcquireRequest(), protocol.AcquireResponse() + req.SetRequestURI("http://127.0.0.1:11002") + req.SetMethod(consts.MethodPost) + err := c.Do(context.Background(), req, resp) + if err != nil { + if errors.Is(err, errs.ErrNoFreeConns) { + atomic.AddInt32(&noFreeCount, 1) + return + } + t.Errorf("client Do error=%v", err.Error()) + } + atomic.AddInt32(&successCount, 1) + }() + } + wg.Wait() + + assert.True(t, atomic.LoadInt32(&successCount) == 5) + assert.True(t, atomic.LoadInt32(&noFreeCount) == 5) + assert.DeepEqual(t, 0, c.ConnectionCount()) + assert.DeepEqual(t, 5, c.WantConnectionCount()) + + runtime.GC() + // wait for gc + time.Sleep(100 * time.Millisecond) + c.CloseIdleConnections() + assert.DeepEqual(t, 0, c.WantConnectionCount()) +} diff --git a/pkg/protocol/http1/ext/common_test.go b/pkg/protocol/http1/ext/common_test.go index 5ad9c001a..824ccf4ca 100644 --- a/pkg/protocol/http1/ext/common_test.go +++ b/pkg/protocol/http1/ext/common_test.go @@ -18,12 +18,16 @@ package ext import ( "bytes" + "errors" + "io" "strings" "testing" + errs "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/common/test/mock" "github.com/cloudwego/hertz/pkg/protocol" + "github.com/cloudwego/netpoll" ) func Test_stripSpace(t *testing.T) { @@ -77,6 +81,12 @@ func TestReadTrailerError(t *testing.T) { if err == nil { t.Fatalf("expecting error.") } + + // eof + er := mock.EOFReader{} + trailer = protocol.Trailer{} + err = ReadTrailer(&trailer, &er) + assert.DeepEqual(t, io.EOF, err) } func TestReadTrailer1(t *testing.T) { @@ -95,3 +105,87 @@ func TestReadTrailer1(t *testing.T) { } } } + +func TestReadRawHeaders(t *testing.T) { + s := "HTTP/1.1 200 OK\r\n" + + "EmptyValue1:\r\n" + + "Content-Type: foo/bar;\r\n\tnewline;\r\n another/newline\r\n" + + "Foo: Bar\r\n" + + "Multi-Line: one;\r\n two\r\n" + + "Values: v1;\r\n v2; v3;\r\n v4;\tv5\r\n" + + "Content-Length: 5\r\n\r\n" + + "HELLOaaa" + + var dst []byte + rawHeaders, index, err := ReadRawHeaders(dst, []byte(s)) + assert.Nil(t, err) + assert.DeepEqual(t, s[:index], string(rawHeaders)) +} + +func TestBodyChunked(t *testing.T) { + body := "foobar baz aaa bbb ccc" + chunk := "16\r\nfoobar baz aaa bbb ccc\r\n0\r\n" + b := bytes.NewBufferString(body) + + var w bytes.Buffer + zw := netpoll.NewWriter(&w) + WriteBodyChunked(zw, b) + + assert.DeepEqual(t, chunk, w.String()) + + zr := mock.NewZeroCopyReader(chunk) + rb, err := ReadBody(zr, -1, 0, nil) + assert.Nil(t, err) + assert.DeepEqual(t, body, string(rb)) +} + +func TestBodyFixedSize(t *testing.T) { + body := mock.CreateFixedBody(10) + b := bytes.NewBuffer(body) + + var w bytes.Buffer + zw := netpoll.NewWriter(&w) + WriteBodyFixedSize(zw, b, int64(len(body))) + + assert.DeepEqual(t, body, w.Bytes()) + + zr := mock.NewZeroCopyReader(string(body)) + rb, err := ReadBody(zr, len(body), 0, nil) + assert.Nil(t, err) + assert.DeepEqual(t, body, rb) +} + +func TestBodyIdentity(t *testing.T) { + body := mock.CreateFixedBody(1024) + zr := mock.NewZeroCopyReader(string(body)) + rb, err := ReadBody(zr, -2, 0, nil) + assert.Nil(t, err) + assert.DeepEqual(t, string(body), string(rb)) +} + +func TestBodySkipTrailer(t *testing.T) { + t.Run("TestBodySkipTrailer", func(t *testing.T) { + body := mock.CreateFixedBody(10) + trailer := map[string]string{"Foo": "chunked shit"} + chunkedBody := mock.CreateChunkedBody(body, trailer, true) + r := mock.NewSlowReadConn(string(chunkedBody)) + err := SkipTrailer(r) + assert.Nil(t, err) + _, err = r.ReadByte() + assert.NotNil(t, err) + assert.True(t, errors.Is(err, netpoll.ErrEOF)) + }) + + t.Run("TestBodySkipTrailerError", func(t *testing.T) { + // timeout error + sr := mock.NewSlowReadConn("") + err := SkipTrailer(sr) + assert.NotNil(t, err) + assert.True(t, errors.Is(err, errs.ErrTimeout)) + // EOF error + er := &mock.EOFReader{} + err = SkipTrailer(er) + assert.NotNil(t, err) + assert.True(t, errors.Is(err, io.EOF)) + }) +} diff --git a/pkg/protocol/http1/ext/headerscanner_test.go b/pkg/protocol/http1/ext/headerscanner_test.go index fc1f92126..0f8874d81 100644 --- a/pkg/protocol/http1/ext/headerscanner_test.go +++ b/pkg/protocol/http1/ext/headerscanner_test.go @@ -42,8 +42,13 @@ package ext import ( + "bufio" + "errors" + "net/http" + "strings" "testing" + errs "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/common/test/assert" ) @@ -52,3 +57,58 @@ func TestHasHeaderValue(t *testing.T) { assert.True(t, HasHeaderValue(s, []byte("Connection: Keep-Alive"))) assert.False(t, HasHeaderValue(s, []byte("Connection: Keep-Alive1"))) } + +func TestResponseHeaderMultiLineValue(t *testing.T) { + firstLine := "HTTP/1.1 200 OK\r\n" + rawHeaders := "EmptyValue1:\r\n" + + "Content-Type: foo/bar;\r\n\tnewline;\r\n another/newline\r\n" + + "Foo: Bar\r\n" + + "Multi-Line: one;\r\n two\r\n" + + "Values: v1;\r\n v2; v3;\r\n v4;\tv5\r\n" + + "\r\n" + + // compared with http response + response, err := http.ReadResponse(bufio.NewReader(strings.NewReader(firstLine+rawHeaders)), nil) + assert.Nil(t, err) + defer func() { response.Body.Close() }() + + hs := &HeaderScanner{} + hs.B = []byte(rawHeaders) + hs.DisableNormalizing = false + hmap := make(map[string]string, len(response.Header)) + for hs.Next() { + if len(hs.Key) > 0 { + hmap[string(hs.Key)] = string(hs.Value) + } + } + + for name, vals := range response.Header { + got := hmap[name] + want := vals[0] + assert.DeepEqual(t, want, got) + } +} + +func TestHeaderScannerError(t *testing.T) { + t.Run("TestHeaderScannerErrorInvalidName", func(t *testing.T) { + rawHeaders := "Host: go.dev\r\nGopher-New-\r\n Line: This is a header on multiple lines\r\n\r\n" + testTestHeaderScannerError(t, rawHeaders, errInvalidName) + }) + t.Run("TestHeaderScannerErrorNeedMore", func(t *testing.T) { + rawHeaders := "This is a header on multiple lines" + testTestHeaderScannerError(t, rawHeaders, errs.ErrNeedMore) + + rawHeaders = "Gopher-New-\r\n Line" + testTestHeaderScannerError(t, rawHeaders, errs.ErrNeedMore) + }) +} + +func testTestHeaderScannerError(t *testing.T, rawHeaders string, expectError error) { + hs := &HeaderScanner{} + hs.B = []byte(rawHeaders) + hs.DisableNormalizing = false + for hs.Next() { + } + assert.NotNil(t, hs.Err) + assert.True(t, errors.Is(hs.Err, expectError)) +} diff --git a/pkg/protocol/http1/ext/stream_test.go b/pkg/protocol/http1/ext/stream_test.go index 188d234eb..ccec83dab 100644 --- a/pkg/protocol/http1/ext/stream_test.go +++ b/pkg/protocol/http1/ext/stream_test.go @@ -17,6 +17,7 @@ package ext import ( "bytes" + "errors" "fmt" "io" "testing" @@ -112,3 +113,141 @@ func TestBodyStream_Reset(t *testing.T) { assert.DeepEqual(t, 0, bs.chunkLeft) assert.False(t, bs.chunkEOF) } + +func TestReadBodyWithStreaming(t *testing.T) { + t.Run("TestBodyFixedSize", func(t *testing.T) { + bodySize := 1024 + body := mock.CreateFixedBody(bodySize) + reader := mock.NewZeroCopyReader(string(body)) + dst, err := ReadBodyWithStreaming(reader, bodySize, -1, nil) + assert.Nil(t, err) + assert.DeepEqual(t, body, dst) + }) + + t.Run("TestBodyFixedSizeMaxContentLength", func(t *testing.T) { + bodySize := 8 * 1024 * 2 + body := mock.CreateFixedBody(bodySize) + reader := mock.NewZeroCopyReader(string(body)) + dst, err := ReadBodyWithStreaming(reader, bodySize, 8*1024*10, nil) + assert.Nil(t, err) + assert.DeepEqual(t, body[:maxContentLengthInStream], dst) + }) + + t.Run("TestBodyIdentity", func(t *testing.T) { + bodySize := 1024 + body := mock.CreateFixedBody(bodySize) + reader := mock.NewZeroCopyReader(string(body)) + dst, err := ReadBodyWithStreaming(reader, -2, 512, nil) + assert.Nil(t, err) + assert.DeepEqual(t, body, dst) + }) + + t.Run("TestErrBodyTooLarge", func(t *testing.T) { + bodySize := 2048 + body := mock.CreateFixedBody(bodySize) + reader := mock.NewZeroCopyReader(string(body)) + dst, err := ReadBodyWithStreaming(reader, bodySize, 1024, nil) + assert.True(t, errors.Is(err, errBodyTooLarge)) + assert.DeepEqual(t, body[:len(dst)], dst) + }) + + t.Run("TestErrChunkedStream", func(t *testing.T) { + bodySize := 1024 + body := mock.CreateFixedBody(bodySize) + reader := mock.NewZeroCopyReader(string(body)) + dst, err := ReadBodyWithStreaming(reader, -1, bodySize, nil) + assert.True(t, errors.Is(err, errChunkedStream)) + assert.Nil(t, dst) + }) +} + +func TestBodyStream(t *testing.T) { + t.Run("TestBodyStreamPrereadBuffer", func(t *testing.T) { + bodySize := 1024 + body := mock.CreateFixedBody(bodySize) + byteBuffer := &bytebufferpool.ByteBuffer{} + byteBuffer.Set(body) + + bs := AcquireBodyStream(byteBuffer, mock.NewSlowReadConn(""), nil, len(body)) + defer func() { + ReleaseBodyStream(bs) + }() + + b := make([]byte, bodySize) + err := bodyStreamRead(bs, b) + assert.Nil(t, err) + assert.DeepEqual(t, len(body), len(b)) + assert.DeepEqual(t, string(body), string(b)) + }) + + t.Run("TestBodyStreamRelease", func(t *testing.T) { + bodySize := 1024 + body := mock.CreateFixedBody(bodySize) + byteBuffer := &bytebufferpool.ByteBuffer{} + byteBuffer.Set(body) + bs := AcquireBodyStream(byteBuffer, mock.NewSlowReadConn(string(body)), nil, bodySize*2) + err := ReleaseBodyStream(bs) + assert.Nil(t, err) + }) + + t.Run("TestBodyStreamChunked", func(t *testing.T) { + bodySize := 5 + body := mock.CreateFixedBody(bodySize) + expectedTrailer := map[string]string{"Foo": "chunked shit"} + chunkedBody := mock.CreateChunkedBody(body, expectedTrailer, true) + + byteBuffer := &bytebufferpool.ByteBuffer{} + byteBuffer.Set(chunkedBody) + + bs := AcquireBodyStream(byteBuffer, mock.NewSlowReadConn(string(chunkedBody)), &protocol.Trailer{}, -1) + defer func() { + ReleaseBodyStream(bs) + }() + + b := make([]byte, bodySize) + err := bodyStreamRead(bs, b) + assert.Nil(t, err) + assert.DeepEqual(t, len(body), len(b)) + assert.DeepEqual(t, string(body), string(b)) + }) + + t.Run("TestBodyStreamReadFromWire", func(t *testing.T) { + bodySize := 1024 + body := mock.CreateFixedBody(bodySize) + byteBuffer := &bytebufferpool.ByteBuffer{} + byteBuffer.Set(body) + + rcBodySize := 128 + rcBody := mock.CreateFixedBody(rcBodySize) + bs := AcquireBodyStream(byteBuffer, mock.NewSlowReadConn(string(rcBody)), nil, -2) + defer func() { + ReleaseBodyStream(bs) + }() + + b := make([]byte, bodySize) + err := bodyStreamRead(bs, b) + assert.Nil(t, err) + assert.DeepEqual(t, len(body), len(b)) + assert.DeepEqual(t, string(body), string(b)) + }) +} + +func bodyStreamRead(bs io.Reader, b []byte) (err error) { + nb := 0 + for { + p := make([]byte, 64) + n, rErr := bs.Read(p) + if n > 0 { + copy(b[nb:], p[:]) + nb = nb + n + } + + if rErr != nil { + if rErr != io.EOF { + err = rErr + } + break + } + } + return +} diff --git a/pkg/protocol/http1/req/header_test.go b/pkg/protocol/http1/req/header_test.go index fcf5f4177..489e0aabf 100644 --- a/pkg/protocol/http1/req/header_test.go +++ b/pkg/protocol/http1/req/header_test.go @@ -44,11 +44,13 @@ package req import ( "bufio" "bytes" + "errors" "fmt" "net/http" "strings" "testing" + errs "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/common/test/mock" "github.com/cloudwego/hertz/pkg/protocol" @@ -412,3 +414,10 @@ func TestRequestHeader_PeekIfExists(t *testing.T) { assert.DeepEqual(t, []byte{}, rh.Peek("exists")) assert.DeepEqual(t, []byte(nil), rh.Peek("non-exists")) } + +func TestRequestHeaderError(t *testing.T) { + er := mock.EOFReader{} + rh := protocol.RequestHeader{} + err := ReadHeader(&rh, &er) + assert.True(t, errors.Is(err, errs.ErrNothingRead)) +} diff --git a/pkg/protocol/http1/req/request_test.go b/pkg/protocol/http1/req/request_test.go index e123da406..0411187a5 100644 --- a/pkg/protocol/http1/req/request_test.go +++ b/pkg/protocol/http1/req/request_test.go @@ -44,6 +44,7 @@ package req import ( "bufio" "bytes" + "encoding/base64" "errors" "fmt" "io" @@ -54,6 +55,7 @@ import ( "testing" "github.com/cloudwego/hertz/internal/bytesconv" + "github.com/cloudwego/hertz/internal/bytestr" "github.com/cloudwego/hertz/pkg/common/bytebufferpool" "github.com/cloudwego/hertz/pkg/common/compress" errs "github.com/cloudwego/hertz/pkg/common/errors" @@ -451,7 +453,7 @@ func TestRequestWriteRequestURINoHost(t *testing.T) { t.Parallel() var req protocol.Request - req.Header.SetRequestURI("http://google.com/foo/bar?baz=aaa") + req.Header.SetRequestURI("http://user:pass@google.com/foo/bar?baz=aaa") var w bytes.Buffer zw := netpoll.NewWriter(&w) if err := Write(&req, zw); err != nil { @@ -474,6 +476,16 @@ func TestRequestWriteRequestURINoHost(t *testing.T) { if string(req.Header.RequestURI()) != "/foo/bar?baz=aaa" { t.Fatalf("unexpected requestURI: %q. Expecting %q", req.Header.RequestURI(), "/foo/bar?baz=aaa") } + // authorization + authorization := req.Header.Get(string(bytestr.StrAuthorization)) + author, err := base64.StdEncoding.DecodeString(authorization[len(bytestr.StrBasicSpace):]) + if err != nil { + t.Fatalf("expecting error") + } + + if string(author) != "user:pass" { + t.Fatalf("unexpected Authorization: %q. Expecting %q", authorization, "user:pass") + } // verify that Write returns error on non-absolute RequestURI req.Reset() @@ -484,6 +496,38 @@ func TestRequestWriteRequestURINoHost(t *testing.T) { } } +func TestRequestWriteMultipartFile(t *testing.T) { + t.Parallel() + + var req protocol.Request + req.Header.SetHost("foobar.com") + req.Header.SetMethod(consts.MethodPost) + req.SetFileReader("filea", "filea.txt", bytes.NewReader([]byte("This is filea."))) + req.SetMultipartField("fileb", "fileb.txt", "text/plain", bytes.NewReader([]byte("This is fileb."))) + + var w bytes.Buffer + zw := netpoll.NewWriter(&w) + if err := Write(&req, zw); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if err := zw.Flush(); err != nil { + t.Fatalf("unexpected error: %s", err) + } + + var req1 protocol.Request + zr := mock.NewZeroCopyReader(w.String()) + if err := Read(&req1, zr); err != nil { + t.Fatalf("unexpected error: %s", err) + } + + filea, err := req1.FormFile("filea") + assert.Nil(t, err) + assert.DeepEqual(t, "filea.txt", filea.Filename) + fileb, err := req1.FormFile("fileb") + assert.Nil(t, err) + assert.DeepEqual(t, "fileb.txt", fileb.Filename) +} + func TestSetRequestBodyStreamChunked(t *testing.T) { t.Parallel() @@ -1037,7 +1081,7 @@ func testRequestMultipartFormNotPreParse(t *testing.T, boundary string, formData func testReadBodyChunked(t *testing.T, bodySize int) { body := mock.CreateFixedBody(bodySize) expectedTrailer := map[string]string{"Foo": "chunked shit"} - chunkedBody := createChunkedBody(body, expectedTrailer, true) + chunkedBody := mock.CreateChunkedBody(body, expectedTrailer, true) zr := mock.NewZeroCopyReader(string(chunkedBody)) @@ -1052,32 +1096,6 @@ func testReadBodyChunked(t *testing.T, bodySize int) { verifyTrailer(t, zr, expectedTrailer) } -func createChunkedBody(body []byte, trailer map[string]string, hasTrailer bool) []byte { - var b []byte - chunkSize := 1 - for len(body) > 0 { - if chunkSize > len(body) { - chunkSize = len(body) - } - b = append(b, []byte(fmt.Sprintf("%x\r\n", chunkSize))...) - b = append(b, body[:chunkSize]...) - b = append(b, []byte("\r\n")...) - body = body[chunkSize:] - chunkSize++ - } - if hasTrailer { - b = append(b, "0\r\n"...) - for k, v := range trailer { - b = append(b, k...) - b = append(b, ": "...) - b = append(b, v...) - b = append(b, "\r\n"...) - } - b = append(b, "\r\n"...) - } - return b -} - func testReadBodyFixedSize(t *testing.T, bodySize int) { body := mock.CreateFixedBody(bodySize) diff --git a/pkg/protocol/http1/resp/response_test.go b/pkg/protocol/http1/resp/response_test.go index cff0a5783..0ff010fd7 100644 --- a/pkg/protocol/http1/resp/response_test.go +++ b/pkg/protocol/http1/resp/response_test.go @@ -820,3 +820,9 @@ func TestResponseReadBodyStreamBadTrailer(t *testing.T) { testResponseReadBodyStreamBadTrailer(t, resp, "HTTP/1.1 300 OK\r\nTransfer-Encoding: chunked\r\nContent-Type: bar\r\n\r\n5\r\n56789\r\n0\r\ncontent-type: bar\r\n\r\n") testResponseReadBodyStreamBadTrailer(t, resp, "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nTransfer-Encoding: chunked\r\n\r\n4\r\nqwer\r\n2\r\nty\r\n0\r\nproxy-connection: bar2\r\n\r\n") } + +func TestResponseString(t *testing.T) { + resp := protocol.Response{} + resp.Header.Set("Location", "foo\r\nSet-Cookie: SESSIONID=MaliciousValue\r\n") + assert.True(t, strings.Contains(GetHTTP1Response(&resp).String(), "Location: foo\r\nSet-Cookie: SESSIONID=MaliciousValue\r\n")) +} diff --git a/pkg/protocol/http1/server_test.go b/pkg/protocol/http1/server_test.go index dfa8b78d3..e34bb3f9a 100644 --- a/pkg/protocol/http1/server_test.go +++ b/pkg/protocol/http1/server_test.go @@ -20,8 +20,10 @@ import ( "bytes" "context" "errors" + "strings" "sync" "testing" + "time" inStats "github.com/cloudwego/hertz/internal/stats" "github.com/cloudwego/hertz/pkg/app" @@ -33,6 +35,7 @@ import ( "github.com/cloudwego/hertz/pkg/common/tracer/traceinfo" "github.com/cloudwego/hertz/pkg/network" "github.com/cloudwego/hertz/pkg/protocol" + "github.com/cloudwego/hertz/pkg/protocol/consts" "github.com/cloudwego/hertz/pkg/protocol/http1/resp" ) @@ -247,14 +250,124 @@ func TestHijackResponseWriter(t *testing.T) { assert.True(t, isFinal) } +func TestHijackHandler(t *testing.T) { + server := NewServer() + reqCtx := &app.RequestContext{} + originReadTimeout := time.Second + hijackReadTimeout := 200 * time.Millisecond + reqCtx.SetHijackHandler(func(c network.Conn) { + c.SetReadTimeout(hijackReadTimeout) // hijack read timeout + }) + + server.Core = &mockCore{ + ctxPool: &sync.Pool{New: func() interface{} { + return reqCtx + }}, + } + + server.HijackConnHandle = func(c network.Conn, h app.HijackHandler) { + h(c) + } + + defaultConn := mock.NewConn("GET / HTTP/1.1\nHost: foobar.com\n\n") + defaultConn.SetReadTimeout(originReadTimeout) + assert.DeepEqual(t, originReadTimeout, defaultConn.GetReadTimeout()) + err := server.Serve(context.TODO(), defaultConn) + assert.True(t, errors.Is(err, errs.ErrHijacked)) + assert.DeepEqual(t, hijackReadTimeout, defaultConn.GetReadTimeout()) +} + +func TestKeepAlive(t *testing.T) { + server := NewServer() + reqCtx := &app.RequestContext{} + times := 0 + server.Core = &mockCore{ + ctxPool: &sync.Pool{New: func() interface{} { + return reqCtx + }}, + isRunning: true, + mockHandler: func(c context.Context, ctx *app.RequestContext) { + times++ + if string(ctx.Path()) == "/close" { + ctx.SetConnectionClose() + } + }, + } + server.IdleTimeout = time.Second + + var s strings.Builder + s.WriteString("GET / HTTP/1.1\r\nHost: aaa\r\nConnection: keep-alive\r\n\r\n") + s.WriteString("GET /close HTTP/1.0\r\nHost: aaa\r\nConnection: keep-alive\r\n\r\n") // set connection close + + defaultConn := mock.NewConn(s.String()) + err := server.Serve(context.TODO(), defaultConn) + assert.True(t, errors.Is(err, errs.ErrShortConnection)) + assert.DeepEqual(t, times, 2) +} + +func TestExpect100Continue(t *testing.T) { + server := &Server{} + reqCtx := &app.RequestContext{} + server.Core = &mockCore{ + ctxPool: &sync.Pool{New: func() interface{} { + return reqCtx + }}, + mockHandler: func(c context.Context, ctx *app.RequestContext) { + data, err := ctx.Body() + if err == nil { + ctx.Write(data) + } + }, + } + + defaultConn := mock.NewConn("POST /foo HTTP/1.1\r\nHost: gle.com\r\nExpect: 100-continue\r\nContent-Length: 5\r\nContent-Type: a/b\r\n\r\n12345") + err := server.Serve(context.TODO(), defaultConn) + assert.True(t, errors.Is(err, errs.ErrShortConnection)) + defaultResponseResult := defaultConn.WriterRecorder() + assert.DeepEqual(t, 0, defaultResponseResult.Len()) + response := protocol.AcquireResponse() + resp.Read(response, defaultResponseResult) + assert.DeepEqual(t, "12345", string(response.Body())) +} + +func TestExpect100ContinueHandler(t *testing.T) { + server := &Server{} + reqCtx := &app.RequestContext{} + server.Core = &mockCore{ + ctxPool: &sync.Pool{New: func() interface{} { + return reqCtx + }}, + mockHandler: func(c context.Context, ctx *app.RequestContext) { + data, err := ctx.Body() + if err == nil { + ctx.Write(data) + } + }, + } + server.ContinueHandler = func(header *protocol.RequestHeader) bool { + return false + } + + defaultConn := mock.NewConn("POST /foo HTTP/1.1\r\nHost: gle.com\r\nExpect: 100-continue\r\nContent-Length: 5\r\nContent-Type: a/b\r\n\r\n12345") + err := server.Serve(context.TODO(), defaultConn) + assert.True(t, errors.Is(err, errs.ErrShortConnection)) + defaultResponseResult := defaultConn.WriterRecorder() + assert.DeepEqual(t, 0, defaultResponseResult.Len()) + response := protocol.AcquireResponse() + resp.Read(response, defaultResponseResult) + assert.DeepEqual(t, consts.StatusExpectationFailed, response.StatusCode()) + assert.DeepEqual(t, "", string(response.Body())) +} + type mockCore struct { ctxPool *sync.Pool controller tracer.Controller mockHandler func(c context.Context, ctx *app.RequestContext) + isRunning bool } func (m *mockCore) IsRunning() bool { - return false + return m.isRunning } func (m *mockCore) GetCtxPool() *sync.Pool { From e1d541bb909376012ec9db6ac59f952424549ce8 Mon Sep 17 00:00:00 2001 From: GuangyuFan <97507466+FGYFFFF@users.noreply.github.com> Date: Tue, 18 Jul 2023 20:34:59 +0800 Subject: [PATCH 03/17] feat(hz): remove tag (#811) --- _typos.toml | 3 ++- cmd/hz/app/app.go | 5 +++++ cmd/hz/config/argument.go | 2 ++ cmd/hz/protobuf/plugin.go | 40 ++++++++++++++++++++++++++++++--------- cmd/hz/thrift/plugin.go | 13 ++++++++++--- 5 files changed, 50 insertions(+), 13 deletions(-) diff --git a/_typos.toml b/_typos.toml index 9808229cc..3d3103e19 100644 --- a/_typos.toml +++ b/_typos.toml @@ -17,4 +17,5 @@ referer = "referer" HeaderReferer = "HeaderReferer" expectedReferer = "expectedReferer" Referer = "Referer" -O_WRONLY = "O_WRONLY" \ No newline at end of file +O_WRONLY = "O_WRONLY" +WRONLY = "WRONLY" \ No newline at end of file diff --git a/cmd/hz/app/app.go b/cmd/hz/app/app.go index 239b4df21..269e5bad9 100644 --- a/cmd/hz/app/app.go +++ b/cmd/hz/app/app.go @@ -184,6 +184,7 @@ func Init() *cli.App { unsetOmitemptyFlag := cli.BoolFlag{Name: "unset_omitempty", Usage: "Remove 'omitempty' tag for generated struct.", Destination: &globalArgs.UnsetOmitempty} protoCamelJSONTag := cli.BoolFlag{Name: "pb_camel_json_tag", Usage: "Convert Name style for json tag to camel(Only works protobuf).", Destination: &globalArgs.ProtobufCamelJSONTag} snakeNameFlag := cli.BoolFlag{Name: "snake_tag", Usage: "Use snake_case style naming for tags. (Only works for 'form', 'query', 'json')", Destination: &globalArgs.SnakeName} + rmTagFlag := cli.StringSliceFlag{Name: "rm_tag", Usage: "Remove the specified tag"} customLayout := cli.StringFlag{Name: "customize_layout", Usage: "Specify the path for layout template.", Destination: &globalArgs.CustomizeLayout} customLayoutData := cli.StringFlag{Name: "customize_layout_data_path", Usage: "Specify the path for layout template render data.", Destination: &globalArgs.CustomizeLayoutData} customPackage := cli.StringFlag{Name: "customize_package", Usage: "Specify the path for package template.", Destination: &globalArgs.CustomizePackage} @@ -230,6 +231,7 @@ func Init() *cli.App { &unsetOmitemptyFlag, &protoCamelJSONTag, &snakeNameFlag, + &rmTagFlag, &excludeFilesFlag, &customLayout, &customLayoutData, @@ -263,6 +265,7 @@ func Init() *cli.App { &unsetOmitemptyFlag, &protoCamelJSONTag, &snakeNameFlag, + &rmTagFlag, &excludeFilesFlag, &customPackage, &handlerByMethod, @@ -289,6 +292,7 @@ func Init() *cli.App { &unsetOmitemptyFlag, &protoCamelJSONTag, &snakeNameFlag, + &rmTagFlag, &excludeFilesFlag, }, Action: Model, @@ -315,6 +319,7 @@ func Init() *cli.App { &unsetOmitemptyFlag, &protoCamelJSONTag, &snakeNameFlag, + &rmTagFlag, &excludeFilesFlag, &customLayout, &customLayoutData, diff --git a/cmd/hz/config/argument.go b/cmd/hz/config/argument.go index 8eef041d2..b8bf41cac 100644 --- a/cmd/hz/config/argument.go +++ b/cmd/hz/config/argument.go @@ -64,6 +64,7 @@ type Argument struct { ProtobufPlugins []string ThriftPlugins []string SnakeName bool + RmTags []string Excludes []string NoRecurse bool HandlerByMethod bool @@ -121,6 +122,7 @@ func (arg *Argument) parseStringSlice(c *cli.Context) { arg.ProtocOptions = c.StringSlice("protoc") arg.ThriftPlugins = c.StringSlice("thrift-plugins") arg.ProtobufPlugins = c.StringSlice("protoc-plugins") + arg.RmTags = c.StringSlice("rm_tag") } func (arg *Argument) UpdateByManifest(m *meta.Manifest) { diff --git a/cmd/hz/protobuf/plugin.go b/cmd/hz/protobuf/plugin.go index 82b01cd6a..e25803860 100644 --- a/cmd/hz/protobuf/plugin.go +++ b/cmd/hz/protobuf/plugin.go @@ -76,10 +76,22 @@ type Plugin struct { ModelDir string UseDir string IdlClientDir string + RmTags RemoveTags PkgMap map[string]string logger *logs.StdLogger } +type RemoveTags []string + +func (rm *RemoveTags) Exist(tag string) bool { + for _, rmTag := range *rm { + if rmTag == tag { + return true + } + } + return false +} + func (plugin *Plugin) Run() int { plugin.setLogger() args := &config.Argument{} @@ -192,6 +204,7 @@ func (plugin *Plugin) Handle(req *pluginpb.CodeGeneratorRequest, args *config.Ar opts := protogen.Options{} gen, err := opts.New(req) plugin.Plugin = gen + plugin.RmTags = args.RmTags if err != nil { return fmt.Errorf("new protoc plugin failed: %s", err.Error()) } @@ -322,7 +335,6 @@ func (plugin *Plugin) fixModelPathAndPackage(pkg string) (impt, path string) { func (plugin *Plugin) GenerateFiles(pluginPb *protogen.Plugin) error { idl := pluginPb.Request.FileToGenerate[len(pluginPb.Request.FileToGenerate)-1] pluginPb.SupportedFeatures = gengo.SupportedFeatures - for _, f := range pluginPb.Files { if f.Proto.GetName() == idl { err := plugin.GenerateFile(pluginPb, f) @@ -358,7 +370,7 @@ func (plugin *Plugin) GenerateFile(gen *protogen.Plugin, f *protogen.File) error if len(plugin.UseDir) != 0 { return nil } - file, err := generateFile(gen, f) + file, err := generateFile(gen, f, plugin.RmTags) if err != nil || file == nil { return fmt.Errorf("generate file %s failed: %s", f.Proto.GetName(), err.Error()) } @@ -366,7 +378,7 @@ func (plugin *Plugin) GenerateFile(gen *protogen.Plugin, f *protogen.File) error } // generateFile generates the contents of a .pb.go file. -func generateFile(gen *protogen.Plugin, file *protogen.File) (*protogen.GeneratedFile, error) { +func generateFile(gen *protogen.Plugin, file *protogen.File, rmTags RemoveTags) (*protogen.GeneratedFile, error) { filename := file.GeneratedFilenamePrefix + ".pb.go" g := gen.NewGeneratedFile(filename, file.GoImportPath) f := newFileInfo(file) @@ -398,7 +410,7 @@ func generateFile(gen *protogen.Plugin, file *protogen.File) (*protogen.Generate } var err error for _, message := range f.allMessages { - err = genMessage(g, f, message) + err = genMessage(g, f, message, rmTags) if err != nil { return nil, err } @@ -410,7 +422,7 @@ func generateFile(gen *protogen.Plugin, file *protogen.File) (*protogen.Generate return g, nil } -func genMessage(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo) error { +func genMessage(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo, rmTags RemoveTags) error { if m.Desc.IsMapEntry() { return nil } @@ -421,7 +433,7 @@ func genMessage(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo) error { m.Desc.Options().(*descriptorpb.MessageOptions).GetDeprecated()) g.P(leadingComments, "type ", m.GoIdent, " struct {") - err := genMessageFields(g, f, m) + err := genMessageFields(g, f, m, rmTags) if err != nil { return err } @@ -435,12 +447,12 @@ func genMessage(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo) error { return nil } -func genMessageFields(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo) error { +func genMessageFields(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo, rmTags RemoveTags) error { sf := f.allMessageFieldsByPtr[m] genMessageInternalFields(g, f, m, sf) var err error for _, field := range m.Fields { - err = genMessageField(g, f, m, field, sf) + err = genMessageField(g, f, m, field, sf, rmTags) if err != nil { return err } @@ -448,7 +460,7 @@ func genMessageFields(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo) er return nil } -func genMessageField(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo, field *protogen.Field, sf *structFields) error { +func genMessageField(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo, field *protogen.Field, sf *structFields, rmTags RemoveTags) error { if oneof := field.Oneof; oneof != nil && !oneof.Desc.IsSynthetic() { // It would be a bit simpler to iterate over the oneofs below, // but generating the field here keeps the contents of the Go @@ -506,6 +518,16 @@ func genMessageField(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo, fie tags = append(tags, gotrackTags...) } + if len(rmTags) > 0 { + tmp := structTags{} + for _, tag := range tags { + if !rmTags.Exist(tag[0]) { + tmp = append(tmp, tag) + } + } + tags = tmp + } + name := field.GoName if field.Desc.IsWeak() { name = WeakFieldPrefix_goname + name diff --git a/cmd/hz/thrift/plugin.go b/cmd/hz/thrift/plugin.go index 23c5c7597..ff4fdf65f 100644 --- a/cmd/hz/thrift/plugin.go +++ b/cmd/hz/thrift/plugin.go @@ -41,6 +41,7 @@ type Plugin struct { req *thriftgo_plugin.Request args *config.Argument logger *logs.StdLogger + rmTags []string } func (plugin *Plugin) Run() int { @@ -74,6 +75,7 @@ func (plugin *Plugin) Run() int { logs.Errorf("parse args failed: %s", err.Error()) return meta.PluginError } + plugin.rmTags = args.RmTags if args.CmdType == meta.CmdModel { res, err := plugin.GetResponse(nil, args.OutDir) if err != nil { @@ -334,7 +336,7 @@ func (plugin *Plugin) InsertTag() ([]*thriftgo_plugin.Generated, error) { stName := st.GetName() for _, f := range st.Fields { fieldName := f.GetName() - tagString, err := getTagString(f) + tagString, err := getTagString(f, plugin.rmTags) if err != nil { return nil, err } @@ -360,7 +362,7 @@ func (plugin *Plugin) InsertTag() ([]*thriftgo_plugin.Generated, error) { stName := st.GetName() for _, f := range st.Fields { fieldName := f.GetName() - tagString, err := getTagString(f) + tagString, err := getTagString(f, plugin.rmTags) if err != nil { return nil, err } @@ -400,7 +402,7 @@ func (plugin *Plugin) GetResponse(files []generator.File, outputDir string) (*th }, nil } -func getTagString(f *parser.Field) (string, error) { +func getTagString(f *parser.Field, rmTags []string) (string, error) { field := model.Field{} err := injectTags(f, &field, true, false) if err != nil { @@ -412,6 +414,11 @@ func getTagString(f *parser.Field) (string, error) { disableTag = true } } + + for _, rmTag := range rmTags { + field.Tags.Remove(rmTag) + } + var tagString string tags := field.Tags for idx, tag := range tags { From 1018b5dba033144c784f034a2a9a89c64750fe1c Mon Sep 17 00:00:00 2001 From: GuangyuFan <97507466+FGYFFFF@users.noreply.github.com> Date: Tue, 18 Jul 2023 20:35:30 +0800 Subject: [PATCH 04/17] feat(hz): add validator anno (#799) --- cmd/hz/protobuf/api/api.proto | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/cmd/hz/protobuf/api/api.proto b/cmd/hz/protobuf/api/api.proto index a41e2f8d6..72edf41ff 100644 --- a/cmd/hz/protobuf/api/api.proto +++ b/cmd/hz/protobuf/api/api.proto @@ -24,6 +24,8 @@ extend google.protobuf.FieldOptions { optional string js_conv_compatible = 50132; optional string file_name_compatible = 50133; optional string none_compatible = 50134; + // 50135 is reserved to vt_compatible + // optional FieldRules vt_compatible = 50135; optional string go_tag = 51001; } @@ -62,4 +64,12 @@ extend google.protobuf.ServiceOptions { // 50731~50760 used to extend service option by hz optional string base_domain_compatible = 50731; +} + +extend google.protobuf.MessageOptions { + // optional FieldRules msg_vt = 50111; + + optional string reserve = 50830; + // 550831 is reserved to msg_vt_compatible + // optional FieldRules msg_vt_compatible = 50831; } \ No newline at end of file From 3ecd9868bf0c03a43ae304295b9660fc83767b82 Mon Sep 17 00:00:00 2001 From: GuangyuFan <97507466+FGYFFFF@users.noreply.github.com> Date: Tue, 18 Jul 2023 20:36:44 +0800 Subject: [PATCH 05/17] feat(hz): add more template parameters (#795) --- cmd/hz/generator/custom_files.go | 49 ++++++++++++++++++-------------- 1 file changed, 27 insertions(+), 22 deletions(-) diff --git a/cmd/hz/generator/custom_files.go b/cmd/hz/generator/custom_files.go index 390a40d60..1efaae07a 100644 --- a/cmd/hz/generator/custom_files.go +++ b/cmd/hz/generator/custom_files.go @@ -48,16 +48,17 @@ type IDLPackageRenderInfo struct { type CustomizedFileForMethod struct { *HttpMethod - FilePath string - FilePackage string - ServiceInfo *Service // service info for that method + FilePath string + FilePackage string + ServiceInfo *Service // service info for this method + IDLPackageInfo *IDLPackageRenderInfo // IDL info for this service } type CustomizedFileForService struct { *Service FilePath string FilePackage string - IDLPackageInfo *IDLPackageRenderInfo // IDL info for that service + IDLPackageInfo *IDLPackageRenderInfo // IDL info for this service } type CustomizedFileForIDL struct { @@ -110,7 +111,7 @@ func (pkgGen *HttpPackageGenerator) genCustomizedFile(pkg *HttpPackage) error { filePathRenderInfo.ServiceName = service.Name filePathRenderInfo.MethodName = method.Name filePathRenderInfo.HandlerGenPath = method.OutputDir - err := pkgGen.genLoopMethod(tplInfo, filePathRenderInfo, method, service) + err := pkgGen.genLoopMethod(tplInfo, filePathRenderInfo, method, service, &idlPackageRenderInfo) if err != nil { return err } @@ -323,10 +324,11 @@ func (pkgGen *HttpPackageGenerator) genLoopService(tplInfo *Template, filePathRe if tplInfo.UpdateBehavior.AppendKey == "method" { for _, method := range service.Methods { data := CustomizedFileForMethod{ - HttpMethod: method, - FilePath: filePath, - FilePackage: util.SplitPackage(filepath.Dir(filePath), ""), - ServiceInfo: service, + HttpMethod: method, + FilePath: filePath, + FilePackage: util.SplitPackage(filepath.Dir(filePath), ""), + ServiceInfo: service, + IDLPackageInfo: idlPackageRenderInfo, } insertKey, err := renderInsertKey(tplInfo, data) if err != nil { @@ -395,7 +397,7 @@ func (pkgGen *HttpPackageGenerator) genLoopService(tplInfo *Template, filePathRe } // genLoopMethod used to generate files by 'method' -func (pkgGen *HttpPackageGenerator) genLoopMethod(tplInfo *Template, filePathRenderInfo FilePathRenderInfo, method *HttpMethod, service *Service) error { +func (pkgGen *HttpPackageGenerator) genLoopMethod(tplInfo *Template, filePathRenderInfo FilePathRenderInfo, method *HttpMethod, service *Service, idlPackageRenderInfo *IDLPackageRenderInfo) error { filePath, err := renderFilePath(tplInfo, filePathRenderInfo) if err != nil { return err @@ -408,10 +410,11 @@ func (pkgGen *HttpPackageGenerator) genLoopMethod(tplInfo *Template, filePathRen if !exist { // create file data := CustomizedFileForMethod{ - HttpMethod: method, - FilePath: filePath, - FilePackage: util.SplitPackage(filepath.Dir(filePath), ""), - ServiceInfo: service, + HttpMethod: method, + FilePath: filePath, + FilePackage: util.SplitPackage(filepath.Dir(filePath), ""), + ServiceInfo: service, + IDLPackageInfo: idlPackageRenderInfo, } err := pkgGen.TemplateGenerator.Generate(data, tplInfo.Path, filePath, false) if err != nil { @@ -426,10 +429,11 @@ func (pkgGen *HttpPackageGenerator) genLoopMethod(tplInfo *Template, filePathRen // re-generate logs.Infof("re-generate file '%s', because the update behavior is 'Regenerate'", filePath) data := CustomizedFileForMethod{ - HttpMethod: method, - FilePath: filePath, - FilePackage: util.SplitPackage(filepath.Dir(filePath), ""), - ServiceInfo: service, + HttpMethod: method, + FilePath: filePath, + FilePackage: util.SplitPackage(filepath.Dir(filePath), ""), + ServiceInfo: service, + IDLPackageInfo: idlPackageRenderInfo, } err := pkgGen.TemplateGenerator.Generate(data, tplInfo.Path, filePath, false) if err != nil { @@ -497,10 +501,11 @@ func (pkgGen *HttpPackageGenerator) genSingleCustomizedFile(tplInfo *Template, f for _, service := range idlPackageRenderInfo.ServiceInfos.Services { for _, method := range service.Methods { data := CustomizedFileForMethod{ - HttpMethod: method, - FilePath: filePath, - FilePackage: util.SplitPackage(filepath.Dir(filePath), ""), - ServiceInfo: service, + HttpMethod: method, + FilePath: filePath, + FilePackage: util.SplitPackage(filepath.Dir(filePath), ""), + ServiceInfo: service, + IDLPackageInfo: &idlPackageRenderInfo, } insertKey, err := renderInsertKey(tplInfo, data) if err != nil { From 90e1ff059ae52b8051fd029fcab649088ef9d782 Mon Sep 17 00:00:00 2001 From: GuangyuFan <97507466+FGYFFFF@users.noreply.github.com> Date: Tue, 18 Jul 2023 20:38:01 +0800 Subject: [PATCH 06/17] fix(hz): client snake tag name (#851) --- cmd/hz/protobuf/ast.go | 12 ++++++------ cmd/hz/thrift/ast.go | 12 ++++++------ 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/cmd/hz/protobuf/ast.go b/cmd/hz/protobuf/ast.go index 2abe83372..292c0a461 100644 --- a/cmd/hz/protobuf/ast.go +++ b/cmd/hz/protobuf/ast.go @@ -269,14 +269,14 @@ func parseAnnotationToClient(clientMethod *generator.ClientMethod, gen *protogen if proto.HasExtension(f.Desc.Options(), api.E_Query) { hasAnnotation = true queryAnnos := proto.GetExtension(f.Desc.Options(), api.E_Query) - val := queryAnnos.(string) + val := checkSnakeName(queryAnnos.(string)) clientMethod.QueryParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", val, f.GoName) } if proto.HasExtension(f.Desc.Options(), api.E_Path) { hasAnnotation = true pathAnnos := proto.GetExtension(f.Desc.Options(), api.E_Path) - val := pathAnnos.(string) + val := checkSnakeName(pathAnnos.(string)) if isStringFieldType { clientMethod.PathParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", val, f.GoName) } else { @@ -287,7 +287,7 @@ func parseAnnotationToClient(clientMethod *generator.ClientMethod, gen *protogen if proto.HasExtension(f.Desc.Options(), api.E_Header) { hasAnnotation = true headerAnnos := proto.GetExtension(f.Desc.Options(), api.E_Header) - val := headerAnnos.(string) + val := checkSnakeName(headerAnnos.(string)) if isStringFieldType { clientMethod.HeaderParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", val, f.GoName) } else { @@ -298,7 +298,7 @@ func parseAnnotationToClient(clientMethod *generator.ClientMethod, gen *protogen if formAnnos := getCompatibleAnnotation(f.Desc.Options(), api.E_Form, api.E_FormCompatible); formAnnos != nil { hasAnnotation = true hasFormAnnotation = true - val := formAnnos.(string) + val := checkSnakeName(formAnnos.(string)) if isStringFieldType { clientMethod.FormValueCode += fmt.Sprintf("%q: req.Get%s(),\n", val, f.GoName) } else { @@ -314,11 +314,11 @@ func parseAnnotationToClient(clientMethod *generator.ClientMethod, gen *protogen if fileAnnos := getCompatibleAnnotation(f.Desc.Options(), api.E_FileName, api.E_FileNameCompatible); fileAnnos != nil { hasAnnotation = true hasFormAnnotation = true - val := fileAnnos.(string) + val := checkSnakeName(fileAnnos.(string)) clientMethod.FormFileCode += fmt.Sprintf("%q: req.Get%s(),\n", val, f.GoName) } if !hasAnnotation && strings.EqualFold(clientMethod.HTTPMethod, "get") { - clientMethod.QueryParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", f.GoName, f.GoName) + clientMethod.QueryParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", checkSnakeName(f.GoName), f.GoName) } } clientMethod.BodyParamsCode = meta.SetBodyParam diff --git a/cmd/hz/thrift/ast.go b/cmd/hz/thrift/ast.go index 298bad006..af7b05401 100644 --- a/cmd/hz/thrift/ast.go +++ b/cmd/hz/thrift/ast.go @@ -217,13 +217,13 @@ func parseAnnotationToClient(clientMethod *generator.ClientMethod, p *parser.Typ } if anno := getAnnotation(field.Annotations, AnnotationQuery); len(anno) > 0 { hasAnnotation = true - query := anno[0] + query := checkSnakeName(anno[0]) clientMethod.QueryParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", query, field.GoName().String()) } if anno := getAnnotation(field.Annotations, AnnotationPath); len(anno) > 0 { hasAnnotation = true - path := anno[0] + path := checkSnakeName(anno[0]) if isStringFieldType { clientMethod.PathParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", path, field.GoName().String()) } else { @@ -233,7 +233,7 @@ func parseAnnotationToClient(clientMethod *generator.ClientMethod, p *parser.Typ if anno := getAnnotation(field.Annotations, AnnotationHeader); len(anno) > 0 { hasAnnotation = true - header := anno[0] + header := checkSnakeName(anno[0]) if isStringFieldType { clientMethod.HeaderParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", header, field.GoName().String()) } else { @@ -243,7 +243,7 @@ func parseAnnotationToClient(clientMethod *generator.ClientMethod, p *parser.Typ if anno := getAnnotation(field.Annotations, AnnotationForm); len(anno) > 0 { hasAnnotation = true - form := anno[0] + form := checkSnakeName(anno[0]) hasFormAnnotation = true if isStringFieldType { clientMethod.FormValueCode += fmt.Sprintf("%q: req.Get%s(),\n", form, field.GoName().String()) @@ -259,12 +259,12 @@ func parseAnnotationToClient(clientMethod *generator.ClientMethod, p *parser.Typ if anno := getAnnotation(field.Annotations, AnnotationFileName); len(anno) > 0 { hasAnnotation = true - fileName := anno[0] + fileName := checkSnakeName(anno[0]) hasFormAnnotation = true clientMethod.FormFileCode += fmt.Sprintf("%q: req.Get%s(),\n", fileName, field.GoName().String()) } if !hasAnnotation && strings.EqualFold(clientMethod.HTTPMethod, "get") { - clientMethod.QueryParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", field.GoName().String(), field.GoName().String()) + clientMethod.QueryParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", checkSnakeName(field.GoName().String()), field.GoName().String()) } } clientMethod.BodyParamsCode = meta.SetBodyParam From 05d9ab4a7a3ae47b5389084be56045299ecfb301 Mon Sep 17 00:00:00 2001 From: GuangyuFan <97507466+FGYFFFF@users.noreply.github.com> Date: Tue, 18 Jul 2023 20:38:52 +0800 Subject: [PATCH 07/17] feat(hz): add default service name (#852) --- cmd/hz/generator/layout.go | 7 ++++++- cmd/hz/meta/const.go | 2 ++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/cmd/hz/generator/layout.go b/cmd/hz/generator/layout.go index 3791955b5..d007fa82e 100644 --- a/cmd/hz/generator/layout.go +++ b/cmd/hz/generator/layout.go @@ -25,6 +25,7 @@ import ( "reflect" "strings" + "github.com/cloudwego/hertz/cmd/hz/meta" "github.com/cloudwego/hertz/cmd/hz/util" "gopkg.in/yaml.v2" ) @@ -169,10 +170,14 @@ func serviceToLayoutData(service Layout) (map[string]interface{}, error) { if len(service.RouterDir) != 0 { routerPkg = filepath.Base(service.RouterDir) } + serviceName := service.ServiceName + if len(serviceName) == 0 { + serviceName = meta.DefaultServiceName + } return map[string]interface{}{ "GoModule": goMod, - "ServiceName": service.ServiceName, + "ServiceName": serviceName, "UseApacheThrift": service.UseApacheThrift, "HandlerPkg": handlerPkg, "RouterPkg": routerPkg, diff --git a/cmd/hz/meta/const.go b/cmd/hz/meta/const.go index 6ec98d6e2..66f0a1fb4 100644 --- a/cmd/hz/meta/const.go +++ b/cmd/hz/meta/const.go @@ -21,6 +21,8 @@ import "runtime" // Version hz version const Version = "v0.6.5" +const DefaultServiceName = "hertz_service" + // Mode hz run modes type Mode int From 1c49d8b8fe7e8d95a6998c5022a45ff22671b94f Mon Sep 17 00:00:00 2001 From: kinggo Date: Tue, 18 Jul 2023 20:52:21 +0800 Subject: [PATCH 08/17] optimize: modify judgment conditions when use retry and err = nil (#860) --- pkg/app/client/client_test.go | 51 +++++++++++++++++++++++++++++++++++ pkg/protocol/http1/client.go | 3 ++- 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/pkg/app/client/client_test.go b/pkg/app/client/client_test.go index 84c6a918e..ba00f78a8 100644 --- a/pkg/app/client/client_test.go +++ b/pkg/app/client/client_test.go @@ -2296,3 +2296,54 @@ func TestClientState(t *testing.T) { client.Get(context.Background(), nil, "http://127.0.0.1:11000") time.Sleep(time.Second * 22) } + +func TestClientRetryErr(t *testing.T) { + t.Run("200", func(t *testing.T) { + opt := config.NewOptions([]config.Option{}) + opt.Addr = "127.0.0.1:10136" + engine := route.NewEngine(opt) + var l sync.Mutex + retryNum := 0 + engine.GET("/ping", func(c context.Context, ctx *app.RequestContext) { + l.Lock() + defer l.Unlock() + retryNum += 1 + ctx.SetStatusCode(200) + }) + go engine.Run() + time.Sleep(100 * time.Millisecond) + c, _ := NewClient(WithRetryConfig(retry.WithMaxAttemptTimes(3))) + _, _, err := c.Get(context.Background(), nil, "http://127.0.0.1:10136/ping") + assert.Nil(t, err) + l.Lock() + assert.DeepEqual(t, 1, retryNum) + l.Unlock() + engine.Close() + }) + + t.Run("502", func(t *testing.T) { + opt := config.NewOptions([]config.Option{}) + opt.Addr = "127.0.0.1:10137" + engine := route.NewEngine(opt) + var l sync.Mutex + retryNum := 0 + engine.GET("/ping", func(c context.Context, ctx *app.RequestContext) { + l.Lock() + defer l.Unlock() + retryNum += 1 + ctx.SetStatusCode(502) + }) + go engine.Run() + time.Sleep(100 * time.Millisecond) + c, _ := NewClient(WithRetryConfig(retry.WithMaxAttemptTimes(3))) + c.SetRetryIfFunc(func(req *protocol.Request, resp *protocol.Response, err error) bool { + return resp.StatusCode() == 502 + }) + _, _, err := c.Get(context.Background(), nil, "http://127.0.0.1:10137/ping") + assert.Nil(t, err) + l.Lock() + assert.DeepEqual(t, 3, retryNum) + l.Unlock() + engine.Close() + }) +} diff --git a/pkg/protocol/http1/client.go b/pkg/protocol/http1/client.go index 13f07a149..34bb5f014 100644 --- a/pkg/protocol/http1/client.go +++ b/pkg/protocol/http1/client.go @@ -381,7 +381,8 @@ func (c *HostClient) Do(ctx context.Context, req *protocol.Request, resp *protoc req.Options().StartRequest() for { canIdempotentRetry, err = c.do(req, resp) - if err == nil { + // If there is no custom retry and err is equal to nil, the loop simply exits. + if err == nil && isDefaultRetryFunc { break } From 9014502d33ba3f1d46937173953ce6741d95f8d1 Mon Sep 17 00:00:00 2001 From: GuangyuFan <97507466+FGYFFFF@users.noreply.github.com> Date: Wed, 19 Jul 2023 10:45:14 +0800 Subject: [PATCH 09/17] feat(hz): add more template parameter (#854) --- cmd/hz/generator/handler.go | 26 +++++++++++++---------- cmd/hz/protobuf/ast.go | 12 +++++++++-- cmd/hz/thrift/ast.go | 42 +++++++++++++++++++++++++++++-------- 3 files changed, 58 insertions(+), 22 deletions(-) diff --git a/cmd/hz/generator/handler.go b/cmd/hz/generator/handler.go index cecb0d038..4d797d26e 100644 --- a/cmd/hz/generator/handler.go +++ b/cmd/hz/generator/handler.go @@ -29,17 +29,21 @@ import ( ) type HttpMethod struct { - Name string - HTTPMethod string - Comment string - RequestTypeName string - ReturnTypeName string - Path string - Serializer string - OutputDir string - RefPackage string - RefPackageAlias string - ModelPackage map[string]string + Name string + HTTPMethod string + Comment string + RequestTypeName string + RequestTypePackage string + RequestTypeRawName string + ReturnTypeName string + ReturnTypePackage string + ReturnTypeRawName string + Path string + Serializer string + OutputDir string + RefPackage string // handler import dir + RefPackageAlias string // handler import alias + ModelPackage map[string]string // Annotations map[string]string Models map[string]*model.Model } diff --git a/cmd/hz/protobuf/ast.go b/cmd/hz/protobuf/ast.go index 292c0a461..fcb8197eb 100644 --- a/cmd/hz/protobuf/ast.go +++ b/cmd/hz/protobuf/ast.go @@ -158,16 +158,20 @@ func astToService(ast *descriptorpb.FileDescriptorProto, resolver *Resolver, cmd reqName := m.GetInputType() sb, err := resolver.ResolveIdentifier(reqName) - reqName = util.BaseName(sb.Scope.GetOptions().GetGoPackage(), "") + "." + inputGoType.GoIdent.GoName if err != nil { return nil, err } + reqName = util.BaseName(sb.Scope.GetOptions().GetGoPackage(), "") + "." + inputGoType.GoIdent.GoName + reqRawName := inputGoType.GoIdent.GoName + reqPackage := util.BaseName(sb.Scope.GetOptions().GetGoPackage(), "") respName := m.GetOutputType() st, err := resolver.ResolveIdentifier(respName) - respName = util.BaseName(st.Scope.GetOptions().GetGoPackage(), "") + "." + outputGoType.GoIdent.GoName if err != nil { return nil, err } + respName = util.BaseName(st.Scope.GetOptions().GetGoPackage(), "") + "." + outputGoType.GoIdent.GoName + respRawName := outputGoType.GoIdent.GoName + respPackage := util.BaseName(sb.Scope.GetOptions().GetGoPackage(), "") var serializer string sl, sv := checkFirstOptions(SerializerOptions, m.GetOptions()) @@ -212,7 +216,11 @@ func astToService(ast *descriptorpb.FileDescriptorProto, resolver *Resolver, cmd respName = goOptMapAlias[st.Scope.GetOptions().GetGoPackage()] + "." + outputGoType.GoIdent.GoName } method.RequestTypeName = reqName + method.RequestTypeRawName = reqRawName + method.RequestTypePackage = reqPackage method.ReturnTypeName = respName + method.ReturnTypeRawName = respRawName + method.ReturnTypePackage = respPackage methods = append(methods, method) diff --git a/cmd/hz/thrift/ast.go b/cmd/hz/thrift/ast.go index af7b05401..557b50a41 100644 --- a/cmd/hz/thrift/ast.go +++ b/cmd/hz/thrift/ast.go @@ -122,7 +122,7 @@ func astToService(ast *parser.Thrift, resolver *Resolver, args *config.Argument) return nil, fmt.Errorf("invalid api.%s for %s.%s: %s", hmethod, s.Name, m.Name, path) } - var reqName string + var reqName, reqRawName, reqPackage string if len(m.Arguments) >= 1 { if len(m.Arguments) > 1 { logs.Warnf("function '%s' has more than one argument, but only the first can be used in hertz now", m.GetName()) @@ -132,25 +132,49 @@ func astToService(ast *parser.Thrift, resolver *Resolver, args *config.Argument) if err != nil { return nil, err } + if strings.Contains(reqName, ".") && !m.Arguments[0].GetType().Category.IsContainerType() { + // If reqName contains "." , then it must be of the form "pkg.name". + // so reqRawName='name', reqPackage='pkg' + names := strings.Split(reqName, ".") + if len(names) != 2 { + return nil, fmt.Errorf("request name: %s is wrong", reqName) + } + reqRawName = names[1] + reqPackage = names[0] + } } - var respName string + var respName, respRawName, respPackage string if !m.Oneway { var err error respName, err = resolver.ResolveTypeName(m.GetFunctionType()) if err != nil { return nil, err } + if strings.Contains(respName, ".") && !m.GetFunctionType().Category.IsContainerType() { + names := strings.Split(respName, ".") + if len(names) != 2 { + return nil, fmt.Errorf("response name: %s is wrong", respName) + } + // If respName contains "." , then it must be of the form "pkg.name". + // so respRawName='name', respPackage='pkg' + respRawName = names[1] + respPackage = names[0] + } } sr, _ := util.GetFirstKV(getAnnotations(m.Annotations, SerializerTags)) method := &generator.HttpMethod{ - Name: util.CamelString(m.GetName()), - HTTPMethod: hmethod, - RequestTypeName: reqName, - ReturnTypeName: respName, - Path: path[0], - Serializer: sr, - OutputDir: handlerOutDir, + Name: util.CamelString(m.GetName()), + HTTPMethod: hmethod, + RequestTypeName: reqName, + RequestTypeRawName: reqRawName, + RequestTypePackage: reqPackage, + ReturnTypeName: respName, + ReturnTypeRawName: respRawName, + ReturnTypePackage: respPackage, + Path: path[0], + Serializer: sr, + OutputDir: handlerOutDir, // Annotations: m.Annotations, } refs := resolver.ExportReferred(false, true) From eca144b85291fe0f0976ddf7f92d15c559fd6257 Mon Sep 17 00:00:00 2001 From: Guangming Luo Date: Thu, 20 Jul 2023 20:38:27 +0800 Subject: [PATCH 10/17] chore: remove wechat group in readme (#864) --- README.md | 2 -- README_cn.md | 4 +--- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/README.md b/README.md index ea347fdfa..ae09350ea 100644 --- a/README.md +++ b/README.md @@ -112,9 +112,7 @@ Hertz is distributed under the [Apache License, version 2.0](https://github.com/ - Lark: Scan the QR code below with [Lark](https://www.larksuite.com/zh_cn/download) to join our CloudWeGo/hertz user group. ![LarkGroup](images/lark_group.png) -- WeChat: CloudWeGo community WeChat group. -![WechatGroup](images/wechat_group_cn.png) ## Contributors Thank you for your contribution to Hertz! diff --git a/README_cn.md b/README_cn.md index a1c13fbd6..8ff215d04 100644 --- a/README_cn.md +++ b/README_cn.md @@ -113,9 +113,7 @@ Hertz 基于[Apache License 2.0](https://github.com/cloudwego/hertz/blob/main/LI - 飞书用户群([注册飞书](https://www.larksuite.com/zh_cn/download)进群) ![LarkGroup](images/lark_group_cn.png) -- 微信: CloudWeGo community - ![WechatGroup](images/wechat_group_cn.png) ## 贡献者 感谢您对 Hertz 作出的贡献! @@ -126,4 +124,4 @@ Hertz 基于[Apache License 2.0](https://github.com/cloudwego/hertz/blob/main/LI   

CloudWeGo 丰富了 CNCF 云原生生态。 -

\ No newline at end of file +

From 3161f043a97d062f887fa2d8a9af9ca9bc96fc95 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Tue, 25 Jul 2023 20:10:32 +0800 Subject: [PATCH 11/17] test(http1): fix retry test (#866) --- pkg/protocol/http1/client_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/protocol/http1/client_test.go b/pkg/protocol/http1/client_test.go index 19154453c..4dd6fcbf2 100644 --- a/pkg/protocol/http1/client_test.go +++ b/pkg/protocol/http1/client_test.go @@ -437,7 +437,7 @@ func TestRetry(t *testing.T) { Delay: time.Millisecond * 10, }, RetryIfFunc: func(req *protocol.Request, resp *protocol.Response, err error) bool { - return true + return resp.Header.ContentLength() != 10 }, }, Addr: "foobar", From 4a68f8488361bc796a047af4133b515ac309468e Mon Sep 17 00:00:00 2001 From: Wenju Gao Date: Mon, 31 Jul 2023 16:06:28 +0800 Subject: [PATCH 12/17] fix: keep isTLS field in keepalive when reset the request (#876) --- pkg/app/context.go | 3 ++- pkg/app/context_test.go | 4 +++- pkg/protocol/request.go | 19 ++++++++++++++++++- 3 files changed, 23 insertions(+), 3 deletions(-) diff --git a/pkg/app/context.go b/pkg/app/context.go index e4fcd0af9..559b84d03 100644 --- a/pkg/app/context.go +++ b/pkg/app/context.go @@ -754,7 +754,8 @@ func (ctx *RequestContext) ResetWithoutConn() { close(ctx.finished) ctx.finished = nil } - ctx.Request.Reset() + + ctx.Request.ResetWithoutConn() ctx.Response.Reset() if ctx.IsEnableTrace() { ctx.traceInfo.Reset() diff --git a/pkg/app/context_test.go b/pkg/app/context_test.go index b9e9d2f5f..8bf52f437 100644 --- a/pkg/app/context_test.go +++ b/pkg/app/context_test.go @@ -654,8 +654,10 @@ func TestContextReset(t *testing.T) { c.Params = param.Params{param.Param{}} c.Error(errors.New("test")) // nolint: errcheck c.Set("foo", "bar") + c.Request.SetIsTLS(true) c.ResetWithoutConn() - + c.Request.URI() + assert.DeepEqual(t, "https", string(c.Request.Scheme())) assert.False(t, c.IsAborted()) assert.DeepEqual(t, 0, len(c.Errors)) assert.Nil(t, c.Errors.Errors()) diff --git a/pkg/protocol/request.go b/pkg/protocol/request.go index af29db1ba..ef70f92ca 100644 --- a/pkg/protocol/request.go +++ b/pkg/protocol/request.go @@ -176,16 +176,25 @@ func (req *Request) MayContinue() bool { return bytes.Equal(req.Header.peek(bytestr.StrExpect), bytestr.Str100Continue) } +// Scheme returns the scheme of the request. +// uri will be parsed in ServeHTTP(before user's process), so that there is no need for uri nil-check. func (req *Request) Scheme() []byte { return req.uri.Scheme() } -func (req *Request) ResetSkipHeader() { +// For keepalive connection reuse. +// It is roughly the same as ResetSkipHeader, except that the connection related fields are removed: +// - req.isTLS +func (req *Request) resetSkipHeaderAndConn() { req.ResetBody() req.uri.Reset() req.parsedURI = false req.parsedPostArgs = false req.postArgs.Reset() +} + +func (req *Request) ResetSkipHeader() { + req.resetSkipHeaderAndConn() req.isTLS = false } @@ -814,6 +823,14 @@ func (req *Request) SetConnectionClose() { req.Header.SetConnectionClose(true) } +func (req *Request) ResetWithoutConn() { + req.Header.Reset() + req.resetSkipHeaderAndConn() + req.CloseBodyStream() + + req.options = nil +} + // AcquireRequest returns an empty Request instance from request pool. // // The returned Request instance may be passed to ReleaseRequest when it is From f4cde92c4781a355113b92e7a9e444b69e640fbb Mon Sep 17 00:00:00 2001 From: Wenju Gao Date: Mon, 31 Jul 2023 16:14:47 +0800 Subject: [PATCH 13/17] optimize(http1): retry if the error is caused by broken pooled conn (#871) --- pkg/app/client/client_test.go | 6 ++- pkg/app/server/hertz_test.go | 7 +-- pkg/common/errors/errors.go | 6 +-- pkg/common/test/mock/network.go | 65 +++++++++++++++++++++++- pkg/common/test/mock/network_test.go | 58 ++++++++++++++++++++- pkg/protocol/client/client.go | 11 ---- pkg/protocol/http1/client.go | 70 +++++++++++++++++++++----- pkg/protocol/http1/client_test.go | 67 ++++++++++++++++++++++-- pkg/protocol/http1/client_unix_test.go | 4 +- pkg/protocol/http1/server.go | 2 +- 10 files changed, 251 insertions(+), 45 deletions(-) diff --git a/pkg/app/client/client_test.go b/pkg/app/client/client_test.go index ba00f78a8..8449cd4d8 100644 --- a/pkg/app/client/client_test.go +++ b/pkg/app/client/client_test.go @@ -453,7 +453,9 @@ func TestClientReadTimeout(t *testing.T) { req.SetConnectionClose() if err := c.Do(context.Background(), req, res); !errors.Is(err, errs.ErrTimeout) { - t.Errorf("expected ErrTimeout got %#v", err) + if !strings.Contains(err.Error(), "timeout") { + t.Errorf("expected ErrTimeout got %#v", err) + } } protocol.ReleaseRequest(req) @@ -2267,7 +2269,7 @@ func TestClientDoWithDialFunc(t *testing.T) { func TestClientState(t *testing.T) { opt := config.NewOptions([]config.Option{}) - opt.Addr = ":11000" + opt.Addr = "127.0.0.1:11000" engine := route.NewEngine(opt) go engine.Run() diff --git a/pkg/app/server/hertz_test.go b/pkg/app/server/hertz_test.go index b9a07fbce..bc3e1a0e4 100644 --- a/pkg/app/server/hertz_test.go +++ b/pkg/app/server/hertz_test.go @@ -311,16 +311,17 @@ func TestNotAbsolutePathWithRawPath(t *testing.T) { } func TestWithBasePath(t *testing.T) { - engine := New(WithBasePath("/hertz"), WithHostPorts("127.0.0.1:9898")) + engine := New(WithBasePath("/hertz"), WithHostPorts("127.0.0.1:19898")) engine.POST("/test", func(c context.Context, ctx *app.RequestContext) { }) go engine.Run() - time.Sleep(200 * time.Microsecond) + time.Sleep(500 * time.Microsecond) var r http.Request r.ParseForm() r.Form.Add("xxxxxx", "xxx") body := strings.NewReader(r.Form.Encode()) - resp, _ := http.Post("http://127.0.0.1:9898/hertz/test", "application/x-www-form-urlencoded", body) + resp, err := http.Post("http://127.0.0.1:19898/hertz/test", "application/x-www-form-urlencoded", body) + assert.Nil(t, err) assert.DeepEqual(t, consts.StatusOK, resp.StatusCode) } diff --git a/pkg/common/errors/errors.go b/pkg/common/errors/errors.go index fdaf2d56b..918c7fa7a 100644 --- a/pkg/common/errors/errors.go +++ b/pkg/common/errors/errors.go @@ -53,17 +53,15 @@ var ( ErrChunkedStream = errors.New("chunked stream") ErrBodyTooLarge = errors.New("body size exceeds the given limit") ErrHijacked = errors.New("connection has been hijacked") - ErrIdleTimeout = errors.New("idle timeout") ErrTimeout = errors.New("timeout") - ErrReadTimeout = errors.New("read timeout") - ErrWriteTimeout = errors.New("write timeout") - ErrDialTimeout = errors.New("dial timeout") + ErrIdleTimeout = errors.New("idle timeout") ErrNothingRead = errors.New("nothing read") ErrShortConnection = errors.New("short connection") ErrNoFreeConns = errors.New("no free connections available to host") ErrConnectionClosed = errors.New("connection closed") ErrNotSupportProtocol = errors.New("not support protocol") ErrNoMultipartForm = errors.New("request has no multipart/form-data Content-Type") + ErrBadPoolConn = errors.New("connection is closed by peer while being in the connection pool") ) // ErrorType is an unsigned 64-bit error code as defined in the hertz spec. diff --git a/pkg/common/test/mock/network.go b/pkg/common/test/mock/network.go index 7ad083647..57293cb2b 100644 --- a/pkg/common/test/mock/network.go +++ b/pkg/common/test/mock/network.go @@ -18,6 +18,7 @@ package mock import ( "bytes" + "io" "net" "strings" "time" @@ -27,6 +28,11 @@ import ( "github.com/cloudwego/netpoll" ) +var ( + ErrReadTimeout = errs.New(errs.ErrTimeout, errs.ErrorTypePublic, "read timeout") + ErrWriteTimeout = errs.New(errs.ErrTimeout, errs.ErrorTypePublic, "write timeout") +) + type Conn struct { readTimeout time.Duration zr network.Reader @@ -146,7 +152,7 @@ func (m *SlowReadConn) Peek(i int) ([]byte, error) { time.Sleep(100 * time.Millisecond) } if err != nil || len(b) != i { - return nil, errs.ErrReadTimeout + return nil, ErrReadTimeout } return b, err } @@ -161,6 +167,61 @@ func NewConn(source string) *Conn { } } +type BrokenConn struct { + *Conn +} + +func (o *BrokenConn) Peek(i int) ([]byte, error) { + return nil, io.EOF +} + +func (o *BrokenConn) Flush() error { + return errs.ErrConnectionClosed +} + +func NewBrokenConn(source string) *BrokenConn { + return &BrokenConn{Conn: NewConn(source)} +} + +type OneTimeConn struct { + isRead bool + isFlushed bool + contentLength int + *Conn +} + +func (o *OneTimeConn) Peek(n int) ([]byte, error) { + if o.isRead { + return nil, io.EOF + } + return o.Conn.Peek(n) +} + +func (o *OneTimeConn) Skip(n int) error { + if o.isRead { + return io.EOF + } + o.contentLength -= n + + if o.contentLength == 0 { + o.isRead = true + } + + return o.Conn.Skip(n) +} + +func (o *OneTimeConn) Flush() error { + if o.isFlushed { + return errs.ErrConnectionClosed + } + o.isFlushed = true + return o.Conn.Flush() +} + +func NewOneTimeConn(source string) *OneTimeConn { + return &OneTimeConn{isRead: false, isFlushed: false, Conn: NewConn(source), contentLength: len(source)} +} + func NewSlowReadConn(source string) *SlowReadConn { return &SlowReadConn{Conn: NewConn(source)} } @@ -200,7 +261,7 @@ func (m *SlowWriteConn) Flush() error { time.Sleep(100 * time.Millisecond) if err == nil { time.Sleep(m.writeTimeout) - return errs.ErrWriteTimeout + return ErrWriteTimeout } return err } diff --git a/pkg/common/test/mock/network_test.go b/pkg/common/test/mock/network_test.go index 2d4ff973a..bd6ad5eb9 100644 --- a/pkg/common/test/mock/network_test.go +++ b/pkg/common/test/mock/network_test.go @@ -18,6 +18,7 @@ package mock import ( "context" + "io" "testing" "time" @@ -30,6 +31,7 @@ func TestConn(t *testing.T) { t.Run("TestReader", func(t *testing.T) { s1 := "abcdef4343" conn1 := NewConn(s1) + assert.Nil(t, conn1.SetWriteTimeout(1)) err := conn1.SetReadDeadline(time.Now().Add(time.Millisecond * 100)) assert.DeepEqual(t, nil, err) err = conn1.SetReadTimeout(time.Millisecond * 100) @@ -135,12 +137,16 @@ func TestSlowConn(t *testing.T) { t.Run("TestSlowReadConn", func(t *testing.T) { s1 := "abcdefg" conn := NewSlowReadConn(s1) + assert.Nil(t, conn.SetWriteTimeout(1)) + assert.Nil(t, conn.SetReadTimeout(1)) + assert.DeepEqual(t, time.Duration(1), conn.readTimeout) + b, err := conn.Peek(4) assert.DeepEqual(t, nil, err) assert.DeepEqual(t, s1[:4], string(b)) conn.Skip(len(s1)) _, err = conn.Peek(1) - assert.DeepEqual(t, errs.ErrReadTimeout, err) + assert.DeepEqual(t, ErrReadTimeout, err) _, err = SlowReadDialer("") assert.DeepEqual(t, nil, err) }) @@ -150,7 +156,7 @@ func TestSlowConn(t *testing.T) { assert.DeepEqual(t, nil, err) conn.SetWriteTimeout(time.Millisecond * 100) err = conn.Flush() - assert.DeepEqual(t, errs.ErrWriteTimeout, err) + assert.DeepEqual(t, ErrWriteTimeout, err) }) } @@ -183,3 +189,51 @@ func TestStreamConn(t *testing.T) { }) }) } + +func TestBrokenConn_Flush(t *testing.T) { + conn := NewBrokenConn("") + n, err := conn.Writer().WriteBinary([]byte("Foo")) + assert.DeepEqual(t, 3, n) + assert.Nil(t, err) + assert.DeepEqual(t, errs.ErrConnectionClosed, conn.Flush()) +} + +func TestBrokenConn_Peek(t *testing.T) { + conn := NewBrokenConn("Foo") + buf, err := conn.Peek(3) + assert.Nil(t, buf) + assert.DeepEqual(t, io.EOF, err) +} + +func TestOneTimeConn_Flush(t *testing.T) { + conn := NewOneTimeConn("") + n, err := conn.Writer().WriteBinary([]byte("Foo")) + assert.DeepEqual(t, 3, n) + assert.Nil(t, err) + assert.Nil(t, conn.Flush()) + n, err = conn.Writer().WriteBinary([]byte("Bar")) + assert.DeepEqual(t, 3, n) + assert.Nil(t, err) + assert.DeepEqual(t, errs.ErrConnectionClosed, conn.Flush()) +} + +func TestOneTimeConn_Skip(t *testing.T) { + conn := NewOneTimeConn("FooBar") + buf, err := conn.Peek(3) + assert.DeepEqual(t, "Foo", string(buf)) + assert.Nil(t, err) + assert.Nil(t, conn.Skip(3)) + assert.DeepEqual(t, 3, conn.contentLength) + + buf, err = conn.Peek(3) + assert.DeepEqual(t, "Bar", string(buf)) + assert.Nil(t, err) + assert.Nil(t, conn.Skip(3)) + assert.DeepEqual(t, 0, conn.contentLength) + + buf, err = conn.Peek(3) + assert.DeepEqual(t, 0, len(buf)) + assert.DeepEqual(t, io.EOF, err) + assert.DeepEqual(t, io.EOF, conn.Skip(3)) + assert.DeepEqual(t, 0, conn.contentLength) +} diff --git a/pkg/protocol/client/client.go b/pkg/protocol/client/client.go index 3715255e4..777f55cdd 100644 --- a/pkg/protocol/client/client.go +++ b/pkg/protocol/client/client.go @@ -43,7 +43,6 @@ package client import ( "context" - "io" "sync" "time" @@ -88,16 +87,6 @@ func DefaultRetryIf(req *protocol.Request, resp *protocol.Response, err error) b if isIdempotent(req, resp, err) { return true } - // Retry non-idempotent requests if the server closes - // the connection before sending the response. - // - // This case is possible if the server closes the idle - // keep-alive connection on timeout. - // - // Apache and nginx usually do this. - if err == io.EOF { - return true - } return false } diff --git a/pkg/protocol/http1/client.go b/pkg/protocol/http1/client.go index 34bb5f014..c306f071f 100644 --- a/pkg/protocol/http1/client.go +++ b/pkg/protocol/http1/client.go @@ -51,6 +51,7 @@ import ( "strings" "sync" "sync/atomic" + "syscall" "time" "github.com/cloudwego/hertz/internal/bytesconv" @@ -362,6 +363,7 @@ func (c *HostClient) Do(ctx context.Context, req *protocol.Request, resp *protoc canIdempotentRetry bool isDefaultRetryFunc = true attempts uint = 0 + connAttempts uint = 0 maxAttempts uint = 1 isRequestRetryable client.RetryIfFunc = client.DefaultRetryIf ) @@ -380,17 +382,38 @@ func (c *HostClient) Do(ctx context.Context, req *protocol.Request, resp *protoc atomic.AddInt32(&c.pendingRequests, 1) req.Options().StartRequest() for { + select { + case <-ctx.Done(): + req.CloseBodyStream() //nolint:errcheck + return ctx.Err() + default: + } + canIdempotentRetry, err = c.do(req, resp) // If there is no custom retry and err is equal to nil, the loop simply exits. if err == nil && isDefaultRetryFunc { + if connAttempts != 0 { + hlog.SystemLogger().Warnf("Client connection attempt times: %d, url: %s. "+ + "This is mainly because the connection in pool is closed by peer in advance. "+ + "If this number is too high which indicates that long-connection are basically unavailable, "+ + "try to change the request to short-connection.\n", connAttempts, req.URI().FullURI()) + } break } + // This connection is closed by the peer when it is in the connection pool. + // + // This case is possible if the server closes the idle + // keep-alive connection on timeout. + // + // Apache and nginx usually do this. + if canIdempotentRetry && client.DefaultRetryIf(req, resp, err) && errors.Is(err, errs.ErrBadPoolConn) { + connAttempts++ + continue + } + if isDefaultRetryFunc { - // canIdempotentRetry only makes sense if the user hasn't provided a custom retry function. - if !canIdempotentRetry { - break - } + break } attempts++ @@ -516,7 +539,7 @@ func (c *HostClient) doNonNilReqResp(req *protocol.Request, resp *protocol.Respo if reqTimeout < dialTimeout || dialTimeout == 0 { dialTimeout = reqTimeout } - cc, err := c.acquireConn(dialTimeout) + cc, inPool, err := c.acquireConn(dialTimeout) // if getting connection error, fast fail if err != nil { return false, err @@ -588,13 +611,16 @@ func (c *HostClient) doNonNilReqResp(req *protocol.Request, resp *protocol.Respo // Only if the connection is closed while writing the request. Try to parse the response and return. // In this case, the request/response is considered as successful. // Otherwise, return the former error. - zr := c.acquireReader(conn) defer zr.Release() if respI.ReadHeaderAndLimitBody(resp, zr, c.MaxResponseBodySize) == nil { return false, nil } + if inPool { + err = errs.ErrBadPoolConn + } + return true, err } @@ -620,6 +646,24 @@ func (c *HostClient) doNonNilReqResp(req *protocol.Request, resp *protocol.Respo } zr := c.acquireReader(conn) + // errs.ErrBadPoolConn error are returned when the + // 1 byte peek read fails, and we're actually anticipating a response. + // Usually this is just due to the inherent keep-alive shut down race, + // where the server closed the connection at the same time the client + // wrote. The underlying err field is usually io.EOF or some + // ECONNRESET sort of thing which varies by platform. + _, err = zr.Peek(1) + if err != nil { + zr.Release() //nolint:errcheck + c.closeConn(cc) + if inPool && (err == io.EOF || err == syscall.ECONNRESET) { + return true, errs.ErrBadPoolConn + } + // if this is not a pooled connection, + // we should not retry to avoid getting stuck in an endless retry loop. + return false, err + } + // init here for passing in ReadBodyStream's closure // and this value will be assigned after reading Response's Header // @@ -676,7 +720,7 @@ func (c *HostClient) SetMaxConns(newMaxConns int) { c.connsLock.Unlock() } -func (c *HostClient) acquireConn(dialTimeout time.Duration) (cc *clientConn, err error) { +func (c *HostClient) acquireConn(dialTimeout time.Duration) (cc *clientConn, inPool bool, err error) { createConn := false startCleaner := false @@ -705,11 +749,11 @@ func (c *HostClient) acquireConn(dialTimeout time.Duration) (cc *clientConn, err c.connsLock.Unlock() if cc != nil { - return cc, nil + return cc, true, nil } if !createConn { if c.MaxConnWaitTimeout <= 0 { - return nil, errs.ErrNoFreeConns + return nil, true, errs.ErrNoFreeConns } timeout := c.MaxConnWaitTimeout @@ -736,9 +780,9 @@ func (c *HostClient) acquireConn(dialTimeout time.Duration) (cc *clientConn, err select { case <-w.ready: - return w.conn, w.err + return w.conn, true, w.err case <-tc.C: - return nil, errs.ErrNoFreeConns + return nil, true, errs.ErrNoFreeConns } } @@ -749,11 +793,11 @@ func (c *HostClient) acquireConn(dialTimeout time.Duration) (cc *clientConn, err conn, err := c.dialHostHard(dialTimeout) if err != nil { c.decConnsCount() - return nil, err + return nil, false, err } cc = acquireClientConn(conn) - return cc, nil + return cc, false, nil } func (c *HostClient) queueForIdle(w *wantConn) { diff --git a/pkg/protocol/http1/client_test.go b/pkg/protocol/http1/client_test.go index 4dd6fcbf2..7bc389bf9 100644 --- a/pkg/protocol/http1/client_test.go +++ b/pkg/protocol/http1/client_test.go @@ -58,6 +58,7 @@ import ( "github.com/cloudwego/hertz/pkg/app/client/retry" "github.com/cloudwego/hertz/pkg/common/config" errs "github.com/cloudwego/hertz/pkg/common/errors" + "github.com/cloudwego/hertz/pkg/common/hlog" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/common/test/mock" "github.com/cloudwego/hertz/pkg/common/utils" @@ -69,6 +70,8 @@ import ( "github.com/cloudwego/netpoll" ) +var errDialTimeout = errs.New(errs.ErrTimeout, errs.ErrorTypePublic, "dial timeout") + func TestHostClientMaxConnWaitTimeoutWithEarlierDeadline(t *testing.T) { var ( emptyBodyCount uint8 @@ -235,7 +238,7 @@ type slowDialer struct { func (s *slowDialer) DialConnection(network, address string, timeout time.Duration, tlsConfig *tls.Config) (conn network.Conn, err error) { time.Sleep(timeout) - return nil, errs.ErrDialTimeout + return nil, errDialTimeout } func TestReadTimeoutPriority(t *testing.T) { @@ -264,7 +267,7 @@ func TestReadTimeoutPriority(t *testing.T) { case <-time.After(time.Second * 2): t.Fatalf("should use readTimeout in request options") case err := <-ch: - assert.DeepEqual(t, errs.ErrTimeout.Error(), err.Error()) + assert.DeepEqual(t, mock.ErrReadTimeout, err) } } @@ -334,7 +337,7 @@ func TestWriteTimeoutPriority(t *testing.T) { case <-time.After(time.Second * 2): t.Fatalf("should use writeTimeout in request options") case err := <-ch: - assert.DeepEqual(t, errs.ErrWriteTimeout.Error(), err.Error()) + assert.DeepEqual(t, mock.ErrWriteTimeout, err) } } @@ -359,10 +362,10 @@ func TestDialTimeoutPriority(t *testing.T) { ch <- c.Do(context.Background(), req, resp) }() select { - case <-time.After(time.Second * 2000): + case <-time.After(time.Second * 2): t.Fatalf("should use dialTimeout in request options") case err := <-ch: - assert.DeepEqual(t, errs.ErrDialTimeout.Error(), err.Error()) + assert.DeepEqual(t, errDialTimeout, err) } } @@ -479,3 +482,57 @@ type retryConn struct { func (w retryConn) SetWriteTimeout(t time.Duration) error { return errors.New("should retry") } + +func TestConnInPoolRetry(t *testing.T) { + c := &HostClient{ + ClientOptions: &ClientOptions{ + Dialer: newSlowConnDialer(func(network, addr string) (network.Conn, error) { + return mock.NewOneTimeConn("HTTP/1.1 200 OK\r\nContent-Length: 10\r\nContent-Type: foo/bar\r\n\r\n0123456789"), nil + }), + }, + Addr: "foobar", + } + + req := protocol.AcquireRequest() + req.SetRequestURI("http://foobar/baz") + req.SetOptions(config.WithWriteTimeout(time.Millisecond * 100)) + resp := protocol.AcquireResponse() + + logbuf := &bytes.Buffer{} + hlog.SetOutput(logbuf) + + err := c.Do(context.Background(), req, resp) + assert.Nil(t, err) + assert.DeepEqual(t, resp.StatusCode(), 200) + assert.DeepEqual(t, string(resp.Body()), "0123456789") + assert.True(t, logbuf.String() == "") + protocol.ReleaseResponse(resp) + resp = protocol.AcquireResponse() + err = c.Do(context.Background(), req, resp) + assert.Nil(t, err) + assert.DeepEqual(t, resp.StatusCode(), 200) + assert.DeepEqual(t, string(resp.Body()), "0123456789") + assert.True(t, strings.Contains(logbuf.String(), "Client connection attempt times: 1")) +} + +func TestConnNotRetry(t *testing.T) { + c := &HostClient{ + ClientOptions: &ClientOptions{ + Dialer: newSlowConnDialer(func(network, addr string) (network.Conn, error) { + return mock.NewBrokenConn(""), nil + }), + }, + Addr: "foobar", + } + + req := protocol.AcquireRequest() + req.SetRequestURI("http://foobar/baz") + req.SetOptions(config.WithWriteTimeout(time.Millisecond * 100)) + resp := protocol.AcquireResponse() + logbuf := &bytes.Buffer{} + hlog.SetOutput(logbuf) + err := c.Do(context.Background(), req, resp) + assert.DeepEqual(t, errs.ErrConnectionClosed, err) + assert.True(t, logbuf.String() == "") + protocol.ReleaseResponse(resp) +} diff --git a/pkg/protocol/http1/client_unix_test.go b/pkg/protocol/http1/client_unix_test.go index 6e417ba18..1e88502d3 100644 --- a/pkg/protocol/http1/client_unix_test.go +++ b/pkg/protocol/http1/client_unix_test.go @@ -35,7 +35,7 @@ import ( ) func TestGcBodyStream(t *testing.T) { - srv := &http.Server{Addr: ":11001", Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + srv := &http.Server{Addr: "127.0.0.1:11001", Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { for range [1024]int{} { w.Write([]byte("hello world\n")) } @@ -69,7 +69,7 @@ func TestGcBodyStream(t *testing.T) { } func TestMaxConn(t *testing.T) { - srv := &http.Server{Addr: ":11002", Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + srv := &http.Server{Addr: "127.0.0.1:11002", Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { w.Write([]byte("hello world\n")) })} go srv.ListenAndServe() diff --git a/pkg/protocol/http1/server.go b/pkg/protocol/http1/server.go index 8aaed6db7..77cc74e18 100644 --- a/pkg/protocol/http1/server.go +++ b/pkg/protocol/http1/server.go @@ -48,7 +48,7 @@ const NextProtoTLS = suite.HTTP1 var ( errHijacked = errs.New(errs.ErrHijacked, errs.ErrorTypePublic, nil) - errIdleTimeout = errs.New(errs.ErrIdleTimeout, errs.ErrorTypePublic, nil) + errIdleTimeout = errs.New(errs.ErrIdleTimeout, errs.ErrorTypePrivate, nil) errShortConnection = errs.New(errs.ErrShortConnection, errs.ErrorTypePublic, "server is going to close the connection") errUnexpectedEOF = errs.NewPublic(io.ErrUnexpectedEOF.Error() + " when reading request") ) From 41c781eb0755258d8b0d1b6b87d7495654a292ee Mon Sep 17 00:00:00 2001 From: Wenju Gao Date: Mon, 31 Jul 2023 16:40:54 +0800 Subject: [PATCH 14/17] optimize(client): do not write body in stream mode if content-length is 0 (#875) --- pkg/common/test/mock/network.go | 6 +++++- pkg/common/test/mock/network_test.go | 2 +- pkg/protocol/http1/ext/common.go | 3 +++ pkg/protocol/http1/ext/common_test.go | 6 ++++++ 4 files changed, 15 insertions(+), 2 deletions(-) diff --git a/pkg/common/test/mock/network.go b/pkg/common/test/mock/network.go index 57293cb2b..fe671ae56 100644 --- a/pkg/common/test/mock/network.go +++ b/pkg/common/test/mock/network.go @@ -172,7 +172,11 @@ type BrokenConn struct { } func (o *BrokenConn) Peek(i int) ([]byte, error) { - return nil, io.EOF + return nil, io.ErrUnexpectedEOF +} + +func (o *BrokenConn) Read(b []byte) (n int, err error) { + return 0, io.ErrUnexpectedEOF } func (o *BrokenConn) Flush() error { diff --git a/pkg/common/test/mock/network_test.go b/pkg/common/test/mock/network_test.go index bd6ad5eb9..4c9c4cf5b 100644 --- a/pkg/common/test/mock/network_test.go +++ b/pkg/common/test/mock/network_test.go @@ -202,7 +202,7 @@ func TestBrokenConn_Peek(t *testing.T) { conn := NewBrokenConn("Foo") buf, err := conn.Peek(3) assert.Nil(t, buf) - assert.DeepEqual(t, io.EOF, err) + assert.DeepEqual(t, io.ErrUnexpectedEOF, err) } func TestOneTimeConn_Flush(t *testing.T) { diff --git a/pkg/protocol/http1/ext/common.go b/pkg/protocol/http1/ext/common.go index b31b40cdd..864988317 100644 --- a/pkg/protocol/http1/ext/common.go +++ b/pkg/protocol/http1/ext/common.go @@ -135,6 +135,9 @@ func WriteBodyChunked(w network.Writer, r io.Reader) error { } func WriteBodyFixedSize(w network.Writer, r io.Reader, size int64) error { + if size == 0 { + return nil + } if size > consts.MaxSmallFileSize { if err := w.Flush(); err != nil { return err diff --git a/pkg/protocol/http1/ext/common_test.go b/pkg/protocol/http1/ext/common_test.go index 824ccf4ca..78a9fede0 100644 --- a/pkg/protocol/http1/ext/common_test.go +++ b/pkg/protocol/http1/ext/common_test.go @@ -155,6 +155,12 @@ func TestBodyFixedSize(t *testing.T) { assert.DeepEqual(t, body, rb) } +func TestBodyFixedSizeQuickPath(t *testing.T) { + conn := mock.NewBrokenConn("") + err := WriteBodyFixedSize(conn.Writer(), conn, 0) + assert.Nil(t, err) +} + func TestBodyIdentity(t *testing.T) { body := mock.CreateFixedBody(1024) zr := mock.NewZeroCopyReader(string(body)) From fc36f6f69e7f29edd640aeac51f6eae8ab11e9d3 Mon Sep 17 00:00:00 2001 From: Wenju Gao Date: Thu, 3 Aug 2023 16:29:49 +0800 Subject: [PATCH 15/17] optimize: unify timeout error (#884) --- pkg/network/netpoll/connection.go | 5 +++++ pkg/network/standard/connection.go | 4 ++++ pkg/protocol/http1/client.go | 4 ++++ 3 files changed, 13 insertions(+) diff --git a/pkg/network/netpoll/connection.go b/pkg/network/netpoll/connection.go index 7c26ef914..e6df5030c 100644 --- a/pkg/network/netpoll/connection.go +++ b/pkg/network/netpoll/connection.go @@ -35,6 +35,11 @@ func (c *Conn) ToHertzError(err error) error { if errors.Is(err, netpoll.ErrConnClosed) || errors.Is(err, syscall.EPIPE) { return errs.ErrConnectionClosed } + + // only unify read timeout for now + if errors.Is(err, netpoll.ErrReadTimeout) { + return errs.ErrTimeout + } return err } diff --git a/pkg/network/standard/connection.go b/pkg/network/standard/connection.go index e38ee9e7b..13c10b479 100644 --- a/pkg/network/standard/connection.go +++ b/pkg/network/standard/connection.go @@ -54,6 +54,10 @@ func (c *Conn) ToHertzError(err error) error { if errors.Is(err, syscall.EPIPE) || errors.Is(err, syscall.ENOTCONN) { return errs.ErrConnectionClosed } + if netErr, ok := err.(*net.OpError); ok && netErr.Timeout() { + return errs.ErrTimeout + } + return err } diff --git a/pkg/protocol/http1/client.go b/pkg/protocol/http1/client.go index c306f071f..c2b7cf5a4 100644 --- a/pkg/protocol/http1/client.go +++ b/pkg/protocol/http1/client.go @@ -661,6 +661,10 @@ func (c *HostClient) doNonNilReqResp(req *protocol.Request, resp *protocol.Respo } // if this is not a pooled connection, // we should not retry to avoid getting stuck in an endless retry loop. + errNorm, ok := conn.(network.ErrorNormalization) + if ok { + err = errNorm.ToHertzError(err) + } return false, err } From 2b0a165b60f9361c20ffc36b551250310c293107 Mon Sep 17 00:00:00 2001 From: GuangyuFan <97507466+FGYFFFF@users.noreply.github.com> Date: Fri, 4 Aug 2023 15:01:45 +0800 Subject: [PATCH 16/17] chore(hz): release v0.6.6 (#886) --- cmd/hz/meta/const.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/hz/meta/const.go b/cmd/hz/meta/const.go index 66f0a1fb4..8e7490dfe 100644 --- a/cmd/hz/meta/const.go +++ b/cmd/hz/meta/const.go @@ -19,7 +19,7 @@ package meta import "runtime" // Version hz version -const Version = "v0.6.5" +const Version = "v0.6.6" const DefaultServiceName = "hertz_service" From 69c28889a12c680f4493548b48f444ee6fa96aa0 Mon Sep 17 00:00:00 2001 From: alice <90381261+alice-yyds@users.noreply.github.com> Date: Fri, 4 Aug 2023 15:02:30 +0800 Subject: [PATCH 17/17] chore: update version v0.6.7 --- version.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/version.go b/version.go index f56cbc31a..2f7a82fb5 100644 --- a/version.go +++ b/version.go @@ -19,5 +19,5 @@ package hertz // Name and Version info of this framework, used for statistics and debug const ( Name = "Hertz" - Version = "v0.6.6" + Version = "v0.6.7" )