Skip to content

Commit

Permalink
tools/ctl: add caller ID for pd-ctl (#8214)
Browse files Browse the repository at this point in the history
ref #7300

Signed-off-by: husharp <[email protected]>
  • Loading branch information
HuSharp authored May 24, 2024
1 parent 4cd42b3 commit 7cc3b4e
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 15 deletions.
8 changes: 4 additions & 4 deletions tools/pd-ctl/pdctl/command/global.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import (
)

const (
pdControlCallerID = "pd-ctl"
PDControlCallerID = "pd-ctl"
clusterPrefix = "pd/api/v1/cluster"
)

Expand Down Expand Up @@ -107,7 +107,7 @@ func initNewPDClient(cmd *cobra.Command, opts ...pd.ClientOption) error {
if PDCli != nil {
PDCli.Close()
}
PDCli = pd.NewClient(pdControlCallerID, getEndpoints(cmd), opts...)
PDCli = pd.NewClient(PDControlCallerID, getEndpoints(cmd), opts...).WithCallerID(PDControlCallerID)
return nil
}

Expand All @@ -122,7 +122,7 @@ func initNewPDClientWithTLS(cmd *cobra.Command, caPath, certPath, keyPath string

// TODO: replace dialClient with the PD HTTP client completely.
var dialClient = &http.Client{
Transport: apiutil.NewCallerIDRoundTripper(http.DefaultTransport, pdControlCallerID),
Transport: apiutil.NewCallerIDRoundTripper(http.DefaultTransport, PDControlCallerID),
}

// RequireHTTPSClient creates a HTTPS client if the related flags are set
Expand Down Expand Up @@ -153,7 +153,7 @@ func initHTTPSClient(caPath, certPath, keyPath string) error {
}
dialClient = &http.Client{
Transport: apiutil.NewCallerIDRoundTripper(
&http.Transport{TLSClientConfig: tlsConfig}, pdControlCallerID),
&http.Transport{TLSClientConfig: tlsConfig}, PDControlCallerID),
}
return nil
}
Expand Down
45 changes: 34 additions & 11 deletions tools/pd-ctl/tests/global_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,33 +16,44 @@ package tests

import (
"context"
"encoding/json"
"fmt"
"net/http"
"testing"

"github.com/pingcap/log"
"github.com/pingcap/kvproto/pkg/metapb"
"github.com/stretchr/testify/require"
"github.com/tikv/pd/pkg/utils/apiutil"
"github.com/tikv/pd/pkg/utils/assertutil"
"github.com/tikv/pd/pkg/utils/testutil"
"github.com/tikv/pd/server"
cmd "github.com/tikv/pd/tools/pd-ctl/pdctl"
"go.uber.org/zap"
"github.com/tikv/pd/tools/pd-ctl/pdctl/command"
)

const pdControlCallerID = "pd-ctl"

func TestSendAndGetComponent(t *testing.T) {
re := require.New(t)
handler := func(context.Context, *server.Server) (http.Handler, apiutil.APIServiceGroup, error) {
mux := http.NewServeMux()
// check pd http sdk api
mux.HandleFunc("/pd/api/v1/cluster", func(w http.ResponseWriter, r *http.Request) {
callerID := apiutil.GetCallerIDOnHTTP(r)
re.Equal(command.PDControlCallerID, callerID)
cluster := &metapb.Cluster{Id: 1}
clusterBytes, err := json.Marshal(cluster)
re.NoError(err)
w.Write(clusterBytes)
})
// check http client api
// TODO: remove this comment after replacing dialClient with the PD HTTP client completely.
mux.HandleFunc("/pd/api/v1/health", func(w http.ResponseWriter, r *http.Request) {
callerID := apiutil.GetCallerIDOnHTTP(r)
for k := range r.Header {
log.Info("header", zap.String("key", k))
}
log.Info("caller id", zap.String("caller-id", callerID))
re.Equal(pdControlCallerID, callerID)
re.Equal(command.PDControlCallerID, callerID)
fmt.Fprint(w, callerID)
})
mux.HandleFunc("/pd/api/v1/stores", func(w http.ResponseWriter, r *http.Request) {
callerID := apiutil.GetCallerIDOnHTTP(r)
re.Equal(command.PDControlCallerID, callerID)
fmt.Fprint(w, callerID)
})
info := apiutil.APIServiceGroup{
Expand All @@ -64,8 +75,20 @@ func TestSendAndGetComponent(t *testing.T) {
}()

cmd := cmd.GetRootCmd()
args := []string{"-u", pdAddr, "health"}
args := []string{"-u", pdAddr, "cluster"}
output, err := ExecuteCommand(cmd, args...)
re.NoError(err)
re.Equal(fmt.Sprintf("%s\n", pdControlCallerID), string(output))
re.Equal(fmt.Sprintf("%s\n", `{
"id": 1
}`), string(output))

args = []string{"-u", pdAddr, "health"}
output, err = ExecuteCommand(cmd, args...)
re.NoError(err)
re.Equal(fmt.Sprintf("%s\n", command.PDControlCallerID), string(output))

args = []string{"-u", pdAddr, "store"}
output, err = ExecuteCommand(cmd, args...)
re.NoError(err)
re.Equal(fmt.Sprintf("%s\n", command.PDControlCallerID), string(output))
}

0 comments on commit 7cc3b4e

Please sign in to comment.