Skip to content

Commit

Permalink
Add TargetHostName to QuicConnection (#84976)
Browse files Browse the repository at this point in the history
* Add TargetHostName to QuicConnection
Fixes #80508

* Make TargetHostName not nullable

* Fix build

* Fix build of tests

* Fix failing tests

* Code review feedback

* Use unencoded hostname in user-facing properties/params

* Fix failing tests

* Revert unwanted changes

* Add test for IDN cert validation

* Fix test again

* Fix trailing dot in hostname
  • Loading branch information
rzikm authored Apr 25, 2023
1 parent a6f20f2 commit 25b61f6
Show file tree
Hide file tree
Showing 13 changed files with 232 additions and 78 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System.Buffers;
using System.Collections.Generic;
using System.Globalization;
using System.Runtime.InteropServices;

namespace System.Net.Security
{
internal static class TargetHostNameHelper
{
private static readonly IdnMapping s_idnMapping = new IdnMapping();
private static readonly IndexOfAnyValues<char> s_safeDnsChars =
IndexOfAnyValues.Create("-.0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ_abcdefghijklmnopqrstuvwxyz");

private static bool IsSafeDnsString(ReadOnlySpan<char> name) =>
name.IndexOfAnyExcept(s_safeDnsChars) < 0;

internal static string NormalizeHostName(string? targetHost)
{
if (string.IsNullOrEmpty(targetHost))
{
return string.Empty;
}

// RFC 6066 section 3 says to exclude trailing dot from fully qualified DNS hostname
targetHost = targetHost.TrimEnd('.');

try
{
return s_idnMapping.GetAscii(targetHost);
}
catch (ArgumentException) when (IsSafeDnsString(targetHost))
{
// Seems like name that does not confrom to IDN but apers somewhat valid according to original DNS rfc.
}

return targetHost;
}

// Simplified version of IPAddressParser.Parse to avoid allocations and dependencies.
// It purposely ignores scopeId as we don't really use so we do not need to map it to actual interface id.
internal static unsafe bool IsValidAddress(string? hostname)
{
if (string.IsNullOrEmpty(hostname))
{
return false;
}

ReadOnlySpan<char> ipSpan = hostname.AsSpan();

int end = ipSpan.Length;

if (ipSpan.Contains(':'))
{
// The address is parsed as IPv6 if and only if it contains a colon. This is valid because
// we don't support/parse a port specification at the end of an IPv4 address.
Span<ushort> numbers = stackalloc ushort[IPAddressParserStatics.IPv6AddressShorts];

fixed (char* ipStringPtr = &MemoryMarshal.GetReference(ipSpan))
{
return IPv6AddressHelper.IsValidStrict(ipStringPtr, 0, ref end);
}
}
else if (char.IsDigit(ipSpan[0]))
{
long tmpAddr;

fixed (char* ipStringPtr = &MemoryMarshal.GetReference(ipSpan))
{
tmpAddr = IPv4AddressHelper.ParseNonCanonical(ipStringPtr, 0, ref end, notImplicitFile: true);
}

if (tmpAddr != IPv4AddressHelper.Invalid && end == ipSpan.Length)
{
return true;
}
}

return false;
}
}
}
2 changes: 2 additions & 0 deletions src/libraries/System.Net.Quic/ref/System.Net.Quic.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ internal QuicConnection() { }
public System.Net.Security.SslApplicationProtocol NegotiatedApplicationProtocol { get { throw null; } }
public System.Security.Cryptography.X509Certificates.X509Certificate? RemoteCertificate { get { throw null; } }
public System.Net.IPEndPoint RemoteEndPoint { get { throw null; } }
public string TargetHostName { get { throw null; } }
public System.Threading.Tasks.ValueTask<System.Net.Quic.QuicStream> AcceptInboundStreamAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public System.Threading.Tasks.ValueTask CloseAsync(long errorCode, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public static System.Threading.Tasks.ValueTask<System.Net.Quic.QuicConnection> ConnectAsync(System.Net.Quic.QuicClientConnectionOptions options, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
Expand Down Expand Up @@ -122,6 +123,7 @@ public override void Flush() { }
public override int ReadByte() { throw null; }
public override long Seek(long offset, System.IO.SeekOrigin origin) { throw null; }
public override void SetLength(long value) { }
public override string ToString() { throw null; }
public override void Write(byte[] buffer, int offset, int count) { }
public override void Write(System.ReadOnlySpan<byte> buffer) { }
public override System.Threading.Tasks.Task WriteAsync(byte[] buffer, int offset, int count, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
Expand Down
4 changes: 4 additions & 0 deletions src/libraries/System.Net.Quic/src/System.Net.Quic.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@
<Compile Include="$(CommonPath)System\Net\IPAddressParserStatics.cs" Link="Common\System\Net\IPAddressParserStatics.cs" />
<Compile Include="$(CommonPath)System\Net\Internals\IPEndPointExtensions.cs" Link="Common\System\Net\Internals\IPEndPointExtensions.cs" />
<Compile Include="$(CommonPath)System\Net\Security\TlsAlertMessage.cs" Link="Common\System\Net\Security\TlsAlertMessage.cs" />
<Compile Include="$(CommonPath)System\Net\Security\TargetHostNameHelper.cs" Link="Common\System\Net\Security\TargetHostNameHelper.cs" />
<!-- IP parser -->
<Compile Include="$(CommonPath)System\Net\IPv4AddressHelper.Common.cs" Link="System\Net\IPv4AddressHelper.Common.cs" />
<Compile Include="$(CommonPath)System\Net\IPv6AddressHelper.Common.cs" Link="System\Net\IPv6AddressHelper.Common.cs" />
</ItemGroup>
<!-- Unsupported platforms -->
<ItemGroup Condition="'$(TargetPlatformIdentifier)' == ''">
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ private readonly struct SslConnectionOptions
/// <summary>
/// Host name send in SNI, set only for outbound/client connections. Configured via <see cref="SslClientAuthenticationOptions.TargetHost"/>.
/// </summary>
private readonly string? _targetHost;
private readonly string _targetHost;
/// <summary>
/// Always <c>true</c> for outbound/client connections. Configured for inbound/server ones via <see cref="SslServerAuthenticationOptions.ClientCertificateRequired"/>.
/// </summary>
Expand All @@ -47,8 +47,10 @@ private readonly struct SslConnectionOptions
/// </summary>
private readonly X509ChainPolicy? _certificateChainPolicy;

internal string TargetHost => _targetHost;

public SslConnectionOptions(QuicConnection connection, bool isClient,
string? targetHost, bool certificateRequired, X509RevocationMode
string targetHost, bool certificateRequired, X509RevocationMode
revocationMode, RemoteCertificateValidationCallback? validationCallback,
X509ChainPolicy? certificateChainPolicy)
{
Expand Down Expand Up @@ -118,7 +120,7 @@ public unsafe int ValidateCertificate(QUIC_BUFFER* certificatePtr, QUIC_BUFFER*
if (result is not null)
{
bool checkCertName = !chain!.ChainPolicy!.VerificationFlags.HasFlag(X509VerificationFlags.IgnoreInvalidName);
sslPolicyErrors |= CertificateValidation.BuildChainAndVerifyProperties(chain!, result, checkCertName, !_isClient, _targetHost, certificateBuffer, certificateLength);
sslPolicyErrors |= CertificateValidation.BuildChainAndVerifyProperties(chain!, result, checkCertName, !_isClient, TargetHostNameHelper.NormalizeHostName(_targetHost), certificateBuffer, certificateLength);
}
else if (_certificateRequired)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,12 @@ public static async ValueTask<QuicConnection> ConnectAsync(QuicClientConnectionO
/// </summary>
public IPEndPoint LocalEndPoint => _localEndPoint;

/// <summary>
/// Gets the name of the server the client is trying to connect to. That name is used for server certificate validation. It can be a DNS name or an IP address.
/// </summary>
/// <returns>The name of the server the client is trying to connect to.</returns>
public string TargetHostName => _sslConnectionOptions.TargetHost ?? string.Empty;

/// <summary>
/// The certificate provided by the peer.
/// For an outbound/client connection will always have the peer's (server) certificate; for an inbound/server one, only if the connection requested and the peer (client) provided one.
Expand Down Expand Up @@ -279,10 +285,16 @@ private async ValueTask FinishConnectAsync(QuicClientConnectionOptions options,
MsQuicHelpers.SetMsQuicParameter(_handle, QUIC_PARAM_CONN_LOCAL_ADDRESS, quicAddress);
}

// RFC 6066 forbids IP literals
// DNI mapping is handled by MsQuic
var hostname = TargetHostNameHelper.IsValidAddress(options.ClientAuthenticationOptions.TargetHost)
? string.Empty
: options.ClientAuthenticationOptions.TargetHost ?? string.Empty;

_sslConnectionOptions = new SslConnectionOptions(
this,
isClient: true,
options.ClientAuthenticationOptions.TargetHost,
hostname,
certificateRequired: true,
options.ClientAuthenticationOptions.CertificateRevocationCheckMode,
options.ClientAuthenticationOptions.RemoteCertificateValidationCallback,
Expand Down Expand Up @@ -312,7 +324,7 @@ private async ValueTask FinishConnectAsync(QuicClientConnectionOptions options,
await valueTask.ConfigureAwait(false);
}

internal ValueTask FinishHandshakeAsync(QuicServerConnectionOptions options, string? targetHost, CancellationToken cancellationToken = default)
internal ValueTask FinishHandshakeAsync(QuicServerConnectionOptions options, string targetHost, CancellationToken cancellationToken = default)
{
ObjectDisposedException.ThrowIf(_disposed == 1, this);

Expand All @@ -322,10 +334,16 @@ internal ValueTask FinishHandshakeAsync(QuicServerConnectionOptions options, str
_defaultStreamErrorCode = options.DefaultStreamErrorCode;
_defaultCloseErrorCode = options.DefaultCloseErrorCode;

// RFC 6066 forbids IP literals, avoid setting IP address here for consistency with SslStream
if (TargetHostNameHelper.IsValidAddress(targetHost))
{
targetHost = string.Empty;
}

_sslConnectionOptions = new SslConnectionOptions(
this,
isClient: false,
targetHost: null,
targetHost,
options.ServerAuthenticationOptions.ClientCertificateRequired,
options.ServerAuthenticationOptions.CertificateRevocationCheckMode,
options.ServerAuthenticationOptions.RemoteCertificateValidationCallback,
Expand Down
53 changes: 53 additions & 0 deletions src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1204,5 +1204,58 @@ public async Task IdleTimeout_ThrowsQuicException()
await AssertThrowsQuicExceptionAsync(QuicError.ConnectionIdle, async () => await acceptTask).WaitAsync(TimeSpan.FromSeconds(10));
}
}

private async Task SniTestCore(string hostname, bool shouldSendSni)
{
string expectedHostName = shouldSendSni ? hostname : string.Empty;

using X509Certificate serverCert = Configuration.Certificates.GetSelfSignedServerCertificate();
var listenerOptions = new QuicListenerOptions()
{
ListenEndPoint = new IPEndPoint(IPAddress.Loopback, 0),
ApplicationProtocols = new List<SslApplicationProtocol>() { ApplicationProtocol },
ConnectionOptionsCallback = (_, _, _) =>
{
var serverOptions = CreateQuicServerOptions();
serverOptions.ServerAuthenticationOptions.ServerCertificateContext = null;
serverOptions.ServerAuthenticationOptions.ServerCertificate = null;
serverOptions.ServerAuthenticationOptions.ServerCertificateSelectionCallback = (sender, actualHostName) =>
{
Assert.Equal(expectedHostName, actualHostName);
return serverCert;
};
return ValueTask.FromResult(serverOptions);
}
};

// Use whatever endpoint, it'll get overwritten in CreateConnectedQuicConnection.
QuicClientConnectionOptions clientOptions = CreateQuicClientOptions(listenerOptions.ListenEndPoint);
clientOptions.ClientAuthenticationOptions.TargetHost = hostname;
clientOptions.ClientAuthenticationOptions.RemoteCertificateValidationCallback = delegate { return true; };


(QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(clientOptions, listenerOptions);
await using (clientConnection)
await using (serverConnection)
{
Assert.Equal(expectedHostName, clientConnection.TargetHostName);
Assert.Equal(expectedHostName, serverConnection.TargetHostName);
}
}

[Theory]
[InlineData("a")]
[InlineData("test")]
[InlineData("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")] // max allowed hostname length is 63
[InlineData("\u017C\u00F3\u0142\u0107 g\u0119\u015Bl\u0105 ja\u017A\u0144. \u7EA2\u70E7. \u7167\u308A\u713C\u304D")]
public Task ClientSendsSniServerReceives_Ok(string hostname) => SniTestCore(hostname, true);

[Theory]
[InlineData("127.0.0.1")]
[InlineData("::1")]
[InlineData("2001:11:22::1")]
[InlineData("fe80::9c3a:b64d:6249:1de8%2")]
[InlineData("fe80::9c3a:b64d:6249:1de8")]
public Task DoesNotSendIPAsSni(string target) => SniTestCore(target, false);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ public async Task TestConnect()
{
await using QuicListener listener = await CreateQuicListener();

ValueTask<QuicConnection> connectTask = CreateQuicConnection(listener.LocalEndPoint);
var options = CreateQuicClientOptions(listener.LocalEndPoint);
ValueTask<QuicConnection> connectTask = CreateQuicConnection(options);
ValueTask<QuicConnection> acceptTask = listener.AcceptConnectionAsync();

await new Task[] { connectTask.AsTask(), acceptTask.AsTask() }.WhenAllOrAnyFailed(PassingTestTimeoutMilliseconds);
Expand All @@ -34,6 +35,8 @@ public async Task TestConnect()
Assert.Equal(clientConnection.LocalEndPoint, serverConnection.RemoteEndPoint);
Assert.Equal(ApplicationProtocol.ToString(), clientConnection.NegotiatedApplicationProtocol.ToString());
Assert.Equal(ApplicationProtocol.ToString(), serverConnection.NegotiatedApplicationProtocol.ToString());
Assert.Equal(options.ClientAuthenticationOptions.TargetHost, clientConnection.TargetHostName);
Assert.Equal(options.ClientAuthenticationOptions.TargetHost, serverConnection.TargetHostName);
}

private static async Task<QuicStream> OpenAndUseStreamAsync(QuicConnection c)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@
Link="Common\System\NotImplemented.cs" />
<Compile Include="$(CommonPath)System\Net\Security\TlsAlertMessage.cs"
Link="Common\System\Net\Security\TlsAlertMessage.cs" />
<Compile Include="$(CommonPath)System\Net\Security\TargetHostNameHelper.cs"
Link="Common\System\Net\Security\TargetHostNameHelper.cs" />
<Compile Include="$(CommonPath)System\Net\Security\SafeCredentialReference.cs"
Link="Common\System\Net\Security\SafeCredentialReference.cs" />
<Compile Include="$(CommonPath)System\Net\Security\SSPIHandleCache.cs"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,58 +4,13 @@
using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
using System.Globalization;
using System.Runtime.InteropServices;
using System.Security.Authentication;
using System.Security.Cryptography.X509Certificates;

namespace System.Net.Security
{
internal sealed class SslAuthenticationOptions
{
private static readonly IdnMapping s_idnMapping = new IdnMapping();

// Simplified version of IPAddressParser.Parse to avoid allocations and dependencies.
// It purposely ignores scopeId as we don't really use so we do not need to map it to actual interface id.
private static unsafe bool IsValidAddress(ReadOnlySpan<char> ipSpan)
{
int end = ipSpan.Length;

if (ipSpan.Contains(':'))
{
// The address is parsed as IPv6 if and only if it contains a colon. This is valid because
// we don't support/parse a port specification at the end of an IPv4 address.
Span<ushort> numbers = stackalloc ushort[IPAddressParserStatics.IPv6AddressShorts];

fixed (char* ipStringPtr = &MemoryMarshal.GetReference(ipSpan))
{
return IPv6AddressHelper.IsValidStrict(ipStringPtr, 0, ref end);
}
}
else if (char.IsDigit(ipSpan[0]))
{
long tmpAddr;

fixed (char* ipStringPtr = &MemoryMarshal.GetReference(ipSpan))
{
tmpAddr = IPv4AddressHelper.ParseNonCanonical(ipStringPtr, 0, ref end, notImplicitFile: true);
}

if (tmpAddr != IPv4AddressHelper.Invalid && end == ipSpan.Length)
{
return true;
}
}

return false;
}

private static readonly IndexOfAnyValues<char> s_safeDnsChars =
IndexOfAnyValues.Create("-.0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ_abcdefghijklmnopqrstuvwxyz");

private static bool IsSafeDnsString(ReadOnlySpan<char> name) =>
name.IndexOfAnyExcept(s_safeDnsChars) < 0;

internal SslAuthenticationOptions()
{
TargetHost = string.Empty;
Expand Down Expand Up @@ -93,29 +48,11 @@ internal void UpdateOptions(SslClientAuthenticationOptions sslClientAuthenticati
IsServer = false;
RemoteCertRequired = true;
CertificateContext = sslClientAuthenticationOptions.ClientCertificateContext;
if (!string.IsNullOrEmpty(sslClientAuthenticationOptions.TargetHost))
{
// RFC 6066 section 3 says to exclude trailing dot from fully qualified DNS hostname
string targetHost = sslClientAuthenticationOptions.TargetHost.TrimEnd('.');

// RFC 6066 forbids IP literals
if (IsValidAddress(targetHost))
{
TargetHost = string.Empty;
}
else
{
try
{
TargetHost = s_idnMapping.GetAscii(targetHost);
}
catch (ArgumentException) when (IsSafeDnsString(targetHost))
{
// Seems like name that does not confrom to IDN but apers somewhat valid according to orogional DNS rfc.
TargetHost = targetHost;
}
}
}
// RFC 6066 forbids IP literals
TargetHost = TargetHostNameHelper.IsValidAddress(sslClientAuthenticationOptions.TargetHost)
? string.Empty
: sslClientAuthenticationOptions.TargetHost ?? string.Empty;

// Client specific options.
CertificateRevocationCheckMode = sslClientAuthenticationOptions.CertificateRevocationCheckMode;
Expand Down
Loading

0 comments on commit 25b61f6

Please sign in to comment.