Skip to content

Commit

Permalink
Perform handshake on first read/write
Browse files Browse the repository at this point in the history
Updates the connection to perform a handshake on first read/write
instead of on accept.

Closes #279.
  • Loading branch information
kevmo314 authored and Sean-Der committed Jul 21, 2024
1 parent 6178064 commit e406468
Show file tree
Hide file tree
Showing 18 changed files with 227 additions and 249 deletions.
2 changes: 1 addition & 1 deletion bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func TestSimpleReadWrite(t *testing.T) {
return
}
buf := make([]byte, 1024)
if _, sErr = server.Read(buf); sErr != nil {
if _, sErr = server.Read(buf); sErr != nil { //nolint:contextcheck
t.Error(sErr)
}
gotHello <- struct{}{}
Expand Down
21 changes: 0 additions & 21 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
package dtls

import (
"context"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rsa"
Expand Down Expand Up @@ -118,15 +117,6 @@ type Config struct {

LoggerFactory logging.LoggerFactory

// ConnectContextMaker is a function to make a context used in Dial(),
// Client(), Server(), and Accept(). If nil, the default ConnectContextMaker
// is used. It can be implemented as following.
//
// func ConnectContextMaker() (context.Context, func()) {
// return context.WithTimeout(context.Background(), 30*time.Second)
// }
ConnectContextMaker func() (context.Context, func())

// MTU is the length at which handshake messages will be fragmented to
// fit within the maximum transmission unit (default is 1200 bytes)
MTU int
Expand Down Expand Up @@ -230,17 +220,6 @@ type Config struct {
OnConnectionAttempt func(net.Addr) error
}

func defaultConnectContextMaker() (context.Context, func()) {
return context.WithTimeout(context.Background(), 30*time.Second)
}

func (c *Config) connectContextMaker() (context.Context, func()) {
if c.ConnectContextMaker == nil {
return defaultConnectContextMaker()
}
return c.ConnectContextMaker()
}

func (c *Config) includeCertificateSuites() bool {
return c.PSK == nil || len(c.Certificates) > 0 || c.GetCertificate != nil || c.GetClientCertificate != nil
}
Expand Down
195 changes: 93 additions & 102 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ type Conn struct {
paddingLengthGenerator func(uint) uint

handshakeCompletedSuccessfully atomic.Value
handshakeMutex sync.Mutex

encryptedPackets []addrPkt

Expand All @@ -94,9 +95,11 @@ type Conn struct {
fsm *handshakeFSM

replayProtectionWindow uint

handshakeConfig *handshakeConfig
}

func createConn(nextConn net.PacketConn, rAddr net.Addr, config *Config, isClient bool) (*Conn, error) {
func createConn(nextConn net.PacketConn, rAddr net.Addr, config *Config, isClient bool, resumeState *State) (*Conn, error) {
if err := validateConfig(config); err != nil {
return nil, err
}
Expand Down Expand Up @@ -127,42 +130,6 @@ func createConn(nextConn net.PacketConn, rAddr net.Addr, config *Config, isClien
paddingLengthGenerator = func(uint) uint { return 0 }
}

c := &Conn{
rAddr: rAddr,
nextConn: netctx.NewPacketConn(nextConn),
fragmentBuffer: newFragmentBuffer(),
handshakeCache: newHandshakeCache(),
maximumTransmissionUnit: mtu,
paddingLengthGenerator: paddingLengthGenerator,

decrypted: make(chan interface{}, 1),
log: logger,

readDeadline: deadline.New(),
writeDeadline: deadline.New(),

reading: make(chan struct{}, 1),
handshakeRecv: make(chan recvHandshakeState),
closed: closer.NewCloser(),
cancelHandshaker: func() {},

replayProtectionWindow: uint(replayProtectionWindow),

state: State{
isClient: isClient,
},
}

c.setRemoteEpoch(0)
c.setLocalEpoch(0)
return c, nil
}

func handshakeConn(ctx context.Context, conn *Conn, config *Config, isClient bool, initialState *State) (*Conn, error) {
if conn == nil {
return nil, errNilNextConn
}

cipherSuites, err := parseCipherSuites(config.CipherSuites, config.CustomCipherSuites, config.includeCertificateSuites(), config.PSK != nil)
if err != nil {
return nil, err
Expand Down Expand Up @@ -190,7 +157,7 @@ func handshakeConn(ctx context.Context, conn *Conn, config *Config, isClient boo
curves = defaultCurves
}

hsCfg := &handshakeConfig{
handshakeConfig := &handshakeConfig{
localPSKCallback: config.PSK,
localPSKIdentityHint: config.PSKIdentityHint,
localCipherSuites: cipherSuites,
Expand All @@ -209,7 +176,7 @@ func handshakeConn(ctx context.Context, conn *Conn, config *Config, isClient boo
customCipherSuites: config.CustomCipherSuites,
initialRetransmitInterval: workerInterval,
disableRetransmitBackoff: config.DisableRetransmitBackoff,
log: conn.log,
log: logger,
initialEpoch: 0,
keyLogWriter: config.KeyLogWriter,
sessionStore: config.SessionStore,
Expand All @@ -222,82 +189,115 @@ func handshakeConn(ctx context.Context, conn *Conn, config *Config, isClient boo
clientHelloMessageHook: config.ClientHelloMessageHook,
serverHelloMessageHook: config.ServerHelloMessageHook,
certificateRequestMessageHook: config.CertificateRequestMessageHook,
resumeState: resumeState,
}

c := &Conn{
rAddr: rAddr,
nextConn: netctx.NewPacketConn(nextConn),
handshakeConfig: handshakeConfig,
fragmentBuffer: newFragmentBuffer(),
handshakeCache: newHandshakeCache(),
maximumTransmissionUnit: mtu,
paddingLengthGenerator: paddingLengthGenerator,

decrypted: make(chan interface{}, 1),
log: logger,

readDeadline: deadline.New(),
writeDeadline: deadline.New(),

reading: make(chan struct{}, 1),
handshakeRecv: make(chan recvHandshakeState),
closed: closer.NewCloser(),
cancelHandshaker: func() {},
cancelHandshakeReader: func() {},

replayProtectionWindow: uint(replayProtectionWindow),

state: State{
isClient: isClient,
},
}

c.setRemoteEpoch(0)
c.setLocalEpoch(0)
return c, nil
}

// Handshake runs the client or server DTLS handshake
// protocol if it has not yet been run.
//
// Most uses of this package need not call Handshake explicitly: the
// first [Conn.Read] or [Conn.Write] will call it automatically.
//
// For control over canceling or setting a timeout on a handshake, use
// [Conn.HandshakeContext].
func (c *Conn) Handshake() error {
return c.HandshakeContext(context.Background())
}

// HandshakeContext runs the client or server DTLS handshake
// protocol if it has not yet been run.
//
// The provided Context must be non-nil. If the context is canceled before
// the handshake is complete, the handshake is interrupted and an error is returned.
// Once the handshake has completed, cancellation of the context will not affect the
// connection.
//
// Most uses of this package need not call HandshakeContext explicitly: the
// first [Conn.Read] or [Conn.Write] will call it automatically.
func (c *Conn) HandshakeContext(ctx context.Context) error {
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()

if c.isHandshakeCompletedSuccessfully() {
return nil
}

// rfc5246#section-7.4.3
// In addition, the hash and signature algorithms MUST be compatible
// with the key in the server's end-entity certificate.
if !isClient {
cert, err := hsCfg.getCertificate(&ClientHelloInfo{})
if !c.state.isClient {
cert, err := c.handshakeConfig.getCertificate(&ClientHelloInfo{})
if err != nil && !errors.Is(err, errNoCertificates) {
return nil, err
return err
}
hsCfg.localCipherSuites = filterCipherSuitesForCertificate(cert, cipherSuites)
c.handshakeConfig.localCipherSuites = filterCipherSuitesForCertificate(cert, c.handshakeConfig.localCipherSuites)
}

var initialFlight flightVal
var initialFSMState handshakeState

if initialState != nil {
if conn.state.isClient {
if c.handshakeConfig.resumeState != nil {
if c.state.isClient {
initialFlight = flight5
} else {
initialFlight = flight6
}
initialFSMState = handshakeFinished

conn.state = *initialState
c.state = *c.handshakeConfig.resumeState
} else {
if conn.state.isClient {
if c.state.isClient {
initialFlight = flight1
} else {
initialFlight = flight0
}
initialFSMState = handshakePreparing
}
// Do handshake
if err := conn.handshake(ctx, hsCfg, initialFlight, initialFSMState); err != nil {
return nil, err
if err := c.handshake(ctx, c.handshakeConfig, initialFlight, initialFSMState); err != nil {
return err
}

conn.log.Trace("Handshake Completed")
c.log.Trace("Handshake Completed")

return conn, nil
return nil
}

// Dial connects to the given network address and establishes a DTLS connection on top.
// Connection handshake will timeout using ConnectContextMaker in the Config.
// If you want to specify the timeout duration, use DialWithContext() instead.
func Dial(network string, rAddr *net.UDPAddr, config *Config) (*Conn, error) {
ctx, cancel := config.connectContextMaker()
defer cancel()

return DialWithContext(ctx, network, rAddr, config)
}

// Client establishes a DTLS connection over an existing connection.
// Connection handshake will timeout using ConnectContextMaker in the Config.
// If you want to specify the timeout duration, use ClientWithContext() instead.
func Client(conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) {
ctx, cancel := config.connectContextMaker()
defer cancel()

return ClientWithContext(ctx, conn, rAddr, config)
}

// Server listens for incoming DTLS connections.
// Connection handshake will timeout using ConnectContextMaker in the Config.
// If you want to specify the timeout duration, use ServerWithContext() instead.
func Server(conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) {
ctx, cancel := config.connectContextMaker()
defer cancel()

return ServerWithContext(ctx, conn, rAddr, config)
}

// DialWithContext connects to the given network address and establishes a DTLS
// connection on top.
func DialWithContext(ctx context.Context, network string, rAddr *net.UDPAddr, config *Config) (*Conn, error) {
// net.ListenUDP is used rather than net.DialUDP as the latter prevents the
// use of net.PacketConn.WriteTo.
// https://github.com/golang/go/blob/ce5e37ec21442c6eb13a43e68ca20129102ebac0/src/net/udpsock_posix.go#L115
Expand All @@ -306,28 +306,23 @@ func DialWithContext(ctx context.Context, network string, rAddr *net.UDPAddr, co
return nil, err
}

return ClientWithContext(ctx, pConn, rAddr, config)
return Client(pConn, rAddr, config)
}

// ClientWithContext establishes a DTLS connection over an existing connection.
func ClientWithContext(ctx context.Context, conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) {
// Client establishes a DTLS connection over an existing connection.
func Client(conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) {
switch {
case config == nil:
return nil, errNoConfigProvided
case config.PSK != nil && config.PSKIdentityHint == nil:
return nil, errPSKAndIdentityMustBeSetForClient
}

dconn, err := createConn(conn, rAddr, config, true)
if err != nil {
return nil, err
}

return handshakeConn(ctx, dconn, config, true, nil)
return createConn(conn, rAddr, config, true, nil)
}

// ServerWithContext listens for incoming DTLS connections.
func ServerWithContext(ctx context.Context, conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) {
// Server listens for incoming DTLS connections.
func Server(conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) {
if config == nil {
return nil, errNoConfigProvided
}
Expand All @@ -336,17 +331,13 @@ func ServerWithContext(ctx context.Context, conn net.PacketConn, rAddr net.Addr,
return nil, err
}
}
dconn, err := createConn(conn, rAddr, config, false)
if err != nil {
return nil, err
}
return handshakeConn(ctx, dconn, config, false, nil)
return createConn(conn, rAddr, config, false, nil)
}

// Read reads data from the connection.
func (c *Conn) Read(p []byte) (n int, err error) {
if !c.isHandshakeCompletedSuccessfully() {
return 0, errHandshakeInProgress
if err := c.Handshake(); err != nil {
return 0, err
}

select {
Expand Down Expand Up @@ -389,8 +380,8 @@ func (c *Conn) Write(p []byte) (int, error) {
default:
}

if !c.isHandshakeCompletedSuccessfully() {
return 0, errHandshakeInProgress
if err := c.Handshake(); err != nil {
return 0, err
}

return len(p), c.writePackets(c.writeDeadline, []*packet{
Expand Down
Loading

0 comments on commit e406468

Please sign in to comment.