Skip to content

Commit

Permalink
Encode session ids using protobufs.
Browse files Browse the repository at this point in the history
This will reduce their size and improve encoding / decoding performance.

Also fixes GO-2024-3106 (Stack exhaustion in Decoder.Decode in encoding/gob).
  • Loading branch information
fancycode committed Oct 28, 2024
1 parent b90913c commit 5046699
Show file tree
Hide file tree
Showing 11 changed files with 466 additions and 109 deletions.
10 changes: 6 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ VERSION := $(shell "$(CURDIR)/scripts/get-version.sh")
TARVERSION := $(shell "$(CURDIR)/scripts/get-version.sh" --tar)
PACKAGENAME := github.com/strukturag/nextcloud-spreed-signaling
ALL_PACKAGES := $(PACKAGENAME) $(PACKAGENAME)/client $(PACKAGENAME)/proxy $(PACKAGENAME)/server
PROTO_FILES := $(basename $(wildcard *.proto))
PROTO_GO_FILES := $(addsuffix .pb.go,$(PROTO_FILES)) $(addsuffix _grpc.pb.go,$(PROTO_FILES))
GRPC_PROTO_FILES := $(basename $(wildcard grpc_*.proto))
PROTO_FILES := $(filter-out $(GRPC_PROTO_FILES),$(basename $(wildcard *.proto)))
PROTO_GO_FILES := $(addsuffix .pb.go,$(PROTO_FILES))
GRPC_PROTO_GO_FILES := $(addsuffix .pb.go,$(GRPC_PROTO_FILES)) $(addsuffix _grpc.pb.go,$(GRPC_PROTO_FILES))
EASYJSON_GO_FILES := \
api_async_easyjson.go \
api_backend_easyjson.go \
Expand Down Expand Up @@ -139,7 +141,7 @@ coverhtml: vet
$*.proto
sed -i -e '1h;2,$$H;$$!d;g' -re 's|// versions.+// source:|// source:|' $*_grpc.pb.go

common: $(EASYJSON_GO_FILES) $(PROTO_GO_FILES)
common: $(EASYJSON_GO_FILES) $(PROTO_GO_FILES) $(GRPC_PROTO_GO_FILES)

$(BINDIR):
mkdir -p "$(BINDIR)"
Expand All @@ -166,7 +168,7 @@ clean:
rm -f "$(BINDIR)/proxy"

clean-generated: clean
rm -f $(EASYJSON_GO_FILES) $(PROTO_GO_FILES)
rm -f $(EASYJSON_GO_FILES) $(PROTO_GO_FILES) $(GRPC_PROTO_GO_FILES)

build: server proxy

Expand Down
36 changes: 7 additions & 29 deletions client/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ package main

import (
"bytes"
"encoding/base64"
"encoding/json"
"flag"
"fmt"
Expand All @@ -43,7 +42,6 @@ import (

"github.com/dlintw/goconf"
"github.com/gorilla/mux"
"github.com/gorilla/securecookie"
"github.com/gorilla/websocket"
"github.com/mailru/easyjson"

Expand Down Expand Up @@ -75,9 +73,6 @@ const (

// Maximum message size allowed from peer.
maxMessageSize = 64 * 1024

privateSessionName = "private-session"
publicSessionName = "public-session"
)

type Stats struct {
Expand Down Expand Up @@ -120,7 +115,7 @@ type MessagePayload struct {

type SignalingClient struct {
readyWg *sync.WaitGroup
cookie *securecookie.SecureCookie
cookie *signaling.SessionIdCodec

conn *websocket.Conn

Expand All @@ -135,7 +130,7 @@ type SignalingClient struct {
userId string
}

func NewSignalingClient(cookie *securecookie.SecureCookie, url string, stats *Stats, readyWg *sync.WaitGroup, doneWg *sync.WaitGroup) (*SignalingClient, error) {
func NewSignalingClient(cookie *signaling.SessionIdCodec, url string, stats *Stats, readyWg *sync.WaitGroup, doneWg *sync.WaitGroup) (*SignalingClient, error) {
conn, _, err := websocket.DefaultDialer.Dial(url, nil)
if err != nil {
return nil, err
Expand Down Expand Up @@ -215,19 +210,15 @@ func (c *SignalingClient) processMessage(message *signaling.ServerMessage) {
}

func (c *SignalingClient) privateToPublicSessionId(privateId string) string {
var data signaling.SessionIdData
if err := c.cookie.Decode(privateSessionName, privateId, &data); err != nil {
data, err := c.cookie.DecodePrivate(privateId)
if err != nil {
panic(fmt.Sprintf("could not decode private session id: %s", err))
}
encoded, err := c.cookie.Encode(publicSessionName, data)
publicId, err := c.cookie.EncodePublic(data)
if err != nil {
panic(fmt.Sprintf("could not encode public id: %s", err))
}
reversed, err := reverseSessionId(encoded)
if err != nil {
panic(fmt.Sprintf("could not reverse session id: %s", err))
}
return reversed
return publicId
}

func (c *SignalingClient) processHelloMessage(message *signaling.ServerMessage) {
Expand Down Expand Up @@ -493,19 +484,6 @@ func getLocalIP() string {
return ""
}

func reverseSessionId(s string) (string, error) {
// Note that we are assuming base64 encoded strings here.
decoded, err := base64.URLEncoding.DecodeString(s)
if err != nil {
return "", err
}

for i, j := 0, len(decoded)-1; i < j; i, j = i+1, j-1 {
decoded[i], decoded[j] = decoded[j], decoded[i]
}
return base64.URLEncoding.EncodeToString(decoded), nil
}

func main() {
flag.Parse()
log.SetFlags(0)
Expand Down Expand Up @@ -537,7 +515,7 @@ func main() {
default:
log.Fatalf("The sessions block key must be 16, 24 or 32 bytes but is %d bytes", len(blockKey))
}
cookie := securecookie.New([]byte(hashKey), blockBytes).MaxAge(0)
cookie := signaling.NewSessionIdCodec([]byte(hashKey), blockBytes)

cpus := runtime.NumCPU()
runtime.GOMAXPROCS(cpus)
Expand Down
94 changes: 36 additions & 58 deletions hub.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ import (
"github.com/dlintw/goconf"
"github.com/golang-jwt/jwt/v4"
"github.com/gorilla/mux"
"github.com/gorilla/securecookie"
"github.com/gorilla/websocket"
"google.golang.org/protobuf/types/known/timestamppb"
)

var (
Expand Down Expand Up @@ -114,11 +114,6 @@ var (
DefaultTrustedProxies = DefaultPrivateIps()
)

const (
privateSessionName = "private-session"
publicSessionName = "public-session"
)

func init() {
RegisterHubStats()
}
Expand All @@ -127,7 +122,7 @@ type Hub struct {
version string
events AsyncEvents
upgrader websocket.Upgrader
cookie *securecookie.SecureCookie
cookie *SessionIdCodec
info *WelcomeServerMessage
infoInternal *WelcomeServerMessage
welcome atomic.Value // *ServerMessage
Expand Down Expand Up @@ -325,7 +320,7 @@ func NewHub(config *goconf.ConfigFile, events AsyncEvents, rpcServer *GrpcServer
ReadBufferSize: websocketReadBufferSize,
WriteBufferSize: websocketWriteBufferSize,
},
cookie: securecookie.New([]byte(hashKey), blockBytes).MaxAge(0),
cookie: NewSessionIdCodec([]byte(hashKey), blockBytes),
info: NewWelcomeServerMessage(version, DefaultFeatures...),
infoInternal: NewWelcomeServerMessage(version, DefaultFeaturesInternal...),

Expand Down Expand Up @@ -531,35 +526,6 @@ func (h *Hub) Reload(config *goconf.ConfigFile) {
h.rpcClients.Reload(config)
}

func reverseSessionId(s string) (string, error) {
// Note that we are assuming base64 encoded strings here.
decoded, err := base64.URLEncoding.DecodeString(s)
if err != nil {
return "", err
}

for i, j := 0, len(decoded)-1; i < j; i, j = i+1, j-1 {
decoded[i], decoded[j] = decoded[j], decoded[i]
}
return base64.URLEncoding.EncodeToString(decoded), nil
}

func (h *Hub) encodeSessionId(data *SessionIdData, sessionType string) (string, error) {
encoded, err := h.cookie.Encode(sessionType, data)
if err != nil {
return "", err
}
if sessionType == publicSessionName {
// We are reversing the public session ids because clients compare them
// to decide who calls whom. The prefix of the session id is increasing
// (a timestamp) but the suffix the (random) hash.
// By reversing we move the hash to the front, making the comparison of
// session ids "random".
encoded, err = reverseSessionId(encoded)
}
return encoded, err
}

func (h *Hub) getDecodeCache(cache_key string) *LruCache {
hash := fnv.New32a()
hash.Write([]byte(cache_key)) // nolint
Expand Down Expand Up @@ -587,36 +553,48 @@ func (h *Hub) setDecodedSessionId(id string, sessionType string, data *SessionId
cache.Set(cache_key, data)
}

func (h *Hub) decodeSessionId(id string, sessionType string) *SessionIdData {
func (h *Hub) decodePrivateSessionId(id string) *SessionIdData {
if len(id) == 0 {
return nil
}

cache_key := id + "|" + sessionType
cache_key := id + "|" + privateSessionName
cache := h.getDecodeCache(cache_key)
if result := cache.Get(cache_key); result != nil {
return result.(*SessionIdData)
}

if sessionType == publicSessionName {
var err error
id, err = reverseSessionId(id)
if err != nil {
return nil
}
data, err := h.cookie.DecodePrivate(id)
if err != nil {
return nil
}

cache.Set(cache_key, data)
return data
}

func (h *Hub) decodePublicSessionId(id string) *SessionIdData {
if len(id) == 0 {
return nil
}

cache_key := id + "|" + publicSessionName
cache := h.getDecodeCache(cache_key)
if result := cache.Get(cache_key); result != nil {
return result.(*SessionIdData)
}

var data SessionIdData
if h.cookie.Decode(sessionType, id, &data) != nil {
data, err := h.cookie.DecodePublic(id)
if err != nil {
return nil
}

cache.Set(cache_key, &data)
return &data
cache.Set(cache_key, data)
return data
}

func (h *Hub) GetSessionByPublicId(sessionId string) Session {
data := h.decodeSessionId(sessionId, publicSessionName)
data := h.decodePublicSessionId(sessionId)
if data == nil {
return nil
}
Expand All @@ -632,7 +610,7 @@ func (h *Hub) GetSessionByPublicId(sessionId string) Session {
}

func (h *Hub) GetSessionByResumeId(resumeId string) Session {
data := h.decodeSessionId(resumeId, privateSessionName)
data := h.decodePrivateSessionId(resumeId)
if data == nil {
return nil
}
Expand Down Expand Up @@ -834,7 +812,7 @@ func (h *Hub) newSessionIdData(backend *Backend) *SessionIdData {
}
sessionIdData := &SessionIdData{
Sid: sid,
Created: time.Now(),
Created: timestamppb.Now(),
BackendId: backend.Id(),
}
return sessionIdData
Expand Down Expand Up @@ -862,12 +840,12 @@ func (h *Hub) processRegister(c HandlerClient, message *ClientMessage, backend *
}

sessionIdData := h.newSessionIdData(backend)
privateSessionId, err := h.encodeSessionId(sessionIdData, privateSessionName)
privateSessionId, err := h.cookie.EncodePrivate(sessionIdData)
if err != nil {
client.SendMessage(message.NewWrappedErrorServerMessage(err))
return
}
publicSessionId, err := h.encodeSessionId(sessionIdData, publicSessionName)
publicSessionId, err := h.cookie.EncodePublic(sessionIdData)
if err != nil {
client.SendMessage(message.NewWrappedErrorServerMessage(err))
return
Expand Down Expand Up @@ -1172,7 +1150,7 @@ func (h *Hub) processHello(client HandlerClient, message *ClientMessage) {
return
}

data := h.decodeSessionId(resumeId, privateSessionName)
data := h.decodePrivateSessionId(resumeId)
if data == nil {
statsHubSessionResumeFailed.Inc()
if h.tryProxyResume(client, resumeId, message) {
Expand Down Expand Up @@ -2165,7 +2143,7 @@ func (h *Hub) processControlMsg(session Session, message *ClientMessage) {
var room *Room
switch msg.Recipient.Type {
case RecipientTypeSession:
data := h.decodeSessionId(msg.Recipient.SessionId, publicSessionName)
data := h.decodePublicSessionId(msg.Recipient.SessionId)
if data != nil {
if msg.Recipient.SessionId == session.PublicId() {
// Don't loop messages to the sender.
Expand Down Expand Up @@ -2285,12 +2263,12 @@ func (h *Hub) processInternalMsg(sess Session, message *ClientMessage) {
}

sessionIdData := h.newSessionIdData(session.Backend())
privateSessionId, err := h.encodeSessionId(sessionIdData, privateSessionName)
privateSessionId, err := h.cookie.EncodePrivate(sessionIdData)
if err != nil {
log.Printf("Could not encode private virtual session id: %s", err)
return
}
publicSessionId, err := h.encodeSessionId(sessionIdData, publicSessionName)
publicSessionId, err := h.cookie.EncodePublic(sessionIdData)
if err != nil {
log.Printf("Could not encode public virtual session id: %s", err)
return
Expand Down
4 changes: 2 additions & 2 deletions hub_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -907,7 +907,7 @@ func TestClientHelloV2(t *testing.T) {
assert.Equal(testDefaultUserId, hello.Hello.UserId, "%+v", hello.Hello)
assert.NotEmpty(hello.Hello.SessionId, "%+v", hello.Hello)

data := hub.decodeSessionId(hello.Hello.SessionId, publicSessionName)
data := hub.decodePublicSessionId(hello.Hello.SessionId)
require.NotNil(data, "Could not decode session id: %s", hello.Hello.SessionId)

hub.mu.RLock()
Expand Down Expand Up @@ -1281,7 +1281,7 @@ func TestSessionIdsUnordered(t *testing.T) {
assert.Equal(testDefaultUserId, hello.Hello.UserId, "%+v", hello.Hello)
assert.NotEmpty(hello.Hello.SessionId, "%+v", hello.Hello)

data := hub.decodeSessionId(hello.Hello.SessionId, publicSessionName)
data := hub.decodePublicSessionId(hello.Hello.SessionId)
if !assert.NotNil(data, "Could not decode session id: %s", hello.Hello.SessionId) {
break
}
Expand Down
Loading

0 comments on commit 5046699

Please sign in to comment.