diff --git a/go.mod b/go.mod index b7d6a8b36..5e7a173f7 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/FZambia/statik v0.1.2-0.20180217151304-b9f012bb2a1b github.com/FZambia/tarantool v0.3.1 github.com/FZambia/viper-lite v0.0.0-20220110144934-1899f66c7d0e - github.com/centrifugal/centrifuge v0.33.5-0.20241026095149-106d43f21426 + github.com/centrifugal/centrifuge v0.33.5-0.20241103135221-783166d2ec2b github.com/centrifugal/protocol v0.13.4 github.com/cristalhq/jwt/v5 v5.4.0 github.com/gobwas/glob v0.2.3 diff --git a/go.sum b/go.sum index 34823831a..e0a22ed4c 100644 --- a/go.sum +++ b/go.sum @@ -10,8 +10,8 @@ github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= -github.com/centrifugal/centrifuge v0.33.5-0.20241026095149-106d43f21426 h1:g5zZaCr/BybYgq8Nqrnrvqvb3jGGO/Dloil3cFGzzbg= -github.com/centrifugal/centrifuge v0.33.5-0.20241026095149-106d43f21426/go.mod h1:Ck+7H3eVwoeyabKcj3L55oSunaORIOGPAIVB5xrQyGU= +github.com/centrifugal/centrifuge v0.33.5-0.20241103135221-783166d2ec2b h1:Ayx2/Pn0m51Rw9NTXdc3PGLt+kF+2JIEnVr5Hu30lOs= +github.com/centrifugal/centrifuge v0.33.5-0.20241103135221-783166d2ec2b/go.mod h1:yvzNn5hq/bFBpoXQwM8HbU481pAXQkyP2tzvJgFsiN8= github.com/centrifugal/protocol v0.13.4 h1:I0YxXtFNfn/ndDIZp5RkkqQcSSNH7DNPUbXKYtJXDzs= github.com/centrifugal/protocol v0.13.4/go.mod h1:7V5vI30VcoxJe4UD87xi7bOsvI0bmEhvbQuMjrFM2L4= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index d88b57a9d..7e7bc5df3 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -3,6 +3,8 @@ package proxy import ( "context" "encoding/json" + "errors" + "fmt" "net" "strings" @@ -26,6 +28,29 @@ type HttpStatusToCodeTransform struct { ToDisconnect TransformDisconnect `mapstructure:"to_disconnect" json:"to_disconnect"` } +func (t *HttpStatusToCodeTransform) Validate() error { + if t.StatusCode == 0 { + return errors.New("no status code specified") + } + if t.ToDisconnect.Code == 0 && t.ToError.Code == 0 { + return errors.New("no error or disconnect code set") + } + if t.ToDisconnect.Code > 0 && t.ToError.Code > 0 { + return errors.New("only error or disconnect code can be set") + } + if !tools.IsASCII(t.ToDisconnect.Reason) { + return errors.New("disconnect reason must be ASCII") + } + if !tools.IsASCII(t.ToError.Message) { + return errors.New("error message must be ASCII") + } + const reasonOrMessageMaxLength = 123 // limit comes from WebSocket close reason length limit. See https://datatracker.ietf.org/doc/html/rfc6455. + if len(t.ToDisconnect.Reason) > reasonOrMessageMaxLength { + return fmt.Errorf("disconnect reason can be up to %d characters long", reasonOrMessageMaxLength) + } + return nil +} + // Config for proxy. type Config struct { // Name is a unique name of proxy to reference. diff --git a/internal/tools/code_translate.go b/internal/tools/code_translate.go index 28b6f2e90..d421d1cab 100644 --- a/internal/tools/code_translate.go +++ b/internal/tools/code_translate.go @@ -1,11 +1,35 @@ package tools import ( + "errors" "net/http" "github.com/centrifugal/centrifuge" ) +type TransformDisconnect struct { + Code uint32 `mapstructure:"code" json:"code"` + Reason string `mapstructure:"reason" json:"reason"` +} + +type UniConnectCodeToDisconnectTransform struct { + Code uint32 `mapstructure:"code" json:"code"` + To TransformDisconnect `mapstructure:"to" json:"to"` +} + +func (t UniConnectCodeToDisconnectTransform) Validate() error { + if t.Code == 0 { + return errors.New("no code specified") + } + if t.To.Code == 0 { + return errors.New("no disconnect code specified") + } + if !IsASCII(t.To.Reason) { + return errors.New("disconnect reason must be ASCII") + } + return nil +} + type ConnectCodeToHTTPStatus struct { Enabled bool `mapstructure:"enabled" json:"enabled"` Transforms []ConnectCodeToHTTPStatusTransform `mapstructure:"transforms" json:"transforms"` @@ -16,6 +40,16 @@ type ConnectCodeToHTTPStatusTransform struct { To TransformedConnectErrorHttpResponse `mapstructure:"to" json:"to"` } +func (t ConnectCodeToHTTPStatusTransform) Validate() error { + if t.Code == 0 { + return errors.New("no code specified") + } + if t.To.Status == 0 { + return errors.New("no status_code specified") + } + return nil +} + type TransformedConnectErrorHttpResponse struct { Status int `mapstructure:"status_code" json:"status_code"` Body string `mapstructure:"body" json:"body"` diff --git a/main.go b/main.go index 6fd79e80a..b73c8a8cb 100644 --- a/main.go +++ b/main.go @@ -243,6 +243,9 @@ var defaults = map[string]any{ "uni_sse": false, "uni_http_stream": false, + "client_connect_code_to_unidirectional_disconnect.enabled": false, + "client_connect_code_to_unidirectional_disconnect.transforms": []any{}, + "uni_sse_connect_code_to_http_response.enabled": false, "uni_sse_connect_code_to_http_response.transforms": []any{}, "uni_http_stream_connect_code_to_http_response.enabled": false, @@ -1911,6 +1914,11 @@ func granularProxiesFromConfig(v *viper.Viper) []proxy.Config { if p.Endpoint == "" { log.Fatal().Msgf("no endpoint set for proxy %s", p.Name) } + for i, transform := range p.HttpStatusTransforms { + if err := transform.Validate(); err != nil { + log.Fatal().Msgf("error validating proxy_http_status_code_transforms[%d] in proxy %s: %v", i, p.Name, err) + } + } names[p.Name] = struct{}{} } @@ -2075,28 +2083,9 @@ func proxyMapConfig() (*client.ProxyMap, bool) { if v.IsSet("proxy_http_status_code_transforms") { tools.DecodeSlice(v, &httpStatusTransforms, "proxy_http_status_code_transforms") } - for _, transform := range httpStatusTransforms { - if transform.StatusCode == 0 { - log.Fatal().Msg("status should be set in proxy_http_status_code_transforms item") - } - if transform.ToDisconnect.Code == 0 && transform.ToError.Code == 0 { - log.Fatal().Msg("no error or disconnect code set in proxy_http_status_code_transforms item") - } - if transform.ToDisconnect.Code > 0 && transform.ToError.Code > 0 { - log.Fatal().Msg("only error or disconnect code can be set in proxy_http_status_code_transforms item, but not both") - } - if !tools.IsASCII(transform.ToDisconnect.Reason) { - log.Fatal().Msg("proxy_http_status_code_transforms item disconnect reason must be ASCII") - } - if !tools.IsASCII(transform.ToError.Message) { - log.Fatal().Msg("proxy_http_status_code_transforms item error message must be ASCII") - } - const reasonOrMessageMaxLength = 123 // limit comes from WebSocket close reason length limit. See https://datatracker.ietf.org/doc/html/rfc6455. - if len(transform.ToError.Message) > reasonOrMessageMaxLength { - log.Fatal().Msgf("proxy_http_status_code_transforms item error message can be up to %d characters long", reasonOrMessageMaxLength) - } - if len(transform.ToDisconnect.Reason) > reasonOrMessageMaxLength { - log.Fatal().Msgf("proxy_http_status_code_transforms item disconnect reason can be up to %d characters long", reasonOrMessageMaxLength) + for i, transform := range httpStatusTransforms { + if err := transform.Validate(); err != nil { + log.Fatal().Msgf("error validating proxy_http_status_code_transforms[%d]: %v", i, err) } } proxyConfig.HttpStatusTransforms = httpStatusTransforms @@ -2431,6 +2420,23 @@ func nodeConfig(version string) centrifuge.Config { } cfg.LogLevel = level cfg.LogHandler = newLogHandler().handle + + uniCodeTransformsEnabled := viper.GetBool("client_connect_code_to_unidirectional_disconnect.enabled") + if uniCodeTransformsEnabled { + var uniCodeToDisconnectTransforms []tools.UniConnectCodeToDisconnectTransform + if viper.IsSet("client_connect_code_to_unidirectional_disconnect.transforms") { + tools.DecodeSlice(viper.GetViper(), &uniCodeToDisconnectTransforms, "client_connect_code_to_unidirectional_disconnect.transforms") + } + uniCodeTransforms := make(map[uint32]centrifuge.Disconnect) + for _, transform := range uniCodeToDisconnectTransforms { + if err := transform.Validate(); err != nil { + log.Fatal().Msgf("error validating unidirectional code to disconnect transform: %v", err) + } + uniCodeTransforms[transform.Code] = centrifuge.Disconnect{Code: transform.To.Code, Reason: transform.To.Reason} + } + cfg.UnidirectionalCodeToDisconnect = uniCodeTransforms + } + return cfg } @@ -2570,6 +2576,11 @@ func uniSSEHandlerConfig() unisse.Config { if viper.IsSet("uni_sse_connect_code_to_http_response.transforms") { tools.DecodeSlice(viper.GetViper(), &connectCodeToHTTPStatusTransforms, "uni_sse_connect_code_to_http_response.transforms") } + for i, transform := range connectCodeToHTTPStatusTransforms { + if err := transform.Validate(); err != nil { + log.Fatal().Msgf("error validating uni_sse_connect_code_to_http_response.transforms[%d]: %v", i, err) + } + } return unisse.Config{ MaxRequestBodySize: viper.GetInt("uni_sse_max_request_body_size"), PingPongConfig: getPingPongConfig(), @@ -2586,6 +2597,11 @@ func uniStreamHandlerConfig() unihttpstream.Config { if viper.IsSet("uni_http_stream_connect_code_to_http_response.transforms") { tools.DecodeSlice(viper.GetViper(), &connectCodeToHTTPStatusTransforms, "uni_http_stream_connect_code_to_http_response.transforms") } + for i, transform := range connectCodeToHTTPStatusTransforms { + if err := transform.Validate(); err != nil { + log.Fatal().Msgf("error validating uni_http_stream_connect_code_to_http_response.transforms[%d]: %v", i, err) + } + } return unihttpstream.Config{ MaxRequestBodySize: viper.GetInt("uni_http_stream_max_request_body_size"), PingPongConfig: getPingPongConfig(),