Skip to content

Commit

Permalink
On Read Retransmit send FSM to SENDING
Browse files Browse the repository at this point in the history
RFC6347 Section-4.2.4 states

```
The implementation reads a retransmitted flight from the peer: the
implementation transitions to the SENDING state, where it
retransmits the flight, resets the retransmit timer, and returns
to the WAITING state.  The rationale here is that the receipt of a
duplicate message is the likely result of timer expiry on the peer
and therefore suggests that part of one's previous flight was
lost.
```

Resolves #478
  • Loading branch information
Sean-Der committed Jul 15, 2024
1 parent ec76652 commit d013d0c
Show file tree
Hide file tree
Showing 9 changed files with 119 additions and 55 deletions.
70 changes: 41 additions & 29 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ type addrPkt struct {
data []byte
}

type recvHandshakeState struct {
done chan struct{}
isRetransmit bool
}

// Conn represents a DTLS connection
type Conn struct {
lock sync.RWMutex // Internal lock (must not be public)
Expand Down Expand Up @@ -82,7 +87,7 @@ type Conn struct {
log logging.LeveledLogger

reading chan struct{}
handshakeRecv chan chan struct{}
handshakeRecv chan recvHandshakeState
cancelHandshaker func()
cancelHandshakeReader func()

Expand Down Expand Up @@ -137,7 +142,7 @@ func createConn(nextConn net.PacketConn, rAddr net.Addr, config *Config, isClien
writeDeadline: deadline.New(),

reading: make(chan struct{}, 1),
handshakeRecv: make(chan chan struct{}),
handshakeRecv: make(chan recvHandshakeState),
closed: closer.NewCloser(),
cancelHandshaker: func() {},

Expand Down Expand Up @@ -704,9 +709,9 @@ func (c *Conn) readAndBuffer(ctx context.Context) error {
return err
}

var hasHandshake bool
var hasHandshake, isRetransmit bool
for _, p := range pkts {
hs, alert, err := c.handleIncomingPacket(ctx, p, rAddr, true)
hs, rtx, alert, err := c.handleIncomingPacket(ctx, p, rAddr, true)
if alert != nil {
if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil {
if err == nil {
Expand All @@ -725,14 +730,20 @@ func (c *Conn) readAndBuffer(ctx context.Context) error {
if hs {
hasHandshake = true
}
if rtx {
isRetransmit = true
}
}
if hasHandshake {
done := make(chan struct{})
s := recvHandshakeState{
done: make(chan struct{}),
isRetransmit: isRetransmit,
}
select {
case c.handshakeRecv <- done:
case c.handshakeRecv <- s:
// If the other party may retransmit the flight,
// we should respond even if it not a new message.
<-done
<-s.done
case <-c.fsm.Done():
}
}
Expand All @@ -744,7 +755,7 @@ func (c *Conn) handleQueuedPackets(ctx context.Context) error {
c.encryptedPackets = nil

for _, p := range pkts {
_, alert, err := c.handleIncomingPacket(ctx, p.data, p.rAddr, false) // don't re-enqueue
_, _, alert, err := c.handleIncomingPacket(ctx, p.data, p.rAddr, false) // don't re-enqueue
if alert != nil {
if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil {
if err == nil {
Expand All @@ -771,7 +782,7 @@ func (c *Conn) enqueueEncryptedPackets(packet addrPkt) bool {
return false
}

func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.Addr, enqueue bool) (bool, *alert.Alert, error) { //nolint:gocognit
func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.Addr, enqueue bool) (bool, bool, *alert.Alert, error) { //nolint:gocognit
h := &recordlayer.Header{}
// Set connection ID size so that records of content type tls12_cid will
// be parsed correctly.
Expand All @@ -782,7 +793,7 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A
// Decode error must be silently discarded
// [RFC6347 Section-4.1.2.7]
c.log.Debugf("discarded broken packet: %v", err)
return false, nil, nil
return false, false, nil, nil
}
// Validate epoch
remoteEpoch := c.state.getRemoteEpoch()
Expand All @@ -791,14 +802,14 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A
c.log.Debugf("discarded future packet (epoch: %d, seq: %d)",
h.Epoch, h.SequenceNumber,
)
return false, nil, nil
return false, false, nil, nil
}
if enqueue {
if ok := c.enqueueEncryptedPackets(addrPkt{rAddr, buf}); ok {
c.log.Debug("received packet of next epoch, queuing packet")
}
}
return false, nil, nil
return false, false, nil, nil
}

// Anti-replay protection
Expand All @@ -812,7 +823,7 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A
c.log.Debugf("discarded duplicated packet (epoch: %d, seq: %d)",
h.Epoch, h.SequenceNumber,
)
return false, nil, nil
return false, false, nil, nil
}

// originalCID indicates whether the original record had content type
Expand All @@ -827,14 +838,14 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A
c.log.Debug("handshake not finished, queuing packet")
}
}
return false, nil, nil
return false, false, nil, nil
}

// If a connection identifier had been negotiated and encryption is
// enabled, the connection identifier MUST be sent.
if len(c.state.getLocalConnectionID()) > 0 && h.ContentType != protocol.ContentTypeConnectionID {
c.log.Debug("discarded packet missing connection ID after value negotiated")
return false, nil, nil
return false, false, nil, nil
}

var err error
Expand All @@ -845,7 +856,7 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A
buf, err = c.state.cipherSuite.Decrypt(hdr, buf)
if err != nil {
c.log.Debugf("%s: decrypt failed: %s", srvCliStr(c.state.isClient), err)
return false, nil, nil
return false, false, nil, nil
}
// If this is a connection ID record, make it look like a normal record for
// further processing.
Expand All @@ -854,7 +865,7 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A
ip := &recordlayer.InnerPlaintext{}
if err := ip.Unmarshal(buf[h.Size():]); err != nil { //nolint:govet
c.log.Debugf("unpacking inner plaintext failed: %s", err)
return false, nil, nil
return false, false, nil, nil
}
unpacked := &recordlayer.Header{
ContentType: ip.RealType,
Expand All @@ -866,26 +877,27 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A
buf, err = unpacked.Marshal()
if err != nil {
c.log.Debugf("converting CID record to inner plaintext failed: %s", err)
return false, nil, nil
return false, false, nil, nil
}
buf = append(buf, ip.Content...)
}

// If connection ID does not match discard the packet.
if !bytes.Equal(c.state.getLocalConnectionID(), h.ConnectionID) {
c.log.Debug("unexpected connection ID")
return false, nil, nil
return false, false, nil, nil
}
}

isHandshake, err := c.fragmentBuffer.push(append([]byte{}, buf...))
isHandshake, isRetransmit, err := c.fragmentBuffer.push(append([]byte{}, buf...))
if err != nil {
// Decode error must be silently discarded
// [RFC6347 Section-4.1.2.7]
c.log.Debugf("defragment failed: %s", err)
return false, nil, nil
return false, false, nil, nil
} else if isHandshake {
markPacketAsValid()

for out, epoch := c.fragmentBuffer.pop(); out != nil; out, epoch = c.fragmentBuffer.pop() {
header := &handshake.Header{}
if err := header.Unmarshal(out); err != nil {
Expand All @@ -895,12 +907,12 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A
c.handshakeCache.push(out, epoch, header.MessageSequence, header.Type, !c.state.isClient)
}

return true, nil, nil
return true, isRetransmit, nil, nil
}

r := &recordlayer.RecordLayer{}
if err := r.Unmarshal(buf); err != nil {
return false, &alert.Alert{Level: alert.Fatal, Description: alert.DecodeError}, err
return false, false, &alert.Alert{Level: alert.Fatal, Description: alert.DecodeError}, err
}

isLatestSeqNum := false
Expand All @@ -913,15 +925,15 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A
a = &alert.Alert{Level: alert.Warning, Description: alert.CloseNotify}
}
_ = markPacketAsValid()
return false, a, &alertError{content}
return false, false, a, &alertError{content}
case *protocol.ChangeCipherSpec:
if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() {
if enqueue {
if ok := c.enqueueEncryptedPackets(addrPkt{rAddr, buf}); ok {
c.log.Debugf("CipherSuite not initialized, queuing packet")
}
}
return false, nil, nil
return false, false, nil, nil
}

newRemoteEpoch := h.Epoch + 1
Expand All @@ -933,7 +945,7 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A
}
case *protocol.ApplicationData:
if h.Epoch == 0 {
return false, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, errApplicationDataEpochZero
return false, false, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, errApplicationDataEpochZero
}

isLatestSeqNum = markPacketAsValid()
Expand All @@ -945,7 +957,7 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A
}

default:
return false, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, fmt.Errorf("%w: %d", errUnhandledContextType, content.ContentType())
return false, false, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, fmt.Errorf("%w: %d", errUnhandledContextType, content.ContentType())
}

// Any valid connection ID record is a candidate for updating the remote
Expand All @@ -959,10 +971,10 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A
}
}

return false, nil, nil
return false, false, nil, nil
}

func (c *Conn) recvHandshake() <-chan chan struct{} {
func (c *Conn) recvHandshake() <-chan recvHandshakeState {
return c.handshakeRecv
}

Expand Down
21 changes: 21 additions & 0 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3482,3 +3482,24 @@ func TestOnConnectionAttempt(t *testing.T) {
t.Fatal("OnConnectionAttempt fired for client")
}
}

func TestFragmentBuffer_Retransmission(t *testing.T) {
fragmentBuffer := newFragmentBuffer()
frag := []byte{0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x30, 0x03, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0xfe, 0xff, 0x01, 0x01}

if _, isRetransmission, err := fragmentBuffer.push(frag); err != nil {
t.Fatal(err)
} else if isRetransmission {
t.Fatal("fragment should not be retransmission")
}

if v, _ := fragmentBuffer.pop(); v == nil {
t.Fatal("Failed to pop fragment")
}

if _, isRetransmission, err := fragmentBuffer.push(frag); err != nil {
t.Fatal(err)
} else if !isRetransmission {
t.Fatal("fragment should be retransmission")
}
}
31 changes: 27 additions & 4 deletions e2e/e2e_lossy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,11 @@ func TestPionE2ELossy(t *testing.T) {
}

for _, test := range []struct {
LossChanceRange int
DoClientAuth bool
CipherSuites []dtls.CipherSuiteID
MTU int
LossChanceRange int
DoClientAuth bool
CipherSuites []dtls.CipherSuiteID
MTU int
DisableServerFlightInterval bool
}{
{
LossChanceRange: 0,
Expand Down Expand Up @@ -109,6 +110,20 @@ func TestPionE2ELossy(t *testing.T) {
MTU: 100,
DoClientAuth: true,
},
// Incoming retransmitted handshakes should cause us to retransmit. Disabling the FlightInterval on one side
// means that a incoming re-transmissions causes the retransmission to be fired
{
LossChanceRange: 10,
DisableServerFlightInterval: true,
},
{
LossChanceRange: 20,
DisableServerFlightInterval: true,
},
{
LossChanceRange: 50,
DisableServerFlightInterval: true,
},
} {
name := fmt.Sprintf("Loss%d_MTU%d", test.LossChanceRange, test.MTU)
if test.DoClientAuth {
Expand All @@ -117,6 +132,10 @@ func TestPionE2ELossy(t *testing.T) {
for _, ciph := range test.CipherSuites {
name += "_With" + ciph.String()
}
if test.DisableServerFlightInterval {
name += "_WithNoServerFlightInterval"
}

test := test
t.Run(name, func(t *testing.T) {
// Limit runtime in case of deadlocks
Expand Down Expand Up @@ -162,6 +181,10 @@ func TestPionE2ELossy(t *testing.T) {
cfg.ClientAuth = dtls.RequireAnyClientCert
}

if test.DisableServerFlightInterval {
cfg.FlightInterval = time.Hour
}

server, startupErr := dtls.Server(dtlsnet.PacketConnFromConn(br.GetConn1()), br.GetConn1().RemoteAddr(), cfg)
serverDone <- runResult{server, startupErr}
}()
Expand Down
2 changes: 1 addition & 1 deletion flight1handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func (f *flight1TestMockFlightConn) notify(context.Context, alert.Level, alert.D
return nil
}
func (f *flight1TestMockFlightConn) writePackets(context.Context, []*packet) error { return nil }
func (f *flight1TestMockFlightConn) recvHandshake() <-chan chan struct{} { return nil }
func (f *flight1TestMockFlightConn) recvHandshake() <-chan recvHandshakeState { return nil }
func (f *flight1TestMockFlightConn) setLocalEpoch(uint16) {}
func (f *flight1TestMockFlightConn) handleQueuedPackets(context.Context) error { return nil }
func (f *flight1TestMockFlightConn) sessionKey() []byte { return nil }
Expand Down
2 changes: 1 addition & 1 deletion flight4handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func (f *flight4TestMockFlightConn) notify(context.Context, alert.Level, alert.D
return nil
}
func (f *flight4TestMockFlightConn) writePackets(context.Context, []*packet) error { return nil }
func (f *flight4TestMockFlightConn) recvHandshake() <-chan chan struct{} { return nil }
func (f *flight4TestMockFlightConn) recvHandshake() <-chan recvHandshakeState { return nil }
func (f *flight4TestMockFlightConn) setLocalEpoch(uint16) {}
func (f *flight4TestMockFlightConn) handleQueuedPackets(context.Context) error { return nil }
func (f *flight4TestMockFlightConn) sessionKey() []byte { return nil }
Expand Down
15 changes: 9 additions & 6 deletions fragment_buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,26 +43,29 @@ func (f *fragmentBuffer) size() int {
// Attempts to push a DTLS packet to the fragmentBuffer
// when it returns true it means the fragmentBuffer has inserted and the buffer shouldn't be handled
// when an error returns it is fatal, and the DTLS connection should be stopped
func (f *fragmentBuffer) push(buf []byte) (bool, error) {
func (f *fragmentBuffer) push(buf []byte) (isHandshake, isRetransmit bool, err error) {
if f.size()+len(buf) >= fragmentBufferMaxSize {
return false, errFragmentBufferOverflow
return false, false, errFragmentBufferOverflow
}

frag := new(fragment)
if err := frag.recordLayerHeader.Unmarshal(buf); err != nil {
return false, err
return false, false, err
}

// fragment isn't a handshake, we don't need to handle it
if frag.recordLayerHeader.ContentType != protocol.ContentTypeHandshake {
return false, nil
return false, false, nil
}

for buf = buf[recordlayer.FixedHeaderSize:]; len(buf) != 0; frag = new(fragment) {
if err := frag.handshakeHeader.Unmarshal(buf); err != nil {
return false, err
return false, false, err
}

// Fragment is a retransmission. We have already assembled it before successfully
isRetransmit = frag.handshakeHeader.FragmentOffset == 0 && frag.handshakeHeader.MessageSequence < f.currentMessageSequenceNumber

if _, ok := f.cache[frag.handshakeHeader.MessageSequence]; !ok {
f.cache[frag.handshakeHeader.MessageSequence] = []*fragment{}
}
Expand All @@ -80,7 +83,7 @@ func (f *fragmentBuffer) push(buf []byte) (bool, error) {
buf = buf[end:]
}

return true, nil
return true, isRetransmit, nil
}

func (f *fragmentBuffer) pop() (content []byte, epoch uint16) {
Expand Down
Loading

0 comments on commit d013d0c

Please sign in to comment.