From a6c64c3c26417f579022e20cbcee992b4814ed2c Mon Sep 17 00:00:00 2001 From: Tomas Weinfurt Date: Thu, 14 Sep 2023 21:40:29 -0700 Subject: [PATCH] fix ReceiveFrom with dual mode socket (#92086) * fix ReceiveFrom with dual mode socket * test * feedback --- .../src/Resources/Strings.resx | 2 +- .../src/System/Net/IPEndPoint.cs | 6 +-- .../tests/FunctionalTests/IPEndPointTest.cs | 16 ++++++- .../Net/Sockets/SocketAsyncEventArgs.cs | 17 +++++-- .../tests/FunctionalTests/ReceiveFrom.cs | 46 +++++++++++++++++++ 5 files changed, 76 insertions(+), 11 deletions(-) diff --git a/src/libraries/System.Net.Primitives/src/Resources/Strings.resx b/src/libraries/System.Net.Primitives/src/Resources/Strings.resx index 958a0e2e269f9..65d4809398b3b 100644 --- a/src/libraries/System.Net.Primitives/src/Resources/Strings.resx +++ b/src/libraries/System.Net.Primitives/src/Resources/Strings.resx @@ -64,7 +64,7 @@ This property is not implemented by this class. - The AddressFamily {0} is not valid for the {1} end point, use {2} instead. + The AddressFamily {0} is not valid for the {1} end point. The supplied {0} is an invalid size for the {1} end point. diff --git a/src/libraries/System.Net.Primitives/src/System/Net/IPEndPoint.cs b/src/libraries/System.Net.Primitives/src/System/Net/IPEndPoint.cs index 3531f266e6c50..ff47d2fbc515e 100644 --- a/src/libraries/System.Net.Primitives/src/System/Net/IPEndPoint.cs +++ b/src/libraries/System.Net.Primitives/src/System/Net/IPEndPoint.cs @@ -155,9 +155,9 @@ public override EndPoint Create(SocketAddress socketAddress) { ArgumentNullException.ThrowIfNull(socketAddress); - if (socketAddress.Family != AddressFamily) - { - throw new ArgumentException(SR.Format(SR.net_InvalidAddressFamily, socketAddress.Family.ToString(), GetType().FullName, AddressFamily.ToString()), nameof(socketAddress)); + if (socketAddress.Family is not (AddressFamily.InterNetwork or AddressFamily.InterNetworkV6)) + { + throw new ArgumentException(SR.Format(SR.net_InvalidAddressFamily, socketAddress.Family.ToString(), GetType().FullName), nameof(socketAddress)); } int minSize = AddressFamily == AddressFamily.InterNetworkV6 ? SocketAddress.IPv6AddressSize : SocketAddress.IPv4AddressSize; diff --git a/src/libraries/System.Net.Primitives/tests/FunctionalTests/IPEndPointTest.cs b/src/libraries/System.Net.Primitives/tests/FunctionalTests/IPEndPointTest.cs index bb9b95d438e99..c233dee628dfe 100644 --- a/src/libraries/System.Net.Primitives/tests/FunctionalTests/IPEndPointTest.cs +++ b/src/libraries/System.Net.Primitives/tests/FunctionalTests/IPEndPointTest.cs @@ -143,6 +143,19 @@ public static void ToString_Invoke_ReturnsExpected(IPEndPoint endPoint, string e Assert.Equal(expected, endPoint.ToString()); } + [Fact] + public static void Create_DifferentAF_Success() + { + SocketAddress sa = new SocketAddress(AddressFamily.InterNetwork, SocketAddress.GetMaximumAddressSize(AddressFamily.InterNetworkV6)); + var ep = new IPEndPoint(IPAddress.IPv6Any, 0); + Assert.NotNull(ep.Create(sa)); + + sa = new SocketAddress(AddressFamily.InterNetworkV6); + ep = new IPEndPoint(IPAddress.Any, 0); + + Assert.NotNull(ep.Create(sa)); + } + public static IEnumerable Serialize_TestData() { yield return new object[] { new IPAddress(2), 16 }; @@ -195,8 +208,7 @@ public static void Create_NullSocketAddress_ThrowsArgumentNullException() public static IEnumerable Create_InvalidAddressFamily_TestData() { - yield return new object[] { new IPEndPoint(2, 500), new SocketAddress(Sockets.AddressFamily.InterNetworkV6) }; - yield return new object[] { new IPEndPoint(IPAddress.Parse("192.169.0.9"), 500), new SocketAddress(Sockets.AddressFamily.InterNetworkV6) }; + yield return new object[] { new IPEndPoint(2, 500), new SocketAddress(Sockets.AddressFamily.Unknown) }; yield return new object[] { new IPEndPoint(IPAddress.Parse("0:0:0:0:0:0:0:1"), 500), new SocketAddress(Sockets.AddressFamily.InterNetwork) }; } diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.cs index e04739d5fe7a6..e94d862571a0f 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.cs @@ -927,13 +927,13 @@ internal void FinishOperationSyncSuccess(int bytesTransferred, SocketFlags flags { try { - if (_remoteEndPoint!.AddressFamily == _socketAddress!.Family) + if (_remoteEndPoint!.AddressFamily == AddressFamily.InterNetworkV6 && _socketAddress!.Family == AddressFamily.InterNetwork) { - _remoteEndPoint = _remoteEndPoint!.Create(_socketAddress); + _remoteEndPoint = new IPEndPoint(_socketAddress.GetIPAddress().MapToIPv6(), _socketAddress.GetPort()); } - else if (_remoteEndPoint!.AddressFamily == AddressFamily.InterNetworkV6 && _socketAddress.Family == AddressFamily.InterNetwork) + else { - _remoteEndPoint = new IPEndPoint(_socketAddress.GetIPAddress().MapToIPv6(), _socketAddress.GetPort()); + _remoteEndPoint = _remoteEndPoint!.Create(_socketAddress!); } } catch @@ -949,7 +949,14 @@ internal void FinishOperationSyncSuccess(int bytesTransferred, SocketFlags flags { try { - _remoteEndPoint = _remoteEndPoint!.Create(_socketAddress!); + if (_remoteEndPoint!.AddressFamily == AddressFamily.InterNetworkV6 && _socketAddress!.Family == AddressFamily.InterNetwork) + { + _remoteEndPoint = new IPEndPoint(_socketAddress.GetIPAddress().MapToIPv6(), _socketAddress.GetPort()); + } + else + { + _remoteEndPoint = _remoteEndPoint!.Create(_socketAddress!); + } } catch { diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/ReceiveFrom.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/ReceiveFrom.cs index 1a5ec7d05d28e..1ec2adeadcf51 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/ReceiveFrom.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/ReceiveFrom.cs @@ -168,6 +168,52 @@ public async Task ReceiveSent_UDP_Success(bool ipv4) } } + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task ReceiveSent_DualMode_Success(bool ipv4) + { + const int Offset = 10; + const int DatagramSize = 256; + const int DatagramsToSend = 16; + + IPAddress address = ipv4 ? IPAddress.Loopback : IPAddress.IPv6Loopback; + using Socket receiver = new Socket(SocketType.Dgram, ProtocolType.Udp); + using Socket sender = new Socket(SocketType.Dgram, ProtocolType.Udp); + if (receiver.DualMode != true || sender.DualMode != true) + { + throw new SkipException("DualMode not available"); + } + + ConfigureNonBlocking(sender); + ConfigureNonBlocking(receiver); + + receiver.BindToAnonymousPort(address); + sender.BindToAnonymousPort(address); + + byte[] sendBuffer = new byte[DatagramSize]; + var receiveInternalBuffer = new byte[DatagramSize + Offset]; + var emptyBuffer = new byte[Offset]; + ArraySegment receiveBuffer = new ArraySegment(receiveInternalBuffer, Offset, DatagramSize); + + Random rnd = new Random(0); + + for (int i = 0; i < DatagramsToSend; i++) + { + rnd.NextBytes(sendBuffer); + sender.SendTo(sendBuffer, receiver.LocalEndPoint); + + IPEndPoint remoteEp = new IPEndPoint(ipv4 ? IPAddress.Any : IPAddress.IPv6Any, 0); + + SocketReceiveFromResult result = await ReceiveFromAsync(receiver, receiveBuffer, remoteEp); + + Assert.Equal(DatagramSize, result.ReceivedBytes); + AssertExtensions.SequenceEqual(emptyBuffer, new ReadOnlySpan(receiveInternalBuffer, 0, Offset)); + AssertExtensions.SequenceEqual(sendBuffer, new ReadOnlySpan(receiveInternalBuffer, Offset, DatagramSize)); + Assert.Equal(sender.LocalEndPoint, result.RemoteEndPoint); + } + } + [Theory] [InlineData(false)] [InlineData(true)]