Optimize SqlClient SNIPacket async paths (dotnet/corefx#34184)
authorBen Adams <thundercat@illyriad.co.uk>
Fri, 21 Dec 2018 02:21:00 +0000 (02:21 +0000)
committerSaurabh Singh <saurabh.singh@microsoft.com>
Fri, 21 Dec 2018 02:21:00 +0000 (18:21 -0800)
* Optimize SNIPacket async paths

* Feedback

* NET Core vs NET Std

* csproj

* Feedback

Commit migrated from https://github.com/dotnet/corefx/commit/6870c4ed70f6aa39c5a622c9eaa86bd7e01b8fe7

src/libraries/System.Data.SqlClient/src/System.Data.SqlClient.csproj
src/libraries/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIPacket.NetCoreApp.cs [new file with mode: 0644]
src/libraries/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIPacket.NetStandard.cs [new file with mode: 0644]
src/libraries/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIPacket.cs

index c60ca96..d227f54 100644 (file)
     <Compile Include="System\Data\ProviderBase\DbConnectionPool.NetCoreApp.cs" />
     <Compile Include="System\Data\SqlClient\SqlConnectionString.NetCoreApp.cs" />
     <Compile Include="System\Data\SqlClient\SqlConnectionStringBuilder.NetCoreApp.cs" />
+    <Compile Include="System\Data\SqlClient\SNI\SNIPacket.NetCoreApp.cs" />
+  </ItemGroup>
+  <ItemGroup Condition="'$(IsPartialFacadeAssembly)' != 'true' AND '$(OSGroup)' != 'AnyOS' AND '$(TargetsNetCoreApp)' != 'true'">
+    <Compile Include="System\Data\SqlClient\SNI\SNIPacket.NetStandard.cs" />
   </ItemGroup>
   <!-- Manage the SNI toggle for Windows netstandard and UWP -->
   <ItemGroup Condition="('$(TargetGroup)' == 'netstandard' OR '$(TargetsNetCoreApp)' == 'true') AND '$(TargetsWindows)' == 'true'">
diff --git a/src/libraries/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIPacket.NetCoreApp.cs b/src/libraries/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIPacket.NetCoreApp.cs
new file mode 100644 (file)
index 0000000..6e5cab4
--- /dev/null
@@ -0,0 +1,117 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Buffers;
+using System.IO;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace System.Data.SqlClient.SNI
+{
+    internal partial class SNIPacket
+    {
+        /// <summary>
+        /// Read data from a stream asynchronously
+        /// </summary>
+        /// <param name="stream">Stream to read from</param>
+        /// <param name="callback">Completion callback</param>
+        public void ReadFromStreamAsync(Stream stream, SNIAsyncCallback callback)
+        {
+            // Treat local function as a static and pass all params otherwise as async will allocate
+            async Task ReadFromStreamAsync(SNIPacket packet, SNIAsyncCallback cb, ValueTask<int> valueTask)
+            {
+                bool error = false;
+                try
+                {
+                    packet._length = await valueTask.ConfigureAwait(false);
+                    if (packet._length == 0)
+                    {
+                        SNILoadHandle.SingletonInstance.LastError = new SNIError(SNIProviders.TCP_PROV, 0, SNICommon.ConnTerminatedError, string.Empty);
+                        error = true;
+                    }
+                }
+                catch (Exception ex)
+                {
+                    SNILoadHandle.SingletonInstance.LastError = new SNIError(SNIProviders.TCP_PROV, SNICommon.InternalExceptionError, ex);
+                    error = true;
+                }
+
+                if (error)
+                {
+                    packet.Release();
+                }
+
+                cb(packet, error ? TdsEnums.SNI_ERROR : TdsEnums.SNI_SUCCESS);
+            }
+
+            ValueTask<int> vt = stream.ReadAsync(new Memory<byte>(_data, 0, _capacity), CancellationToken.None);
+
+            if (vt.IsCompletedSuccessfully)
+            {
+                _length = vt.Result;
+                // Zero length to go via async local function as is error condition
+                if (_length > 0)
+                {
+                    callback(this, TdsEnums.SNI_SUCCESS);
+
+                    // Completed
+                    return;
+                }
+            }
+
+            // Not complete or error call the async local function to complete
+            _ = ReadFromStreamAsync(this, callback, vt);
+        }
+
+        /// <summary>
+        /// Write data to a stream asynchronously
+        /// </summary>
+        /// <param name="stream">Stream to write to</param>
+        public void WriteToStreamAsync(Stream stream, SNIAsyncCallback callback, SNIProviders provider, bool disposeAfterWriteAsync = false)
+        {
+            // Treat local function as a static and pass all params otherwise as async will allocate
+            async Task WriteToStreamAsync(SNIPacket packet, SNIAsyncCallback cb, SNIProviders providers, bool disposeAfter, ValueTask valueTask)
+            {
+                uint status = TdsEnums.SNI_SUCCESS;
+                try
+                {
+                    await valueTask.ConfigureAwait(false);
+                }
+                catch (Exception e)
+                {
+                    SNILoadHandle.SingletonInstance.LastError = new SNIError(providers, SNICommon.InternalExceptionError, e);
+                    status = TdsEnums.SNI_ERROR;
+                }
+
+                cb(packet, status);
+
+                if (disposeAfter)
+                {
+                    packet.Dispose();
+                }
+            }
+
+            ValueTask vt = stream.WriteAsync(new Memory<byte>(_data, 0, _length), CancellationToken.None);
+
+            if (vt.IsCompletedSuccessfully)
+            {
+                // Read the result to register as complete for the ValueTask
+                vt.GetAwaiter().GetResult();
+
+                callback(this, TdsEnums.SNI_SUCCESS);
+
+                if (disposeAfterWriteAsync)
+                {
+                    Dispose();
+                }
+
+                // Completed
+                return;
+            }
+
+            // Not complete or error call the async local function to complete
+            _ = WriteToStreamAsync(this, callback, provider, disposeAfterWriteAsync, vt);
+        }
+    }
+}
diff --git a/src/libraries/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIPacket.NetStandard.cs b/src/libraries/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIPacket.NetStandard.cs
new file mode 100644 (file)
index 0000000..bfa48ac
--- /dev/null
@@ -0,0 +1,117 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Buffers;
+using System.IO;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace System.Data.SqlClient.SNI
+{
+    internal partial class SNIPacket
+    {
+        /// <summary>
+        /// Read data from a stream asynchronously
+        /// </summary>
+        /// <param name="stream">Stream to read from</param>
+        /// <param name="callback">Completion callback</param>
+        public void ReadFromStreamAsync(Stream stream, SNIAsyncCallback callback)
+        {
+            // Treat local function as a static and pass all params otherwise as async will allocate
+            async Task ReadFromStreamAsync(SNIPacket packet, SNIAsyncCallback cb, Task<int> task)
+            {
+                bool error = false;
+                try
+                {
+                    packet._length = await task.ConfigureAwait(false);
+                    if (packet._length == 0)
+                    {
+                        SNILoadHandle.SingletonInstance.LastError = new SNIError(SNIProviders.TCP_PROV, 0, SNICommon.ConnTerminatedError, string.Empty);
+                        error = true;
+                    }
+                }
+                catch (Exception ex)
+                {
+                    SNILoadHandle.SingletonInstance.LastError = new SNIError(SNIProviders.TCP_PROV, SNICommon.InternalExceptionError, ex);
+                    error = true;
+                }
+
+                if (error)
+                {
+                    packet.Release();
+                }
+
+                cb(packet, error ? TdsEnums.SNI_ERROR : TdsEnums.SNI_SUCCESS);
+            }
+
+            Task<int> t = stream.ReadAsync(_data, 0, _capacity, CancellationToken.None);
+
+            if ((t.Status & TaskStatus.RanToCompletion) != 0)
+            {
+                _length = t.Result;
+                // Zero length to go via async local function as is error condition
+                if (_length > 0)
+                {
+                    callback(this, TdsEnums.SNI_SUCCESS);
+
+                    // Completed
+                    return;
+                }
+            }
+
+            // Not complete or error call the async local function to complete
+            _ = ReadFromStreamAsync(this, callback, t);
+        }
+
+        /// <summary>
+        /// Write data to a stream asynchronously
+        /// </summary>
+        /// <param name="stream">Stream to write to</param>
+        public void WriteToStreamAsync(Stream stream, SNIAsyncCallback callback, SNIProviders provider, bool disposeAfterWriteAsync = false)
+        {
+            // Treat local function as a static and pass all params otherwise as async will allocate
+            async Task WriteToStreamAsync(SNIPacket packet, SNIAsyncCallback cb, SNIProviders providers, bool disposeAfter, Task task)
+            {
+                uint status = TdsEnums.SNI_SUCCESS;
+                try
+                {
+                    await task.ConfigureAwait(false);
+                }
+                catch (Exception e)
+                {
+                    SNILoadHandle.SingletonInstance.LastError = new SNIError(providers, SNICommon.InternalExceptionError, e);
+                    status = TdsEnums.SNI_ERROR;
+                }
+
+                cb(packet, status);
+
+                if (disposeAfter)
+                {
+                    packet.Dispose();
+                }
+            }
+
+            Task t = stream.WriteAsync(_data, 0, _length, CancellationToken.None);
+
+            if ((t.Status & TaskStatus.RanToCompletion) != 0)
+            {
+                // Read the result to register as complete for the Task
+                t.GetAwaiter().GetResult();
+
+                callback(this, TdsEnums.SNI_SUCCESS);
+
+                if (disposeAfterWriteAsync)
+                {
+                    Dispose();
+                }
+
+                // Completed
+                return;
+            }
+
+            // Not complete or error call the async local function to complete
+            _ = WriteToStreamAsync(this, callback, provider, disposeAfterWriteAsync, t);
+        }
+    }
+}
index f7ba249..931d064 100644 (file)
@@ -12,7 +12,7 @@ namespace System.Data.SqlClient.SNI
     /// <summary>
     /// SNI Packet
     /// </summary>
-    internal class SNIPacket : IDisposable, IEquatable<SNIPacket>
+    internal partial class SNIPacket : IDisposable, IEquatable<SNIPacket>
     {
         private byte[] _data;
         private int _length;
@@ -241,46 +241,6 @@ namespace System.Data.SqlClient.SNI
         }
 
         /// <summary>
-        /// Read data from a stream asynchronously
-        /// </summary>
-        /// <param name="stream">Stream to read from</param>
-        /// <param name="callback">Completion callback</param>
-        public void ReadFromStreamAsync(Stream stream, SNIAsyncCallback callback)
-        {
-            bool error = false;
-            
-            stream.ReadAsync(_data, 0, _capacity, CancellationToken.None).ContinueWith(t =>
-            {
-                Exception e = t.Exception?.InnerException;
-                if (e != null)
-                {
-                    SNILoadHandle.SingletonInstance.LastError = new SNIError(SNIProviders.TCP_PROV, SNICommon.InternalExceptionError, e);
-                    error = true;
-                }
-                else
-                {
-                    _length = t.Result;
-
-                    if (_length == 0)
-                    {
-                        SNILoadHandle.SingletonInstance.LastError = new SNIError(SNIProviders.TCP_PROV, 0, SNICommon.ConnTerminatedError, string.Empty);
-                        error = true;
-                    }
-                }
-
-                if (error)
-                {
-                    Release();
-                }
-
-                callback(this, error ? TdsEnums.SNI_ERROR : TdsEnums.SNI_SUCCESS);
-            },
-            CancellationToken.None,
-            TaskContinuationOptions.DenyChildAttach,
-            TaskScheduler.Default);
-        }
-
-        /// <summary>
         /// Read data from a stream synchronously
         /// </summary>
         /// <param name="stream">Stream to read from</param>
@@ -299,30 +259,6 @@ namespace System.Data.SqlClient.SNI
         }
 
         /// <summary>
-        /// Write data to a stream asynchronously
-        /// </summary>
-        /// <param name="stream">Stream to write to</param>
-        public async void WriteToStreamAsync(Stream stream, SNIAsyncCallback callback, SNIProviders provider, bool disposeAfterWriteAsync = false)
-        {
-            uint status = TdsEnums.SNI_SUCCESS;
-            try
-            {
-                await stream.WriteAsync(_data, 0, _length, CancellationToken.None).ConfigureAwait(false);
-            }
-            catch (Exception e)
-            {
-                SNILoadHandle.SingletonInstance.LastError = new SNIError(provider, SNICommon.InternalExceptionError, e);
-                status = TdsEnums.SNI_ERROR;
-            }
-            callback(this, status);
-
-            if (disposeAfterWriteAsync)
-            {
-                Dispose();
-            }
-        }
-
-        /// <summary>
         /// Get hash code
         /// </summary>
         /// <returns>Hash code</returns>