From 7b1686ea0b3db0e1e84927b2d417b8d7554c5330 Mon Sep 17 00:00:00 2001 From: Jonathan Basseri Date: Fri, 3 Apr 2020 17:56:11 -0700 Subject: [PATCH] Add DialContext option to v1 client This adds a configurable DialContext function to the client library, which will be invoked to create HTTP/HTTPS connections. Allowing a custom DialContext captures a wide variety of use cases and behaviors by delegating the responsibility to the DialContext function. For example, users can set dial timeouts, share connections, or tunnel/proxy connections using custom logic. Signed-off-by: Jonathan Basseri --- client.go | 35 +++++++++++++++++++++++++++++++++++ client_test.go | 23 +++++++++++++++++++++++ 2 files changed, 58 insertions(+) diff --git a/client.go b/client.go index 035f1a1..1a6dc0d 100644 --- a/client.go +++ b/client.go @@ -44,6 +44,9 @@ var ( // ErrProxyNotSupported is returned when a client is unable to set a proxy for http requests. ErrProxyNotSupported = errors.New("client does not support http proxy") + // ErrDialerNotSupported is returned when a client is unable to set a DialContext for http requests. + ErrDialerNotSupported = errors.New("client does not support setting DialContext") + // DefaultPort is the default API port. DefaultPort = "5705" @@ -107,6 +110,8 @@ type Dialer interface { Dial(network, address string) (net.Conn, error) } +type dialContext = func(ctx context.Context, network, address string) (net.Conn, error) + // NewClient returns a Client instance ready for communication with the given // server endpoint. It will use the latest remote API version available in the // server. @@ -203,6 +208,36 @@ func (c *Client) SetTimeout(t time.Duration) { } } +// GetDialContext returns the current DialContext function, or nil if there is none. +func (c *Client) GetDialContext() dialContext { + c.configLock.RLock() + defer c.configLock.RUnlock() + + if c.httpClient == nil { + return nil + } + transport, supported := c.httpClient.Transport.(*http.Transport) + if !supported { + return nil + } + return transport.DialContext +} + +// SetDialContext uses the given dial function to establish TCP connections in the HTTPClient. +func (c *Client) SetDialContext(dial dialContext) error { + c.configLock.Lock() + defer c.configLock.Unlock() + + if client := c.httpClient; client != nil { + transport, supported := client.Transport.(*http.Transport) + if !supported { + return ErrDialerNotSupported + } + transport.DialContext = dial + } + return nil +} + func (c *Client) checkAPIVersion() error { serverAPIVersionString, err := c.getServerAPIVersionString() if err != nil { diff --git a/client_test.go b/client_test.go index 74f696f..5c3225e 100644 --- a/client_test.go +++ b/client_test.go @@ -9,6 +9,7 @@ import ( "context" "fmt" "io/ioutil" + "net" "net/http" "net/http/httptest" "net/url" @@ -243,6 +244,28 @@ func TestSetAuth(t *testing.T) { } } +func TestSetDialContext(t *testing.T) { + client, err := NewClient("http://localhost:8080") + if err != nil { + t.Fatal(err) + } + dialerCalled := false + dialContext := func(ctx context.Context, network, address string) (net.Conn, error) { + dialerCalled = true + return nil, fmt.Errorf("not implemented") + } + if err := client.SetDialContext(dialContext); err != nil { + t.Fatalf("Failed to set DialContext: %v", err) + } + + if err := client.Ping(); err == nil { + t.Error("Expected Ping() to return an error") + } + if !dialerCalled { + t.Error("Expected custom DialContext to be called") + } +} + func TestPing(t *testing.T) { fakeRT := &FakeRoundTripper{message: "", status: http.StatusOK} client := newTestClient(fakeRT)