Skip to content

Commit

Permalink
[MM-54359] Implement session validation check (#620)
Browse files Browse the repository at this point in the history
* Implement session validation check

* Simplify tests

* Support zero ExpiresAt
  • Loading branch information
streamer45 authored Jan 18, 2024
1 parent cc47155 commit 738ae2b
Show file tree
Hide file tree
Showing 2 changed files with 181 additions and 7 deletions.
54 changes: 47 additions & 7 deletions server/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ const (
wsReconnectionTimeout = 10 * time.Second
)

var (
sessionAuthCheckInterval = 10 * time.Second
)

type CallsClientJoinData struct {
ChannelID string
Title string
Expand Down Expand Up @@ -379,7 +383,10 @@ func (p *Plugin) OnWebSocketDisconnect(connID, userID string) {
}
}

func (p *Plugin) wsReader(us *session, handlerID string) {
func (p *Plugin) wsReader(us *session, authSessionID, handlerID string) {
sessionAuthTicker := time.NewTicker(sessionAuthCheckInterval)
defer sessionAuthTicker.Stop()

for {
select {
case msg, ok := <-us.wsMsgCh:
Expand All @@ -397,6 +404,39 @@ func (p *Plugin) wsReader(us *session, handlerID string) {
return
case <-us.rtcCloseCh:
return
case <-sessionAuthTicker.C:
// Server versions prior to MM v9.5 won't have the session ID set so we
// cannot go ahead with this check.
// Should be removed as soon as we bump the minimum supported version.
if authSessionID == "" {
continue
}

if s, appErr := p.API.GetSession(authSessionID); appErr != nil || (s.ExpiresAt != 0 && time.Now().UnixMilli() >= s.ExpiresAt) {
fields := []any{
"channelID",
us.channelID,
"userID",
us.userID,
"connID",
us.connID,
}

if appErr != nil {
fields = append(fields, "err", appErr.Error())
} else {
fields = append(fields, "sessionID", s.Id, "expiresAt", fmt.Sprintf("%d", s.ExpiresAt))
}

p.LogInfo("invalid or expired session, closing RTC session", fields...)

// We forcefully disconnect any session that has been revoked or expired.
if err := p.closeRTCSession(us.userID, us.connID, us.channelID, handlerID); err != nil {
p.LogError("failed to close RTC session", append(fields[:5], "err", err.Error()))
}

return
}
}
}
}
Expand Down Expand Up @@ -507,7 +547,7 @@ func (p *Plugin) handleLeave(us *session, userID, connID, channelID string) erro
return nil
}

func (p *Plugin) handleJoin(userID, connID string, joinData CallsClientJoinData) (retErr error) {
func (p *Plugin) handleJoin(userID, connID, authSessionID string, joinData CallsClientJoinData) (retErr error) {
channelID := joinData.ChannelID
p.LogDebug("handleJoin", "userID", userID, "connID", connID, "channelID", channelID)

Expand Down Expand Up @@ -715,12 +755,12 @@ func (p *Plugin) handleJoin(userID, connID string, joinData CallsClientJoinData)
"CallID": state.Call.ID,
})

p.wsReader(us, handlerID)
p.wsReader(us, authSessionID, handlerID)

return nil
}

func (p *Plugin) handleReconnect(userID, connID, channelID, originalConnID, prevConnID string) error {
func (p *Plugin) handleReconnect(userID, connID, channelID, originalConnID, prevConnID, authSessionID string) error {
p.LogDebug("handleReconnect", "userID", userID, "connID", connID, "channelID", channelID,
"originalConnID", originalConnID, "prevConnID", prevConnID)

Expand Down Expand Up @@ -814,7 +854,7 @@ func (p *Plugin) handleReconnect(userID, connID, channelID, originalConnID, prev
handlerID = state.NodeID
}

p.wsReader(us, handlerID)
p.wsReader(us, authSessionID, handlerID)

if err := p.handleLeave(us, userID, connID, channelID); err != nil {
p.LogError(err.Error())
Expand Down Expand Up @@ -870,7 +910,7 @@ func (p *Plugin) WebSocketMessageHasBeenPosted(connID, userID string, req *model
}

go func() {
if err := p.handleJoin(userID, connID, joinData); err != nil {
if err := p.handleJoin(userID, connID, req.Session.Id, joinData); err != nil {
p.LogWarn(err.Error(), "userID", userID, "connID", connID, "channelID", channelID)
p.publishWebSocketEvent(wsEventError, map[string]interface{}{
"data": err.Error(),
Expand Down Expand Up @@ -898,7 +938,7 @@ func (p *Plugin) WebSocketMessageHasBeenPosted(connID, userID string, req *model
}

go func() {
if err := p.handleReconnect(userID, connID, channelID, originalConnID, prevConnID); err != nil {
if err := p.handleReconnect(userID, connID, channelID, originalConnID, prevConnID, req.Session.Id); err != nil {
p.LogWarn(err.Error(), "userID", userID, "connID", connID,
"originalConnID", originalConnID, "prevConnID", prevConnID, "channelID", channelID)
}
Expand Down
134 changes: 134 additions & 0 deletions server/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@ package main
import (
"database/sql/driver"
"encoding/json"
"fmt"
"net/http"
"sync"
"testing"
"time"

"github.com/mattermost/mattermost-plugin-calls/server/cluster"
serverMocks "github.com/mattermost/mattermost-plugin-calls/server/mocks/github.com/mattermost/mattermost-plugin-calls/server/interfaces"
Expand Down Expand Up @@ -234,3 +238,133 @@ func TestHandleBotWSReconnect(t *testing.T) {
})
})
}

func TestWSReader(t *testing.T) {
mockAPI := &pluginMocks.MockAPI{}
mockMetrics := &serverMocks.MockMetrics{}

p := Plugin{
MattermostPlugin: plugin.MattermostPlugin{
API: mockAPI,
},
callsClusterLocks: map[string]*cluster.Mutex{},
metrics: mockMetrics,
}

t.Run("user session validation", func(t *testing.T) {
sessionAuthCheckInterval = time.Second

t.Run("empty session ID", func(t *testing.T) {
us := newUserSession("userID", "channelID", "connID", false)
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
p.wsReader(us, "", "handlerID")
}()

time.Sleep(time.Second)
close(us.wsCloseCh)

wg.Wait()
})

t.Run("valid session", func(t *testing.T) {
mockAPI.On("GetSession", "authSessionID").Return(&model.Session{
Id: "authSessionID",
ExpiresAt: time.Now().UnixMilli() + 60000,
}, nil).Once()

us := newUserSession("userID", "channelID", "connID", false)
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
p.wsReader(us, "authSessionID", "handlerID")
}()

time.Sleep(time.Second)
close(us.wsCloseCh)

wg.Wait()
})

t.Run("valid session, no expiration", func(t *testing.T) {
mockAPI.On("GetSession", "authSessionID").Return(&model.Session{
Id: "authSessionID",
}, nil).Once()

us := newUserSession("userID", "channelID", "connID", false)
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
p.wsReader(us, "authSessionID", "handlerID")
}()

time.Sleep(time.Second)
close(us.wsCloseCh)

wg.Wait()
})

t.Run("expired session", func(t *testing.T) {
expiresAt := time.Now().UnixMilli()
us := newUserSession("userID", "channelID", "connID", false)

mockAPI.On("GetSession", "authSessionID").Return(&model.Session{
Id: "authSessionID",
ExpiresAt: expiresAt,
}, nil).Once()

mockAPI.On("LogInfo", "invalid or expired session, closing RTC session",
"origin", mock.AnythingOfType("string"),
"channelID", us.channelID, "userID", us.userID, "connID", us.connID,
"sessionID", "authSessionID", "expiresAt", fmt.Sprintf("%d", expiresAt)).Once()

mockAPI.On("LogDebug", "closeRTCSession",
"origin", mock.AnythingOfType("string"),
"userID", us.userID, "connID", us.connID, "channelID", us.channelID).Once()

var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
p.wsReader(us, "authSessionID", "handlerID")
}()

time.Sleep(2 * time.Second)
close(us.wsCloseCh)

wg.Wait()
})

t.Run("revoked session", func(t *testing.T) {
us := newUserSession("userID", "channelID", "connID", false)

mockAPI.On("GetSession", "authSessionID").Return(nil,
model.NewAppError("GetSessionById", "We encountered an error finding the session.", nil, "", http.StatusUnauthorized)).Once()

mockAPI.On("LogInfo", "invalid or expired session, closing RTC session",
"origin", mock.AnythingOfType("string"),
"channelID", us.channelID, "userID", us.userID, "connID", us.connID,
"err", "GetSessionById: We encountered an error finding the session.").Once()

mockAPI.On("LogDebug", "closeRTCSession",
"origin", mock.AnythingOfType("string"),
"userID", us.userID, "connID", us.connID, "channelID", us.channelID).Once()

var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
p.wsReader(us, "authSessionID", "handlerID")
}()

time.Sleep(time.Second * 2)
close(us.wsCloseCh)

wg.Wait()
})
})
}

0 comments on commit 738ae2b

Please sign in to comment.