Skip to content

Commit

Permalink
Add the client request check tests
Browse files Browse the repository at this point in the history
Signed-off-by: JmPotato <[email protected]>
  • Loading branch information
JmPotato committed Dec 4, 2023
1 parent 3d4a605 commit d1b1e03
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 4 deletions.
8 changes: 4 additions & 4 deletions client/http/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,8 @@ func (c *client) execDuration(name string, duration time.Duration) {

// Header key definition constants.
const (
PDAllowFollowerHandleKey = "PD-Allow-Follower-Handle"
ComponentSignatureKey = "component"
pdAllowFollowerHandleKey = "PD-Allow-Follower-Handle"
componentSignatureKey = "component"
)

// HeaderOption configures the HTTP header.
Expand All @@ -229,7 +229,7 @@ type HeaderOption func(header http.Header)
// WithAllowFollowerHandle sets the header field to allow a PD follower to handle this request.
func WithAllowFollowerHandle() HeaderOption {
return func(header http.Header) {
header.Set("PD-Allow-Follower-Handle", "true")
header.Set(pdAllowFollowerHandleKey, "true")
}
}

Expand Down Expand Up @@ -278,7 +278,7 @@ func (c *client) request(
for _, opt := range headerOpts {
opt(req.Header)
}
req.Header.Set(ComponentSignatureKey, c.callerID)
req.Header.Set(componentSignatureKey, c.callerID)

start := time.Now()
resp, err := c.inner.cli.Do(req)
Expand Down
73 changes: 73 additions & 0 deletions client/http/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Copyright 2023 TiKV Project Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package http

import (
"context"
"net/http"
"testing"

"github.com/stretchr/testify/require"
)

// requestChecker is used to check the HTTP request sent by the client.
type requestChecker struct {
checker func(req *http.Request) error
}

// RoundTrip implements the `http.RoundTripper` interface.
func (rc *requestChecker) RoundTrip(req *http.Request) (resp *http.Response, err error) {
return &http.Response{StatusCode: http.StatusOK}, rc.checker(req)
}

func newHTTPClientWithRequestChecker(checker func(req *http.Request) error) *http.Client {
return &http.Client{
Transport: &requestChecker{checker: checker},
}
}

func TestPDAllowFollowerHandleHeader(t *testing.T) {
re := require.New(t)
var exceptedVal string
httpClient := newHTTPClientWithRequestChecker(func(req *http.Request) error {
val := req.Header.Get(pdAllowFollowerHandleKey)
if val != exceptedVal {
re.Failf("PD allow follower handler header check failed",
"should be %s, but got %s", exceptedVal, val)
}
return nil
})
c := NewClient([]string{"http://127.0.0.1"}, WithHTTPClient(httpClient))
c.GetRegions(context.Background())
exceptedVal = "true"
c.GetHistoryHotRegions(context.Background(), &HistoryHotRegionsRequest{})
}

func TestCallerID(t *testing.T) {
re := require.New(t)
var exceptedVal string
httpClient := newHTTPClientWithRequestChecker(func(req *http.Request) error {
val := req.Header.Get(componentSignatureKey)
if val != exceptedVal {
re.Failf("Caller ID header check failed",
"should be %s, but got %s", exceptedVal, val)
}
return nil
})
c := NewClient([]string{"http://127.0.0.1"}, WithHTTPClient(httpClient))
c.GetRegions(context.Background())
exceptedVal = "test"
c.WithCallerID(exceptedVal).GetRegions(context.Background())
}

0 comments on commit d1b1e03

Please sign in to comment.