From d4b2f07d5c2432c8077b4534bab85896395c306e Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Fri, 19 Jul 2024 21:03:49 -0400 Subject: [PATCH] Perform handshake on first read/write Updates the connection to perform a handshake on first read/write instead of on accept. Closes https://github.com/pion/dtls/issues/279. --- bench_test.go | 2 +- config.go | 20 +- conn.go | 217 ++++++++++-------- conn_go_test.go | 45 +++- conn_test.go | 42 ++-- e2e/e2e_test.go | 18 +- examples/dial/cid/main.go | 7 +- examples/dial/psk/main.go | 7 +- examples/dial/selfsign/main.go | 7 +- examples/dial/verify/main.go | 7 +- examples/listen/cid/main.go | 16 +- examples/listen/psk/main.go | 10 - examples/listen/selfsign/main.go | 10 - .../verify-brute-force-protection/main.go | 9 - examples/listen/verify/main.go | 10 - handshaker.go | 2 + listener.go | 2 - resume.go | 12 +- 18 files changed, 227 insertions(+), 216 deletions(-) diff --git a/bench_test.go b/bench_test.go index 7b236f6d8..885b311f2 100644 --- a/bench_test.go +++ b/bench_test.go @@ -40,7 +40,7 @@ func TestSimpleReadWrite(t *testing.T) { return } buf := make([]byte, 1024) - if _, sErr = server.Read(buf); sErr != nil { + if _, sErr = server.Read(buf); sErr != nil { //nolint:contextcheck t.Error(sErr) } gotHello <- struct{}{} diff --git a/config.go b/config.go index 0f6813d24..dd3a0b117 100644 --- a/config.go +++ b/config.go @@ -118,13 +118,10 @@ type Config struct { LoggerFactory logging.LoggerFactory - // ConnectContextMaker is a function to make a context used in Dial(), - // Client(), Server(), and Accept(). If nil, the default ConnectContextMaker - // is used. It can be implemented as following. + // ConnectContextMaker is no longer used. It is kept for compatibility + // reasons and will be removed in a future release. // - // func ConnectContextMaker() (context.Context, func()) { - // return context.WithTimeout(context.Background(), 30*time.Second) - // } + // Deprecated: Use the context parameter in [HandshakeContext] instead. ConnectContextMaker func() (context.Context, func()) // MTU is the length at which handshake messages will be fragmented to @@ -230,17 +227,6 @@ type Config struct { OnConnectionAttempt func(net.Addr) error } -func defaultConnectContextMaker() (context.Context, func()) { - return context.WithTimeout(context.Background(), 30*time.Second) -} - -func (c *Config) connectContextMaker() (context.Context, func()) { - if c.ConnectContextMaker == nil { - return defaultConnectContextMaker() - } - return c.ConnectContextMaker() -} - func (c *Config) includeCertificateSuites() bool { return c.PSK == nil || len(c.Certificates) > 0 || c.GetCertificate != nil || c.GetClientCertificate != nil } diff --git a/conn.go b/conn.go index 0ca5090d2..ba947aa31 100644 --- a/conn.go +++ b/conn.go @@ -73,6 +73,7 @@ type Conn struct { paddingLengthGenerator func(uint) uint handshakeCompletedSuccessfully atomic.Value + handshakeMutex sync.Mutex encryptedPackets []addrPkt @@ -94,9 +95,11 @@ type Conn struct { fsm *handshakeFSM replayProtectionWindow uint + + handshakeConfig *handshakeConfig } -func createConn(nextConn net.PacketConn, rAddr net.Addr, config *Config, isClient bool) (*Conn, error) { +func createConn(nextConn net.PacketConn, rAddr net.Addr, config *Config, isClient bool, resumeState *State) (*Conn, error) { if err := validateConfig(config); err != nil { return nil, err } @@ -127,42 +130,6 @@ func createConn(nextConn net.PacketConn, rAddr net.Addr, config *Config, isClien paddingLengthGenerator = func(uint) uint { return 0 } } - c := &Conn{ - rAddr: rAddr, - nextConn: netctx.NewPacketConn(nextConn), - fragmentBuffer: newFragmentBuffer(), - handshakeCache: newHandshakeCache(), - maximumTransmissionUnit: mtu, - paddingLengthGenerator: paddingLengthGenerator, - - decrypted: make(chan interface{}, 1), - log: logger, - - readDeadline: deadline.New(), - writeDeadline: deadline.New(), - - reading: make(chan struct{}, 1), - handshakeRecv: make(chan recvHandshakeState), - closed: closer.NewCloser(), - cancelHandshaker: func() {}, - - replayProtectionWindow: uint(replayProtectionWindow), - - state: State{ - isClient: isClient, - }, - } - - c.setRemoteEpoch(0) - c.setLocalEpoch(0) - return c, nil -} - -func handshakeConn(ctx context.Context, conn *Conn, config *Config, isClient bool, initialState *State) (*Conn, error) { - if conn == nil { - return nil, errNilNextConn - } - cipherSuites, err := parseCipherSuites(config.CipherSuites, config.CustomCipherSuites, config.includeCertificateSuites(), config.PSK != nil) if err != nil { return nil, err @@ -190,7 +157,7 @@ func handshakeConn(ctx context.Context, conn *Conn, config *Config, isClient boo curves = defaultCurves } - hsCfg := &handshakeConfig{ + handshakeConfig := &handshakeConfig{ localPSKCallback: config.PSK, localPSKIdentityHint: config.PSKIdentityHint, localCipherSuites: cipherSuites, @@ -209,7 +176,7 @@ func handshakeConn(ctx context.Context, conn *Conn, config *Config, isClient boo customCipherSuites: config.CustomCipherSuites, initialRetransmitInterval: workerInterval, disableRetransmitBackoff: config.DisableRetransmitBackoff, - log: conn.log, + log: logger, initialEpoch: 0, keyLogWriter: config.KeyLogWriter, sessionStore: config.SessionStore, @@ -222,33 +189,97 @@ func handshakeConn(ctx context.Context, conn *Conn, config *Config, isClient boo clientHelloMessageHook: config.ClientHelloMessageHook, serverHelloMessageHook: config.ServerHelloMessageHook, certificateRequestMessageHook: config.CertificateRequestMessageHook, + resumeState: resumeState, + } + + c := &Conn{ + rAddr: rAddr, + nextConn: netctx.NewPacketConn(nextConn), + handshakeConfig: handshakeConfig, + fragmentBuffer: newFragmentBuffer(), + handshakeCache: newHandshakeCache(), + maximumTransmissionUnit: mtu, + paddingLengthGenerator: paddingLengthGenerator, + + decrypted: make(chan interface{}, 1), + log: logger, + + readDeadline: deadline.New(), + writeDeadline: deadline.New(), + + reading: make(chan struct{}, 1), + handshakeRecv: make(chan recvHandshakeState), + closed: closer.NewCloser(), + cancelHandshaker: func() {}, + cancelHandshakeReader: func() {}, + + replayProtectionWindow: uint(replayProtectionWindow), + + state: State{ + isClient: isClient, + }, + } + + c.setRemoteEpoch(0) + c.setLocalEpoch(0) + return c, nil +} + +// Handshake runs the client or server DTLS handshake +// protocol if it has not yet been run. +// +// Most uses of this package need not call Handshake explicitly: the +// first [Conn.Read] or [Conn.Write] will call it automatically. +// +// For control over canceling or setting a timeout on a handshake, use +// [Conn.HandshakeContext]. +func (c *Conn) Handshake() error { + return c.HandshakeContext(context.Background()) +} + +// HandshakeContext runs the client or server DTLS handshake +// protocol if it has not yet been run. +// +// The provided Context must be non-nil. If the context is canceled before +// the handshake is complete, the handshake is interrupted and an error is returned. +// Once the handshake has completed, cancellation of the context will not affect the +// connection. +// +// Most uses of this package need not call HandshakeContext explicitly: the +// first [Conn.Read] or [Conn.Write] will call it automatically. +func (c *Conn) HandshakeContext(ctx context.Context) error { + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + + if c.isHandshakeCompletedSuccessfully() { + return nil } // rfc5246#section-7.4.3 // In addition, the hash and signature algorithms MUST be compatible // with the key in the server's end-entity certificate. - if !isClient { - cert, err := hsCfg.getCertificate(&ClientHelloInfo{}) + if !c.state.isClient { + cert, err := c.handshakeConfig.getCertificate(&ClientHelloInfo{}) if err != nil && !errors.Is(err, errNoCertificates) { - return nil, err + return err } - hsCfg.localCipherSuites = filterCipherSuitesForCertificate(cert, cipherSuites) + c.handshakeConfig.localCipherSuites = filterCipherSuitesForCertificate(cert, c.handshakeConfig.localCipherSuites) } var initialFlight flightVal var initialFSMState handshakeState - if initialState != nil { - if conn.state.isClient { + if c.handshakeConfig.resumeState != nil { + if c.state.isClient { initialFlight = flight5 } else { initialFlight = flight6 } initialFSMState = handshakeFinished - conn.state = *initialState + c.state = *c.handshakeConfig.resumeState } else { - if conn.state.isClient { + if c.state.isClient { initialFlight = flight1 } else { initialFlight = flight0 @@ -256,48 +287,17 @@ func handshakeConn(ctx context.Context, conn *Conn, config *Config, isClient boo initialFSMState = handshakePreparing } // Do handshake - if err := conn.handshake(ctx, hsCfg, initialFlight, initialFSMState); err != nil { - return nil, err + if err := c.handshake(ctx, c.handshakeConfig, initialFlight, initialFSMState); err != nil { + return err } - conn.log.Trace("Handshake Completed") + c.log.Trace("Handshake Completed") - return conn, nil + return nil } // Dial connects to the given network address and establishes a DTLS connection on top. -// Connection handshake will timeout using ConnectContextMaker in the Config. -// If you want to specify the timeout duration, use DialWithContext() instead. func Dial(network string, rAddr *net.UDPAddr, config *Config) (*Conn, error) { - ctx, cancel := config.connectContextMaker() - defer cancel() - - return DialWithContext(ctx, network, rAddr, config) -} - -// Client establishes a DTLS connection over an existing connection. -// Connection handshake will timeout using ConnectContextMaker in the Config. -// If you want to specify the timeout duration, use ClientWithContext() instead. -func Client(conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) { - ctx, cancel := config.connectContextMaker() - defer cancel() - - return ClientWithContext(ctx, conn, rAddr, config) -} - -// Server listens for incoming DTLS connections. -// Connection handshake will timeout using ConnectContextMaker in the Config. -// If you want to specify the timeout duration, use ServerWithContext() instead. -func Server(conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) { - ctx, cancel := config.connectContextMaker() - defer cancel() - - return ServerWithContext(ctx, conn, rAddr, config) -} - -// DialWithContext connects to the given network address and establishes a DTLS -// connection on top. -func DialWithContext(ctx context.Context, network string, rAddr *net.UDPAddr, config *Config) (*Conn, error) { // net.ListenUDP is used rather than net.DialUDP as the latter prevents the // use of net.PacketConn.WriteTo. // https://github.com/golang/go/blob/ce5e37ec21442c6eb13a43e68ca20129102ebac0/src/net/udpsock_posix.go#L115 @@ -306,11 +306,11 @@ func DialWithContext(ctx context.Context, network string, rAddr *net.UDPAddr, co return nil, err } - return ClientWithContext(ctx, pConn, rAddr, config) + return Client(pConn, rAddr, config) } -// ClientWithContext establishes a DTLS connection over an existing connection. -func ClientWithContext(ctx context.Context, conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) { +// Client establishes a DTLS connection over an existing connection. +func Client(conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) { switch { case config == nil: return nil, errNoConfigProvided @@ -318,16 +318,11 @@ func ClientWithContext(ctx context.Context, conn net.PacketConn, rAddr net.Addr, return nil, errPSKAndIdentityMustBeSetForClient } - dconn, err := createConn(conn, rAddr, config, true) - if err != nil { - return nil, err - } - - return handshakeConn(ctx, dconn, config, true, nil) + return createConn(conn, rAddr, config, true, nil) } -// ServerWithContext listens for incoming DTLS connections. -func ServerWithContext(ctx context.Context, conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) { +// Server listens for incoming DTLS connections. +func Server(conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) { if config == nil { return nil, errNoConfigProvided } @@ -336,17 +331,35 @@ func ServerWithContext(ctx context.Context, conn net.PacketConn, rAddr net.Addr, return nil, err } } - dconn, err := createConn(conn, rAddr, config, false) - if err != nil { - return nil, err - } - return handshakeConn(ctx, dconn, config, false, nil) + return createConn(conn, rAddr, config, false, nil) +} + +// DialWithContext connects to the given network address and establishes a DTLS +// connection on top. +// +// Deprecated: Use Dial instead, the context parameter is no longer used. +func DialWithContext(_ context.Context, network string, rAddr *net.UDPAddr, config *Config) (*Conn, error) { + return Dial(network, rAddr, config) +} + +// ClientWithContext establishes a DTLS connection over an existing connection. +// +// Deprecated: Use Client instead, the context parameter is no longer used. +func ClientWithContext(_ context.Context, conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) { + return Client(conn, rAddr, config) +} + +// ServerWithContext listens for incoming DTLS connections. +// +// Deprecated: Use Server instead, the context parameter is no longer used. +func ServerWithContext(_ context.Context, conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) { + return Server(conn, rAddr, config) } // Read reads data from the connection. func (c *Conn) Read(p []byte) (n int, err error) { - if !c.isHandshakeCompletedSuccessfully() { - return 0, errHandshakeInProgress + if err := c.Handshake(); err != nil { + return 0, err } select { @@ -389,8 +402,8 @@ func (c *Conn) Write(p []byte) (int, error) { default: } - if !c.isHandshakeCompletedSuccessfully() { - return 0, errHandshakeInProgress + if err := c.Handshake(); err != nil { + return 0, err } return len(p), c.writePackets(c.writeDeadline, []*packet{ diff --git a/conn_go_test.go b/conn_go_test.go index d9ca6e187..c79d1b70e 100644 --- a/conn_go_test.go +++ b/conn_go_test.go @@ -52,9 +52,6 @@ func TestContextConfig(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } config := &Config{ - ConnectContextMaker: func() (context.Context, func()) { - return context.WithTimeout(context.Background(), 40*time.Millisecond) - }, Certificates: []tls.Certificate{cert}, } @@ -64,9 +61,15 @@ func TestContextConfig(t *testing.T) { }{ "Dial": { f: func() (func() (net.Conn, error), func()) { + ctx, cancel := context.WithTimeout(context.Background(), 40*time.Millisecond) return func() (net.Conn, error) { - return Dial("udp", addr, config) + conn, err := Dial("udp", addr, config) + if err != nil { + return nil, err + } + return conn, conn.HandshakeContext(ctx) }, func() { + cancel() } }, order: []byte{0, 1, 2}, @@ -75,7 +78,11 @@ func TestContextConfig(t *testing.T) { f: func() (func() (net.Conn, error), func()) { ctx, cancel := context.WithTimeout(context.Background(), 80*time.Millisecond) return func() (net.Conn, error) { - return DialWithContext(ctx, "udp", addr, config) + conn, err := DialWithContext(ctx, "udp", addr, config) + if err != nil { + return nil, err + } + return conn, conn.HandshakeContext(ctx) }, func() { cancel() } @@ -85,10 +92,16 @@ func TestContextConfig(t *testing.T) { "Client": { f: func() (func() (net.Conn, error), func()) { ca, _ := dpipe.Pipe() + ctx, cancel := context.WithTimeout(context.Background(), 40*time.Millisecond) return func() (net.Conn, error) { - return Client(dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), config) + conn, err := Client(dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), config) + if err != nil { + return nil, err + } + return conn, conn.HandshakeContext(ctx) }, func() { _ = ca.Close() + cancel() } }, order: []byte{0, 1, 2}, @@ -98,7 +111,11 @@ func TestContextConfig(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 80*time.Millisecond) ca, _ := dpipe.Pipe() return func() (net.Conn, error) { - return ClientWithContext(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), config) + conn, err := ClientWithContext(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), config) + if err != nil { + return nil, err + } + return conn, conn.HandshakeContext(ctx) }, func() { cancel() _ = ca.Close() @@ -109,10 +126,16 @@ func TestContextConfig(t *testing.T) { "Server": { f: func() (func() (net.Conn, error), func()) { ca, _ := dpipe.Pipe() + ctx, cancel := context.WithTimeout(context.Background(), 40*time.Millisecond) return func() (net.Conn, error) { - return Server(dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), config) + conn, err := Server(dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), config) + if err != nil { + return nil, err + } + return conn, conn.HandshakeContext(ctx) }, func() { _ = ca.Close() + cancel() } }, order: []byte{0, 1, 2}, @@ -122,7 +145,11 @@ func TestContextConfig(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 80*time.Millisecond) ca, _ := dpipe.Pipe() return func() (net.Conn, error) { - return ServerWithContext(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), config) + conn, err := ServerWithContext(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), config) + if err != nil { + return nil, err + } + return conn, conn.HandshakeContext(ctx) }, func() { cancel() _ = ca.Close() diff --git a/conn_test.go b/conn_test.go index 3d4b3ab47..03364e2af 100644 --- a/conn_test.go +++ b/conn_test.go @@ -295,7 +295,11 @@ func testClient(ctx context.Context, c net.PacketConn, rAddr net.Addr, cfg *Conf cfg.Certificates = []tls.Certificate{clientCert} } cfg.InsecureSkipVerify = true - return ClientWithContext(ctx, c, rAddr, cfg) + conn, err := ClientWithContext(ctx, c, rAddr, cfg) + if err != nil { + return nil, err + } + return conn, conn.HandshakeContext(ctx) } func testServer(ctx context.Context, c net.PacketConn, rAddr net.Addr, cfg *Config, generateCertificate bool) (*Conn, error) { @@ -306,7 +310,11 @@ func testServer(ctx context.Context, c net.PacketConn, rAddr net.Addr, cfg *Conf } cfg.Certificates = []tls.Certificate{serverCert} } - return ServerWithContext(ctx, c, rAddr, cfg) + conn, err := ServerWithContext(ctx, c, rAddr, cfg) + if err != nil { + return nil, err + } + return conn, conn.HandshakeContext(ctx) } func sendClientHello(cookie []byte, ca net.Conn, sequenceNumber uint64, extensions []extension.Extension) error { @@ -1135,17 +1143,18 @@ func TestClientCertificate(t *testing.T) { t.Run(name, func(t *testing.T) { ca, cb := dpipe.Pipe() type result struct { - c *Conn - err error + c *Conn + err, hserr error } c := make(chan result) go func() { client, err := Client(dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), tt.clientCfg) - c <- result{client, err} + c <- result{client, err, client.Handshake()} }() server, err := Server(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), tt.serverCfg) + hserr := server.Handshake() res := <-c defer func() { if err == nil { @@ -1157,7 +1166,7 @@ func TestClientCertificate(t *testing.T) { }() if tt.wantErr { - if err != nil { + if err != nil || hserr != nil { // Error expected, test succeeded return } @@ -1556,23 +1565,24 @@ func TestServerCertificate(t *testing.T) { ca, cb := dpipe.Pipe() type result struct { - c *Conn - err error + c *Conn + err, hserr error } srvCh := make(chan result) go func() { s, err := Server(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), tt.serverCfg) - srvCh <- result{s, err} + srvCh <- result{s, err, s.Handshake()} }() cli, err := Client(dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), tt.clientCfg) + hserr := cli.Handshake() if err == nil { _ = cli.Close() } - if !tt.wantErr && err != nil { - t.Errorf("Client failed(%v)", err) + if !tt.wantErr && (err != nil || hserr != nil) { + t.Errorf("Client failed(%v, %v)", err, hserr) } - if tt.wantErr && err == nil { + if tt.wantErr && err == nil && hserr == nil { t.Fatal("Error expected") } @@ -3237,7 +3247,7 @@ func TestSkipHelloVerify(t *testing.T) { return } buf := make([]byte, 1024) - if _, sErr = server.Read(buf); sErr != nil { + if _, sErr = server.Read(buf); sErr != nil { //nolint:contextcheck t.Error(sErr) } gotHello <- struct{}{} @@ -3306,7 +3316,7 @@ func TestApplicationDataQueueLimited(t *testing.T) { cfg := &Config{} cfg.Certificates = []tls.Certificate{serverCert} - dconn, err := createConn(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), cfg, false) + dconn, err := createConn(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), cfg, false, nil) if err != nil { t.Error(err) return @@ -3322,7 +3332,7 @@ func TestApplicationDataQueueLimited(t *testing.T) { time.Sleep(1 * time.Second) } }() - if _, err := handshakeConn(ctx, dconn, cfg, false, nil); err == nil { + if err := dconn.HandshakeContext(ctx); err == nil { t.Error("expected handshake to fail") } close(done) @@ -3402,7 +3412,7 @@ func TestHelloRandom(t *testing.T) { return } buf := make([]byte, 1024) - if _, sErr = server.Read(buf); sErr != nil { + if _, sErr = server.Read(buf); sErr != nil { //nolint:contextcheck t.Error(sErr) } gotHello <- struct{}{} diff --git a/e2e/e2e_test.go b/e2e/e2e_test.go index ec1253ec8..0d8ba35bc 100644 --- a/e2e/e2e_test.go +++ b/e2e/e2e_test.go @@ -219,8 +219,7 @@ func clientPion(c *comm) { c.clientMutex.Lock() defer c.clientMutex.Unlock() - var err error - c.clientConn, err = dtls.DialWithContext(c.ctx, "udp", + conn, err := dtls.Dial("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: c.serverPort}, c.clientConfig, ) @@ -229,6 +228,13 @@ func clientPion(c *comm) { return } + if err := conn.HandshakeContext(c.ctx); err != nil { + c.errChan <- err + return + } + + c.clientConn = conn + simpleReadWrite(c.errChan, c.clientChan, c.clientConn, c.messageRecvCount) c.clientDone <- nil close(c.clientDone) @@ -254,6 +260,14 @@ func serverPion(c *comm) { return } + dtlsConn, ok := c.serverConn.(*dtls.Conn) + if ok { + if err := dtlsConn.HandshakeContext(c.ctx); err != nil { + c.errChan <- err + return + } + } + simpleReadWrite(c.errChan, c.serverChan, c.serverConn, c.messageRecvCount) c.serverDone <- nil close(c.serverDone) diff --git a/examples/dial/cid/main.go b/examples/dial/cid/main.go index 10e547706..4859e72b1 100644 --- a/examples/dial/cid/main.go +++ b/examples/dial/cid/main.go @@ -37,12 +37,17 @@ func main() { // Connect to a DTLS server ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() - dtlsConn, err := dtls.DialWithContext(ctx, "udp", addr, config) + dtlsConn, err := dtls.Dial("udp", addr, config) util.Check(err) defer func() { util.Check(dtlsConn.Close()) }() + if err := dtlsConn.HandshakeContext(ctx); err != nil { + fmt.Printf("Failed to handshake with server: %v\n", err) + return + } + fmt.Println("Connected; type 'exit' to shutdown gracefully") // Simulate a chat session diff --git a/examples/dial/psk/main.go b/examples/dial/psk/main.go index b70efdcc3..94731e93a 100644 --- a/examples/dial/psk/main.go +++ b/examples/dial/psk/main.go @@ -36,12 +36,17 @@ func main() { // Connect to a DTLS server ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() - dtlsConn, err := dtls.DialWithContext(ctx, "udp", addr, config) + dtlsConn, err := dtls.Dial("udp", addr, config) util.Check(err) defer func() { util.Check(dtlsConn.Close()) }() + if err := dtlsConn.HandshakeContext(ctx); err != nil { + fmt.Printf("Failed to handshake with server: %v\n", err) + return + } + fmt.Println("Connected; type 'exit' to shutdown gracefully") // Simulate a chat session diff --git a/examples/dial/selfsign/main.go b/examples/dial/selfsign/main.go index 5fa25a923..b3be5a8f8 100644 --- a/examples/dial/selfsign/main.go +++ b/examples/dial/selfsign/main.go @@ -38,12 +38,17 @@ func main() { // Connect to a DTLS server ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() - dtlsConn, err := dtls.DialWithContext(ctx, "udp", addr, config) + dtlsConn, err := dtls.Dial("udp", addr, config) util.Check(err) defer func() { util.Check(dtlsConn.Close()) }() + if err := dtlsConn.HandshakeContext(ctx); err != nil { + fmt.Printf("Failed to handshake with server: %v\n", err) + return + } + fmt.Println("Connected; type 'exit' to shutdown gracefully") // Simulate a chat session diff --git a/examples/dial/verify/main.go b/examples/dial/verify/main.go index 07501954d..ed5352dd5 100644 --- a/examples/dial/verify/main.go +++ b/examples/dial/verify/main.go @@ -45,12 +45,17 @@ func main() { // Connect to a DTLS server ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() - dtlsConn, err := dtls.DialWithContext(ctx, "udp", addr, config) + dtlsConn, err := dtls.Dial("udp", addr, config) util.Check(err) defer func() { util.Check(dtlsConn.Close()) }() + if err := dtlsConn.HandshakeContext(ctx); err != nil { + fmt.Printf("Failed to handshake with server: %v\n", err) + return + } + fmt.Println("Connected; type 'exit' to shutdown gracefully") // Simulate a chat session diff --git a/examples/listen/cid/main.go b/examples/listen/cid/main.go index 770bbcfa4..2deacae5e 100644 --- a/examples/listen/cid/main.go +++ b/examples/listen/cid/main.go @@ -5,10 +5,8 @@ package main import ( - "context" "fmt" "net" - "time" "github.com/pion/dtls/v2" "github.com/pion/dtls/v2/examples/util" @@ -18,10 +16,6 @@ func main() { // Prepare the IP to connect to addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 4444} - // Create parent context to cleanup handshaking connections on exit. - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - // // Everything below is the pion-DTLS API! Thanks for using it ❤️. // @@ -32,13 +26,9 @@ func main() { fmt.Printf("Client's hint: %s \n", hint) return []byte{0xAB, 0xC1, 0x23}, nil }, - PSKIdentityHint: []byte("Pion DTLS Server"), - CipherSuites: []dtls.CipherSuiteID{dtls.TLS_PSK_WITH_AES_128_CCM_8}, - ExtendedMasterSecret: dtls.RequireExtendedMasterSecret, - // Create timeout context for accepted connection. - ConnectContextMaker: func() (context.Context, func()) { - return context.WithTimeout(ctx, 30*time.Second) - }, + PSKIdentityHint: []byte("Pion DTLS Server"), + CipherSuites: []dtls.CipherSuiteID{dtls.TLS_PSK_WITH_AES_128_CCM_8}, + ExtendedMasterSecret: dtls.RequireExtendedMasterSecret, ConnectionIDGenerator: dtls.RandomCIDGenerator(8), } diff --git a/examples/listen/psk/main.go b/examples/listen/psk/main.go index 66f099693..777d7c7b8 100644 --- a/examples/listen/psk/main.go +++ b/examples/listen/psk/main.go @@ -5,10 +5,8 @@ package main import ( - "context" "fmt" "net" - "time" "github.com/pion/dtls/v2" "github.com/pion/dtls/v2/examples/util" @@ -18,10 +16,6 @@ func main() { // Prepare the IP to connect to addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 4444} - // Create parent context to cleanup handshaking connections on exit. - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - // // Everything below is the pion-DTLS API! Thanks for using it ❤️. // @@ -35,10 +29,6 @@ func main() { PSKIdentityHint: []byte("Pion DTLS Server"), CipherSuites: []dtls.CipherSuiteID{dtls.TLS_PSK_WITH_AES_128_CCM_8}, ExtendedMasterSecret: dtls.RequireExtendedMasterSecret, - // Create timeout context for accepted connection. - ConnectContextMaker: func() (context.Context, func()) { - return context.WithTimeout(ctx, 30*time.Second) - }, } // Connect to a DTLS server diff --git a/examples/listen/selfsign/main.go b/examples/listen/selfsign/main.go index 025b667e4..6cbcf6bde 100644 --- a/examples/listen/selfsign/main.go +++ b/examples/listen/selfsign/main.go @@ -5,11 +5,9 @@ package main import ( - "context" "crypto/tls" "fmt" "net" - "time" "github.com/pion/dtls/v2" "github.com/pion/dtls/v2/examples/util" @@ -24,10 +22,6 @@ func main() { certificate, genErr := selfsign.GenerateSelfSigned() util.Check(genErr) - // Create parent context to cleanup handshaking connections on exit. - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - // // Everything below is the pion-DTLS API! Thanks for using it ❤️. // @@ -36,10 +30,6 @@ func main() { config := &dtls.Config{ Certificates: []tls.Certificate{certificate}, ExtendedMasterSecret: dtls.RequireExtendedMasterSecret, - // Create timeout context for accepted connection. - ConnectContextMaker: func() (context.Context, func()) { - return context.WithTimeout(ctx, 30*time.Second) - }, } // Connect to a DTLS server diff --git a/examples/listen/verify-brute-force-protection/main.go b/examples/listen/verify-brute-force-protection/main.go index b5fb82c42..2c07e7790 100644 --- a/examples/listen/verify-brute-force-protection/main.go +++ b/examples/listen/verify-brute-force-protection/main.go @@ -6,7 +6,6 @@ package main import ( - "context" "crypto/tls" "crypto/x509" "fmt" @@ -22,10 +21,6 @@ func main() { // Prepare the IP to connect to addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 4444} - // Create parent context to cleanup handshaking connections on exit. - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - // // Everything below is the pion-DTLS API! Thanks for using it ❤️. // @@ -52,10 +47,6 @@ func main() { ExtendedMasterSecret: dtls.RequireExtendedMasterSecret, ClientAuth: dtls.RequireAndVerifyClientCert, ClientCAs: certPool, - // Create timeout context for accepted connection. - ConnectContextMaker: func() (context.Context, func()) { - return context.WithTimeout(ctx, 30*time.Second) - }, // This function will be called on each connection attempt. OnConnectionAttempt: func(addr net.Addr) error { // *************** Brute Force Attack protection *************** diff --git a/examples/listen/verify/main.go b/examples/listen/verify/main.go index a02211e15..6e48753f4 100644 --- a/examples/listen/verify/main.go +++ b/examples/listen/verify/main.go @@ -5,12 +5,10 @@ package main import ( - "context" "crypto/tls" "crypto/x509" "fmt" "net" - "time" "github.com/pion/dtls/v2" "github.com/pion/dtls/v2/examples/util" @@ -20,10 +18,6 @@ func main() { // Prepare the IP to connect to addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 4444} - // Create parent context to cleanup handshaking connections on exit. - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - // // Everything below is the pion-DTLS API! Thanks for using it ❤️. // @@ -45,10 +39,6 @@ func main() { ExtendedMasterSecret: dtls.RequireExtendedMasterSecret, ClientAuth: dtls.RequireAndVerifyClientCert, ClientCAs: certPool, - // Create timeout context for accepted connection. - ConnectContextMaker: func() (context.Context, func()) { - return context.WithTimeout(ctx, 30*time.Second) - }, } // Connect to a DTLS server diff --git a/handshaker.go b/handshaker.go index a585c3db8..62a4bf6e8 100644 --- a/handshaker.go +++ b/handshaker.go @@ -132,6 +132,8 @@ type handshakeConfig struct { clientHelloMessageHook func(handshake.MessageClientHello) handshake.Message serverHelloMessageHook func(handshake.MessageServerHello) handshake.Message certificateRequestMessageHook func(handshake.MessageCertificateRequest) handshake.Message + + resumeState *State } type flightConn interface { diff --git a/listener.go b/listener.go index 90dbbb427..cb75d4143 100644 --- a/listener.go +++ b/listener.go @@ -67,8 +67,6 @@ type listener struct { // Accept waits for and returns the next connection to the listener. // You have to either close or read on all connection that are created. -// Connection handshake will timeout using ConnectContextMaker in the Config. -// If you want to specify the timeout duration, set ConnectContextMaker. func (l *listener) Accept() (net.Conn, error) { c, raddr, err := l.parent.Accept() if err != nil { diff --git a/resume.go b/resume.go index 6cd1c5a69..0b76314a5 100644 --- a/resume.go +++ b/resume.go @@ -4,7 +4,6 @@ package dtls import ( - "context" "net" ) @@ -13,14 +12,5 @@ func Resume(state *State, conn net.PacketConn, rAddr net.Addr, config *Config) ( if err := state.initCipherSuite(); err != nil { return nil, err } - dconn, err := createConn(conn, rAddr, config, state.isClient) - if err != nil { - return nil, err - } - c, err := handshakeConn(context.Background(), dconn, config, state.isClient, state) - if err != nil { - return nil, err - } - - return c, nil + return createConn(conn, rAddr, config, state.isClient, state) }