Skip to content

Commit

Permalink
Perform handshake on first read/write
Browse files Browse the repository at this point in the history
Updates the connection to perform a handshake on first read/write instead of on accept. Closes #279
  • Loading branch information
kevmo314 committed Jul 16, 2024
1 parent d013d0c commit f4011b1
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 79 deletions.
127 changes: 61 additions & 66 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,11 @@ type Conn struct {
fsm *handshakeFSM

replayProtectionWindow uint

handshakeConfig *handshakeConfig
}

func createConn(nextConn net.PacketConn, rAddr net.Addr, config *Config, isClient bool) (*Conn, error) {
func createConn(nextConn net.PacketConn, rAddr net.Addr, config *Config, isClient bool, resumeState *State) (*Conn, error) {
if err := validateConfig(config); err != nil {
return nil, err
}
Expand Down Expand Up @@ -127,42 +129,6 @@ func createConn(nextConn net.PacketConn, rAddr net.Addr, config *Config, isClien
paddingLengthGenerator = func(uint) uint { return 0 }
}

c := &Conn{
rAddr: rAddr,
nextConn: netctx.NewPacketConn(nextConn),
fragmentBuffer: newFragmentBuffer(),
handshakeCache: newHandshakeCache(),
maximumTransmissionUnit: mtu,
paddingLengthGenerator: paddingLengthGenerator,

decrypted: make(chan interface{}, 1),
log: logger,

readDeadline: deadline.New(),
writeDeadline: deadline.New(),

reading: make(chan struct{}, 1),
handshakeRecv: make(chan recvHandshakeState),
closed: closer.NewCloser(),
cancelHandshaker: func() {},

replayProtectionWindow: uint(replayProtectionWindow),

state: State{
isClient: isClient,
},
}

c.setRemoteEpoch(0)
c.setLocalEpoch(0)
return c, nil
}

func handshakeConn(ctx context.Context, conn *Conn, config *Config, isClient bool, initialState *State) (*Conn, error) {
if conn == nil {
return nil, errNilNextConn
}

cipherSuites, err := parseCipherSuites(config.CipherSuites, config.CustomCipherSuites, config.includeCertificateSuites(), config.PSK != nil)
if err != nil {
return nil, err
Expand Down Expand Up @@ -190,7 +156,7 @@ func handshakeConn(ctx context.Context, conn *Conn, config *Config, isClient boo
curves = defaultCurves
}

hsCfg := &handshakeConfig{
handshakeConfig := &handshakeConfig{
localPSKCallback: config.PSK,
localPSKIdentityHint: config.PSKIdentityHint,
localCipherSuites: cipherSuites,
Expand All @@ -209,7 +175,7 @@ func handshakeConn(ctx context.Context, conn *Conn, config *Config, isClient boo
customCipherSuites: config.CustomCipherSuites,
initialRetransmitInterval: workerInterval,
disableRetransmitBackoff: config.DisableRetransmitBackoff,
log: conn.log,
log: logger,
initialEpoch: 0,
keyLogWriter: config.KeyLogWriter,
sessionStore: config.SessionStore,
Expand All @@ -222,47 +188,85 @@ func handshakeConn(ctx context.Context, conn *Conn, config *Config, isClient boo
clientHelloMessageHook: config.ClientHelloMessageHook,
serverHelloMessageHook: config.ServerHelloMessageHook,
certificateRequestMessageHook: config.CertificateRequestMessageHook,
resumeState: resumeState,
}

c := &Conn{
rAddr: rAddr,
nextConn: netctx.NewPacketConn(nextConn),
handshakeConfig: handshakeConfig,
fragmentBuffer: newFragmentBuffer(),
handshakeCache: newHandshakeCache(),
maximumTransmissionUnit: mtu,
paddingLengthGenerator: paddingLengthGenerator,

decrypted: make(chan interface{}, 1),
log: logger,

readDeadline: deadline.New(),
writeDeadline: deadline.New(),

reading: make(chan struct{}, 1),
handshakeRecv: make(chan recvHandshakeState),
closed: closer.NewCloser(),
cancelHandshaker: func() {},

replayProtectionWindow: uint(replayProtectionWindow),

state: State{
isClient: isClient,
},
}

c.setRemoteEpoch(0)
c.setLocalEpoch(0)
return c, nil
}

func (c *Conn) Handshake() error {

Check warning on line 226 in conn.go

View workflow job for this annotation

GitHub Actions / lint / Go

exported: exported method Conn.Handshake should have comment or be unexported (revive)
return c.HandshakeContext(context.Background())
}

func (c *Conn) HandshakeContext(ctx context.Context) error {

Check warning on line 230 in conn.go

View workflow job for this annotation

GitHub Actions / lint / Go

exported: exported method Conn.HandshakeContext should have comment or be unexported (revive)
// rfc5246#section-7.4.3
// In addition, the hash and signature algorithms MUST be compatible
// with the key in the server's end-entity certificate.
if !isClient {
cert, err := hsCfg.getCertificate(&ClientHelloInfo{})
if !c.state.isClient {
cert, err := c.handshakeConfig.getCertificate(&ClientHelloInfo{})
if err != nil && !errors.Is(err, errNoCertificates) {
return nil, err
return err
}
hsCfg.localCipherSuites = filterCipherSuitesForCertificate(cert, cipherSuites)
c.handshakeConfig.localCipherSuites = filterCipherSuitesForCertificate(cert, c.handshakeConfig.localCipherSuites)
}

var initialFlight flightVal
var initialFSMState handshakeState

if initialState != nil {
if conn.state.isClient {
if c.handshakeConfig.resumeState != nil {
if c.state.isClient {
initialFlight = flight5
} else {
initialFlight = flight6
}
initialFSMState = handshakeFinished

conn.state = *initialState
c.state = *c.handshakeConfig.resumeState
} else {
if conn.state.isClient {
if c.state.isClient {
initialFlight = flight1
} else {
initialFlight = flight0
}
initialFSMState = handshakePreparing
}
// Do handshake
if err := conn.handshake(ctx, hsCfg, initialFlight, initialFSMState); err != nil {
return nil, err
if err := c.handshake(ctx, c.handshakeConfig, initialFlight, initialFSMState); err != nil {
return err
}

conn.log.Trace("Handshake Completed")
c.log.Trace("Handshake Completed")

return conn, nil
return nil
}

// Dial connects to the given network address and establishes a DTLS connection on top.
Expand Down Expand Up @@ -318,12 +322,7 @@ func ClientWithContext(ctx context.Context, conn net.PacketConn, rAddr net.Addr,
return nil, errPSKAndIdentityMustBeSetForClient
}

dconn, err := createConn(conn, rAddr, config, true)
if err != nil {
return nil, err
}

return handshakeConn(ctx, dconn, config, true, nil)
return createConn(conn, rAddr, config, true, nil)
}

// ServerWithContext listens for incoming DTLS connections.
Expand All @@ -336,17 +335,13 @@ func ServerWithContext(ctx context.Context, conn net.PacketConn, rAddr net.Addr,
return nil, err
}
}
dconn, err := createConn(conn, rAddr, config, false)
if err != nil {
return nil, err
}
return handshakeConn(ctx, dconn, config, false, nil)
return createConn(conn, rAddr, config, false, nil)
}

// Read reads data from the connection.
func (c *Conn) Read(p []byte) (n int, err error) {
if !c.isHandshakeCompletedSuccessfully() {
return 0, errHandshakeInProgress
if err := c.Handshake(); err != nil {
return 0, err
}

select {
Expand Down Expand Up @@ -389,8 +384,8 @@ func (c *Conn) Write(p []byte) (int, error) {
default:
}

if !c.isHandshakeCompletedSuccessfully() {
return 0, errHandshakeInProgress
if err := c.Handshake(); err != nil {
return 0, err
}

return len(p), c.writePackets(c.writeDeadline, []*packet{
Expand Down
4 changes: 2 additions & 2 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3306,7 +3306,7 @@ func TestApplicationDataQueueLimited(t *testing.T) {
cfg := &Config{}
cfg.Certificates = []tls.Certificate{serverCert}

dconn, err := createConn(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), cfg, false)
dconn, err := createConn(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), cfg, false, nil)
if err != nil {
t.Error(err)
return
Expand All @@ -3322,7 +3322,7 @@ func TestApplicationDataQueueLimited(t *testing.T) {
time.Sleep(1 * time.Second)
}
}()
if _, err := handshakeConn(ctx, dconn, cfg, false, nil); err == nil {
if err := dconn.HandshakeContext(ctx); err == nil {
t.Error("expected handshake to fail")
}
close(done)
Expand Down
2 changes: 2 additions & 0 deletions handshaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ type handshakeConfig struct {
clientHelloMessageHook func(handshake.MessageClientHello) handshake.Message
serverHelloMessageHook func(handshake.MessageServerHello) handshake.Message
certificateRequestMessageHook func(handshake.MessageCertificateRequest) handshake.Message

resumeState *State
}

type flightConn interface {
Expand Down
12 changes: 1 addition & 11 deletions resume.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
package dtls

import (
"context"
"net"
)

Expand All @@ -13,14 +12,5 @@ func Resume(state *State, conn net.PacketConn, rAddr net.Addr, config *Config) (
if err := state.initCipherSuite(); err != nil {
return nil, err
}
dconn, err := createConn(conn, rAddr, config, state.isClient)
if err != nil {
return nil, err
}
c, err := handshakeConn(context.Background(), dconn, config, state.isClient, state)
if err != nil {
return nil, err
}

return c, nil
return createConn(conn, rAddr, config, state.isClient, state)
}

0 comments on commit f4011b1

Please sign in to comment.