diff --git a/dialer_wrapper.go b/dialer_wrapper.go index a630415..cebaf96 100644 --- a/dialer_wrapper.go +++ b/dialer_wrapper.go @@ -17,14 +17,16 @@ var ( ) type dialerOpts struct { - name string - monitoring bool - tracing bool - parentDialer *net.Dialer + name string + monitoring bool + tracing bool + parentDialContextFunc dialerContextFunc } type dialerOpt func(*dialerOpts) +type dialerContextFunc func(context.Context, string, string) (net.Conn, error) + // DialWithName sets the name of the dialer for tracking and monitoring. // This is the name for the dialer (default is `default`), but for `NewDialContextFunc` can be overwritten from the // Context using `DialNameToContext`. @@ -50,8 +52,13 @@ func DialWithTracing() dialerOpt { // DialWithDialer allows you to override the `net.Dialer` instance used to actually conduct the dials. func DialWithDialer(parentDialer *net.Dialer) dialerOpt { + return DialWithDialContextFunc(parentDialer.DialContext) +} + +// DialWithDialContextFunc allows you to override func gets used for the actual dialing. The default is `net.Dialer.DialContext`. +func DialWithDialContextFunc(parentDialerFunc dialerContextFunc) dialerOpt { return func(opts *dialerOpts) { - opts.parentDialer = parentDialer + opts.parentDialContextFunc = parentDialerFunc } } @@ -72,7 +79,7 @@ func DialNameToContext(ctx context.Context, dialerName string) context.Context { // NewDialContextFunc returns a `DialContext` function that tracks outbound connections. // The signature is compatible with `http.Tranport.DialContext` and is meant to be used there. func NewDialContextFunc(optFuncs ...dialerOpt) func(context.Context, string, string) (net.Conn, error) { - opts := &dialerOpts{name: defaultName, monitoring: true, parentDialer: &net.Dialer{}} + opts := &dialerOpts{name: defaultName, monitoring: true, parentDialContextFunc: (&net.Dialer{}).DialContext} for _, f := range optFuncs { f(opts) } @@ -113,7 +120,7 @@ func dialClientConnTracker(ctx context.Context, network string, addr string, dia if opts.monitoring { reportDialerConnAttempt(dialerName) } - conn, err := opts.parentDialer.DialContext(ctx, network, addr) + conn, err := opts.parentDialContextFunc(ctx, network, addr) if err != nil { if event != nil { event.Errorf("failed dialing: %v", err)