From fde8a3b8ea20ef9ac66c83e12113c39060c55ae3 Mon Sep 17 00:00:00 2001 From: Adam Sitnik Date: Wed, 18 Sep 2024 01:06:23 +0200 Subject: [PATCH] [release/9.0-rc2] NRBF Fuzzer and bug fixes (#107788) * [NRBF] Don't use Unsafe.As when decoding DateTime(s) (#105749) * Add NrbfDecoder Fuzzer (#107385) * [NRBF] Fix bugs discovered by the fuzzer (#107368) * bug #1: don't allow for values out of the SerializationRecordType enum range * bug #2: throw SerializationException rather than KeyNotFoundException when the referenced record is missing or it points to a record of different type * bug #3: throw SerializationException rather than FormatException when it's being thrown by BinaryReader (or sth else that we use) * bug #4: document the fact that IOException can be thrown * bug #5: throw SerializationException rather than OverflowException when parsing the decimal fails * bug #6: 0 and 17 are illegal values for PrimitiveType enum * bug #7: throw SerializationException when a surrogate character is read (so far an ArgumentException was thrown) # Conflicts: # src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/NrbfDecoder.cs * [NRBF] throw SerializationException when a surrogate character is read (#107532) (so far an ArgumentException was thrown) * [NRBF] Fuzzing non-seekable stream input (#107605) * [NRBF] More bug fixes (#107682) - Don't use `Debug.Fail` not followed by an exception (it may cause problems for apps deployed in Debug) - avoid Int32 overflow - throw for unexpected enum values just in case parsing has not rejected them - validate the number of chars read by BinaryReader.ReadChars - pass serialization record id to ex message - return false rather than throw EndOfStreamException when provided Stream has not enough data - don't restore the position in finally - limit max SZ and MD array length to Array.MaxLength, stop using LinkedList as List will be able to hold all elements now - remove internal enum values that were always illegal, but needed to be handled everywhere - Fix DebuggerDisplay * [NRBF] Comments and bug fixes from internal code review (#107735) * copy comments and asserts from Levis internal code review * apply Levis suggestion: don't store Array.MaxLength as a const, as it may change in the future * add missing and fix some of the existing comments * first bug fix: SerializationRecord.TypeNameMatches should throw ArgumentNullException for null Type argument * second bug fix: SerializationRecord.TypeNameMatches should know the difference between SZArray and single-dimension, non-zero offset arrays (example: int[] and int[*]) * third bug fix: don't cast bytes to booleans * fourth bug fix: don't cast bytes to DateTimes * add one test case that I've forgot in previous PR # Conflicts: # src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/SerializationRecord.cs * [NRBF] Address issues discovered by Threat Model (#106629) * introduce ArrayRecord.FlattenedLength * do not include invalid Type or Assembly names in the exception messages, as it's most likely corrupted/tampered/malicious data and could be used as a vector of attack. * It is possible to have binary array records have an element type of array without being marked as jagged --------- Co-authored-by: Buyaa Namnan --- .../libraries/fuzzing/deploy-to-onefuzz.yml | 8 + src/libraries/Fuzzing/DotnetFuzzing/Assert.cs | 11 ++ .../Dictionaries/nrbfdecoder.dict | 16 ++ .../DotnetFuzzing/DotnetFuzzing.csproj | 4 + .../Fuzzers/AssemblyNameInfoFuzzer.cs | 12 +- .../Fuzzers/NrbfDecoderFuzzer.cs | 126 +++++++++++++++ .../DotnetFuzzing/Fuzzers/TypeNameFuzzer.cs | 4 +- .../ref/System.Formats.Nrbf.cs | 1 + .../src/Resources/Strings.resx | 19 ++- .../System/Formats/Nrbf/AllowedRecordType.cs | 3 + .../src/System/Formats/Nrbf/ArrayInfo.cs | 18 ++- .../Formats/Nrbf/ArrayOfClassesRecord.cs | 4 + .../src/System/Formats/Nrbf/ArrayRecord.cs | 15 +- .../Formats/Nrbf/ArraySingleObjectRecord.cs | 7 +- .../Nrbf/ArraySinglePrimitiveRecord.cs | 116 ++++++++++---- .../Formats/Nrbf/ArraySingleStringRecord.cs | 9 +- .../System/Formats/Nrbf/BinaryArrayRecord.cs | 102 +++++++++++- .../Formats/Nrbf/BinaryLibraryRecord.cs | 11 +- .../src/System/Formats/Nrbf/ClassInfo.cs | 5 +- .../src/System/Formats/Nrbf/ClassTypeInfo.cs | 4 +- .../System/Formats/Nrbf/ClassWithIdRecord.cs | 5 +- .../Nrbf/ClassWithMembersAndTypesRecord.cs | 2 +- .../Formats/Nrbf/MemberReferenceRecord.cs | 2 +- .../src/System/Formats/Nrbf/MemberTypeInfo.cs | 21 ++- .../System/Formats/Nrbf/MessageEndRecord.cs | 9 +- .../src/System/Formats/Nrbf/NextInfo.cs | 4 +- .../src/System/Formats/Nrbf/NrbfDecoder.cs | 66 ++++---- .../src/System/Formats/Nrbf/NullsRecord.cs | 9 +- .../src/System/Formats/Nrbf/PayloadOptions.cs | 7 + .../src/System/Formats/Nrbf/PrimitiveType.cs | 22 ++- .../src/System/Formats/Nrbf/RecordMap.cs | 15 +- .../Formats/Nrbf/RectangularArrayRecord.cs | 56 ++++--- .../Formats/Nrbf/SerializationRecord.cs | 56 ++++++- .../Formats/Nrbf/SerializationRecordId.cs | 11 ++ .../Formats/Nrbf/SerializationRecordType.cs | 3 + .../Nrbf/SerializedStreamHeaderRecord.cs | 10 +- .../SystemClassWithMembersAndTypesRecord.cs | 19 +-- .../Nrbf/Utils/BinaryReaderExtensions.cs | 150 +++++++++++++++--- .../System/Formats/Nrbf/Utils/ThrowHelper.cs | 21 ++- .../Formats/Nrbf/Utils/TypeNameHelpers.cs | 31 +++- .../tests/ArraySinglePrimitiveRecordTests.cs | 48 ++++++ .../System.Formats.Nrbf/tests/AttackTests.cs | 2 +- .../tests/EdgeCaseTests.cs | 43 ++++- .../tests/InvalidInputTests.cs | 124 +++++++++++++++ .../tests/JaggedArraysTests.cs | 78 ++++++++- .../System.Formats.Nrbf/tests/ReadTests.cs | 4 +- .../tests/RectangularArraysTests.cs | 3 + .../tests/System.Formats.Nrbf.Tests.csproj | 2 + .../tests/TypeMatchTests.cs | 28 ++++ .../Reflection/Metadata/AssemblyNameInfo.cs | 8 + .../System/Reflection/Metadata/TypeName.cs | 26 ++- .../Reflection/Metadata/TypeNameParser.cs | 13 ++ .../Metadata/TypeNameParserHelpers.cs | 18 ++- .../Metadata/TypeNameParserOptions.cs | 7 + 54 files changed, 1179 insertions(+), 239 deletions(-) create mode 100644 src/libraries/Fuzzing/DotnetFuzzing/Dictionaries/nrbfdecoder.dict create mode 100644 src/libraries/Fuzzing/DotnetFuzzing/Fuzzers/NrbfDecoderFuzzer.cs diff --git a/eng/pipelines/libraries/fuzzing/deploy-to-onefuzz.yml b/eng/pipelines/libraries/fuzzing/deploy-to-onefuzz.yml index 18c5ad2b5e832..25baed9956a05 100644 --- a/eng/pipelines/libraries/fuzzing/deploy-to-onefuzz.yml +++ b/eng/pipelines/libraries/fuzzing/deploy-to-onefuzz.yml @@ -97,6 +97,14 @@ extends: SYSTEM_ACCESSTOKEN: $(System.AccessToken) displayName: Send JsonDocumentFuzzer to OneFuzz + - task: onefuzz-task@0 + inputs: + onefuzzOSes: 'Windows' + env: + onefuzzDropDirectory: $(fuzzerProject)/deployment/NrbfDecoderFuzzer + SYSTEM_ACCESSTOKEN: $(System.AccessToken) + displayName: Send NrbfDecoderFuzzer to OneFuzz + - task: onefuzz-task@0 inputs: onefuzzOSes: 'Windows' diff --git a/src/libraries/Fuzzing/DotnetFuzzing/Assert.cs b/src/libraries/Fuzzing/DotnetFuzzing/Assert.cs index 810174ccbc131..a5f2a9dd1d195 100644 --- a/src/libraries/Fuzzing/DotnetFuzzing/Assert.cs +++ b/src/libraries/Fuzzing/DotnetFuzzing/Assert.cs @@ -18,6 +18,17 @@ static void Throw(T expected, T actual) => throw new Exception($"Expected={expected} Actual={actual}"); } + public static void NotNull(T value) + { + if (value == null) + { + ThrowNull(); + } + + static void ThrowNull() => + throw new Exception("Value is null"); + } + public static void SequenceEqual(ReadOnlySpan expected, ReadOnlySpan actual) { if (!expected.SequenceEqual(actual)) diff --git a/src/libraries/Fuzzing/DotnetFuzzing/Dictionaries/nrbfdecoder.dict b/src/libraries/Fuzzing/DotnetFuzzing/Dictionaries/nrbfdecoder.dict new file mode 100644 index 0000000000000..1b8f14f961cc1 --- /dev/null +++ b/src/libraries/Fuzzing/DotnetFuzzing/Dictionaries/nrbfdecoder.dict @@ -0,0 +1,16 @@ +# "Hello World!" +"\x00\x01\x00\x00\x00\xFF\xFF\xFF\xFF\x01\x00\x00\x00\x00\x00\x00\x00\x06\x01\x00\x00\x00\x0C\x48\x65\x6C\x6C\x6F\x20\x57\x6F\x72\x6C\x64\x21\x0B" +# new DateTime(2024, 2, 29) +"\x00\x01\x00\x00\x00\xFF\xFF\xFF\xFF\x01\x00\x00\x00\x00\x00\x00\x00\x04\x01\x00\x00\x00\x0F\x53\x79\x73\x74\x65\x6D\x2E\x44\x61\x74\x65\x54\x69\x6D\x65\x02\x00\x00\x00\x05\x74\x69\x63\x6B\x73\x08\x64\x61\x74\x65\x44\x61\x74\x61\x00\x00\x09\x10\x00\x00\x60\x5F\xB9\x38\xDC\x08\x00\x00\x60\x5F\xB9\x38\xDC\x08\x0B" +# new int[] { 1, 2, 3 } +"\x00\x01\x00\x00\x00\xFF\xFF\xFF\xFF\x01\x00\x00\x00\x00\x00\x00\x00\x0F\x01\x00\x00\x00\x03\x00\x00\x00\x08\x01\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x0B" +# new object[] { int.MaxValue, "string", null } +"\x00\x01\x00\x00\x00\xFF\xFF\xFF\xFF\x01\x00\x00\x00\x00\x00\x00\x00\x10\x01\x00\x00\x00\x03\x00\x00\x00\x08\x08\xFF\xFF\xFF\x7F\x06\x02\x00\x00\x00\x06\x73\x74\x72\x69\x6E\x67\x0A\x0B" +# new int?[Array.MaxLength] (plenty of nulls) +"\x00\x01\x00\x00\x00\xFF\xFF\xFF\xFF\x01\x00\x00\x00\x00\x00\x00\x00\x07\x01\x00\x00\x00\x00\x01\x00\x00\x00\xC7\xFF\xFF\x7F\x03\x6E\x53\x79\x73\x74\x65\x6D\x2E\x4E\x75\x6C\x6C\x61\x62\x6C\x65\x60\x31\x5B\x5B\x53\x79\x73\x74\x65\x6D\x2E\x49\x6E\x74\x33\x32\x2C\x20\x6D\x73\x63\x6F\x72\x6C\x69\x62\x2C\x20\x56\x65\x72\x73\x69\x6F\x6E\x3D\x34\x2E\x30\x2E\x30\x2E\x30\x2C\x20\x43\x75\x6C\x74\x75\x72\x65\x3D\x6E\x65\x75\x74\x72\x61\x6C\x2C\x20\x50\x75\x62\x6C\x69\x63\x4B\x65\x79\x54\x6F\x6B\x65\x6E\x3D\x62\x37\x37\x61\x35\x63\x35\x36\x31\x39\x33\x34\x65\x30\x38\x39\x5D\x5D\x0E\xC7\xFF\xFF\x7F\x0B" +# [["jagged", "array"], ["of", "strings"]] +"\x00\x01\x00\x00\x00\xFF\xFF\xFF\xFF\x01\x00\x00\x00\x00\x00\x00\x00\x07\x01\x00\x00\x00\x01\x01\x00\x00\x00\x02\x00\x00\x00\x06\x09\x02\x00\x00\x00\x09\x03\x00\x00\x00\x11\x02\x00\x00\x00\x02\x00\x00\x00\x06\x04\x00\x00\x00\x06\x6A\x61\x67\x67\x65\x64\x06\x05\x00\x00\x00\x05\x61\x72\x72\x61\x79\x11\x03\x00\x00\x00\x02\x00\x00\x00\x06\x06\x00\x00\x00\x02\x6F\x66\x06\x07\x00\x00\x00\x07\x73\x74\x72\x69\x6E\x67\x73\x0B" +# new Dictionary { { "1", true }, { "2", false } } +"\x00\x01\x00\x00\x00\xFF\xFF\xFF\xFF\x01\x00\x00\x00\x00\x00\x00\x00\x04\x01\x00\x00\x00\xE3\x01\x53\x79\x73\x74\x65\x6D\x2E\x43\x6F\x6C\x6C\x65\x63\x74\x69\x6F\x6E\x73\x2E\x47\x65\x6E\x65\x72\x69\x63\x2E\x44\x69\x63\x74\x69\x6F\x6E\x61\x72\x79\x60\x32\x5B\x5B\x53\x79\x73\x74\x65\x6D\x2E\x53\x74\x72\x69\x6E\x67\x2C\x20\x6D\x73\x63\x6F\x72\x6C\x69\x62\x2C\x20\x56\x65\x72\x73\x69\x6F\x6E\x3D\x34\x2E\x30\x2E\x30\x2E\x30\x2C\x20\x43\x75\x6C\x74\x75\x72\x65\x3D\x6E\x65\x75\x74\x72\x61\x6C\x2C\x20\x50\x75\x62\x6C\x69\x63\x4B\x65\x79\x54\x6F\x6B\x65\x6E\x3D\x62\x37\x37\x61\x35\x63\x35\x36\x31\x39\x33\x34\x65\x30\x38\x39\x5D\x2C\x5B\x53\x79\x73\x74\x65\x6D\x2E\x42\x6F\x6F\x6C\x65\x61\x6E\x2C\x20\x6D\x73\x63\x6F\x72\x6C\x69\x62\x2C\x20\x56\x65\x72\x73\x69\x6F\x6E\x3D\x34\x2E\x30\x2E\x30\x2E\x30\x2C\x20\x43\x75\x6C\x74\x75\x72\x65\x3D\x6E\x65\x75\x74\x72\x61\x6C\x2C\x20\x50\x75\x62\x6C\x69\x63\x4B\x65\x79\x54\x6F\x6B\x65\x6E\x3D\x62\x37\x37\x61\x35\x63\x35\x36\x31\x39\x33\x34\x65\x30\x38\x39\x5D\x5D\x04\x00\x00\x00\x07\x56\x65\x72\x73\x69\x6F\x6E\x08\x43\x6F\x6D\x70\x61\x72\x65\x72\x08\x48\x61\x73\x68\x53\x69\x7A\x65\x0D\x4B\x65\x79\x56\x61\x6C\x75\x65\x50\x61\x69\x72\x73\x00\x03\x00\x03\x08\x92\x01\x53\x79\x73\x74\x65\x6D\x2E\x43\x6F\x6C\x6C\x65\x63\x74\x69\x6F\x6E\x73\x2E\x47\x65\x6E\x65\x72\x69\x63\x2E\x47\x65\x6E\x65\x72\x69\x63\x45\x71\x75\x61\x6C\x69\x74\x79\x43\x6F\x6D\x70\x61\x72\x65\x72\x60\x31\x5B\x5B\x53\x79\x73\x74\x65\x6D\x2E\x53\x74\x72\x69\x6E\x67\x2C\x20\x6D\x73\x63\x6F\x72\x6C\x69\x62\x2C\x20\x56\x65\x72\x73\x69\x6F\x6E\x3D\x34\x2E\x30\x2E\x30\x2E\x30\x2C\x20\x43\x75\x6C\x74\x75\x72\x65\x3D\x6E\x65\x75\x74\x72\x61\x6C\x2C\x20\x50\x75\x62\x6C\x69\x63\x4B\x65\x79\x54\x6F\x6B\x65\x6E\x3D\x62\x37\x37\x61\x35\x63\x35\x36\x31\x39\x33\x34\x65\x30\x38\x39\x5D\x5D\x08\xE7\x01\x53\x79\x73\x74\x65\x6D\x2E\x43\x6F\x6C\x6C\x65\x63\x74\x69\x6F\x6E\x73\x2E\x47\x65\x6E\x65\x72\x69\x63\x2E\x4B\x65\x79\x56\x61\x6C\x75\x65\x50\x61\x69\x72\x60\x32\x5B\x5B\x53\x79\x73\x74\x65\x6D\x2E\x53\x74\x72\x69\x6E\x67\x2C\x20\x6D\x73\x63\x6F\x72\x6C\x69\x62\x2C\x20\x56\x65\x72\x73\x69\x6F\x6E\x3D\x34\x2E\x30\x2E\x30\x2E\x30\x2C\x20\x43\x75\x6C\x74\x75\x72\x65\x3D\x6E\x65\x75\x74\x72\x61\x6C\x2C\x20\x50\x75\x62\x6C\x69\x63\x4B\x65\x79\x54\x6F\x6B\x65\x6E\x3D\x62\x37\x37\x61\x35\x63\x35\x36\x31\x39\x33\x34\x65\x30\x38\x39\x5D\x2C\x5B\x53\x79\x73\x74\x65\x6D\x2E\x42\x6F\x6F\x6C\x65\x61\x6E\x2C\x20\x6D\x73\x63\x6F\x72\x6C\x69\x62\x2C\x20\x56\x65\x72\x73\x69\x6F\x6E\x3D\x34\x2E\x30\x2E\x30\x2E\x30\x2C\x20\x43\x75\x6C\x74\x75\x72\x65\x3D\x6E\x65\x75\x74\x72\x61\x6C\x2C\x20\x50\x75\x62\x6C\x69\x63\x4B\x65\x79\x54\x6F\x6B\x65\x6E\x3D\x62\x37\x37\x61\x35\x63\x35\x36\x31\x39\x33\x34\x65\x30\x38\x39\x5D\x5D\x5B\x5D\x02\x00\x00\x00\x09\x02\x00\x00\x00\x03\x00\x00\x00\x09\x03\x00\x00\x00\x04\x02\x00\x00\x00\x92\x01\x53\x79\x73\x74\x65\x6D\x2E\x43\x6F\x6C\x6C\x65\x63\x74\x69\x6F\x6E\x73\x2E\x47\x65\x6E\x65\x72\x69\x63\x2E\x47\x65\x6E\x65\x72\x69\x63\x45\x71\x75\x61\x6C\x69\x74\x79\x43\x6F\x6D\x70\x61\x72\x65\x72\x60\x31\x5B\x5B\x53\x79\x73\x74\x65\x6D\x2E\x53\x74\x72\x69\x6E\x67\x2C\x20\x6D\x73\x63\x6F\x72\x6C\x69\x62\x2C\x20\x56\x65\x72\x73\x69\x6F\x6E\x3D\x34\x2E\x30\x2E\x30\x2E\x30\x2C\x20\x43\x75\x6C\x74\x75\x72\x65\x3D\x6E\x65\x75\x74\x72\x61\x6C\x2C\x20\x50\x75\x62\x6C\x69\x63\x4B\x65\x79\x54\x6F\x6B\x65\x6E\x3D\x62\x37\x37\x61\x35\x63\x35\x36\x31\x39\x33\x34\x65\x30\x38\x39\x5D\x5D\x00\x00\x00\x00\x07\x03\x00\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x03\xE5\x01\x53\x79\x73\x74\x65\x6D\x2E\x43\x6F\x6C\x6C\x65\x63\x74\x69\x6F\x6E\x73\x2E\x47\x65\x6E\x65\x72\x69\x63\x2E\x4B\x65\x79\x56\x61\x6C\x75\x65\x50\x61\x69\x72\x60\x32\x5B\x5B\x53\x79\x73\x74\x65\x6D\x2E\x53\x74\x72\x69\x6E\x67\x2C\x20\x6D\x73\x63\x6F\x72\x6C\x69\x62\x2C\x20\x56\x65\x72\x73\x69\x6F\x6E\x3D\x34\x2E\x30\x2E\x30\x2E\x30\x2C\x20\x43\x75\x6C\x74\x75\x72\x65\x3D\x6E\x65\x75\x74\x72\x61\x6C\x2C\x20\x50\x75\x62\x6C\x69\x63\x4B\x65\x79\x54\x6F\x6B\x65\x6E\x3D\x62\x37\x37\x61\x35\x63\x35\x36\x31\x39\x33\x34\x65\x30\x38\x39\x5D\x2C\x5B\x53\x79\x73\x74\x65\x6D\x2E\x42\x6F\x6F\x6C\x65\x61\x6E\x2C\x20\x6D\x73\x63\x6F\x72\x6C\x69\x62\x2C\x20\x56\x65\x72\x73\x69\x6F\x6E\x3D\x34\x2E\x30\x2E\x30\x2E\x30\x2C\x20\x43\x75\x6C\x74\x75\x72\x65\x3D\x6E\x65\x75\x74\x72\x61\x6C\x2C\x20\x50\x75\x62\x6C\x69\x63\x4B\x65\x79\x54\x6F\x6B\x65\x6E\x3D\x62\x37\x37\x61\x35\x63\x35\x36\x31\x39\x33\x34\x65\x30\x38\x39\x5D\x5D\x04\xFC\xFF\xFF\xFF\xE5\x01\x53\x79\x73\x74\x65\x6D\x2E\x43\x6F\x6C\x6C\x65\x63\x74\x69\x6F\x6E\x73\x2E\x47\x65\x6E\x65\x72\x69\x63\x2E\x4B\x65\x79\x56\x61\x6C\x75\x65\x50\x61\x69\x72\x60\x32\x5B\x5B\x53\x79\x73\x74\x65\x6D\x2E\x53\x74\x72\x69\x6E\x67\x2C\x20\x6D\x73\x63\x6F\x72\x6C\x69\x62\x2C\x20\x56\x65\x72\x73\x69\x6F\x6E\x3D\x34\x2E\x30\x2E\x30\x2E\x30\x2C\x20\x43\x75\x6C\x74\x75\x72\x65\x3D\x6E\x65\x75\x74\x72\x61\x6C\x2C\x20\x50\x75\x62\x6C\x69\x63\x4B\x65\x79\x54\x6F\x6B\x65\x6E\x3D\x62\x37\x37\x61\x35\x63\x35\x36\x31\x39\x33\x34\x65\x30\x38\x39\x5D\x2C\x5B\x53\x79\x73\x74\x65\x6D\x2E\x42\x6F\x6F\x6C\x65\x61\x6E\x2C\x20\x6D\x73\x63\x6F\x72\x6C\x69\x62\x2C\x20\x56\x65\x72\x73\x69\x6F\x6E\x3D\x34\x2E\x30\x2E\x30\x2E\x30\x2C\x20\x43\x75\x6C\x74\x75\x72\x65\x3D\x6E\x65\x75\x74\x72\x61\x6C\x2C\x20\x50\x75\x62\x6C\x69\x63\x4B\x65\x79\x54\x6F\x6B\x65\x6E\x3D\x62\x37\x37\x61\x35\x63\x35\x36\x31\x39\x33\x34\x65\x30\x38\x39\x5D\x5D\x02\x00\x00\x00\x03\x6B\x65\x79\x05\x76\x61\x6C\x75\x65\x01\x00\x01\x06\x05\x00\x00\x00\x01\x31\x01\x01\xFA\xFF\xFF\xFF\xFC\xFF\xFF\xFF\x06\x07\x00\x00\x00\x01\x32\x00\x0B" +# new ComplexType2D { I = 1, J = 2 } (non-system class) +"\x00\x01\x00\x00\x00\xFF\xFF\xFF\xFF\x01\x00\x00\x00\x00\x00\x00\x00\x0C\x02\x00\x00\x00\x3D\x42\x66\x44\x65\x6D\x6F\x2C\x20\x56\x65\x72\x73\x69\x6F\x6E\x3D\x31\x2E\x30\x2E\x30\x2E\x30\x2C\x20\x43\x75\x6C\x74\x75\x72\x65\x3D\x6E\x65\x75\x74\x72\x61\x6C\x2C\x20\x50\x75\x62\x6C\x69\x63\x4B\x65\x79\x54\x6F\x6B\x65\x6E\x3D\x6E\x75\x6C\x6C\x05\x01\x00\x00\x00\x14\x42\x66\x44\x65\x6D\x6F\x2E\x43\x6F\x6D\x70\x6C\x65\x78\x54\x79\x70\x65\x32\x44\x02\x00\x00\x00\x01\x49\x01\x4A\x00\x00\x08\x08\x02\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x0B" \ No newline at end of file diff --git a/src/libraries/Fuzzing/DotnetFuzzing/DotnetFuzzing.csproj b/src/libraries/Fuzzing/DotnetFuzzing/DotnetFuzzing.csproj index 6db5eb5a9ed04..84231742ee2fd 100644 --- a/src/libraries/Fuzzing/DotnetFuzzing/DotnetFuzzing.csproj +++ b/src/libraries/Fuzzing/DotnetFuzzing/DotnetFuzzing.csproj @@ -30,4 +30,8 @@ + + + + diff --git a/src/libraries/Fuzzing/DotnetFuzzing/Fuzzers/AssemblyNameInfoFuzzer.cs b/src/libraries/Fuzzing/DotnetFuzzing/Fuzzers/AssemblyNameInfoFuzzer.cs index 9ce1bd255c7b6..d166726665afe 100644 --- a/src/libraries/Fuzzing/DotnetFuzzing/Fuzzers/AssemblyNameInfoFuzzer.cs +++ b/src/libraries/Fuzzing/DotnetFuzzing/Fuzzers/AssemblyNameInfoFuzzer.cs @@ -24,15 +24,15 @@ public void FuzzTarget(ReadOnlySpan bytes) using PooledBoundedMemory inputPoisonedBefore = PooledBoundedMemory.Rent(chars, PoisonPagePlacement.Before); using PooledBoundedMemory inputPoisonedAfter = PooledBoundedMemory.Rent(chars, PoisonPagePlacement.After); - Test(inputPoisonedBefore); - Test(inputPoisonedAfter); + Test(inputPoisonedBefore.Span); + Test(inputPoisonedAfter.Span); } - private static void Test(PooledBoundedMemory inputPoisoned) + private static void Test(Span span) { - if (AssemblyNameInfo.TryParse(inputPoisoned.Span, out AssemblyNameInfo? fromTryParse)) + if (AssemblyNameInfo.TryParse(span, out AssemblyNameInfo? fromTryParse)) { - AssemblyNameInfo fromParse = AssemblyNameInfo.Parse(inputPoisoned.Span); + AssemblyNameInfo fromParse = AssemblyNameInfo.Parse(span); Assert.Equal(fromTryParse.Name, fromParse.Name); Assert.Equal(fromTryParse.FullName, fromParse.FullName); @@ -66,7 +66,7 @@ private static void Test(PooledBoundedMemory inputPoisoned) { try { - _ = AssemblyNameInfo.Parse(inputPoisoned.Span); + _ = AssemblyNameInfo.Parse(span); } catch (ArgumentException) { diff --git a/src/libraries/Fuzzing/DotnetFuzzing/Fuzzers/NrbfDecoderFuzzer.cs b/src/libraries/Fuzzing/DotnetFuzzing/Fuzzers/NrbfDecoderFuzzer.cs new file mode 100644 index 0000000000000..5c6397187ff64 --- /dev/null +++ b/src/libraries/Fuzzing/DotnetFuzzing/Fuzzers/NrbfDecoderFuzzer.cs @@ -0,0 +1,126 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; +using System.Formats.Nrbf; +using System.Runtime.Serialization; +using System.Text; + +namespace DotnetFuzzing.Fuzzers +{ + internal sealed class NrbfDecoderFuzzer : IFuzzer + { + public string[] TargetAssemblies { get; } = ["System.Formats.Nrbf"]; + + public string[] TargetCoreLibPrefixes => []; + + public string Dictionary => "nrbfdecoder.dict"; + + public void FuzzTarget(ReadOnlySpan bytes) + { + Test(bytes, PoisonPagePlacement.Before); + Test(bytes, PoisonPagePlacement.After); + } + + private static void Test(ReadOnlySpan bytes, PoisonPagePlacement poisonPagePlacement) + { + using PooledBoundedMemory inputPoisoned = PooledBoundedMemory.Rent(bytes, poisonPagePlacement); + + using MemoryStream seekableStream = new(inputPoisoned.Memory.ToArray()); + Test(inputPoisoned.Span, seekableStream); + + // NrbfDecoder has few code paths dedicated to non-seekable streams, let's test them as well. + using NonSeekableStream nonSeekableStream = new(inputPoisoned.Memory.ToArray()); + Test(inputPoisoned.Span, nonSeekableStream); + } + + private static void Test(Span testSpan, Stream stream) + { + if (NrbfDecoder.StartsWithPayloadHeader(testSpan)) + { + try + { + SerializationRecord record = NrbfDecoder.Decode(stream, out IReadOnlyDictionary recordMap); + switch (record.RecordType) + { + case SerializationRecordType.ArraySingleObject: + SZArrayRecord arrayObj = (SZArrayRecord)record; + object?[] objArray = arrayObj.GetArray(); + Assert.Equal(arrayObj.Length, objArray.Length); + Assert.Equal(1, arrayObj.Rank); + break; + case SerializationRecordType.ArraySingleString: + SZArrayRecord arrayString = (SZArrayRecord)record; + string?[] array = arrayString.GetArray(); + Assert.Equal(arrayString.Length, array.Length); + Assert.Equal(1, arrayString.Rank); + Assert.Equal(true, arrayString.TypeNameMatches(typeof(string[]))); + break; + case SerializationRecordType.ArraySinglePrimitive: + case SerializationRecordType.BinaryArray: + ArrayRecord arrayBinary = (ArrayRecord)record; + Assert.NotNull(arrayBinary.TypeName); + break; + case SerializationRecordType.BinaryObjectString: + _ = ((PrimitiveTypeRecord)record).Value; + break; + case SerializationRecordType.ClassWithId: + case SerializationRecordType.ClassWithMembersAndTypes: + case SerializationRecordType.SystemClassWithMembersAndTypes: + ClassRecord classRecord = (ClassRecord)record; + Assert.NotNull(classRecord.TypeName); + + foreach (string name in classRecord.MemberNames) + { + Assert.Equal(true, classRecord.HasMember(name)); + } + break; + case SerializationRecordType.MemberPrimitiveTyped: + PrimitiveTypeRecord primitiveType = (PrimitiveTypeRecord)record; + Assert.NotNull(primitiveType.Value); + break; + case SerializationRecordType.MemberReference: + Assert.NotNull(record.TypeName); + break; + case SerializationRecordType.BinaryLibrary: + Assert.Equal(false, record.Id.Equals(default)); + break; + case SerializationRecordType.ObjectNull: + case SerializationRecordType.ObjectNullMultiple: + case SerializationRecordType.ObjectNullMultiple256: + Assert.Equal(default, record.Id); + break; + case SerializationRecordType.MessageEnd: + case SerializationRecordType.SerializedStreamHeader: + // case SerializationRecordType.ClassWithMembers: will cause NotSupportedException + // case SerializationRecordType.SystemClassWithMembers: will cause NotSupportedException + default: + throw new Exception("Unexpected RecordType"); + } + } + catch (SerializationException) { /* Reading from the stream encountered invalid NRBF data.*/ } + catch (NotSupportedException) { /* Reading from the stream encountered unsupported records */ } + catch (DecoderFallbackException) { /* Reading from the stream encountered an invalid UTF8 sequence. */ } + catch (EndOfStreamException) { /* The end of the stream was reached before reading SerializationRecordType.MessageEnd record. */ } + catch (IOException) { /* An I/O error occurred. */ } + } + else + { + try + { + NrbfDecoder.Decode(stream); + throw new Exception("Decoding supposed to fail!"); + } + catch (SerializationException) { /* Everything has to start with a header */ } + catch (NotSupportedException) { /* Reading from the stream encountered unsupported records */ } + catch (EndOfStreamException) { /* The end of the stream was reached before reading SerializationRecordType.MessageEnd record. */ } + } + } + + private class NonSeekableStream : MemoryStream + { + public NonSeekableStream(byte[] buffer) : base(buffer) { } + public override bool CanSeek => false; + } + } +} diff --git a/src/libraries/Fuzzing/DotnetFuzzing/Fuzzers/TypeNameFuzzer.cs b/src/libraries/Fuzzing/DotnetFuzzing/Fuzzers/TypeNameFuzzer.cs index f8b3e96083707..0a189da4f18af 100644 --- a/src/libraries/Fuzzing/DotnetFuzzing/Fuzzers/TypeNameFuzzer.cs +++ b/src/libraries/Fuzzing/DotnetFuzzing/Fuzzers/TypeNameFuzzer.cs @@ -3,8 +3,6 @@ using System.Buffers; using System.Reflection.Metadata; -using System.Runtime.InteropServices; -using System.Runtime.InteropServices.Marshalling; using System.Text; namespace DotnetFuzzing.Fuzzers @@ -55,7 +53,7 @@ private static void Test(Span testSpan) try { TypeName.Parse(testSpan); - Assert.Equal(true, false); // should never succeed + throw new Exception("Parsing was supposed to fail!"); } catch (ArgumentException) { } catch (InvalidOperationException) { } diff --git a/src/libraries/System.Formats.Nrbf/ref/System.Formats.Nrbf.cs b/src/libraries/System.Formats.Nrbf/ref/System.Formats.Nrbf.cs index 8e12cf7c3712f..d7a6e01a72352 100644 --- a/src/libraries/System.Formats.Nrbf/ref/System.Formats.Nrbf.cs +++ b/src/libraries/System.Formats.Nrbf/ref/System.Formats.Nrbf.cs @@ -11,6 +11,7 @@ public abstract partial class ArrayRecord : System.Formats.Nrbf.SerializationRec internal ArrayRecord() { } public override System.Formats.Nrbf.SerializationRecordId Id { get { throw null; } } public abstract System.ReadOnlySpan Lengths { get; } + public virtual long FlattenedLength { get; } public int Rank { get { throw null; } } [System.Diagnostics.CodeAnalysis.RequiresDynamicCode("The code for an array of the specified type might not be available.")] public System.Array GetArray(System.Type expectedArrayType, bool allowNulls = true) { throw null; } diff --git a/src/libraries/System.Formats.Nrbf/src/Resources/Strings.resx b/src/libraries/System.Formats.Nrbf/src/Resources/Strings.resx index 349405150c65e..c6085fff72398 100644 --- a/src/libraries/System.Formats.Nrbf/src/Resources/Strings.resx +++ b/src/libraries/System.Formats.Nrbf/src/Resources/Strings.resx @@ -126,26 +126,23 @@ Unexpected Null Record count. - - The serialized array length ({0}) was larger than the configured limit {1}. - {0} Record Type is not supported by design. - Member reference was pointing to a record of unexpected type. + Invalid member reference. - Invalid type name: `{0}`. + Invalid type name. Expected the array to be of type {0}, but its element type was {1}. - Invalid type or assembly name: `{0},{1}`. + Invalid type or assembly name. - Duplicate member name: `{0}`. + Duplicate member name. Stream does not support seeking. @@ -160,6 +157,12 @@ Only arrays with zero offsets are supported. - Invalid assembly name: `{0}`. + Invalid assembly name. + + + Invalid format. + + + A surrogate character was read. \ No newline at end of file diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/AllowedRecordType.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/AllowedRecordType.cs index 8a3b304610555..063a243078206 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/AllowedRecordType.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/AllowedRecordType.cs @@ -3,6 +3,9 @@ namespace System.Formats.Nrbf; +// See [MS-NRBF] Sec. 2.7 for more information. +// https://learn.microsoft.com/openspecs/windows_protocols/ms-nrbf/ca3ad2bc-777b-413a-a72a-9ba6ced76bc3 + [Flags] internal enum AllowedRecordTypes : uint { diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayInfo.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayInfo.cs index e8b28825888e4..da03a459f35aa 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayInfo.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayInfo.cs @@ -13,22 +13,26 @@ namespace System.Formats.Nrbf; /// /// ArrayInfo structures are described in [MS-NRBF] 2.4.2.1. /// -[DebuggerDisplay("Length={Length}, {ArrayType}, rank={Rank}")] +[DebuggerDisplay("{ArrayType}, rank={Rank}")] internal readonly struct ArrayInfo { - internal const int MaxArrayLength = 2147483591; // Array.MaxLength +#if NET8_0_OR_GREATER + internal static int MaxArrayLength => Array.MaxLength; // dynamic lookup in case the value changes in a future runtime +#else + internal const int MaxArrayLength = 2147483591; // hardcode legacy Array.MaxLength for downlevel runtimes +#endif internal ArrayInfo(SerializationRecordId id, long totalElementsCount, BinaryArrayType arrayType = BinaryArrayType.Single, int rank = 1) { Id = id; - TotalElementsCount = totalElementsCount; + FlattenedLength = totalElementsCount; ArrayType = arrayType; Rank = rank; } internal SerializationRecordId Id { get; } - internal long TotalElementsCount { get; } + internal long FlattenedLength { get; } internal BinaryArrayType ArrayType { get; } @@ -36,8 +40,8 @@ internal ArrayInfo(SerializationRecordId id, long totalElementsCount, BinaryArra internal int GetSZArrayLength() { - Debug.Assert(TotalElementsCount <= MaxArrayLength); - return (int)TotalElementsCount; + Debug.Assert(FlattenedLength <= MaxArrayLength); + return (int)FlattenedLength; } internal static ArrayInfo Decode(BinaryReader reader) @@ -47,7 +51,7 @@ internal static int ParseValidArrayLength(BinaryReader reader) { int length = reader.ReadInt32(); - if (length is < 0 or > MaxArrayLength) + if (length < 0 || length > MaxArrayLength) { ThrowHelper.ThrowInvalidValue(length); } diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayOfClassesRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayOfClassesRecord.cs index 46e066bd39dbb..f345292c693a6 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayOfClassesRecord.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayOfClassesRecord.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.Reflection.Metadata; using System.Formats.Nrbf.Utils; +using System.Diagnostics; namespace System.Formats.Nrbf; @@ -54,6 +55,7 @@ public override TypeName TypeName } int nullCount = ((NullsRecord)actual).NullCount; + Debug.Assert(nullCount > 0, "All implementations of NullsRecord are expected to return a positive value for NullCount."); do { result[resultIndex++] = null; @@ -63,6 +65,8 @@ public override TypeName TypeName } } + Debug.Assert(resultIndex == result.Length, "We should have traversed the entirety of the newly created array."); + return result; } diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayRecord.cs index ddfd91a29fb1a..237b7b72a2719 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayRecord.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayRecord.cs @@ -18,7 +18,7 @@ public abstract class ArrayRecord : SerializationRecord private protected ArrayRecord(ArrayInfo arrayInfo) { ArrayInfo = arrayInfo; - ValuesToRead = arrayInfo.TotalElementsCount; + ValuesToRead = arrayInfo.FlattenedLength; } /// @@ -27,6 +27,12 @@ private protected ArrayRecord(ArrayInfo arrayInfo) /// A buffer of integers that represent the number of elements in every dimension. public abstract ReadOnlySpan Lengths { get; } + /// + /// When overridden in a derived class, gets the total number of all elements in every dimension. + /// + /// A number that represent the total number of all elements in every dimension. + public virtual long FlattenedLength => ArrayInfo.FlattenedLength; + /// /// Gets the rank of the array. /// @@ -44,7 +50,12 @@ private protected ArrayRecord(ArrayInfo arrayInfo) internal long ValuesToRead { get; private protected set; } - private protected ArrayInfo ArrayInfo { get; } + internal ArrayInfo ArrayInfo { get; } + + internal bool IsJagged + => ArrayInfo.ArrayType == BinaryArrayType.Jagged + // It is possible to have binary array records have an element type of array without being marked as jagged. + || TypeName.GetElementType().IsArray; /// /// Allocates an array and fills it with the data provided in the serialized records (in case of primitive types like or ) or the serialized records themselves. diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySingleObjectRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySingleObjectRecord.cs index 37e94842719a9..d0276ff3782e3 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySingleObjectRecord.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySingleObjectRecord.cs @@ -5,6 +5,7 @@ using System.IO; using System.Reflection.Metadata; using System.Formats.Nrbf.Utils; +using System.Diagnostics; namespace System.Formats.Nrbf; @@ -33,13 +34,15 @@ public override TypeName TypeName { object?[] values = new object?[Length]; - for (int recordIndex = 0, valueIndex = 0; recordIndex < Records.Count; recordIndex++) + int valueIndex = 0; + for (int recordIndex = 0; recordIndex < Records.Count; recordIndex++) { SerializationRecord record = Records[recordIndex]; int nullCount = record is NullsRecord nullsRecord ? nullsRecord.NullCount : 0; if (nullCount == 0) { + // "new object[] { }" is special cased because it allows for storing reference to itself. values[valueIndex++] = record is MemberReferenceRecord referenceRecord && referenceRecord.Reference.Equals(Id) ? values // a reference to self, and a way to get StackOverflow exception ;) : record.GetValue(); @@ -59,6 +62,8 @@ public override TypeName TypeName while (nullCount > 0); } + Debug.Assert(valueIndex == values.Length, "We should have traversed the entirety of the newly created array."); + return values; } diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySinglePrimitiveRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySinglePrimitiveRecord.cs index ee3a7916b069c..a13507b97015a 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySinglePrimitiveRecord.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySinglePrimitiveRecord.cs @@ -41,17 +41,9 @@ internal ArraySinglePrimitiveRecord(ArrayInfo arrayInfo, IReadOnlyList values public override T[] GetArray(bool allowNulls = true) => (T[])(_arrayNullsNotAllowed ??= (Values is T[] array ? array : Values.ToArray())); - internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetAllowedRecordType() - { - Debug.Fail("GetAllowedRecordType should never be called on ArraySinglePrimitiveRecord"); - throw new InvalidOperationException(); - } + internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetAllowedRecordType() => throw new InvalidOperationException(); - private protected override void AddValue(object value) - { - Debug.Fail("AddValue should never be called on ArraySinglePrimitiveRecord"); - throw new InvalidOperationException(); - } + private protected override void AddValue(object value) => throw new InvalidOperationException(); internal static IReadOnlyList DecodePrimitiveTypes(BinaryReader reader, int count) { @@ -61,8 +53,32 @@ internal static IReadOnlyList DecodePrimitiveTypes(BinaryReader reader, int c return (List)(object)DecodeDecimals(reader, count); } + // char[] has a unique representation in NRBF streams. Typical strings are transcoded + // to UTF-8 and prefixed with the number of bytes in the UTF-8 representation. char[] + // is also serialized as UTF-8, but it is instead prefixed with the number of chars + // in the UTF-16 representation, not the number of bytes in the UTF-8 representation. + // This number doesn't directly precede the UTF-8 contents in the NRBF stream; it's + // instead contained within the ArrayInfo structure (passed to this method as the + // 'count' argument). + // + // The practical consequence of this is that we don't actually know how many UTF-8 + // bytes we need to consume in order to ensure we've read 'count' chars. We know that + // an n-length UTF-16 string turns into somewhere between [n .. 3n] UTF-8 bytes. + // The best we can do is that when reading an n-element char[], we'll ensure that + // there are at least n bytes remaining in the input stream. We'll still need to + // account for that even with this check, we might hit EOF before fully populating + // the char[]. But from a safety perspective, it does appropriately limit our + // allocations to be proportional to the amount of data present in the input stream, + // which is a sufficient defense against DoS. + long requiredBytes = count; - if (typeof(T) != typeof(char)) // the input is UTF8 + if (typeof(T) == typeof(DateTime) || typeof(T) == typeof(TimeSpan)) + { + // We can't assume DateTime as represented by the runtime is 8 bytes. + // The only assumption we can make is that it's 8 bytes on the wire. + requiredBytes *= 8; + } + else if (typeof(T) != typeof(char)) { requiredBytes *= Unsafe.SizeOf(); } @@ -85,7 +101,11 @@ internal static IReadOnlyList DecodePrimitiveTypes(BinaryReader reader, int c } else if (typeof(T) == typeof(char)) { - return (T[])(object)reader.ReadChars(count); + return (T[])(object)reader.ParseChars(count); + } + else if (typeof(T) == typeof(TimeSpan) || typeof(T) == typeof(DateTime)) + { + return DecodeTime(reader, count); } // It's safe to pre-allocate, as we have ensured there is enough bytes in the stream. @@ -94,7 +114,7 @@ internal static IReadOnlyList DecodePrimitiveTypes(BinaryReader reader, int c #if NET reader.BaseStream.ReadExactly(resultAsBytes); #else - byte[] bytes = ArrayPool.Shared.Rent(Math.Min(count * Unsafe.SizeOf(), 256_000)); + byte[] bytes = ArrayPool.Shared.Rent((int)Math.Min(requiredBytes, 256_000)); while (!resultAsBytes.IsEmpty) { @@ -138,8 +158,7 @@ internal static IReadOnlyList DecodePrimitiveTypes(BinaryReader reader, int c } #endif } - else if (typeof(T) == typeof(long) || typeof(T) == typeof(ulong) || typeof(T) == typeof(double) - || typeof(T) == typeof(DateTime) || typeof(T) == typeof(TimeSpan)) + else if (typeof(T) == typeof(long) || typeof(T) == typeof(ulong) || typeof(T) == typeof(double)) { Span span = MemoryMarshal.Cast(result); #if NET @@ -153,37 +172,62 @@ internal static IReadOnlyList DecodePrimitiveTypes(BinaryReader reader, int c } } + if (typeof(T) == typeof(bool)) + { + // See DontCastBytesToBooleans test to see what could go wrong. + bool[] booleans = (bool[])(object)result; + resultAsBytes = MemoryMarshal.AsBytes(result); + for (int i = 0; i < booleans.Length; i++) + { + // We don't use the bool array to get the value, as an optimizing compiler or JIT could elide this. + if (resultAsBytes[i] != 0) // it can be any byte different than 0 + { + booleans[i] = true; // set it to 1 in explicit way + } + } + } + return result; } private static List DecodeDecimals(BinaryReader reader, int count) { List values = new(); -#if NET - Span buffer = stackalloc byte[256]; for (int i = 0; i < count; i++) { - int stringLength = reader.Read7BitEncodedInt(); - if (!(stringLength > 0 && stringLength <= buffer.Length)) - { - ThrowHelper.ThrowInvalidValue(stringLength); - } - - reader.BaseStream.ReadExactly(buffer.Slice(0, stringLength)); - - values.Add(decimal.Parse(buffer.Slice(0, stringLength), CultureInfo.InvariantCulture)); + values.Add(reader.ParseDecimal()); } -#else - for (int i = 0; i < count; i++) + return values; + } + + private static T[] DecodeTime(BinaryReader reader, int count) + { + T[] values = new T[count]; + for (int i = 0; i < values.Length; i++) { - values.Add(decimal.Parse(reader.ReadString(), CultureInfo.InvariantCulture)); + if (typeof(T) == typeof(DateTime)) + { + values[i] = (T)(object)Utils.BinaryReaderExtensions.CreateDateTimeFromData(reader.ReadUInt64()); + } + else if (typeof(T) == typeof(TimeSpan)) + { + values[i] = (T)(object)new TimeSpan(reader.ReadInt64()); + } + else + { + throw new InvalidOperationException(); + } } -#endif + return values; } private static List DecodeFromNonSeekableStream(BinaryReader reader, int count) { + // The count arg could originate from untrusted input, so we shouldn't + // pass it as-is to the ctor's capacity arg. We'll instead rely on + // List.Add's O(1) amortization to keep the entire loop O(count). + List values = new List(Math.Min(count, 4)); for (int i = 0; i < count; i++) { @@ -201,7 +245,7 @@ private static List DecodeFromNonSeekableStream(BinaryReader reader, int coun } else if (typeof(T) == typeof(char)) { - values.Add((T)(object)reader.ReadChar()); + values.Add((T)(object)reader.ParseChar()); } else if (typeof(T) == typeof(short)) { @@ -237,14 +281,16 @@ private static List DecodeFromNonSeekableStream(BinaryReader reader, int coun } else if (typeof(T) == typeof(DateTime)) { - values.Add((T)(object)Utils.BinaryReaderExtensions.CreateDateTimeFromData(reader.ReadInt64())); + values.Add((T)(object)Utils.BinaryReaderExtensions.CreateDateTimeFromData(reader.ReadUInt64())); } - else + else if (typeof(T) == typeof(TimeSpan)) { - Debug.Assert(typeof(T) == typeof(TimeSpan)); - values.Add((T)(object)new TimeSpan(reader.ReadInt64())); } + else + { + throw new InvalidOperationException(); + } } return values; diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySingleStringRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySingleStringRecord.cs index de248bcef7675..42b9eadd97bd5 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySingleStringRecord.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySingleStringRecord.cs @@ -5,6 +5,7 @@ using System.IO; using System.Reflection.Metadata; using System.Formats.Nrbf.Utils; +using System.Diagnostics; namespace System.Formats.Nrbf; @@ -21,7 +22,7 @@ internal sealed class ArraySingleStringRecord : SZArrayRecord public override SerializationRecordType RecordType => SerializationRecordType.ArraySingleString; /// - public override TypeName TypeName => TypeNameHelpers.GetPrimitiveSZArrayTypeName(PrimitiveType.String); + public override TypeName TypeName => TypeNameHelpers.GetPrimitiveSZArrayTypeName(TypeNameHelpers.StringPrimitiveType); private List Records { get; } @@ -47,7 +48,8 @@ internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetA { string?[] values = new string?[Length]; - for (int recordIndex = 0, valueIndex = 0; recordIndex < Records.Count; recordIndex++) + int valueIndex = 0; + for (int recordIndex = 0; recordIndex < Records.Count; recordIndex++) { SerializationRecord record = Records[recordIndex]; @@ -73,6 +75,7 @@ record = memberReference.GetReferencedRecord(); } int nullCount = ((NullsRecord)record).NullCount; + Debug.Assert(nullCount > 0, "All implementations of NullsRecord are expected to return a positive value for NullCount."); do { values[valueIndex++] = null; @@ -81,6 +84,8 @@ record = memberReference.GetReferencedRecord(); while (nullCount > 0); } + Debug.Assert(valueIndex == values.Length, "We should have traversed the entirety of the newly created array."); + return values; } } diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/BinaryArrayRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/BinaryArrayRecord.cs index 0c7e04e840a48..41b1f73f03550 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/BinaryArrayRecord.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/BinaryArrayRecord.cs @@ -6,6 +6,7 @@ using System.IO; using System.Reflection.Metadata; using System.Formats.Nrbf.Utils; +using System.Diagnostics; namespace System.Formats.Nrbf; @@ -27,12 +28,15 @@ internal sealed class BinaryArrayRecord : ArrayRecord ]; private TypeName? _typeName; + private long _totalElementsCount; private BinaryArrayRecord(ArrayInfo arrayInfo, MemberTypeInfo memberTypeInfo) : base(arrayInfo) { MemberTypeInfo = memberTypeInfo; Values = []; + // We need to parse all elements of the jagged array to obtain total elements count. + _totalElementsCount = -1; } public override SerializationRecordType RecordType => SerializationRecordType.BinaryArray; @@ -40,6 +44,22 @@ private BinaryArrayRecord(ArrayInfo arrayInfo, MemberTypeInfo memberTypeInfo) /// public override ReadOnlySpan Lengths => new int[1] { Length }; + /// + public override long FlattenedLength + { + get + { + if (_totalElementsCount < 0) + { + _totalElementsCount = IsJagged + ? GetJaggedArrayFlattenedLength(this) + : ArrayInfo.FlattenedLength; + } + + return _totalElementsCount; + } + } + public override TypeName TypeName => _typeName ??= MemberTypeInfo.GetArrayTypeName(ArrayInfo); @@ -84,6 +104,10 @@ private protected override Array Deserialize(Type arrayType, bool allowNulls) case SerializationRecordType.ArraySinglePrimitive: case SerializationRecordType.ArraySingleObject: case SerializationRecordType.ArraySingleString: + + // Recursion depth is bounded by the depth of arrayType, which is + // a trustworthy Type instance. Don't need to worry about stack overflow. + ArrayRecord nestedArrayRecord = (ArrayRecord)record; Array nestedArray = nestedArrayRecord.GetArray(actualElementType, allowNulls); array.SetValue(nestedArray, resultIndex++); @@ -97,6 +121,7 @@ private protected override Array Deserialize(Type arrayType, bool allowNulls) } int nullCount = ((NullsRecord)item).NullCount; + Debug.Assert(nullCount > 0, "All implementations of NullsRecord are expected to return a positive value for NullCount."); do { array.SetValue(null, resultIndex++); @@ -110,6 +135,8 @@ private protected override Array Deserialize(Type arrayType, bool allowNulls) } } + Debug.Assert(resultIndex == array.Length, "We should have traversed the entirety of the newly created array."); + return array; } @@ -122,6 +149,7 @@ internal static ArrayRecord Decode(BinaryReader reader, RecordMap recordMap, Pay bool isRectangular = arrayType is BinaryArrayType.Rectangular; // It is an arbitrary limit in the current CoreCLR type loader. + // Don't change this value without reviewing the loop a few lines below. const int MaxSupportedArrayRank = 32; if (rank < 1 || rank > MaxSupportedArrayRank @@ -132,18 +160,26 @@ internal static ArrayRecord Decode(BinaryReader reader, RecordMap recordMap, Pay } int[] lengths = new int[rank]; // adversary-controlled, but acceptable since upper limit of 32 - long totalElementCount = 1; + long totalElementCount = 1; // to avoid integer overflow during the multiplication below for (int i = 0; i < lengths.Length; i++) { lengths[i] = ArrayInfo.ParseValidArrayLength(reader); totalElementCount *= lengths[i]; - if (totalElementCount > uint.MaxValue) + // n.b. This forbids "new T[Array.MaxLength, Array.MaxLength, Array.MaxLength, ..., 0]" + // but allows "new T[0, Array.MaxLength, Array.MaxLength, Array.MaxLength, ...]". But + // that's the same behavior that newarr and Array.CreateInstance exhibit, so at least + // we're consistent. + + if (totalElementCount > ArrayInfo.MaxArrayLength) { ThrowHelper.ThrowInvalidValue(lengths[i]); // max array size exceeded } } + // Per BinaryReaderExtensions.ReadArrayType, we do not support nonzero offsets, so + // we don't need to read the NRBF stream 'LowerBounds' field here. + MemberTypeInfo memberTypeInfo = MemberTypeInfo.Decode(reader, 1, options, recordMap); ArrayInfo arrayInfo = new(objectId, totalElementCount, arrayType, rank); @@ -157,6 +193,65 @@ internal static ArrayRecord Decode(BinaryReader reader, RecordMap recordMap, Pay : new BinaryArrayRecord(arrayInfo, memberTypeInfo); } + private static long GetJaggedArrayFlattenedLength(BinaryArrayRecord jaggedArrayRecord) + { + long result = 0; + Queue? jaggedArrayRecords = null; + + do + { + if (jaggedArrayRecords is not null) + { + jaggedArrayRecord = jaggedArrayRecords.Dequeue(); + } + + Debug.Assert(jaggedArrayRecord.IsJagged); + + // In theory somebody could create a payload that would represent + // a very nested array with total elements count > long.MaxValue. + // That is why this method is using checked arithmetic. + result = checked(result + jaggedArrayRecord.Length); // count the arrays themselves + + foreach (object value in jaggedArrayRecord.Values) + { + if (value is not SerializationRecord record) + { + continue; + } + + if (record.RecordType == SerializationRecordType.MemberReference) + { + record = ((MemberReferenceRecord)record).GetReferencedRecord(); + } + + switch (record.RecordType) + { + case SerializationRecordType.ArraySinglePrimitive: + case SerializationRecordType.ArraySingleObject: + case SerializationRecordType.ArraySingleString: + case SerializationRecordType.BinaryArray: + ArrayRecord nestedArrayRecord = (ArrayRecord)record; + if (nestedArrayRecord.IsJagged) + { + (jaggedArrayRecords ??= new()).Enqueue((BinaryArrayRecord)nestedArrayRecord); + } + else + { + // Don't call nestedArrayRecord.FlattenedLength to avoid any potential recursion, + // just call nestedArrayRecord.ArrayInfo.FlattenedLength that returns pre-computed value. + result = checked(result + nestedArrayRecord.ArrayInfo.FlattenedLength); + } + break; + default: + break; + } + } + } + while (jaggedArrayRecords is not null && jaggedArrayRecords.Count > 0); + + return result; + } + private protected override void AddValue(object value) => Values.Add(value); internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetAllowedRecordType() @@ -186,6 +281,9 @@ private static Type MapElementType(Type arrayType, out bool isClassRecord) Type elementType = arrayType; int arrayNestingDepth = 0; + // Loop iteration counts are bound by the nesting depth of arrayType, + // which is a trustworthy input. No DoS concerns. + while (elementType.IsArray) { elementType = elementType.GetElementType()!; diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/BinaryLibraryRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/BinaryLibraryRecord.cs index ccd39922e23fb..b723d8083e4a9 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/BinaryLibraryRecord.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/BinaryLibraryRecord.cs @@ -30,14 +30,7 @@ private BinaryLibraryRecord(SerializationRecordId libraryId, AssemblyNameInfo li public override SerializationRecordType RecordType => SerializationRecordType.BinaryLibrary; - public override TypeName TypeName - { - get - { - Debug.Fail("TypeName should never be called on BinaryLibraryRecord"); - return TypeName.Parse(nameof(BinaryLibraryRecord).AsSpan()); - } - } + public override TypeName TypeName => TypeName.Parse(nameof(BinaryLibraryRecord).AsSpan()); internal string? RawLibraryName { get; } @@ -57,7 +50,7 @@ internal static BinaryLibraryRecord Decode(BinaryReader reader, PayloadOptions o } else if (!options.UndoTruncatedTypeNames) { - ThrowHelper.ThrowInvalidAssemblyName(rawName); + ThrowHelper.ThrowInvalidAssemblyName(); } return new BinaryLibraryRecord(id, rawName); diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ClassInfo.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ClassInfo.cs index 75340b72a4f0d..a1cb7b47fb5ae 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ClassInfo.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ClassInfo.cs @@ -50,7 +50,8 @@ internal static ClassInfo Decode(BinaryReader reader) // Use Dictionary instead of List so that searching for member IDs by name // is O(n) instead of O(m * n), where m = memberCount and n = memberNameLength, - // in degenerate cases. + // in degenerate cases. Since memberCount may be hostile, don't allow it to be + // used as the initial capacity in the collection instance. Dictionary memberNames = new(StringComparer.Ordinal); for (int i = 0; i < memberCount; i++) { @@ -70,7 +71,7 @@ internal static ClassInfo Decode(BinaryReader reader) continue; } #endif - throw new SerializationException(SR.Format(SR.Serialization_DuplicateMemberName, memberName)); + ThrowHelper.ThrowDuplicateMemberName(); } return new ClassInfo(id, typeName, memberNames); diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ClassTypeInfo.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ClassTypeInfo.cs index 6a9e9d7b90afe..dd5ee0e5bdf2c 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ClassTypeInfo.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ClassTypeInfo.cs @@ -9,7 +9,7 @@ namespace System.Formats.Nrbf; /// -/// Identifies a class by it's name and library id. +/// Identifies a class by its name and library id. /// /// /// ClassTypeInfo structures are described in [MS-NRBF] 2.1.1.8. @@ -26,7 +26,7 @@ internal static ClassTypeInfo Decode(BinaryReader reader, PayloadOptions options string rawName = reader.ReadString(); SerializationRecordId libraryId = SerializationRecordId.Decode(reader); - BinaryLibraryRecord library = (BinaryLibraryRecord)recordMap[libraryId]; + BinaryLibraryRecord library = recordMap.GetRecord(libraryId); return new ClassTypeInfo(rawName.ParseNonSystemClassRecordTypeName(library, options)); } diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ClassWithIdRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ClassWithIdRecord.cs index e18033524d17e..c643d3ce8c846 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ClassWithIdRecord.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ClassWithIdRecord.cs @@ -34,10 +34,7 @@ internal static ClassWithIdRecord Decode( SerializationRecordId id = SerializationRecordId.Decode(reader); SerializationRecordId metadataId = SerializationRecordId.Decode(reader); - if (recordMap[metadataId] is not ClassRecord referencedRecord) - { - throw new SerializationException(SR.Serialization_InvalidReference); - } + ClassRecord referencedRecord = recordMap.GetRecord(metadataId); return new ClassWithIdRecord(id, referencedRecord); } diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ClassWithMembersAndTypesRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ClassWithMembersAndTypesRecord.cs index 117e5e90ef681..d6d8c122d3ed9 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ClassWithMembersAndTypesRecord.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ClassWithMembersAndTypesRecord.cs @@ -27,7 +27,7 @@ internal static ClassWithMembersAndTypesRecord Decode(BinaryReader reader, Recor MemberTypeInfo memberTypeInfo = MemberTypeInfo.Decode(reader, classInfo.MemberNames.Count, options, recordMap); SerializationRecordId libraryId = SerializationRecordId.Decode(reader); - BinaryLibraryRecord library = (BinaryLibraryRecord)recordMap[libraryId]; + BinaryLibraryRecord library = recordMap.GetRecord(libraryId); classInfo.LoadTypeName(library, options); return new ClassWithMembersAndTypesRecord(classInfo, memberTypeInfo); diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/MemberReferenceRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/MemberReferenceRecord.cs index 162cf0b1d5c57..14bd4e7ff1f2d 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/MemberReferenceRecord.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/MemberReferenceRecord.cs @@ -38,5 +38,5 @@ private MemberReferenceRecord(SerializationRecordId reference, RecordMap recordM internal static MemberReferenceRecord Decode(BinaryReader reader, RecordMap recordMap) => new(SerializationRecordId.Decode(reader), recordMap); - internal SerializationRecord GetReferencedRecord() => RecordMap[Reference]; + internal SerializationRecord GetReferencedRecord() => RecordMap.GetRecord(Reference); } diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/MemberTypeInfo.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/MemberTypeInfo.cs index 9843a0b71f04c..57e47a02eec68 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/MemberTypeInfo.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/MemberTypeInfo.cs @@ -53,10 +53,14 @@ internal static MemberTypeInfo Decode(BinaryReader reader, int count, PayloadOpt case BinaryType.Class: info[i] = (type, ClassTypeInfo.Decode(reader, options, recordMap)); break; - default: - // Other types have no additional data. - Debug.Assert(type is BinaryType.String or BinaryType.ObjectArray or BinaryType.StringArray or BinaryType.Object); + case BinaryType.String: + case BinaryType.StringArray: + case BinaryType.Object: + case BinaryType.ObjectArray: + // These types have no additional data. break; + default: + throw new InvalidOperationException(); } } @@ -97,7 +101,8 @@ internal static MemberTypeInfo Decode(BinaryReader reader, int count, PayloadOpt BinaryType.PrimitiveArray => (PrimitiveArray, default), BinaryType.Class => (NonSystemClass, default), BinaryType.SystemClass => (SystemClass, default), - _ => (ObjectArray, default) + BinaryType.ObjectArray => (ObjectArray, default), + _ => throw new InvalidOperationException() }; } @@ -105,7 +110,7 @@ internal bool ShouldBeRepresentedAsArrayOfClassRecords() { // This library tries to minimize the number of concepts the users need to learn to use it. // Since SZArrays are most common, it provides an SZArrayRecord abstraction. - // Every other array (jagged, multi-dimensional etc) is represented using SZArrayRecord. + // Every other array (jagged, multi-dimensional etc) is represented using ArrayRecord. // The goal of this method is to determine whether given array can be represented as SZArrayRecord. (BinaryType binaryType, object? additionalInfo) = Infos[0]; @@ -144,15 +149,15 @@ internal TypeName GetArrayTypeName(ArrayInfo arrayInfo) TypeName elementTypeName = binaryType switch { - BinaryType.String => TypeNameHelpers.GetPrimitiveTypeName(PrimitiveType.String), - BinaryType.StringArray => TypeNameHelpers.GetPrimitiveSZArrayTypeName(PrimitiveType.String), + BinaryType.String => TypeNameHelpers.GetPrimitiveTypeName(TypeNameHelpers.StringPrimitiveType), + BinaryType.StringArray => TypeNameHelpers.GetPrimitiveSZArrayTypeName(TypeNameHelpers.StringPrimitiveType), BinaryType.Primitive => TypeNameHelpers.GetPrimitiveTypeName((PrimitiveType)additionalInfo!), BinaryType.PrimitiveArray => TypeNameHelpers.GetPrimitiveSZArrayTypeName((PrimitiveType)additionalInfo!), BinaryType.Object => TypeNameHelpers.GetPrimitiveTypeName(TypeNameHelpers.ObjectPrimitiveType), BinaryType.ObjectArray => TypeNameHelpers.GetPrimitiveSZArrayTypeName(TypeNameHelpers.ObjectPrimitiveType), BinaryType.SystemClass => (TypeName)additionalInfo!, BinaryType.Class => ((ClassTypeInfo)additionalInfo!).TypeName, - _ => throw new ArgumentOutOfRangeException(paramName: nameof(binaryType), actualValue: binaryType, message: null) + _ => throw new InvalidOperationException() }; // In general, arrayRank == 1 may have two different meanings: diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/MessageEndRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/MessageEndRecord.cs index 7cb28224a890e..62c7d57b3fa37 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/MessageEndRecord.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/MessageEndRecord.cs @@ -24,12 +24,5 @@ private MessageEndRecord() public override SerializationRecordId Id => SerializationRecordId.NoId; - public override TypeName TypeName - { - get - { - Debug.Fail("TypeName should never be called on MessageEndRecord"); - return TypeName.Parse(nameof(MessageEndRecord).AsSpan()); - } - } + public override TypeName TypeName => TypeName.Parse(nameof(MessageEndRecord).AsSpan()); } diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/NextInfo.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/NextInfo.cs index a01a25e60047a..08b1f53dca670 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/NextInfo.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/NextInfo.cs @@ -27,7 +27,5 @@ internal NextInfo(AllowedRecordTypes allowed, SerializationRecord parent, internal PrimitiveType PrimitiveType { get; } internal NextInfo With(AllowedRecordTypes allowed, PrimitiveType primitiveType) - => allowed == Allowed && primitiveType == PrimitiveType - ? this // previous record was of the same type - : new(allowed, Parent, Stack, primitiveType); + => new(allowed, Parent, Stack, primitiveType); } diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/NrbfDecoder.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/NrbfDecoder.cs index a07c567b7a769..a315b37cff023 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/NrbfDecoder.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/NrbfDecoder.cs @@ -22,7 +22,7 @@ public static class NrbfDecoder // The header consists of: // - a byte that describes the record type (SerializationRecordType.SerializedStreamHeader) // - four 32 bit integers: - // - root Id (every value is valid) + // - root Id (every value except of 0 is valid) // - header Id (value is ignored) // - major version, it has to be equal 1. // - minor version, it has to be equal 0. @@ -46,6 +46,7 @@ public static bool StartsWithPayloadHeader(ReadOnlySpan bytes) /// is . /// The stream does not support reading or seeking. /// The stream was closed. + /// An I/O error occurred. /// When this method returns, will be restored to its original position. public static bool StartsWithPayloadHeader(Stream stream) { @@ -68,28 +69,22 @@ public static bool StartsWithPayloadHeader(Stream stream) return false; } - try + byte[] buffer = new byte[SerializedStreamHeaderRecord.Size]; + int offset = 0; + while (offset < buffer.Length) { -#if NET - Span buffer = stackalloc byte[SerializedStreamHeaderRecord.Size]; - stream.ReadExactly(buffer); -#else - byte[] buffer = new byte[SerializedStreamHeaderRecord.Size]; - int offset = 0; - while (offset < buffer.Length) + int read = stream.Read(buffer, offset, buffer.Length - offset); + if (read == 0) { - int read = stream.Read(buffer, offset, buffer.Length - offset); - if (read == 0) - throw new EndOfStreamException(); - offset += read; + stream.Position = beginning; + return false; } -#endif - return StartsWithPayloadHeader(buffer); - } - finally - { - stream.Position = beginning; + offset += read; } + + bool result = StartsWithPayloadHeader(buffer); + stream.Position = beginning; + return result; } /// @@ -107,6 +102,7 @@ public static bool StartsWithPayloadHeader(Stream stream) /// is . /// does not support reading or is already closed. /// Reading from encounters invalid NRBF data. + /// An I/O error occurred. /// /// Reading from encounters not supported records. /// For example, arrays with non-zero offset or not supported record types @@ -142,7 +138,14 @@ public static SerializationRecord Decode(Stream payload, out IReadOnlyDictionary #endif using BinaryReader reader = new(payload, ThrowOnInvalidUtf8Encoding, leaveOpen: leaveOpen); - return Decode(reader, options ?? new(), out recordMap); + try + { + return Decode(reader, options ?? new(), out recordMap); + } + catch (FormatException) // can be thrown by various BinaryReader methods + { + throw new SerializationException(SR.Serialization_InvalidFormat); + } } /// @@ -213,12 +216,7 @@ private static SerializationRecord Decode(BinaryReader reader, PayloadOptions op private static SerializationRecord DecodeNext(BinaryReader reader, RecordMap recordMap, AllowedRecordTypes allowed, PayloadOptions options, out SerializationRecordType recordType) { - byte nextByte = reader.ReadByte(); - if (((uint)allowed & (1u << nextByte)) == 0) - { - ThrowHelper.ThrowForUnexpectedRecordType(nextByte); - } - recordType = (SerializationRecordType)nextByte; + recordType = reader.ReadSerializationRecordType(allowed); SerializationRecord record = recordType switch { @@ -237,7 +235,8 @@ private static SerializationRecord DecodeNext(BinaryReader reader, RecordMap rec SerializationRecordType.ObjectNullMultiple => ObjectNullMultipleRecord.Decode(reader), SerializationRecordType.ObjectNullMultiple256 => ObjectNullMultiple256Record.Decode(reader), SerializationRecordType.SerializedStreamHeader => SerializedStreamHeaderRecord.Decode(reader), - _ => SystemClassWithMembersAndTypesRecord.Decode(reader, recordMap, options), + SerializationRecordType.SystemClassWithMembersAndTypes => SystemClassWithMembersAndTypesRecord.Decode(reader, recordMap, options), + _ => throw new InvalidOperationException() }; recordMap.Add(record); @@ -254,7 +253,7 @@ private static SerializationRecord DecodeMemberPrimitiveTypedRecord(BinaryReader PrimitiveType.Boolean => new MemberPrimitiveTypedRecord(reader.ReadBoolean()), PrimitiveType.Byte => new MemberPrimitiveTypedRecord(reader.ReadByte()), PrimitiveType.SByte => new MemberPrimitiveTypedRecord(reader.ReadSByte()), - PrimitiveType.Char => new MemberPrimitiveTypedRecord(reader.ReadChar()), + PrimitiveType.Char => new MemberPrimitiveTypedRecord(reader.ParseChar()), PrimitiveType.Int16 => new MemberPrimitiveTypedRecord(reader.ReadInt16()), PrimitiveType.UInt16 => new MemberPrimitiveTypedRecord(reader.ReadUInt16()), PrimitiveType.Int32 => new MemberPrimitiveTypedRecord(reader.ReadInt32()), @@ -263,10 +262,10 @@ private static SerializationRecord DecodeMemberPrimitiveTypedRecord(BinaryReader PrimitiveType.UInt64 => new MemberPrimitiveTypedRecord(reader.ReadUInt64()), PrimitiveType.Single => new MemberPrimitiveTypedRecord(reader.ReadSingle()), PrimitiveType.Double => new MemberPrimitiveTypedRecord(reader.ReadDouble()), - PrimitiveType.Decimal => new MemberPrimitiveTypedRecord(decimal.Parse(reader.ReadString(), CultureInfo.InvariantCulture)), - PrimitiveType.DateTime => new MemberPrimitiveTypedRecord(Utils.BinaryReaderExtensions.CreateDateTimeFromData(reader.ReadInt64())), - // String is handled with a record, never on it's own - _ => new MemberPrimitiveTypedRecord(new TimeSpan(reader.ReadInt64())), + PrimitiveType.Decimal => new MemberPrimitiveTypedRecord(reader.ParseDecimal()), + PrimitiveType.DateTime => new MemberPrimitiveTypedRecord(Utils.BinaryReaderExtensions.CreateDateTimeFromData(reader.ReadUInt64())), + PrimitiveType.TimeSpan => new MemberPrimitiveTypedRecord(new TimeSpan(reader.ReadInt64())), + _ => throw new InvalidOperationException() }; } @@ -291,7 +290,8 @@ private static SerializationRecord DecodeArraySinglePrimitiveRecord(BinaryReader PrimitiveType.Double => Decode(info, reader), PrimitiveType.Decimal => Decode(info, reader), PrimitiveType.DateTime => Decode(info, reader), - _ => Decode(info, reader), + PrimitiveType.TimeSpan => Decode(info, reader), + _ => throw new InvalidOperationException() }; static SerializationRecord Decode(ArrayInfo info, BinaryReader reader) where T : unmanaged diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/NullsRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/NullsRecord.cs index d3d859c193a9c..9c11db4307ced 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/NullsRecord.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/NullsRecord.cs @@ -12,12 +12,5 @@ internal abstract class NullsRecord : SerializationRecord public override SerializationRecordId Id => SerializationRecordId.NoId; - public override TypeName TypeName - { - get - { - Debug.Fail($"TypeName should never be called on {GetType().Name}"); - return TypeName.Parse(GetType().Name.AsSpan()); - } - } + public override TypeName TypeName => TypeName.Parse(GetType().Name.AsSpan()); } diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/PayloadOptions.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/PayloadOptions.cs index 60d1aafcc5291..fdb3ccae4632e 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/PayloadOptions.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/PayloadOptions.cs @@ -25,10 +25,17 @@ public PayloadOptions() { } /// /// if truncated type names should be reassembled; otherwise, . /// + /// /// Example: /// TypeName: "Namespace.TypeName`1[[Namespace.GenericArgName" /// LibraryName: "AssemblyName]]" /// Is combined into "Namespace.TypeName`1[[Namespace.GenericArgName, AssemblyName]]" + /// + /// + /// Setting this to can render susceptible to Denial of Service + /// attacks when parsing or handling malicious input. + /// + /// The default value is . /// public bool UndoTruncatedTypeNames { get; set; } } diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/PrimitiveType.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/PrimitiveType.cs index 9ddb9179518fa..f2e696e6a90e9 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/PrimitiveType.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/PrimitiveType.cs @@ -11,10 +11,6 @@ namespace System.Formats.Nrbf; /// internal enum PrimitiveType : byte { - /// - /// Used internally to express no value - /// - None = 0, Boolean = 1, Byte = 2, Char = 3, @@ -30,7 +26,19 @@ internal enum PrimitiveType : byte DateTime = 13, UInt16 = 14, UInt32 = 15, - UInt64 = 16, - Null = 17, - String = 18 + UInt64 = 16 + // This internal enum no longer contains Null and String as they were always illegal: + // - In case of BinaryArray (NRBF 2.4.3.1): + // "If the BinaryTypeEnum value is Primitive, the PrimitiveTypeEnumeration + // value in AdditionalTypeInfo MUST NOT be Null (17) or String (18)." + // - In case of MemberPrimitiveTyped (NRBF 2.5.1): + // "PrimitiveTypeEnum (1 byte): A PrimitiveTypeEnumeration + // value that specifies the Primitive Type of data that is being transmitted. + // This field MUST NOT contain a value of 17 (Null) or 18 (String)." + // - In case of ArraySinglePrimitive (NRBF 2.4.3.3): + // "A PrimitiveTypeEnumeration value that identifies the Primitive Type + // of the items of the Array. The value MUST NOT be 17 (Null) or 18 (String)." + // - In case of MemberTypeInfo (NRBF 2.3.1.2): + // "When the BinaryTypeEnum value is Primitive, the PrimitiveTypeEnumeration + // value in AdditionalInfo MUST NOT be Null (17) or String (18)." } diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/RecordMap.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/RecordMap.cs index a25ab508f5db3..eafcbf93249c5 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/RecordMap.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/RecordMap.cs @@ -56,14 +56,15 @@ internal void Add(SerializationRecord record) return; } #endif - throw new SerializationException(SR.Format(SR.Serialization_DuplicateSerializationRecordId, record.Id)); + throw new SerializationException(SR.Format(SR.Serialization_DuplicateSerializationRecordId, record.Id._id)); } } } internal SerializationRecord GetRootRecord(SerializedStreamHeaderRecord header) { - SerializationRecord rootRecord = _map[header.RootId]; + SerializationRecord rootRecord = GetRecord(header.RootId); + if (rootRecord is SystemClassWithMembersAndTypesRecord systemClass) { // update the record map, so it's visible also to those who access it via Id @@ -72,4 +73,14 @@ internal SerializationRecord GetRootRecord(SerializedStreamHeaderRecord header) return rootRecord; } + + internal SerializationRecord GetRecord(SerializationRecordId recordId) + => _map.TryGetValue(recordId, out SerializationRecord? record) + ? record + : throw new SerializationException(SR.Serialization_InvalidReference); + + internal T GetRecord(SerializationRecordId recordId) where T : SerializationRecord + => _map.TryGetValue(recordId, out SerializationRecord? record) && record is T casted + ? casted + : throw new SerializationException(SR.Serialization_InvalidReference); } diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/RectangularArrayRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/RectangularArrayRecord.cs index de3c6d671850a..f64dde36163d6 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/RectangularArrayRecord.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/RectangularArrayRecord.cs @@ -8,13 +8,14 @@ using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Formats.Nrbf.Utils; +using System.Diagnostics; namespace System.Formats.Nrbf; internal sealed class RectangularArrayRecord : ArrayRecord { private readonly int[] _lengths; - private readonly ICollection _values; + private readonly List _values; private TypeName? _typeName; private RectangularArrayRecord(Type elementType, ArrayInfo arrayInfo, @@ -24,18 +25,8 @@ private RectangularArrayRecord(Type elementType, ArrayInfo arrayInfo, MemberTypeInfo = memberTypeInfo; _lengths = lengths; - // A List can hold as many objects as an array, so for multi-dimensional arrays - // with more elements than Array.MaxLength we use LinkedList. - // Testing that many elements takes a LOT of time, so to ensure that both code paths are tested, - // we always use LinkedList code path for Debug builds. -#if DEBUG - _values = new LinkedList(); -#else - _values = arrayInfo.TotalElementsCount <= ArrayInfo.MaxArrayLength - ? new List(canPreAllocate ? arrayInfo.GetSZArrayLength() : Math.Min(4, arrayInfo.GetSZArrayLength())) - : new LinkedList(); -#endif - + // ArrayInfo.GetSZArrayLength ensures to return a value <= Array.MaxLength + _values = new List(canPreAllocate ? arrayInfo.GetSZArrayLength() : Math.Min(4, arrayInfo.GetSZArrayLength())); } public override SerializationRecordType RecordType => SerializationRecordType.BinaryArray; @@ -65,10 +56,12 @@ private protected override Array Deserialize(Type arrayType, bool allowNulls) #if !NET8_0_OR_GREATER int[] indices = new int[_lengths.Length]; + nuint numElementsWritten = 0; // only for debugging; not used in release builds foreach (object value in _values) { result.SetValue(GetActualValue(value), indices); + numElementsWritten++; int dimension = indices.Length - 1; while (dimension >= 0) @@ -88,6 +81,9 @@ private protected override Array Deserialize(Type arrayType, bool allowNulls) } } + Debug.Assert(numElementsWritten == (uint)_values.Count, "We should have traversed the entirety of the source values collection."); + Debug.Assert(numElementsWritten == (ulong)result.LongLength, "We should have traversed the entirety of the destination array."); + return result; #else // Idea from Array.CoreCLR that maps an array of int indices into @@ -108,6 +104,7 @@ private protected override Array Deserialize(Type arrayType, bool allowNulls) else if (ElementType == typeof(TimeSpan)) CopyTo(_values, result); else if (ElementType == typeof(DateTime)) CopyTo(_values, result); else if (ElementType == typeof(decimal)) CopyTo(_values, result); + else throw new InvalidOperationException(); } else { @@ -116,7 +113,7 @@ private protected override Array Deserialize(Type arrayType, bool allowNulls) return result; - static void CopyTo(ICollection list, Array array) + static void CopyTo(List list, Array array) { ref byte arrayDataRef = ref MemoryMarshal.GetArrayDataReference(array); ref T firstElementRef = ref Unsafe.As(ref arrayDataRef); @@ -127,6 +124,8 @@ static void CopyTo(ICollection list, Array array) targetElement = (T)GetActualValue(value)!; flattenedIndex++; } + + Debug.Assert(flattenedIndex == (ulong)array.LongLength, "We should have traversed the entirety of the array."); } #endif } @@ -167,7 +166,7 @@ internal static RectangularArrayRecord Create(BinaryReader reader, ArrayInfo arr PrimitiveType.Boolean => sizeof(bool), PrimitiveType.Byte => sizeof(byte), PrimitiveType.SByte => sizeof(sbyte), - PrimitiveType.Char => sizeof(byte), // it's UTF8 + PrimitiveType.Char => sizeof(byte), // it's UTF8 (see comment below) PrimitiveType.Int16 => sizeof(short), PrimitiveType.UInt16 => sizeof(ushort), PrimitiveType.Int32 => sizeof(int), @@ -176,12 +175,29 @@ internal static RectangularArrayRecord Create(BinaryReader reader, ArrayInfo arr PrimitiveType.Int64 => sizeof(long), PrimitiveType.UInt64 => sizeof(ulong), PrimitiveType.Double => sizeof(double), - _ => -1 + PrimitiveType.TimeSpan => sizeof(ulong), + PrimitiveType.DateTime => sizeof(ulong), + PrimitiveType.Decimal => -1, // represented as variable-length string + _ => throw new InvalidOperationException() }; if (sizeOfSingleValue > 0) { - long size = arrayInfo.TotalElementsCount * sizeOfSingleValue; + // NRBF encodes rectangular char[,,,...] by converting each standalone UTF-16 code point into + // its UTF-8 encoding. This means that surrogate code points (including adjacent surrogate + // pairs) occurring within a char[,,,...] cannot be encoded by NRBF. BinaryReader will detect + // that they're ill-formed and reject them on read. + // + // Per the comment in ArraySinglePrimitiveRecord.DecodePrimitiveTypes, we'll assume best-case + // encoding where 1 UTF-16 char encodes as a single UTF-8 byte, even though this might lead + // to encountering an EOF if we realize later that we actually need to read more bytes in + // order to fully populate the char[,,,...] array. Any such allocation is still linearly + // proportional to the length of the incoming payload, so it's not a DoS vector. + // The multiplication below is guaranteed not to overflow because FlattenedLength is bounded + // to <= Array.MaxLength (see BinaryArrayRecord.Decode) and sizeOfSingleValue is at most 8. + Debug.Assert(arrayInfo.FlattenedLength >= 0 && arrayInfo.FlattenedLength <= long.MaxValue / sizeOfSingleValue); + + long size = arrayInfo.FlattenedLength * sizeOfSingleValue; bool? isDataAvailable = reader.IsDataAvailable(size); if (isDataAvailable.HasValue) { @@ -215,7 +231,8 @@ private static Type MapPrimitive(PrimitiveType primitiveType) PrimitiveType.DateTime => typeof(DateTime), PrimitiveType.UInt16 => typeof(ushort), PrimitiveType.UInt32 => typeof(uint), - _ => typeof(ulong) + PrimitiveType.UInt64 => typeof(ulong), + _ => throw new InvalidOperationException() }; private static Type MapPrimitiveArray(PrimitiveType primitiveType) @@ -235,7 +252,8 @@ private static Type MapPrimitiveArray(PrimitiveType primitiveType) PrimitiveType.DateTime => typeof(DateTime[]), PrimitiveType.UInt16 => typeof(ushort[]), PrimitiveType.UInt32 => typeof(uint[]), - _ => typeof(ulong[]), + PrimitiveType.UInt64 => typeof(ulong[]), + _ => throw new InvalidOperationException() }; private static object? GetActualValue(object value) diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/SerializationRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/SerializationRecord.cs index 751a932d8f8e0..8a2d9ad7653b3 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/SerializationRecord.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/SerializationRecord.cs @@ -13,7 +13,7 @@ namespace System.Formats.Nrbf; /// /// /// Every instance returned to the end user can be either , -/// a or an . +/// a , or an . /// /// [DebuggerDisplay("{RecordType}, {Id}")] @@ -50,7 +50,20 @@ internal SerializationRecord() // others can't derive from this type /// /// The type to compare against. /// if the serialized type name match provided type; otherwise, . - public bool TypeNameMatches(Type type) => Matches(type, TypeName); + /// is . + public bool TypeNameMatches(Type type) + { +#if NET + ArgumentNullException.ThrowIfNull(type); +#else + if (type is null) + { + throw new ArgumentNullException(nameof(type)); + } +#endif + + return Matches(type, TypeName); + } private static bool Matches(Type type, TypeName typeName) { @@ -61,10 +74,38 @@ private static bool Matches(Type type, TypeName typeName) return false; } + // The TypeName.FullName property getter is recursive and backed by potentially hostile + // input. See comments in that property getter for more information, including what defenses + // are in place to prevent attacks. + // + // Note that the equality comparison below is worst-case O(n) since the adversary could ensure + // that only the last char differs. Even if the strings have equal contents, we should still + // expect the comparison to take O(n) time since RuntimeType.FullName and TypeName.FullName + // will never reference the same string instance with current runtime implementations. + // + // Since a call to Matches could take place within a loop, and since TypeName.FullName could + // be arbitrarily long (it's attacker-controlled and the NRBF protocol allows backtracking via + // the ClassWithId record, providing a form of compression), this presents opportunity + // for an algorithmic complexity attack, where a (2 * l)-length payload has an l-length type + // name and an array with l elements, resulting in O(l^2) total work factor. Protection against + // such attack is provided by the fact that the System.Type object is fully under the app's + // control and is assumed to be trusted and a reasonable length. This brings the cumulative loop + // work factor back down to O(l * RuntimeType.FullName), which is acceptable. + // + // The above statement assumes that "(string)m == (string)n" has worst-case complexity + // O(min(m.Length, n.Length)). This is not stated in string's public docs, but it is + // a guaranteed behavior for all built-in Ordinal string comparisons. + // At first, check the non-allocating properties for mismatch. if (type.IsArray != typeName.IsArray || type.IsConstructedGenericType != typeName.IsConstructedGenericType || type.IsNested != typeName.IsNested - || (type.IsArray && type.GetArrayRank() != typeName.GetArrayRank())) + || (type.IsArray && type.GetArrayRank() != typeName.GetArrayRank()) +#if NET + || type.IsSZArray != typeName.IsSZArray // int[] vs int[*] +#else + || (type.IsArray && type.Name != typeName.Name) +#endif + ) { return false; } @@ -111,11 +152,16 @@ private static bool Matches(Type type, TypeName typeName) /// For reference records, it returns the referenced record. /// For other records, it returns the records themselves. /// + /// + /// Overrides of this method should take care not to allow + /// the introduction of cycles, even in the face of adversarial + /// edges in the object graph. + /// internal virtual object? GetValue() => this; internal virtual void HandleNextRecord(SerializationRecord nextRecord, NextInfo info) - => Debug.Fail($"HandleNextRecord should not have been called for '{GetType().Name}'"); + => throw new InvalidOperationException(); internal virtual void HandleNextValue(object value, NextInfo info) - => Debug.Fail($"HandleNextValue should not have been called for '{GetType().Name}'"); + => throw new InvalidOperationException(); } diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/SerializationRecordId.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/SerializationRecordId.cs index 7f51525e6e113..a8318cb72d11d 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/SerializationRecordId.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/SerializationRecordId.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Formats.Nrbf.Utils; using System.IO; using System.Linq; @@ -15,6 +16,7 @@ namespace System.Formats.Nrbf; /// /// The ID of . /// +[DebuggerDisplay("{_id}")] public readonly struct SerializationRecordId : IEquatable { #pragma warning disable CS0649 // the default value is used on purpose @@ -29,6 +31,15 @@ internal static SerializationRecordId Decode(BinaryReader reader) { int id = reader.ReadInt32(); + // Many object ids are required to be positive. See: + // - https://learn.microsoft.com/openspecs/windows_protocols/ms-nrbf/8fac763f-e46d-43a1-b360-80eb83d2c5fb + // - https://learn.microsoft.com/openspecs/windows_protocols/ms-nrbf/eb503ca5-e1f6-4271-a7ee-c4ca38d07996 + // - https://learn.microsoft.com/openspecs/windows_protocols/ms-nrbf/7fcf30e1-4ad4-4410-8f1a-901a4a1ea832 (for library id) + // + // Exception: https://learn.microsoft.com/openspecs/windows_protocols/ms-nrbf/0a192be0-58a1-41d0-8a54-9c91db0ab7bf may be negative + // The problem is that input generated with FormatterTypeStyle.XsdString ends up generating negative Ids anyway. + // That information is not reflected in payload in anyway, so we just always allow for negative Ids. + if (id == 0) { ThrowHelper.ThrowInvalidValue(id); diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/SerializationRecordType.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/SerializationRecordType.cs index 57760b8a377fc..b78bae2ac86ca 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/SerializationRecordType.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/SerializationRecordType.cs @@ -6,6 +6,9 @@ namespace System.Formats.Nrbf; /// /// Record type. /// +/// +/// SerializationRecordType enumeration is described in [MS-NRBF] 2.1.2.1. +/// public enum SerializationRecordType { /// diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/SerializedStreamHeaderRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/SerializedStreamHeaderRecord.cs index 4757958fcb777..b21ff8ca23732 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/SerializedStreamHeaderRecord.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/SerializedStreamHeaderRecord.cs @@ -24,14 +24,8 @@ internal sealed class SerializedStreamHeaderRecord : SerializationRecord public override SerializationRecordType RecordType => SerializationRecordType.SerializedStreamHeader; - public override TypeName TypeName - { - get - { - Debug.Fail("TypeName should never be called on SerializedStreamHeaderRecord"); - return TypeName.Parse(nameof(SerializedStreamHeaderRecord).AsSpan()); - } - } + public override TypeName TypeName => TypeName.Parse(nameof(SerializedStreamHeaderRecord).AsSpan()); + public override SerializationRecordId Id => SerializationRecordId.NoId; internal SerializationRecordId RootId { get; } diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/SystemClassWithMembersAndTypesRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/SystemClassWithMembersAndTypesRecord.cs index 05d38ec736f10..ccecc2246e8c2 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/SystemClassWithMembersAndTypesRecord.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/SystemClassWithMembersAndTypesRecord.cs @@ -75,29 +75,30 @@ ulong value when TypeNameMatches(typeof(UIntPtr)) => Create(new UIntPtr(value)), _ => this }; } - else if (HasMember("_ticks") && MemberValues[0] is long ticks && TypeNameMatches(typeof(TimeSpan))) + else if (HasMember("_ticks") && GetRawValue("_ticks") is long ticks && TypeNameMatches(typeof(TimeSpan))) { return Create(new TimeSpan(ticks)); } } else if (MemberValues.Count == 2 && HasMember("ticks") && HasMember("dateData") - && MemberValues[0] is long value && MemberValues[1] is ulong + && GetRawValue("ticks") is long && GetRawValue("dateData") is ulong dateData && TypeNameMatches(typeof(DateTime))) { - return Create(Utils.BinaryReaderExtensions.CreateDateTimeFromData(value)); + return Create(Utils.BinaryReaderExtensions.CreateDateTimeFromData(dateData)); } - else if(MemberValues.Count == 4 + else if (MemberValues.Count == 4 && HasMember("lo") && HasMember("mid") && HasMember("hi") && HasMember("flags") - && MemberValues[0] is int && MemberValues[1] is int && MemberValues[2] is int && MemberValues[3] is int + && GetRawValue("lo") is int lo && GetRawValue("mid") is int mid + && GetRawValue("hi") is int hi && GetRawValue("flags") is int flags && TypeNameMatches(typeof(decimal))) { int[] bits = [ - GetInt32("lo"), - GetInt32("mid"), - GetInt32("hi"), - GetInt32("flags") + lo, + mid, + hi, + flags ]; return Create(new decimal(bits)); diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/Utils/BinaryReaderExtensions.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/Utils/BinaryReaderExtensions.cs index 73759a7a22dac..d5baa09dbd8fc 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/Utils/BinaryReaderExtensions.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/Utils/BinaryReaderExtensions.cs @@ -1,18 +1,41 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Diagnostics; using System.Globalization; using System.IO; +using System.Reflection; using System.Reflection.Metadata; using System.Runtime.CompilerServices; using System.Runtime.Serialization; +using System.Threading; namespace System.Formats.Nrbf.Utils; internal static class BinaryReaderExtensions { + private static object? s_baseAmbiguousDstDateTime; + + internal static SerializationRecordType ReadSerializationRecordType(this BinaryReader reader, AllowedRecordTypes allowed) + { + byte nextByte = reader.ReadByte(); + if (nextByte > (byte)SerializationRecordType.MethodReturn // MethodReturn is the last defined value. + || (nextByte > (byte)SerializationRecordType.ArraySingleString && nextByte < (byte)SerializationRecordType.MethodCall) // not part of the spec + || ((uint)allowed & (1u << nextByte)) == 0) // valid, but not allowed + { + ThrowHelper.ThrowForUnexpectedRecordType(nextByte); + } + + return (SerializationRecordType)nextByte; + } + internal static BinaryArrayType ReadArrayType(this BinaryReader reader) { + // To simplify the behavior and security review of the BinaryArrayRecord type, we + // do not support reading non-zero-offset arrays. If this should change in the + // future, the BinaryArrayRecord.Decode method and supporting infrastructure + // will need re-review. + byte arrayType = reader.ReadByte(); // Rectangular is the last defined value. if (arrayType > (byte)BinaryArrayType.Rectangular) @@ -43,8 +66,8 @@ internal static BinaryType ReadBinaryType(this BinaryReader reader) internal static PrimitiveType ReadPrimitiveType(this BinaryReader reader) { byte primitiveType = reader.ReadByte(); - // String is the last defined value, 4 is not used at all. - if (primitiveType is 4 or > (byte)PrimitiveType.String) + // Boolean is the first valid value (1), UInt64 (16) is the last one. 4 is not used at all. + if (primitiveType is 4 or < (byte)PrimitiveType.Boolean or > (byte)PrimitiveType.UInt64) { ThrowHelper.ThrowInvalidValue(primitiveType); } @@ -60,7 +83,7 @@ internal static object ReadPrimitiveValue(this BinaryReader reader, PrimitiveTyp PrimitiveType.Boolean => reader.ReadBoolean(), PrimitiveType.Byte => reader.ReadByte(), PrimitiveType.SByte => reader.ReadSByte(), - PrimitiveType.Char => reader.ReadChar(), + PrimitiveType.Char => reader.ParseChar(), PrimitiveType.Int16 => reader.ReadInt16(), PrimitiveType.UInt16 => reader.ReadUInt16(), PrimitiveType.Int32 => reader.ReadInt32(), @@ -69,41 +92,130 @@ internal static object ReadPrimitiveValue(this BinaryReader reader, PrimitiveTyp PrimitiveType.UInt64 => reader.ReadUInt64(), PrimitiveType.Single => reader.ReadSingle(), PrimitiveType.Double => reader.ReadDouble(), - PrimitiveType.Decimal => decimal.Parse(reader.ReadString(), CultureInfo.InvariantCulture), - PrimitiveType.DateTime => CreateDateTimeFromData(reader.ReadInt64()), - _ => new TimeSpan(reader.ReadInt64()), + PrimitiveType.Decimal => reader.ParseDecimal(), + PrimitiveType.DateTime => CreateDateTimeFromData(reader.ReadUInt64()), + PrimitiveType.TimeSpan => new TimeSpan(reader.ReadInt64()), + _ => throw new InvalidOperationException(), }; - // TODO: fix https://github.com/dotnet/runtime/issues/102826 + // BinaryFormatter serializes decimals as strings and we can't BinaryReader.ReadDecimal. + internal static decimal ParseDecimal(this BinaryReader reader) + { + // The spec (MS NRBF 2.1.1.6, https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-nrbf/10b218f5-9b2b-4947-b4b7-07725a2c8127) + // says that the length of LengthPrefixedString must be of optimal size (using as few bytes as possible). + // BinaryReader.ReadString does not enforce that and we are OK with that, + // as it takes care of handling multiple edge cases and we don't want to re-implement it. + + string text = reader.ReadString(); + if (!decimal.TryParse(text, NumberStyles.Number, CultureInfo.InvariantCulture, out decimal result)) + { + ThrowHelper.ThrowInvalidFormat(); + } + + return result; + } + + internal static char ParseChar(this BinaryReader reader) + { + try + { + return reader.ReadChar(); + } + catch (ArgumentException) // A surrogate character was read. + { + throw new SerializationException(SR.Serialization_SurrogateCharacter); + } + } + + internal static char[] ParseChars(this BinaryReader reader, int count) + { + char[]? result; + try + { + result = reader.ReadChars(count); + } + catch (ArgumentException) // A surrogate character was read. + { + throw new SerializationException(SR.Serialization_SurrogateCharacter); + } + + if (result.Length != count) + { + // We might hit EOF before fully reading the requested + // number of chars. This means that ReadChars(count) could return a char[] with + // *fewer* than 'count' elements. + ThrowHelper.ThrowEndOfStreamException(); + } + + return result; + } + /// /// Creates a object from raw data with validation. /// - /// was invalid. - internal static DateTime CreateDateTimeFromData(long data) + /// was invalid. + internal static DateTime CreateDateTimeFromData(ulong dateData) { - // Copied from System.Runtime.Serialization.Formatters.Binary.BinaryParser - - // Use DateTime's public constructor to validate the input, but we - // can't return that result as it strips off the kind. To address - // that, store the value directly into a DateTime via an unsafe cast. - // See BinaryFormatterWriter.WriteDateTime for details. + ulong ticks = dateData & 0x3FFFFFFF_FFFFFFFFUL; + DateTimeKind kind = (DateTimeKind)(dateData >> 62); try { - const long TicksMask = 0x3FFFFFFFFFFFFFFF; - _ = new DateTime(data & TicksMask); + return ((uint)kind <= (uint)DateTimeKind.Local) ? new DateTime((long)ticks, kind) : CreateFromAmbiguousDst(ticks); } catch (ArgumentException ex) { - // Bad data throw new SerializationException(ex.Message, ex); } - return Unsafe.As(ref data); + [MethodImpl(MethodImplOptions.NoInlining)] + static DateTime CreateFromAmbiguousDst(ulong ticks) + { + // There's no public API to create a DateTime from an ambiguous DST, and we + // can't use private reflection to access undocumented .NET Framework APIs. + // However, the ISerializable pattern *is* a documented protocol, so we can + // use DateTime's serialization ctor to create a zero-tick "ambiguous" instance, + // then keep reusing it as the base to which we can add our tick offsets. + + if (s_baseAmbiguousDstDateTime is not DateTime baseDateTime) + { +#pragma warning disable SYSLIB0050 // Type or member is obsolete + SerializationInfo si = new(typeof(DateTime), new FormatterConverter()); + // We don't know the value of "ticks", so we don't specify it. + // If the code somehow runs on a very old runtime that does not know the concept of "dateData" + // (it should not be possible as the library targets .NET Standard 2.0) + // the ctor is going to throw rather than silently return an invalid value. + si.AddValue("dateData", 0xC0000000_00000000UL); // new value (serialized as ulong) + +#if NET + baseDateTime = CallPrivateSerializationConstructor(si, new StreamingContext(StreamingContextStates.All)); +#else + ConstructorInfo ci = typeof(DateTime).GetConstructor( + BindingFlags.Instance | BindingFlags.NonPublic, + binder: null, + new Type[] { typeof(SerializationInfo), typeof(StreamingContext) }, + modifiers: null); + + baseDateTime = (DateTime)ci.Invoke(new object[] { si, new StreamingContext(StreamingContextStates.All) }); +#endif + +#pragma warning restore SYSLIB0050 // Type or member is obsolete + Volatile.Write(ref s_baseAmbiguousDstDateTime, baseDateTime); // it's ok if two threads race here + } + + return baseDateTime.AddTicks((long)ticks); + } + +#if NET + [UnsafeAccessor(UnsafeAccessorKind.Constructor)] + extern static DateTime CallPrivateSerializationConstructor(SerializationInfo si, StreamingContext ct); +#endif } internal static bool? IsDataAvailable(this BinaryReader reader, long requiredBytes) { + Debug.Assert(requiredBytes >= 0); + if (!reader.BaseStream.CanSeek) { return null; diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/Utils/ThrowHelper.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/Utils/ThrowHelper.cs index f096bfc736098..ac8c861e5d199 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/Utils/ThrowHelper.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/Utils/ThrowHelper.cs @@ -6,28 +6,33 @@ namespace System.Formats.Nrbf.Utils; +// The exception messages do not contain member/type/assembly names on purpose, +// as it's most likely corrupted/tampered/malicious data. internal static class ThrowHelper { - internal static void ThrowInvalidValue(object value) + internal static void ThrowDuplicateMemberName() + => throw new SerializationException(SR.Serialization_DuplicateMemberName); + + internal static void ThrowInvalidValue(int value) => throw new SerializationException(SR.Format(SR.Serialization_InvalidValue, value)); internal static void ThrowInvalidReference() => throw new SerializationException(SR.Serialization_InvalidReference); - internal static void ThrowInvalidTypeName(string name) - => throw new SerializationException(SR.Format(SR.Serialization_InvalidTypeName, name)); + internal static void ThrowInvalidTypeName() + => throw new SerializationException(SR.Serialization_InvalidTypeName); internal static void ThrowUnexpectedNullRecordCount() => throw new SerializationException(SR.Serialization_UnexpectedNullRecordCount); - internal static void ThrowMaxArrayLength(long limit, long actual) - => throw new SerializationException(SR.Format(SR.Serialization_MaxArrayLength, actual, limit)); - internal static void ThrowArrayContainedNulls() => throw new SerializationException(SR.Serialization_ArrayContainedNulls); - internal static void ThrowInvalidAssemblyName(string rawName) - => throw new SerializationException(SR.Format(SR.Serialization_InvalidAssemblyName, rawName)); + internal static void ThrowInvalidAssemblyName() + => throw new SerializationException(SR.Serialization_InvalidAssemblyName); + + internal static void ThrowInvalidFormat() + => throw new SerializationException(SR.Serialization_InvalidFormat); internal static void ThrowEndOfStreamException() => throw new EndOfStreamException(); diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/Utils/TypeNameHelpers.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/Utils/TypeNameHelpers.cs index 97c3b4e42f68b..a2fba1b52ecbc 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/Utils/TypeNameHelpers.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/Utils/TypeNameHelpers.cs @@ -12,7 +12,8 @@ namespace System.Formats.Nrbf.Utils; internal static class TypeNameHelpers { - // PrimitiveType does not define Object, IntPtr or UIntPtr + // PrimitiveType does not define Object, IntPtr or UIntPtr. + internal const PrimitiveType StringPrimitiveType = (PrimitiveType)18; internal const PrimitiveType ObjectPrimitiveType = (PrimitiveType)19; internal const PrimitiveType IntPtrPrimitiveType = (PrimitiveType)20; internal const PrimitiveType UIntPtrPrimitiveType = (PrimitiveType)21; @@ -22,8 +23,6 @@ internal static class TypeNameHelpers internal static TypeName GetPrimitiveTypeName(PrimitiveType primitiveType) { - Debug.Assert(primitiveType is not (PrimitiveType.None or PrimitiveType.Null)); - TypeName? typeName = s_primitiveTypeNames[(int)primitiveType]; if (typeName is null) { @@ -44,11 +43,11 @@ internal static TypeName GetPrimitiveTypeName(PrimitiveType primitiveType) PrimitiveType.Decimal => "System.Decimal", PrimitiveType.TimeSpan => "System.TimeSpan", PrimitiveType.DateTime => "System.DateTime", - PrimitiveType.String => "System.String", + StringPrimitiveType => "System.String", ObjectPrimitiveType => "System.Object", IntPtrPrimitiveType => "System.IntPtr", UIntPtrPrimitiveType => "System.UIntPtr", - _ => throw new ArgumentOutOfRangeException(paramName: nameof(primitiveType), actualValue: primitiveType, message: null) + _ => throw new InvalidOperationException() }; s_primitiveTypeNames[(int)primitiveType] = typeName = TypeName.Parse(fullName.AsSpan()).WithCoreLibAssemblyName(); @@ -99,7 +98,7 @@ internal static PrimitiveType GetPrimitiveType() else if (typeof(T) == typeof(TimeSpan)) return PrimitiveType.TimeSpan; else if (typeof(T) == typeof(string)) - return PrimitiveType.String; + return StringPrimitiveType; else if (typeof(T) == typeof(IntPtr)) return IntPtrPrimitiveType; else if (typeof(T) == typeof(UIntPtr)) @@ -118,6 +117,17 @@ internal static TypeName ParseNonSystemClassRecordTypeName(this string rawName, Debug.Assert(payloadOptions.UndoTruncatedTypeNames); Debug.Assert(libraryRecord.RawLibraryName is not null); + // This is potentially a DoS vector, as somebody could submit: + // [1] BinaryLibraryRecord = + // [2] ClassRecord (lib = [1]) + // [3] ClassRecord (lib = [1]) + // ... + // [n] ClassRecord (lib = [1]) + // + // Which means somebody submits a payload of length O(long + n) and tricks us into + // performing O(long * n) work. For this reason, we have marked the UndoTruncatedTypeNames + // property as "keep this disabled unless you trust the input." + // Combining type and library allows us for handling truncated generic type names that may be present in resources. ArraySegment assemblyQualifiedName = RentAssemblyQualifiedName(rawName, libraryRecord.RawLibraryName); TypeName.TryParse(assemblyQualifiedName.AsSpan(), out TypeName? typeName, payloadOptions.TypeNameParseOptions); @@ -125,7 +135,7 @@ internal static TypeName ParseNonSystemClassRecordTypeName(this string rawName, if (typeName is null) { - throw new SerializationException(SR.Format(SR.Serialization_InvalidTypeOrAssemblyName, rawName, libraryRecord.RawLibraryName)); + throw new SerializationException(SR.Serialization_InvalidTypeOrAssemblyName); } if (typeName.AssemblyName is null) @@ -149,6 +159,10 @@ internal static TypeName WithCoreLibAssemblyName(this TypeName systemType) private static TypeName With(this TypeName typeName, AssemblyNameInfo assemblyName) { + // This is a recursive method over potentially hostile TypeName arguments. + // We assume the complexity of the TypeName arg was appropriately bounded. + // See comment in TypeName.FullName property getter for more info. + if (!typeName.IsSimple) { if (typeName.IsArray) @@ -169,7 +183,7 @@ private static TypeName With(this TypeName typeName, AssemblyNameInfo assemblyNa else { // BinaryFormatter can not serialize pointers or references. - ThrowHelper.ThrowInvalidTypeName(typeName.FullName); + ThrowHelper.ThrowInvalidTypeName(); } } @@ -187,6 +201,7 @@ private static TypeName ParseWithoutAssemblyName(string rawName, PayloadOptions return typeName; } + // Complexity is O(typeName.Length + libraryName.Length) private static ArraySegment RentAssemblyQualifiedName(string typeName, string libraryName) { int length = typeName.Length + 1 + libraryName.Length; diff --git a/src/libraries/System.Formats.Nrbf/tests/ArraySinglePrimitiveRecordTests.cs b/src/libraries/System.Formats.Nrbf/tests/ArraySinglePrimitiveRecordTests.cs index d4a4b5b3c690a..49d523088a89f 100644 --- a/src/libraries/System.Formats.Nrbf/tests/ArraySinglePrimitiveRecordTests.cs +++ b/src/libraries/System.Formats.Nrbf/tests/ArraySinglePrimitiveRecordTests.cs @@ -3,6 +3,8 @@ using System.Collections.Generic; using System.IO; +using System.Runtime.Serialization; +using System.Text; using Xunit; namespace System.Formats.Nrbf.Tests; @@ -24,6 +26,51 @@ public static IEnumerable GetCanReadArrayOfAnySizeArgs() } } + [Fact] + public void DontCastBytesToBooleans() + { + using MemoryStream stream = new(); + BinaryWriter writer = new(stream, Encoding.UTF8); + + WriteSerializedStreamHeader(writer); + writer.Write((byte)SerializationRecordType.ArraySinglePrimitive); + writer.Write(1); // object ID + writer.Write(2); // length + writer.Write((byte)PrimitiveType.Boolean); // element type + writer.Write((byte)0x01); + writer.Write((byte)0x02); + writer.Write((byte)SerializationRecordType.MessageEnd); + stream.Position = 0; + + SZArrayRecord serializationRecord = (SZArrayRecord)NrbfDecoder.Decode(stream); + + bool[] bools = serializationRecord.GetArray(); + bool a = bools[0]; + Assert.True(a); + bool b = bools[1]; + Assert.True(b); + bool c = a && b; + Assert.True(c); + } + + [Fact] + public void DontCastBytesToDateTimes() + { + using MemoryStream stream = new(); + BinaryWriter writer = new(stream, Encoding.UTF8); + + WriteSerializedStreamHeader(writer); + writer.Write((byte)SerializationRecordType.ArraySinglePrimitive); + writer.Write(1); // object ID + writer.Write(1); // length + writer.Write((byte)PrimitiveType.DateTime); // element type + writer.Write(ulong.MaxValue); // un-representable DateTime + writer.Write((byte)SerializationRecordType.MessageEnd); + stream.Position = 0; + + Assert.Throws(() => NrbfDecoder.Decode(stream)); + } + [Theory] [MemberData(nameof(GetCanReadArrayOfAnySizeArgs))] public void CanReadArrayOfAnySize_Bool(int size, bool canSeek) => Test(size, canSeek); @@ -94,6 +141,7 @@ private void Test(int size, bool canSeek) SZArrayRecord arrayRecord = (SZArrayRecord)NrbfDecoder.Decode(stream); Assert.Equal(size, arrayRecord.Length); + Assert.Equal(size, arrayRecord.FlattenedLength); T?[] output = arrayRecord.GetArray(); Assert.Equal(input, output); Assert.Same(output, arrayRecord.GetArray()); diff --git a/src/libraries/System.Formats.Nrbf/tests/AttackTests.cs b/src/libraries/System.Formats.Nrbf/tests/AttackTests.cs index d9f7ac05811ad..fe780d94698df 100644 --- a/src/libraries/System.Formats.Nrbf/tests/AttackTests.cs +++ b/src/libraries/System.Formats.Nrbf/tests/AttackTests.cs @@ -154,7 +154,7 @@ public void ArraysOfBytesAreNotBeingPreAllocated() writer.Write((byte)SerializationRecordType.ArraySinglePrimitive); writer.Write(1); // object ID writer.Write(Array.MaxLength); // length - writer.Write((byte)2); // PrimitiveType.Byte + writer.Write((byte)PrimitiveType.Byte); writer.Write((byte)SerializationRecordType.MessageEnd); stream.Position = 0; diff --git a/src/libraries/System.Formats.Nrbf/tests/EdgeCaseTests.cs b/src/libraries/System.Formats.Nrbf/tests/EdgeCaseTests.cs index b443eba5ed4c9..f091d47ded8c5 100644 --- a/src/libraries/System.Formats.Nrbf/tests/EdgeCaseTests.cs +++ b/src/libraries/System.Formats.Nrbf/tests/EdgeCaseTests.cs @@ -1,4 +1,5 @@ -using System.IO; +using System.Collections.Generic; +using System.IO; using System.Runtime.Serialization.Formatters; using System.Runtime.Serialization.Formatters.Binary; using Microsoft.DotNet.XUnitExtensions; @@ -103,4 +104,44 @@ public void FormatterTypeStyleOtherThanTypesAlwaysAreNotSupportedByDesign(Format Assert.Throws(() => NrbfDecoder.Decode(ms)); } + + public static IEnumerable CanReadAllKindsOfDateTimes_Arguments + { + get + { + yield return new object[] { new DateTime(1990, 11, 24, 0, 0, 0, DateTimeKind.Local) }; + yield return new object[] { new DateTime(1990, 11, 25, 0, 0, 0, DateTimeKind.Utc) }; + yield return new object[] { new DateTime(1990, 11, 26, 0, 0, 0, DateTimeKind.Unspecified) }; + } + } + + [Theory] + [MemberData(nameof(CanReadAllKindsOfDateTimes_Arguments))] + public void CanReadAllKindsOfDateTimes_DateTimeIsTheRootRecord(DateTime input) + { + using MemoryStream stream = Serialize(input); + + PrimitiveTypeRecord dateTimeRecord = (PrimitiveTypeRecord)NrbfDecoder.Decode(stream); + + Assert.Equal(input.Ticks, dateTimeRecord.Value.Ticks); + Assert.Equal(input.Kind, dateTimeRecord.Value.Kind); + } + + [Serializable] + public class ClassWithDateTime + { + public DateTime Value; + } + + [Theory] + [MemberData(nameof(CanReadAllKindsOfDateTimes_Arguments))] + public void CanReadAllKindsOfDateTimes_DateTimeIsMemberOfTheRootRecord(DateTime input) + { + using MemoryStream stream = Serialize(new ClassWithDateTime() { Value = input }); + + ClassRecord classRecord = NrbfDecoder.DecodeClassRecord(stream); + + Assert.Equal(input.Ticks, classRecord.GetDateTime(nameof(ClassWithDateTime.Value)).Ticks); + Assert.Equal(input.Kind, classRecord.GetDateTime(nameof(ClassWithDateTime.Value)).Kind); + } } diff --git a/src/libraries/System.Formats.Nrbf/tests/InvalidInputTests.cs b/src/libraries/System.Formats.Nrbf/tests/InvalidInputTests.cs index bc134350eb7c9..6acb44d03697d 100644 --- a/src/libraries/System.Formats.Nrbf/tests/InvalidInputTests.cs +++ b/src/libraries/System.Formats.Nrbf/tests/InvalidInputTests.cs @@ -426,7 +426,10 @@ public static IEnumerable ThrowsForInvalidPrimitiveType_Arguments() { foreach (byte binaryType in new byte[] { (byte)0 /* BinaryType.Primitive */, (byte)7 /* BinaryType.PrimitiveArray */ }) { + yield return new object[] { recordType, binaryType, (byte)0 }; // value not used by the spec yield return new object[] { recordType, binaryType, (byte)4 }; // value not used by the spec + yield return new object[] { recordType, binaryType, (byte)17 }; // used by the spec, but illegal in given context + yield return new object[] { recordType, binaryType, (byte)18 }; // used by the spec, but illegal in given context yield return new object[] { recordType, binaryType, (byte)19 }; } } @@ -478,4 +481,125 @@ public void ThrowsOnInvalidArrayType() stream.Position = 0; Assert.Throws(() => NrbfDecoder.Decode(stream)); } + + [Theory] + [InlineData(18, typeof(NotSupportedException))] // not part of the spec, but still less than max allowed value (22) + [InlineData(19, typeof(NotSupportedException))] // same as above + [InlineData(20, typeof(NotSupportedException))] // same as above + [InlineData(23, typeof(SerializationException))] // not part of the spec and more than max allowed value (22) + [InlineData(64, typeof(SerializationException))] // same as above but also matches AllowedRecordTypes.SerializedStreamHeader + public void InvalidSerializationRecordType(byte recordType, Type expectedException) + { + using MemoryStream stream = new(); + BinaryWriter writer = new(stream, Encoding.UTF8); + + WriteSerializedStreamHeader(writer); + writer.Write(recordType); // SerializationRecordType + writer.Write((byte)SerializationRecordType.MessageEnd); + + stream.Position = 0; + + Assert.Throws(expectedException, () => NrbfDecoder.Decode(stream)); + } + + [Fact] + public void MissingRootRecord() + { + const int RootRecordId = 1; + using MemoryStream stream = new(); + BinaryWriter writer = new(stream, Encoding.UTF8); + + WriteSerializedStreamHeader(writer, rootId: RootRecordId); + writer.Write((byte)SerializationRecordType.BinaryObjectString); + writer.Write(RootRecordId + 1); // a different ID + writer.Write("theString"); + writer.Write((byte)SerializationRecordType.MessageEnd); + + stream.Position = 0; + + Assert.Throws(() => NrbfDecoder.Decode(stream)); + } + + [Fact] + public void Invalid7BitEncodedStringLength() + { + // The highest bit of the last byte is set (so it's invalid). + byte[] invalidLength = [byte.MaxValue, byte.MaxValue, byte.MaxValue, byte.MaxValue, byte.MaxValue]; + + using MemoryStream stream = new(); + BinaryWriter writer = new(stream, Encoding.UTF8); + + WriteSerializedStreamHeader(writer); + writer.Write((byte)SerializationRecordType.BinaryObjectString); + writer.Write(1); // root record Id + writer.Write(invalidLength); // the length prefix + writer.Write(Encoding.UTF8.GetBytes("theString")); + writer.Write((byte)SerializationRecordType.MessageEnd); + + stream.Position = 0; + + Assert.Throws(() => NrbfDecoder.Decode(stream)); + } + + [Theory] + [InlineData("79228162514264337593543950336")] // invalid format (decimal.MaxValue + 1) + [InlineData("1111111111111111111111111111111111111111111111111")] // overflow + public void InvalidDecimal(string textRepresentation) + { + using MemoryStream stream = new(); + BinaryWriter writer = new(stream, Encoding.UTF8); + + WriteSerializedStreamHeader(writer); + writer.Write((byte)SerializationRecordType.SystemClassWithMembersAndTypes); + writer.Write(1); // root record Id + writer.Write("ClassWithDecimalField"); // type name + writer.Write(1); // member count + writer.Write("memberName"); + writer.Write((byte)BinaryType.Primitive); + writer.Write((byte)PrimitiveType.Decimal); + writer.Write(textRepresentation); + writer.Write((byte)SerializationRecordType.MessageEnd); + + stream.Position = 0; + + Assert.Throws(() => NrbfDecoder.Decode(stream)); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void SurrogateCharacters(bool array) + { + using MemoryStream stream = new(); + BinaryWriter writer = new(stream, Encoding.UTF8); + + WriteSerializedStreamHeader(writer); + writer.Write((byte)SerializationRecordType.SystemClassWithMembersAndTypes); + writer.Write(1); // root record Id + writer.Write("ClassWithCharField"); // type name + writer.Write(1); // member count + writer.Write("memberName"); + + if (array) + { + writer.Write((byte)BinaryType.PrimitiveArray); + writer.Write((byte)PrimitiveType.Char); + writer.Write((byte)SerializationRecordType.ArraySinglePrimitive); + writer.Write(2); // array record Id + writer.Write(1); // array length + writer.Write((byte)PrimitiveType.Char); + } + else + { + writer.Write((byte)BinaryType.Primitive); + writer.Write((byte)PrimitiveType.Char); + } + + writer.Write((byte)0xC0); // a surrogate character + writer.Write((byte)SerializationRecordType.MessageEnd); + + stream.Position = 0; + + Assert.Throws(() => NrbfDecoder.Decode(stream)); + } } diff --git a/src/libraries/System.Formats.Nrbf/tests/JaggedArraysTests.cs b/src/libraries/System.Formats.Nrbf/tests/JaggedArraysTests.cs index a72c3227c1eec..8bb844ff76a58 100644 --- a/src/libraries/System.Formats.Nrbf/tests/JaggedArraysTests.cs +++ b/src/libraries/System.Formats.Nrbf/tests/JaggedArraysTests.cs @@ -1,4 +1,5 @@ using System.Formats.Nrbf.Utils; +using System.IO; using System.Linq; using Xunit; @@ -6,29 +7,91 @@ namespace System.Formats.Nrbf.Tests; public class JaggedArraysTests : ReadTests { - [Fact] - public void CanReadJaggedArraysOfPrimitiveTypes_2D() + [Theory] + [InlineData(true)] + [InlineData(false)] + public void CanReadJaggedArraysOfPrimitiveTypes_2D(bool useReferences) { int[][] input = new int[7][]; + int[] same = [1, 2, 3]; for (int i = 0; i < input.Length; i++) { - input[i] = [i, i, i]; + input[i] = useReferences + ? same // reuse the same object (represented as a single record that is referenced multiple times) + : [i, i, i]; // create new array } var arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input)); Verify(input, arrayRecord); Assert.Equal(input, arrayRecord.GetArray(input.GetType())); + Assert.Equal(input.Length + input.Length * 3, arrayRecord.FlattenedLength); + } + + [Theory] + [InlineData(1)] // SerializationRecordType.ObjectNull + [InlineData(200)] // SerializationRecordType.ObjectNullMultiple256 + [InlineData(10_000)] // SerializationRecordType.ObjectNullMultiple + public void FlattenedLengthIncludesNullArrays(int nullCount) + { + int[][] input = new int[nullCount][]; + + var arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input)); + + Verify(input, arrayRecord); + Assert.Equal(input, arrayRecord.GetArray(input.GetType())); + Assert.Equal(nullCount, arrayRecord.FlattenedLength); + } + + [Fact] + public void ItIsPossibleToHaveBinaryArrayRecordsHaveAnElementTypeOfArrayWithoutBeingMarkedAsJagged() + { + int[][][] input = new int[3][][]; + long totalElementsCount = 0; + for (int i = 0; i < input.Length; i++) + { + input[i] = new int[4][]; + totalElementsCount++; // count the arrays themselves + + for (int j = 0; j < input[i].Length; j++) + { + input[i][j] = [i, j, 0, 1, 2]; + totalElementsCount += input[i][j].Length; + totalElementsCount++; // count the arrays themselves + } + } + + byte[] serialized = Serialize(input).ToArray(); + const int ArrayTypeByteIndex = + sizeof(byte) + sizeof(int) * 4 + // stream header + sizeof(byte) + // SerializationRecordType.BinaryArray + sizeof(int); // SerializationRecordId + + Assert.Equal((byte)BinaryArrayType.Jagged, serialized[ArrayTypeByteIndex]); + + // change the reported array type + serialized[ArrayTypeByteIndex] = (byte)BinaryArrayType.Single; + + var arrayRecord = (ArrayRecord)NrbfDecoder.Decode(new MemoryStream(serialized)); + + Verify(input, arrayRecord); + Assert.Equal(input, arrayRecord.GetArray(input.GetType())); + Assert.Equal(3 + 3 * 4 + 3 * 4 * 5, totalElementsCount); + Assert.Equal(totalElementsCount, arrayRecord.FlattenedLength); } [Fact] public void CanReadJaggedArraysOfPrimitiveTypes_3D() { int[][][] input = new int[7][][]; + long totalElementsCount = 0; for (int i = 0; i < input.Length; i++) { + totalElementsCount++; // count the arrays themselves input[i] = new int[1][]; + totalElementsCount++; // count the arrays themselves input[i][0] = [i, i, i]; + totalElementsCount += input[i][0].Length; } var arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input)); @@ -36,6 +99,8 @@ public void CanReadJaggedArraysOfPrimitiveTypes_3D() Verify(input, arrayRecord); Assert.Equal(input, arrayRecord.GetArray(input.GetType())); Assert.Equal(1, arrayRecord.Rank); + Assert.Equal(7 + 7 * 1 + 7 * 1 * 3, totalElementsCount); + Assert.Equal(totalElementsCount, arrayRecord.FlattenedLength); } [Fact] @@ -60,6 +125,7 @@ public void CanReadJaggedArrayOfRectangularArrays() Verify(input, arrayRecord); Assert.Equal(input, arrayRecord.GetArray(input.GetType())); Assert.Equal(1, arrayRecord.Rank); + Assert.Equal(input.Length + input.Length * 3 * 3, arrayRecord.FlattenedLength); } [Fact] @@ -75,6 +141,7 @@ public void CanReadJaggedArraysOfStrings() Verify(input, arrayRecord); Assert.Equal(input, arrayRecord.GetArray(input.GetType())); + Assert.Equal(input.Length + input.Length * 3, arrayRecord.FlattenedLength); } [Fact] @@ -90,6 +157,7 @@ public void CanReadJaggedArraysOfObjects() Verify(input, arrayRecord); Assert.Equal(input, arrayRecord.GetArray(input.GetType())); + Assert.Equal(input.Length + input.Length * 3, arrayRecord.FlattenedLength); } [Serializable] @@ -102,14 +170,18 @@ public class ComplexType public void CanReadJaggedArraysOfComplexTypes() { ComplexType[][] input = new ComplexType[3][]; + long totalElementsCount = 0; for (int i = 0; i < input.Length; i++) { input[i] = Enumerable.Range(0, i + 1).Select(j => new ComplexType { SomeField = j }).ToArray(); + totalElementsCount += input[i].Length; + totalElementsCount++; // count the arrays themselves } var arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input)); Verify(input, arrayRecord); + Assert.Equal(totalElementsCount, arrayRecord.FlattenedLength); var output = (ClassRecord?[][])arrayRecord.GetArray(input.GetType()); for (int i = 0; i < input.Length; i++) { diff --git a/src/libraries/System.Formats.Nrbf/tests/ReadTests.cs b/src/libraries/System.Formats.Nrbf/tests/ReadTests.cs index b9bee7d881a69..0c7bd2045fa1f 100644 --- a/src/libraries/System.Formats.Nrbf/tests/ReadTests.cs +++ b/src/libraries/System.Formats.Nrbf/tests/ReadTests.cs @@ -45,10 +45,10 @@ protected static BinaryFormatter CreateBinaryFormatter() }; #pragma warning restore SYSLIB0011 // Type or member is obsolete - protected static void WriteSerializedStreamHeader(BinaryWriter writer, int major = 1, int minor = 0) + protected static void WriteSerializedStreamHeader(BinaryWriter writer, int major = 1, int minor = 0, int rootId = 1) { writer.Write((byte)SerializationRecordType.SerializedStreamHeader); - writer.Write(1); // root ID + writer.Write(rootId); // root ID writer.Write(1); // header ID writer.Write(major); // major version writer.Write(minor); // minor version diff --git a/src/libraries/System.Formats.Nrbf/tests/RectangularArraysTests.cs b/src/libraries/System.Formats.Nrbf/tests/RectangularArraysTests.cs index 25e7bb5a4d533..3191d57ba807c 100644 --- a/src/libraries/System.Formats.Nrbf/tests/RectangularArraysTests.cs +++ b/src/libraries/System.Formats.Nrbf/tests/RectangularArraysTests.cs @@ -223,10 +223,13 @@ public void CanReadRectangularArraysOfComplexTypes_3D() internal static void Verify(Array input, ArrayRecord arrayRecord) { Assert.Equal(input.Rank, arrayRecord.Lengths.Length); + long totalElementsCount = 1; for (int i = 0; i < input.Rank; i++) { Assert.Equal(input.GetLength(i), arrayRecord.Lengths[i]); + totalElementsCount *= input.GetLength(i); } + Assert.Equal(totalElementsCount, arrayRecord.FlattenedLength); Assert.Equal(input.GetType().FullName, arrayRecord.TypeName.FullName); Assert.Equal(input.GetType().GetAssemblyNameIncludingTypeForwards(), arrayRecord.TypeName.AssemblyName!.FullName); } diff --git a/src/libraries/System.Formats.Nrbf/tests/System.Formats.Nrbf.Tests.csproj b/src/libraries/System.Formats.Nrbf/tests/System.Formats.Nrbf.Tests.csproj index a9da043b5ecd5..c31537bb983ef 100644 --- a/src/libraries/System.Formats.Nrbf/tests/System.Formats.Nrbf.Tests.csproj +++ b/src/libraries/System.Formats.Nrbf/tests/System.Formats.Nrbf.Tests.csproj @@ -7,6 +7,8 @@ + + diff --git a/src/libraries/System.Formats.Nrbf/tests/TypeMatchTests.cs b/src/libraries/System.Formats.Nrbf/tests/TypeMatchTests.cs index 6a981197f82f1..e0827c1225b42 100644 --- a/src/libraries/System.Formats.Nrbf/tests/TypeMatchTests.cs +++ b/src/libraries/System.Formats.Nrbf/tests/TypeMatchTests.cs @@ -73,6 +73,34 @@ public void CanRecognizeGenericSystemTypes() Verify(new Dictionary>>()); } + [Fact] + public void ThrowsForNullType() + { + List input = new List(); + + SerializationRecord record = NrbfDecoder.Decode(Serialize(input)); + + Assert.Throws(() => record.TypeNameMatches(type: null)); + } + + [Fact] + public void TakesCustomOffsetsIntoAccount() + { + int[] input = [1, 2, 3]; + + SerializationRecord record = NrbfDecoder.Decode(Serialize(input)); + + Assert.True(record.TypeNameMatches(typeof(int[]))); + + Type nonSzArray = typeof(int).Assembly.GetType("System.Int32[*]"); +#if NET + Assert.False(nonSzArray.IsSZArray); + Assert.True(nonSzArray.IsVariableBoundArray); +#endif + Assert.Equal(1, nonSzArray.GetArrayRank()); + Assert.False(record.TypeNameMatches(nonSzArray)); + } + [Fact] public void TakesGenericTypeDefinitionIntoAccount() { diff --git a/src/libraries/System.Reflection.Metadata/src/System/Reflection/Metadata/AssemblyNameInfo.cs b/src/libraries/System.Reflection.Metadata/src/System/Reflection/Metadata/AssemblyNameInfo.cs index bd1febff03659..3ac30c92db57a 100644 --- a/src/libraries/System.Reflection.Metadata/src/System/Reflection/Metadata/AssemblyNameInfo.cs +++ b/src/libraries/System.Reflection.Metadata/src/System/Reflection/Metadata/AssemblyNameInfo.cs @@ -81,6 +81,10 @@ internal AssemblyNameInfo(AssemblyNameParser.AssemblyNameParts parts) /// /// Gets the name of the culture associated with the assembly. /// + /// + /// Do not create a instance from this string unless + /// you know the string has originated from a trustworthy source. + /// public string? CultureName { get; } /// @@ -131,6 +135,10 @@ public string FullName /// /// Initializes a new instance of the class based on the stored information. /// + /// + /// Do not create an instance with string unless + /// you know the string has originated from a trustworthy source. + /// public AssemblyName ToAssemblyName() { AssemblyName assemblyName = new(); diff --git a/src/libraries/System.Reflection.Metadata/src/System/Reflection/Metadata/TypeName.cs b/src/libraries/System.Reflection.Metadata/src/System/Reflection/Metadata/TypeName.cs index 22ae86f08e6b5..2b620761bb708 100644 --- a/src/libraries/System.Reflection.Metadata/src/System/Reflection/Metadata/TypeName.cs +++ b/src/libraries/System.Reflection.Metadata/src/System/Reflection/Metadata/TypeName.cs @@ -95,7 +95,7 @@ private TypeName(string? fullName, /// If returns null, simply returns . /// public string AssemblyQualifiedName - => _assemblyQualifiedName ??= AssemblyName is null ? FullName : $"{FullName}, {AssemblyName.FullName}"; + => _assemblyQualifiedName ??= AssemblyName is null ? FullName : $"{FullName}, {AssemblyName.FullName}"; // see recursion comments in FullName /// /// Returns assembly name which contains this type, or null if this was not @@ -142,6 +142,17 @@ public string FullName { get { + // This is a recursive method over potentially hostile input. Protection against DoS is offered + // via the [Try]Parse method and TypeNameParserOptions.MaxNodes property at construction time. + // This FullName property getter and related methods assume that this TypeName instance has an + // acceptable node count. + // + // The node count controls the total amount of work performed by this method, including: + // - The max possible stack depth due to the recursive methods calls; and + // - The total number of bytes allocated by this function. For a deeply-nested TypeName + // object, the total allocation across the full object graph will be + // O(FullName.Length * GetNodeCount()). + if (_fullName is null) { if (IsConstructedGenericType) @@ -245,6 +256,8 @@ public string Name { get { + // Lookups to Name and FullName might be recursive. See comments in FullName property getter. + if (_name is null) { if (IsConstructedGenericType) @@ -425,6 +438,17 @@ public int GetArrayRank() /// The current type name is not simple. public TypeName WithAssemblyName(AssemblyNameInfo? assemblyName) { + // Recursive method. See comments in FullName property getter for more information + // on how this is protected against attack. + // + // n.b. AssemblyNameInfo could also be hostile. The typical exploit is that a single + // long AssemblyNameInfo is associated with one or more simple TypeName objects, + // leading to an alg. complexity attack (DoS). It's important that TypeName doesn't + // actually *do* anything with the provided AssemblyNameInfo rather than store it. + // For example, don't use it inside a string concat operation unless the caller + // explicitly requested that to happen. If the input is hostile, the caller should + // never perform such concats in a loop. + if (!IsSimple) { TypeNameParserHelpers.ThrowInvalidOperation_NotSimpleName(FullName); diff --git a/src/libraries/System.Reflection.Metadata/src/System/Reflection/Metadata/TypeNameParser.cs b/src/libraries/System.Reflection.Metadata/src/System/Reflection/Metadata/TypeNameParser.cs index 97294c8014f2b..08ed944d03108 100644 --- a/src/libraries/System.Reflection.Metadata/src/System/Reflection/Metadata/TypeNameParser.cs +++ b/src/libraries/System.Reflection.Metadata/src/System/Reflection/Metadata/TypeNameParser.cs @@ -80,6 +80,8 @@ private TypeNameParser(ReadOnlySpan name, bool throwOnError, TypeNameParse return null; } + // At this point, we have performed O(fullTypeNameLength) total work. + ReadOnlySpan fullTypeName = _inputString.Slice(0, fullTypeNameLength); _inputString = _inputString.Slice(fullTypeNameLength); @@ -142,6 +144,12 @@ private TypeNameParser(ReadOnlySpan name, bool throwOnError, TypeNameParse } } + // At this point, we may have performed O(fullTypeNameLength + _inputString.Length) total work. + // This will be the case if there was whitespace after the full type name in the original input + // string. We could end up looking at these same whitespace chars again later in this method, + // such as when parsing decorators. We rely on the TryDive routine to limit the total number + // of times we might inspect the same character. + // If there was an error stripping the generic args, back up to // before we started processing them, and let the decorator // parser try handling it. @@ -202,6 +210,9 @@ private TypeNameParser(ReadOnlySpan name, bool throwOnError, TypeNameParse result = new(fullName: null, assemblyName, elementOrGenericType: result, declaringType, genericArgs); } + // The loop below is protected by the dive check during the first decorator pass prior + // to assembly name parsing above. + if (previousDecorator != default) // some decorators were recognized { while (TryParseNextDecorator(ref capturedBeforeProcessing, out int parsedModifier)) @@ -245,6 +256,8 @@ private bool TryParseAssemblyName(ref AssemblyNameInfo? assemblyName) return null; } + // The loop below is protected by the dive check in GetFullTypeNameLength. + TypeName? declaringType = null; int nameOffset = 0; foreach (int nestedNameLength in nestedNameLengths) diff --git a/src/libraries/System.Reflection.Metadata/src/System/Reflection/Metadata/TypeNameParserHelpers.cs b/src/libraries/System.Reflection.Metadata/src/System/Reflection/Metadata/TypeNameParserHelpers.cs index 93859262eed99..7cafd746b7d17 100644 --- a/src/libraries/System.Reflection.Metadata/src/System/Reflection/Metadata/TypeNameParserHelpers.cs +++ b/src/libraries/System.Reflection.Metadata/src/System/Reflection/Metadata/TypeNameParserHelpers.cs @@ -16,6 +16,7 @@ internal static class TypeNameParserHelpers internal const int ByRef = -3; private const char EscapeCharacter = '\\'; #if NET8_0_OR_GREATER + // Keep this in sync with GetFullTypeNameLength/NeedsEscaping private static readonly SearchValues s_endOfFullTypeNameDelimitersSearchValues = SearchValues.Create("[]&*,+\\"); #endif @@ -30,7 +31,7 @@ internal static string GetGenericTypeFullName(ReadOnlySpan fullTypeName, R foreach (TypeName genericArg in genericArgs) { result.Append('['); - result.Append(genericArg.AssemblyQualifiedName); + result.Append(genericArg.AssemblyQualifiedName); // see recursion comments in TypeName.FullName result.Append(']'); result.Append(','); } @@ -97,11 +98,16 @@ static int GetUnescapedOffset(ReadOnlySpan input, int startOffset) return offset; } + // Keep this in sync with s_endOfFullTypeNameDelimitersSearchValues static bool NeedsEscaping(char c) => c is '[' or ']' or '&' or '*' or ',' or '+' or EscapeCharacter; } internal static ReadOnlySpan GetName(ReadOnlySpan fullName) { + // The two-value form of MemoryExtensions.LastIndexOfAny does not suffer + // from the behavior mentioned in the comment at the top of GetFullTypeNameLength. + // It always takes O(m * i) worst-case time and is safe to use here. + int offset = fullName.LastIndexOfAny('.', '+'); if (offset > 0 && fullName[offset - 1] == EscapeCharacter) // this should be very rare (IL Emit & pure IL) @@ -182,6 +188,13 @@ internal static string GetRankOrModifierStringRepresentation(int rankOrModifier, { Debug.Assert(rankOrModifier >= 2); + // O(rank) work, so we have to assume the rank is trusted. We don't put a hard cap on this, + // but within the TypeName parser, we do require the input string to contain the correct number + // of commas. This forces the input string to have at least O(rank) length, so there's no + // alg. complexity attack possible here. Callers can of course pass any arbitrary value to + // TypeName.MakeArrayTypeName, but per first sentence in this comment, we have to assume any + // such arbitrary value which is programmatically fed in originates from a trustworthy source. + builder.Append('['); builder.Append(',', rankOrModifier - 1); builder.Append(']'); @@ -310,6 +323,9 @@ internal static bool TryParseNextDecorator(ref ReadOnlySpan input, out int else if (TryStripFirstCharAndTrailingSpaces(ref input, ',')) { // [,,, ...] + // The runtime restricts arrays to rank 32, but we don't enforce that here. + // Instead, the max rank is controlled by the total number of commas present + // in the array decorator. checked { rank++; } goto ReadNextArrayToken; } diff --git a/src/libraries/System.Reflection.Metadata/src/System/Reflection/Metadata/TypeNameParserOptions.cs b/src/libraries/System.Reflection.Metadata/src/System/Reflection/Metadata/TypeNameParserOptions.cs index b7420c40aa9ea..53d6f7f164275 100644 --- a/src/libraries/System.Reflection.Metadata/src/System/Reflection/Metadata/TypeNameParserOptions.cs +++ b/src/libraries/System.Reflection.Metadata/src/System/Reflection/Metadata/TypeNameParserOptions.cs @@ -10,6 +10,13 @@ public sealed class TypeNameParseOptions /// /// Limits the maximum value of node count that parser can handle. /// + /// + /// + /// Setting this to a large value can render susceptible to Denial of Service + /// attacks when parsing or handling malicious input. + /// + /// The default value is 20. + /// public int MaxNodes { get => _maxNodes;