From f4011b144428a1c20b8ce9c055f9c7a81f681620 Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Tue, 16 Jul 2024 16:59:43 -0400 Subject: [PATCH] Perform handshake on first read/write Updates the connection to perform a handshake on first read/write instead of on accept. Closes https://github.com/pion/dtls/issues/279 --- conn.go | 127 ++++++++++++++++++++++++-------------------------- conn_test.go | 4 +- handshaker.go | 2 + resume.go | 12 +---- 4 files changed, 66 insertions(+), 79 deletions(-) diff --git a/conn.go b/conn.go index 0ca5090d2..112507565 100644 --- a/conn.go +++ b/conn.go @@ -94,9 +94,11 @@ type Conn struct { fsm *handshakeFSM replayProtectionWindow uint + + handshakeConfig *handshakeConfig } -func createConn(nextConn net.PacketConn, rAddr net.Addr, config *Config, isClient bool) (*Conn, error) { +func createConn(nextConn net.PacketConn, rAddr net.Addr, config *Config, isClient bool, resumeState *State) (*Conn, error) { if err := validateConfig(config); err != nil { return nil, err } @@ -127,42 +129,6 @@ func createConn(nextConn net.PacketConn, rAddr net.Addr, config *Config, isClien paddingLengthGenerator = func(uint) uint { return 0 } } - c := &Conn{ - rAddr: rAddr, - nextConn: netctx.NewPacketConn(nextConn), - fragmentBuffer: newFragmentBuffer(), - handshakeCache: newHandshakeCache(), - maximumTransmissionUnit: mtu, - paddingLengthGenerator: paddingLengthGenerator, - - decrypted: make(chan interface{}, 1), - log: logger, - - readDeadline: deadline.New(), - writeDeadline: deadline.New(), - - reading: make(chan struct{}, 1), - handshakeRecv: make(chan recvHandshakeState), - closed: closer.NewCloser(), - cancelHandshaker: func() {}, - - replayProtectionWindow: uint(replayProtectionWindow), - - state: State{ - isClient: isClient, - }, - } - - c.setRemoteEpoch(0) - c.setLocalEpoch(0) - return c, nil -} - -func handshakeConn(ctx context.Context, conn *Conn, config *Config, isClient bool, initialState *State) (*Conn, error) { - if conn == nil { - return nil, errNilNextConn - } - cipherSuites, err := parseCipherSuites(config.CipherSuites, config.CustomCipherSuites, config.includeCertificateSuites(), config.PSK != nil) if err != nil { return nil, err @@ -190,7 +156,7 @@ func handshakeConn(ctx context.Context, conn *Conn, config *Config, isClient boo curves = defaultCurves } - hsCfg := &handshakeConfig{ + handshakeConfig := &handshakeConfig{ localPSKCallback: config.PSK, localPSKIdentityHint: config.PSKIdentityHint, localCipherSuites: cipherSuites, @@ -209,7 +175,7 @@ func handshakeConn(ctx context.Context, conn *Conn, config *Config, isClient boo customCipherSuites: config.CustomCipherSuites, initialRetransmitInterval: workerInterval, disableRetransmitBackoff: config.DisableRetransmitBackoff, - log: conn.log, + log: logger, initialEpoch: 0, keyLogWriter: config.KeyLogWriter, sessionStore: config.SessionStore, @@ -222,33 +188,71 @@ func handshakeConn(ctx context.Context, conn *Conn, config *Config, isClient boo clientHelloMessageHook: config.ClientHelloMessageHook, serverHelloMessageHook: config.ServerHelloMessageHook, certificateRequestMessageHook: config.CertificateRequestMessageHook, + resumeState: resumeState, + } + + c := &Conn{ + rAddr: rAddr, + nextConn: netctx.NewPacketConn(nextConn), + handshakeConfig: handshakeConfig, + fragmentBuffer: newFragmentBuffer(), + handshakeCache: newHandshakeCache(), + maximumTransmissionUnit: mtu, + paddingLengthGenerator: paddingLengthGenerator, + + decrypted: make(chan interface{}, 1), + log: logger, + + readDeadline: deadline.New(), + writeDeadline: deadline.New(), + + reading: make(chan struct{}, 1), + handshakeRecv: make(chan recvHandshakeState), + closed: closer.NewCloser(), + cancelHandshaker: func() {}, + + replayProtectionWindow: uint(replayProtectionWindow), + + state: State{ + isClient: isClient, + }, } + c.setRemoteEpoch(0) + c.setLocalEpoch(0) + return c, nil +} + +func (c *Conn) Handshake() error { + return c.HandshakeContext(context.Background()) +} + +func (c *Conn) HandshakeContext(ctx context.Context) error { // rfc5246#section-7.4.3 // In addition, the hash and signature algorithms MUST be compatible // with the key in the server's end-entity certificate. - if !isClient { - cert, err := hsCfg.getCertificate(&ClientHelloInfo{}) + if !c.state.isClient { + cert, err := c.handshakeConfig.getCertificate(&ClientHelloInfo{}) if err != nil && !errors.Is(err, errNoCertificates) { - return nil, err + return err } - hsCfg.localCipherSuites = filterCipherSuitesForCertificate(cert, cipherSuites) + c.handshakeConfig.localCipherSuites = filterCipherSuitesForCertificate(cert, c.handshakeConfig.localCipherSuites) } var initialFlight flightVal var initialFSMState handshakeState - if initialState != nil { - if conn.state.isClient { + if c.handshakeConfig.resumeState != nil { + if c.state.isClient { initialFlight = flight5 } else { initialFlight = flight6 } initialFSMState = handshakeFinished - conn.state = *initialState + c.state = *c.handshakeConfig.resumeState } else { - if conn.state.isClient { + if c.state.isClient { initialFlight = flight1 } else { initialFlight = flight0 @@ -256,13 +260,13 @@ func handshakeConn(ctx context.Context, conn *Conn, config *Config, isClient boo initialFSMState = handshakePreparing } // Do handshake - if err := conn.handshake(ctx, hsCfg, initialFlight, initialFSMState); err != nil { - return nil, err + if err := c.handshake(ctx, c.handshakeConfig, initialFlight, initialFSMState); err != nil { + return err } - conn.log.Trace("Handshake Completed") + c.log.Trace("Handshake Completed") - return conn, nil + return nil } // Dial connects to the given network address and establishes a DTLS connection on top. @@ -318,12 +322,7 @@ func ClientWithContext(ctx context.Context, conn net.PacketConn, rAddr net.Addr, return nil, errPSKAndIdentityMustBeSetForClient } - dconn, err := createConn(conn, rAddr, config, true) - if err != nil { - return nil, err - } - - return handshakeConn(ctx, dconn, config, true, nil) + return createConn(conn, rAddr, config, true, nil) } // ServerWithContext listens for incoming DTLS connections. @@ -336,17 +335,13 @@ func ServerWithContext(ctx context.Context, conn net.PacketConn, rAddr net.Addr, return nil, err } } - dconn, err := createConn(conn, rAddr, config, false) - if err != nil { - return nil, err - } - return handshakeConn(ctx, dconn, config, false, nil) + return createConn(conn, rAddr, config, false, nil) } // Read reads data from the connection. func (c *Conn) Read(p []byte) (n int, err error) { - if !c.isHandshakeCompletedSuccessfully() { - return 0, errHandshakeInProgress + if err := c.Handshake(); err != nil { + return 0, err } select { @@ -389,8 +384,8 @@ func (c *Conn) Write(p []byte) (int, error) { default: } - if !c.isHandshakeCompletedSuccessfully() { - return 0, errHandshakeInProgress + if err := c.Handshake(); err != nil { + return 0, err } return len(p), c.writePackets(c.writeDeadline, []*packet{ diff --git a/conn_test.go b/conn_test.go index 3d4b3ab47..cb89be89f 100644 --- a/conn_test.go +++ b/conn_test.go @@ -3306,7 +3306,7 @@ func TestApplicationDataQueueLimited(t *testing.T) { cfg := &Config{} cfg.Certificates = []tls.Certificate{serverCert} - dconn, err := createConn(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), cfg, false) + dconn, err := createConn(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), cfg, false, nil) if err != nil { t.Error(err) return @@ -3322,7 +3322,7 @@ func TestApplicationDataQueueLimited(t *testing.T) { time.Sleep(1 * time.Second) } }() - if _, err := handshakeConn(ctx, dconn, cfg, false, nil); err == nil { + if err := dconn.HandshakeContext(ctx); err == nil { t.Error("expected handshake to fail") } close(done) diff --git a/handshaker.go b/handshaker.go index a585c3db8..62a4bf6e8 100644 --- a/handshaker.go +++ b/handshaker.go @@ -132,6 +132,8 @@ type handshakeConfig struct { clientHelloMessageHook func(handshake.MessageClientHello) handshake.Message serverHelloMessageHook func(handshake.MessageServerHello) handshake.Message certificateRequestMessageHook func(handshake.MessageCertificateRequest) handshake.Message + + resumeState *State } type flightConn interface { diff --git a/resume.go b/resume.go index 6cd1c5a69..0b76314a5 100644 --- a/resume.go +++ b/resume.go @@ -4,7 +4,6 @@ package dtls import ( - "context" "net" ) @@ -13,14 +12,5 @@ func Resume(state *State, conn net.PacketConn, rAddr net.Addr, config *Config) ( if err := state.initCipherSuite(); err != nil { return nil, err } - dconn, err := createConn(conn, rAddr, config, state.isClient) - if err != nil { - return nil, err - } - c, err := handshakeConn(context.Background(), dconn, config, state.isClient, state) - if err != nil { - return nil, err - } - - return c, nil + return createConn(conn, rAddr, config, state.isClient, state) }