diff --git a/pkg/apiclient/apiclient_test.go b/pkg/apiclient/apiclient_test.go index 7bb3b36befdde..b4b35d0b80d48 100644 --- a/pkg/apiclient/apiclient_test.go +++ b/pkg/apiclient/apiclient_test.go @@ -7,10 +7,34 @@ import ( ) func Test_parseHeaders(t *testing.T) { - headerString := []string{"foo:", "foo1:bar1", "foo2:bar2:bar2"} - headers, err := parseHeaders(headerString) - assert.NoError(t, err) - assert.Equal(t, headers.Get("foo"), "") - assert.Equal(t, headers.Get("foo1"), "bar1") - assert.Equal(t, headers.Get("foo2"), "bar2:bar2") + t.Run("Header parsed successfully", func(t *testing.T) { + headerString := []string{"foo:", "foo1:bar1", "foo2:bar2:bar2"} + headers, err := parseHeaders(headerString) + assert.NoError(t, err) + assert.Equal(t, headers.Get("foo"), "") + assert.Equal(t, headers.Get("foo1"), "bar1") + assert.Equal(t, headers.Get("foo2"), "bar2:bar2") + }) + + t.Run("Header parsed error", func(t *testing.T) { + headerString := []string{"foo"} + _, err := parseHeaders(headerString) + assert.ErrorContains(t, err, "additional headers must be colon(:)-separated: foo") + }) +} + +func Test_parseGRPCHeaders(t *testing.T) { + t.Run("Header parsed successfully", func(t *testing.T) { + headerStrings := []string{"origin: https://foo.bar", "content-length: 123"} + headers, err := parseGRPCHeaders(headerStrings) + assert.NoError(t, err) + assert.Equal(t, headers.Get("origin"), []string{" https://foo.bar"}) + assert.Equal(t, headers.Get("content-length"), []string{" 123"}) + }) + + t.Run("Header parsed error", func(t *testing.T) { + headerString := []string{"foo"} + _, err := parseGRPCHeaders(headerString) + assert.ErrorContains(t, err, "additional headers must be colon(:)-separated: foo") + }) } diff --git a/pkg/apiclient/grpcproxy.go b/pkg/apiclient/grpcproxy.go index 28af7b62783df..72fea42efee3f 100644 --- a/pkg/apiclient/grpcproxy.go +++ b/pkg/apiclient/grpcproxy.go @@ -131,14 +131,14 @@ func (c *client) startGRPCProxy() (*grpc.Server, net.Listener, error) { } md, _ := metadata.FromIncomingContext(stream.Context()) + headersMD, err := parseGRPCHeaders(c.Headers) - for _, kv := range c.Headers { - if len(strings.Split(kv, ":"))%2 == 1 { - return fmt.Errorf("additional headers key/values must be separated by a colon(:): %s", kv) - } - md.Append(strings.Split(kv, ":")[0], strings.Split(kv, ":")[1]) + if err != nil { + return err } + md = metadata.Join(md, headersMD) + resp, err := c.executeRequest(fullMethodName, msg, md) if err != nil { return err @@ -216,3 +216,16 @@ func (c *client) useGRPCProxy() (net.Addr, io.Closer, error) { return nil }), nil } + +func parseGRPCHeaders(headerStrings []string) (metadata.MD, error) { + md := metadata.New(map[string]string{}) + for _, kv := range headerStrings { + i := strings.IndexByte(kv, ':') + // zero means meaningless empty header name + if i <= 0 { + return nil, fmt.Errorf("additional headers must be colon(:)-separated: %s", kv) + } + md.Append(kv[0:i], kv[i+1:]) + } + return md, nil +}