Skip to content

Commit

Permalink
Move common flags code to own struct.
Browse files Browse the repository at this point in the history
  • Loading branch information
fancycode committed Oct 30, 2023
1 parent a43686a commit 6316c53
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 98 deletions.
15 changes: 3 additions & 12 deletions clientsession.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ type ClientSession struct {
userId string
userData *json.RawMessage

inCall atomic.Uint32
inCall Flags
supportsPermissions bool
permissions map[Permission]bool

Expand Down Expand Up @@ -169,24 +169,15 @@ func (s *ClientSession) ClientType() string {

// GetInCall is only used for internal clients.
func (s *ClientSession) GetInCall() int {
return int(s.inCall.Load())
return int(s.inCall.Get())
}

func (s *ClientSession) SetInCall(inCall int) bool {
if inCall < 0 {
inCall = 0
}

for {
old := s.inCall.Load()
if old == uint32(inCall) {
return false
}

if s.inCall.CompareAndSwap(old, uint32(inCall)) {
return true
}
}
return s.inCall.Set(uint32(inCall))
}

func (s *ClientSession) GetFeatures() []string {
Expand Down
75 changes: 75 additions & 0 deletions flags.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/**
* Standalone signaling server for the Nextcloud Spreed app.
* Copyright (C) 2023 struktur AG
*
* @author Joachim Bauch <[email protected]>
*
* @license GNU AGPL version 3 or any later version
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package signaling

import "sync/atomic"

type Flags struct {
flags atomic.Uint32
}

func (f *Flags) Add(flags uint32) bool {
for {
old := f.flags.Load()
if old&flags == flags {
// Flags already set.
return false
}
newFlags := old | flags
if f.flags.CompareAndSwap(old, newFlags) {
return true
}
// Another thread updated the flags while we were checking, retry.
}
}

func (f *Flags) Remove(flags uint32) bool {
for {
old := f.flags.Load()
if old&flags == 0 {
// Flags not set.
return false
}
newFlags := old & ^flags
if f.flags.CompareAndSwap(old, newFlags) {
return true
}
// Another thread updated the flags while we were checking, retry.
}
}

func (f *Flags) Set(flags uint32) bool {
for {
old := f.flags.Load()
if old == flags {
return false
}

if f.flags.CompareAndSwap(old, flags) {
return true
}
}
}

func (f *Flags) Get() uint32 {
return f.flags.Load()
}
73 changes: 73 additions & 0 deletions flags_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/**
* Standalone signaling server for the Nextcloud Spreed app.
* Copyright (C) 2023 struktur AG
*
* @author Joachim Bauch <[email protected]>
*
* @license GNU AGPL version 3 or any later version
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package signaling

import "testing"

func TestFlags(t *testing.T) {
var f Flags
if f.Get() != 0 {
t.Fatalf("Expected flags 0, got %d", f.Get())
}
if !f.Add(1) {
t.Error("expected true")
}
if f.Get() != 1 {
t.Fatalf("Expected flags 1, got %d", f.Get())
}
if f.Add(1) {
t.Error("expected false")
}
if f.Get() != 1 {
t.Fatalf("Expected flags 1, got %d", f.Get())
}
if !f.Add(2) {
t.Error("expected true")
}
if f.Get() != 3 {
t.Fatalf("Expected flags 3, got %d", f.Get())
}
if !f.Remove(1) {
t.Error("expected true")
}
if f.Get() != 2 {
t.Fatalf("Expected flags 2, got %d", f.Get())
}
if f.Remove(1) {
t.Error("expected false")
}
if f.Get() != 2 {
t.Fatalf("Expected flags 2, got %d", f.Get())
}
if !f.Add(3) {
t.Error("expected true")
}
if f.Get() != 3 {
t.Fatalf("Expected flags 3, got %d", f.Get())
}
if !f.Remove(1) {
t.Error("expected true")
}
if f.Get() != 2 {
t.Fatalf("Expected flags 2, got %d", f.Get())
}
}
60 changes: 11 additions & 49 deletions virtualsession.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ type VirtualSession struct {
sessionId string
userId string
userData *json.RawMessage
inCall atomic.Uint32
flags atomic.Uint32
inCall Flags
flags Flags
options *AddSessionOptions
}

Expand All @@ -69,7 +69,6 @@ func NewVirtualSession(session *ClientSession, privateId string, publicId string
userData: msg.User,
options: msg.Options,
}
result.flags.Store(msg.Flags)

if err := session.events.RegisterSessionListener(publicId, session.Backend(), result); err != nil {
return nil, err
Expand All @@ -80,6 +79,9 @@ func NewVirtualSession(session *ClientSession, privateId string, publicId string
} else if !session.HasFeature(ClientFeatureInternalInCall) {
result.SetInCall(FlagInCall | FlagWithPhone)
}
if msg.Flags != 0 {
result.SetFlags(msg.Flags)
}

return result, nil
}
Expand All @@ -97,24 +99,15 @@ func (s *VirtualSession) ClientType() string {
}

func (s *VirtualSession) GetInCall() int {
return int(s.inCall.Load())
return int(s.inCall.Get())
}

func (s *VirtualSession) SetInCall(inCall int) bool {
if inCall < 0 {
inCall = 0
}

for {
old := s.inCall.Load()
if old == uint32(inCall) {
return false
}

if s.inCall.CompareAndSwap(old, uint32(inCall)) {
return true
}
}
return s.inCall.Set(uint32(inCall))
}

func (s *VirtualSession) Data() *SessionIdData {
Expand Down Expand Up @@ -240,50 +233,19 @@ func (s *VirtualSession) SessionId() string {
}

func (s *VirtualSession) AddFlags(flags uint32) bool {
for {
old := s.flags.Load()
if old&flags == flags {
// Flags already set.
return false
}
newFlags := old | flags
if s.flags.CompareAndSwap(old, newFlags) {
return true
}
// Another thread updated the flags while we were checking, retry.
}
return s.flags.Add(flags)
}

func (s *VirtualSession) RemoveFlags(flags uint32) bool {
for {
old := s.flags.Load()
if old&flags == 0 {
// Flags not set.
return false
}
newFlags := old & ^flags
if s.flags.CompareAndSwap(old, newFlags) {
return true
}
// Another thread updated the flags while we were checking, retry.
}
return s.flags.Remove(flags)
}

func (s *VirtualSession) SetFlags(flags uint32) bool {
for {
old := s.flags.Load()
if old == flags {
return false
}

if s.flags.CompareAndSwap(old, flags) {
return true
}
}
return s.flags.Set(flags)
}

func (s *VirtualSession) Flags() uint32 {
return s.flags.Load()
return s.flags.Get()
}

func (s *VirtualSession) Options() *AddSessionOptions {
Expand Down
37 changes: 0 additions & 37 deletions virtualsession_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -711,40 +711,3 @@ func TestVirtualSessionCleanup(t *testing.T) {
t.Error(err)
}
}

func TestVirtualSessionFlags(t *testing.T) {
s := &VirtualSession{
publicId: "dummy-for-testing",
}
if s.Flags() != 0 {
t.Fatalf("Expected flags 0, got %d", s.Flags())
}
s.AddFlags(1)
if s.Flags() != 1 {
t.Fatalf("Expected flags 1, got %d", s.Flags())
}
s.AddFlags(1)
if s.Flags() != 1 {
t.Fatalf("Expected flags 1, got %d", s.Flags())
}
s.AddFlags(2)
if s.Flags() != 3 {
t.Fatalf("Expected flags 3, got %d", s.Flags())
}
s.RemoveFlags(1)
if s.Flags() != 2 {
t.Fatalf("Expected flags 2, got %d", s.Flags())
}
s.RemoveFlags(1)
if s.Flags() != 2 {
t.Fatalf("Expected flags 2, got %d", s.Flags())
}
s.AddFlags(3)
if s.Flags() != 3 {
t.Fatalf("Expected flags 3, got %d", s.Flags())
}
s.RemoveFlags(1)
if s.Flags() != 2 {
t.Fatalf("Expected flags 2, got %d", s.Flags())
}
}

0 comments on commit 6316c53

Please sign in to comment.