From 738ae2b79d5de5befae43c24e9bcb10e010b8e1a Mon Sep 17 00:00:00 2001 From: Claudio Costa Date: Thu, 18 Jan 2024 09:59:01 -0600 Subject: [PATCH] [MM-54359] Implement session validation check (#620) * Implement session validation check * Simplify tests * Support zero ExpiresAt --- server/websocket.go | 54 ++++++++++++++-- server/websocket_test.go | 134 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 181 insertions(+), 7 deletions(-) diff --git a/server/websocket.go b/server/websocket.go index 6a79e7c1c..84ccfe535 100644 --- a/server/websocket.go +++ b/server/websocket.go @@ -48,6 +48,10 @@ const ( wsReconnectionTimeout = 10 * time.Second ) +var ( + sessionAuthCheckInterval = 10 * time.Second +) + type CallsClientJoinData struct { ChannelID string Title string @@ -379,7 +383,10 @@ func (p *Plugin) OnWebSocketDisconnect(connID, userID string) { } } -func (p *Plugin) wsReader(us *session, handlerID string) { +func (p *Plugin) wsReader(us *session, authSessionID, handlerID string) { + sessionAuthTicker := time.NewTicker(sessionAuthCheckInterval) + defer sessionAuthTicker.Stop() + for { select { case msg, ok := <-us.wsMsgCh: @@ -397,6 +404,39 @@ func (p *Plugin) wsReader(us *session, handlerID string) { return case <-us.rtcCloseCh: return + case <-sessionAuthTicker.C: + // Server versions prior to MM v9.5 won't have the session ID set so we + // cannot go ahead with this check. + // Should be removed as soon as we bump the minimum supported version. + if authSessionID == "" { + continue + } + + if s, appErr := p.API.GetSession(authSessionID); appErr != nil || (s.ExpiresAt != 0 && time.Now().UnixMilli() >= s.ExpiresAt) { + fields := []any{ + "channelID", + us.channelID, + "userID", + us.userID, + "connID", + us.connID, + } + + if appErr != nil { + fields = append(fields, "err", appErr.Error()) + } else { + fields = append(fields, "sessionID", s.Id, "expiresAt", fmt.Sprintf("%d", s.ExpiresAt)) + } + + p.LogInfo("invalid or expired session, closing RTC session", fields...) + + // We forcefully disconnect any session that has been revoked or expired. + if err := p.closeRTCSession(us.userID, us.connID, us.channelID, handlerID); err != nil { + p.LogError("failed to close RTC session", append(fields[:5], "err", err.Error())) + } + + return + } } } } @@ -507,7 +547,7 @@ func (p *Plugin) handleLeave(us *session, userID, connID, channelID string) erro return nil } -func (p *Plugin) handleJoin(userID, connID string, joinData CallsClientJoinData) (retErr error) { +func (p *Plugin) handleJoin(userID, connID, authSessionID string, joinData CallsClientJoinData) (retErr error) { channelID := joinData.ChannelID p.LogDebug("handleJoin", "userID", userID, "connID", connID, "channelID", channelID) @@ -715,12 +755,12 @@ func (p *Plugin) handleJoin(userID, connID string, joinData CallsClientJoinData) "CallID": state.Call.ID, }) - p.wsReader(us, handlerID) + p.wsReader(us, authSessionID, handlerID) return nil } -func (p *Plugin) handleReconnect(userID, connID, channelID, originalConnID, prevConnID string) error { +func (p *Plugin) handleReconnect(userID, connID, channelID, originalConnID, prevConnID, authSessionID string) error { p.LogDebug("handleReconnect", "userID", userID, "connID", connID, "channelID", channelID, "originalConnID", originalConnID, "prevConnID", prevConnID) @@ -814,7 +854,7 @@ func (p *Plugin) handleReconnect(userID, connID, channelID, originalConnID, prev handlerID = state.NodeID } - p.wsReader(us, handlerID) + p.wsReader(us, authSessionID, handlerID) if err := p.handleLeave(us, userID, connID, channelID); err != nil { p.LogError(err.Error()) @@ -870,7 +910,7 @@ func (p *Plugin) WebSocketMessageHasBeenPosted(connID, userID string, req *model } go func() { - if err := p.handleJoin(userID, connID, joinData); err != nil { + if err := p.handleJoin(userID, connID, req.Session.Id, joinData); err != nil { p.LogWarn(err.Error(), "userID", userID, "connID", connID, "channelID", channelID) p.publishWebSocketEvent(wsEventError, map[string]interface{}{ "data": err.Error(), @@ -898,7 +938,7 @@ func (p *Plugin) WebSocketMessageHasBeenPosted(connID, userID string, req *model } go func() { - if err := p.handleReconnect(userID, connID, channelID, originalConnID, prevConnID); err != nil { + if err := p.handleReconnect(userID, connID, channelID, originalConnID, prevConnID, req.Session.Id); err != nil { p.LogWarn(err.Error(), "userID", userID, "connID", connID, "originalConnID", originalConnID, "prevConnID", prevConnID, "channelID", channelID) } diff --git a/server/websocket_test.go b/server/websocket_test.go index 5c527cffb..d4c536bfe 100644 --- a/server/websocket_test.go +++ b/server/websocket_test.go @@ -6,7 +6,11 @@ package main import ( "database/sql/driver" "encoding/json" + "fmt" + "net/http" + "sync" "testing" + "time" "github.com/mattermost/mattermost-plugin-calls/server/cluster" serverMocks "github.com/mattermost/mattermost-plugin-calls/server/mocks/github.com/mattermost/mattermost-plugin-calls/server/interfaces" @@ -234,3 +238,133 @@ func TestHandleBotWSReconnect(t *testing.T) { }) }) } + +func TestWSReader(t *testing.T) { + mockAPI := &pluginMocks.MockAPI{} + mockMetrics := &serverMocks.MockMetrics{} + + p := Plugin{ + MattermostPlugin: plugin.MattermostPlugin{ + API: mockAPI, + }, + callsClusterLocks: map[string]*cluster.Mutex{}, + metrics: mockMetrics, + } + + t.Run("user session validation", func(t *testing.T) { + sessionAuthCheckInterval = time.Second + + t.Run("empty session ID", func(t *testing.T) { + us := newUserSession("userID", "channelID", "connID", false) + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + p.wsReader(us, "", "handlerID") + }() + + time.Sleep(time.Second) + close(us.wsCloseCh) + + wg.Wait() + }) + + t.Run("valid session", func(t *testing.T) { + mockAPI.On("GetSession", "authSessionID").Return(&model.Session{ + Id: "authSessionID", + ExpiresAt: time.Now().UnixMilli() + 60000, + }, nil).Once() + + us := newUserSession("userID", "channelID", "connID", false) + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + p.wsReader(us, "authSessionID", "handlerID") + }() + + time.Sleep(time.Second) + close(us.wsCloseCh) + + wg.Wait() + }) + + t.Run("valid session, no expiration", func(t *testing.T) { + mockAPI.On("GetSession", "authSessionID").Return(&model.Session{ + Id: "authSessionID", + }, nil).Once() + + us := newUserSession("userID", "channelID", "connID", false) + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + p.wsReader(us, "authSessionID", "handlerID") + }() + + time.Sleep(time.Second) + close(us.wsCloseCh) + + wg.Wait() + }) + + t.Run("expired session", func(t *testing.T) { + expiresAt := time.Now().UnixMilli() + us := newUserSession("userID", "channelID", "connID", false) + + mockAPI.On("GetSession", "authSessionID").Return(&model.Session{ + Id: "authSessionID", + ExpiresAt: expiresAt, + }, nil).Once() + + mockAPI.On("LogInfo", "invalid or expired session, closing RTC session", + "origin", mock.AnythingOfType("string"), + "channelID", us.channelID, "userID", us.userID, "connID", us.connID, + "sessionID", "authSessionID", "expiresAt", fmt.Sprintf("%d", expiresAt)).Once() + + mockAPI.On("LogDebug", "closeRTCSession", + "origin", mock.AnythingOfType("string"), + "userID", us.userID, "connID", us.connID, "channelID", us.channelID).Once() + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + p.wsReader(us, "authSessionID", "handlerID") + }() + + time.Sleep(2 * time.Second) + close(us.wsCloseCh) + + wg.Wait() + }) + + t.Run("revoked session", func(t *testing.T) { + us := newUserSession("userID", "channelID", "connID", false) + + mockAPI.On("GetSession", "authSessionID").Return(nil, + model.NewAppError("GetSessionById", "We encountered an error finding the session.", nil, "", http.StatusUnauthorized)).Once() + + mockAPI.On("LogInfo", "invalid or expired session, closing RTC session", + "origin", mock.AnythingOfType("string"), + "channelID", us.channelID, "userID", us.userID, "connID", us.connID, + "err", "GetSessionById: We encountered an error finding the session.").Once() + + mockAPI.On("LogDebug", "closeRTCSession", + "origin", mock.AnythingOfType("string"), + "userID", us.userID, "connID", us.connID, "channelID", us.channelID).Once() + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + p.wsReader(us, "authSessionID", "handlerID") + }() + + time.Sleep(time.Second * 2) + close(us.wsCloseCh) + + wg.Wait() + }) + }) +}