Skip to content

Commit

Permalink
Change implementation to accept batch of requests
Browse files Browse the repository at this point in the history
  • Loading branch information
grongor committed Jul 7, 2020
1 parent c1dde7d commit 28346a6
Show file tree
Hide file tree
Showing 9 changed files with 712 additions and 365 deletions.
41 changes: 30 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Use cases
- use HTTP authentication for your SNMP requests
- encrypt "SNMP traffic" between the client and the server
- bypass firewall
- ...post your use-case via issue :-)
- ...post your use-case via issue/PR :-)

Only SNMP versions 1 and 2c are supported. If you want support for version 3, please, send a pull request.

Expand All @@ -35,16 +35,25 @@ How it works
The application provides a single HTTP endpoint `/snmp-proxy`, which accepts POST requests:
```json
{
"request_type": "getNext",
"host": "192.168.1.1",
"community": "public",
"oids": [
".1.2.3",
".4.5.6"
],
"version": "2c",
"timeout": 10,
"retries": 3
"retries": 3,
"requests": [
{
"request_type": "getNext",
"oids": [
".1.2.3",
".4.5.6"
]
},
{
"request_type": "walk",
"oids": [".7.8.9"],
"max_repetitions": 20
}
]
}
```

Expand All @@ -53,10 +62,20 @@ to the client. Response might look like this:
```json
{
"result": [
".1.2.3.4.5",
123,
".4.5.6.7.8",
"lorem"
[
".1.2.3.4.5",
123,
".4.5.6.7.8",
"lorem"
],
[
".7.8.9.1.1",
"some",
".7.8.9.1.2",
"values",
".7.8.9.1.3",
"here"
]
]
}
```
Expand Down
8 changes: 6 additions & 2 deletions snmpproxy/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func (l *ApiListener) ServeHTTP(writer http.ResponseWriter, request *http.Reques
return
}

var apiRequest Request
apiRequest := &ApiRequest{}
var response = Response{}

defer func() {
Expand All @@ -56,7 +56,7 @@ func (l *ApiListener) ServeHTTP(writer http.ResponseWriter, request *http.Reques
return
}

err = json.Unmarshal(body, &apiRequest)
err = json.Unmarshal(body, apiRequest)
if err != nil {
l.logger.Debugw("failed unmarshal API request", zap.Error(err), "requestBody", string(body))
writer.WriteHeader(http.StatusBadRequest)
Expand All @@ -78,10 +78,14 @@ func (l *ApiListener) ServeHTTP(writer http.ResponseWriter, request *http.Reques

result, err := l.requester.ExecuteRequest(apiRequest)
if err == nil {
l.logger.Debugw("request successful", "request", apiRequest)

writer.WriteHeader(http.StatusOK)

response.Result = result
} else {
l.logger.Debugw("request failed", zap.Error(err), "request", apiRequest)

writer.WriteHeader(http.StatusInternalServerError)

response.Error = err.Error()
Expand Down
44 changes: 23 additions & 21 deletions snmpproxy/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ type mockRequester struct {
mock.Mock
}

func (r *mockRequester) ExecuteRequest(request snmpproxy.Request) ([]interface{}, error) {
args := r.Mock.Called(request)
func (r *mockRequester) ExecuteRequest(apiRequest *snmpproxy.ApiRequest) ([][]interface{}, error) {
args := r.Mock.Called(apiRequest)

result := args.Get(0)
if result == nil {
return nil, args.Error(1)
}

return result.([]interface{}), args.Error(1)
return result.([][]interface{}), args.Error(1)
}

type errReader struct {
Expand All @@ -42,15 +42,16 @@ func (errReader) Read(_ []byte) (n int, err error) {

const getRequestBody = `
{
"request_type": "get",
"host": "localhost",
"oids": [".1.2.3"],
"version": "2c",
"timeout": 3
"timeout": 3,
"requests": [
{"request_type": "get", "oids": [".1.2.3"]}
]
}
`

func TestHandlerErrorNotPost(t *testing.T) {
func TestListenerErrorNotPost(t *testing.T) {
require := require.New(t)

prometheus.DefaultRegisterer = prometheus.NewRegistry()
Expand All @@ -68,7 +69,7 @@ func TestHandlerErrorNotPost(t *testing.T) {
require.Equal(http.StatusMethodNotAllowed, response.StatusCode)
}

func TestHandlerErrorReadingRequest(t *testing.T) {
func TestListenerErrorReadingRequest(t *testing.T) {
require := require.New(t)

prometheus.DefaultRegisterer = prometheus.NewRegistry()
Expand All @@ -87,7 +88,7 @@ func TestHandlerErrorReadingRequest(t *testing.T) {
require.Equal(`{"error":"test error"}`, read(response.Body))
}

func TestHandlerErrorUnmarshalingRequest(t *testing.T) {
func TestListenerErrorUnmarshalingRequest(t *testing.T) {
tests := []struct {
name string
requestBody string
Expand All @@ -101,7 +102,7 @@ func TestHandlerErrorUnmarshalingRequest(t *testing.T) {
{
name: "not expected json struct",
requestBody: `{"something": "else"}`,
err: `{"error":"field request_type mustn't be empty"}`,
err: `{"error":"field host mustn't be empty"}`,
},
}
for _, test := range tests {
Expand All @@ -127,7 +128,7 @@ func TestHandlerErrorUnmarshalingRequest(t *testing.T) {
}
}

func TestHandlerErrorRequestValidatorError(t *testing.T) {
func TestListenerErrorRequestValidatorError(t *testing.T) {
require := require.New(t)

prometheus.DefaultRegisterer = prometheus.NewRegistry()
Expand All @@ -136,13 +137,14 @@ func TestHandlerErrorRequestValidatorError(t *testing.T) {

listener := snmpproxy.NewApiListener(newValidator(), requester, zap.NewNop().Sugar(), "")

requestBody := `
const requestBody = `
{
"request_type": "get",
"host": "localhost",
"oids": [".1.2.3"],
"version": "2c",
"timeout": 100
"timeout": 100,
"requests": [
{"request_type": "get", "oids": [".1.2.3"]}
]
}
`

Expand All @@ -157,7 +159,7 @@ func TestHandlerErrorRequestValidatorError(t *testing.T) {
require.Equal(`{"error":"maximum allowed timeout is 10 seconds, got 100 seconds"}`, read(response.Body))
}

func TestHandlerErrorRequesterError(t *testing.T) {
func TestListenerErrorRequesterError(t *testing.T) {
require := require.New(t)

prometheus.DefaultRegisterer = prometheus.NewRegistry()
Expand All @@ -180,15 +182,15 @@ func TestHandlerErrorRequesterError(t *testing.T) {
require.Equal(`{"error":"some error"}`, read(response.Body))
}

func TestHandlerNoError(t *testing.T) {
func TestListenerNoError(t *testing.T) {
require := require.New(t)

prometheus.DefaultRegisterer = prometheus.NewRegistry()

requester := &mockRequester{}
defer requester.AssertExpectations(t)

requester.On("ExecuteRequest", mock.Anything).Once().Return([]interface{}{".1.2.3", 123}, nil)
requester.On("ExecuteRequest", mock.Anything).Once().Return([][]interface{}{{".1.2.3", 123}}, nil)

listener := snmpproxy.NewApiListener(newValidator(), requester, zap.NewNop().Sugar(), "")

Expand All @@ -200,7 +202,7 @@ func TestHandlerNoError(t *testing.T) {
response := recorder.Result()

require.Equal(http.StatusOK, response.StatusCode)
require.Equal(`{"result":[".1.2.3",123]}`, read(response.Body))
require.Equal(`{"result":[[".1.2.3",123]]}`, read(response.Body))
}

func TestStartAndClose(t *testing.T) {
Expand All @@ -211,7 +213,7 @@ func TestStartAndClose(t *testing.T) {
requester := &mockRequester{}
defer requester.AssertExpectations(t)

requester.On("ExecuteRequest", mock.Anything).Once().Return([]interface{}{".1.2.3", 123}, nil)
requester.On("ExecuteRequest", mock.Anything).Once().Return([][]interface{}{{".1.2.3", 123}}, nil)

listener := snmpproxy.NewApiListener(newValidator(), requester, zap.NewNop().Sugar(), "localhost:15721")
listener.Start()
Expand All @@ -221,7 +223,7 @@ func TestStartAndClose(t *testing.T) {
response, err := http.Post("http://localhost:15721/snmp-proxy", "", strings.NewReader(getRequestBody))
require.NoError(err)
require.Equal(http.StatusOK, response.StatusCode)
require.Equal(`{"result":[".1.2.3",123]}`, read(response.Body))
require.Equal(`{"result":[[".1.2.3",123]]}`, read(response.Body))

listener.Close()

Expand Down
46 changes: 32 additions & 14 deletions snmpproxy/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ type RequestType string
func (t *RequestType) UnmarshalJSON(data []byte) error {
var s string
if err := json.Unmarshal(data, &s); err != nil {
return fmt.Errorf("RequestType should be a string, got %s: %w", string(data), err)
return fmt.Errorf("RequestType must be a string, got %s: %w", string(data), err)
}

*t = RequestType(s)
Expand All @@ -40,7 +40,7 @@ type SnmpVersion gosnmp.SnmpVersion
func (v *SnmpVersion) UnmarshalJSON(data []byte) error {
var s string
if err := json.Unmarshal(data, &s); err != nil {
return fmt.Errorf("snmpVersion should be a string, got %s: %w", string(data), err)
return fmt.Errorf("snmpVersion must be a string, got %s: %w", string(data), err)
}

switch s {
Expand All @@ -58,14 +58,9 @@ func (v *SnmpVersion) UnmarshalJSON(data []byte) error {
}

type Request struct {
RequestType RequestType `json:"request_type"`
Host string `json:"host"`
Community string `json:"community"`
Version SnmpVersion `json:"version"`
Retries uint8 `json:"retries"`
MaxRepetitions uint8 `json:"max_repetitions"`
Timeout time.Duration `json:"timeout"`
Oids []string `json:"oids"`
RequestType RequestType `json:"request_type"`
Oids []string `json:"oids"`
MaxRepetitions uint8 `json:"max_repetitions"`
}

func (r *Request) UnmarshalJSON(data []byte) error {
Expand All @@ -74,13 +69,36 @@ func (r *Request) UnmarshalJSON(data []byte) error {
var t tmp

if err := json.Unmarshal(data, &t); err != nil {
return fmt.Errorf("failed to unmarshal request body into Request struct, got %+v: %w", string(data), err)
return fmt.Errorf("failed to unmarshal Request struct, got %+v: %w", string(data), err)
}

if t.RequestType == "" {
return fmt.Errorf("field request_type mustn't be empty")
}

*r = Request(t)

return nil
}

type ApiRequest struct {
Host string `json:"host"`
Community string `json:"community"`
Version SnmpVersion `json:"version"`
Retries uint8 `json:"retries"`
Timeout time.Duration `json:"timeout"`
Requests []Request `json:"requests"`
}

func (r *ApiRequest) UnmarshalJSON(data []byte) error {
type tmp ApiRequest

var t tmp

if err := json.Unmarshal(data, &t); err != nil {
return fmt.Errorf("failed to unmarshal request body into ApiRequest struct, got %+v: %w", string(data), err)
}

if t.Host == "" {
return fmt.Errorf("field host mustn't be empty")
}
Expand All @@ -107,14 +125,14 @@ func (r *Request) UnmarshalJSON(data []byte) error {

t.Timeout *= time.Second

*r = Request(t)
*r = ApiRequest(t)

return nil
}

type Response struct {
Error string `json:"error,omitempty"`
Result []interface{} `json:"result,omitempty"`
Error string `json:"error,omitempty"`
Result [][]interface{} `json:"result,omitempty"`
}

func (r *Response) Bytes() []byte {
Expand Down
Loading

0 comments on commit 28346a6

Please sign in to comment.