From 9d39efb07ff6df2e121938dfc2809463ba2717e1 Mon Sep 17 00:00:00 2001 From: Daniel Adam Date: Thu, 22 Aug 2024 10:38:44 +0200 Subject: [PATCH] fixup! Synchronize parallel calls to Conn.Close and Conn.handshake --- conn.go | 26 +++++++++++++++++++------- conn_test.go | 29 +++++++++++++++++++++++++---- 2 files changed, 44 insertions(+), 11 deletions(-) diff --git a/conn.go b/conn.go index 46c44ad2..26c56f2e 100644 --- a/conn.go +++ b/conn.go @@ -74,13 +74,13 @@ type Conn struct { handshakeCompletedSuccessfully atomic.Value handshakeMutex sync.Mutex + handshakeDone chan struct{} encryptedPackets []addrPkt connectionClosedByUser bool closeLock sync.Mutex closed *closer.Closer - handshakeLoopsFinished sync.WaitGroup readDeadline *deadline.Deadline writeDeadline *deadline.Deadline @@ -256,6 +256,12 @@ func (c *Conn) HandshakeContext(ctx context.Context) error { return nil } + handshakeDone := make(chan struct{}) + defer close(handshakeDone) + c.closeLock.Lock() + c.handshakeDone = handshakeDone + c.closeLock.Unlock() + // rfc5246#section-7.4.3 // In addition, the hash and signature algorithms MUST be compatible // with the key in the server's end-entity certificate. @@ -405,7 +411,12 @@ func (c *Conn) Write(p []byte) (int, error) { // Close closes the connection. func (c *Conn) Close() error { err := c.close(true) //nolint:contextcheck - c.handshakeLoopsFinished.Wait() + c.closeLock.Lock() + handshakeDone := c.handshakeDone + c.closeLock.Unlock() + if handshakeDone != nil { + <-handshakeDone + } return err } @@ -1042,12 +1053,13 @@ func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFligh firstErr := make(chan error, 1) - c.handshakeLoopsFinished.Add(2) + var handshakeLoopsFinished sync.WaitGroup + handshakeLoopsFinished.Add(2) // Handshake routine should be live until close. // The other party may request retransmission of the last flight to cope with packet drop. go func() { - defer c.handshakeLoopsFinished.Done() + defer handshakeLoopsFinished.Done() err := c.fsm.Run(ctxHs, c, initialState) if !errors.Is(err, context.Canceled) { select { @@ -1067,7 +1079,7 @@ func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFligh // Force stop handshaker when the underlying connection is closed. cancel() }() - defer c.handshakeLoopsFinished.Done() + defer handshakeLoopsFinished.Done() for { if err := c.readAndBuffer(ctxRead); err != nil { var e *alertError @@ -1126,12 +1138,12 @@ func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFligh case err := <-firstErr: cancelRead() cancel() - c.handshakeLoopsFinished.Wait() + handshakeLoopsFinished.Wait() return c.translateHandshakeCtxError(err) case <-ctx.Done(): cancelRead() cancel() - c.handshakeLoopsFinished.Wait() + handshakeLoopsFinished.Wait() return c.translateHandshakeCtxError(ctx.Err()) case <-done: return nil diff --git a/conn_test.go b/conn_test.go index c909c2af..9edb3aee 100644 --- a/conn_test.go +++ b/conn_test.go @@ -3667,12 +3667,13 @@ func TestCloseDuringHandshake(t *testing.T) { defer test.CheckRoutines(t)() defer test.TimeOut(time.Second * 10).Stop() + serverCert, err := selfsign.GenerateSelfSigned() + if err != nil { + t.Fatal(err) + } + for i := 0; i < 100; i++ { _, cb := dpipe.Pipe() - serverCert, err := selfsign.GenerateSelfSigned() - if err != nil { - t.Fatal(err) - } server, err := Server(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ Certificates: []tls.Certificate{serverCert}, }) @@ -3692,3 +3693,23 @@ func TestCloseDuringHandshake(t *testing.T) { } } } + +func TestCloseWithoutHandshake(t *testing.T) { + defer test.CheckRoutines(t)() + defer test.TimeOut(time.Second * 10).Stop() + + serverCert, err := selfsign.GenerateSelfSigned() + if err != nil { + t.Fatal(err) + } + _, cb := dpipe.Pipe() + server, err := Server(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ + Certificates: []tls.Certificate{serverCert}, + }) + if err != nil { + t.Fatal(err) + } + if err = server.Close(); err != nil { + t.Fatal(err) + } +}