diff --git a/conn.go b/conn.go index 338f793ad..cf8551b40 100644 --- a/conn.go +++ b/conn.go @@ -34,6 +34,9 @@ const ( inboundBufferSize = 8192 // Default replay protection window is specified by RFC 6347 Section 4.1.2.6 defaultReplayProtectionWindow = 64 + // maxAppDataPacketQueueSize is the maximum number of app data packets we will + // enqueue before the handshake is completed + maxAppDataPacketQueueSize = 100 ) func invalidKeyingLabels() map[string]bool { @@ -81,7 +84,7 @@ type Conn struct { replayProtectionWindow uint } -func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient bool, initialState *State) (*Conn, error) { +func createConn(nextConn net.Conn, config *Config, isClient bool) (*Conn, error) { err := validateConfig(config) if err != nil { return nil, err @@ -91,21 +94,6 @@ func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient return nil, errNilNextConn } - cipherSuites, err := parseCipherSuites(config.CipherSuites, config.CustomCipherSuites, config.includeCertificateSuites(), config.PSK != nil) - if err != nil { - return nil, err - } - - signatureSchemes, err := signaturehash.ParseSignatureSchemes(config.SignatureSchemes, config.InsecureHashes) - if err != nil { - return nil, err - } - - workerInterval := initialTickerInterval - if config.FlightInterval != 0 { - workerInterval = config.FlightInterval - } - loggerFactory := config.LoggerFactory if loggerFactory == nil { loggerFactory = logging.NewDefaultLoggerFactory() @@ -149,6 +137,38 @@ func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient c.setRemoteEpoch(0) c.setLocalEpoch(0) + return c, nil +} + +func handshakeConn(ctx context.Context, conn *Conn, config *Config, isClient bool, initialState *State) (*Conn, error) { + if conn == nil { + return nil, errNilNextConn + } + + cipherSuites, err := parseCipherSuites(config.CipherSuites, config.CustomCipherSuites, config.includeCertificateSuites(), config.PSK != nil) + if err != nil { + return nil, err + } + + signatureSchemes, err := signaturehash.ParseSignatureSchemes(config.SignatureSchemes, config.InsecureHashes) + if err != nil { + return nil, err + } + + workerInterval := initialTickerInterval + if config.FlightInterval != 0 { + workerInterval = config.FlightInterval + } + + mtu := config.MTU + if mtu <= 0 { + mtu = defaultMTU + } + + replayProtectionWindow := config.ReplayProtectionWindow + if replayProtectionWindow <= 0 { + replayProtectionWindow = defaultReplayProtectionWindow + } serverName := config.ServerName // Do not allow the use of an IP address literal as an SNI value. @@ -180,7 +200,7 @@ func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient clientCAs: config.ClientCAs, customCipherSuites: config.CustomCipherSuites, retransmitInterval: workerInterval, - log: logger, + log: conn.log, initialEpoch: 0, keyLogWriter: config.KeyLogWriter, sessionStore: config.SessionStore, @@ -205,16 +225,16 @@ func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient var initialFSMState handshakeState if initialState != nil { - if c.state.isClient { + if conn.state.isClient { initialFlight = flight5 } else { initialFlight = flight6 } initialFSMState = handshakeFinished - c.state = *initialState + conn.state = *initialState } else { - if c.state.isClient { + if conn.state.isClient { initialFlight = flight1 } else { initialFlight = flight0 @@ -222,13 +242,13 @@ func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient initialFSMState = handshakePreparing } // Do handshake - if err := c.handshake(ctx, hsCfg, initialFlight, initialFSMState); err != nil { + if err := conn.handshake(ctx, hsCfg, initialFlight, initialFSMState); err != nil { return nil, err } - c.log.Trace("Handshake Completed") + conn.log.Trace("Handshake Completed") - return c, nil + return conn, nil } // Dial connects to the given network address and establishes a DTLS connection on top. @@ -279,7 +299,12 @@ func ClientWithContext(ctx context.Context, conn net.Conn, config *Config) (*Con return nil, errPSKAndIdentityMustBeSetForClient } - return createConn(ctx, conn, config, true, nil) + dconn, err := createConn(conn, config, true) + if err != nil { + return nil, err + } + + return handshakeConn(ctx, dconn, config, true, nil) } // ServerWithContext listens for incoming DTLS connections. @@ -287,8 +312,11 @@ func ServerWithContext(ctx context.Context, conn net.Conn, config *Config) (*Con if config == nil { return nil, errNoConfigProvided } - - return createConn(ctx, conn, config, false, nil) + dconn, err := createConn(conn, config, false) + if err != nil { + return nil, err + } + return handshakeConn(ctx, dconn, config, false, nil) } // Read reads data from the connection. @@ -662,7 +690,6 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, enqueue boo c.log.Debugf("discarded broken packet: %v", err) return false, nil, nil } - // Validate epoch remoteEpoch := c.state.getRemoteEpoch() if h.Epoch > remoteEpoch { @@ -673,8 +700,12 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, enqueue boo return false, nil, nil } if enqueue { - c.log.Debug("received packet of next epoch, queuing packet") - c.encryptedPackets = append(c.encryptedPackets, buf) + if len(c.encryptedPackets) < maxAppDataPacketQueueSize { + c.log.Debug("received packet of next epoch, queuing packet") + c.encryptedPackets = append(c.encryptedPackets, buf) + } else { + c.log.Debug("app data packet queue full, dropping packet") + } } return false, nil, nil } @@ -697,8 +728,12 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, enqueue boo if h.Epoch != 0 { if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() { if enqueue { - c.encryptedPackets = append(c.encryptedPackets, buf) - c.log.Debug("handshake not finished, queuing packet") + if len(c.encryptedPackets) < maxAppDataPacketQueueSize { + c.encryptedPackets = append(c.encryptedPackets, buf) + c.log.Debug("handshake not finished, queuing packet") + } else { + c.log.Debug("app data packet queue full, dropping packet") + } } return false, nil, nil } @@ -749,8 +784,12 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, enqueue boo case *protocol.ChangeCipherSpec: if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() { if enqueue { - c.encryptedPackets = append(c.encryptedPackets, buf) - c.log.Debugf("CipherSuite not initialized, queuing packet") + if len(c.encryptedPackets) < maxAppDataPacketQueueSize { + c.encryptedPackets = append(c.encryptedPackets, buf) + c.log.Debugf("CipherSuite not initialized, queuing packet") + } else { + c.log.Debug("app data packet queue full. dropping packet") + } } return false, nil, nil } diff --git a/conn_test.go b/conn_test.go index ea3c842f7..6083a050a 100644 --- a/conn_test.go +++ b/conn_test.go @@ -3050,3 +3050,88 @@ func (c *connWithCallback) Write(b []byte) (int, error) { } return c.Conn.Write(b) } + +func TestApplicationDataQueueLimited(t *testing.T) { + // Limit runtime in case of deadlocks + lim := test.TimeOut(time.Second * 20) + defer lim.Stop() + + // Check for leaking routines + report := test.CheckRoutines(t) + defer report() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + ca, cb := dpipe.Pipe() + defer ca.Close() + defer cb.Close() + + done := make(chan struct{}) + go func() { + serverCert, err := selfsign.GenerateSelfSigned() + if err != nil { + t.Error(err) + return + } + cfg := &Config{} + cfg.Certificates = []tls.Certificate{serverCert} + + dconn, err := createConn(cb, cfg, false) + if err != nil { + t.Error(err) + return + } + go func() { + for i := 0; i < 5; i++ { + dconn.lock.RLock() + qlen := len(dconn.encryptedPackets) + dconn.lock.RUnlock() + if qlen > maxAppDataPacketQueueSize { + t.Error("too many encrypted packets enqueued", len(dconn.encryptedPackets)) + } + t.Log(qlen) + time.Sleep(1 * time.Second) + } + + }() + if _, err := handshakeConn(ctx, dconn, cfg, false, nil); err == nil { + t.Error("expected handshake to fail") + } + close(done) + }() + extensions := []extension.Extension{} + + time.Sleep(50 * time.Millisecond) + + err := sendClientHello([]byte{}, ca, 0, extensions) + if err != nil { + t.Fatal(err) + } + + time.Sleep(50 * time.Millisecond) + + for i := 0; i < 1000; i++ { + // Send an application data packet + packet, err := (&recordlayer.RecordLayer{ + Header: recordlayer.Header{ + Version: protocol.Version1_2, + SequenceNumber: uint64(3), + Epoch: 1, // use an epoch greater than 0 + }, + Content: &protocol.ApplicationData{ + Data: []byte{1, 2, 3, 4}, + }, + }).Marshal() + if err != nil { + t.Fatal(err) + } + ca.Write(packet) + if i%100 == 0 { + time.Sleep(10 * time.Millisecond) + } + } + time.Sleep(1 * time.Second) + ca.Close() + <-done +} diff --git a/resume.go b/resume.go index c470d856b..f070d7537 100644 --- a/resume.go +++ b/resume.go @@ -13,7 +13,11 @@ func Resume(state *State, conn net.Conn, config *Config) (*Conn, error) { if err := state.initCipherSuite(); err != nil { return nil, err } - c, err := createConn(context.Background(), conn, config, state.isClient, state) + dconn, err := createConn(conn, config, state.isClient) + if err != nil { + return nil, err + } + c, err := handshakeConn(context.Background(), dconn, config, state.isClient, state) if err != nil { return nil, err }