Skip to content

Commit

Permalink
Add unix socket support
Browse files Browse the repository at this point in the history
  • Loading branch information
grongor committed Jul 9, 2020
1 parent a3a5a0e commit f073d50
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 4 deletions.
6 changes: 5 additions & 1 deletion cmd/snmp-proxy/snmp-proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,11 @@ func main() {
requester := snmpproxy.NewGosnmpRequester(mibDataProvider)

apiListener := snmpproxy.NewApiListener(validator, requester, config.Logger, config.Api.Listen)
apiListener.Start()

err = apiListener.Start()
if err != nil {
config.Logger.Fatalw("failed to start API listener", zap.Error(err))
}

signals := make(chan os.Signal, 1)
signal.Notify(signals, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM)
Expand Down
25 changes: 22 additions & 3 deletions snmpproxy/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ import (
"encoding/json"
"errors"
"io/ioutil"
"net"
"net/http"
"strings"
"time"

"github.com/prometheus/client_golang/prometheus"
Expand All @@ -21,15 +23,32 @@ type ApiListener struct {
server *http.Server
}

func (l *ApiListener) Start() {
func (l *ApiListener) Start() error {
var (
ln net.Listener
err error
)

if strings.HasSuffix(l.server.Addr, ".sock") {
ln, err = net.Listen("unix", l.server.Addr)
} else {
ln, err = net.Listen("tcp", l.server.Addr)
}

if err != nil {
return err
}

go func() {
err := l.server.ListenAndServe()
err := l.server.Serve(ln)
if !errors.Is(err, http.ErrServerClosed) {
l.logger.Fatalw("failed to start API listener", zap.Error(err))
l.logger.Fatalw("http.Serve error", zap.Error(err))
}
}()

l.logger.Info("API listener started")

return nil
}

func (l *ApiListener) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
Expand Down
54 changes: 54 additions & 0 deletions snmpproxy/api_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
package snmpproxy_test

import (
"context"
"errors"
"io"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -232,6 +235,57 @@ func TestStartAndClose(t *testing.T) {
require.Error(err)
}

func TestStartAndCloseOnSocket(t *testing.T) {
require := require.New(t)

prometheus.DefaultRegisterer = prometheus.NewRegistry()

requester := &mockRequester{}
defer requester.AssertExpectations(t)

requester.On("ExecuteRequest", mock.Anything).Once().Return([][]interface{}{{".1.2.3", 123}}, nil)

f, err := ioutil.TempFile("", "snmp-proxy-test-*.sock")
require.NoError(err)
require.NoError(f.Close())
require.NoError(os.Remove(f.Name()))

listener := snmpproxy.NewApiListener(newValidator(), requester, zap.NewNop().Sugar(), f.Name())
err = listener.Start()
require.NoError(err)

time.Sleep(time.Millisecond * 10)

client := http.Client{
Transport: &http.Transport{
DialContext: func(context.Context, string, string) (net.Conn, error) {
return net.Dial("unix", f.Name())
},
},
}

response, err := client.Post("http://socket/snmp-proxy", "", strings.NewReader(getRequestBody))
require.NoError(err)
require.Equal(http.StatusOK, response.StatusCode)
require.Equal(`{"result":[[".1.2.3",123]]}`, read(response.Body))

listener.Close()

response, err = client.Post("http://socket/snmp-proxy", "", strings.NewReader(getRequestBody))
require.Nil(response)
require.Error(err)
}

func TestStartError(t *testing.T) {
require := require.New(t)

prometheus.DefaultRegisterer = prometheus.NewRegistry()

listener := snmpproxy.NewApiListener(newValidator(), &mockRequester{}, zap.NewNop().Sugar(), "localhost:80")
err := listener.Start()
require.EqualError(err, "listen tcp 127.0.0.1:80: bind: permission denied")
}

func read(r io.Reader) string {
b, err := ioutil.ReadAll(r)
if err != nil {
Expand Down

0 comments on commit f073d50

Please sign in to comment.