Skip to content

Commit

Permalink
Merge pull request RackSec#33 from iaburton/custom-dialer
Browse files Browse the repository at this point in the history
Add custom dialer option
  • Loading branch information
sirsean authored Jul 9, 2018
2 parents 1f7cff9 + b1a86db commit a4725f0
Show file tree
Hide file tree
Showing 5 changed files with 151 additions and 6 deletions.
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,22 @@ w.Debug("this is debug")
w.Write([]byte("these are some bytes"))
```

If you need further control over connection attempts, you can use the DialWithCustomDialer
function. To continue with the DialWithTLSConfig example:

```
netDialer := &net.Dialer{Timeout: time.Second*5} // easy timeouts
realNetwork := "tcp" // real network, other vars your dail func can close over
dial := func(network, addr string) (net.Conn, error) {
// cannot use "network" here as it'll simply be "custom" which will fail
return tls.DialWithDialer(netDialer, realNetwork, addr, &config)
}
w, err := DialWithCustomDialer("custom", "192.168.0.52:514", syslog.LOG_ERR, "testtag", dial)
```

Your custom dial func can set timeouts, proxy connections, and do whatever else it needs before returning a net.Conn.

# Generating TLS Certificates

We've provided a script that you can use to generate a self-signed keypair:
Expand Down
17 changes: 17 additions & 0 deletions dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ func (w *Writer) getDialer() dialerFunctionWrapper {
dialers := map[string]dialerFunctionWrapper{
"": dialerFunctionWrapper{"unixDialer", w.unixDialer},
"tcp+tls": dialerFunctionWrapper{"tlsDialer", w.tlsDialer},
"custom": dialerFunctionWrapper{"customDialer", w.customDialer},
}
dialer, ok := dialers[w.network]
if !ok {
Expand Down Expand Up @@ -85,3 +86,19 @@ func (w *Writer) basicDialer() (serverConn, string, error) {
}
return sc, hostname, err
}

// customDialer uses the custom dialer when the Writer was created
// giving developers total control over how connections are made and returned.
// Note it does not check if cdialer is nil, as it should only be referenced from getDialer.
func (w *Writer) customDialer() (serverConn, string, error) {
c, err := w.customDial(w.network, w.raddr)
var sc serverConn
hostname := w.hostname
if err == nil {
sc = &netConn{conn: c}
if hostname == "" {
hostname = c.LocalAddr().String()
}
}
return sc, hostname, err
}
81 changes: 81 additions & 0 deletions dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ package srslog
import (
"crypto/tls"
"crypto/x509"
"errors"
"io/ioutil"
"net"
"testing"
)

Expand Down Expand Up @@ -44,6 +46,13 @@ func TestGetDialer(t *testing.T) {
if "basicDialer" != dialer.Name {
t.Errorf("should get basicDialer, got: %v", dialer)
}

w.network = "custom"
w.customDial = func(string, string) (net.Conn, error) { return nil, nil }
dialer = w.getDialer()
if "customDialer" != dialer.Name {
t.Errorf("should get customDialer, got: %v", dialer)
}
}

func TestUnixDialer(t *testing.T) {
Expand Down Expand Up @@ -196,3 +205,75 @@ func TestUDPDialer(t *testing.T) {
t.Errorf("should not interfere with hostname")
}
}

func TestCustomDialer(t *testing.T) {
// A custom dialer can really be anything, so we don't test an actual connection
// instead we test the behavior of this code path

nwork, addr := "custom", "custom_addr_to_pass"
w := Writer{
priority: LOG_ERR,
tag: "tag",
hostname: "",
network: nwork,
raddr: addr,
customDial: func(n string, a string) (net.Conn, error) {
if n != nwork || a != addr {
return nil, errors.New("Unexpected network or address, expected: (" +
nwork + ":" + addr + ") but received (" + n + ":" + a + ")")
}
return fakeConn{addr: &fakeAddr{nwork, addr}}, nil
},
}

_, hostname, err := w.customDialer()

if err != nil {
t.Errorf("failed to dial: %v", err)
}

if hostname == "" {
t.Errorf("should set default hostname")
}

w.hostname = "my other hostname"

_, hostname, err = w.customDialer()

if err != nil {
t.Errorf("failed to dial: %v", err)
}

if hostname != "my other hostname" {
t.Errorf("should not interfere with hostname")
}
}

type fakeConn struct {
net.Conn
addr net.Addr
}

func (fc fakeConn) Close() error {
return nil
}

func (fc fakeConn) Write(p []byte) (int, error) {
return len(p), nil
}

func (fc fakeConn) LocalAddr() net.Addr {
return fc.addr
}

type fakeAddr struct {
nwork, addr string
}

func (fa *fakeAddr) Network() string {
return fa.nwork
}

func (fa *fakeAddr) String() string {
return fa.addr
}
40 changes: 34 additions & 6 deletions srslog.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ package srslog
import (
"crypto/tls"
"crypto/x509"
"errors"
"io/ioutil"
"log"
"net"
"os"
)

Expand All @@ -15,6 +17,10 @@ type serverConn interface {
close() error
}

// DialFunc is the function signature to be used for a custom dialer callback
// with DialWithCustomDialer
type DialFunc func(string, string) (net.Conn, error)

// New establishes a new connection to the system log daemon. Each
// write to the returned Writer sends a log message with the given
// priority and prefix.
Expand All @@ -31,6 +37,22 @@ func Dial(network, raddr string, priority Priority, tag string) (*Writer, error)
return DialWithTLSConfig(network, raddr, priority, tag, nil)
}

// ErrNilDialFunc is returned from DialWithCustomDialer when a nil DialFunc is passed,
// avoiding a nil pointer deference panic.
var ErrNilDialFunc = errors.New("srslog: nil DialFunc passed to DialWithCustomDialer")

// DialWithCustomDialer establishes a connection by calling customDial.
// Each write to the returned Writer sends a log message with the given facility, severity and tag.
// Network must be "custom" in order for this package to use customDial.
// While network and raddr will be passed to customDial, it is allowed for customDial to ignore them.
// If customDial is nil, this function returns ErrNilDialFunc.
func DialWithCustomDialer(network, raddr string, priority Priority, tag string, customDial DialFunc) (*Writer, error) {
if customDial == nil {
return nil, ErrNilDialFunc
}
return dialAllParameters(network, raddr, priority, tag, nil, customDial)
}

// DialWithTLSCertPath establishes a secure connection to a log daemon by connecting to
// address raddr on the specified network. It uses certPath to load TLS certificates and configure
// the secure connection.
Expand Down Expand Up @@ -59,6 +81,11 @@ func DialWithTLSCert(network, raddr string, priority Priority, tag string, serve
// DialWithTLSConfig establishes a secure connection to a log daemon by connecting to
// address raddr on the specified network. It uses tlsConfig to configure the secure connection.
func DialWithTLSConfig(network, raddr string, priority Priority, tag string, tlsConfig *tls.Config) (*Writer, error) {
return dialAllParameters(network, raddr, priority, tag, tlsConfig, nil)
}

// implementation of the various functions above
func dialAllParameters(network, raddr string, priority Priority, tag string, tlsConfig *tls.Config, customDial DialFunc) (*Writer, error) {
if err := validatePriority(priority); err != nil {
return nil, err
}
Expand All @@ -69,12 +96,13 @@ func DialWithTLSConfig(network, raddr string, priority Priority, tag string, tls
hostname, _ := os.Hostname()

w := &Writer{
priority: priority,
tag: tag,
hostname: hostname,
network: network,
raddr: raddr,
tlsConfig: tlsConfig,
priority: priority,
tag: tag,
hostname: hostname,
network: network,
raddr: raddr,
tlsConfig: tlsConfig,
customDial: customDial,
}

_, err := w.connect()
Expand Down
3 changes: 3 additions & 0 deletions writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ type Writer struct {
framer Framer
formatter Formatter

//non-nil if custom dialer set, used in getDialer
customDial DialFunc

mu sync.RWMutex // guards conn
conn serverConn
}
Expand Down

0 comments on commit a4725f0

Please sign in to comment.