From 5b0fe234bfadd7bd04e960d9a143d6cd0d2f0f4f Mon Sep 17 00:00:00 2001 From: JmPotato Date: Mon, 4 Dec 2023 13:25:13 +0800 Subject: [PATCH 1/3] Introduce caller ID into the HTTP client Signed-off-by: JmPotato --- client/http/client.go | 55 ++++++++++++++++++++++++++++++------------- 1 file changed, 39 insertions(+), 16 deletions(-) diff --git a/client/http/client.go b/client/http/client.go index 36355a90d19..a7a68b28f55 100644 --- a/client/http/client.go +++ b/client/http/client.go @@ -79,6 +79,8 @@ type Client interface { GetMinResolvedTSByStoresIDs(context.Context, []uint64) (uint64, map[uint64]uint64, error) /* Client-related methods */ + // WithCallerID sets and returns a new client with the given caller ID. + WithCallerID(string) Client // WithRespHandler sets and returns a new client with the given HTTP response handler. // This allows the caller to customize how the response is handled, including error handling logic. // Additionally, it is important for the caller to handle the content of the response body properly @@ -89,11 +91,20 @@ type Client interface { var _ Client = (*client)(nil) -type client struct { +// clientInner is the inner implementation of the PD HTTP client, which will +// implement some internal logics, such as HTTP client, service discovery, etc. +type clientInner struct { pdAddrs []string tlsConf *tls.Config cli *http.Client +} +type client struct { + // Wrap this struct is to make sure the inner implementation + // won't be exposed and cloud be consistent during the copy. + inner *clientInner + + callerID string respHandler func(resp *http.Response, res interface{}) error requestCounter *prometheus.CounterVec @@ -106,7 +117,7 @@ type ClientOption func(c *client) // WithHTTPClient configures the client with the given initialized HTTP client. func WithHTTPClient(cli *http.Client) ClientOption { return func(c *client) { - c.cli = cli + c.inner.cli = cli } } @@ -114,7 +125,7 @@ func WithHTTPClient(cli *http.Client) ClientOption { // This option won't work if the client is configured with WithHTTPClient. func WithTLSConfig(tlsConf *tls.Config) ClientOption { return func(c *client) { - c.tlsConf = tlsConf + c.inner.tlsConf = tlsConf } } @@ -134,7 +145,7 @@ func NewClient( pdAddrs []string, opts ...ClientOption, ) Client { - c := &client{} + c := &client{inner: &clientInner{}} // Apply the options first. for _, opt := range opts { opt(c) @@ -143,7 +154,7 @@ func NewClient( for i, addr := range pdAddrs { if !strings.HasPrefix(addr, httpScheme) { var scheme string - if c.tlsConf != nil { + if c.inner.tlsConf != nil { scheme = httpsScheme } else { scheme = httpScheme @@ -151,14 +162,14 @@ func NewClient( pdAddrs[i] = fmt.Sprintf("%s://%s", scheme, addr) } } - c.pdAddrs = pdAddrs + c.inner.pdAddrs = pdAddrs // Init the HTTP client if it's not configured. - if c.cli == nil { - c.cli = &http.Client{Timeout: defaultTimeout} - if c.tlsConf != nil { + if c.inner.cli == nil { + c.inner.cli = &http.Client{Timeout: defaultTimeout} + if c.inner.tlsConf != nil { transport := http.DefaultTransport.(*http.Transport).Clone() - transport.TLSClientConfig = c.tlsConf - c.cli.Transport = transport + transport.TLSClientConfig = c.inner.tlsConf + c.inner.cli.Transport = transport } } @@ -167,12 +178,22 @@ func NewClient( // Close closes the HTTP client. func (c *client) Close() { - if c.cli != nil { - c.cli.CloseIdleConnections() + if c.inner == nil { + return + } + if c.inner.cli != nil { + c.inner.cli.CloseIdleConnections() } log.Info("[pd] http client closed") } +// WithCallerID sets and returns a new client with the given caller ID. +func (c *client) WithCallerID(callerID string) Client { + newClient := *c + newClient.callerID = callerID + return &newClient +} + // WithRespHandler sets and returns a new client with the given HTTP response handler. func (c *client) WithRespHandler( handler func(resp *http.Response, res interface{}) error, @@ -218,8 +239,8 @@ func (c *client) requestWithRetry( err error addr string ) - for idx := 0; idx < len(c.pdAddrs); idx++ { - addr = c.pdAddrs[idx] + for idx := 0; idx < len(c.inner.pdAddrs); idx++ { + addr = c.inner.pdAddrs[idx] err = c.request(ctx, name, fmt.Sprintf("%s%s", addr, uri), method, body, res, headerOpts...) if err == nil { break @@ -239,6 +260,8 @@ func (c *client) request( logFields := []zap.Field{ zap.String("name", name), zap.String("url", url), + zap.String("method", method), + zap.String("caller-id", c.callerID), } log.Debug("[pd] request the http url", logFields...) req, err := http.NewRequestWithContext(ctx, method, url, body) @@ -250,7 +273,7 @@ func (c *client) request( opt(req.Header) } start := time.Now() - resp, err := c.cli.Do(req) + resp, err := c.inner.cli.Do(req) if err != nil { c.reqCounter(name, networkErrorStatus) log.Error("[pd] do http request failed", append(logFields, zap.Error(err))...) From 3d4a6056f5ef95fae86dcd0c9f7ae93e501a2b55 Mon Sep 17 00:00:00 2001 From: JmPotato Date: Mon, 4 Dec 2023 14:27:28 +0800 Subject: [PATCH 2/3] Set the component header key with the caller ID Signed-off-by: JmPotato --- client/http/client.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/client/http/client.go b/client/http/client.go index a7a68b28f55..15752481502 100644 --- a/client/http/client.go +++ b/client/http/client.go @@ -217,6 +217,12 @@ func (c *client) execDuration(name string, duration time.Duration) { c.executionDuration.WithLabelValues(name).Observe(duration.Seconds()) } +// Header key definition constants. +const ( + PDAllowFollowerHandleKey = "PD-Allow-Follower-Handle" + ComponentSignatureKey = "component" +) + // HeaderOption configures the HTTP header. type HeaderOption func(header http.Header) @@ -272,6 +278,8 @@ func (c *client) request( for _, opt := range headerOpts { opt(req.Header) } + req.Header.Set(ComponentSignatureKey, c.callerID) + start := time.Now() resp, err := c.inner.cli.Do(req) if err != nil { From 00916552f650d452bfd772a074e9a9e0a612d455 Mon Sep 17 00:00:00 2001 From: JmPotato Date: Mon, 4 Dec 2023 14:51:39 +0800 Subject: [PATCH 3/3] Add the client request check tests Signed-off-by: JmPotato --- client/http/client.go | 11 +++--- client/http/client_test.go | 73 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 5 deletions(-) create mode 100644 client/http/client_test.go diff --git a/client/http/client.go b/client/http/client.go index 15752481502..b79aa9ca002 100644 --- a/client/http/client.go +++ b/client/http/client.go @@ -32,6 +32,7 @@ import ( ) const ( + defaultCallerID = "pd-http-client" httpScheme = "http" httpsScheme = "https" networkErrorStatus = "network error" @@ -145,7 +146,7 @@ func NewClient( pdAddrs []string, opts ...ClientOption, ) Client { - c := &client{inner: &clientInner{}} + c := &client{inner: &clientInner{}, callerID: defaultCallerID} // Apply the options first. for _, opt := range opts { opt(c) @@ -219,8 +220,8 @@ func (c *client) execDuration(name string, duration time.Duration) { // Header key definition constants. const ( - PDAllowFollowerHandleKey = "PD-Allow-Follower-Handle" - ComponentSignatureKey = "component" + pdAllowFollowerHandleKey = "PD-Allow-Follower-Handle" + componentSignatureKey = "component" ) // HeaderOption configures the HTTP header. @@ -229,7 +230,7 @@ type HeaderOption func(header http.Header) // WithAllowFollowerHandle sets the header field to allow a PD follower to handle this request. func WithAllowFollowerHandle() HeaderOption { return func(header http.Header) { - header.Set("PD-Allow-Follower-Handle", "true") + header.Set(pdAllowFollowerHandleKey, "true") } } @@ -278,7 +279,7 @@ func (c *client) request( for _, opt := range headerOpts { opt(req.Header) } - req.Header.Set(ComponentSignatureKey, c.callerID) + req.Header.Set(componentSignatureKey, c.callerID) start := time.Now() resp, err := c.inner.cli.Do(req) diff --git a/client/http/client_test.go b/client/http/client_test.go new file mode 100644 index 00000000000..621910e29ea --- /dev/null +++ b/client/http/client_test.go @@ -0,0 +1,73 @@ +// Copyright 2023 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package http + +import ( + "context" + "net/http" + "testing" + + "github.com/stretchr/testify/require" +) + +// requestChecker is used to check the HTTP request sent by the client. +type requestChecker struct { + checker func(req *http.Request) error +} + +// RoundTrip implements the `http.RoundTripper` interface. +func (rc *requestChecker) RoundTrip(req *http.Request) (resp *http.Response, err error) { + return &http.Response{StatusCode: http.StatusOK}, rc.checker(req) +} + +func newHTTPClientWithRequestChecker(checker func(req *http.Request) error) *http.Client { + return &http.Client{ + Transport: &requestChecker{checker: checker}, + } +} + +func TestPDAllowFollowerHandleHeader(t *testing.T) { + re := require.New(t) + var expectedVal string + httpClient := newHTTPClientWithRequestChecker(func(req *http.Request) error { + val := req.Header.Get(pdAllowFollowerHandleKey) + if val != expectedVal { + re.Failf("PD allow follower handler header check failed", + "should be %s, but got %s", expectedVal, val) + } + return nil + }) + c := NewClient([]string{"http://127.0.0.1"}, WithHTTPClient(httpClient)) + c.GetRegions(context.Background()) + expectedVal = "true" + c.GetHistoryHotRegions(context.Background(), &HistoryHotRegionsRequest{}) +} + +func TestCallerID(t *testing.T) { + re := require.New(t) + expectedVal := defaultCallerID + httpClient := newHTTPClientWithRequestChecker(func(req *http.Request) error { + val := req.Header.Get(componentSignatureKey) + if val != expectedVal { + re.Failf("Caller ID header check failed", + "should be %s, but got %s", expectedVal, val) + } + return nil + }) + c := NewClient([]string{"http://127.0.0.1"}, WithHTTPClient(httpClient)) + c.GetRegions(context.Background()) + expectedVal = "test" + c.WithCallerID(expectedVal).GetRegions(context.Background()) +}