From d1b1e03b3f949ae7ca596c54708a15b11c8b6f15 Mon Sep 17 00:00:00 2001 From: JmPotato Date: Mon, 4 Dec 2023 14:51:39 +0800 Subject: [PATCH] Add the client request check tests Signed-off-by: JmPotato --- client/http/client.go | 8 ++--- client/http/client_test.go | 73 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 77 insertions(+), 4 deletions(-) create mode 100644 client/http/client_test.go diff --git a/client/http/client.go b/client/http/client.go index 157524815024..3109fc470bb7 100644 --- a/client/http/client.go +++ b/client/http/client.go @@ -219,8 +219,8 @@ func (c *client) execDuration(name string, duration time.Duration) { // Header key definition constants. const ( - PDAllowFollowerHandleKey = "PD-Allow-Follower-Handle" - ComponentSignatureKey = "component" + pdAllowFollowerHandleKey = "PD-Allow-Follower-Handle" + componentSignatureKey = "component" ) // HeaderOption configures the HTTP header. @@ -229,7 +229,7 @@ type HeaderOption func(header http.Header) // WithAllowFollowerHandle sets the header field to allow a PD follower to handle this request. func WithAllowFollowerHandle() HeaderOption { return func(header http.Header) { - header.Set("PD-Allow-Follower-Handle", "true") + header.Set(pdAllowFollowerHandleKey, "true") } } @@ -278,7 +278,7 @@ func (c *client) request( for _, opt := range headerOpts { opt(req.Header) } - req.Header.Set(ComponentSignatureKey, c.callerID) + req.Header.Set(componentSignatureKey, c.callerID) start := time.Now() resp, err := c.inner.cli.Do(req) diff --git a/client/http/client_test.go b/client/http/client_test.go new file mode 100644 index 000000000000..096435fb5668 --- /dev/null +++ b/client/http/client_test.go @@ -0,0 +1,73 @@ +// Copyright 2023 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package http + +import ( + "context" + "net/http" + "testing" + + "github.com/stretchr/testify/require" +) + +// requestChecker is used to check the HTTP request sent by the client. +type requestChecker struct { + checker func(req *http.Request) error +} + +// RoundTrip implements the `http.RoundTripper` interface. +func (rc *requestChecker) RoundTrip(req *http.Request) (resp *http.Response, err error) { + return &http.Response{StatusCode: http.StatusOK}, rc.checker(req) +} + +func newHTTPClientWithRequestChecker(checker func(req *http.Request) error) *http.Client { + return &http.Client{ + Transport: &requestChecker{checker: checker}, + } +} + +func TestPDAllowFollowerHandleHeader(t *testing.T) { + re := require.New(t) + var exceptedVal string + httpClient := newHTTPClientWithRequestChecker(func(req *http.Request) error { + val := req.Header.Get(pdAllowFollowerHandleKey) + if val != exceptedVal { + re.Failf("PD allow follower handler header check failed", + "should be %s, but got %s", exceptedVal, val) + } + return nil + }) + c := NewClient([]string{"http://127.0.0.1"}, WithHTTPClient(httpClient)) + c.GetRegions(context.Background()) + exceptedVal = "true" + c.GetHistoryHotRegions(context.Background(), &HistoryHotRegionsRequest{}) +} + +func TestCallerID(t *testing.T) { + re := require.New(t) + var exceptedVal string + httpClient := newHTTPClientWithRequestChecker(func(req *http.Request) error { + val := req.Header.Get(componentSignatureKey) + if val != exceptedVal { + re.Failf("Caller ID header check failed", + "should be %s, but got %s", exceptedVal, val) + } + return nil + }) + c := NewClient([]string{"http://127.0.0.1"}, WithHTTPClient(httpClient)) + c.GetRegions(context.Background()) + exceptedVal = "test" + c.WithCallerID(exceptedVal).GetRegions(context.Background()) +}