Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make localConnectionID thread safe #648

Merged
merged 1 commit into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ func (c *Conn) readAndBuffer(ctx context.Context) error {
return netError(err)
}

pkts, err := recordlayer.ContentAwareUnpackDatagram(b[:i], len(c.state.localConnectionID))
pkts, err := recordlayer.ContentAwareUnpackDatagram(b[:i], len(c.state.getLocalConnectionID()))
if err != nil {
return err
}
Expand Down Expand Up @@ -775,8 +775,8 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A
h := &recordlayer.Header{}
// Set connection ID size so that records of content type tls12_cid will
// be parsed correctly.
if len(c.state.localConnectionID) > 0 {
h.ConnectionID = make([]byte, len(c.state.localConnectionID))
if len(c.state.getLocalConnectionID()) > 0 {
h.ConnectionID = make([]byte, len(c.state.getLocalConnectionID()))
}
if err := h.Unmarshal(buf); err != nil {
// Decode error must be silently discarded
Expand Down Expand Up @@ -832,15 +832,15 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A

// If a connection identifier had been negotiated and encryption is
// enabled, the connection identifier MUST be sent.
if len(c.state.localConnectionID) > 0 && h.ContentType != protocol.ContentTypeConnectionID {
if len(c.state.getLocalConnectionID()) > 0 && h.ContentType != protocol.ContentTypeConnectionID {
c.log.Debug("discarded packet missing connection ID after value negotiated")
return false, nil, nil
}

var err error
var hdr recordlayer.Header
if h.ContentType == protocol.ContentTypeConnectionID {
hdr.ConnectionID = make([]byte, len(c.state.localConnectionID))
hdr.ConnectionID = make([]byte, len(c.state.getLocalConnectionID()))
}
buf, err = c.state.cipherSuite.Decrypt(hdr, buf)
if err != nil {
Expand Down Expand Up @@ -872,7 +872,7 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A
}

// If connection ID does not match discard the packet.
if !bytes.Equal(c.state.localConnectionID, h.ConnectionID) {
if !bytes.Equal(c.state.getLocalConnectionID(), h.ConnectionID) {
c.log.Debug("unexpected connection ID")
return false, nil, nil
}
Expand Down
4 changes: 2 additions & 2 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1319,13 +1319,13 @@ func TestConnectionID(t *testing.T) {
}
}()

if !bytes.Equal(res.c.state.localConnectionID, tt.clientConnectionID) {
if !bytes.Equal(res.c.state.getLocalConnectionID(), tt.clientConnectionID) {
t.Errorf("Unexpected client local connection ID\nwant: %v\ngot:%v", tt.clientConnectionID, res.c.state.localConnectionID)
}
if !bytes.Equal(res.c.state.remoteConnectionID, tt.serverConnectionID) {
t.Errorf("Unexpected client remote connection ID\nwant: %v\ngot:%v", tt.serverConnectionID, res.c.state.remoteConnectionID)
}
if !bytes.Equal(server.state.localConnectionID, tt.serverConnectionID) {
if !bytes.Equal(server.state.getLocalConnectionID(), tt.serverConnectionID) {
t.Errorf("Unexpected server local connection ID\nwant: %v\ngot:%v", tt.serverConnectionID, server.state.localConnectionID)
}
if !bytes.Equal(server.state.remoteConnectionID, tt.clientConnectionID) {
Expand Down
4 changes: 2 additions & 2 deletions flight0handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func flight0Parse(_ context.Context, _ flightConn, state *State, cache *handshak

// Connection Identifiers must be negotiated afresh on session resumption.
// https://datatracker.ietf.org/doc/html/rfc9146#name-the-connection_id-extension
state.localConnectionID = nil
state.setLocalConnectionID(nil)
state.remoteConnectionID = nil

state.handshakeRecvSequence = seq
Expand Down Expand Up @@ -87,7 +87,7 @@ func flight0Parse(_ context.Context, _ flightConn, state *State, cache *handshak
// If the client doesn't support connection IDs, the server should not
// expect one to be sent.
if state.remoteConnectionID == nil {
state.localConnectionID = nil
state.setLocalConnectionID(nil)
}

if cfg.extendedMasterSecret == RequireExtendedMasterSecret && !state.extendedMasterSecret {
Expand Down
8 changes: 4 additions & 4 deletions flight1handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,15 +126,15 @@ func flight1Generate(c flightConn, state *State, _ *handshakeCache, cfg *handsha
// in which case we are just requesting that the server send us a CID to
// use.
if cfg.connectionIDGenerator != nil {
state.localConnectionID = cfg.connectionIDGenerator()
state.setLocalConnectionID(cfg.connectionIDGenerator())
// The presence of a generator indicates support for connection IDs. We
// use the presence of a non-nil local CID in flight 3 to determine
// whether we send a CID in the second ClientHello, so we convert any
// nil CID returned by a generator to []byte{}.
if state.localConnectionID == nil {
state.localConnectionID = []byte{}
if state.getLocalConnectionID() == nil {
state.setLocalConnectionID([]byte{})
}
extensions = append(extensions, &extension.ConnectionID{CID: state.localConnectionID})
extensions = append(extensions, &extension.ConnectionID{CID: state.getLocalConnectionID()})
}

clientHello := &handshake.MessageClientHello{
Expand Down
6 changes: 3 additions & 3 deletions flight3handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func flight3Parse(ctx context.Context, c flightConn, state *State, cache *handsh
// If the server doesn't support connection IDs, the client should not
// expect one to be sent.
if state.remoteConnectionID == nil {
state.localConnectionID = nil
state.setLocalConnectionID(nil)
}

if cfg.extendedMasterSecret == RequireExtendedMasterSecret && !state.extendedMasterSecret {
Expand Down Expand Up @@ -284,8 +284,8 @@ func flight3Generate(_ flightConn, state *State, _ *handshakeCache, cfg *handsha

// If we sent a connection ID on the first ClientHello, send it on the
// second.
if state.localConnectionID != nil {
extensions = append(extensions, &extension.ConnectionID{CID: state.localConnectionID})
if state.getLocalConnectionID() != nil {
extensions = append(extensions, &extension.ConnectionID{CID: state.getLocalConnectionID()})
}

clientHello := &handshake.MessageClientHello{
Expand Down
4 changes: 2 additions & 2 deletions flight4handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,8 @@ func flight4Generate(_ flightConn, state *State, _ *handshakeCache, cfg *handsha
// parsing the ClientHello, so avoid setting local connection ID if the
// client won't send it.
if cfg.connectionIDGenerator != nil && state.remoteConnectionID != nil {
state.localConnectionID = cfg.connectionIDGenerator()
extensions = append(extensions, &extension.ConnectionID{CID: state.localConnectionID})
state.setLocalConnectionID(cfg.connectionIDGenerator())
extensions = append(extensions, &extension.ConnectionID{CID: state.getLocalConnectionID()})
}

var pkts []*packet
Expand Down
18 changes: 15 additions & 3 deletions state.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ type State struct {
// to be received from the remote endpoint.
// For a server, this is the connection ID sent in ServerHello.
// For a client, this is the connection ID sent in the ClientHello.
localConnectionID []byte
localConnectionID atomic.Value
// remoteConnectionID is the connection ID that the remote endpoint
// specifies should be sent.
// For a server, this is the connection ID received in the ClientHello.
Expand Down Expand Up @@ -111,7 +111,7 @@ func (s *State) serialize() *serializedState {
PeerCertificates: s.PeerCertificates,
IdentityHint: s.IdentityHint,
SessionID: s.SessionID,
LocalConnectionID: s.localConnectionID,
LocalConnectionID: s.getLocalConnectionID(),
RemoteConnectionID: s.remoteConnectionID,
IsClient: s.isClient,
NegotiatedProtocol: s.NegotiatedProtocol,
Expand Down Expand Up @@ -155,7 +155,7 @@ func (s *State) deserialize(serialized serializedState) {
s.IdentityHint = serialized.IdentityHint

// Set local and remote connection IDs
s.localConnectionID = serialized.LocalConnectionID
s.setLocalConnectionID(serialized.LocalConnectionID)
s.remoteConnectionID = serialized.RemoteConnectionID

s.SessionID = serialized.SessionID
Expand Down Expand Up @@ -259,6 +259,18 @@ func (s *State) getSRTPProtectionProfile() SRTPProtectionProfile {
return 0
}

func (s *State) getLocalConnectionID() []byte {
if val, ok := s.localConnectionID.Load().([]byte); ok {
return val
}

return nil
}

func (s *State) setLocalConnectionID(v []byte) {
s.localConnectionID.Store(v)
}

// RemoteRandomBytes returns the remote client hello random bytes
func (s *State) RemoteRandomBytes() [handshake.RandomBytesLength]byte {
return s.remoteRandom.RandomBytes
Expand Down
Loading