diff --git a/api/converter/errors.go b/api/converter/errors.go index 69ec75f4e..a89b0bfb1 100644 --- a/api/converter/errors.go +++ b/api/converter/errors.go @@ -25,3 +25,22 @@ func ErrorCodeOf(err error) string { } return "" } + +// ErrorMetadataOf returns the error metadata of the given error. +func ErrorMetadataOf(err error) map[string]string { + var connectErr *connect.Error + if !errors.As(err, &connectErr) { + return nil + } + for _, detail := range connectErr.Details() { + msg, valueErr := detail.Value() + if valueErr != nil { + continue + } + + if errorInfo, ok := msg.(*errdetails.ErrorInfo); ok { + return errorInfo.GetMetadata() + } + } + return nil +} diff --git a/api/types/updatable_project_fields_test.go b/api/types/updatable_project_fields_test.go index cba41401d..e8e61263b 100644 --- a/api/types/updatable_project_fields_test.go +++ b/api/types/updatable_project_fields_test.go @@ -26,7 +26,7 @@ import ( ) func TestUpdatableProjectFields(t *testing.T) { - var structError *validation.StructError + var formErr *validation.FormError t.Run("validation test", func(t *testing.T) { newName := "changed-name" newAuthWebhookURL := "http://localhost:3000" @@ -68,7 +68,7 @@ func TestUpdatableProjectFields(t *testing.T) { AuthWebhookMethods: &newAuthWebhookMethods, ClientDeactivateThreshold: &newClientDeactivateThreshold, } - assert.ErrorAs(t, fields.Validate(), &structError) + assert.ErrorAs(t, fields.Validate(), &formErr) }) t.Run("project name format test", func(t *testing.T) { @@ -82,36 +82,36 @@ func TestUpdatableProjectFields(t *testing.T) { fields = &types.UpdatableProjectFields{ Name: &invalidName, } - assert.ErrorAs(t, fields.Validate(), &structError) + assert.ErrorAs(t, fields.Validate(), &formErr) reservedName := "new" fields = &types.UpdatableProjectFields{ Name: &reservedName, } - assert.ErrorAs(t, fields.Validate(), &structError) + assert.ErrorAs(t, fields.Validate(), &formErr) reservedName = "default" fields = &types.UpdatableProjectFields{ Name: &reservedName, } - assert.ErrorAs(t, fields.Validate(), &structError) + assert.ErrorAs(t, fields.Validate(), &formErr) invalidName = "1" fields = &types.UpdatableProjectFields{ Name: &invalidName, } - assert.ErrorAs(t, fields.Validate(), &structError) + assert.ErrorAs(t, fields.Validate(), &formErr) invalidName = "over_30_chracaters_is_invalid_name" fields = &types.UpdatableProjectFields{ Name: &invalidName, } - assert.ErrorAs(t, fields.Validate(), &structError) + assert.ErrorAs(t, fields.Validate(), &formErr) invalidName = "invalid/name" fields = &types.UpdatableProjectFields{ Name: &invalidName, } - assert.ErrorAs(t, fields.Validate(), &structError) + assert.ErrorAs(t, fields.Validate(), &formErr) }) } diff --git a/api/types/user_fields_test.go b/api/types/user_fields_test.go index 15e30339a..e60aae49b 100644 --- a/api/types/user_fields_test.go +++ b/api/types/user_fields_test.go @@ -26,7 +26,7 @@ import ( ) func TestSignupFields(t *testing.T) { - var structError *validation.StructError + var formErr *validation.FormError t.Run("password validation test", func(t *testing.T) { validUsername := "test" @@ -42,48 +42,48 @@ func TestSignupFields(t *testing.T) { Username: &validUsername, Password: &invalidPassword, } - assert.ErrorAs(t, fields.Validate(), &structError) + assert.ErrorAs(t, fields.Validate(), &formErr) invalidPassword = "abcd" fields = &types.UserFields{ Username: &validUsername, Password: &invalidPassword, } - assert.ErrorAs(t, fields.Validate(), &structError) + assert.ErrorAs(t, fields.Validate(), &formErr) invalidPassword = "!@#$" fields = &types.UserFields{ Username: &validUsername, Password: &invalidPassword, } - assert.ErrorAs(t, fields.Validate(), &structError) + assert.ErrorAs(t, fields.Validate(), &formErr) invalidPassword = "abcd1234" fields = &types.UserFields{ Username: &validUsername, Password: &invalidPassword, } - assert.ErrorAs(t, fields.Validate(), &structError) + assert.ErrorAs(t, fields.Validate(), &formErr) invalidPassword = "abcd!@#$" fields = &types.UserFields{ Username: &validUsername, Password: &invalidPassword, } - assert.ErrorAs(t, fields.Validate(), &structError) + assert.ErrorAs(t, fields.Validate(), &formErr) invalidPassword = "1234!@#$" fields = &types.UserFields{ Username: &validUsername, Password: &invalidPassword, } - assert.ErrorAs(t, fields.Validate(), &structError) + assert.ErrorAs(t, fields.Validate(), &formErr) invalidPassword = "abcd1234!@abcd1234!@abcd1234!@1" fields = &types.UserFields{ Username: &validUsername, Password: &invalidPassword, } - assert.ErrorAs(t, fields.Validate(), &structError) + assert.ErrorAs(t, fields.Validate(), &formErr) }) } diff --git a/client/auth.go b/client/auth.go index 7e422c6e2..269d80d84 100644 --- a/client/auth.go +++ b/client/auth.go @@ -39,6 +39,11 @@ func NewAuthInterceptor(apiKey, token string) *AuthInterceptor { } } +// SetToken sets the token. +func (i *AuthInterceptor) SetToken(token string) { + i.token = token +} + // WrapUnary creates a unary server interceptor for authorization. func (i *AuthInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { return func( diff --git a/client/client.go b/client/client.go index 46b39d925..282a19858 100644 --- a/client/client.go +++ b/client/client.go @@ -99,6 +99,7 @@ type Client struct { client v1connect.YorkieServiceClient options Options clientOptions []connect.ClientOption + interceptor *AuthInterceptor logger *zap.Logger id *time.ActorID @@ -149,8 +150,8 @@ func New(opts ...Option) (*Client, error) { } var clientOptions []connect.ClientOption - - clientOptions = append(clientOptions, connect.WithInterceptors(NewAuthInterceptor(options.APIKey, options.Token))) + interceptor := NewAuthInterceptor(options.APIKey, options.Token) + clientOptions = append(clientOptions, connect.WithInterceptors(interceptor)) if options.MaxCallRecvMsgSize != 0 { clientOptions = append(clientOptions, connect.WithReadMaxBytes(options.MaxCallRecvMsgSize)) } @@ -169,6 +170,7 @@ func New(opts ...Option) (*Client, error) { clientOptions: clientOptions, options: options, logger: logger, + interceptor: interceptor, key: k, status: deactivated, @@ -205,6 +207,11 @@ func (c *Client) Dial(rpcAddr string) error { return nil } +// SetToken sets the given token of this client. +func (c *Client) SetToken(token string) { + c.interceptor.SetToken(token) +} + // Close closes all resources of this client. func (c *Client) Close() error { if err := c.Deactivate(context.Background()); err != nil { diff --git a/internal/metaerrors/metaerrors.go b/internal/metaerrors/metaerrors.go new file mode 100644 index 000000000..9fce8d7a1 --- /dev/null +++ b/internal/metaerrors/metaerrors.go @@ -0,0 +1,59 @@ +/* + * Copyright 2024 The Yorkie Authors. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package metaerrors provides a way to attach metadata to errors. +package metaerrors + +import "strings" + +// MetaError is an error that can have metadata attached to it. This can be used +// to send additional information to the SDK or to the user. +type MetaError struct { + // Err is the underlying error. + Err error + + // Metadata is a map of additional information that can be attached to the + // error. + Metadata map[string]string +} + +// New returns a new MetaError with the given error and metadata. +func New(err error, metadata map[string]string) *MetaError { + return &MetaError{ + Err: err, + Metadata: metadata, + } +} + +// Error returns the error message. +func (e MetaError) Error() string { + if len(e.Metadata) == 0 { + return e.Err.Error() + } + + sb := strings.Builder{} + + for key, val := range e.Metadata { + if sb.Len() > 0 { + sb.WriteString(",") + } + sb.WriteString(key) + sb.WriteString("=") + sb.WriteString(val) + } + + return e.Err.Error() + " [" + sb.String() + "]" +} diff --git a/internal/metaerrors/metaerrors_test.go b/internal/metaerrors/metaerrors_test.go new file mode 100644 index 000000000..70a3c286c --- /dev/null +++ b/internal/metaerrors/metaerrors_test.go @@ -0,0 +1,53 @@ +/* + * Copyright 2024 The Yorkie Authors. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package metaerrors_test + +import ( + "errors" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/yorkie-team/yorkie/internal/metaerrors" +) + +func TestMetaError(t *testing.T) { + t.Run("test meta error", func(t *testing.T) { + err := errors.New("error message") + metaErr := metaerrors.New(err, map[string]string{"key": "value"}) + assert.Equal(t, "error message [key=value]", metaErr.Error()) + }) + + t.Run("test meta error without metadata", func(t *testing.T) { + err := errors.New("error message") + metaErr := metaerrors.New(err, nil) + assert.Equal(t, "error message", metaErr.Error()) + }) + + t.Run("test meta error with wrapped error", func(t *testing.T) { + err := fmt.Errorf("wrapped error: %w", errors.New("error message")) + metaErr := metaerrors.New(err, map[string]string{"key": "value"}) + assert.Equal(t, "wrapped error: error message [key=value]", metaErr.Error()) + + metaErr = metaerrors.New(errors.New("error message"), map[string]string{"key": "value"}) + assert.Equal(t, "error message [key=value]", metaErr.Error()) + + wrappedErr := fmt.Errorf("wrapped error: %w", metaErr) + assert.Equal(t, "wrapped error: error message [key=value]", wrappedErr.Error()) + }) +} diff --git a/internal/validation/validation.go b/internal/validation/validation.go index fbf4cbd01..e5ad41d6e 100644 --- a/internal/validation/validation.go +++ b/internal/validation/validation.go @@ -14,7 +14,7 @@ * limitations under the License. */ -// Package validation provides the validation functions. +// Package validation provides the validation functions for form and field. package validation import ( @@ -119,13 +119,13 @@ func (e Violation) Error() string { return e.Err.Error() } -// StructError is the error returned by the validation of struct. -type StructError struct { +// FormError represents the error of the form validation. +type FormError struct { Violations []Violation } // Error returns the error message. -func (s StructError) Error() string { +func (s FormError) Error() string { sb := strings.Builder{} for _, v := range s.Violations { @@ -223,16 +223,16 @@ func Validate(v string, tagOrRules []interface{}) error { // ValidateStruct validates the struct func ValidateStruct(s interface{}) error { if err := defaultValidator.Struct(s); err != nil { - structError := &StructError{} + formErr := &FormError{} for _, e := range err.(validator.ValidationErrors) { - structError.Violations = append(structError.Violations, Violation{ + formErr.Violations = append(formErr.Violations, Violation{ Tag: e.Tag(), Field: e.StructField(), Err: e, Description: e.Translate(trans), }) } - return structError + return formErr } return nil diff --git a/internal/validation/validation_test.go b/internal/validation/validation_test.go index be513e21d..9fe6bcea4 100644 --- a/internal/validation/validation_test.go +++ b/internal/validation/validation_test.go @@ -61,8 +61,8 @@ func TestValidation(t *testing.T) { user := User{Name: "invalid-key-$-wrong-string-value", Country: "korea"} err := ValidateStruct(user) - structError := err.(*StructError) - assert.Len(t, structError.Violations, 2, "user should be invalid") + formErr := err.(*FormError) + assert.Len(t, formErr.Violations, 2, "user should be invalid") }) t.Run("custom rule test", func(t *testing.T) { diff --git a/server/rpc/auth/webhook.go b/server/rpc/auth/webhook.go index cb15a53c8..ec997c316 100644 --- a/server/rpc/auth/webhook.go +++ b/server/rpc/auth/webhook.go @@ -28,17 +28,24 @@ import ( "time" "github.com/yorkie-team/yorkie/api/types" + "github.com/yorkie-team/yorkie/internal/metaerrors" "github.com/yorkie-team/yorkie/server/backend" "github.com/yorkie-team/yorkie/server/logging" ) var ( - // ErrNotAllowed is returned when the given user is not allowed for the access. - ErrNotAllowed = errors.New("method is not allowed for this user") + // ErrUnauthenticated is returned when the authentication is failed. + ErrUnauthenticated = errors.New("unauthenticated") + + // ErrPermissionDenied is returned when the given user is not allowed for the access. + ErrPermissionDenied = errors.New("method is not allowed for this user") // ErrUnexpectedStatusCode is returned when the response code is not 200 from the webhook. ErrUnexpectedStatusCode = errors.New("unexpected status code from webhook") + // ErrUnexpectedResponse is returned when the response from the webhook is not as expected. + ErrUnexpectedResponse = errors.New("unexpected response from webhook") + // ErrWebhookTimeout is returned when the webhook does not respond in time. ErrWebhookTimeout = errors.New("webhook timeout") ) @@ -64,7 +71,7 @@ func verifyAccess( if entry, ok := be.AuthWebhookCache.Get(cacheKey); ok { resp := entry if !resp.Allowed { - return fmt.Errorf("%s: %w", resp.Reason, ErrNotAllowed) + return fmt.Errorf("%s: %w", resp.Reason, ErrPermissionDenied) } return nil } @@ -86,7 +93,9 @@ func verifyAccess( } }() - if http.StatusOK != resp.StatusCode { + if resp.StatusCode != http.StatusOK && + resp.StatusCode != http.StatusUnauthorized && + resp.StatusCode != http.StatusForbidden { return resp.StatusCode, ErrUnexpectedStatusCode } @@ -95,13 +104,22 @@ func verifyAccess( return resp.StatusCode, err } - if !authResp.Allowed { - return resp.StatusCode, fmt.Errorf("%s: %w", authResp.Reason, ErrNotAllowed) + if resp.StatusCode == http.StatusOK && authResp.Allowed { + return resp.StatusCode, nil + } + if resp.StatusCode == http.StatusForbidden && !authResp.Allowed { + return resp.StatusCode, fmt.Errorf("%s: %w", authResp.Reason, ErrPermissionDenied) + } + if resp.StatusCode == http.StatusUnauthorized && !authResp.Allowed { + return resp.StatusCode, metaerrors.New( + ErrUnauthenticated, + map[string]string{"reason": authResp.Reason}, + ) } - return resp.StatusCode, nil + return resp.StatusCode, fmt.Errorf("%d: %w", resp.StatusCode, ErrUnexpectedResponse) }); err != nil { - if errors.Is(err, ErrNotAllowed) { + if errors.Is(err, ErrPermissionDenied) { be.AuthWebhookCache.Add(cacheKey, authResp, be.Config.ParseAuthWebhookCacheUnauthTTL()) } @@ -120,7 +138,7 @@ func withExponentialBackoff(ctx context.Context, cfg *backend.Config, webhookFn statusCode, err := webhookFn() if !shouldRetry(statusCode, err) { if err == ErrUnexpectedStatusCode { - return fmt.Errorf("unexpected status code from webhook: %d", statusCode) + return fmt.Errorf("%d: %w", statusCode, ErrUnexpectedStatusCode) } return err diff --git a/server/rpc/connecthelper/status.go b/server/rpc/connecthelper/status.go index 8708b401b..b56fd161c 100644 --- a/server/rpc/connecthelper/status.go +++ b/server/rpc/connecthelper/status.go @@ -26,6 +26,7 @@ import ( "github.com/yorkie-team/yorkie/api/converter" "github.com/yorkie-team/yorkie/api/types" + "github.com/yorkie-team/yorkie/internal/metaerrors" "github.com/yorkie-team/yorkie/internal/validation" "github.com/yorkie-team/yorkie/pkg/document/key" "github.com/yorkie-team/yorkie/pkg/document/time" @@ -78,11 +79,17 @@ var errorToConnectCode = map[error]connect.Code{ converter.ErrUnsupportedCounterType: connect.CodeUnimplemented, // Unauthenticated means the request does not have valid authentication - auth.ErrNotAllowed: connect.CodeUnauthenticated, - auth.ErrUnexpectedStatusCode: connect.CodeUnauthenticated, - auth.ErrWebhookTimeout: connect.CodeUnauthenticated, + auth.ErrUnauthenticated: connect.CodeUnauthenticated, database.ErrMismatchedPassword: connect.CodeUnauthenticated, + // Internal means an internal error occurred. + auth.ErrUnexpectedStatusCode: connect.CodeInternal, + auth.ErrUnexpectedResponse: connect.CodeInternal, + auth.ErrWebhookTimeout: connect.CodeInternal, + + // PermissionDenied means the request does not have permission for the operation. + auth.ErrPermissionDenied: connect.CodePermissionDenied, + // Canceled means the operation was canceled (typically by the caller). context.Canceled: connect.CodeCanceled, } @@ -124,7 +131,9 @@ var errorToCode = map[error]string{ converter.ErrUnsupportedValueType: "ErrUnsupportedValueType", converter.ErrUnsupportedCounterType: "ErrUnsupportedCounterType", - auth.ErrNotAllowed: "ErrNotAllowed", + auth.ErrPermissionDenied: "ErrPermissionDenied", + auth.ErrUnauthenticated: "ErrUnauthenticated", + auth.ErrUnexpectedResponse: "ErrUnexpectedResponse", auth.ErrUnexpectedStatusCode: "ErrUnexpectedStatusCode", auth.ErrWebhookTimeout: "ErrWebhookTimeout", database.ErrMismatchedPassword: "ErrMismatchedPassword", @@ -179,9 +188,47 @@ func errorToConnectError(err error) (*connect.Error, bool) { return connectErr, true } -// structErrorToConnectError returns connect.Error from the given struct error. -func structErrorToConnectError(err error) (*connect.Error, bool) { - var invalidFieldsError *validation.StructError +// metaErrorToConnectError returns connect.Error from the given rich error. +func metaErrorToConnectError(err error) (*connect.Error, bool) { + var metaErr *metaerrors.MetaError + if !errors.As(err, &metaErr) { + return nil, false + } + + // NOTE(hackerwins): This prevents panic when the cause is an unhashable + // error. + var connectCode connect.Code + var ok bool + defer func() { + if r := recover(); r != nil { + ok = false + } + }() + + connectCode, ok = errorToConnectCode[metaErr.Err] + if !ok { + return nil, false + } + + connectErr := connect.NewError(connectCode, err) + if code, ok := errorToCode[metaErr.Err]; ok { + errorInfo := &errdetails.ErrorInfo{ + Metadata: map[string]string{"code": code}, + } + for key, value := range metaErr.Metadata { + errorInfo.Metadata[key] = value + } + if detail, detailErr := connect.NewErrorDetail(errorInfo); detailErr == nil { + connectErr.AddDetail(detail) + } + } + + return connectErr, true +} + +// formErrorToConnectError returns connect.Error from the given form error. +func formErrorToConnectError(err error) (*connect.Error, bool) { + var invalidFieldsError *validation.FormError if !errors.As(err, &invalidFieldsError) { return nil, false } @@ -199,7 +246,7 @@ func structErrorToConnectError(err error) (*connect.Error, bool) { } func badRequestFromError(err error) (*errdetails.BadRequest, bool) { - var invalidFieldsError *validation.StructError + var invalidFieldsError *validation.FormError if !errors.As(err, &invalidFieldsError) { return nil, false } @@ -225,11 +272,15 @@ func ToStatusError(err error) error { return nil } + if err, ok := metaErrorToConnectError(err); ok { + return err + } + if err, ok := errorToConnectError(err); ok { return err } - if err, ok := structErrorToConnectError(err); ok { + if err, ok := formErrorToConnectError(err); ok { return err } diff --git a/test/integration/auth_webhook_test.go b/test/integration/auth_webhook_test.go index a37a6bf41..b8f9692b4 100644 --- a/test/integration/auth_webhook_test.go +++ b/test/integration/auth_webhook_test.go @@ -29,12 +29,15 @@ import ( "github.com/rs/xid" "github.com/stretchr/testify/assert" + "github.com/yorkie-team/yorkie/api/converter" "github.com/yorkie-team/yorkie/api/types" "github.com/yorkie-team/yorkie/client" "github.com/yorkie-team/yorkie/pkg/document" "github.com/yorkie-team/yorkie/pkg/document/json" "github.com/yorkie-team/yorkie/pkg/document/presence" "github.com/yorkie-team/yorkie/server" + "github.com/yorkie-team/yorkie/server/rpc/auth" + "github.com/yorkie-team/yorkie/server/rpc/connecthelper" "github.com/yorkie-team/yorkie/test/helper" ) @@ -47,8 +50,18 @@ func newAuthServer(t *testing.T) (*httptest.Server, string) { var res types.AuthWebhookResponse if req.Token == token { + w.WriteHeader(http.StatusOK) // 200 res.Allowed = true + } else if req.Token == "not allowed token" { + w.WriteHeader(http.StatusForbidden) // 403 + res.Allowed = false + } else if req.Token == "" { + w.WriteHeader(http.StatusUnauthorized) // 401 + res.Allowed = false + res.Reason = "no token" } else { + w.WriteHeader(http.StatusUnauthorized) // 401 + res.Allowed = false res.Reason = "invalid token" } @@ -58,22 +71,20 @@ func newAuthServer(t *testing.T) (*httptest.Server, string) { } func newUnavailableAuthServer(t *testing.T, recoveryCnt uint64) *httptest.Server { - var retries uint64 + var requestCount uint64 return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, err := types.NewAuthWebhookRequest(r.Body) assert.NoError(t, err) var res types.AuthWebhookResponse res.Allowed = true - if retries < recoveryCnt-1 { + + if requestCount < recoveryCnt { w.WriteHeader(http.StatusServiceUnavailable) - retries++ - } else { - retries = 0 } - _, err = res.Write(w) assert.NoError(t, err) + requestCount++ })) } @@ -89,7 +100,7 @@ func TestProjectAuthWebhook(t *testing.T) { project, err := adminCli.CreateProject(context.Background(), "auth-webhook-test") assert.NoError(t, err) - t.Run("authorization webhook test", func(t *testing.T) { + t.Run("successful authorization test", func(t *testing.T) { ctx := context.Background() authServer, token := newAuthServer(t) @@ -117,6 +128,22 @@ func TestProjectAuthWebhook(t *testing.T) { doc := document.New(helper.TestDocKey(t)) assert.NoError(t, cli.Attach(ctx, doc)) + }) + + t.Run("unauthenticated response test", func(t *testing.T) { + ctx := context.Background() + authServer, _ := newAuthServer(t) + + // project with authorization webhook + project.AuthWebhookURL = authServer.URL + _, err := adminCli.UpdateProject( + ctx, + project.ID.String(), + &types.UpdatableProjectFields{ + AuthWebhookURL: &project.AuthWebhookURL, + }, + ) + assert.NoError(t, err) // client without token cliWithoutToken, err := client.Dial( @@ -127,6 +154,7 @@ func TestProjectAuthWebhook(t *testing.T) { defer func() { assert.NoError(t, cliWithoutToken.Close()) }() err = cliWithoutToken.Activate(ctx) assert.Equal(t, connect.CodeUnauthenticated, connect.CodeOf(err)) + assert.Equal(t, map[string]string{"code": connecthelper.CodeOf(auth.ErrUnauthenticated), "reason": "no token"}, converter.ErrorMetadataOf(err)) // client with invalid token cliWithInvalidToken, err := client.Dial( @@ -138,9 +166,38 @@ func TestProjectAuthWebhook(t *testing.T) { defer func() { assert.NoError(t, cliWithInvalidToken.Close()) }() err = cliWithInvalidToken.Activate(ctx) assert.Equal(t, connect.CodeUnauthenticated, connect.CodeOf(err)) + assert.Equal(t, map[string]string{"code": connecthelper.CodeOf(auth.ErrUnauthenticated), "reason": "invalid token"}, converter.ErrorMetadataOf(err)) }) - t.Run("Selected method authorization webhook test", func(t *testing.T) { + t.Run("permission denied response test", func(t *testing.T) { + ctx := context.Background() + authServer, _ := newAuthServer(t) + + // project with authorization webhook + project.AuthWebhookURL = authServer.URL + _, err := adminCli.UpdateProject( + ctx, + project.ID.String(), + &types.UpdatableProjectFields{ + AuthWebhookURL: &project.AuthWebhookURL, + }, + ) + assert.NoError(t, err) + + // client with not allowed token + cliNotAllowed, err := client.Dial( + svr.RPCAddr(), + client.WithAPIKey(project.PublicKey), + client.WithToken("not allowed token"), + ) + assert.NoError(t, err) + defer func() { assert.NoError(t, cliNotAllowed.Close()) }() + err = cliNotAllowed.Activate(ctx) + assert.Equal(t, connect.CodePermissionDenied, connect.CodeOf(err)) + assert.Equal(t, connecthelper.CodeOf(auth.ErrPermissionDenied), converter.ErrorCodeOf(err)) + }) + + t.Run("selected method authorization webhook test", func(t *testing.T) { ctx := context.Background() authServer, _ := newAuthServer(t) @@ -182,26 +239,40 @@ func TestProjectAuthWebhook(t *testing.T) { }) } -func TestAuthWebhook(t *testing.T) { - t.Run("authorization webhook that success after retries test", func(t *testing.T) { +func TestAuthWebhookErrorHandling(t *testing.T) { + var recoveryCnt uint64 = 4 + + conf := helper.TestConfig() + conf.Backend.AuthWebhookMaxRetries = recoveryCnt + conf.Backend.AuthWebhookMaxWaitInterval = "1000ms" + svr, err := server.New(conf) + assert.NoError(t, err) + assert.NoError(t, svr.Start()) + defer func() { assert.NoError(t, svr.Shutdown(true)) }() + + adminCli := helper.CreateAdminCli(t, svr.RPCAddr()) + defer func() { adminCli.Close() }() + + t.Run("unexpected status code test", func(t *testing.T) { ctx := context.Background() + authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := types.NewAuthWebhookRequest(r.Body) + assert.NoError(t, err) - var recoveryCnt uint64 - recoveryCnt = 4 - authServer := newUnavailableAuthServer(t, recoveryCnt) + var res types.AuthWebhookResponse + res.Allowed = true - conf := helper.TestConfig() - conf.Backend.AuthWebhookMaxRetries = recoveryCnt - conf.Backend.AuthWebhookMaxWaitInterval = "1000ms" - svr, err := server.New(conf) - assert.NoError(t, err) - assert.NoError(t, svr.Start()) - defer func() { assert.NoError(t, svr.Shutdown(true)) }() + // unexpected status code + w.WriteHeader(http.StatusBadRequest) - adminCli := helper.CreateAdminCli(t, svr.RPCAddr()) - defer func() { adminCli.Close() }() - project, err := adminCli.CreateProject(context.Background(), "success-webhook-after-retries") + _, err = res.Write(w) + assert.NoError(t, err) + })) + + // project with authorization webhook + project, err := adminCli.CreateProject(context.Background(), "unexpected-status-code") assert.NoError(t, err) + project.AuthWebhookURL = authServer.URL _, err = adminCli.UpdateProject( ctx, @@ -214,35 +285,90 @@ func TestAuthWebhook(t *testing.T) { cli, err := client.Dial( svr.RPCAddr(), - client.WithToken("token"), client.WithAPIKey(project.PublicKey), + client.WithToken("token"), ) assert.NoError(t, err) defer func() { assert.NoError(t, cli.Close()) }() - err = cli.Activate(ctx) + assert.Equal(t, connect.CodeInternal, connect.CodeOf(err)) + assert.Equal(t, connecthelper.CodeOf(auth.ErrUnexpectedStatusCode), converter.ErrorCodeOf(err)) + }) + + t.Run("unexpected webhook response test", func(t *testing.T) { + ctx := context.Background() + authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := types.NewAuthWebhookRequest(r.Body) + assert.NoError(t, err) + + var res types.AuthWebhookResponse + // mismatched response + res.Allowed = false + + _, err = res.Write(w) + assert.NoError(t, err) + })) + + // project with authorization webhook + project, err := adminCli.CreateProject(context.Background(), "unexpected-response-code") assert.NoError(t, err) - doc := document.New(helper.TestDocKey(t)) - err = cli.Attach(ctx, doc) + project.AuthWebhookURL = authServer.URL + _, err = adminCli.UpdateProject( + ctx, + project.ID.String(), + &types.UpdatableProjectFields{ + AuthWebhookURL: &project.AuthWebhookURL, + }, + ) + assert.NoError(t, err) + + cli, err := client.Dial( + svr.RPCAddr(), + client.WithAPIKey(project.PublicKey), + client.WithToken("token"), + ) assert.NoError(t, err) + defer func() { assert.NoError(t, cli.Close()) }() + err = cli.Activate(ctx) + assert.Equal(t, connect.CodeInternal, connect.CodeOf(err)) + assert.Equal(t, connecthelper.CodeOf(auth.ErrUnexpectedResponse), converter.ErrorCodeOf(err)) }) - t.Run("authorization webhook that fails after retries test", func(t *testing.T) { + t.Run("unavailable authentication server test(timeout)", func(t *testing.T) { ctx := context.Background() - authServer := newUnavailableAuthServer(t, 4) + authServer := newUnavailableAuthServer(t, recoveryCnt+1) - conf := helper.TestConfig() - conf.Backend.AuthWebhookMaxRetries = 2 - conf.Backend.AuthWebhookMaxWaitInterval = "1000ms" - svr, err := server.New(conf) + project, err := adminCli.CreateProject(context.Background(), "unavailable-auth-server") + assert.NoError(t, err) + project.AuthWebhookURL = authServer.URL + _, err = adminCli.UpdateProject( + ctx, + project.ID.String(), + &types.UpdatableProjectFields{ + AuthWebhookURL: &project.AuthWebhookURL, + }, + ) assert.NoError(t, err) - assert.NoError(t, svr.Start()) - defer func() { assert.NoError(t, svr.Shutdown(true)) }() - adminCli := helper.CreateAdminCli(t, svr.RPCAddr()) - defer func() { adminCli.Close() }() - project, err := adminCli.CreateProject(context.Background(), "fail-webhook-after-retries") + cli, err := client.Dial( + svr.RPCAddr(), + client.WithToken("token"), + client.WithAPIKey(project.PublicKey), + ) + assert.NoError(t, err) + defer func() { assert.NoError(t, cli.Close()) }() + + err = cli.Activate(ctx) + assert.Equal(t, connect.CodeInternal, connect.CodeOf(err)) + assert.Equal(t, connecthelper.CodeOf(auth.ErrWebhookTimeout), converter.ErrorCodeOf(err)) + }) + + t.Run("successful authorization after temporarily unavailable server test", func(t *testing.T) { + ctx := context.Background() + authServer := newUnavailableAuthServer(t, recoveryCnt) + + project, err := adminCli.CreateProject(context.Background(), "success-webhook-after-retries") assert.NoError(t, err) project.AuthWebhookURL = authServer.URL _, err = adminCli.UpdateProject( @@ -263,10 +389,16 @@ func TestAuthWebhook(t *testing.T) { defer func() { assert.NoError(t, cli.Close()) }() err = cli.Activate(ctx) - assert.Equal(t, connect.CodeUnauthenticated, connect.CodeOf(err)) + assert.NoError(t, err) + + doc := document.New(helper.TestDocKey(t)) + err = cli.Attach(ctx, doc) + assert.NoError(t, err) }) +} - t.Run("authorized request cache test", func(t *testing.T) { +func TestAuthWebhookCache(t *testing.T) { + t.Run("authorized response cache test", func(t *testing.T) { ctx := context.Background() reqCnt := 0 authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -295,7 +427,7 @@ func TestAuthWebhook(t *testing.T) { adminCli := helper.CreateAdminCli(t, svr.RPCAddr()) defer func() { adminCli.Close() }() - project, err := adminCli.CreateProject(context.Background(), "auth-request-cache") + project, err := adminCli.CreateProject(context.Background(), "authorized-response-cache") assert.NoError(t, err) project.AuthWebhookURL = authServer.URL _, err = adminCli.UpdateProject( @@ -344,13 +476,77 @@ func TestAuthWebhook(t *testing.T) { assert.Equal(t, 2, reqCnt) }) - t.Run("unauthorized request cache test", func(t *testing.T) { + t.Run("permission denied response cache test", func(t *testing.T) { + ctx := context.Background() + reqCnt := 0 + authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := types.NewAuthWebhookRequest(r.Body) + assert.NoError(t, err) + + w.WriteHeader(http.StatusForbidden) + var res types.AuthWebhookResponse + res.Allowed = false + + _, err = res.Write(w) + assert.NoError(t, err) + + reqCnt++ + })) + + unauthorizedTTL := 1 * time.Second + conf := helper.TestConfig() + conf.Backend.AuthWebhookCacheUnauthTTL = unauthorizedTTL.String() + + svr, err := server.New(conf) + assert.NoError(t, err) + assert.NoError(t, svr.Start()) + defer func() { assert.NoError(t, svr.Shutdown(true)) }() + + adminCli := helper.CreateAdminCli(t, svr.RPCAddr()) + defer func() { adminCli.Close() }() + project, err := adminCli.CreateProject(context.Background(), "permission-denied-cache") + assert.NoError(t, err) + project.AuthWebhookURL = authServer.URL + _, err = adminCli.UpdateProject( + ctx, + project.ID.String(), + &types.UpdatableProjectFields{ + AuthWebhookURL: &project.AuthWebhookURL, + }, + ) + assert.NoError(t, err) + + cli, err := client.Dial( + svr.RPCAddr(), + client.WithToken("token"), + client.WithAPIKey(project.PublicKey), + ) + assert.NoError(t, err) + defer func() { assert.NoError(t, cli.Close()) }() + + // 01. multiple requests. + for i := 0; i < 3; i++ { + err = cli.Activate(ctx) + assert.Equal(t, connect.CodePermissionDenied, connect.CodeOf(err)) + } + + // 02. multiple requests after eviction by ttl. + time.Sleep(unauthorizedTTL) + for i := 0; i < 3; i++ { + err = cli.Activate(ctx) + assert.Equal(t, connect.CodePermissionDenied, connect.CodeOf(err)) + } + assert.Equal(t, 2, reqCnt) + }) + + t.Run("other response not cached test", func(t *testing.T) { ctx := context.Background() reqCnt := 0 authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, err := types.NewAuthWebhookRequest(r.Body) assert.NoError(t, err) + w.WriteHeader(http.StatusUnauthorized) var res types.AuthWebhookResponse res.Allowed = false @@ -371,7 +567,7 @@ func TestAuthWebhook(t *testing.T) { adminCli := helper.CreateAdminCli(t, svr.RPCAddr()) defer func() { adminCli.Close() }() - project, err := adminCli.CreateProject(context.Background(), "unauth-request-cache") + project, err := adminCli.CreateProject(context.Background(), "other-response-not-cached") assert.NoError(t, err) project.AuthWebhookURL = authServer.URL _, err = adminCli.UpdateProject( @@ -403,6 +599,49 @@ func TestAuthWebhook(t *testing.T) { err = cli.Activate(ctx) assert.Equal(t, connect.CodeUnauthenticated, connect.CodeOf(err)) } - assert.Equal(t, 2, reqCnt) + assert.Equal(t, 6, reqCnt) + }) +} + +func TestAuthWebhookNewToken(t *testing.T) { + t.Run("set valid token after invalid token test", func(t *testing.T) { + ctx := context.Background() + authServer, validToken := newAuthServer(t) + + svr, err := server.New(helper.TestConfig()) + assert.NoError(t, err) + assert.NoError(t, svr.Start()) + defer func() { assert.NoError(t, svr.Shutdown(true)) }() + + adminCli := helper.CreateAdminCli(t, svr.RPCAddr()) + defer func() { adminCli.Close() }() + project, err := adminCli.CreateProject(context.Background(), "new-auth-token") + assert.NoError(t, err) + project.AuthWebhookURL = authServer.URL + _, err = adminCli.UpdateProject( + ctx, + project.ID.String(), + &types.UpdatableProjectFields{ + AuthWebhookURL: &project.AuthWebhookURL, + }, + ) + assert.NoError(t, err) + + cli, err := client.Dial( + svr.RPCAddr(), + client.WithToken("invalid"), + client.WithAPIKey(project.PublicKey), + ) + assert.NoError(t, err) + defer func() { assert.NoError(t, cli.Close()) }() + + err = cli.Activate(ctx) + assert.Equal(t, connect.CodeUnauthenticated, connect.CodeOf(err)) + + // activate again with valid token + metadata := converter.ErrorMetadataOf(err) + assert.Equal(t, "invalid token", metadata["reason"]) + cli.SetToken(validToken) + assert.NoError(t, cli.Activate(ctx)) }) }