From f02d4630ab98cc00839751c9814aed9bfa5f67e7 Mon Sep 17 00:00:00 2001 From: Stephen Buttolph Date: Fri, 8 Mar 2024 13:43:50 -0500 Subject: [PATCH] Prevent zero length values in slices and maps in codec (#2819) --- codec/codec.go | 2 + codec/reflectcodec/type_codec.go | 25 +++++++++++ codec/test_codec.go | 75 ++++++++++++++++++++++++++++---- 3 files changed, 93 insertions(+), 9 deletions(-) diff --git a/codec/codec.go b/codec/codec.go index 7aacb9085848..6ee799667182 100644 --- a/codec/codec.go +++ b/codec/codec.go @@ -15,6 +15,8 @@ var ( ErrDoesNotImplementInterface = errors.New("does not implement interface") ErrUnexportedField = errors.New("unexported field") ErrExtraSpace = errors.New("trailing buffer space") + ErrMarshalZeroLength = errors.New("can't marshal zero length value") + ErrUnmarshalZeroLength = errors.New("can't unmarshal zero length value") ) // Codec marshals and unmarshals diff --git a/codec/reflectcodec/type_codec.go b/codec/reflectcodec/type_codec.go index 7f567d0bb7fa..31243b2594ae 100644 --- a/codec/reflectcodec/type_codec.go +++ b/codec/reflectcodec/type_codec.go @@ -159,6 +159,10 @@ func (c *genericCodec) size( return 0, false, err } + if size == 0 { + return 0, false, fmt.Errorf("can't marshal slice of zero length values: %w", codec.ErrMarshalZeroLength) + } + // For fixed-size types we manually calculate lengths rather than // processing each element separately to improve performance. if constSize { @@ -235,6 +239,10 @@ func (c *genericCodec) size( return 0, false, err } + if keySize == 0 && valueSize == 0 { + return 0, false, fmt.Errorf("can't marshal map with zero length entries: %w", codec.ErrMarshalZeroLength) + } + switch { case keyConstSize && valueConstSize: numElts := value.Len() @@ -394,9 +402,13 @@ func (c *genericCodec) marshal( return p.Err } for i := 0; i < numElts; i++ { // Process each element in the slice + startOffset := p.Offset if err := c.marshal(value.Index(i), p, typeStack); err != nil { return err } + if startOffset == p.Offset { + return fmt.Errorf("couldn't marshal slice of zero length values: %w", codec.ErrMarshalZeroLength) + } } return nil case reflect.Array: @@ -479,6 +491,8 @@ func (c *genericCodec) marshal( allKeyBytes := slices.Clone(p.Bytes[startOffset:p.Offset]) p.Offset = startOffset for _, key := range sortedKeys { + keyStartOffset := p.Offset + // pack key startIndex := key.startIndex - startOffset endIndex := key.endIndex - startOffset @@ -492,6 +506,9 @@ func (c *genericCodec) marshal( if err := c.marshal(value.MapIndex(key.key), p, typeStack); err != nil { return err } + if keyStartOffset == p.Offset { + return fmt.Errorf("couldn't marshal map with zero length entries: %w", codec.ErrMarshalZeroLength) + } } return nil @@ -625,9 +642,14 @@ func (c *genericCodec) unmarshal( zeroValue := reflect.Zero(innerType) for i := 0; i < numElts; i++ { value.Set(reflect.Append(value, zeroValue)) + + startOffset := p.Offset if err := c.unmarshal(p, value.Index(i), typeStack); err != nil { return err } + if startOffset == p.Offset { + return fmt.Errorf("couldn't unmarshal slice of zero length values: %w", codec.ErrUnmarshalZeroLength) + } } return nil case reflect.Array: @@ -755,6 +777,9 @@ func (c *genericCodec) unmarshal( if err := c.unmarshal(p, mapValue, typeStack); err != nil { return err } + if keyStartOffset == p.Offset { + return fmt.Errorf("couldn't unmarshal map with zero length entries: %w", codec.ErrUnmarshalZeroLength) + } // Assign the key-value pair in the map value.SetMapIndex(mapKey, mapValue) diff --git a/codec/test_codec.go b/codec/test_codec.go index d58e2d818f9e..04d2b53abd38 100644 --- a/codec/test_codec.go +++ b/codec/test_codec.go @@ -36,7 +36,9 @@ var ( TestNilSliceSerialization, TestEmptySliceSerialization, TestSliceWithEmptySerialization, - TestSliceWithEmptySerializationOutOfMemory, + TestSliceWithEmptySerializationError, + TestMapWithEmptySerialization, + TestMapWithEmptySerializationError, TestSliceTooLarge, TestNegativeNumbers, TestTooLargeUnmarshal, @@ -731,7 +733,7 @@ func TestEmptySliceSerialization(codec GeneralCodec, t testing.TB) { require.Equal(val, valUnmarshaled) } -// Test marshaling slice that is not nil and not empty +// Test marshaling empty slice of zero length structs func TestSliceWithEmptySerialization(codec GeneralCodec, t testing.TB) { require := require.New(t) @@ -745,9 +747,9 @@ func TestSliceWithEmptySerialization(codec GeneralCodec, t testing.TB) { require.NoError(manager.RegisterCodec(0, codec)) val := &nestedSliceStruct{ - Arr: make([]emptyStruct, 1000), + Arr: make([]emptyStruct, 0), } - expected := []byte{0x00, 0x00, 0x00, 0x00, 0x03, 0xE8} // codec version (0x00, 0x00) then 1000 for numElts + expected := []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00} // codec version (0x00, 0x00) then (0x00, 0x00, 0x00, 0x00) for numElts result, err := manager.Marshal(0, val) require.NoError(err) require.Equal(expected, result) @@ -760,10 +762,10 @@ func TestSliceWithEmptySerialization(codec GeneralCodec, t testing.TB) { version, err := manager.Unmarshal(expected, &unmarshaled) require.NoError(err) require.Zero(version) - require.Len(unmarshaled.Arr, 1000) + require.Empty(unmarshaled.Arr) } -func TestSliceWithEmptySerializationOutOfMemory(codec GeneralCodec, t testing.TB) { +func TestSliceWithEmptySerializationError(codec GeneralCodec, t testing.TB) { require := require.New(t) type emptyStruct struct{} @@ -776,14 +778,69 @@ func TestSliceWithEmptySerializationOutOfMemory(codec GeneralCodec, t testing.TB require.NoError(manager.RegisterCodec(0, codec)) val := &nestedSliceStruct{ - Arr: make([]emptyStruct, math.MaxInt32), + Arr: make([]emptyStruct, 1), } _, err := manager.Marshal(0, val) - require.ErrorIs(err, ErrMaxSliceLenExceeded) + require.ErrorIs(err, ErrMarshalZeroLength) + + _, err = manager.Size(0, val) + require.ErrorIs(err, ErrMarshalZeroLength) + + b := []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x01} // codec version (0x00, 0x00) then (0x00, 0x00, 0x00, 0x01) for numElts + + unmarshaled := nestedSliceStruct{} + _, err = manager.Unmarshal(b, &unmarshaled) + require.ErrorIs(err, ErrUnmarshalZeroLength) +} + +// Test marshaling empty map of zero length structs +func TestMapWithEmptySerialization(codec GeneralCodec, t testing.TB) { + require := require.New(t) + + type emptyStruct struct{} + + manager := NewDefaultManager() + require.NoError(manager.RegisterCodec(0, codec)) + + val := make(map[emptyStruct]emptyStruct) + expected := []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00} // codec version (0x00, 0x00) then (0x00, 0x00, 0x00, 0x00) for numElts + result, err := manager.Marshal(0, val) + require.NoError(err) + require.Equal(expected, result) bytesLen, err := manager.Size(0, val) require.NoError(err) - require.Equal(6, bytesLen) // 2 byte codec version + 4 byte length prefix + require.Len(result, bytesLen) + + var unmarshaled map[emptyStruct]emptyStruct + version, err := manager.Unmarshal(expected, &unmarshaled) + require.NoError(err) + require.Zero(version) + require.Empty(unmarshaled) +} + +func TestMapWithEmptySerializationError(codec GeneralCodec, t testing.TB) { + require := require.New(t) + + type emptyStruct struct{} + + manager := NewDefaultManager() + require.NoError(manager.RegisterCodec(0, codec)) + + val := map[emptyStruct]emptyStruct{ + {}: {}, + } + _, err := manager.Marshal(0, val) + require.ErrorIs(err, ErrMarshalZeroLength) + + _, err = manager.Size(0, val) + require.ErrorIs(err, ErrMarshalZeroLength) + + b := []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x01} // codec version (0x00, 0x00) then (0x00, 0x00, 0x00, 0x01) for numElts + + var unmarshaled map[emptyStruct]emptyStruct + _, err = manager.Unmarshal(b, &unmarshaled) + require.ErrorIs(err, ErrUnmarshalZeroLength) } func TestSliceTooLarge(codec GeneralCodec, t testing.TB) {