Simplify winsock initialization (#43284)
authorJan Kotas <jkotas@microsoft.com>
Tue, 13 Oct 2020 04:25:37 +0000 (21:25 -0700)
committerGitHub <noreply@github.com>
Tue, 13 Oct 2020 04:25:37 +0000 (21:25 -0700)
* Initialize winsock directly

* Delete test that is not relevant anymore

* Add tests for methods the require initialized winsock

Co-authored-by: Stephen Toub <stoub@microsoft.com>
23 files changed:
src/libraries/Common/src/Interop/Windows/WinSock/Interop.GetAddrInfoExW.cs
src/libraries/Common/src/Interop/Windows/WinSock/Interop.WSAStartup.cs
src/libraries/Common/src/System/Net/SocketProtocolSupportPal.Windows.cs
src/libraries/System.Net.NameResolution/src/System.Net.NameResolution.csproj
src/libraries/System.Net.NameResolution/src/System/Net/Dns.cs
src/libraries/System.Net.NameResolution/src/System/Net/NameResolutionPal.Unix.cs
src/libraries/System.Net.NameResolution/src/System/Net/NameResolutionPal.Win32.cs [deleted file]
src/libraries/System.Net.NameResolution/src/System/Net/NameResolutionPal.Windows.cs
src/libraries/System.Net.NameResolution/tests/PalTests/NameResolutionPalTests.cs
src/libraries/System.Net.NameResolution/tests/PalTests/System.Net.NameResolution.Pal.Tests.csproj
src/libraries/System.Net.NameResolution/tests/UnitTests/Fakes/FakeNameResolutionPal.cs
src/libraries/System.Net.NameResolution/tests/UnitTests/InitializationTest.cs [deleted file]
src/libraries/System.Net.NameResolution/tests/UnitTests/System.Net.NameResolution.Unit.Tests.csproj
src/libraries/System.Net.Ping/src/System.Net.Ping.csproj
src/libraries/System.Net.Ping/src/System/Net/NetworkInformation/Ping.Windows.cs
src/libraries/System.Net.Ping/src/System/Net/NetworkInformation/Ping.cs
src/libraries/System.Net.Sockets/src/System.Net.Sockets.csproj
src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Windows.cs
src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs
src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs
src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs
src/libraries/System.Net.Sockets/tests/FunctionalTests/StartupTests.Windows.cs [new file with mode: 0644]
src/libraries/System.Net.Sockets/tests/FunctionalTests/System.Net.Sockets.Tests.csproj

index 7610f6d..c5c233c 100644 (file)
@@ -15,8 +15,6 @@ internal static partial class Interop
 
         internal const int NS_ALL = 0;
 
-        internal unsafe delegate void LPLOOKUPSERVICE_COMPLETION_ROUTINE([In] int dwError, [In] int dwBytes, [In] NativeOverlapped* lpOverlapped);
-
         [DllImport(Libraries.Ws2_32, ExactSpelling = true, CharSet = CharSet.Unicode, SetLastError = true)]
         internal static extern unsafe int GetAddrInfoExW(
             [In] string pName,
@@ -27,7 +25,7 @@ internal static partial class Interop
             [Out] AddressInfoEx** ppResult,
             [In] IntPtr timeout,
             [In] NativeOverlapped* lpOverlapped,
-            [In] LPLOOKUPSERVICE_COMPLETION_ROUTINE lpCompletionRoutine,
+            [In] delegate* unmanaged<int, int, NativeOverlapped*, void> lpCompletionRoutine,
             [Out] IntPtr* lpNameHandle);
 
         [DllImport(Libraries.Ws2_32, ExactSpelling = true)]
index 0c119ce..7f90a3a 100644 (file)
@@ -1,24 +1,50 @@
 // Licensed to the .NET Foundation under one or more agreements.
 // The .NET Foundation licenses this file to you under the MIT license.
 
+using System.Diagnostics;
 using System.Runtime.InteropServices;
 using System.Net.Sockets;
+using System.Threading;
 
 internal static partial class Interop
 {
     internal static partial class Winsock
     {
-        // Important: this API is called once by the System.Net.NameResolution contract implementation.
-        // WSACleanup is not called and will be automatically performed at process shutdown.
-        internal static unsafe SocketError WSAStartup()
+        private static int s_initialized;
+
+        internal static void EnsureInitialized()
         {
-            WSAData d;
-            return WSAStartup(0x0202 /* 2.2 */, &d);
+            // No volatile needed here. Reading stale information is just going to cause a harmless extra startup.
+            if (s_initialized == 0)
+                Initialize();
+
+            static unsafe void Initialize()
+            {
+                WSAData d;
+                SocketError errorCode = WSAStartup(0x0202 /* 2.2 */, &d);
+
+                if (errorCode != SocketError.Success)
+                {
+                    // WSAStartup does not set LastWin32Error
+                    throw new SocketException((int)errorCode);
+                }
+
+                if (Interlocked.CompareExchange(ref s_initialized, 1, 0) != 0)
+                {
+                    // Keep the winsock initialization count balanced if other thread beats us to finish the initialization.
+                    // This cleanup is just for good hygiene. A few extra startups would not matter.
+                    errorCode = WSACleanup();
+                    Debug.Assert(errorCode == SocketError.Success);
+                }
+            }
         }
 
-        [DllImport(Libraries.Ws2_32, SetLastError = true)]
+        [DllImport(Libraries.Ws2_32)]
         private static extern unsafe SocketError WSAStartup(short wVersionRequested, WSAData* lpWSAData);
 
+        [DllImport(Libraries.Ws2_32)]
+        private static extern SocketError WSACleanup();
+
         [StructLayout(LayoutKind.Sequential, Size = 408)]
         private unsafe struct WSAData
         {
index ab388ab..de0465a 100644 (file)
@@ -17,6 +17,8 @@ namespace System.Net
 
         private static bool IsSupported(AddressFamily af)
         {
+            Interop.Winsock.EnsureInitialized();
+
             IntPtr INVALID_SOCKET = (IntPtr)(-1);
             IntPtr socket = INVALID_SOCKET;
             try
index 8fff5c1..7286317 100644 (file)
@@ -36,7 +36,6 @@
   </ItemGroup>
   <ItemGroup Condition=" '$(TargetsWindows)' == 'true'">
     <Compile Include="System\Net\NameResolutionPal.Windows.cs" />
-    <Compile Include="System\Net\NameResolutionPal.Win32.cs" />
     <!-- Debug only -->
     <Compile Include="$(CommonPath)System\Net\DebugSafeHandle.cs"
              Link="Common\System\Net\DebugSafeHandle.cs" />
index 60cc964..942db43 100644 (file)
@@ -17,8 +17,6 @@ namespace System.Net
         /// <summary>Gets the host name of the local machine.</summary>
         public static string GetHostName()
         {
-            NameResolutionPal.EnsureSocketsAreInitialized();
-
             ValueStopwatch stopwatch = NameResolutionTelemetry.Log.BeforeResolution(string.Empty);
 
             string name;
@@ -41,8 +39,6 @@ namespace System.Net
 
         public static IPHostEntry GetHostEntry(IPAddress address)
         {
-            NameResolutionPal.EnsureSocketsAreInitialized();
-
             if (address is null)
             {
                 throw new ArgumentNullException(nameof(address));
@@ -62,8 +58,6 @@ namespace System.Net
 
         public static IPHostEntry GetHostEntry(string hostNameOrAddress)
         {
-            NameResolutionPal.EnsureSocketsAreInitialized();
-
             if (hostNameOrAddress is null)
             {
                 throw new ArgumentNullException(nameof(hostNameOrAddress));
@@ -107,8 +101,6 @@ namespace System.Net
 
         public static Task<IPHostEntry> GetHostEntryAsync(IPAddress address)
         {
-            NameResolutionPal.EnsureSocketsAreInitialized();
-
             if (address is null)
             {
                 throw new ArgumentNullException(nameof(address));
@@ -138,8 +130,6 @@ namespace System.Net
 
         public static IPAddress[] GetHostAddresses(string hostNameOrAddress)
         {
-            NameResolutionPal.EnsureSocketsAreInitialized();
-
             if (hostNameOrAddress is null)
             {
                 throw new ArgumentNullException(nameof(hostNameOrAddress));
@@ -178,8 +168,6 @@ namespace System.Net
         [Obsolete("GetHostByName is obsoleted for this type, please use GetHostEntry instead. https://go.microsoft.com/fwlink/?linkid=14202")]
         public static IPHostEntry GetHostByName(string hostName)
         {
-            NameResolutionPal.EnsureSocketsAreInitialized();
-
             if (hostName is null)
             {
                 throw new ArgumentNullException(nameof(hostName));
@@ -204,8 +192,6 @@ namespace System.Net
         [Obsolete("GetHostByAddress is obsoleted for this type, please use GetHostEntry instead. https://go.microsoft.com/fwlink/?linkid=14202")]
         public static IPHostEntry GetHostByAddress(string address)
         {
-            NameResolutionPal.EnsureSocketsAreInitialized();
-
             if (address is null)
             {
                 throw new ArgumentNullException(nameof(address));
@@ -220,8 +206,6 @@ namespace System.Net
         [Obsolete("GetHostByAddress is obsoleted for this type, please use GetHostEntry instead. https://go.microsoft.com/fwlink/?linkid=14202")]
         public static IPHostEntry GetHostByAddress(IPAddress address)
         {
-            NameResolutionPal.EnsureSocketsAreInitialized();
-
             if (address is null)
             {
                 throw new ArgumentNullException(nameof(address));
@@ -236,8 +220,6 @@ namespace System.Net
         [Obsolete("Resolve is obsoleted for this type, please use GetHostEntry instead. https://go.microsoft.com/fwlink/?linkid=14202")]
         public static IPHostEntry Resolve(string hostName)
         {
-            NameResolutionPal.EnsureSocketsAreInitialized();
-
             if (hostName is null)
             {
                 throw new ArgumentNullException(nameof(hostName));
@@ -430,8 +412,6 @@ namespace System.Net
         // If hostName is an IPString and justReturnParsedIP==true then no reverse lookup will be attempted, but the original address is returned.
         private static Task GetHostEntryOrAddressesCoreAsync(string hostName, bool justReturnParsedIp, bool throwOnIIPAny, bool justAddresses)
         {
-            NameResolutionPal.EnsureSocketsAreInitialized();
-
             if (hostName is null)
             {
                 throw new ArgumentNullException(nameof(hostName));
index 316dc76..add1031 100644 (file)
@@ -15,8 +15,6 @@ namespace System.Net
     {
         public const bool SupportsGetAddrInfoAsync = false;
 
-        public static void EnsureSocketsAreInitialized() { } // No-op for Unix
-
         internal static Task GetAddrInfoAsync(string hostName, bool justAddresses) =>
             throw new NotSupportedException();
 
diff --git a/src/libraries/System.Net.NameResolution/src/System/Net/NameResolutionPal.Win32.cs b/src/libraries/System.Net.NameResolution/src/System/Net/NameResolutionPal.Win32.cs
deleted file mode 100644 (file)
index 7656d66..0000000
+++ /dev/null
@@ -1,20 +0,0 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-
-using System.Runtime.InteropServices;
-
-namespace System.Net
-{
-    internal static partial class NameResolutionPal
-    {
-        private static bool GetAddrInfoExSupportsOverlapped()
-        {
-            if (!NativeLibrary.TryLoad(Interop.Libraries.Ws2_32, typeof(NameResolutionPal).Assembly, null, out IntPtr libHandle))
-                return false;
-
-            // We can't just check that 'GetAddrInfoEx' exists, because it existed before supporting overlapped.
-            // The existence of 'GetAddrInfoExCancel' indicates that overlapped is supported.
-            return NativeLibrary.TryGetExport(libHandle, Interop.Winsock.GetAddrInfoExCancelFunctionName, out _);
-        }
-    }
-}
index 64fe7cc..1de610f 100644 (file)
@@ -13,50 +13,38 @@ namespace System.Net
 {
     internal static partial class NameResolutionPal
     {
-        private static volatile bool s_initialized;
-        private static readonly object s_initializedLock = new object();
+        private static volatile int s_getAddrInfoExSupported;
 
-        private static readonly unsafe Interop.Winsock.LPLOOKUPSERVICE_COMPLETION_ROUTINE s_getAddrInfoExCallback = GetAddressInfoExCallback;
-        private static bool s_getAddrInfoExSupported;
-
-        public static void EnsureSocketsAreInitialized()
+        public static bool SupportsGetAddrInfoAsync
         {
-            if (!s_initialized)
+            get
             {
-                InitializeSockets();
-            }
+                int supported = s_getAddrInfoExSupported;
+                if (supported == 0)
+                {
+                    Initialize();
+                    supported = s_getAddrInfoExSupported;
+                }
+                return supported == 1;
 
-            static void InitializeSockets()
-            {
-                lock (s_initializedLock)
+                static void Initialize()
                 {
-                    if (!s_initialized)
-                    {
-                        SocketError errorCode = Interop.Winsock.WSAStartup();
-                        if (errorCode != SocketError.Success)
-                        {
-                            // WSAStartup does not set LastWin32Error
-                            throw new SocketException((int)errorCode);
-                        }
+                    Interop.Winsock.EnsureInitialized();
 
-                        s_getAddrInfoExSupported = GetAddrInfoExSupportsOverlapped();
-                        s_initialized = true;
-                    }
-                }
-            }
-        }
+                    IntPtr libHandle = NativeLibrary.Load(Interop.Libraries.Ws2_32, typeof(NameResolutionPal).Assembly, null);
 
-        public static bool SupportsGetAddrInfoAsync
-        {
-            get
-            {
-                EnsureSocketsAreInitialized();
-                return s_getAddrInfoExSupported;
+                    // We can't just check that 'GetAddrInfoEx' exists, because it existed before supporting overlapped.
+                    // The existence of 'GetAddrInfoExCancel' indicates that overlapped is supported.
+                    bool supported = NativeLibrary.TryGetExport(libHandle, Interop.Winsock.GetAddrInfoExCancelFunctionName, out _);
+                    Interlocked.CompareExchange(ref s_getAddrInfoExSupported, supported ? 1 : -1, 0);
+                }
             }
         }
 
         public static unsafe SocketError TryGetAddrInfo(string name, bool justAddresses, out string? hostName, out string[] aliases, out IPAddress[] addresses, out int nativeErrorCode)
         {
+            Interop.Winsock.EnsureInitialized();
+
             aliases = Array.Empty<string>();
 
             var hints = new Interop.Winsock.AddressInfo { ai_family = AddressFamily.Unspecified }; // Gets all address families
@@ -92,6 +80,8 @@ namespace System.Net
 
         public static unsafe string? TryGetNameInfo(IPAddress addr, out SocketError errorCode, out int nativeErrorCode)
         {
+            Interop.Winsock.EnsureInitialized();
+
             SocketAddress address = new IPEndPoint(addr, 0).Serialize();
             Span<byte> addressBuffer = address.Size <= 64 ? stackalloc byte[64] : new byte[address.Size];
             for (int i = 0; i < address.Size; i++)
@@ -126,6 +116,8 @@ namespace System.Net
 
         public static unsafe string GetHostName()
         {
+            Interop.Winsock.EnsureInitialized();
+
             // We do not cache the result in case the hostname changes.
 
             const int HostNameBufferLength = 256;
@@ -143,6 +135,8 @@ namespace System.Net
 
         public static unsafe Task GetAddrInfoAsync(string hostName, bool justAddresses)
         {
+            Interop.Winsock.EnsureInitialized();
+
             GetAddrInfoExContext* context = GetAddrInfoExContext.AllocateContext();
 
             GetAddrInfoExState state;
@@ -164,7 +158,7 @@ namespace System.Net
             }
 
             SocketError errorCode = (SocketError)Interop.Winsock.GetAddrInfoExW(
-                hostName, null, Interop.Winsock.NS_ALL, IntPtr.Zero, &hints, &context->Result, IntPtr.Zero, &context->Overlapped, s_getAddrInfoExCallback, &context->CancelHandle);
+                hostName, null, Interop.Winsock.NS_ALL, IntPtr.Zero, &hints, &context->Result, IntPtr.Zero, &context->Overlapped, &GetAddressInfoExCallback, &context->CancelHandle);
 
             if (errorCode != SocketError.IOPending)
             {
@@ -174,6 +168,7 @@ namespace System.Net
             return state.Task;
         }
 
+        [UnmanagedCallersOnly]
         private static unsafe void GetAddressInfoExCallback(int error, int bytes, NativeOverlapped* overlapped)
         {
             // Can be casted directly to GetAddrInfoExContext* because the overlapped is its first field
index f88cd14..5709b7d 100644 (file)
@@ -14,7 +14,6 @@ namespace System.Net.NameResolution.PalTests
 
         public NameResolutionPalTests(ITestOutputHelper output)
         {
-            NameResolutionPal.EnsureSocketsAreInitialized();
             _output = output;
         }
 
index 3d260d3..7a5b3c3 100644 (file)
@@ -38,8 +38,6 @@
   <ItemGroup Condition=" '$(TargetsWindows)' == 'true' ">
     <Compile Include="..\..\src\System\Net\NameResolutionPal.Windows.cs"
              Link="ProductionCode\System\Net\NameResolutionPal.Windows.cs" />
-    <Compile Include="..\..\src\System\Net\NameResolutionPal.Win32.cs"
-             Link="ProductionCode\System\Net\NameResolutionPal.Win32.cs" />
     <Compile Include="$(CommonPath)System\Net\InternalException.cs"
              Link="Common\System\Net\InternalException.cs" />
     <Compile Include="$(CommonPath)System\Net\SocketProtocolSupportPal.Windows.cs"
     <Compile Include="$(CommonPath)Interop\Unix\System.Native\Interop.SocketAddress.cs"
              Link="Common\Interop\Unix\System.Native\Interop.SocketAddress.cs" />
   </ItemGroup>
-</Project>
\ No newline at end of file
+</Project>
index 773c1e3..78f031f 100644 (file)
@@ -10,12 +10,6 @@ namespace System.Net
     {
         public static bool SupportsGetAddrInfoAsync => false;
 
-        internal static int FakesEnsureSocketsAreInitializedCallCount
-        {
-            get;
-            private set;
-        }
-
         internal static int FakesGetHostByNameCallCount
         {
             get;
@@ -24,15 +18,9 @@ namespace System.Net
 
         internal static void FakesReset()
         {
-            FakesEnsureSocketsAreInitializedCallCount = 0;
             FakesGetHostByNameCallCount = 0;
         }
 
-        internal static void EnsureSocketsAreInitialized()
-        {
-            FakesEnsureSocketsAreInitializedCallCount++;
-        }
-
         internal static SocketError TryGetAddrInfo(string name, bool justAddresses, out string hostName, out string[] aliases, out IPAddress[] addresses, out int nativeErrorCode)
         {
             throw new NotImplementedException();
diff --git a/src/libraries/System.Net.NameResolution/tests/UnitTests/InitializationTest.cs b/src/libraries/System.Net.NameResolution/tests/UnitTests/InitializationTest.cs
deleted file mode 100644 (file)
index 8b3936d..0000000
+++ /dev/null
@@ -1,126 +0,0 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-
-#pragma warning disable 0618 // using obsolete methods
-
-using System.Threading.Tasks;
-
-using Xunit;
-
-namespace System.Net.NameResolution.Tests
-{
-    public class InitializationTests
-    {
-        [Fact]
-        public void Dns_BeginGetHostAddresses_CallSocketInit_Ok()
-        {
-            NameResolutionPal.FakesReset();
-            Assert.ThrowsAny<Exception>(() => Dns.BeginGetHostAddresses(null, null, null));
-            Assert.NotEqual(0, NameResolutionPal.FakesEnsureSocketsAreInitializedCallCount);
-        }
-
-        [Fact]
-        public void Dns_BeginGetHostByName_CallSocketInit_Ok()
-        {
-            NameResolutionPal.FakesReset();
-            Assert.ThrowsAny<Exception>(() => Dns.BeginGetHostByName(null, null, null));
-            Assert.NotEqual(0, NameResolutionPal.FakesEnsureSocketsAreInitializedCallCount);
-        }
-
-        [Fact]
-        public void Dns_BeginGetHostEntry_String_CallSocketInit_Ok()
-        {
-            NameResolutionPal.FakesReset();
-            Assert.ThrowsAny<Exception>(() => Dns.BeginGetHostEntry((string)null, null, null));
-            Assert.NotEqual(0, NameResolutionPal.FakesEnsureSocketsAreInitializedCallCount);
-        }
-
-        [Fact]
-        public void Dns_BeginGetHostEntry_IPAddress_CallSocketInit_Ok()
-        {
-            NameResolutionPal.FakesReset();
-            Assert.ThrowsAny<Exception>(() => Dns.BeginGetHostEntry((IPAddress)null, null, null));
-            Assert.NotEqual(0, NameResolutionPal.FakesEnsureSocketsAreInitializedCallCount);
-        }
-
-        [Fact]
-        public void Dns_GetHostAddresses_CallSocketInit_Ok()
-        {
-            NameResolutionPal.FakesReset();
-            Assert.ThrowsAny<Exception>(() => Dns.GetHostAddresses(null));
-            Assert.NotEqual(0, NameResolutionPal.FakesEnsureSocketsAreInitializedCallCount);
-        }
-
-        [Fact]
-        public void Dns_GetHostByAddress_String_CallSocketInit_Ok()
-        {
-            NameResolutionPal.FakesReset();
-            Assert.ThrowsAny<Exception>(() => Dns.GetHostByAddress((string)null));
-            Assert.NotEqual(0, NameResolutionPal.FakesEnsureSocketsAreInitializedCallCount);
-        }
-
-        [Fact]
-        public void Dns_GetHostByAddress_IPAddress_CallSocketInit_Ok()
-        {
-            NameResolutionPal.FakesReset();
-            Assert.ThrowsAny<Exception>(() => Dns.GetHostByAddress((IPAddress)null));
-            Assert.NotEqual(0, NameResolutionPal.FakesEnsureSocketsAreInitializedCallCount);
-        }
-
-        [Fact]
-        public void Dns_GetHostByName_CallSocketInit_Ok()
-        {
-            NameResolutionPal.FakesReset();
-            Assert.ThrowsAny<Exception>(() => Dns.GetHostByName(null));
-            Assert.NotEqual(0, NameResolutionPal.FakesEnsureSocketsAreInitializedCallCount);
-        }
-
-        [Fact]
-        public void Dns_GetHostEntry_String_CallSocketInit_Ok()
-        {
-            NameResolutionPal.FakesReset();
-            Assert.ThrowsAny<Exception>(() => Dns.GetHostEntry((string)null));
-            Assert.NotEqual(0, NameResolutionPal.FakesEnsureSocketsAreInitializedCallCount);
-        }
-
-        [Fact]
-        public void Dns_GetHostEntry_IPAddress_CallSocketInit_Ok()
-        {
-            NameResolutionPal.FakesReset();
-            Assert.ThrowsAny<Exception>(() => Dns.GetHostEntry((IPAddress)null));
-            Assert.NotEqual(0, NameResolutionPal.FakesEnsureSocketsAreInitializedCallCount);
-        }
-
-        [Fact]
-        public void Dns_Resolve_CallSocketInit_Ok()
-        {
-            NameResolutionPal.FakesReset();
-            Assert.ThrowsAny<Exception>(() => Dns.Resolve(null));
-            Assert.NotEqual(0, NameResolutionPal.FakesEnsureSocketsAreInitializedCallCount);
-        }
-
-        [Fact]
-        public async Task Dns_GetHostAddressesAsync_CallsSocketInit_Ok()
-        {
-            NameResolutionPal.FakesReset();
-            await Assert.ThrowsAnyAsync<Exception>(() => Dns.GetHostAddressesAsync(null));
-            Assert.NotEqual(0, NameResolutionPal.FakesEnsureSocketsAreInitializedCallCount);
-        }
-
-        [Fact]
-        public async Task Dns_GetHostEntryAsync_CallsSocketInit_Ok()
-        {
-            NameResolutionPal.FakesReset();
-            await Assert.ThrowsAnyAsync<Exception>(() => Dns.GetHostEntryAsync((string)null));
-            Assert.NotEqual(0, NameResolutionPal.FakesEnsureSocketsAreInitializedCallCount);
-        }
-
-        [Fact]
-        public void Dns_GetHostName_CallsSocketInit_Ok()
-        {
-            NameResolutionPal.FakesReset();
-            Assert.ThrowsAny<Exception>(() => Dns.GetHostName());
-            Assert.NotEqual(0, NameResolutionPal.FakesEnsureSocketsAreInitializedCallCount);
-        }
-    }
-}
index 747b307..77c8e68 100644 (file)
@@ -20,7 +20,6 @@
   </ItemGroup>
   <ItemGroup>
     <Compile Include="AssemblyInfo.cs" />
-    <Compile Include="InitializationTest.cs" />
     <Compile Include="XunitTestAssemblyAtrributes.cs" />
     <!-- Fakes -->
     <Compile Include="Fakes\FakeContextAwareResult.cs" />
@@ -44,4 +43,4 @@
     <Compile Include="$(CommonPath)Extensions\ValueStopwatch\ValueStopwatch.cs"
              Link="Common\Extensions\ValueStopwatch\ValueStopwatch.cs" />
   </ItemGroup>
-</Project>
\ No newline at end of file
+</Project>
index fa5e3e7..394becd 100644 (file)
@@ -80,6 +80,8 @@
              Link="Common\Interop\Windows\WinSock\Interop.closesocket.cs" />
     <Compile Include="$(CommonPath)Interop\Windows\WinSock\Interop.WSASocketW.cs"
              Link="Common\Interop\Windows\WinSock\Interop.WSASocketW.cs" />
+    <Compile Include="$(CommonPath)Interop\Windows\WinSock\Interop.WSAStartup.cs"
+             Link="Common\Interop\Windows\WinSock\Interop.WSAStartup.cs" />
     <Compile Include="$(CommonPath)Interop\Windows\WinSock\Interop.SocketConstructorFlags.cs"
              Link="Common\Interop\Windows\WinSock\Interop.SocketConstructorFlags.cs" />
     <!-- System.Net.Internals -->
index b16f087..3ded687 100644 (file)
@@ -17,8 +17,6 @@ namespace System.Net.NetworkInformation
         private const int MaxUdpPacket = 0xFFFF + 256; // Marshal.SizeOf(typeof(Icmp6EchoReply)) * 2 + ip header info;
 
         private static readonly SafeWaitHandle s_nullSafeWaitHandle = new SafeWaitHandle(IntPtr.Zero, true);
-        private static readonly object s_socketInitializationLock = new object();
-        private static bool s_socketInitialized;
 
         private int _sendSize;  // Needed to determine what the reply size is for ipv6 in callback.
         private bool _ipv6;
@@ -391,24 +389,5 @@ namespace System.Net.NetworkInformation
 
             return new PingReply(address, null, ipStatus, rtt, buffer);
         }
-
-        static partial void InitializeSockets()
-        {
-            if (!Volatile.Read(ref s_socketInitialized))
-            {
-                lock (s_socketInitializationLock)
-                {
-                    if (!s_socketInitialized)
-                    {
-                        // Ensure that WSAStartup has been called once per process.
-                        // The System.Net.NameResolution contract is responsible with the initialization.
-                        Dns.GetHostName();
-
-                        // Cache some settings locally.
-                        s_socketInitialized = true;
-                    }
-                }
-            }
-        }
     }
 }
index 71fca63..fb16acb 100644 (file)
@@ -423,8 +423,6 @@ namespace System.Net.NetworkInformation
         // Tests if the current machine supports the given ip protocol family.
         private void TestIsIpSupported(IPAddress ip)
         {
-            InitializeSockets();
-
             if (ip.AddressFamily == AddressFamily.InterNetwork && !SocketProtocolSupportPal.OSSupportsIPv4)
             {
                 throw new NotSupportedException(SR.net_ipv4_not_installed);
@@ -435,7 +433,6 @@ namespace System.Net.NetworkInformation
             }
         }
 
-        static partial void InitializeSockets();
         partial void InternalDisposeCore();
 
         // Creates a default send buffer if a buffer wasn't specified.  This follows the ping.exe model.
index d8ae07c..8db0795 100644 (file)
              Link="Common\Interop\Windows\WinSock\Interop.WSASocketW.cs" />
     <Compile Include="$(CommonPath)Interop\Windows\WinSock\Interop.WSASocketW.SafeCloseSocket.cs"
              Link="Common\Interop\Windows\WinSock\Interop.WSASocketW.SafeCloseSocket.cs" />
+    <Compile Include="$(CommonPath)Interop\Windows\WinSock\Interop.WSAStartup.cs"
+             Link="Common\Interop\Windows\WinSock\Interop.WSAStartup.cs" />
     <Compile Include="$(CommonPath)Interop\Windows\WinSock\Interop.SocketConstructorFlags.cs"
              Link="Common\Interop\Windows\WinSock\Interop.SocketConstructorFlags.cs" />
     <Compile Include="$(CommonPath)Interop\Windows\WinSock\SafeNativeOverlapped.cs"
index 84b5895..4432190 100644 (file)
@@ -20,8 +20,6 @@ namespace System.Net.Sockets
         [SupportedOSPlatform("windows")]
         public Socket(SocketInformation socketInformation)
         {
-            InitializeSockets();
-
             SocketError errorCode = SocketPal.CreateSocket(socketInformation, out _handle,
                 ref _addressFamily, ref _socketType, ref _protocolType);
 
@@ -83,6 +81,10 @@ namespace System.Net.Sockets
         private unsafe void LoadSocketTypeFromHandle(
             SafeSocketHandle handle, out AddressFamily addressFamily, out SocketType socketType, out ProtocolType protocolType, out bool blocking, out bool isListening)
         {
+            // This can be called without winsock initialized. The handle is not going to be a valid socket handle in that case and the code will throw exception anyway.
+            // Initializing winsock will ensure the error SocketError.NotSocket as opposed to SocketError.NotInitialized.
+            Interop.Winsock.EnsureInitialized();
+
             Interop.Winsock.WSAPROTOCOL_INFOW info = default;
             int optionLength = sizeof(Interop.Winsock.WSAPROTOCOL_INFOW);
 
index 0e73552..2e8671d 100644 (file)
@@ -73,12 +73,9 @@ namespace System.Net.Sockets
         // Bool marked true if the native socket option IP_PKTINFO or IPV6_PKTINFO has been set.
         private bool _receivingPacketInformation;
 
-        private static object? s_internalSyncObject;
         private int _closeTimeout = Socket.DefaultCloseTimeout;
         private int _disposed; // 0 == false, anything else == true
 
-        internal static volatile bool s_initialized;
-
         #region Constructors
         public Socket(SocketType socketType, ProtocolType protocolType)
             : this(OSSupportsIPv6 ? AddressFamily.InterNetworkV6 : AddressFamily.InterNetwork, socketType, protocolType)
@@ -93,7 +90,6 @@ namespace System.Net.Sockets
         public Socket(AddressFamily addressFamily, SocketType socketType, ProtocolType protocolType)
         {
             if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, addressFamily);
-            InitializeSockets();
 
             SocketError errorCode = SocketPal.CreateSocket(addressFamily, socketType, protocolType, out _handle);
             if (errorCode != SocketError.Success)
@@ -130,8 +126,6 @@ namespace System.Net.Sockets
 
         private unsafe Socket(SafeSocketHandle handle, bool loadPropertiesFromHandle)
         {
-            InitializeSockets();
-
             _handle = handle;
             _addressFamily = AddressFamily.Unknown;
             _socketType = SocketType.Unknown;
@@ -259,32 +253,9 @@ namespace System.Net.Sockets
         [Obsolete("SupportsIPv6 is obsoleted for this type, please use OSSupportsIPv6 instead. https://go.microsoft.com/fwlink/?linkid=14202")]
         public static bool SupportsIPv6 => OSSupportsIPv6;
 
-        public static bool OSSupportsIPv4
-        {
-            get
-            {
-                InitializeSockets();
-                return SocketProtocolSupportPal.OSSupportsIPv4;
-            }
-        }
-
-        public static bool OSSupportsIPv6
-        {
-            get
-            {
-                InitializeSockets();
-                return SocketProtocolSupportPal.OSSupportsIPv6;
-            }
-        }
-
-        public static bool OSSupportsUnixDomainSockets
-        {
-            get
-            {
-                InitializeSockets();
-                return SocketProtocolSupportPal.OSSupportsUnixDomainSockets;
-            }
-        }
+        public static bool OSSupportsIPv4 => SocketProtocolSupportPal.OSSupportsIPv4;
+        public static bool OSSupportsIPv6 => SocketProtocolSupportPal.OSSupportsIPv6;
+        public static bool OSSupportsUnixDomainSockets => SocketProtocolSupportPal.OSSupportsUnixDomainSockets;
 
         // Gets the amount of data pending in the network's input buffer that can be
         // read from the socket.
@@ -4201,18 +4172,6 @@ namespace System.Net.Sockets
         #endregion
 
         #region Internal and private properties
-        private static object InternalSyncObject
-        {
-            get
-            {
-                if (s_internalSyncObject == null)
-                {
-                    object o = new object();
-                    Interlocked.CompareExchange(ref s_internalSyncObject, o, null);
-                }
-                return s_internalSyncObject;
-            }
-        }
 
         private CacheSet Caches
         {
@@ -4267,26 +4226,6 @@ namespace System.Net.Sockets
             return IPEndPointExtensions.Serialize(remoteEP);
         }
 
-        internal static void InitializeSockets()
-        {
-            if (!s_initialized)
-            {
-                InitializeSocketsCore();
-            }
-
-            static void InitializeSocketsCore()
-            {
-                lock (InternalSyncObject)
-                {
-                    if (!s_initialized)
-                    {
-                        SocketPal.Initialize();
-                        s_initialized = true;
-                    }
-                }
-            }
-        }
-
         private void DoConnect(EndPoint endPointSnapshot, Internals.SocketAddress socketAddress)
         {
             if (SocketsTelemetry.Log.IsEnabled()) SocketsTelemetry.Log.ConnectStart(socketAddress);
index 9bbe0f4..62d5719 100644 (file)
@@ -24,12 +24,6 @@ namespace System.Net.Sockets
 
         private static bool GetPlatformSupportsDualModeIPv4PacketInfo() =>
             Interop.Sys.PlatformSupportsDualModeIPv4PacketInfo() != 0;
-
-        public static void Initialize()
-        {
-            // nop.  No initialization required.
-        }
-
         public static SocketError GetSocketErrorForErrorCode(Interop.Error errorCode)
         {
             return SocketErrorPal.GetSocketErrorForNativeError(errorCode);
index c6fa1c0..8ebb6c4 100644 (file)
@@ -26,13 +26,6 @@ namespace System.Net.Sockets
             socketTime.Microseconds = (int)(microseconds % microcnv);
         }
 
-        public static void Initialize()
-        {
-            // Ensure that WSAStartup has been called once per process.
-            // The System.Net.NameResolution contract is responsible for the initialization.
-            Dns.GetHostName();
-        }
-
         public static SocketError GetLastSocketError()
         {
             int win32Error = Marshal.GetLastWin32Error();
@@ -42,6 +35,8 @@ namespace System.Net.Sockets
 
         public static SocketError CreateSocket(AddressFamily addressFamily, SocketType socketType, ProtocolType protocolType, out SafeSocketHandle socket)
         {
+            Interop.Winsock.EnsureInitialized();
+
             IntPtr handle = Interop.Winsock.WSASocketW(addressFamily, socketType, protocolType, IntPtr.Zero, 0, Interop.Winsock.SocketConstructorFlags.WSA_FLAG_OVERLAPPED |
                                                                                                                 Interop.Winsock.SocketConstructorFlags.WSA_FLAG_NO_HANDLE_INHERIT);
 
@@ -70,6 +65,8 @@ namespace System.Net.Sockets
                 throw new ArgumentException(SR.net_sockets_invalid_socketinformation, nameof(socketInformation));
             }
 
+            Interop.Winsock.EnsureInitialized();
+
             fixed (byte* protocolInfoBytes = socketInformation.ProtocolInformation)
             {
                 // Sockets are non-inheritable in .NET Core.
diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/StartupTests.Windows.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/StartupTests.Windows.cs
new file mode 100644 (file)
index 0000000..d48d767
--- /dev/null
@@ -0,0 +1,95 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.IO.Pipes;
+using Microsoft.DotNet.RemoteExecutor;
+using Xunit;
+
+namespace System.Net.Sockets.Tests
+{
+    [PlatformSpecific(TestPlatforms.Windows)]
+    public class StartupTests
+    {
+        // Socket functionality on Windows requires WSAStartup to have been called, and thus System.Net.Sockets
+        // is responsible for doing so prior to making relevant native calls; this tests entry points.
+        // RemoteExecutor is used so that the individual method is used as early in the process as possible.
+
+        [Fact]
+        public static void OSSupportsIPv4()
+        {
+            bool parentSupported = Socket.OSSupportsIPv4;
+            RemoteExecutor.Invoke(parentSupported =>
+            {
+                Assert.Equal(bool.Parse(parentSupported), Socket.OSSupportsIPv4);
+            }, parentSupported.ToString()).Dispose();
+        }
+
+        [Fact]
+        public static void OSSupportsIPv6()
+        {
+            bool parentSupported = Socket.OSSupportsIPv6;
+            RemoteExecutor.Invoke(parentSupported =>
+            {
+                Assert.Equal(bool.Parse(parentSupported), Socket.OSSupportsIPv6);
+            }, parentSupported.ToString()).Dispose();
+        }
+
+        [Fact]
+        public static void OSSupportsUnixDomainSockets()
+        {
+            bool parentSupported = Socket.OSSupportsUnixDomainSockets;
+            RemoteExecutor.Invoke(parentSupported =>
+            {
+                Assert.Equal(bool.Parse(parentSupported), Socket.OSSupportsUnixDomainSockets);
+            }, parentSupported.ToString()).Dispose();
+        }
+
+#pragma warning disable CS0618 // SupportsIPv4 and SupportsIPv6 are obsolete
+        [Fact]
+        public static void SupportsIPv4()
+        {
+            bool parentSupported = Socket.SupportsIPv4;
+            RemoteExecutor.Invoke(parentSupported =>
+            {
+                Assert.Equal(bool.Parse(parentSupported), Socket.SupportsIPv4);
+            }, parentSupported.ToString()).Dispose();
+        }
+
+        [Fact]
+        public static void SupportsIPv6()
+        {
+            bool parentSupported = Socket.SupportsIPv6;
+            RemoteExecutor.Invoke(parentSupported =>
+            {
+                Assert.Equal(bool.Parse(parentSupported), Socket.SupportsIPv6);
+            }, parentSupported.ToString()).Dispose();
+        }
+#pragma warning restore CS0618
+
+        [Fact]
+        public static void Ctor_SocketType_ProtocolType()
+        {
+            RemoteExecutor.Invoke(() =>
+            {
+                new Socket(SocketType.Stream, ProtocolType.Tcp).Dispose();
+            }).Dispose();
+        }
+
+        [Fact]
+        public static void Ctor_AddressFamily_SocketType_ProtocolType()
+        {
+            RemoteExecutor.Invoke(() =>
+            {
+                new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp).Dispose();
+            }).Dispose();
+        }
+
+        [Fact]
+        public static void Ctor_SafeHandle() => RemoteExecutor.Invoke(() =>
+        {
+            using var pipe = new AnonymousPipeServerStream();
+            SocketException se = Assert.Throws<SocketException>(() => new Socket(new SafeSocketHandle(pipe.ClientSafePipeHandle.DangerousGetHandle(), ownsHandle: false)));
+            Assert.Equal(SocketError.NotSocket, se.SocketErrorCode);
+        }).Dispose();
+    }
+}
index 1e51b4f..8c3c0cc 100644 (file)
@@ -40,6 +40,7 @@
     <Compile Include="SocketTestHelper.cs" />
     <Compile Include="SelectAndPollTests.cs" />
     <Compile Include="SocketInformationTest.cs" />
+    <Compile Include="StartupTests.Windows.cs" />
     <Compile Include="TcpListenerTest.cs" />
     <Compile Include="TelemetryTest.cs" />
     <Compile Include="TimeoutTest.cs" />