Skip to content

Commit

Permalink
Fix error count check when parsing query parameters. (#145)
Browse files Browse the repository at this point in the history
The code should not silently proceed when a single error is found.
Tests added for non-integer count and startIndex query parameters.
  • Loading branch information
mibanescu authored Mar 12, 2024
1 parent 941a5ea commit df55fde
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 1 deletion.
26 changes: 26 additions & 0 deletions handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,32 @@ func TestServerResourcesGetAllHandlerNegativeCount(t *testing.T) {
assertEqual(t, 0, len(response.Resources))
}

func TestServerResourcesGetAllHandlerNonIntCount(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/Users?count=BadBanana", nil)
rr := httptest.NewRecorder()
newTestServer().ServeHTTP(rr, req)

assertEqualStatusCode(t, http.StatusBadRequest, rr.Code)

var response errors.ScimError
assertUnmarshalNoError(t, json.Unmarshal(rr.Body.Bytes(), &response))
assertEqual(t, http.StatusBadRequest, response.Status)
assertEqual(t, "Bad Request. Invalid parameter provided in request: count.", response.Detail)
}

func TestServerResourcesGetAllHandlerNonIntStartIndex(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/Users?startIndex=BadBanana", nil)
rr := httptest.NewRecorder()
newTestServer().ServeHTTP(rr, req)

assertEqualStatusCode(t, http.StatusBadRequest, rr.Code)

var response errors.ScimError
assertUnmarshalNoError(t, json.Unmarshal(rr.Body.Bytes(), &response))
assertEqual(t, http.StatusBadRequest, response.Status)
assertEqual(t, "Bad Request. Invalid parameter provided in request: startIndex.", response.Detail)
}

func TestServerResourcesGetHandler(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/Users", nil)
rr := httptest.NewRecorder()
Expand Down
2 changes: 1 addition & 1 deletion server.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ func (s Server) parseRequestParams(r *http.Request, refSchema schema.Schema, ref
startIndex = defaultStartIndex
}

if len(invalidParams) > 1 {
if len(invalidParams) > 0 {
scimErr := errors.ScimErrorBadParams(invalidParams)
return ListRequestParams{}, &scimErr
}
Expand Down

0 comments on commit df55fde

Please sign in to comment.