Skip to content

Commit

Permalink
Make localConnectionID thread safe
Browse files Browse the repository at this point in the history
Resolves #647
  • Loading branch information
Sean-Der committed Jul 12, 2024
1 parent 0a1b73a commit 602dc71
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 22 deletions.
12 changes: 6 additions & 6 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ func (c *Conn) readAndBuffer(ctx context.Context) error {
return netError(err)
}

pkts, err := recordlayer.ContentAwareUnpackDatagram(b[:i], len(c.state.localConnectionID))
pkts, err := recordlayer.ContentAwareUnpackDatagram(b[:i], len(c.state.getLocalConnectionID()))
if err != nil {
return err
}
Expand Down Expand Up @@ -775,8 +775,8 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A
h := &recordlayer.Header{}
// Set connection ID size so that records of content type tls12_cid will
// be parsed correctly.
if len(c.state.localConnectionID) > 0 {
h.ConnectionID = make([]byte, len(c.state.localConnectionID))
if len(c.state.getLocalConnectionID()) > 0 {
h.ConnectionID = make([]byte, len(c.state.getLocalConnectionID()))
}
if err := h.Unmarshal(buf); err != nil {
// Decode error must be silently discarded
Expand Down Expand Up @@ -832,15 +832,15 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A

// If a connection identifier had been negotiated and encryption is
// enabled, the connection identifier MUST be sent.
if len(c.state.localConnectionID) > 0 && h.ContentType != protocol.ContentTypeConnectionID {
if len(c.state.getLocalConnectionID()) > 0 && h.ContentType != protocol.ContentTypeConnectionID {
c.log.Debug("discarded packet missing connection ID after value negotiated")
return false, nil, nil
}

var err error
var hdr recordlayer.Header
if h.ContentType == protocol.ContentTypeConnectionID {
hdr.ConnectionID = make([]byte, len(c.state.localConnectionID))
hdr.ConnectionID = make([]byte, len(c.state.getLocalConnectionID()))
}
buf, err = c.state.cipherSuite.Decrypt(hdr, buf)
if err != nil {
Expand Down Expand Up @@ -872,7 +872,7 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A
}

// If connection ID does not match discard the packet.
if !bytes.Equal(c.state.localConnectionID, h.ConnectionID) {
if !bytes.Equal(c.state.getLocalConnectionID(), h.ConnectionID) {
c.log.Debug("unexpected connection ID")
return false, nil, nil
}
Expand Down
4 changes: 2 additions & 2 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1319,13 +1319,13 @@ func TestConnectionID(t *testing.T) {
}
}()

if !bytes.Equal(res.c.state.localConnectionID, tt.clientConnectionID) {
if !bytes.Equal(res.c.state.getLocalConnectionID(), tt.clientConnectionID) {
t.Errorf("Unexpected client local connection ID\nwant: %v\ngot:%v", tt.clientConnectionID, res.c.state.localConnectionID)
}
if !bytes.Equal(res.c.state.remoteConnectionID, tt.serverConnectionID) {
t.Errorf("Unexpected client remote connection ID\nwant: %v\ngot:%v", tt.serverConnectionID, res.c.state.remoteConnectionID)
}
if !bytes.Equal(server.state.localConnectionID, tt.serverConnectionID) {
if !bytes.Equal(server.state.getLocalConnectionID(), tt.serverConnectionID) {
t.Errorf("Unexpected server local connection ID\nwant: %v\ngot:%v", tt.serverConnectionID, server.state.localConnectionID)
}
if !bytes.Equal(server.state.remoteConnectionID, tt.clientConnectionID) {
Expand Down
4 changes: 2 additions & 2 deletions flight0handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func flight0Parse(_ context.Context, _ flightConn, state *State, cache *handshak

// Connection Identifiers must be negotiated afresh on session resumption.
// https://datatracker.ietf.org/doc/html/rfc9146#name-the-connection_id-extension
state.localConnectionID = nil
state.setLocalConnectionID(nil)
state.remoteConnectionID = nil

state.handshakeRecvSequence = seq
Expand Down Expand Up @@ -87,7 +87,7 @@ func flight0Parse(_ context.Context, _ flightConn, state *State, cache *handshak
// If the client doesn't support connection IDs, the server should not
// expect one to be sent.
if state.remoteConnectionID == nil {
state.localConnectionID = nil
state.setLocalConnectionID(nil)
}

if cfg.extendedMasterSecret == RequireExtendedMasterSecret && !state.extendedMasterSecret {
Expand Down
8 changes: 4 additions & 4 deletions flight1handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,15 +126,15 @@ func flight1Generate(c flightConn, state *State, _ *handshakeCache, cfg *handsha
// in which case we are just requesting that the server send us a CID to
// use.
if cfg.connectionIDGenerator != nil {
state.localConnectionID = cfg.connectionIDGenerator()
state.setLocalConnectionID(cfg.connectionIDGenerator())
// The presence of a generator indicates support for connection IDs. We
// use the presence of a non-nil local CID in flight 3 to determine
// whether we send a CID in the second ClientHello, so we convert any
// nil CID returned by a generator to []byte{}.
if state.localConnectionID == nil {
state.localConnectionID = []byte{}
if state.getLocalConnectionID() == nil {
state.setLocalConnectionID([]byte{})
}
extensions = append(extensions, &extension.ConnectionID{CID: state.localConnectionID})
extensions = append(extensions, &extension.ConnectionID{CID: state.getLocalConnectionID()})
}

clientHello := &handshake.MessageClientHello{
Expand Down
6 changes: 3 additions & 3 deletions flight3handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func flight3Parse(ctx context.Context, c flightConn, state *State, cache *handsh
// If the server doesn't support connection IDs, the client should not
// expect one to be sent.
if state.remoteConnectionID == nil {
state.localConnectionID = nil
state.setLocalConnectionID(nil)
}

if cfg.extendedMasterSecret == RequireExtendedMasterSecret && !state.extendedMasterSecret {
Expand Down Expand Up @@ -284,8 +284,8 @@ func flight3Generate(_ flightConn, state *State, _ *handshakeCache, cfg *handsha

// If we sent a connection ID on the first ClientHello, send it on the
// second.
if state.localConnectionID != nil {
extensions = append(extensions, &extension.ConnectionID{CID: state.localConnectionID})
if state.getLocalConnectionID() != nil {
extensions = append(extensions, &extension.ConnectionID{CID: state.getLocalConnectionID()})
}

clientHello := &handshake.MessageClientHello{
Expand Down
4 changes: 2 additions & 2 deletions flight4handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,8 @@ func flight4Generate(_ flightConn, state *State, _ *handshakeCache, cfg *handsha
// parsing the ClientHello, so avoid setting local connection ID if the
// client won't send it.
if cfg.connectionIDGenerator != nil && state.remoteConnectionID != nil {
state.localConnectionID = cfg.connectionIDGenerator()
extensions = append(extensions, &extension.ConnectionID{CID: state.localConnectionID})
state.setLocalConnectionID(cfg.connectionIDGenerator())
extensions = append(extensions, &extension.ConnectionID{CID: state.getLocalConnectionID()})
}

var pkts []*packet
Expand Down
18 changes: 15 additions & 3 deletions state.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ type State struct {
// to be received from the remote endpoint.
// For a server, this is the connection ID sent in ServerHello.
// For a client, this is the connection ID sent in the ClientHello.
localConnectionID []byte
localConnectionID atomic.Value
// remoteConnectionID is the connection ID that the remote endpoint
// specifies should be sent.
// For a server, this is the connection ID received in the ClientHello.
Expand Down Expand Up @@ -111,7 +111,7 @@ func (s *State) serialize() *serializedState {
PeerCertificates: s.PeerCertificates,
IdentityHint: s.IdentityHint,
SessionID: s.SessionID,
LocalConnectionID: s.localConnectionID,
LocalConnectionID: s.getLocalConnectionID(),
RemoteConnectionID: s.remoteConnectionID,
IsClient: s.isClient,
NegotiatedProtocol: s.NegotiatedProtocol,
Expand Down Expand Up @@ -155,7 +155,7 @@ func (s *State) deserialize(serialized serializedState) {
s.IdentityHint = serialized.IdentityHint

// Set local and remote connection IDs
s.localConnectionID = serialized.LocalConnectionID
s.setLocalConnectionID(serialized.LocalConnectionID)
s.remoteConnectionID = serialized.RemoteConnectionID

s.SessionID = serialized.SessionID
Expand Down Expand Up @@ -259,6 +259,18 @@ func (s *State) getSRTPProtectionProfile() SRTPProtectionProfile {
return 0
}

func (s *State) getLocalConnectionID() []byte {
if val, ok := s.localConnectionID.Load().([]byte); ok {
return val
}

return nil
}

func (s *State) setLocalConnectionID(v []byte) {
s.localConnectionID.Store(v)
}

// RemoteRandomBytes returns the remote client hello random bytes
func (s *State) RemoteRandomBytes() [handshake.RandomBytesLength]byte {
return s.remoteRandom.RandomBytes
Expand Down

0 comments on commit 602dc71

Please sign in to comment.