Skip to content

Commit

Permalink
Add Code type for authentication webhook response
Browse files Browse the repository at this point in the history
  • Loading branch information
chacha912 committed Oct 16, 2024
1 parent 3e49afb commit cdd952e
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 22 deletions.
21 changes: 19 additions & 2 deletions api/types/auth_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,27 @@ func NewAuthWebhookRequest(reader io.Reader) (*AuthWebhookRequest, error) {
return req, nil
}

// Code represents the result of an authentication webhook request.
type Code int

const (
// CodeOK indicates that the request is fully authenticated and has
// the necessary permissions.
CodeOK Code = 200

// CodeUnauthenticated indicates that the request does not have valid
// authentication credentials for the operation.
CodeUnauthenticated Code = 401

// CodePermissionDenied indicates that the authenticated request lacks
// the necessary permissions.
CodePermissionDenied Code = 403
)

// AuthWebhookResponse represents the response of authentication webhook.
type AuthWebhookResponse struct {
Allowed bool `json:"allowed"`
Reason string `json:"reason"`
Code Code `json:"code"`
Message string `json:"message"`
}

// NewAuthWebhookResponse creates a new instance of AuthWebhookResponse.
Expand Down
30 changes: 21 additions & 9 deletions server/rpc/auth/webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,20 @@ import (
)

var (
// ErrNotAllowed is returned when the given user is not allowed for the access.
ErrNotAllowed = errors.New("method is not allowed for this user")
// ErrPermissionDenied is returned when the given user is not allowed for the access.
ErrPermissionDenied = errors.New("method is not allowed for this user")

// ErrUnexpectedStatusCode is returned when the response code is not 200 from the webhook.
ErrUnexpectedStatusCode = errors.New("unexpected status code from webhook")

// ErrUnexpectedResponse is returned when the response from the webhook is not as expected.
ErrUnexpectedResponse = errors.New("unexpected response from webhook")

// ErrWebhookTimeout is returned when the webhook does not respond in time.
ErrWebhookTimeout = errors.New("webhook timeout")

// ErrUnauthenticated is returned when the request lacks valid authentication credentials.
ErrUnauthenticated = errors.New("request lacks valid authentication credentials")
)

// verifyAccess verifies the given user is allowed to access the given method.
Expand All @@ -63,8 +69,8 @@ func verifyAccess(
cacheKey := string(reqBody)
if entry, ok := be.AuthWebhookCache.Get(cacheKey); ok {
resp := entry
if !resp.Allowed {
return fmt.Errorf("%s: %w", resp.Reason, ErrNotAllowed)
if resp.Code != types.CodeOK {
return fmt.Errorf("%s: %w", resp.Message, ErrPermissionDenied)
}
return nil
}
Expand Down Expand Up @@ -95,13 +101,19 @@ func verifyAccess(
return resp.StatusCode, err
}

if !authResp.Allowed {
return resp.StatusCode, fmt.Errorf("%s: %w", authResp.Reason, ErrNotAllowed)
if authResp.Code == types.CodeOK {
return resp.StatusCode, nil
}
if authResp.Code == types.CodePermissionDenied {
return resp.StatusCode, fmt.Errorf("%s: %w", authResp.Message, ErrPermissionDenied)
}
if authResp.Code == types.CodeUnauthenticated {
return resp.StatusCode, fmt.Errorf("%s: %w", authResp.Message, ErrUnauthenticated)
}

return resp.StatusCode, nil
return resp.StatusCode, fmt.Errorf("%d: %w", authResp.Code, ErrUnexpectedResponse)
}); err != nil {
if errors.Is(err, ErrNotAllowed) {
if errors.Is(err, ErrPermissionDenied) {
be.AuthWebhookCache.Add(cacheKey, authResp, be.Config.ParseAuthWebhookCacheUnauthTTL())
}

Expand All @@ -120,7 +132,7 @@ func withExponentialBackoff(ctx context.Context, cfg *backend.Config, webhookFn
statusCode, err := webhookFn()
if !shouldRetry(statusCode, err) {
if err == ErrUnexpectedStatusCode {
return fmt.Errorf("unexpected status code from webhook: %d", statusCode)
return fmt.Errorf("%d: %w", statusCode, ErrUnexpectedStatusCode)
}

return err
Expand Down
10 changes: 8 additions & 2 deletions server/rpc/connecthelper/status.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,15 @@ var errorToConnectCode = map[error]connect.Code{
converter.ErrUnsupportedCounterType: connect.CodeUnimplemented,

// Unauthenticated means the request does not have valid authentication
auth.ErrNotAllowed: connect.CodeUnauthenticated,
auth.ErrUnexpectedStatusCode: connect.CodeUnauthenticated,
auth.ErrUnexpectedResponse: connect.CodeUnauthenticated,
auth.ErrWebhookTimeout: connect.CodeUnauthenticated,
auth.ErrUnauthenticated: connect.CodeUnauthenticated,
database.ErrMismatchedPassword: connect.CodeUnauthenticated,

// PermissionDenied means the request does not have permission for the operation.
auth.ErrPermissionDenied: connect.CodePermissionDenied,

// Canceled means the operation was canceled (typically by the caller).
context.Canceled: connect.CodeCanceled,
}
Expand Down Expand Up @@ -124,7 +128,9 @@ var errorToCode = map[error]string{
converter.ErrUnsupportedValueType: "ErrUnsupportedValueType",
converter.ErrUnsupportedCounterType: "ErrUnsupportedCounterType",

auth.ErrNotAllowed: "ErrNotAllowed",
auth.ErrPermissionDenied: "ErrPermissionDenied",
auth.ErrUnauthenticated: "ErrUnauthenticated",
auth.ErrUnexpectedResponse: "ErrUnexpectedResponse",
auth.ErrUnexpectedStatusCode: "ErrUnexpectedStatusCode",
auth.ErrWebhookTimeout: "ErrWebhookTimeout",
database.ErrMismatchedPassword: "ErrMismatchedPassword",
Expand Down
22 changes: 13 additions & 9 deletions test/integration/auth_webhook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,15 @@ import (
"github.com/rs/xid"
"github.com/stretchr/testify/assert"

"github.com/yorkie-team/yorkie/api/converter"
"github.com/yorkie-team/yorkie/api/types"
"github.com/yorkie-team/yorkie/client"
"github.com/yorkie-team/yorkie/pkg/document"
"github.com/yorkie-team/yorkie/pkg/document/json"
"github.com/yorkie-team/yorkie/pkg/document/presence"
"github.com/yorkie-team/yorkie/server"
"github.com/yorkie-team/yorkie/server/rpc/auth"
"github.com/yorkie-team/yorkie/server/rpc/connecthelper"
"github.com/yorkie-team/yorkie/test/helper"
)

Expand All @@ -47,9 +50,10 @@ func newAuthServer(t *testing.T) (*httptest.Server, string) {

var res types.AuthWebhookResponse
if req.Token == token {
res.Allowed = true
res.Code = types.CodeOK
} else {
res.Reason = "invalid token"
res.Code = types.CodeUnauthenticated
res.Message = "invalid token"
}

_, err = res.Write(w)
Expand All @@ -64,7 +68,7 @@ func newUnavailableAuthServer(t *testing.T, recoveryCnt uint64) *httptest.Server
assert.NoError(t, err)

var res types.AuthWebhookResponse
res.Allowed = true
res.Code = types.CodeOK
if retries < recoveryCnt-1 {
w.WriteHeader(http.StatusServiceUnavailable)
retries++
Expand Down Expand Up @@ -186,8 +190,7 @@ func TestAuthWebhook(t *testing.T) {
t.Run("authorization webhook that success after retries test", func(t *testing.T) {
ctx := context.Background()

var recoveryCnt uint64
recoveryCnt = 4
var recoveryCnt uint64 = 4
authServer := newUnavailableAuthServer(t, recoveryCnt)

conf := helper.TestConfig()
Expand Down Expand Up @@ -264,6 +267,7 @@ func TestAuthWebhook(t *testing.T) {

err = cli.Activate(ctx)
assert.Equal(t, connect.CodeUnauthenticated, connect.CodeOf(err))
assert.Equal(t, connecthelper.CodeOf(auth.ErrWebhookTimeout), converter.ErrorCodeOf(err))
})

t.Run("authorized request cache test", func(t *testing.T) {
Expand All @@ -274,7 +278,7 @@ func TestAuthWebhook(t *testing.T) {
assert.NoError(t, err)

var res types.AuthWebhookResponse
res.Allowed = true
res.Code = types.CodeOK

_, err = res.Write(w)
assert.NoError(t, err)
Expand Down Expand Up @@ -352,7 +356,7 @@ func TestAuthWebhook(t *testing.T) {
assert.NoError(t, err)

var res types.AuthWebhookResponse
res.Allowed = false
res.Code = types.CodePermissionDenied

_, err = res.Write(w)
assert.NoError(t, err)
Expand Down Expand Up @@ -394,14 +398,14 @@ func TestAuthWebhook(t *testing.T) {
// 01. multiple requests.
for i := 0; i < 3; i++ {
err = cli.Activate(ctx)
assert.Equal(t, connect.CodeUnauthenticated, connect.CodeOf(err))
assert.Equal(t, connect.CodePermissionDenied, connect.CodeOf(err))
}

// 02. multiple requests after eviction by ttl.
time.Sleep(unauthorizedTTL)
for i := 0; i < 3; i++ {
err = cli.Activate(ctx)
assert.Equal(t, connect.CodeUnauthenticated, connect.CodeOf(err))
assert.Equal(t, connect.CodePermissionDenied, connect.CodeOf(err))
}
assert.Equal(t, 2, reqCnt)
})
Expand Down

0 comments on commit cdd952e

Please sign in to comment.