Add Socket(SocketSafeHandle) ctor (#34727)
authorStephen Toub <stoub@microsoft.com>
Fri, 10 Apr 2020 23:54:09 +0000 (19:54 -0400)
committerGitHub <noreply@github.com>
Fri, 10 Apr 2020 23:54:09 +0000 (19:54 -0400)
27 files changed:
src/libraries/Common/src/Interop/Unix/System.Native/Interop.Fcntl.cs
src/libraries/Common/src/Interop/Unix/System.Native/Interop.GetSocketType.cs [new file with mode: 0644]
src/libraries/Common/src/Interop/Unix/System.Native/Interop.Stat.cs
src/libraries/Common/src/Interop/Windows/WinSock/Interop.WSADuplicateSocket.cs
src/libraries/Common/src/Interop/Windows/WinSock/Interop.WSAPROTOCOL_INFOW.cs [new file with mode: 0644]
src/libraries/Common/src/Interop/Windows/WinSock/Interop.getsockname.cs
src/libraries/Common/src/System/Net/ByteOrder.cs
src/libraries/Common/src/System/Net/Internals/IPEndPointExtensions.cs
src/libraries/Common/src/System/Net/SocketAddress.cs
src/libraries/Common/src/System/Net/SocketAddressPal.Unix.cs
src/libraries/Common/src/System/Net/SocketAddressPal.Windows.cs
src/libraries/Native/Unix/System.Native/pal_io.c
src/libraries/Native/Unix/System.Native/pal_io.h
src/libraries/Native/Unix/System.Native/pal_networking.c
src/libraries/Native/Unix/System.Native/pal_networking.h
src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs
src/libraries/System.Net.Sockets/src/Resources/Strings.resx
src/libraries/System.Net.Sockets/src/System.Net.Sockets.csproj
src/libraries/System.Net.Sockets/src/System/Net/Sockets/SafeSocketHandle.cs
src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Unix.cs
src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Windows.cs
src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs
src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs
src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs
src/libraries/System.Net.Sockets/tests/FunctionalTests/CreateSocketTests.cs
src/libraries/System.Net.Sockets/tests/FunctionalTests/SendReceive.cs
src/libraries/System.Net.Sockets/tests/FunctionalTests/UnixDomainSocketTest.cs

index cec6451..0877ca7 100644 (file)
@@ -17,6 +17,9 @@ internal static partial class Interop
             [DllImport(Libraries.SystemNative, EntryPoint = "SystemNative_FcntlSetIsNonBlocking", SetLastError=true)]
             internal static extern int SetIsNonBlocking(SafeHandle fd, int isNonBlocking);
 
+            [DllImport(Libraries.SystemNative, EntryPoint = "SystemNative_FcntlGetIsNonBlocking", SetLastError = true)]
+            internal static extern int GetIsNonBlocking(SafeHandle fd, out bool isNonBlocking);
+
             [DllImport(Libraries.SystemNative, EntryPoint = "SystemNative_FcntlSetFD", SetLastError=true)]
             internal static extern int SetFD(SafeHandle fd, int flags);
 
diff --git a/src/libraries/Common/src/Interop/Unix/System.Native/Interop.GetSocketType.cs b/src/libraries/Common/src/Interop/Unix/System.Native/Interop.GetSocketType.cs
new file mode 100644 (file)
index 0000000..f857bf3
--- /dev/null
@@ -0,0 +1,15 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Net.Sockets;
+using System.Runtime.InteropServices;
+
+internal static partial class Interop
+{
+    internal static partial class Sys
+    {
+        [DllImport(Libraries.SystemNative, EntryPoint = "SystemNative_GetSocketType")]
+        internal static extern Error GetSocketType(SafeSocketHandle socket, out AddressFamily addressFamily, out SocketType socketType, out ProtocolType protocolType);
+    }
+}
index d06fbda..628bd9c 100644 (file)
@@ -55,7 +55,7 @@ internal static partial class Interop
         }
 
         [DllImport(Libraries.SystemNative, EntryPoint = "SystemNative_FStat", SetLastError = true)]
-        internal static extern int FStat(SafeFileHandle fd, out FileStatus output);
+        internal static extern int FStat(SafeHandle fd, out FileStatus output);
 
         [DllImport(Libraries.SystemNative, EntryPoint = "SystemNative_Stat", SetLastError = true)]
         internal static extern int Stat(string path, out FileStatus output);
index 23fa032..d079c1c 100644 (file)
@@ -2,7 +2,6 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 // See the LICENSE file in the project root for more information.
 
-using System;
 using System.Net.Sockets;
 using System.Runtime.InteropServices;
 
@@ -10,42 +9,6 @@ internal static partial class Interop
 {
     internal static partial class Winsock
     {
-        [StructLayout(LayoutKind.Sequential)]
-        internal unsafe struct WSAPROTOCOLCHAIN
-        {
-            private const int MAX_PROTOCOL_CHAIN = 7;
-
-            internal int ChainLen;
-            internal fixed uint ChainEntries[MAX_PROTOCOL_CHAIN];
-        }
-
-        [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)]
-        internal unsafe struct WSAPROTOCOL_INFOW
-        {
-            private const int WSAPROTOCOL_LEN = 255;
-
-            internal uint dwServiceFlags1;
-            internal uint dwServiceFlags2;
-            internal uint dwServiceFlags3;
-            internal uint dwServiceFlags4;
-            internal uint dwProviderFlags;
-            internal Guid ProviderId;
-            internal uint dwCatalogEntryId;
-            internal WSAPROTOCOLCHAIN ProtocolChain;
-            internal int iVersion;
-            internal AddressFamily iAddressFamily;
-            internal int iMaxSockAddr;
-            internal int iMinSockAddr;
-            internal SocketType iSocketType;
-            internal ProtocolType iProtocol;
-            internal int iProtocolMaxOffset;
-            internal int iNetworkByteOrder;
-            internal int iSecurityScheme;
-            internal uint dwMessageSize;
-            internal uint dwProviderReserved;
-            internal fixed char szProtocol[WSAPROTOCOL_LEN + 1];
-        }
-
         [DllImport(Interop.Libraries.Ws2_32, CharSet = CharSet.Unicode, SetLastError = true)]
         internal static extern unsafe int WSADuplicateSocket(
             [In] SafeSocketHandle s,
diff --git a/src/libraries/Common/src/Interop/Windows/WinSock/Interop.WSAPROTOCOL_INFOW.cs b/src/libraries/Common/src/Interop/Windows/WinSock/Interop.WSAPROTOCOL_INFOW.cs
new file mode 100644 (file)
index 0000000..502b284
--- /dev/null
@@ -0,0 +1,51 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Runtime.InteropServices;
+using System.Net.Sockets;
+
+internal static partial class Interop
+{
+    internal static partial class Winsock
+    {
+        public const int SO_PROTOCOL_INFOW = 0x2005;
+
+        [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)]
+        internal unsafe struct WSAPROTOCOL_INFOW
+        {
+            private const int WSAPROTOCOL_LEN = 255;
+
+            internal uint dwServiceFlags1;
+            internal uint dwServiceFlags2;
+            internal uint dwServiceFlags3;
+            internal uint dwServiceFlags4;
+            internal uint dwProviderFlags;
+            internal Guid ProviderId;
+            internal uint dwCatalogEntryId;
+            internal WSAPROTOCOLCHAIN ProtocolChain;
+            internal int iVersion;
+            internal AddressFamily iAddressFamily;
+            internal int iMaxSockAddr;
+            internal int iMinSockAddr;
+            internal SocketType iSocketType;
+            internal ProtocolType iProtocol;
+            internal int iProtocolMaxOffset;
+            internal int iNetworkByteOrder;
+            internal int iSecurityScheme;
+            internal uint dwMessageSize;
+            internal uint dwProviderReserved;
+            internal fixed char szProtocol[WSAPROTOCOL_LEN + 1];
+        }
+
+        [StructLayout(LayoutKind.Sequential)]
+        internal unsafe struct WSAPROTOCOLCHAIN
+        {
+            private const int MAX_PROTOCOL_CHAIN = 7;
+
+            internal int ChainLen;
+            internal fixed uint ChainEntries[MAX_PROTOCOL_CHAIN];
+        }
+    }
+}
index 4d44869..5471f62 100644 (file)
@@ -10,9 +10,9 @@ internal static partial class Interop
     internal static partial class Winsock
     {
         [DllImport(Interop.Libraries.Ws2_32, SetLastError = true)]
-        internal static extern SocketError getsockname(
-            [In] SafeSocketHandle socketHandle,
-            [Out] byte[] socketAddress,
-            [In, Out] ref int socketAddressSize);
+        internal static extern unsafe SocketError getsockname(
+            SafeSocketHandle socketHandle,
+            byte* socketAddress,
+            int* socketAddressSize);
     }
 }
index 1783deb..5bf3365 100644 (file)
@@ -12,7 +12,7 @@ namespace System.Net
             bytes[index + 1] = unchecked((byte)host);
         }
 
-        public static ushort NetworkBytesToHostUInt16(this byte[] bytes, int index)
+        public static ushort NetworkBytesToHostUInt16(this ReadOnlySpan<byte> bytes, int index)
         {
             return (ushort)(((ushort)bytes[index] << 8) | (ushort)bytes[index + 1]);
         }
index 4c095f7..89becdd 100644 (file)
@@ -55,7 +55,7 @@ namespace System.Net.Sockets
             return result;
         }
 
-        private static System.Net.SocketAddress GetNetSocketAddress(Internals.SocketAddress address)
+        internal static System.Net.SocketAddress GetNetSocketAddress(Internals.SocketAddress address)
         {
             var result = new System.Net.SocketAddress(address.Family, address.Size);
             for (int index = 0; index < address.Size; index++)
index a3a93bd..b02d4a7 100644 (file)
@@ -130,6 +130,12 @@ namespace System.Net.Internals
             SocketAddressPal.SetPort(Buffer, unchecked((ushort)port));
         }
 
+        internal SocketAddress(AddressFamily addressFamily, ReadOnlySpan<byte> buffer)
+        {
+            Buffer = buffer.ToArray();
+            InternalSize = Buffer.Length;
+        }
+
         internal IPAddress GetIPAddress()
         {
             if (Family == AddressFamily.InterNetworkV6)
index 027f050..c28b62c 100644 (file)
@@ -52,7 +52,7 @@ namespace System.Net
             }
         }
 
-        public static unsafe AddressFamily GetAddressFamily(byte[] buffer)
+        public static unsafe AddressFamily GetAddressFamily(ReadOnlySpan<byte> buffer)
         {
             AddressFamily family;
             Interop.Error err;
@@ -76,7 +76,7 @@ namespace System.Net
             ThrowOnFailure(err);
         }
 
-        public static unsafe ushort GetPort(byte[] buffer)
+        public static unsafe ushort GetPort(ReadOnlySpan<byte> buffer)
         {
             ushort port;
             Interop.Error err;
index b28bad4..82f0a6d 100644 (file)
@@ -11,9 +11,9 @@ namespace System.Net
         public const int IPv6AddressSize = 28;
         public const int IPv4AddressSize = 16;
 
-        public static unsafe AddressFamily GetAddressFamily(byte[] buffer)
+        public static unsafe AddressFamily GetAddressFamily(ReadOnlySpan<byte> buffer)
         {
-            return (AddressFamily)BitConverter.ToInt16(buffer, 0);
+            return (AddressFamily)BitConverter.ToInt16(buffer);
         }
 
         public static unsafe void SetAddressFamily(byte[] buffer, AddressFamily family)
@@ -35,7 +35,7 @@ namespace System.Net
 #endif
         }
 
-        public static unsafe ushort GetPort(byte[] buffer)
+        public static unsafe ushort GetPort(ReadOnlySpan<byte> buffer)
         {
             return buffer.NetworkBytesToHostUInt16(2);
         }
index d3690fa..6090846 100644 (file)
@@ -594,6 +594,24 @@ int32_t SystemNative_FcntlSetIsNonBlocking(intptr_t fd, int32_t isNonBlocking)
     return fcntl(fileDescriptor, F_SETFL, flags);
 }
 
+int32_t SystemNative_FcntlGetIsNonBlocking(intptr_t fd, int32_t* isNonBlocking)
+{
+    if (isNonBlocking == NULL)
+    {
+        return Error_EFAULT;
+    }
+
+    int flags = fcntl(ToFileDescriptor(fd), F_GETFL);
+    if (flags == -1)
+    {
+        *isNonBlocking = 0;
+        return -1;
+    }
+
+    *isNonBlocking = ((flags & O_NONBLOCK) == O_NONBLOCK) ? 1 : 0;
+    return 0;
+}
+
 int32_t SystemNative_MkDir(const char* path, int32_t mode)
 {
     int32_t result;
index 136f534..9f68aa3 100644 (file)
@@ -472,6 +472,13 @@ PALEXPORT int32_t SystemNative_FcntlSetPipeSz(intptr_t fd, int32_t size);
 PALEXPORT int32_t SystemNative_FcntlSetIsNonBlocking(intptr_t fd, int32_t isNonBlocking);
 
 /**
+ * Gets whether or not a file descriptor is non-blocking.
+ *
+ * Returns 0 for success, -1 for failure. Sets errno for failure.
+ */
+PALEXPORT int32_t SystemNative_FcntlGetIsNonBlocking(intptr_t fd, int32_t* isNonBlocking);
+
+/**
  * Create a directory. Implemented as a shim to mkdir(2).
  *
  * Returns 0 for success, -1 for failure. Sets errno for failure.
index 7faa21b..08540e2 100644 (file)
@@ -1524,7 +1524,6 @@ int32_t SystemNative_GetPeerName(intptr_t socket, uint8_t* socketAddress, int32_
         return SystemNative_ConvertErrorPlatformToPal(errno);
     }
 
-    assert(addrLen <= (socklen_t)*socketAddressLen);
     *socketAddressLen = (int32_t)addrLen;
     return Error_SUCCESS;
 }
@@ -2254,6 +2253,128 @@ static bool TryConvertProtocolTypePalToPlatform(int32_t palAddressFamily, int32_
     }
 }
 
+static bool TryConvertProtocolTypePlatformToPal(int32_t palAddressFamily, int platformProtocolType, int32_t* palProtocolType)
+{
+    assert(palProtocolType != NULL);
+
+    switch (palAddressFamily)
+    {
+#ifdef AF_PACKET
+        case AddressFamily_AF_PACKET:
+            // protocol is the IEEE 802.3 protocol number in network order.
+            *palProtocolType = platformProtocolType;
+            return true;
+#endif
+#if HAVE_LINUX_CAN_H
+        case AddressFamily_AF_CAN:
+            switch (platformProtocolType)
+            {
+                case 0:
+                    *palProtocolType = ProtocolType_PT_UNSPECIFIED;
+                    return true;
+
+                case CAN_RAW:
+                    *palProtocolType = ProtocolType_PT_RAW;
+                    return true;
+
+                default:
+                    *palProtocolType = (int)platformProtocolType;
+                    return false;
+            }
+#endif
+        case AddressFamily_AF_INET:
+            switch (platformProtocolType)
+            {
+                case 0:
+                    *palProtocolType = ProtocolType_PT_UNSPECIFIED;
+                    return true;
+
+                case IPPROTO_ICMP:
+                    *palProtocolType = ProtocolType_PT_ICMP;
+                    return true;
+
+                case IPPROTO_TCP:
+                    *palProtocolType = ProtocolType_PT_TCP;
+                    return true;
+
+                case IPPROTO_UDP:
+                    *palProtocolType = ProtocolType_PT_UDP;
+                    return true;
+
+                case IPPROTO_IGMP:
+                    *palProtocolType = ProtocolType_PT_IGMP;
+                    return true;
+
+                case IPPROTO_RAW:
+                    *palProtocolType = ProtocolType_PT_RAW;
+                    return true;
+
+                default:
+                    *palProtocolType = (int)palProtocolType;
+                    return false;
+            }
+
+        case AddressFamily_AF_INET6:
+            switch (platformProtocolType)
+            {
+                case 0:
+                    *palProtocolType = ProtocolType_PT_UNSPECIFIED;
+                    return true;
+
+                case IPPROTO_ICMPV6:
+                    *palProtocolType = ProtocolType_PT_ICMPV6;
+                    return true;
+
+                case IPPROTO_TCP:
+                    *palProtocolType = ProtocolType_PT_TCP;
+                    return true;
+
+                case IPPROTO_UDP:
+                    *palProtocolType = ProtocolType_PT_UDP;
+                    return true;
+
+                case IPPROTO_IGMP:
+                    *palProtocolType = ProtocolType_PT_IGMP;
+                    return true;
+
+                case IPPROTO_RAW:
+                    *palProtocolType = ProtocolType_PT_RAW;
+                    return true;
+
+                case IPPROTO_DSTOPTS:
+                    *palProtocolType = ProtocolType_PT_DSTOPTS;
+                    return true;
+
+                case IPPROTO_NONE:
+                    *palProtocolType = ProtocolType_PT_NONE;
+                    return true;
+
+                case IPPROTO_ROUTING:
+                    *palProtocolType = ProtocolType_PT_ROUTING;
+                    return true;
+
+                case IPPROTO_FRAGMENT:
+                    *palProtocolType = ProtocolType_PT_FRAGMENT;
+                    return true;
+
+                default:
+                    *palProtocolType = (int)platformProtocolType;
+                    return false;
+            }
+
+        default:
+            switch (platformProtocolType)
+            {
+                case 0:
+                    *palProtocolType = ProtocolType_PT_UNSPECIFIED;
+                    return true;
+                default:
+                    *palProtocolType = (int)platformProtocolType;
+                    return false;
+            }
+    }
+}
+
 int32_t SystemNative_Socket(int32_t addressFamily, int32_t socketType, int32_t protocolType, intptr_t* createdSocket)
 {
     if (createdSocket == NULL)
@@ -2297,6 +2418,48 @@ int32_t SystemNative_Socket(int32_t addressFamily, int32_t socketType, int32_t p
     return Error_SUCCESS;
 }
 
+int32_t SystemNative_GetSocketType(intptr_t socket, int32_t* addressFamily, int32_t* socketType, int32_t* protocolType)
+{
+    if (addressFamily == NULL || socketType == NULL || protocolType == NULL)
+    {
+        return Error_EFAULT;
+    }
+
+    int fd = ToFileDescriptor(socket);
+
+#ifdef SO_DOMAIN
+    int domainValue;
+    socklen_t domainLength = sizeof(int);
+    if (getsockopt(fd, SOL_SOCKET, SO_DOMAIN, &domainValue, &domainLength) != 0 ||
+        !TryConvertAddressFamilyPlatformToPal((sa_family_t)domainValue, addressFamily))
+#endif
+    {
+        *addressFamily = AddressFamily_AF_UNKNOWN;
+    }
+
+#ifdef SO_TYPE
+    int typeValue;
+    socklen_t typeLength = sizeof(int);
+    if (getsockopt(fd, SOL_SOCKET, SO_TYPE, &typeValue, &typeLength) != 0 ||
+        !TryConvertSocketTypePlatformToPal(typeValue, socketType))
+#endif
+    {
+        *socketType = SocketType_UNKNOWN;
+    }
+
+#ifdef SO_PROTOCOL
+    int protocolValue;
+    socklen_t protocolLength = sizeof(int);
+    if (getsockopt(fd, SOL_SOCKET, SO_PROTOCOL, &protocolValue, &protocolLength) != 0 ||
+        !TryConvertProtocolTypePlatformToPal(*addressFamily, protocolValue, protocolType))
+#endif
+    {
+        *protocolType = ProtocolType_PT_UNKNOWN;
+    }
+
+    return Error_SUCCESS;
+}
+
 int32_t SystemNative_GetAtOutOfBandMark(intptr_t socket, int32_t* atMark)
 {
     if (atMark == NULL)
index dbc4ebc..3dbc089 100644 (file)
@@ -59,6 +59,7 @@ typedef enum
  */
 typedef enum
 {
+    AddressFamily_AF_UNKNOWN = -1, // System.Net.AddressFamily.Unknown
     AddressFamily_AF_UNSPEC = 0,   // System.Net.AddressFamily.Unspecified
     AddressFamily_AF_UNIX = 1,     // System.Net.AddressFamily.Unix
     AddressFamily_AF_INET = 2,     // System.Net.AddressFamily.InterNetwork
@@ -74,6 +75,7 @@ typedef enum
  */
 typedef enum
 {
+    SocketType_UNKNOWN = -1,       // System.Net.SocketType.Unknown
     SocketType_SOCK_STREAM = 1,    // System.Net.SocketType.Stream
     SocketType_SOCK_DGRAM = 2,     // System.Net.SocketType.Dgram
     SocketType_SOCK_RAW = 3,       // System.Net.SocketType.Raw
@@ -88,6 +90,7 @@ typedef enum
  */
 typedef enum
 {
+    ProtocolType_PT_UNKNOWN = -1,    // System.Net.ProtocolType.Unknown
     ProtocolType_PT_UNSPECIFIED = 0, // System.Net.ProtocolType.Unspecified
     ProtocolType_PT_ICMP = 1,        // System.Net.ProtocolType.Icmp
     ProtocolType_PT_TCP = 6,         // System.Net.ProtocolType.Tcp
@@ -402,6 +405,8 @@ PALEXPORT int32_t SystemNative_SetRawSockOpt(
 
 PALEXPORT int32_t SystemNative_Socket(int32_t addressFamily, int32_t socketType, int32_t protocolType, intptr_t* createdSocket);
 
+PALEXPORT int32_t SystemNative_GetSocketType(intptr_t socket, int32_t* addressFamily, int32_t* socketType, int32_t* protocolType);
+
 PALEXPORT int32_t SystemNative_GetAtOutOfBandMark(intptr_t socket, int32_t* available);
 
 PALEXPORT int32_t SystemNative_GetBytesAvailable(intptr_t socket, int32_t* available);
index e72a4e2..629e16c 100644 (file)
@@ -222,6 +222,7 @@ namespace System.Net.Sockets
     }
     public partial class Socket : System.IDisposable
     {
+        public Socket(System.Net.Sockets.SafeSocketHandle handle) { }
         public Socket(System.Net.Sockets.AddressFamily addressFamily, System.Net.Sockets.SocketType socketType, System.Net.Sockets.ProtocolType protocolType) { }
         public Socket(System.Net.Sockets.SocketInformation socketInformation) { }
         public Socket(System.Net.Sockets.SocketType socketType, System.Net.Sockets.ProtocolType protocolType) { }
index 3b274c8..3078fab 100644 (file)
@@ -57,6 +57,9 @@
   <resheader name="writer">
     <value>System.Resources.ResXResourceWriter, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089</value>
   </resheader>
+  <data name="Arg_InvalidHandle" xml:space="preserve">
+    <value>Invalid handle.</value>
+  </data>
   <data name="net_invalidversion" xml:space="preserve">
     <value>This protocol version is not supported.</value>
   </data>
index 89b1058..3d74a78 100644 (file)
     <Compile Include="$(CommonPath)Interop\Windows\WinSock\WSABuffer.cs">
       <Link>Common\Interop\Windows\WinSock\WSABuffer.cs</Link>
     </Compile>
+    <Compile Include="$(CommonPath)Interop\Windows\WinSock\Interop.WSAPROTOCOL_INFOW.cs">
+      <Link>Common\Interop\Windows\WinSock\Interop.WSAPROTOCOL_INFOW.cs</Link>
+    </Compile>
     <Compile Include="$(CommonPath)Interop\Windows\Kernel32\Interop.CancelIoEx.cs">
       <Link>Common\Interop\Windows\Interop.CancelIoEx.cs</Link>
     </Compile>
     <Compile Include="$(CommonPath)Interop\Unix\System.Native\Interop.GetSocketErrorOption.cs">
       <Link>Common\Interop\Unix\System.Native\Interop.GetSocketErrorOption.cs</Link>
     </Compile>
+    <Compile Include="$(CommonPath)Interop\Unix\System.Native\Interop.GetSocketType.cs">
+      <Link>Common\Interop\Unix\System.Native\Interop.GetSocketType.cs</Link>
+    </Compile>
     <Compile Include="$(CommonPath)Interop\Unix\System.Native\Interop.GetSockName.cs">
       <Link>Common\Interop\Unix\System.Native\Interop.GetSockName.cs</Link>
     </Compile>
     <Compile Include="$(CommonPath)Interop\Unix\System.Native\Interop.SetReceiveTimeout.cs">
       <Link>Common\Interop\Unix\System.Native\Interop.SetReceiveTimeout.cs</Link>
     </Compile>
+    <Compile Include="$(CommonPath)Interop\Unix\System.Native\Interop.Stat.cs">
+      <Link>Common\Interop\Unix\Interop.Stat.cs</Link>
+    </Compile>
     <Compile Include="$(CommonPath)Interop\Unix\System.Native\Interop.Listen.cs">
       <Link>Common\Interop\Unix\System.Native\Interop.Listen.cs</Link>
     </Compile>
index e39df10..dff1d27 100644 (file)
@@ -35,14 +35,17 @@ namespace System.Net.Sockets
         public SafeSocketHandle(IntPtr preexistingHandle, bool ownsHandle)
             : base(ownsHandle)
         {
+            OwnsHandle = ownsHandle;
             SetHandleAndValid(preexistingHandle);
         }
 
-        private SafeSocketHandle() : base(true) { }
+        private SafeSocketHandle() : base(ownsHandle: true) => OwnsHandle = true;
+
+        internal bool OwnsHandle { get; }
 
         private bool TryOwnClose()
         {
-            return Interlocked.CompareExchange(ref _ownClose, 1, 0) == 0;
+            return OwnsHandle && Interlocked.CompareExchange(ref _ownClose, 1, 0) == 0;
         }
 
         private volatile bool _released;
index 804b9d3..d509263 100644 (file)
@@ -47,6 +47,36 @@ namespace System.Net.Sockets
             Debug.Assert(!_handle.LastConnectFailed);
         }
 
+        private static unsafe void LoadSocketTypeFromHandle(
+            SafeSocketHandle handle, out AddressFamily addressFamily, out SocketType socketType, out ProtocolType protocolType, out bool blocking)
+        {
+            // Validate that the supplied handle is indeed a socket.
+            if (Interop.Sys.FStat(handle, out Interop.Sys.FileStatus stat) == -1 ||
+                (stat.Mode & Interop.Sys.FileTypes.S_IFSOCK) != Interop.Sys.FileTypes.S_IFSOCK)
+            {
+                throw new SocketException((int)SocketError.NotSocket);
+            }
+
+            // On Linux, GetSocketType will be able to query SO_DOMAIN, SO_TYPE, and SO_PROTOCOL to get the
+            // address family, socket type, and protocol type, respectively.  On macOS, this will only succeed
+            // in getting the socket type, and the others will be unknown.  Subsequently the Socket ctor
+            // can use getsockname to retrieve the address family as part of trying to get the local end point.
+            Interop.Error e = Interop.Sys.GetSocketType(handle, out addressFamily, out socketType, out protocolType);
+            Debug.Assert(e == Interop.Error.SUCCESS, e.ToString());
+
+            // Get whether the socket is in non-blocking mode.  On Unix, we automatically put the underlying
+            // Socket into non-blocking mode whenever an async method is first invoked on the instance, but we
+            // maintain a shadow bool that maintains the Socket.Blocking value set by the developer.  Because
+            // we're querying the underlying socket here, and don't have access to the original Socket instance
+            // (if there even was one... the Socket(SafeSocketHandle) ctor is likely being used because there
+            // wasn't one, Socket.Blocking will end up reflecting the actual state of the socket even if the
+            // developer didn't set Blocking = false.
+            bool nonBlocking;
+            int rv = Interop.Sys.Fcntl.GetIsNonBlocking(handle, out nonBlocking);
+            blocking = !nonBlocking;
+            Debug.Assert(rv == 0 || blocking, e.ToString()); // ignore failures
+        }
+
         internal void ReplaceHandleIfNecessaryAfterFailedConnect()
         {
             if (!_handle.LastConnectFailed)
index d606032..2034cab 100644 (file)
@@ -7,6 +7,7 @@ using System.Collections;
 using System.Diagnostics;
 using System.IO;
 using System.Runtime.CompilerServices;
+using System.Runtime.InteropServices;
 using System.Threading;
 
 namespace System.Net.Sockets
@@ -56,7 +57,14 @@ namespace System.Net.Sockets
             IPEndPoint ep = new IPEndPoint(tempAddress, 0);
 
             Internals.SocketAddress socketAddress = IPEndPointExtensions.Serialize(ep);
-            errorCode = SocketPal.GetSockName(_handle, socketAddress.Buffer, ref socketAddress.InternalSize);
+            unsafe
+            {
+                fixed (byte* bufferPtr = socketAddress.Buffer)
+                fixed (int* sizePtr = &socketAddress.InternalSize)
+                {
+                    errorCode = SocketPal.GetSockName(_handle, bufferPtr, sizePtr);
+                }
+            }
 
             if (errorCode == SocketError.Success)
             {
@@ -76,6 +84,28 @@ namespace System.Net.Sockets
             if (NetEventSource.IsEnabled) NetEventSource.Exit(this);
         }
 
+        private unsafe void LoadSocketTypeFromHandle(
+            SafeSocketHandle handle, out AddressFamily addressFamily, out SocketType socketType, out ProtocolType protocolType, out bool blocking)
+        {
+            Interop.Winsock.WSAPROTOCOL_INFOW info = default;
+            int optionLength = sizeof(Interop.Winsock.WSAPROTOCOL_INFOW);
+
+            // Get the address family, socket type, and protocol type from the socket.
+            if (Interop.Winsock.getsockopt(handle, SocketOptionLevel.Socket, (SocketOptionName)Interop.Winsock.SO_PROTOCOL_INFOW, (byte*)&info, ref optionLength) == SocketError.SocketError)
+            {
+                throw new SocketException((int)SocketPal.GetLastSocketError());
+            }
+
+            addressFamily = info.iAddressFamily;
+            socketType = info.iSocketType;
+            protocolType = info.iProtocol;
+
+            // There's no API to retrieve this (WSAIsBlocking isn't supported any more).  Assume it's blocking, but we might be wrong.
+            // This affects the result of querying Socket.Blocking, which will mostly only affect user code that happens to query
+            // that property, though there are a few places we check it internally, e.g. as part of NetworkStream argument validation.
+            blocking = true;
+        }
+
         public SocketInformation DuplicateAndClose(int targetProcessId)
         {
             if (NetEventSource.IsEnabled) NetEventSource.Enter(this, targetProcessId);
index 5dc340c..925e4e8 100644 (file)
@@ -107,23 +107,146 @@ namespace System.Net.Sockets
             if (NetEventSource.IsEnabled) NetEventSource.Exit(this);
         }
 
-        // Called by the class to create a socket to accept an incoming request.
-        private Socket(SafeSocketHandle fd)
+        /// <summary>Initializes a new instance of the <see cref="Socket"/> class for the specified socket handle.</summary>
+        /// <param name="handle">The socket handle for the socket that the <see cref="Socket"/> object will encapsulate.</param>
+        /// <exception cref="ArgumentNullException"><paramref name="handle"/> is null.</exception>
+        /// <exception cref="ArgumentException"><paramref name="handle"/> is invalid.</exception>
+        /// <exception cref="SocketException"><paramref name="handle"/> is not a socket or information about the socket could not be accessed.</exception>
+        /// <remarks>
+        /// This method populates the <see cref="Socket"/> instance with data gathered from the supplied <see cref="SafeSocketHandle"/>.
+        /// Different operating systems provide varying levels of support for querying a socket handle or file descriptor for its
+        /// properties and configuration, which means some of the public APIs on the resulting <see cref="Socket"/> instance may
+        /// differ based on operating system, such as <see cref="Socket.ProtocolType"/> and <see cref="Socket.Blocking"/>.
+        /// </remarks>
+        public Socket(SafeSocketHandle handle) :
+            this(ValidateHandle(handle), loadPropertiesFromHandle: true)
         {
-            // NOTE: If this ctor is ever made public/protected, this check will need
-            // to be converted into a runtime exception.
-            Debug.Assert(fd != null && !fd.IsInvalid);
+        }
 
-            if (NetEventSource.IsEnabled) NetEventSource.Enter(this);
+        private unsafe Socket(SafeSocketHandle handle, bool loadPropertiesFromHandle)
+        {
             InitializeSockets();
 
-            _handle = fd;
+            _handle = handle;
+            _addressFamily = AddressFamily.Unknown;
+            _socketType = SocketType.Unknown;
+            _protocolType = ProtocolType.Unknown;
 
-            _addressFamily = Sockets.AddressFamily.Unknown;
-            _socketType = Sockets.SocketType.Unknown;
-            _protocolType = Sockets.ProtocolType.Unknown;
-            if (NetEventSource.IsEnabled) NetEventSource.Exit(this);
+            if (!loadPropertiesFromHandle)
+            {
+                return;
+            }
+
+            try
+            {
+                // Get properties like address family and blocking mode from the OS.
+                LoadSocketTypeFromHandle(handle, out _addressFamily, out _socketType, out _protocolType, out _willBlockInternal);
+
+                // Determine whether the socket is in listening mode.
+                _isListening =
+                    SocketPal.GetSockOpt(_handle, SocketOptionLevel.Socket, SocketOptionName.AcceptConnection, out int isListening) == SocketError.Success &&
+                    isListening != 0;
+
+                // Try to get the address of the socket.
+                Span<byte> buffer = stackalloc byte[512]; // arbitrary high limit that should suffice for almost all scenarios
+                int bufferLength = buffer.Length;
+                fixed (byte* bufferPtr = buffer)
+                {
+                    if (SocketPal.GetSockName(handle, bufferPtr, &bufferLength) != SocketError.Success)
+                    {
+                        return;
+                    }
+                }
+
+                if (bufferLength > buffer.Length)
+                {
+                    buffer = new byte[buffer.Length];
+                    fixed (byte* bufferPtr = buffer)
+                    {
+                        if (SocketPal.GetSockName(handle, bufferPtr, &bufferLength) != SocketError.Success ||
+                            bufferLength > buffer.Length)
+                        {
+                            return;
+                        }
+                    }
+                }
+
+                buffer = buffer.Slice(0, bufferLength);
+                if (_addressFamily == AddressFamily.Unknown)
+                {
+                    _addressFamily = SocketAddressPal.GetAddressFamily(buffer);
+                }
+#if DEBUG
+                else
+                {
+                    Debug.Assert(_addressFamily == SocketAddressPal.GetAddressFamily(buffer));
+                }
+#endif
+
+                // Try to get the local end point.  That will in turn enable the remote
+                // end point to be retrieved on-demand when the property is accessed.
+                Internals.SocketAddress? socketAddress = null;
+                switch (_addressFamily)
+                {
+                    case AddressFamily.InterNetwork:
+                        _rightEndPoint = new IPEndPoint(
+                            new IPAddress((long)SocketAddressPal.GetIPv4Address(buffer) & 0x0FFFFFFFF),
+                            SocketAddressPal.GetPort(buffer));
+                        break;
+
+                    case AddressFamily.InterNetworkV6:
+                        Span<byte> address = stackalloc byte[IPAddressParserStatics.IPv6AddressBytes];
+                        SocketAddressPal.GetIPv6Address(buffer, address, out uint scope);
+                        _rightEndPoint = new IPEndPoint(
+                            new IPAddress(address, scope),
+                            SocketAddressPal.GetPort(buffer));
+                        break;
+
+                    case AddressFamily.Unix:
+                        socketAddress = new Internals.SocketAddress(_addressFamily, buffer);
+                        _rightEndPoint = new UnixDomainSocketEndPoint(IPEndPointExtensions.GetNetSocketAddress(socketAddress));
+                        break;
+                }
+
+                // Try to determine if we're connected, based on querying for a peer, just as we would in RemoteEndPoint,
+                // but ignoring any failures; this is best-effort (RemoteEndPoint also does a catch-all around the Create call).
+                if (_rightEndPoint != null)
+                {
+                    try
+                    {
+                        socketAddress ??= new Internals.SocketAddress(_addressFamily, buffer);
+                        if (SocketPal.GetPeerName(_handle, socketAddress.Buffer, ref socketAddress.InternalSize) != SocketError.Success)
+                        {
+                            return;
+                        }
+
+                        if (socketAddress.InternalSize > socketAddress.Buffer.Length)
+                        {
+                            socketAddress.Buffer = new byte[socketAddress.InternalSize];
+                            if (SocketPal.GetPeerName(_handle, socketAddress.Buffer, ref socketAddress.InternalSize) != SocketError.Success)
+                            {
+                                return;
+                            }
+                        }
+
+                        _remoteEndPoint = _rightEndPoint.Create(socketAddress);
+                        _isConnected = true;
+                    }
+                    catch { }
+                }
+            }
+            catch
+            {
+                _handle = null!;
+                GC.SuppressFinalize(this);
+                throw;
+            }
         }
+
+        private static SafeSocketHandle ValidateHandle(SafeSocketHandle handle) =>
+            handle is null ? throw new ArgumentNullException(nameof(handle)) :
+            handle.IsInvalid ? throw new ArgumentException(SR.Arg_InvalidHandle, nameof(handle)) :
+            handle;
         #endregion
 
         #region Properties
@@ -209,15 +332,18 @@ namespace System.Net.Sockets
 
                 Internals.SocketAddress socketAddress = IPEndPointExtensions.Serialize(_rightEndPoint);
 
-                // This may throw ObjectDisposedException.
-                SocketError errorCode = SocketPal.GetSockName(
-                    _handle,
-                    socketAddress.Buffer,
-                    ref socketAddress.InternalSize);
-
-                if (errorCode != SocketError.Success)
+                unsafe
                 {
-                    UpdateStatusAfterSocketErrorAndThrowException(errorCode);
+                    fixed (byte* buffer = socketAddress.Buffer)
+                    fixed (int* bufferSize = &socketAddress.InternalSize)
+                    {
+                        // This may throw ObjectDisposedException.
+                        SocketError errorCode = SocketPal.GetSockName(_handle, buffer, bufferSize);
+                        if (errorCode != SocketError.Success)
+                        {
+                            UpdateStatusAfterSocketErrorAndThrowException(errorCode);
+                        }
+                    }
                 }
 
                 return _rightEndPoint.Create(socketAddress);
@@ -4361,6 +4487,13 @@ namespace System.Net.Sockets
 
             SetToDisconnected();
 
+            // If the safe handle doesn't own the underlying handle, we're done.
+            SafeSocketHandle handle = _handle;
+            if (handle != null && !handle.OwnsHandle)
+            {
+                return;
+            }
+
             // Close the handle in one of several ways depending on the timeout.
             // Ignore ObjectDisposedException just in case the handle somehow gets disposed elsewhere.
             try
@@ -4961,7 +5094,7 @@ namespace System.Net.Sockets
         {
             // Internal state of the socket is inherited from listener.
             Debug.Assert(fd != null && !fd.IsInvalid);
-            Socket socket = new Socket(fd);
+            Socket socket = new Socket(fd, loadPropertiesFromHandle: false);
             return UpdateAcceptSocket(socket, remoteEP);
         }
 
index b012817..3c43b02 100644 (file)
@@ -837,16 +837,9 @@ namespace System.Net.Sockets
             return SocketError.Success;
         }
 
-        public static unsafe SocketError GetSockName(SafeSocketHandle handle, byte[] buffer, ref int nameLen)
+        public static unsafe SocketError GetSockName(SafeSocketHandle handle, byte* buffer, int* nameLen)
         {
-            Interop.Error err;
-            int addrLen = nameLen;
-            fixed (byte* rawBuffer = buffer)
-            {
-                err = Interop.Sys.GetSockName(handle, rawBuffer, &addrLen);
-            }
-
-            nameLen = addrLen;
+            Interop.Error err = Interop.Sys.GetSockName(handle, buffer, nameLen);
             return err == Interop.Error.SUCCESS ? SocketError.Success : GetSocketErrorForErrorCode(err);
         }
 
index db5f880..c3e113b 100644 (file)
@@ -138,9 +138,9 @@ namespace System.Net.Sockets
             return errorCode;
         }
 
-        public static SocketError GetSockName(SafeSocketHandle handle, byte[] buffer, ref int nameLen)
+        public static unsafe SocketError GetSockName(SafeSocketHandle handle, byte* buffer, int* nameLen)
         {
-            SocketError errorCode = Interop.Winsock.getsockname(handle, buffer, ref nameLen);
+            SocketError errorCode = Interop.Winsock.getsockname(handle, buffer, nameLen);
             return errorCode == SocketError.SocketError ? GetLastSocketError() : SocketError.Success;
         }
 
index 9fc617f..86e1b2d 100644 (file)
@@ -2,8 +2,10 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 // See the LICENSE file in the project root for more information.
 
+using System.Diagnostics;
 using System.IO;
 using System.IO.Pipes;
+using System.Runtime.InteropServices;
 using System.Threading.Tasks;
 using Microsoft.DotNet.RemoteExecutor;
 using Xunit;
@@ -233,5 +235,236 @@ namespace System.Net.Sockets.Tests
             }
             s.Close();
         }
+
+        [Fact]
+        public void Ctor_SafeHandle_Invalid_ThrowsException()
+        {
+            AssertExtensions.Throws<ArgumentNullException>("handle", () => new Socket(null));
+            AssertExtensions.Throws<ArgumentException>("handle", () => new Socket(new SafeSocketHandle((IntPtr)(-1), false)));
+
+            using (var pipe = new AnonymousPipeServerStream())
+            {
+                SocketException se = Assert.Throws<SocketException>(() => new Socket(new SafeSocketHandle(pipe.ClientSafePipeHandle.DangerousGetHandle(), false)));
+                Assert.Equal(SocketError.NotSocket, se.SocketErrorCode);
+            }
+        }
+
+        [Theory]
+        [InlineData(AddressFamily.ControllerAreaNetwork, SocketType.Raw, ProtocolType.Unspecified)]
+        [InlineData(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp)]
+        [InlineData(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)]
+        [InlineData(AddressFamily.InterNetwork, SocketType.Raw, ProtocolType.Unspecified)]
+        [InlineData(AddressFamily.InterNetworkV6, SocketType.Dgram, ProtocolType.Udp)]
+        [InlineData(AddressFamily.InterNetworkV6, SocketType.Stream, ProtocolType.Tcp)]
+        [InlineData(AddressFamily.InterNetworkV6, SocketType.Raw, ProtocolType.Unspecified)]
+        [InlineData(AddressFamily.Packet, SocketType.Raw, ProtocolType.Raw)]
+        [InlineData(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified)]
+        public void Ctor_SafeHandle_BasicPropertiesPropagate_Success(AddressFamily addressFamily, SocketType socketType, ProtocolType protocolType)
+        {
+            Socket tmpOrig;
+            try
+            {
+                tmpOrig = new Socket(addressFamily, socketType, protocolType);
+            }
+            catch (SocketException e) when (
+                e.SocketErrorCode == SocketError.AccessDenied ||
+                e.SocketErrorCode == SocketError.ProtocolNotSupported ||
+                e.SocketErrorCode == SocketError.AddressFamilyNotSupported)
+            {
+                // We can't test this combination on this platform.
+                return;
+            }
+
+            using Socket orig = tmpOrig;
+            using var copy = new Socket(orig.SafeHandle);
+
+            Assert.False(orig.Connected);
+            Assert.False(copy.Connected);
+
+            Assert.Null(orig.LocalEndPoint);
+            Assert.Null(orig.RemoteEndPoint);
+            Assert.False(orig.IsBound);
+            if (copy.IsBound)
+            {
+                // On Unix, we may successfully obtain an (empty) local end point, even though Bind wasn't called.
+                Debug.Assert(!RuntimeInformation.IsOSPlatform(OSPlatform.Windows));
+                if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) // OSX gets some strange results in some cases, e.g. "@\0\0\0\0\0\0\0\0\0\0\0\0\0" for a UDS
+                {
+                    switch (addressFamily)
+                    {
+                        case AddressFamily.InterNetwork:
+                            Assert.Equal(new IPEndPoint(IPAddress.Any, 0), copy.LocalEndPoint);
+                            break;
+
+                        case AddressFamily.InterNetworkV6:
+                            Assert.Equal(new IPEndPoint(IPAddress.IPv6Any, 0), copy.LocalEndPoint);
+                            break;
+
+                        case AddressFamily.Unix:
+                            Assert.IsType<UnixDomainSocketEndPoint>(copy.LocalEndPoint);
+                            Assert.Equal("", copy.LocalEndPoint.ToString());
+                            break;
+
+                        default:
+                            Assert.Null(copy.LocalEndPoint);
+                            break;
+                    }
+                }
+            }
+            else
+            {
+                Assert.Equal(orig.LocalEndPoint, copy.LocalEndPoint);
+                Assert.Equal(orig.LocalEndPoint, copy.RemoteEndPoint);
+            }
+
+            Assert.Equal(addressFamily, orig.AddressFamily);
+            Assert.Equal(socketType, orig.SocketType);
+            Assert.Equal(protocolType, orig.ProtocolType);
+
+            Assert.Equal(addressFamily, copy.AddressFamily);
+            Assert.Equal(socketType, copy.SocketType);
+            Assert.True(copy.ProtocolType == orig.ProtocolType || copy.ProtocolType == ProtocolType.Unknown, $"Expected: {protocolType} or Unknown, Actual: {copy.ProtocolType}");
+
+            Assert.True(orig.Blocking);
+            Assert.True(copy.Blocking);
+
+            if (orig.AddressFamily == copy.AddressFamily)
+            {
+                AssertEqualOrSameException(() => orig.DontFragment, () => copy.DontFragment);
+                AssertEqualOrSameException(() => orig.MulticastLoopback, () => copy.MulticastLoopback);
+                AssertEqualOrSameException(() => orig.Ttl, () => copy.Ttl);
+            }
+
+            AssertEqualOrSameException(() => orig.EnableBroadcast, () => copy.EnableBroadcast);
+            AssertEqualOrSameException(() => orig.LingerState.Enabled, () => copy.LingerState.Enabled);
+            AssertEqualOrSameException(() => orig.LingerState.LingerTime, () => copy.LingerState.LingerTime);
+            AssertEqualOrSameException(() => orig.NoDelay, () => copy.NoDelay);
+
+            Assert.Equal(orig.Available, copy.Available);
+            Assert.Equal(orig.ExclusiveAddressUse, copy.ExclusiveAddressUse);
+            Assert.Equal(orig.Handle, copy.Handle);
+            Assert.Equal(orig.ReceiveBufferSize, copy.ReceiveBufferSize);
+            Assert.Equal(orig.ReceiveTimeout, copy.ReceiveTimeout);
+            Assert.Equal(orig.SendBufferSize, copy.SendBufferSize);
+            Assert.Equal(orig.SendTimeout, copy.SendTimeout);
+            Assert.Equal(orig.UseOnlyOverlappedIO, copy.UseOnlyOverlappedIO);
+        }
+
+        [Theory]
+        [InlineData(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)]
+        [InlineData(AddressFamily.InterNetworkV6, SocketType.Stream, ProtocolType.Tcp)]
+        public async Task Ctor_SafeHandle_Tcp_SendReceive_Success(AddressFamily addressFamily, SocketType socketType, ProtocolType protocolType)
+        {
+            using var orig = new Socket(addressFamily, socketType, protocolType);
+            using var listener = new Socket(addressFamily, socketType, protocolType);
+            listener.Bind(new IPEndPoint(addressFamily == AddressFamily.InterNetwork ? IPAddress.Loopback : IPAddress.IPv6Loopback, 0));
+            listener.Listen(1);
+            await orig.ConnectAsync(listener.LocalEndPoint);
+            using var server = await listener.AcceptAsync();
+
+            using var client = new Socket(orig.SafeHandle);
+
+            Assert.True(client.Connected);
+            Assert.Equal(orig.AddressFamily, client.AddressFamily);
+            Assert.Equal(orig.SocketType, client.SocketType);
+            Assert.True(client.ProtocolType == orig.ProtocolType || client.ProtocolType == ProtocolType.Unknown, $"Expected: {protocolType} or Unknown, Actual: {client.ProtocolType}");
+
+            // Validate accessing end points
+            Assert.Equal(orig.LocalEndPoint, client.LocalEndPoint);
+            Assert.Equal(orig.RemoteEndPoint, client.RemoteEndPoint);
+
+            // Validating accessing other properties
+            Assert.Equal(orig.Available, client.Available);
+            Assert.True(orig.Blocking);
+            Assert.True(client.Blocking);
+            AssertEqualOrSameException(() => orig.DontFragment, () => client.DontFragment);
+            AssertEqualOrSameException(() => orig.EnableBroadcast, () => client.EnableBroadcast);
+            Assert.Equal(orig.ExclusiveAddressUse, client.ExclusiveAddressUse);
+            Assert.Equal(orig.Handle, client.Handle);
+            Assert.Equal(orig.IsBound, client.IsBound);
+            Assert.Equal(orig.LingerState.Enabled, client.LingerState.Enabled);
+            Assert.Equal(orig.LingerState.LingerTime, client.LingerState.LingerTime);
+            AssertEqualOrSameException(() => orig.MulticastLoopback, () => client.MulticastLoopback);
+            Assert.Equal(orig.NoDelay, client.NoDelay);
+            Assert.Equal(orig.ReceiveBufferSize, client.ReceiveBufferSize);
+            Assert.Equal(orig.ReceiveTimeout, client.ReceiveTimeout);
+            Assert.Equal(orig.SendBufferSize, client.SendBufferSize);
+            Assert.Equal(orig.SendTimeout, client.SendTimeout);
+            Assert.Equal(orig.Ttl, client.Ttl);
+            Assert.Equal(orig.UseOnlyOverlappedIO, client.UseOnlyOverlappedIO);
+
+            // Validate setting various properties on the new instance and seeing them roundtrip back to the original.
+            client.ReceiveTimeout = 42;
+            Assert.Equal(client.ReceiveTimeout, orig.ReceiveTimeout);
+
+            // Validate sending and receiving
+            Assert.Equal(1, await client.SendAsync(new byte[1] { 42 }, SocketFlags.None));
+            var buffer = new byte[1];
+            Assert.Equal(1, await server.ReceiveAsync(buffer, SocketFlags.None));
+            Assert.Equal(42, buffer[0]);
+
+            Assert.Equal(1, await server.SendAsync(new byte[1] { 42 }, SocketFlags.None));
+            buffer[0] = 0;
+            Assert.Equal(1, await client.ReceiveAsync(buffer, SocketFlags.None));
+            Assert.Equal(42, buffer[0]);
+        }
+
+        [PlatformSpecific(TestPlatforms.Windows | TestPlatforms.Linux)] // OSX/FreeBSD doesn't support SO_ACCEPTCONN, so we can't query for whether a socket is listening
+        [Theory]
+        [InlineData(false)]
+        [InlineData(true)]
+        public async Task Ctor_SafeHandle_Listening_Success(bool shareSafeHandle)
+        {
+            using var listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
+            listener.Bind(new IPEndPoint(IPAddress.Loopback, 0));
+            listener.Listen();
+            Assert.Equal(1, listener.GetSocketOption(SocketOptionLevel.Socket, SocketOptionName.AcceptConnection));
+
+            using var listenerCopy = new Socket(shareSafeHandle ? listener.SafeHandle : new SafeSocketHandle(listener.Handle, ownsHandle: false));
+            Assert.Equal(1, listenerCopy.GetSocketOption(SocketOptionLevel.Socket, SocketOptionName.AcceptConnection));
+
+            Assert.Equal(listener.AddressFamily, listenerCopy.AddressFamily);
+            Assert.Equal(listener.Handle, listenerCopy.Handle);
+            Assert.Equal(listener.IsBound, listenerCopy.IsBound);
+            Assert.Equal(listener.LocalEndPoint, listener.LocalEndPoint);
+            Assert.True(listenerCopy.ProtocolType == listener.ProtocolType || listenerCopy.ProtocolType == ProtocolType.Unknown, $"Expected: {listener.ProtocolType} or Unknown, Actual: {listenerCopy.ProtocolType}");
+            Assert.Equal(listener.SocketType, listenerCopy.SocketType);
+
+            foreach (Socket listenerSocket in new[] { listener, listenerCopy })
+            {
+                using (var client1 = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
+                {
+                    Task connect1 = client1.ConnectAsync(listenerSocket.LocalEndPoint);
+                    using (Socket server1 = listenerSocket.Accept())
+                    {
+                        await connect1;
+                        server1.Send(new byte[] { 42 });
+                        Assert.Equal(1, client1.Receive(new byte[1]));
+                    }
+                }
+            }
+        }
+
+        private static void AssertEqualOrSameException<T>(Func<T> expected, Func<T> actual)
+        {
+            T r1 = default, r2 = default;
+            Exception e1 = null, e2 = null;
+
+            try { r1 = expected(); }
+            catch (Exception e) { e1 = e; };
+
+            try { r2 = actual(); }
+            catch (Exception e) { e2 = e; };
+
+            Assert.Equal(e1 is null, e2 is null);
+            if (e1 is null)
+            {
+                Assert.Equal(r1, r2);
+            }
+            else
+            {
+                Assert.Equal(e1.GetType(), e2.GetType());
+            }
+        }
     }
 }
index 3725c4c..68fcd85 100644 (file)
@@ -44,11 +44,16 @@ namespace System.Net.Sockets.Tests
             }
         }
 
+        public static IEnumerable<object[]> LoopbackWithBool =>
+            from addr in Loopbacks
+            from b in new[] { false, true }
+            select new object[] { addr[0], b };
+
         [ActiveIssue("https://github.com/dotnet/runtime/issues/1712")]
         [OuterLoop]
         [Theory]
-        [MemberData(nameof(Loopbacks))]
-        public async Task SendToRecvFrom_Datagram_UDP(IPAddress loopbackAddress)
+        [MemberData(nameof(LoopbackWithBool))]
+        public async Task SendToRecvFrom_Datagram_UDP(IPAddress loopbackAddress, bool useClone)
         {
             IPAddress leftAddress = loopbackAddress, rightAddress = loopbackAddress;
 
@@ -57,11 +62,13 @@ namespace System.Net.Sockets.Tests
             const int AckTimeout = 10000;
             const int TestTimeout = 30000;
 
-            var left = new Socket(leftAddress.AddressFamily, SocketType.Dgram, ProtocolType.Udp);
-            left.BindToAnonymousPort(leftAddress);
+            using var origLeft = new Socket(leftAddress.AddressFamily, SocketType.Dgram, ProtocolType.Udp);
+            using var origRight = new Socket(rightAddress.AddressFamily, SocketType.Dgram, ProtocolType.Udp);
+            origLeft.BindToAnonymousPort(leftAddress);
+            origRight.BindToAnonymousPort(rightAddress);
 
-            var right = new Socket(rightAddress.AddressFamily, SocketType.Dgram, ProtocolType.Udp);
-            right.BindToAnonymousPort(rightAddress);
+            using var left = useClone ? new Socket(origLeft.SafeHandle) : origLeft;
+            using var right = useClone ? new Socket(origRight.SafeHandle) : origRight;
 
             var leftEndpoint = (IPEndPoint)left.LocalEndPoint;
             var rightEndpoint = (IPEndPoint)right.LocalEndPoint;
@@ -74,25 +81,22 @@ namespace System.Net.Sockets.Tests
             var receivedChecksums = new uint?[DatagramsToSend];
             Task leftThread = Task.Run(async () =>
             {
-                using (left)
+                EndPoint remote = leftEndpoint.Create(leftEndpoint.Serialize());
+                var recvBuffer = new byte[DatagramSize];
+                for (int i = 0; i < DatagramsToSend; i++)
                 {
-                    EndPoint remote = leftEndpoint.Create(leftEndpoint.Serialize());
-                    var recvBuffer = new byte[DatagramSize];
-                    for (int i = 0; i < DatagramsToSend; i++)
-                    {
-                        SocketReceiveFromResult result = await ReceiveFromAsync(
-                            left, new ArraySegment<byte>(recvBuffer), remote);
-                        Assert.Equal(DatagramSize, result.ReceivedBytes);
-                        Assert.Equal(rightEndpoint, result.RemoteEndPoint);
-
-                        int datagramId = recvBuffer[0];
-                        Assert.Null(receivedChecksums[datagramId]);
-                        receivedChecksums[datagramId] = Fletcher32.Checksum(recvBuffer, 0, result.ReceivedBytes);
-
-                        receiverAck.Release();
-                        bool gotAck = await senderAck.WaitAsync(TestTimeout);
-                        Assert.True(gotAck, $"{DateTime.Now}: Timeout waiting {TestTimeout} for senderAck in iteration {i}");
-                    }
+                    SocketReceiveFromResult result = await ReceiveFromAsync(
+                        left, new ArraySegment<byte>(recvBuffer), remote);
+                    Assert.Equal(DatagramSize, result.ReceivedBytes);
+                    Assert.Equal(rightEndpoint, result.RemoteEndPoint);
+
+                    int datagramId = recvBuffer[0];
+                    Assert.Null(receivedChecksums[datagramId]);
+                    receivedChecksums[datagramId] = Fletcher32.Checksum(recvBuffer, 0, result.ReceivedBytes);
+
+                    receiverAck.Release();
+                    bool gotAck = await senderAck.WaitAsync(TestTimeout);
+                    Assert.True(gotAck, $"{DateTime.Now}: Timeout waiting {TestTimeout} for senderAck in iteration {i}");
                 }
             });
 
index ddead9a..3e5e7d9 100644 (file)
@@ -148,6 +148,51 @@ namespace System.Net.Sockets.Tests
         }
 
         [ConditionalFact(nameof(PlatformSupportsUnixDomainSockets))]
+        public void Socket_SendReceive_Clone_Success()
+        {
+            string path = GetRandomNonExistingFilePath();
+            var endPoint = new UnixDomainSocketEndPoint(path);
+            try
+            {
+                using var server = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified);
+                using var client = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified);
+                {
+                    server.Bind(endPoint);
+                    server.Listen(1);
+                    client.Connect(endPoint);
+
+                    using (Socket accepted = server.Accept())
+                    {
+                        using var clientClone = new Socket(client.SafeHandle);
+                        using var acceptedClone = new Socket(accepted.SafeHandle);
+
+                        Assert.Equal(client.LocalEndPoint.ToString(), clientClone.LocalEndPoint.ToString());
+                        Assert.Equal(client.RemoteEndPoint.ToString(), clientClone.RemoteEndPoint.ToString());
+                        Assert.Equal(accepted.LocalEndPoint.ToString(), acceptedClone.LocalEndPoint.ToString());
+                        Assert.Equal(accepted.RemoteEndPoint.ToString(), acceptedClone.RemoteEndPoint.ToString());
+
+                        var data = new byte[1];
+                        for (int i = 0; i < 10; i++)
+                        {
+                            data[0] = (byte)i;
+
+                            acceptedClone.Send(data);
+                            data[0] = 0;
+
+                            Assert.Equal(1, clientClone.Receive(data));
+                            Assert.Equal(i, data[0]);
+                        }
+                    }
+                }
+            }
+            finally
+            {
+                try { File.Delete(path); }
+                catch { }
+            }
+        }
+
+        [ConditionalFact(nameof(PlatformSupportsUnixDomainSockets))]
         public async Task Socket_SendReceiveAsync_Success()
         {
             string path = GetRandomNonExistingFilePath();