Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perform handshake on first read/write #649

Merged
merged 1 commit into from
Jul 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func TestSimpleReadWrite(t *testing.T) {
return
}
buf := make([]byte, 1024)
if _, sErr = server.Read(buf); sErr != nil {
if _, sErr = server.Read(buf); sErr != nil { //nolint:contextcheck
t.Error(sErr)
}
gotHello <- struct{}{}
Expand Down
21 changes: 0 additions & 21 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
package dtls

import (
"context"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rsa"
Expand Down Expand Up @@ -118,15 +117,6 @@ type Config struct {

LoggerFactory logging.LoggerFactory

// ConnectContextMaker is a function to make a context used in Dial(),
// Client(), Server(), and Accept(). If nil, the default ConnectContextMaker
// is used. It can be implemented as following.
//
// func ConnectContextMaker() (context.Context, func()) {
// return context.WithTimeout(context.Background(), 30*time.Second)
// }
ConnectContextMaker func() (context.Context, func())

// MTU is the length at which handshake messages will be fragmented to
// fit within the maximum transmission unit (default is 1200 bytes)
MTU int
Expand Down Expand Up @@ -230,17 +220,6 @@ type Config struct {
OnConnectionAttempt func(net.Addr) error
}

func defaultConnectContextMaker() (context.Context, func()) {
return context.WithTimeout(context.Background(), 30*time.Second)
}

func (c *Config) connectContextMaker() (context.Context, func()) {
if c.ConnectContextMaker == nil {
return defaultConnectContextMaker()
}
return c.ConnectContextMaker()
}

func (c *Config) includeCertificateSuites() bool {
return c.PSK == nil || len(c.Certificates) > 0 || c.GetCertificate != nil || c.GetClientCertificate != nil
}
Expand Down
195 changes: 93 additions & 102 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
paddingLengthGenerator func(uint) uint

handshakeCompletedSuccessfully atomic.Value
handshakeMutex sync.Mutex

encryptedPackets []addrPkt

Expand All @@ -94,9 +95,11 @@
fsm *handshakeFSM

replayProtectionWindow uint

handshakeConfig *handshakeConfig
}

func createConn(nextConn net.PacketConn, rAddr net.Addr, config *Config, isClient bool) (*Conn, error) {
func createConn(nextConn net.PacketConn, rAddr net.Addr, config *Config, isClient bool, resumeState *State) (*Conn, error) {
if err := validateConfig(config); err != nil {
return nil, err
}
Expand Down Expand Up @@ -127,42 +130,6 @@
paddingLengthGenerator = func(uint) uint { return 0 }
}

c := &Conn{
rAddr: rAddr,
nextConn: netctx.NewPacketConn(nextConn),
fragmentBuffer: newFragmentBuffer(),
handshakeCache: newHandshakeCache(),
maximumTransmissionUnit: mtu,
paddingLengthGenerator: paddingLengthGenerator,

decrypted: make(chan interface{}, 1),
log: logger,

readDeadline: deadline.New(),
writeDeadline: deadline.New(),

reading: make(chan struct{}, 1),
handshakeRecv: make(chan recvHandshakeState),
closed: closer.NewCloser(),
cancelHandshaker: func() {},

replayProtectionWindow: uint(replayProtectionWindow),

state: State{
isClient: isClient,
},
}

c.setRemoteEpoch(0)
c.setLocalEpoch(0)
return c, nil
}

func handshakeConn(ctx context.Context, conn *Conn, config *Config, isClient bool, initialState *State) (*Conn, error) {
if conn == nil {
return nil, errNilNextConn
}

cipherSuites, err := parseCipherSuites(config.CipherSuites, config.CustomCipherSuites, config.includeCertificateSuites(), config.PSK != nil)
if err != nil {
return nil, err
Expand Down Expand Up @@ -190,7 +157,7 @@
curves = defaultCurves
}

hsCfg := &handshakeConfig{
handshakeConfig := &handshakeConfig{
localPSKCallback: config.PSK,
localPSKIdentityHint: config.PSKIdentityHint,
localCipherSuites: cipherSuites,
Expand All @@ -209,7 +176,7 @@
customCipherSuites: config.CustomCipherSuites,
initialRetransmitInterval: workerInterval,
disableRetransmitBackoff: config.DisableRetransmitBackoff,
log: conn.log,
log: logger,
initialEpoch: 0,
keyLogWriter: config.KeyLogWriter,
sessionStore: config.SessionStore,
Expand All @@ -222,82 +189,115 @@
clientHelloMessageHook: config.ClientHelloMessageHook,
serverHelloMessageHook: config.ServerHelloMessageHook,
certificateRequestMessageHook: config.CertificateRequestMessageHook,
resumeState: resumeState,
}

c := &Conn{
rAddr: rAddr,
nextConn: netctx.NewPacketConn(nextConn),
handshakeConfig: handshakeConfig,
fragmentBuffer: newFragmentBuffer(),
handshakeCache: newHandshakeCache(),
maximumTransmissionUnit: mtu,
paddingLengthGenerator: paddingLengthGenerator,

decrypted: make(chan interface{}, 1),
log: logger,

readDeadline: deadline.New(),
writeDeadline: deadline.New(),

reading: make(chan struct{}, 1),
handshakeRecv: make(chan recvHandshakeState),
closed: closer.NewCloser(),
cancelHandshaker: func() {},
cancelHandshakeReader: func() {},

replayProtectionWindow: uint(replayProtectionWindow),

state: State{
isClient: isClient,
},
}

c.setRemoteEpoch(0)
c.setLocalEpoch(0)
return c, nil
}

// Handshake runs the client or server DTLS handshake
// protocol if it has not yet been run.
//
// Most uses of this package need not call Handshake explicitly: the
// first [Conn.Read] or [Conn.Write] will call it automatically.
//
// For control over canceling or setting a timeout on a handshake, use
// [Conn.HandshakeContext].
func (c *Conn) Handshake() error {
return c.HandshakeContext(context.Background())
}

// HandshakeContext runs the client or server DTLS handshake
// protocol if it has not yet been run.
//
// The provided Context must be non-nil. If the context is canceled before
// the handshake is complete, the handshake is interrupted and an error is returned.
// Once the handshake has completed, cancellation of the context will not affect the
// connection.
//
// Most uses of this package need not call HandshakeContext explicitly: the
// first [Conn.Read] or [Conn.Write] will call it automatically.
func (c *Conn) HandshakeContext(ctx context.Context) error {
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()

if c.isHandshakeCompletedSuccessfully() {
return nil
}

// rfc5246#section-7.4.3
// In addition, the hash and signature algorithms MUST be compatible
// with the key in the server's end-entity certificate.
if !isClient {
cert, err := hsCfg.getCertificate(&ClientHelloInfo{})
if !c.state.isClient {
cert, err := c.handshakeConfig.getCertificate(&ClientHelloInfo{})
if err != nil && !errors.Is(err, errNoCertificates) {
return nil, err
return err

Check warning on line 264 in conn.go

View check run for this annotation

Codecov / codecov/patch

conn.go#L264

Added line #L264 was not covered by tests
}
hsCfg.localCipherSuites = filterCipherSuitesForCertificate(cert, cipherSuites)
c.handshakeConfig.localCipherSuites = filterCipherSuitesForCertificate(cert, c.handshakeConfig.localCipherSuites)
}

var initialFlight flightVal
var initialFSMState handshakeState

if initialState != nil {
if conn.state.isClient {
if c.handshakeConfig.resumeState != nil {
if c.state.isClient {
initialFlight = flight5
} else {
initialFlight = flight6
}
initialFSMState = handshakeFinished

conn.state = *initialState
c.state = *c.handshakeConfig.resumeState
} else {
if conn.state.isClient {
if c.state.isClient {
initialFlight = flight1
} else {
initialFlight = flight0
}
initialFSMState = handshakePreparing
}
// Do handshake
if err := conn.handshake(ctx, hsCfg, initialFlight, initialFSMState); err != nil {
return nil, err
if err := c.handshake(ctx, c.handshakeConfig, initialFlight, initialFSMState); err != nil {
return err
}

conn.log.Trace("Handshake Completed")
c.log.Trace("Handshake Completed")

return conn, nil
return nil
}

// Dial connects to the given network address and establishes a DTLS connection on top.
// Connection handshake will timeout using ConnectContextMaker in the Config.
// If you want to specify the timeout duration, use DialWithContext() instead.
func Dial(network string, rAddr *net.UDPAddr, config *Config) (*Conn, error) {
ctx, cancel := config.connectContextMaker()
defer cancel()

return DialWithContext(ctx, network, rAddr, config)
}

// Client establishes a DTLS connection over an existing connection.
// Connection handshake will timeout using ConnectContextMaker in the Config.
// If you want to specify the timeout duration, use ClientWithContext() instead.
func Client(conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) {
ctx, cancel := config.connectContextMaker()
defer cancel()

return ClientWithContext(ctx, conn, rAddr, config)
}

// Server listens for incoming DTLS connections.
// Connection handshake will timeout using ConnectContextMaker in the Config.
// If you want to specify the timeout duration, use ServerWithContext() instead.
func Server(conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) {
ctx, cancel := config.connectContextMaker()
defer cancel()

return ServerWithContext(ctx, conn, rAddr, config)
}

// DialWithContext connects to the given network address and establishes a DTLS
// connection on top.
func DialWithContext(ctx context.Context, network string, rAddr *net.UDPAddr, config *Config) (*Conn, error) {
// net.ListenUDP is used rather than net.DialUDP as the latter prevents the
// use of net.PacketConn.WriteTo.
// https://github.com/golang/go/blob/ce5e37ec21442c6eb13a43e68ca20129102ebac0/src/net/udpsock_posix.go#L115
Expand All @@ -306,28 +306,23 @@
return nil, err
}

return ClientWithContext(ctx, pConn, rAddr, config)
return Client(pConn, rAddr, config)
}

// ClientWithContext establishes a DTLS connection over an existing connection.
func ClientWithContext(ctx context.Context, conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) {
// Client establishes a DTLS connection over an existing connection.
func Client(conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) {
switch {
case config == nil:
return nil, errNoConfigProvided
case config.PSK != nil && config.PSKIdentityHint == nil:
return nil, errPSKAndIdentityMustBeSetForClient
}

dconn, err := createConn(conn, rAddr, config, true)
if err != nil {
return nil, err
}

return handshakeConn(ctx, dconn, config, true, nil)
return createConn(conn, rAddr, config, true, nil)
}

// ServerWithContext listens for incoming DTLS connections.
func ServerWithContext(ctx context.Context, conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) {
// Server listens for incoming DTLS connections.
func Server(conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) {
if config == nil {
return nil, errNoConfigProvided
}
Expand All @@ -336,17 +331,13 @@
return nil, err
}
}
dconn, err := createConn(conn, rAddr, config, false)
if err != nil {
return nil, err
}
return handshakeConn(ctx, dconn, config, false, nil)
return createConn(conn, rAddr, config, false, nil)
}

// Read reads data from the connection.
func (c *Conn) Read(p []byte) (n int, err error) {
if !c.isHandshakeCompletedSuccessfully() {
return 0, errHandshakeInProgress
if err := c.Handshake(); err != nil {
return 0, err

Check warning on line 340 in conn.go

View check run for this annotation

Codecov / codecov/patch

conn.go#L340

Added line #L340 was not covered by tests
}

select {
Expand Down Expand Up @@ -389,8 +380,8 @@
default:
}

if !c.isHandshakeCompletedSuccessfully() {
return 0, errHandshakeInProgress
if err := c.Handshake(); err != nil {
return 0, err

Check warning on line 384 in conn.go

View check run for this annotation

Codecov / codecov/patch

conn.go#L384

Added line #L384 was not covered by tests
}

return len(p), c.writePackets(c.writeDeadline, []*packet{
Expand Down
Loading
Loading