Skip to content

Commit

Permalink
On Read Retransmit send FSM to SENDING
Browse files Browse the repository at this point in the history
RFC6347 Section-4.2.4 states

```
The implementation reads a retransmitted flight from the peer: the
implementation transitions to the SENDING state, where it
retransmits the flight, resets the retransmit timer, and returns
to the WAITING state.  The rationale here is that the receipt of a
duplicate message is the likely result of timer expiry on the peer
and therefore suggests that part of one's previous flight was
lost.
```

Resolves #478
  • Loading branch information
Sean-Der committed Jul 12, 2024
1 parent d1b179a commit 0364d6d
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 51 deletions.
70 changes: 41 additions & 29 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ type addrPkt struct {
data []byte
}

type recvHandshakeState struct {
done chan struct{}
isRetransmit bool
}

// Conn represents a DTLS connection
type Conn struct {
lock sync.RWMutex // Internal lock (must not be public)
Expand Down Expand Up @@ -82,7 +87,7 @@ type Conn struct {
log logging.LeveledLogger

reading chan struct{}
handshakeRecv chan chan struct{}
handshakeRecv chan recvHandshakeState
cancelHandshaker func()
cancelHandshakeReader func()

Expand Down Expand Up @@ -137,7 +142,7 @@ func createConn(nextConn net.PacketConn, rAddr net.Addr, config *Config, isClien
writeDeadline: deadline.New(),

reading: make(chan struct{}, 1),
handshakeRecv: make(chan chan struct{}),
handshakeRecv: make(chan recvHandshakeState),
closed: closer.NewCloser(),
cancelHandshaker: func() {},

Expand Down Expand Up @@ -704,9 +709,9 @@ func (c *Conn) readAndBuffer(ctx context.Context) error {
return err
}

var hasHandshake bool
var hasHandshake, isRetransmit bool
for _, p := range pkts {
hs, alert, err := c.handleIncomingPacket(ctx, p, rAddr, true)
hs, rtx, alert, err := c.handleIncomingPacket(ctx, p, rAddr, true)
if alert != nil {
if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil {
if err == nil {
Expand All @@ -725,14 +730,20 @@ func (c *Conn) readAndBuffer(ctx context.Context) error {
if hs {
hasHandshake = true
}
if rtx {
isRetransmit = true

Check warning on line 734 in conn.go

View check run for this annotation

Codecov / codecov/patch

conn.go#L734

Added line #L734 was not covered by tests
}
}
if hasHandshake {
done := make(chan struct{})
s := recvHandshakeState{
done: make(chan struct{}),
isRetransmit: isRetransmit,
}
select {
case c.handshakeRecv <- done:
case c.handshakeRecv <- s:
// If the other party may retransmit the flight,
// we should respond even if it not a new message.
<-done
<-s.done
case <-c.fsm.Done():
}
}
Expand All @@ -744,7 +755,7 @@ func (c *Conn) handleQueuedPackets(ctx context.Context) error {
c.encryptedPackets = nil

for _, p := range pkts {
_, alert, err := c.handleIncomingPacket(ctx, p.data, p.rAddr, false) // don't re-enqueue
_, _, alert, err := c.handleIncomingPacket(ctx, p.data, p.rAddr, false) // don't re-enqueue
if alert != nil {
if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil {
if err == nil {
Expand All @@ -771,7 +782,7 @@ func (c *Conn) enqueueEncryptedPackets(packet addrPkt) bool {
return false
}

func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.Addr, enqueue bool) (bool, *alert.Alert, error) { //nolint:gocognit
func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.Addr, enqueue bool) (bool, bool, *alert.Alert, error) { //nolint:gocognit
h := &recordlayer.Header{}
// Set connection ID size so that records of content type tls12_cid will
// be parsed correctly.
Expand All @@ -782,7 +793,7 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A
// Decode error must be silently discarded
// [RFC6347 Section-4.1.2.7]
c.log.Debugf("discarded broken packet: %v", err)
return false, nil, nil
return false, false, nil, nil

Check warning on line 796 in conn.go

View check run for this annotation

Codecov / codecov/patch

conn.go#L796

Added line #L796 was not covered by tests
}
// Validate epoch
remoteEpoch := c.state.getRemoteEpoch()
Expand All @@ -791,14 +802,14 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A
c.log.Debugf("discarded future packet (epoch: %d, seq: %d)",
h.Epoch, h.SequenceNumber,
)
return false, nil, nil
return false, false, nil, nil

Check warning on line 805 in conn.go

View check run for this annotation

Codecov / codecov/patch

conn.go#L805

Added line #L805 was not covered by tests
}
if enqueue {
if ok := c.enqueueEncryptedPackets(addrPkt{rAddr, buf}); ok {
c.log.Debug("received packet of next epoch, queuing packet")
}
}
return false, nil, nil
return false, false, nil, nil
}

// Anti-replay protection
Expand All @@ -812,7 +823,7 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A
c.log.Debugf("discarded duplicated packet (epoch: %d, seq: %d)",
h.Epoch, h.SequenceNumber,
)
return false, nil, nil
return false, false, nil, nil
}

// originalCID indicates whether the original record had content type
Expand All @@ -827,14 +838,14 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A
c.log.Debug("handshake not finished, queuing packet")
}
}
return false, nil, nil
return false, false, nil, nil

Check warning on line 841 in conn.go

View check run for this annotation

Codecov / codecov/patch

conn.go#L841

Added line #L841 was not covered by tests
}

// If a connection identifier had been negotiated and encryption is
// enabled, the connection identifier MUST be sent.
if len(c.state.localConnectionID) > 0 && h.ContentType != protocol.ContentTypeConnectionID {
c.log.Debug("discarded packet missing connection ID after value negotiated")
return false, nil, nil
return false, false, nil, nil

Check warning on line 848 in conn.go

View check run for this annotation

Codecov / codecov/patch

conn.go#L848

Added line #L848 was not covered by tests
}

var err error
Expand All @@ -845,7 +856,7 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A
buf, err = c.state.cipherSuite.Decrypt(hdr, buf)
if err != nil {
c.log.Debugf("%s: decrypt failed: %s", srvCliStr(c.state.isClient), err)
return false, nil, nil
return false, false, nil, nil

Check warning on line 859 in conn.go

View check run for this annotation

Codecov / codecov/patch

conn.go#L859

Added line #L859 was not covered by tests
}
// If this is a connection ID record, make it look like a normal record for
// further processing.
Expand All @@ -854,7 +865,7 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A
ip := &recordlayer.InnerPlaintext{}
if err := ip.Unmarshal(buf[h.Size():]); err != nil { //nolint:govet
c.log.Debugf("unpacking inner plaintext failed: %s", err)
return false, nil, nil
return false, false, nil, nil

Check warning on line 868 in conn.go

View check run for this annotation

Codecov / codecov/patch

conn.go#L868

Added line #L868 was not covered by tests
}
unpacked := &recordlayer.Header{
ContentType: ip.RealType,
Expand All @@ -866,26 +877,27 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A
buf, err = unpacked.Marshal()
if err != nil {
c.log.Debugf("converting CID record to inner plaintext failed: %s", err)
return false, nil, nil
return false, false, nil, nil

Check warning on line 880 in conn.go

View check run for this annotation

Codecov / codecov/patch

conn.go#L880

Added line #L880 was not covered by tests
}
buf = append(buf, ip.Content...)
}

// If connection ID does not match discard the packet.
if !bytes.Equal(c.state.localConnectionID, h.ConnectionID) {
c.log.Debug("unexpected connection ID")
return false, nil, nil
return false, false, nil, nil

Check warning on line 888 in conn.go

View check run for this annotation

Codecov / codecov/patch

conn.go#L888

Added line #L888 was not covered by tests
}
}

isHandshake, err := c.fragmentBuffer.push(append([]byte{}, buf...))
isHandshake, isRetransmit, err := c.fragmentBuffer.push(append([]byte{}, buf...))
if err != nil {
// Decode error must be silently discarded
// [RFC6347 Section-4.1.2.7]
c.log.Debugf("defragment failed: %s", err)
return false, nil, nil
return false, false, nil, nil

Check warning on line 897 in conn.go

View check run for this annotation

Codecov / codecov/patch

conn.go#L897

Added line #L897 was not covered by tests
} else if isHandshake {
markPacketAsValid()

for out, epoch := c.fragmentBuffer.pop(); out != nil; out, epoch = c.fragmentBuffer.pop() {
header := &handshake.Header{}
if err := header.Unmarshal(out); err != nil {
Expand All @@ -895,12 +907,12 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A
c.handshakeCache.push(out, epoch, header.MessageSequence, header.Type, !c.state.isClient)
}

return true, nil, nil
return true, isRetransmit, nil, nil
}

r := &recordlayer.RecordLayer{}
if err := r.Unmarshal(buf); err != nil {
return false, &alert.Alert{Level: alert.Fatal, Description: alert.DecodeError}, err
return false, false, &alert.Alert{Level: alert.Fatal, Description: alert.DecodeError}, err

Check warning on line 915 in conn.go

View check run for this annotation

Codecov / codecov/patch

conn.go#L915

Added line #L915 was not covered by tests
}

isLatestSeqNum := false
Expand All @@ -913,15 +925,15 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A
a = &alert.Alert{Level: alert.Warning, Description: alert.CloseNotify}
}
_ = markPacketAsValid()
return false, a, &alertError{content}
return false, false, a, &alertError{content}
case *protocol.ChangeCipherSpec:
if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() {
if enqueue {
if ok := c.enqueueEncryptedPackets(addrPkt{rAddr, buf}); ok {
c.log.Debugf("CipherSuite not initialized, queuing packet")
}
}
return false, nil, nil
return false, false, nil, nil
}

newRemoteEpoch := h.Epoch + 1
Expand All @@ -933,7 +945,7 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A
}
case *protocol.ApplicationData:
if h.Epoch == 0 {
return false, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, errApplicationDataEpochZero
return false, false, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, errApplicationDataEpochZero

Check warning on line 948 in conn.go

View check run for this annotation

Codecov / codecov/patch

conn.go#L948

Added line #L948 was not covered by tests
}

isLatestSeqNum = markPacketAsValid()
Expand All @@ -945,7 +957,7 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A
}

default:
return false, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, fmt.Errorf("%w: %d", errUnhandledContextType, content.ContentType())
return false, false, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, fmt.Errorf("%w: %d", errUnhandledContextType, content.ContentType())

Check warning on line 960 in conn.go

View check run for this annotation

Codecov / codecov/patch

conn.go#L960

Added line #L960 was not covered by tests
}

// Any valid connection ID record is a candidate for updating the remote
Expand All @@ -959,10 +971,10 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A
}
}

return false, nil, nil
return false, false, nil, nil
}

func (c *Conn) recvHandshake() <-chan chan struct{} {
func (c *Conn) recvHandshake() <-chan recvHandshakeState {
return c.handshakeRecv
}

Expand Down
2 changes: 1 addition & 1 deletion flight1handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func (f *flight1TestMockFlightConn) notify(context.Context, alert.Level, alert.D
return nil
}
func (f *flight1TestMockFlightConn) writePackets(context.Context, []*packet) error { return nil }
func (f *flight1TestMockFlightConn) recvHandshake() <-chan chan struct{} { return nil }
func (f *flight1TestMockFlightConn) recvHandshake() <-chan recvHandshakeState { return nil }
func (f *flight1TestMockFlightConn) setLocalEpoch(uint16) {}
func (f *flight1TestMockFlightConn) handleQueuedPackets(context.Context) error { return nil }
func (f *flight1TestMockFlightConn) sessionKey() []byte { return nil }
Expand Down
2 changes: 1 addition & 1 deletion flight4handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func (f *flight4TestMockFlightConn) notify(context.Context, alert.Level, alert.D
return nil
}
func (f *flight4TestMockFlightConn) writePackets(context.Context, []*packet) error { return nil }
func (f *flight4TestMockFlightConn) recvHandshake() <-chan chan struct{} { return nil }
func (f *flight4TestMockFlightConn) recvHandshake() <-chan recvHandshakeState { return nil }
func (f *flight4TestMockFlightConn) setLocalEpoch(uint16) {}
func (f *flight4TestMockFlightConn) handleQueuedPackets(context.Context) error { return nil }
func (f *flight4TestMockFlightConn) sessionKey() []byte { return nil }
Expand Down
15 changes: 9 additions & 6 deletions fragment_buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,26 +43,29 @@ func (f *fragmentBuffer) size() int {
// Attempts to push a DTLS packet to the fragmentBuffer
// when it returns true it means the fragmentBuffer has inserted and the buffer shouldn't be handled
// when an error returns it is fatal, and the DTLS connection should be stopped
func (f *fragmentBuffer) push(buf []byte) (bool, error) {
func (f *fragmentBuffer) push(buf []byte) (isHandshake, isRetransmit bool, err error) {
if f.size()+len(buf) >= fragmentBufferMaxSize {
return false, errFragmentBufferOverflow
return false, false, errFragmentBufferOverflow
}

frag := new(fragment)
if err := frag.recordLayerHeader.Unmarshal(buf); err != nil {
return false, err
return false, false, err

Check warning on line 53 in fragment_buffer.go

View check run for this annotation

Codecov / codecov/patch

fragment_buffer.go#L53

Added line #L53 was not covered by tests
}

// fragment isn't a handshake, we don't need to handle it
if frag.recordLayerHeader.ContentType != protocol.ContentTypeHandshake {
return false, nil
return false, false, nil
}

for buf = buf[recordlayer.FixedHeaderSize:]; len(buf) != 0; frag = new(fragment) {
if err := frag.handshakeHeader.Unmarshal(buf); err != nil {
return false, err
return false, false, err

Check warning on line 63 in fragment_buffer.go

View check run for this annotation

Codecov / codecov/patch

fragment_buffer.go#L63

Added line #L63 was not covered by tests
}

// Fragment is a retransmission. We have already assembled it before successfully
isRetransmit = frag.handshakeHeader.FragmentOffset == 0 && frag.handshakeHeader.MessageSequence < f.currentMessageSequenceNumber

if _, ok := f.cache[frag.handshakeHeader.MessageSequence]; !ok {
f.cache[frag.handshakeHeader.MessageSequence] = []*fragment{}
}
Expand All @@ -80,7 +83,7 @@ func (f *fragmentBuffer) push(buf []byte) (bool, error) {
buf = buf[end:]
}

return true, nil
return true, isRetransmit, nil
}

func (f *fragmentBuffer) pop() (content []byte, epoch uint16) {
Expand Down
6 changes: 3 additions & 3 deletions fragment_buffer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func TestFragmentBuffer(t *testing.T) {
} {
fragmentBuffer := newFragmentBuffer()
for _, frag := range test.In {
status, err := fragmentBuffer.push(frag)
status, _, err := fragmentBuffer.push(frag)
if err != nil {
t.Error(err)
} else if !status {
Expand Down Expand Up @@ -122,13 +122,13 @@ func TestFragmentBuffer_Overflow(t *testing.T) {
fragmentBuffer := newFragmentBuffer()

// Push a buffer that doesn't exceed size limits
if _, err := fragmentBuffer.push([]byte{0x16, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0F, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0xfe, 0xff, 0x00}); err != nil {
if _, _, err := fragmentBuffer.push([]byte{0x16, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0F, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0xfe, 0xff, 0x00}); err != nil {
t.Fatal(err)
}

// Allocate a buffer that exceeds cache size
largeBuffer := make([]byte, fragmentBufferMaxSize)
if _, err := fragmentBuffer.push(largeBuffer); !errors.Is(err, errFragmentBufferOverflow) {
if _, _, err := fragmentBuffer.push(largeBuffer); !errors.Is(err, errFragmentBufferOverflow) {
t.Fatalf("Pushing a large buffer returned (%s) expected(%s)", err, errFragmentBufferOverflow)
}
}
15 changes: 10 additions & 5 deletions handshaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ type handshakeConfig struct {
type flightConn interface {
notify(ctx context.Context, level alert.Level, desc alert.Description) error
writePackets(context.Context, []*packet) error
recvHandshake() <-chan chan struct{}
recvHandshake() <-chan recvHandshakeState
setLocalEpoch(epoch uint16)
handleQueuedPackets(context.Context) error
sessionKey() []byte
Expand Down Expand Up @@ -280,10 +280,15 @@ func (s *handshakeFSM) wait(ctx context.Context, c flightConn) (handshakeState,
retransmitTimer := time.NewTimer(s.retransmitInterval)
for {
select {
case done := <-c.recvHandshake():
case state := <-c.recvHandshake():
if state.isRetransmit {
close(state.done)
return handshakeSending, nil

Check warning on line 286 in handshaker.go

View check run for this annotation

Codecov / codecov/patch

handshaker.go#L285-L286

Added lines #L285 - L286 were not covered by tests
}

nextFlight, alert, err := parse(ctx, c, s.state, s.cache, s.cfg)
s.retransmitInterval = s.cfg.initialRetransmitInterval
close(done)
close(state.done)
if alert != nil {
if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil {
if err != nil {
Expand Down Expand Up @@ -328,8 +333,8 @@ func (s *handshakeFSM) wait(ctx context.Context, c flightConn) (handshakeState,

func (s *handshakeFSM) finish(ctx context.Context, c flightConn) (handshakeState, error) {
select {
case done := <-c.recvHandshake():
close(done)
case state := <-c.recvHandshake():
close(state.done)
return handshakeSending, nil
case <-ctx.Done():
return handshakeErrored, ctx.Err()
Expand Down
Loading

0 comments on commit 0364d6d

Please sign in to comment.