diff --git a/conn.go b/conn.go index b01c00af..f55237c2 100644 --- a/conn.go +++ b/conn.go @@ -73,6 +73,7 @@ type Conn struct { paddingLengthGenerator func(uint) uint handshakeCompletedSuccessfully atomic.Value + handshakeMutex sync.Mutex encryptedPackets []addrPkt @@ -229,6 +230,13 @@ func (c *Conn) Handshake() error { } func (c *Conn) HandshakeContext(ctx context.Context) error { + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + + if c.isHandshakeCompletedSuccessfully() { + return nil + } + // rfc5246#section-7.4.3 // In addition, the hash and signature algorithms MUST be compatible // with the key in the server's end-entity certificate. diff --git a/conn_test.go b/conn_test.go index 089308d7..db7c2ab4 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1143,17 +1143,18 @@ func TestClientCertificate(t *testing.T) { t.Run(name, func(t *testing.T) { ca, cb := dpipe.Pipe() type result struct { - c *Conn - err error + c *Conn + err, hserr error } c := make(chan result) go func() { client, err := Client(dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), tt.clientCfg) - c <- result{client, err} + c <- result{client, err, client.Handshake()} }() server, err := Server(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), tt.serverCfg) + hserr := server.Handshake() res := <-c defer func() { if err == nil { @@ -1165,7 +1166,7 @@ func TestClientCertificate(t *testing.T) { }() if tt.wantErr { - if err != nil { + if err != nil || hserr != nil { // Error expected, test succeeded return } @@ -1564,23 +1565,24 @@ func TestServerCertificate(t *testing.T) { ca, cb := dpipe.Pipe() type result struct { - c *Conn - err error + c *Conn + err, hserr error } srvCh := make(chan result) go func() { s, err := Server(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), tt.serverCfg) - srvCh <- result{s, err} + srvCh <- result{s, err, s.Handshake()} }() cli, err := Client(dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), tt.clientCfg) if err == nil { _ = cli.Close() } - if !tt.wantErr && err != nil { - t.Errorf("Client failed(%v)", err) + hserr := cli.Handshake() + if !tt.wantErr && (err != nil || hserr != nil) { + t.Errorf("Client failed(%v, %v)", err, hserr) } - if tt.wantErr && err == nil { + if tt.wantErr && err == nil && hserr == nil { t.Fatal("Error expected") } diff --git a/e2e/e2e_test.go b/e2e/e2e_test.go index ec1253ec..aad9649a 100644 --- a/e2e/e2e_test.go +++ b/e2e/e2e_test.go @@ -219,8 +219,7 @@ func clientPion(c *comm) { c.clientMutex.Lock() defer c.clientMutex.Unlock() - var err error - c.clientConn, err = dtls.DialWithContext(c.ctx, "udp", + conn, err := dtls.DialWithContext(c.ctx, "udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: c.serverPort}, c.clientConfig, ) @@ -229,6 +228,13 @@ func clientPion(c *comm) { return } + if err := conn.HandshakeContext(c.ctx); err != nil { + c.errChan <- err + return + } + + c.clientConn = conn + simpleReadWrite(c.errChan, c.clientChan, c.clientConn, c.messageRecvCount) c.clientDone <- nil close(c.clientDone) @@ -254,6 +260,11 @@ func serverPion(c *comm) { return } + if err := (c.serverConn.(*dtls.Conn)).HandshakeContext(c.ctx); err != nil { + c.errChan <- err + return + } + simpleReadWrite(c.errChan, c.serverChan, c.serverConn, c.messageRecvCount) c.serverDone <- nil close(c.serverDone)