diff --git a/client/http/client.go b/client/http/client.go index 36355a90d19..b79aa9ca002 100644 --- a/client/http/client.go +++ b/client/http/client.go @@ -32,6 +32,7 @@ import ( ) const ( + defaultCallerID = "pd-http-client" httpScheme = "http" httpsScheme = "https" networkErrorStatus = "network error" @@ -79,6 +80,8 @@ type Client interface { GetMinResolvedTSByStoresIDs(context.Context, []uint64) (uint64, map[uint64]uint64, error) /* Client-related methods */ + // WithCallerID sets and returns a new client with the given caller ID. + WithCallerID(string) Client // WithRespHandler sets and returns a new client with the given HTTP response handler. // This allows the caller to customize how the response is handled, including error handling logic. // Additionally, it is important for the caller to handle the content of the response body properly @@ -89,11 +92,20 @@ type Client interface { var _ Client = (*client)(nil) -type client struct { +// clientInner is the inner implementation of the PD HTTP client, which will +// implement some internal logics, such as HTTP client, service discovery, etc. +type clientInner struct { pdAddrs []string tlsConf *tls.Config cli *http.Client +} + +type client struct { + // Wrap this struct is to make sure the inner implementation + // won't be exposed and cloud be consistent during the copy. + inner *clientInner + callerID string respHandler func(resp *http.Response, res interface{}) error requestCounter *prometheus.CounterVec @@ -106,7 +118,7 @@ type ClientOption func(c *client) // WithHTTPClient configures the client with the given initialized HTTP client. func WithHTTPClient(cli *http.Client) ClientOption { return func(c *client) { - c.cli = cli + c.inner.cli = cli } } @@ -114,7 +126,7 @@ func WithHTTPClient(cli *http.Client) ClientOption { // This option won't work if the client is configured with WithHTTPClient. func WithTLSConfig(tlsConf *tls.Config) ClientOption { return func(c *client) { - c.tlsConf = tlsConf + c.inner.tlsConf = tlsConf } } @@ -134,7 +146,7 @@ func NewClient( pdAddrs []string, opts ...ClientOption, ) Client { - c := &client{} + c := &client{inner: &clientInner{}, callerID: defaultCallerID} // Apply the options first. for _, opt := range opts { opt(c) @@ -143,7 +155,7 @@ func NewClient( for i, addr := range pdAddrs { if !strings.HasPrefix(addr, httpScheme) { var scheme string - if c.tlsConf != nil { + if c.inner.tlsConf != nil { scheme = httpsScheme } else { scheme = httpScheme @@ -151,14 +163,14 @@ func NewClient( pdAddrs[i] = fmt.Sprintf("%s://%s", scheme, addr) } } - c.pdAddrs = pdAddrs + c.inner.pdAddrs = pdAddrs // Init the HTTP client if it's not configured. - if c.cli == nil { - c.cli = &http.Client{Timeout: defaultTimeout} - if c.tlsConf != nil { + if c.inner.cli == nil { + c.inner.cli = &http.Client{Timeout: defaultTimeout} + if c.inner.tlsConf != nil { transport := http.DefaultTransport.(*http.Transport).Clone() - transport.TLSClientConfig = c.tlsConf - c.cli.Transport = transport + transport.TLSClientConfig = c.inner.tlsConf + c.inner.cli.Transport = transport } } @@ -167,12 +179,22 @@ func NewClient( // Close closes the HTTP client. func (c *client) Close() { - if c.cli != nil { - c.cli.CloseIdleConnections() + if c.inner == nil { + return + } + if c.inner.cli != nil { + c.inner.cli.CloseIdleConnections() } log.Info("[pd] http client closed") } +// WithCallerID sets and returns a new client with the given caller ID. +func (c *client) WithCallerID(callerID string) Client { + newClient := *c + newClient.callerID = callerID + return &newClient +} + // WithRespHandler sets and returns a new client with the given HTTP response handler. func (c *client) WithRespHandler( handler func(resp *http.Response, res interface{}) error, @@ -196,13 +218,19 @@ func (c *client) execDuration(name string, duration time.Duration) { c.executionDuration.WithLabelValues(name).Observe(duration.Seconds()) } +// Header key definition constants. +const ( + pdAllowFollowerHandleKey = "PD-Allow-Follower-Handle" + componentSignatureKey = "component" +) + // HeaderOption configures the HTTP header. type HeaderOption func(header http.Header) // WithAllowFollowerHandle sets the header field to allow a PD follower to handle this request. func WithAllowFollowerHandle() HeaderOption { return func(header http.Header) { - header.Set("PD-Allow-Follower-Handle", "true") + header.Set(pdAllowFollowerHandleKey, "true") } } @@ -218,8 +246,8 @@ func (c *client) requestWithRetry( err error addr string ) - for idx := 0; idx < len(c.pdAddrs); idx++ { - addr = c.pdAddrs[idx] + for idx := 0; idx < len(c.inner.pdAddrs); idx++ { + addr = c.inner.pdAddrs[idx] err = c.request(ctx, name, fmt.Sprintf("%s%s", addr, uri), method, body, res, headerOpts...) if err == nil { break @@ -239,6 +267,8 @@ func (c *client) request( logFields := []zap.Field{ zap.String("name", name), zap.String("url", url), + zap.String("method", method), + zap.String("caller-id", c.callerID), } log.Debug("[pd] request the http url", logFields...) req, err := http.NewRequestWithContext(ctx, method, url, body) @@ -249,8 +279,10 @@ func (c *client) request( for _, opt := range headerOpts { opt(req.Header) } + req.Header.Set(componentSignatureKey, c.callerID) + start := time.Now() - resp, err := c.cli.Do(req) + resp, err := c.inner.cli.Do(req) if err != nil { c.reqCounter(name, networkErrorStatus) log.Error("[pd] do http request failed", append(logFields, zap.Error(err))...) diff --git a/client/http/client_test.go b/client/http/client_test.go new file mode 100644 index 00000000000..621910e29ea --- /dev/null +++ b/client/http/client_test.go @@ -0,0 +1,73 @@ +// Copyright 2023 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package http + +import ( + "context" + "net/http" + "testing" + + "github.com/stretchr/testify/require" +) + +// requestChecker is used to check the HTTP request sent by the client. +type requestChecker struct { + checker func(req *http.Request) error +} + +// RoundTrip implements the `http.RoundTripper` interface. +func (rc *requestChecker) RoundTrip(req *http.Request) (resp *http.Response, err error) { + return &http.Response{StatusCode: http.StatusOK}, rc.checker(req) +} + +func newHTTPClientWithRequestChecker(checker func(req *http.Request) error) *http.Client { + return &http.Client{ + Transport: &requestChecker{checker: checker}, + } +} + +func TestPDAllowFollowerHandleHeader(t *testing.T) { + re := require.New(t) + var expectedVal string + httpClient := newHTTPClientWithRequestChecker(func(req *http.Request) error { + val := req.Header.Get(pdAllowFollowerHandleKey) + if val != expectedVal { + re.Failf("PD allow follower handler header check failed", + "should be %s, but got %s", expectedVal, val) + } + return nil + }) + c := NewClient([]string{"http://127.0.0.1"}, WithHTTPClient(httpClient)) + c.GetRegions(context.Background()) + expectedVal = "true" + c.GetHistoryHotRegions(context.Background(), &HistoryHotRegionsRequest{}) +} + +func TestCallerID(t *testing.T) { + re := require.New(t) + expectedVal := defaultCallerID + httpClient := newHTTPClientWithRequestChecker(func(req *http.Request) error { + val := req.Header.Get(componentSignatureKey) + if val != expectedVal { + re.Failf("Caller ID header check failed", + "should be %s, but got %s", expectedVal, val) + } + return nil + }) + c := NewClient([]string{"http://127.0.0.1"}, WithHTTPClient(httpClient)) + c.GetRegions(context.Background()) + expectedVal = "test" + c.WithCallerID(expectedVal).GetRegions(context.Background()) +}