Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move common flags code to own struct. #591

Merged
merged 1 commit into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 3 additions & 12 deletions clientsession.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ type ClientSession struct {
userId string
userData *json.RawMessage

inCall atomic.Uint32
inCall Flags
supportsPermissions bool
permissions map[Permission]bool

Expand Down Expand Up @@ -169,24 +169,15 @@ func (s *ClientSession) ClientType() string {

// GetInCall is only used for internal clients.
func (s *ClientSession) GetInCall() int {
return int(s.inCall.Load())
return int(s.inCall.Get())
}

func (s *ClientSession) SetInCall(inCall int) bool {
if inCall < 0 {
inCall = 0
}

for {
old := s.inCall.Load()
if old == uint32(inCall) {
return false
}

if s.inCall.CompareAndSwap(old, uint32(inCall)) {
return true
}
}
return s.inCall.Set(uint32(inCall))
}

func (s *ClientSession) GetFeatures() []string {
Expand Down
77 changes: 77 additions & 0 deletions flags.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/**
* Standalone signaling server for the Nextcloud Spreed app.
* Copyright (C) 2023 struktur AG
*
* @author Joachim Bauch <[email protected]>
*
* @license GNU AGPL version 3 or any later version
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package signaling

import (
"sync/atomic"
)

type Flags struct {
flags atomic.Uint32
}

func (f *Flags) Add(flags uint32) bool {
for {
old := f.flags.Load()
if old&flags == flags {
// Flags already set.
return false
}
newFlags := old | flags
if f.flags.CompareAndSwap(old, newFlags) {
return true
}
// Another thread updated the flags while we were checking, retry.
}
}

func (f *Flags) Remove(flags uint32) bool {
for {
old := f.flags.Load()
if old&flags == 0 {
// Flags not set.
return false
}
newFlags := old & ^flags
if f.flags.CompareAndSwap(old, newFlags) {
return true
}
// Another thread updated the flags while we were checking, retry.
}
}

func (f *Flags) Set(flags uint32) bool {
for {
old := f.flags.Load()
if old == flags {
return false
}

if f.flags.CompareAndSwap(old, flags) {
return true
}
}
}

func (f *Flags) Get() uint32 {
return f.flags.Load()
}
140 changes: 140 additions & 0 deletions flags_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
/**
* Standalone signaling server for the Nextcloud Spreed app.
* Copyright (C) 2023 struktur AG
*
* @author Joachim Bauch <[email protected]>
*
* @license GNU AGPL version 3 or any later version
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package signaling

import (
"sync"
"sync/atomic"
"testing"
)

func TestFlags(t *testing.T) {
var f Flags
if f.Get() != 0 {
t.Fatalf("Expected flags 0, got %d", f.Get())
}
if !f.Add(1) {
t.Error("expected true")
}
if f.Get() != 1 {
t.Fatalf("Expected flags 1, got %d", f.Get())
}
if f.Add(1) {
t.Error("expected false")
}
if f.Get() != 1 {
t.Fatalf("Expected flags 1, got %d", f.Get())
}
if !f.Add(2) {
t.Error("expected true")
}
if f.Get() != 3 {
t.Fatalf("Expected flags 3, got %d", f.Get())
}
if !f.Remove(1) {
t.Error("expected true")
}
if f.Get() != 2 {
t.Fatalf("Expected flags 2, got %d", f.Get())
}
if f.Remove(1) {
t.Error("expected false")
}
if f.Get() != 2 {
t.Fatalf("Expected flags 2, got %d", f.Get())
}
if !f.Add(3) {
t.Error("expected true")
}
if f.Get() != 3 {
t.Fatalf("Expected flags 3, got %d", f.Get())
}
if !f.Remove(1) {
t.Error("expected true")
}
if f.Get() != 2 {
t.Fatalf("Expected flags 2, got %d", f.Get())
}
}

func runConcurrentFlags(t *testing.T, count int, f func()) {
var start sync.WaitGroup
start.Add(1)
var ready sync.WaitGroup
var done sync.WaitGroup
for i := 0; i < count; i++ {
done.Add(1)
ready.Add(1)
go func() {
defer done.Done()
ready.Done()
start.Wait()
f()
}()
}
ready.Wait()
start.Done()
done.Wait()
}

func TestFlagsConcurrentAdd(t *testing.T) {
var flags Flags

var added atomic.Int32
runConcurrentFlags(t, 100, func() {
if flags.Add(1) {
added.Add(1)
}
})
if added.Load() != 1 {
t.Errorf("expected only one successfull attempt, got %d", added.Load())
}
}

func TestFlagsConcurrentRemove(t *testing.T) {
var flags Flags
flags.Set(1)

var removed atomic.Int32
runConcurrentFlags(t, 100, func() {
if flags.Remove(1) {
removed.Add(1)
}
})
if removed.Load() != 1 {
t.Errorf("expected only one successfull attempt, got %d", removed.Load())
}
}

func TestFlagsConcurrentSet(t *testing.T) {
var flags Flags

var set atomic.Int32
runConcurrentFlags(t, 100, func() {
if flags.Set(1) {
set.Add(1)
}
})
if set.Load() != 1 {
t.Errorf("expected only one successfull attempt, got %d", set.Load())
}
}
60 changes: 11 additions & 49 deletions virtualsession.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ type VirtualSession struct {
sessionId string
userId string
userData *json.RawMessage
inCall atomic.Uint32
flags atomic.Uint32
inCall Flags
flags Flags
options *AddSessionOptions
}

Expand All @@ -69,7 +69,6 @@ func NewVirtualSession(session *ClientSession, privateId string, publicId string
userData: msg.User,
options: msg.Options,
}
result.flags.Store(msg.Flags)

if err := session.events.RegisterSessionListener(publicId, session.Backend(), result); err != nil {
return nil, err
Expand All @@ -80,6 +79,9 @@ func NewVirtualSession(session *ClientSession, privateId string, publicId string
} else if !session.HasFeature(ClientFeatureInternalInCall) {
result.SetInCall(FlagInCall | FlagWithPhone)
}
if msg.Flags != 0 {
result.SetFlags(msg.Flags)
}

return result, nil
}
Expand All @@ -97,24 +99,15 @@ func (s *VirtualSession) ClientType() string {
}

func (s *VirtualSession) GetInCall() int {
return int(s.inCall.Load())
return int(s.inCall.Get())
}

func (s *VirtualSession) SetInCall(inCall int) bool {
if inCall < 0 {
inCall = 0
}

for {
old := s.inCall.Load()
if old == uint32(inCall) {
return false
}

if s.inCall.CompareAndSwap(old, uint32(inCall)) {
return true
}
}
return s.inCall.Set(uint32(inCall))
}

func (s *VirtualSession) Data() *SessionIdData {
Expand Down Expand Up @@ -240,50 +233,19 @@ func (s *VirtualSession) SessionId() string {
}

func (s *VirtualSession) AddFlags(flags uint32) bool {
for {
old := s.flags.Load()
if old&flags == flags {
// Flags already set.
return false
}
newFlags := old | flags
if s.flags.CompareAndSwap(old, newFlags) {
return true
}
// Another thread updated the flags while we were checking, retry.
}
return s.flags.Add(flags)
}

func (s *VirtualSession) RemoveFlags(flags uint32) bool {
for {
old := s.flags.Load()
if old&flags == 0 {
// Flags not set.
return false
}
newFlags := old & ^flags
if s.flags.CompareAndSwap(old, newFlags) {
return true
}
// Another thread updated the flags while we were checking, retry.
}
return s.flags.Remove(flags)
}

func (s *VirtualSession) SetFlags(flags uint32) bool {
for {
old := s.flags.Load()
if old == flags {
return false
}

if s.flags.CompareAndSwap(old, flags) {
return true
}
}
return s.flags.Set(flags)
}

func (s *VirtualSession) Flags() uint32 {
return s.flags.Load()
return s.flags.Get()
}

func (s *VirtualSession) Options() *AddSessionOptions {
Expand Down
Loading
Loading