From fc419a433e9125e24f69f3e9f467182948547562 Mon Sep 17 00:00:00 2001 From: Daniel Adam Date: Wed, 21 Aug 2024 10:11:21 +0200 Subject: [PATCH] Fix race between Conn.Close and Conn.Handshake --- conn.go | 40 ++++++++++++++++++++++++++++++---------- conn_test.go | 51 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 81 insertions(+), 10 deletions(-) diff --git a/conn.go b/conn.go index c0c34073..26c56f2e 100644 --- a/conn.go +++ b/conn.go @@ -74,13 +74,13 @@ type Conn struct { handshakeCompletedSuccessfully atomic.Value handshakeMutex sync.Mutex + handshakeDone chan struct{} encryptedPackets []addrPkt connectionClosedByUser bool closeLock sync.Mutex closed *closer.Closer - handshakeLoopsFinished sync.WaitGroup readDeadline *deadline.Deadline writeDeadline *deadline.Deadline @@ -256,6 +256,12 @@ func (c *Conn) HandshakeContext(ctx context.Context) error { return nil } + handshakeDone := make(chan struct{}) + defer close(handshakeDone) + c.closeLock.Lock() + c.handshakeDone = handshakeDone + c.closeLock.Unlock() + // rfc5246#section-7.4.3 // In addition, the hash and signature algorithms MUST be compatible // with the key in the server's end-entity certificate. @@ -405,7 +411,12 @@ func (c *Conn) Write(p []byte) (int, error) { // Close closes the connection. func (c *Conn) Close() error { err := c.close(true) //nolint:contextcheck - c.handshakeLoopsFinished.Wait() + c.closeLock.Lock() + handshakeDone := c.handshakeDone + c.closeLock.Unlock() + if handshakeDone != nil { + <-handshakeDone + } return err } @@ -1026,7 +1037,6 @@ func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFligh done := make(chan struct{}) ctxRead, cancelRead := context.WithCancel(context.Background()) - c.cancelHandshakeReader = cancelRead cfg.onFlightState = func(_ flightVal, s handshakeState) { if s == handshakeFinished && !c.isHandshakeCompletedSuccessfully() { c.setHandshakeCompletedSuccessfully() @@ -1035,16 +1045,21 @@ func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFligh } ctxHs, cancel := context.WithCancel(context.Background()) + + c.closeLock.Lock() c.cancelHandshaker = cancel + c.cancelHandshakeReader = cancelRead + c.closeLock.Unlock() firstErr := make(chan error, 1) - c.handshakeLoopsFinished.Add(2) + var handshakeLoopsFinished sync.WaitGroup + handshakeLoopsFinished.Add(2) // Handshake routine should be live until close. // The other party may request retransmission of the last flight to cope with packet drop. go func() { - defer c.handshakeLoopsFinished.Done() + defer handshakeLoopsFinished.Done() err := c.fsm.Run(ctxHs, c, initialState) if !errors.Is(err, context.Canceled) { select { @@ -1064,7 +1079,7 @@ func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFligh // Force stop handshaker when the underlying connection is closed. cancel() }() - defer c.handshakeLoopsFinished.Done() + defer handshakeLoopsFinished.Done() for { if err := c.readAndBuffer(ctxRead); err != nil { var e *alertError @@ -1123,12 +1138,12 @@ func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFligh case err := <-firstErr: cancelRead() cancel() - c.handshakeLoopsFinished.Wait() + handshakeLoopsFinished.Wait() return c.translateHandshakeCtxError(err) case <-ctx.Done(): cancelRead() cancel() - c.handshakeLoopsFinished.Wait() + handshakeLoopsFinished.Wait() return c.translateHandshakeCtxError(ctx.Err()) case <-done: return nil @@ -1146,8 +1161,13 @@ func (c *Conn) translateHandshakeCtxError(err error) error { } func (c *Conn) close(byUser bool) error { - c.cancelHandshaker() - c.cancelHandshakeReader() + c.closeLock.Lock() + cancelHandshaker := c.cancelHandshaker + cancelHandshakeReader := c.cancelHandshakeReader + c.closeLock.Unlock() + + cancelHandshaker() + cancelHandshakeReader() if c.isHandshakeCompletedSuccessfully() && byUser { // Discard error from notify() to return non-error on the first user call of Close() diff --git a/conn_test.go b/conn_test.go index 44fddbdb..9edb3aee 100644 --- a/conn_test.go +++ b/conn_test.go @@ -3662,3 +3662,54 @@ func TestMultiHandshake(t *testing.T) { t.Fatal(err) } } + +func TestCloseDuringHandshake(t *testing.T) { + defer test.CheckRoutines(t)() + defer test.TimeOut(time.Second * 10).Stop() + + serverCert, err := selfsign.GenerateSelfSigned() + if err != nil { + t.Fatal(err) + } + + for i := 0; i < 100; i++ { + _, cb := dpipe.Pipe() + server, err := Server(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ + Certificates: []tls.Certificate{serverCert}, + }) + if err != nil { + t.Fatal(err) + } + + waitChan := make(chan struct{}) + go func() { + close(waitChan) + _ = server.Handshake() + }() + + <-waitChan + if err = server.Close(); err != nil { + t.Fatal(err) + } + } +} + +func TestCloseWithoutHandshake(t *testing.T) { + defer test.CheckRoutines(t)() + defer test.TimeOut(time.Second * 10).Stop() + + serverCert, err := selfsign.GenerateSelfSigned() + if err != nil { + t.Fatal(err) + } + _, cb := dpipe.Pipe() + server, err := Server(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ + Certificates: []tls.Certificate{serverCert}, + }) + if err != nil { + t.Fatal(err) + } + if err = server.Close(); err != nil { + t.Fatal(err) + } +}