From e3c403eedd6fa9c70523ff4e61cb74fac9ed66cc Mon Sep 17 00:00:00 2001 From: Sean DuBois Date: Fri, 12 Jul 2024 15:32:53 -0400 Subject: [PATCH] Make localConnectionID thread safe Resolves #647 --- conn.go | 12 ++++++------ conn_test.go | 4 ++-- flight0handler.go | 4 ++-- flight1handler.go | 8 ++++---- flight3handler.go | 6 +++--- flight4handler.go | 4 ++-- state.go | 18 +++++++++++++++--- 7 files changed, 34 insertions(+), 22 deletions(-) diff --git a/conn.go b/conn.go index d82228f31..50213652c 100644 --- a/conn.go +++ b/conn.go @@ -699,7 +699,7 @@ func (c *Conn) readAndBuffer(ctx context.Context) error { return netError(err) } - pkts, err := recordlayer.ContentAwareUnpackDatagram(b[:i], len(c.state.localConnectionID)) + pkts, err := recordlayer.ContentAwareUnpackDatagram(b[:i], len(c.state.getLocalConnectionID())) if err != nil { return err } @@ -775,8 +775,8 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A h := &recordlayer.Header{} // Set connection ID size so that records of content type tls12_cid will // be parsed correctly. - if len(c.state.localConnectionID) > 0 { - h.ConnectionID = make([]byte, len(c.state.localConnectionID)) + if len(c.state.getLocalConnectionID()) > 0 { + h.ConnectionID = make([]byte, len(c.state.getLocalConnectionID())) } if err := h.Unmarshal(buf); err != nil { // Decode error must be silently discarded @@ -832,7 +832,7 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A // If a connection identifier had been negotiated and encryption is // enabled, the connection identifier MUST be sent. - if len(c.state.localConnectionID) > 0 && h.ContentType != protocol.ContentTypeConnectionID { + if len(c.state.getLocalConnectionID()) > 0 && h.ContentType != protocol.ContentTypeConnectionID { c.log.Debug("discarded packet missing connection ID after value negotiated") return false, nil, nil } @@ -840,7 +840,7 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A var err error var hdr recordlayer.Header if h.ContentType == protocol.ContentTypeConnectionID { - hdr.ConnectionID = make([]byte, len(c.state.localConnectionID)) + hdr.ConnectionID = make([]byte, len(c.state.getLocalConnectionID())) } buf, err = c.state.cipherSuite.Decrypt(hdr, buf) if err != nil { @@ -872,7 +872,7 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A } // If connection ID does not match discard the packet. - if !bytes.Equal(c.state.localConnectionID, h.ConnectionID) { + if !bytes.Equal(c.state.getLocalConnectionID(), h.ConnectionID) { c.log.Debug("unexpected connection ID") return false, nil, nil } diff --git a/conn_test.go b/conn_test.go index 5560edc3a..a4f234e10 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1319,13 +1319,13 @@ func TestConnectionID(t *testing.T) { } }() - if !bytes.Equal(res.c.state.localConnectionID, tt.clientConnectionID) { + if !bytes.Equal(res.c.state.getLocalConnectionID(), tt.clientConnectionID) { t.Errorf("Unexpected client local connection ID\nwant: %v\ngot:%v", tt.clientConnectionID, res.c.state.localConnectionID) } if !bytes.Equal(res.c.state.remoteConnectionID, tt.serverConnectionID) { t.Errorf("Unexpected client remote connection ID\nwant: %v\ngot:%v", tt.serverConnectionID, res.c.state.remoteConnectionID) } - if !bytes.Equal(server.state.localConnectionID, tt.serverConnectionID) { + if !bytes.Equal(server.state.getLocalConnectionID(), tt.serverConnectionID) { t.Errorf("Unexpected server local connection ID\nwant: %v\ngot:%v", tt.serverConnectionID, server.state.localConnectionID) } if !bytes.Equal(server.state.remoteConnectionID, tt.clientConnectionID) { diff --git a/flight0handler.go b/flight0handler.go index 0a45c58d4..c965fd0b5 100644 --- a/flight0handler.go +++ b/flight0handler.go @@ -25,7 +25,7 @@ func flight0Parse(_ context.Context, _ flightConn, state *State, cache *handshak // Connection Identifiers must be negotiated afresh on session resumption. // https://datatracker.ietf.org/doc/html/rfc9146#name-the-connection_id-extension - state.localConnectionID = nil + state.setLocalConnectionID(nil) state.remoteConnectionID = nil state.handshakeRecvSequence = seq @@ -87,7 +87,7 @@ func flight0Parse(_ context.Context, _ flightConn, state *State, cache *handshak // If the client doesn't support connection IDs, the server should not // expect one to be sent. if state.remoteConnectionID == nil { - state.localConnectionID = nil + state.setLocalConnectionID(nil) } if cfg.extendedMasterSecret == RequireExtendedMasterSecret && !state.extendedMasterSecret { diff --git a/flight1handler.go b/flight1handler.go index 6448fef79..bcbddc3ec 100644 --- a/flight1handler.go +++ b/flight1handler.go @@ -126,15 +126,15 @@ func flight1Generate(c flightConn, state *State, _ *handshakeCache, cfg *handsha // in which case we are just requesting that the server send us a CID to // use. if cfg.connectionIDGenerator != nil { - state.localConnectionID = cfg.connectionIDGenerator() + state.setLocalConnectionID(cfg.connectionIDGenerator()) // The presence of a generator indicates support for connection IDs. We // use the presence of a non-nil local CID in flight 3 to determine // whether we send a CID in the second ClientHello, so we convert any // nil CID returned by a generator to []byte{}. - if state.localConnectionID == nil { - state.localConnectionID = []byte{} + if state.getLocalConnectionID() == nil { + state.setLocalConnectionID([]byte{}) } - extensions = append(extensions, &extension.ConnectionID{CID: state.localConnectionID}) + extensions = append(extensions, &extension.ConnectionID{CID: state.getLocalConnectionID()}) } clientHello := &handshake.MessageClientHello{ diff --git a/flight3handler.go b/flight3handler.go index b3f82fe76..678b1e5a7 100644 --- a/flight3handler.go +++ b/flight3handler.go @@ -77,7 +77,7 @@ func flight3Parse(ctx context.Context, c flightConn, state *State, cache *handsh // If the server doesn't support connection IDs, the client should not // expect one to be sent. if state.remoteConnectionID == nil { - state.localConnectionID = nil + state.setLocalConnectionID(nil) } if cfg.extendedMasterSecret == RequireExtendedMasterSecret && !state.extendedMasterSecret { @@ -284,8 +284,8 @@ func flight3Generate(_ flightConn, state *State, _ *handshakeCache, cfg *handsha // If we sent a connection ID on the first ClientHello, send it on the // second. - if state.localConnectionID != nil { - extensions = append(extensions, &extension.ConnectionID{CID: state.localConnectionID}) + if state.getLocalConnectionID() != nil { + extensions = append(extensions, &extension.ConnectionID{CID: state.getLocalConnectionID()}) } clientHello := &handshake.MessageClientHello{ diff --git a/flight4handler.go b/flight4handler.go index 840a24f15..b1ff6c3ff 100644 --- a/flight4handler.go +++ b/flight4handler.go @@ -255,8 +255,8 @@ func flight4Generate(_ flightConn, state *State, _ *handshakeCache, cfg *handsha // parsing the ClientHello, so avoid setting local connection ID if the // client won't send it. if cfg.connectionIDGenerator != nil && state.remoteConnectionID != nil { - state.localConnectionID = cfg.connectionIDGenerator() - extensions = append(extensions, &extension.ConnectionID{CID: state.localConnectionID}) + state.setLocalConnectionID(cfg.connectionIDGenerator()) + extensions = append(extensions, &extension.ConnectionID{CID: state.getLocalConnectionID()}) } var pkts []*packet diff --git a/state.go b/state.go index bc93bdf0c..178d937ea 100644 --- a/state.go +++ b/state.go @@ -36,7 +36,7 @@ type State struct { // to be received from the remote endpoint. // For a server, this is the connection ID sent in ServerHello. // For a client, this is the connection ID sent in the ClientHello. - localConnectionID []byte + localConnectionID atomic.Value // remoteConnectionID is the connection ID that the remote endpoint // specifies should be sent. // For a server, this is the connection ID received in the ClientHello. @@ -111,7 +111,7 @@ func (s *State) serialize() *serializedState { PeerCertificates: s.PeerCertificates, IdentityHint: s.IdentityHint, SessionID: s.SessionID, - LocalConnectionID: s.localConnectionID, + LocalConnectionID: s.getLocalConnectionID(), RemoteConnectionID: s.remoteConnectionID, IsClient: s.isClient, NegotiatedProtocol: s.NegotiatedProtocol, @@ -155,7 +155,7 @@ func (s *State) deserialize(serialized serializedState) { s.IdentityHint = serialized.IdentityHint // Set local and remote connection IDs - s.localConnectionID = serialized.LocalConnectionID + s.setLocalConnectionID(serialized.LocalConnectionID) s.remoteConnectionID = serialized.RemoteConnectionID s.SessionID = serialized.SessionID @@ -259,6 +259,18 @@ func (s *State) getSRTPProtectionProfile() SRTPProtectionProfile { return 0 } +func (s *State) getLocalConnectionID() []byte { + if val, ok := s.localConnectionID.Load().([]byte); ok { + return val + } + + return nil +} + +func (s *State) setLocalConnectionID(v []byte) { + s.localConnectionID.Store(v) +} + // RemoteRandomBytes returns the remote client hello random bytes func (s *State) RemoteRandomBytes() [handshake.RandomBytesLength]byte { return s.remoteRandom.RandomBytes