Skip to content

Commit

Permalink
Drop invalid record silently during handshake
Browse files Browse the repository at this point in the history
Fix issue: invalid record in handshake staging cause readloop
exited then handshake failed.
  • Loading branch information
cnderrauber committed Jan 2, 2024
1 parent 3e8a7d7 commit 9ffd96c
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 12 deletions.
4 changes: 4 additions & 0 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1026,6 +1026,10 @@ func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFligh
} else {
switch {
case errors.Is(err, context.DeadlineExceeded), errors.Is(err, context.Canceled), errors.Is(err, io.EOF), errors.Is(err, net.ErrClosed):
case errors.Is(err, recordlayer.ErrInvalidPacketLength):
// Decode error must be silently discarded
// [RFC6347 Section-4.1.2.7]
continue
default:
if c.isHandshakeCompletedSuccessfully() {
// Keep read loop and pass the read error to Read()
Expand Down
79 changes: 78 additions & 1 deletion conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ func TestHandshakeWithAlert(t *testing.T) {
clientErr <- err
}()

_, errServer := testServer(ctx, dtlsnet.PacketConnFromConn(cb), ca.RemoteAddr(), testCase.configServer, true)
_, errServer := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), testCase.configServer, true)
if !errors.Is(errServer, testCase.errServer) {
t.Fatalf("Server error exp(%v) failed(%v)", testCase.errServer, errServer)
}
Expand All @@ -402,6 +402,71 @@ func TestHandshakeWithAlert(t *testing.T) {
}
}

func TestHandshakeWithInvalidRecord(t *testing.T) {
// Limit runtime in case of deadlocks
lim := test.TimeOut(time.Second * 20)
defer lim.Stop()

// Check for leaking routines
report := test.CheckRoutines(t)
defer report()

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

type result struct {
c *Conn
err error
}
clientErr := make(chan result, 1)
ca, cb := dpipe.Pipe()
caWithInvalidRecord := &connWithCallback{Conn: ca}

var msgSeq atomic.Int32
// Send invalid record after first message
caWithInvalidRecord.onWrite = func(b []byte) {
if msgSeq.Add(1) == 2 {
if _, err := ca.Write([]byte{0x01, 0x02}); err != nil {
t.Fatal(err)
}
}
}
go func() {
client, err := testClient(ctx, dtlsnet.PacketConnFromConn(caWithInvalidRecord), caWithInvalidRecord.RemoteAddr(), &Config{
CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
}, true)
clientErr <- result{client, err}
}()

server, errServer := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{
CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
}, true)

errClient := <-clientErr

defer func() {
if server != nil {
if err := server.Close(); err != nil {
t.Fatal(err)
}
}

if errClient.c != nil {
if err := errClient.c.Close(); err != nil {
t.Fatal(err)
}
}
}()

if errServer != nil {
t.Fatalf("Server failed(%v)", errServer)
}

if errClient.err != nil {
t.Fatalf("Client failed(%v)", errClient.err)
}
}

func TestExportKeyingMaterial(t *testing.T) {
// Check for leaking routines
report := test.CheckRoutines(t)
Expand Down Expand Up @@ -3096,3 +3161,15 @@ func TestSkipHelloVerify(t *testing.T) {
t.Error(err)
}
}

type connWithCallback struct {
net.Conn
onWrite func([]byte)
}

func (c *connWithCallback) Write(b []byte) (int, error) {
if c.onWrite != nil {
c.onWrite(b)
}
return c.Conn.Write(b)
}
12 changes: 7 additions & 5 deletions pkg/protocol/recordlayer/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@ import (
)

var (
errBufferTooSmall = &protocol.TemporaryError{Err: errors.New("buffer is too small")} //nolint:goerr113
errInvalidPacketLength = &protocol.TemporaryError{Err: errors.New("packet length and declared length do not match")} //nolint:goerr113
errSequenceNumberOverflow = &protocol.InternalError{Err: errors.New("sequence number overflow")} //nolint:goerr113
errUnsupportedProtocolVersion = &protocol.FatalError{Err: errors.New("unsupported protocol version")} //nolint:goerr113
errInvalidContentType = &protocol.TemporaryError{Err: errors.New("invalid content type")} //nolint:goerr113
// ErrInvalidPacketLength is returned when the packet length too small or declared length do not match
ErrInvalidPacketLength = &protocol.TemporaryError{Err: errors.New("packet length and declared length do not match")} //nolint:goerr113

errBufferTooSmall = &protocol.TemporaryError{Err: errors.New("buffer is too small")} //nolint:goerr113
errSequenceNumberOverflow = &protocol.InternalError{Err: errors.New("sequence number overflow")} //nolint:goerr113
errUnsupportedProtocolVersion = &protocol.FatalError{Err: errors.New("unsupported protocol version")} //nolint:goerr113
errInvalidContentType = &protocol.TemporaryError{Err: errors.New("invalid content type")} //nolint:goerr113
)
8 changes: 4 additions & 4 deletions pkg/protocol/recordlayer/recordlayer.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,12 @@ func UnpackDatagram(buf []byte) ([][]byte, error) {

for offset := 0; len(buf) != offset; {
if len(buf)-offset <= FixedHeaderSize {
return nil, errInvalidPacketLength
return nil, ErrInvalidPacketLength
}

pktLen := (FixedHeaderSize + int(binary.BigEndian.Uint16(buf[offset+11:])))
if offset+pktLen > len(buf) {
return nil, errInvalidPacketLength
return nil, ErrInvalidPacketLength
}

out = append(out, buf[offset:offset+pktLen])
Expand All @@ -129,12 +129,12 @@ func ContentAwareUnpackDatagram(buf []byte, cidLength int) ([][]byte, error) {
lenIdx += cidLength
}
if len(buf)-offset <= headerSize {
return nil, errInvalidPacketLength
return nil, ErrInvalidPacketLength
}

pktLen := (headerSize + int(binary.BigEndian.Uint16(buf[offset+lenIdx:])))
if offset+pktLen > len(buf) {
return nil, errInvalidPacketLength
return nil, ErrInvalidPacketLength
}

out = append(out, buf[offset:offset+pktLen])
Expand Down
4 changes: 2 additions & 2 deletions pkg/protocol/recordlayer/recordlayer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ func TestUDPDecode(t *testing.T) {
{
Name: "Invalid packet length",
Data: []byte{0x14, 0xfe},
WantError: errInvalidPacketLength,
WantError: ErrInvalidPacketLength,
},
{
Name: "Packet declared invalid length",
Data: []byte{0x14, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12, 0x00, 0xFF, 0x01},
WantError: errInvalidPacketLength,
WantError: ErrInvalidPacketLength,
},
} {
dtlsPkts, err := UnpackDatagram(test.Data)
Expand Down

0 comments on commit 9ffd96c

Please sign in to comment.