Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Drop invalid record silently during handshake #604

Merged
merged 1 commit into from
Jan 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1026,6 +1026,10 @@ func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFligh
} else {
switch {
case errors.Is(err, context.DeadlineExceeded), errors.Is(err, context.Canceled), errors.Is(err, io.EOF), errors.Is(err, net.ErrClosed):
case errors.Is(err, recordlayer.ErrInvalidPacketLength):
// Decode error must be silently discarded
// [RFC6347 Section-4.1.2.7]
continue
default:
if c.isHandshakeCompletedSuccessfully() {
// Keep read loop and pass the read error to Read()
Expand Down
79 changes: 78 additions & 1 deletion conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ func TestHandshakeWithAlert(t *testing.T) {
clientErr <- err
}()

_, errServer := testServer(ctx, dtlsnet.PacketConnFromConn(cb), ca.RemoteAddr(), testCase.configServer, true)
_, errServer := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), testCase.configServer, true)
if !errors.Is(errServer, testCase.errServer) {
t.Fatalf("Server error exp(%v) failed(%v)", testCase.errServer, errServer)
}
Expand All @@ -402,6 +402,71 @@ func TestHandshakeWithAlert(t *testing.T) {
}
}

func TestHandshakeWithInvalidRecord(t *testing.T) {
// Limit runtime in case of deadlocks
lim := test.TimeOut(time.Second * 20)
defer lim.Stop()

// Check for leaking routines
report := test.CheckRoutines(t)
defer report()

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

type result struct {
c *Conn
err error
}
clientErr := make(chan result, 1)
ca, cb := dpipe.Pipe()
caWithInvalidRecord := &connWithCallback{Conn: ca}

var msgSeq atomic.Int32
// Send invalid record after first message
caWithInvalidRecord.onWrite = func(b []byte) {
if msgSeq.Add(1) == 2 {
if _, err := ca.Write([]byte{0x01, 0x02}); err != nil {
t.Fatal(err)
}
}
}
go func() {
client, err := testClient(ctx, dtlsnet.PacketConnFromConn(caWithInvalidRecord), caWithInvalidRecord.RemoteAddr(), &Config{
CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
}, true)
clientErr <- result{client, err}
}()

server, errServer := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{
CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
}, true)

errClient := <-clientErr

defer func() {
if server != nil {
if err := server.Close(); err != nil {
t.Fatal(err)
}
}

if errClient.c != nil {
if err := errClient.c.Close(); err != nil {
t.Fatal(err)
}
}
}()

if errServer != nil {
t.Fatalf("Server failed(%v)", errServer)
}

if errClient.err != nil {
t.Fatalf("Client failed(%v)", errClient.err)
}
}

func TestExportKeyingMaterial(t *testing.T) {
// Check for leaking routines
report := test.CheckRoutines(t)
Expand Down Expand Up @@ -3096,3 +3161,15 @@ func TestSkipHelloVerify(t *testing.T) {
t.Error(err)
}
}

type connWithCallback struct {
net.Conn
onWrite func([]byte)
}

func (c *connWithCallback) Write(b []byte) (int, error) {
if c.onWrite != nil {
c.onWrite(b)
}
return c.Conn.Write(b)
}
12 changes: 7 additions & 5 deletions pkg/protocol/recordlayer/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@ import (
)

var (
errBufferTooSmall = &protocol.TemporaryError{Err: errors.New("buffer is too small")} //nolint:goerr113
errInvalidPacketLength = &protocol.TemporaryError{Err: errors.New("packet length and declared length do not match")} //nolint:goerr113
errSequenceNumberOverflow = &protocol.InternalError{Err: errors.New("sequence number overflow")} //nolint:goerr113
errUnsupportedProtocolVersion = &protocol.FatalError{Err: errors.New("unsupported protocol version")} //nolint:goerr113
errInvalidContentType = &protocol.TemporaryError{Err: errors.New("invalid content type")} //nolint:goerr113
// ErrInvalidPacketLength is returned when the packet length too small or declared length do not match
ErrInvalidPacketLength = &protocol.TemporaryError{Err: errors.New("packet length and declared length do not match")} //nolint:goerr113

errBufferTooSmall = &protocol.TemporaryError{Err: errors.New("buffer is too small")} //nolint:goerr113
errSequenceNumberOverflow = &protocol.InternalError{Err: errors.New("sequence number overflow")} //nolint:goerr113
errUnsupportedProtocolVersion = &protocol.FatalError{Err: errors.New("unsupported protocol version")} //nolint:goerr113
errInvalidContentType = &protocol.TemporaryError{Err: errors.New("invalid content type")} //nolint:goerr113
)
8 changes: 4 additions & 4 deletions pkg/protocol/recordlayer/recordlayer.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,12 @@

for offset := 0; len(buf) != offset; {
if len(buf)-offset <= FixedHeaderSize {
return nil, errInvalidPacketLength
return nil, ErrInvalidPacketLength
}

pktLen := (FixedHeaderSize + int(binary.BigEndian.Uint16(buf[offset+11:])))
if offset+pktLen > len(buf) {
return nil, errInvalidPacketLength
return nil, ErrInvalidPacketLength
}

out = append(out, buf[offset:offset+pktLen])
Expand All @@ -129,12 +129,12 @@
lenIdx += cidLength
}
if len(buf)-offset <= headerSize {
return nil, errInvalidPacketLength
return nil, ErrInvalidPacketLength

Check warning on line 132 in pkg/protocol/recordlayer/recordlayer.go

View check run for this annotation

Codecov / codecov/patch

pkg/protocol/recordlayer/recordlayer.go#L132

Added line #L132 was not covered by tests
}

pktLen := (headerSize + int(binary.BigEndian.Uint16(buf[offset+lenIdx:])))
if offset+pktLen > len(buf) {
return nil, errInvalidPacketLength
return nil, ErrInvalidPacketLength

Check warning on line 137 in pkg/protocol/recordlayer/recordlayer.go

View check run for this annotation

Codecov / codecov/patch

pkg/protocol/recordlayer/recordlayer.go#L137

Added line #L137 was not covered by tests
}

out = append(out, buf[offset:offset+pktLen])
Expand Down
4 changes: 2 additions & 2 deletions pkg/protocol/recordlayer/recordlayer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ func TestUDPDecode(t *testing.T) {
{
Name: "Invalid packet length",
Data: []byte{0x14, 0xfe},
WantError: errInvalidPacketLength,
WantError: ErrInvalidPacketLength,
},
{
Name: "Packet declared invalid length",
Data: []byte{0x14, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12, 0x00, 0xFF, 0x01},
WantError: errInvalidPacketLength,
WantError: ErrInvalidPacketLength,
},
} {
dtlsPkts, err := UnpackDatagram(test.Data)
Expand Down
Loading