Skip to content

Commit

Permalink
Merge pull request #500 from strukturag/atomic-19
Browse files Browse the repository at this point in the history
Switch to atomic types from Go 1.19
  • Loading branch information
fancycode authored Oct 30, 2023
2 parents 2c5ad32 + c134883 commit a43686a
Show file tree
Hide file tree
Showing 17 changed files with 184 additions and 211 deletions.
12 changes: 6 additions & 6 deletions capabilities_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,9 @@ func TestCapabilities(t *testing.T) {
}

func TestInvalidateCapabilities(t *testing.T) {
var called uint32
var called atomic.Uint32
url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse) {
atomic.AddUint32(&called, 1)
called.Add(1)
})

ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
Expand All @@ -209,7 +209,7 @@ func TestInvalidateCapabilities(t *testing.T) {
t.Errorf("expected direct response")
}

if value := atomic.LoadUint32(&called); value != 1 {
if value := called.Load(); value != 1 {
t.Errorf("expected called %d, got %d", 1, value)
}

Expand All @@ -224,7 +224,7 @@ func TestInvalidateCapabilities(t *testing.T) {
t.Errorf("expected direct response")
}

if value := atomic.LoadUint32(&called); value != 2 {
if value := called.Load(); value != 2 {
t.Errorf("expected called %d, got %d", 2, value)
}

Expand All @@ -239,7 +239,7 @@ func TestInvalidateCapabilities(t *testing.T) {
t.Errorf("expected cached response")
}

if value := atomic.LoadUint32(&called); value != 2 {
if value := called.Load(); value != 2 {
t.Errorf("expected called %d, got %d", 2, value)
}

Expand All @@ -258,7 +258,7 @@ func TestInvalidateCapabilities(t *testing.T) {
t.Errorf("expected direct response")
}

if value := atomic.LoadUint32(&called); value != 3 {
if value := called.Load(); value != 3 {
t.Errorf("expected called %d, got %d", 3, value)
}
}
17 changes: 8 additions & 9 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import (
"sync"
"sync/atomic"
"time"
"unsafe"

"github.com/gorilla/websocket"
"github.com/mailru/easyjson"
Expand Down Expand Up @@ -108,11 +107,11 @@ type Client struct {
addr string
handler ClientHandler
agent string
closed uint32
closed atomic.Int32
country *string
logRTT bool

session unsafe.Pointer
session atomic.Pointer[ClientSession]

mu sync.Mutex

Expand Down Expand Up @@ -150,19 +149,19 @@ func (c *Client) SetConn(conn *websocket.Conn, remoteAddress string, handler Cli
}

func (c *Client) IsConnected() bool {
return atomic.LoadUint32(&c.closed) == 0
return c.closed.Load() == 0
}

func (c *Client) IsAuthenticated() bool {
return c.GetSession() != nil
}

func (c *Client) GetSession() *ClientSession {
return (*ClientSession)(atomic.LoadPointer(&c.session))
return c.session.Load()
}

func (c *Client) SetSession(session *ClientSession) {
atomic.StorePointer(&c.session, unsafe.Pointer(session))
c.session.Store(session)
}

func (c *Client) RemoteAddr() string {
Expand All @@ -188,7 +187,7 @@ func (c *Client) Country() string {
}

func (c *Client) Close() {
if atomic.LoadUint32(&c.closed) >= 2 {
if c.closed.Load() >= 2 {
// Prevent reentrant call in case this was the second closing
// step. Would otherwise deadlock in the "Once.Do" call path
// through "Hub.processUnregister" (which calls "Close" again).
Expand All @@ -201,7 +200,7 @@ func (c *Client) Close() {
}

func (c *Client) doClose() {
closed := atomic.AddUint32(&c.closed, 1)
closed := c.closed.Add(1)
if closed == 1 {
c.mu.Lock()
defer c.mu.Unlock()
Expand Down Expand Up @@ -329,7 +328,7 @@ func (c *Client) ReadPump() {
}

// Stop processing if the client was closed.
if atomic.LoadUint32(&c.closed) != 0 {
if !c.IsConnected() {
bufferPool.Put(decodeBuffer)
break
}
Expand Down
22 changes: 11 additions & 11 deletions client/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,17 @@ const (
)

type Stats struct {
numRecvMessages uint64
numSentMessages uint64
numRecvMessages atomic.Uint64
numSentMessages atomic.Uint64
resetRecvMessages uint64
resetSentMessages uint64

start time.Time
}

func (s *Stats) reset(start time.Time) {
s.resetRecvMessages = atomic.AddUint64(&s.numRecvMessages, 0)
s.resetSentMessages = atomic.AddUint64(&s.numSentMessages, 0)
s.resetRecvMessages = s.numRecvMessages.Load()
s.resetSentMessages = s.numSentMessages.Load()
s.start = start
}

Expand All @@ -103,9 +103,9 @@ func (s *Stats) Log() {
return
}

totalSentMessages := atomic.AddUint64(&s.numSentMessages, 0)
totalSentMessages := s.numSentMessages.Load()
sentMessages := totalSentMessages - s.resetSentMessages
totalRecvMessages := atomic.AddUint64(&s.numRecvMessages, 0)
totalRecvMessages := s.numRecvMessages.Load()
recvMessages := totalRecvMessages - s.resetRecvMessages
log.Printf("Stats: sent=%d (%d/sec), recv=%d (%d/sec), delta=%d",
totalSentMessages, sentMessages/perSec,
Expand All @@ -125,7 +125,7 @@ type SignalingClient struct {
conn *websocket.Conn

stats *Stats
closed uint32
closed atomic.Bool

stopChan chan struct{}

Expand Down Expand Up @@ -164,7 +164,7 @@ func NewSignalingClient(cookie *securecookie.SecureCookie, url string, stats *St
}

func (c *SignalingClient) Close() {
if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) {
if !c.closed.CompareAndSwap(false, true) {
return
}

Expand Down Expand Up @@ -197,7 +197,7 @@ func (c *SignalingClient) Send(message *signaling.ClientMessage) {
}

func (c *SignalingClient) processMessage(message *signaling.ServerMessage) {
atomic.AddUint64(&c.stats.numRecvMessages, 1)
c.stats.numRecvMessages.Add(1)
switch message.Type {
case "hello":
c.processHelloMessage(message)
Expand Down Expand Up @@ -334,7 +334,7 @@ func (c *SignalingClient) writeInternal(message *signaling.ClientMessage) bool {
}

writer.Close()
atomic.AddUint64(&c.stats.numSentMessages, 1)
c.stats.numSentMessages.Add(1)
return true

close:
Expand Down Expand Up @@ -383,7 +383,7 @@ func (c *SignalingClient) SendMessages(clients []*SignalingClient) {
sessionIds[c] = c.PublicSessionId()
}

for atomic.LoadUint32(&c.closed) == 0 {
for !c.closed.Load() {
now := time.Now()

sender := c
Expand Down
24 changes: 11 additions & 13 deletions clientsession.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import (
"sync"
"sync/atomic"
"time"
"unsafe"

"github.com/pion/sdp/v3"
)
Expand All @@ -50,9 +49,6 @@ var (
type ResponseHandlerFunc func(message *ClientMessage) bool

type ClientSession struct {
roomJoinTime int64
inCall uint32

hub *Hub
events AsyncEvents
privateId string
Expand All @@ -64,6 +60,7 @@ type ClientSession struct {
userId string
userData *json.RawMessage

inCall atomic.Uint32
supportsPermissions bool
permissions map[Permission]bool

Expand All @@ -76,7 +73,8 @@ type ClientSession struct {
mu sync.Mutex

client *Client
room unsafe.Pointer
room atomic.Pointer[Room]
roomJoinTime atomic.Int64
roomSessionId string

publisherWaiters ChannelWaiters
Expand Down Expand Up @@ -171,7 +169,7 @@ func (s *ClientSession) ClientType() string {

// GetInCall is only used for internal clients.
func (s *ClientSession) GetInCall() int {
return int(atomic.LoadUint32(&s.inCall))
return int(s.inCall.Load())
}

func (s *ClientSession) SetInCall(inCall int) bool {
Expand All @@ -180,12 +178,12 @@ func (s *ClientSession) SetInCall(inCall int) bool {
}

for {
old := atomic.LoadUint32(&s.inCall)
old := s.inCall.Load()
if old == uint32(inCall) {
return false
}

if atomic.CompareAndSwapUint32(&s.inCall, old, uint32(inCall)) {
if s.inCall.CompareAndSwap(old, uint32(inCall)) {
return true
}
}
Expand Down Expand Up @@ -340,11 +338,11 @@ func (s *ClientSession) IsExpired(now time.Time) bool {
}

func (s *ClientSession) SetRoom(room *Room) {
atomic.StorePointer(&s.room, unsafe.Pointer(room))
s.room.Store(room)
if room != nil {
atomic.StoreInt64(&s.roomJoinTime, time.Now().UnixNano())
s.roomJoinTime.Store(time.Now().UnixNano())
} else {
atomic.StoreInt64(&s.roomJoinTime, 0)
s.roomJoinTime.Store(0)
}

s.seenJoinedLock.Lock()
Expand All @@ -353,11 +351,11 @@ func (s *ClientSession) SetRoom(room *Room) {
}

func (s *ClientSession) GetRoom() *Room {
return (*Room)(atomic.LoadPointer(&s.room))
return s.room.Load()
}

func (s *ClientSession) getRoomJoinTime() time.Time {
t := atomic.LoadInt64(&s.roomJoinTime)
t := s.roomJoinTime.Load()
if t == 0 {
return time.Time{}
}
Expand Down
6 changes: 3 additions & 3 deletions closer.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import (
)

type Closer struct {
closed uint32
closed atomic.Bool
C chan struct{}
}

Expand All @@ -37,11 +37,11 @@ func NewCloser() *Closer {
}

func (c *Closer) IsClosed() bool {
return atomic.LoadUint32(&c.closed) != 0
return c.closed.Load()
}

func (c *Closer) Close() {
if atomic.CompareAndSwapUint32(&c.closed, 0, 1) {
if c.closed.CompareAndSwap(false, true) {
close(c.C)
}
}
16 changes: 6 additions & 10 deletions grpc_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ const (
var (
lookupGrpcIp = net.LookupIP // can be overwritten from tests

customResolverPrefix uint64
customResolverPrefix atomic.Uint64
)

func init() {
Expand All @@ -75,12 +75,12 @@ func newGrpcClientImpl(conn grpc.ClientConnInterface) *grpcClientImpl {
}

type GrpcClient struct {
isSelf uint32

ip net.IP
target string
conn *grpc.ClientConn
impl *grpcClientImpl

isSelf atomic.Bool
}

type customIpResolver struct {
Expand Down Expand Up @@ -125,7 +125,7 @@ func NewGrpcClient(target string, ip net.IP, opts ...grpc.DialOption) (*GrpcClie
var conn *grpc.ClientConn
var err error
if ip != nil {
prefix := atomic.AddUint64(&customResolverPrefix, 1)
prefix := customResolverPrefix.Add(1)
addr := ip.String()
hostname := target
if host, port, err := net.SplitHostPort(target); err == nil {
Expand Down Expand Up @@ -168,15 +168,11 @@ func (c *GrpcClient) Close() error {
}

func (c *GrpcClient) IsSelf() bool {
return atomic.LoadUint32(&c.isSelf) != 0
return c.isSelf.Load()
}

func (c *GrpcClient) SetSelf(self bool) {
if self {
atomic.StoreUint32(&c.isSelf, 1)
} else {
atomic.StoreUint32(&c.isSelf, 0)
}
c.isSelf.Store(self)
}

func (c *GrpcClient) GetServerId(ctx context.Context) (string, error) {
Expand Down
Loading

0 comments on commit a43686a

Please sign in to comment.