Skip to content

Commit

Permalink
add DailContext func override.
Browse files Browse the repository at this point in the history
  • Loading branch information
Michal Witkowski committed Nov 14, 2016
1 parent 180ba6f commit e6f6ee4
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions dialer_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,16 @@ var (
)

type dialerOpts struct {
name string
monitoring bool
tracing bool
parentDialer *net.Dialer
name string
monitoring bool
tracing bool
parentDialContextFunc dialerContextFunc
}

type dialerOpt func(*dialerOpts)

type dialerContextFunc func(context.Context, string, string) (net.Conn, error)

// DialWithName sets the name of the dialer for tracking and monitoring.
// This is the name for the dialer (default is `default`), but for `NewDialContextFunc` can be overwritten from the
// Context using `DialNameToContext`.
Expand All @@ -50,8 +52,13 @@ func DialWithTracing() dialerOpt {

// DialWithDialer allows you to override the `net.Dialer` instance used to actually conduct the dials.
func DialWithDialer(parentDialer *net.Dialer) dialerOpt {
return DialWithDialContextFunc(parentDialer.DialContext)
}

// DialWithDialContextFunc allows you to override func gets used for the actual dialing. The default is `net.Dialer.DialContext`.
func DialWithDialContextFunc(parentDialerFunc dialerContextFunc) dialerOpt {
return func(opts *dialerOpts) {
opts.parentDialer = parentDialer
opts.parentDialContextFunc = parentDialerFunc
}
}

Expand All @@ -72,7 +79,7 @@ func DialNameToContext(ctx context.Context, dialerName string) context.Context {
// NewDialContextFunc returns a `DialContext` function that tracks outbound connections.
// The signature is compatible with `http.Tranport.DialContext` and is meant to be used there.
func NewDialContextFunc(optFuncs ...dialerOpt) func(context.Context, string, string) (net.Conn, error) {
opts := &dialerOpts{name: defaultName, monitoring: true, parentDialer: &net.Dialer{}}
opts := &dialerOpts{name: defaultName, monitoring: true, parentDialContextFunc: (&net.Dialer{}).DialContext}
for _, f := range optFuncs {
f(opts)
}
Expand Down Expand Up @@ -113,7 +120,7 @@ func dialClientConnTracker(ctx context.Context, network string, addr string, dia
if opts.monitoring {
reportDialerConnAttempt(dialerName)
}
conn, err := opts.parentDialer.DialContext(ctx, network, addr)
conn, err := opts.parentDialContextFunc(ctx, network, addr)
if err != nil {
if event != nil {
event.Errorf("failed dialing: %v", err)
Expand Down

0 comments on commit e6f6ee4

Please sign in to comment.