Optimize SqlClient tds state to remove handle boxing (dotnet/corefx#34044)
authorWraith2 <Wraith2@users.noreply.github.com>
Fri, 21 Dec 2018 16:07:20 +0000 (16:07 +0000)
committerSaurabh Singh <saurabh.singh@microsoft.com>
Fri, 21 Dec 2018 16:07:20 +0000 (08:07 -0800)
* change TdsParserStateObject to pass packets using a ref struct to avoid boxing of IntPtr in native mode

* add project define for FEATURE_INTEROPSNI on windows non uap builds

* update interop to use SniPacketHandle type name

* rename SNIPacketHandle to SNIPacket

* split PacketHandle and SessionHandle into separate files and implementations

* add comments to PacketHandle and SessionHandle
remove unused packethandle variable in IsConnectionAlive
remove identidal overridden implementations of EmptyReadHandle

* move lazy bool into debug region

* re-add EmptyReadPackt and provide correctly types valid but empty packets in implementations
define IsValidPacket implementations more stringently with type checks

* change implementation switch name to make more sense

* add packet type assertion in IsValiePacket

Commit migrated from https://github.com/dotnet/corefx/commit/d3d8c74919c9ad831ecbae66dd86bd9d60379010

13 files changed:
src/libraries/System.Data.SqlClient/src/System.Data.SqlClient.csproj
src/libraries/System.Data.SqlClient/src/System/Data/SqlClient/PacketHandle.Unix.cs [new file with mode: 0644]
src/libraries/System.Data.SqlClient/src/System/Data/SqlClient/PacketHandle.Windows.cs [new file with mode: 0644]
src/libraries/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIMarsHandle.cs
src/libraries/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIPacket.cs
src/libraries/System.Data.SqlClient/src/System/Data/SqlClient/SessionHandle.Unix.cs [new file with mode: 0644]
src/libraries/System.Data.SqlClient/src/System/Data/SqlClient/SessionHandle.Windows.cs [new file with mode: 0644]
src/libraries/System.Data.SqlClient/src/System/Data/SqlClient/TdsParser.Windows.cs
src/libraries/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserSafeHandles.cs
src/libraries/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserStateObject.cs
src/libraries/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserStateObjectFactory.Windows.cs
src/libraries/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserStateObjectManaged.cs
src/libraries/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserStateObjectNative.cs

index d227f54..d7bf05f 100644 (file)
@@ -1,4 +1,4 @@
-<Project Sdk="Microsoft.NET.Sdk">
+<Project Sdk="Microsoft.NET.Sdk">
   <PropertyGroup>
     <ProjectGuid>{D4550556-4745-457F-BA8F-3EBF3836D6B4}</ProjectGuid>
     <AssemblyName>System.Data.SqlClient</AssemblyName>
     <!-- Manage the SNI toggle for Windows netstandard and UWP -->
     <Compile Include="System\Data\SqlClient\SNI\SNITcpHandle.Windows.cs" />
     <Compile Include="System\Data\SqlClient\TdsParserStateObjectFactory.Windows.cs" />
+    <Compile Include="System\Data\SqlClient\PacketHandle.Windows.cs" />
+    <Compile Include="System\Data\SqlClient\SessionHandle.Windows.cs" />
     <AdditionalFiles Include="$(MSBuildProjectDirectory)/*.analyzerdata.windows" />
   </ItemGroup>
   <ItemGroup Condition="'$(IsUAPAssembly)' == 'true'">
     <Compile Include="System\Data\SqlClient\SNI\LocalDB.uap.cs" />
     <Compile Include="System\Data\ProviderBase\DbConnectionPoolIdentity.Unix.cs" />
     <Compile Include="System\Data\SqlClient\TdsParser.Unix.cs" />
+    <Compile Include="System\Data\SqlClient\PacketHandle.Unix.cs" />
+    <Compile Include="System\Data\SqlClient\SessionHandle.Unix.cs" />
   </ItemGroup>
   <!-- Assets needed on Windows but should be avoided on UAP to avoid sni.dll -->
   <ItemGroup Condition=" '$(TargetsWindows)' == 'true' And '$(IsPartialFacadeAssembly)' != 'true' and '$(IsUAPAssembly)' != 'true'">
     <Compile Include="System\Data\SqlClient\LocalDBAPI.Unix.cs" />
     <Compile Include="System\Data\SqlClient\SNI\LocalDB.Unix.cs" />
     <Compile Include="System\Data\SqlClient\SNI\SNITcpHandle.Unix.cs" />
+    <Compile Include="System\Data\SqlClient\PacketHandle.Unix.cs" />
+    <Compile Include="System\Data\SqlClient\SessionHandle.Unix.cs" />
   </ItemGroup>
   <ItemGroup Condition="'$(TargetsWindows)' == 'true' And '$(IsPartialFacadeAssembly)' != 'true' and '$(IsUAPAssembly)' != 'true'">
     <Reference Include="Microsoft.Win32.Registry" />
diff --git a/src/libraries/System.Data.SqlClient/src/System/Data/SqlClient/PacketHandle.Unix.cs b/src/libraries/System.Data.SqlClient/src/System/Data/SqlClient/PacketHandle.Unix.cs
new file mode 100644 (file)
index 0000000..f9fd9dc
--- /dev/null
@@ -0,0 +1,33 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+
+namespace System.Data.SqlClient
+{
+       // this structure is used for transporting packet handle references between the TdsParserStateObject
+       //  base class and Managed or Native implementations. 
+       // It prevents the native IntPtr type from being boxed and prevents the need to cast from object which loses compile time type safety
+       // It carries type information so that assertions about the type of handle can be made in the implemented abstract methods 
+       // it is a ref struct so that it can only be used to transport the handles and not store them
+
+       // N.B. If you change this type you must also change the version for the other platform
+
+       internal readonly ref struct PacketHandle
+    {
+        public const int NativePointerType = 1;
+        public const int NativePacketType = 2;
+        public const int ManagedPacketType = 3;
+
+        public readonly SNI.SNIPacket ManagedPacket;
+        public readonly int Type;
+
+        private PacketHandle(SNI.SNIPacket managedPacket, int type)
+        {
+            Type = type;
+            ManagedPacket = managedPacket;
+        }
+
+               public static PacketHandle FromManagedPacket(SNI.SNIPacket managedPacket) => new PacketHandle(managedPacket, ManagedPacketType);
+       }
+}
diff --git a/src/libraries/System.Data.SqlClient/src/System/Data/SqlClient/PacketHandle.Windows.cs b/src/libraries/System.Data.SqlClient/src/System/Data/SqlClient/PacketHandle.Windows.cs
new file mode 100644 (file)
index 0000000..f15d26f
--- /dev/null
@@ -0,0 +1,44 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+
+namespace System.Data.SqlClient
+{
+    // this structure is used for transporting packet handle references between the TdsParserStateObject
+    //  base class and Managed or Native implementations. 
+    // It prevents the native IntPtr type from being boxed and prevents the need to cast from object which loses compile time type safety
+    // It carries type information so that assertions about the type of handle can be made in the implemented abstract methods 
+    // it is a ref struct so that it can only be used to transport the handles and not store them
+
+    // N.B. If you change this type you must also change the version for the other platform
+
+    internal readonly ref struct PacketHandle
+    {
+        public const int NativePointerType = 1;
+        public const int NativePacketType = 2;
+        public const int ManagedPacketType = 3;
+
+        public readonly IntPtr NativePointer;
+        public readonly SNIPacket NativePacket;
+
+        public readonly SNI.SNIPacket ManagedPacket;
+        public readonly int Type;
+
+        private PacketHandle(IntPtr nativePointer, SNIPacket nativePacket, SNI.SNIPacket managedPacket, int type)
+        {
+            Type = type;
+            ManagedPacket = managedPacket;
+            NativePointer = nativePointer;
+            NativePacket = nativePacket;
+        }
+
+        public static PacketHandle FromManagedPacket(SNI.SNIPacket managedPacket) => new PacketHandle(default, default, managedPacket, ManagedPacketType);
+
+        public static PacketHandle FromNativePointer(IntPtr nativePointer) => new PacketHandle(nativePointer, default, default, NativePointerType);
+
+        public static PacketHandle FromNativePacket(SNIPacket nativePacket) => new PacketHandle(default, nativePacket, default, NativePacketType);
+
+
+    }
+}
index 4506ff4..5a0e8f7 100644 (file)
@@ -318,7 +318,7 @@ namespace System.Data.SqlClient.SNI
                 _packetEvent.Set();
             }
 
-            ((TdsParserStateObject)_callbackObject).ReadAsyncCallback(packet, 1);
+            ((TdsParserStateObject)_callbackObject).ReadAsyncCallback(PacketHandle.FromManagedPacket(packet), 1);
         }
 
         /// <summary>
@@ -332,7 +332,7 @@ namespace System.Data.SqlClient.SNI
             {
                 Debug.Assert(_callbackObject != null);
 
-                ((TdsParserStateObject)_callbackObject).WriteAsyncCallback(packet, sniErrorCode);
+                ((TdsParserStateObject)_callbackObject).WriteAsyncCallback(PacketHandle.FromManagedPacket(packet), sniErrorCode);
             }
         }
 
@@ -378,7 +378,7 @@ namespace System.Data.SqlClient.SNI
                     _asyncReceives--;
                     Debug.Assert(_callbackObject != null);
 
-                    ((TdsParserStateObject)_callbackObject).ReadAsyncCallback(packet, 0);
+                    ((TdsParserStateObject)_callbackObject).ReadAsyncCallback(PacketHandle.FromManagedPacket(packet), 0);
                 }
             }
 
index 931d064..a047e94 100644 (file)
@@ -20,8 +20,6 @@ namespace System.Data.SqlClient.SNI
         private int _offset;
         private string _description;
         private SNIAsyncCallback _completionCallback;
-
-        private ArrayPool<byte>  _arrayPool = ArrayPool<byte>.Shared;
         private bool _isBufferFromArrayPool = false;
 
         public SNIPacket() { }
@@ -98,14 +96,14 @@ namespace System.Data.SqlClient.SNI
             {
                 if (_isBufferFromArrayPool)
                 {
-                    _arrayPool.Return(_data);
+                    ArrayPool<byte>.Shared.Return(_data);
                 }
                 _data = null;
             }
 
             if (_data == null)
             {
-                _data = _arrayPool.Rent(capacity);
+                _data = ArrayPool<byte>.Shared.Rent(capacity);
                 _isBufferFromArrayPool = true;
             }
 
@@ -221,7 +219,7 @@ namespace System.Data.SqlClient.SNI
             {
                 if(_isBufferFromArrayPool)
                 {
-                    _arrayPool.Return(_data);
+                    ArrayPool<byte>.Shared.Return(_data);
                 }
                 _data = null;
                 _capacity = 0;
diff --git a/src/libraries/System.Data.SqlClient/src/System/Data/SqlClient/SessionHandle.Unix.cs b/src/libraries/System.Data.SqlClient/src/System/Data/SqlClient/SessionHandle.Unix.cs
new file mode 100644 (file)
index 0000000..5bf099a
--- /dev/null
@@ -0,0 +1,34 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+
+namespace System.Data.SqlClient
+{
+       // this structure is used for transporting packet handle references between the TdsParserStateObject
+       //  base class and Managed or Native implementations. 
+       // It carries type information so that assertions about the type of handle can be made in the 
+       //  implemented abstract methods 
+       // it is a ref struct so that it can only be used to transport the handles and not store them
+
+       // N.B. If you change this type you must also change the version for the other platform
+
+       internal readonly ref struct SessionHandle
+    {
+        public const int NativeHandleType = 1;
+        public const int ManagedHandleType = 2;
+
+        public readonly SNI.SNIHandle ManagedHandle;
+        public readonly int Type;
+
+        public SessionHandle(SNI.SNIHandle managedHandle, int type)
+        {
+            Type = type;
+            ManagedHandle = managedHandle;
+        }
+
+        public bool IsNull => ManagedHandle is null;
+
+               public static SessionHandle FromManagedSession(SNI.SNIHandle managedSessionHandle) => new SessionHandle(managedSessionHandle, ManagedHandleType);
+       }
+}
diff --git a/src/libraries/System.Data.SqlClient/src/System/Data/SqlClient/SessionHandle.Windows.cs b/src/libraries/System.Data.SqlClient/src/System/Data/SqlClient/SessionHandle.Windows.cs
new file mode 100644 (file)
index 0000000..a721596
--- /dev/null
@@ -0,0 +1,39 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+
+namespace System.Data.SqlClient
+{
+    // this structure is used for transporting packet handle references between the TdsParserStateObject
+    //  base class and Managed or Native implementations. 
+    // It carries type information so that assertions about the type of handle can be made in the 
+    //  implemented abstract methods 
+    // it is a ref struct so that it can only be used to transport the handles and not store them
+
+    // N.B. If you change this type you must also change the version for the other platform
+
+    internal readonly ref struct SessionHandle
+    {
+        public const int NativeHandleType = 1;
+        public const int ManagedHandleType = 2;
+
+        public readonly SNI.SNIHandle ManagedHandle;
+        public readonly SNIHandle NativeHandle;
+
+        public readonly int Type;
+
+        public SessionHandle(SNI.SNIHandle managedHandle, SNIHandle nativeHandle, int type)
+        {
+            Type = type;
+            ManagedHandle = managedHandle;
+            NativeHandle = nativeHandle;
+        }
+
+        public bool IsNull => (Type == NativeHandleType) ? NativeHandle is null : ManagedHandle is null;
+
+        public static SessionHandle FromManagedSession(SNI.SNIHandle managedSessionHandle) => new SessionHandle(managedSessionHandle, default, ManagedHandleType);
+
+        public static SessionHandle FromNativeHandle(SNIHandle nativeSessionHandle) => new SessionHandle(default, nativeSessionHandle, NativeHandleType);
+    }
+}
index 12ce515..6fec58a 100644 (file)
@@ -20,20 +20,22 @@ namespace System.Data.SqlClient
             // Have to post read to initialize MARS - will get callback on this when connection goes
             // down or is closed.
 
-            IntPtr temp = IntPtr.Zero;
+            PacketHandle temp = default;
             uint error = TdsEnums.SNI_SUCCESS;
 
             _pMarsPhysicalConObj.IncrementPendingCallbacks();
-            object handle = _pMarsPhysicalConObj.SessionHandle;
-            temp = (IntPtr)_pMarsPhysicalConObj.ReadAsync(out error, ref handle);
+            SessionHandle handle = _pMarsPhysicalConObj.SessionHandle;
+            temp = _pMarsPhysicalConObj.ReadAsync(handle, out error);
 
-            if (temp != IntPtr.Zero)
+            Debug.Assert(temp.Type == PacketHandle.NativePointerType, "unexpected packet type when requiring NativePointer");
+
+            if (temp.NativePointer != IntPtr.Zero)
             {
                 // Be sure to release packet, otherwise it will be leaked by native.
                 _pMarsPhysicalConObj.ReleasePacket(temp);
             }
-            
-            Debug.Assert(IntPtr.Zero == temp, "unexpected syncReadPacket without corresponding SNIPacketRelease");
+
+            Debug.Assert(IntPtr.Zero == temp.NativePointer, "unexpected syncReadPacket without corresponding SNIPacketRelease");
             if (TdsEnums.SNI_SUCCESS_IO_PENDING != error)
             {
                 Debug.Assert(TdsEnums.SNI_SUCCESS != error, "Unexpected successful read async on physical connection before enabling MARS!");
@@ -118,4 +120,4 @@ namespace System.Data.SqlClient
         }
 
     }    // tdsparser
-}//namespace
\ No newline at end of file
+}//namespace
index 55cd8a1..7dce0de 100644 (file)
@@ -104,7 +104,7 @@ namespace System.Data.SqlClient
 
                 if (null != stateObj)
                 {
-                    stateObj.ReadAsyncCallback(IntPtr.Zero, packet, error);
+                    stateObj.ReadAsyncCallback(IntPtr.Zero, PacketHandle.FromNativePointer(packet), error);
                 }
             }
         }
@@ -125,7 +125,7 @@ namespace System.Data.SqlClient
 
                 if (null != stateObj)
                 {
-                    stateObj.WriteAsyncCallback(IntPtr.Zero, packet, error);
+                    stateObj.WriteAsyncCallback(IntPtr.Zero, PacketHandle.FromNativePointer(packet), error);
                 }
             }
         }
@@ -296,4 +296,4 @@ namespace System.Data.SqlClient
             }
         }
     }
-}
\ No newline at end of file
+}
index fea7447..e6ba9a6 100644 (file)
@@ -392,7 +392,7 @@ namespace System.Data.SqlClient
             get;
         }
 
-        internal abstract object SessionHandle
+        internal abstract SessionHandle SessionHandle
         {
             get;
         }
@@ -761,27 +761,27 @@ namespace System.Data.SqlClient
 
         internal abstract void DisposePacketCache();
 
-        internal abstract bool IsPacketEmpty(object readPacket);
+        internal abstract bool IsPacketEmpty(PacketHandle readPacket);
 
-        internal abstract object ReadSyncOverAsync(int timeoutRemaining, out uint error);
+        internal abstract PacketHandle ReadSyncOverAsync(int timeoutRemaining, out uint error);
 
-        internal abstract object ReadAsync(out uint error, ref object handle);
+        internal abstract PacketHandle ReadAsync(SessionHandle handle, out uint error);
 
         internal abstract uint CheckConnection();
 
         internal abstract uint SetConnectionBufferSize(ref uint unsignedPacketSize);
 
-        internal abstract void ReleasePacket(object syncReadPacket);
+        internal abstract void ReleasePacket(PacketHandle syncReadPacket);
 
-        protected abstract uint SNIPacketGetData(object packet, byte[] _inBuff, ref uint dataSize);
+        protected abstract uint SNIPacketGetData(PacketHandle packet, byte[] _inBuff, ref uint dataSize);
 
-        internal abstract object GetResetWritePacket();
+        internal abstract PacketHandle GetResetWritePacket();
 
         internal abstract void ClearAllWritePackets();
 
-        internal abstract object AddPacketToPendingList(object packet);
+        internal abstract PacketHandle AddPacketToPendingList(PacketHandle packet);
 
-        protected abstract void RemovePacketFromPendingList(object pointer);
+        protected abstract void RemovePacketFromPendingList(PacketHandle pointer);
 
         internal abstract uint GenerateSspiClientContext(byte[] receivedBuff, uint receivedLength, ref byte[] sendBuff, ref uint sendLength, byte[] _sniSpnBuffer);
 
@@ -855,7 +855,7 @@ namespace System.Data.SqlClient
 
             // NOTE: TdsParserSessionPool may call DecrementPendingCallbacks on a TdsParserStateObject which is already disposed
             // This is not dangerous (since the stateObj is no longer in use), but we need to add a workaround in the assert for it
-            Debug.Assert((remaining == -1 && SessionHandle == null) || (0 <= remaining && remaining < 3), string.Format("_pendingCallbacks values is invalid after decrementing: {0}", remaining));
+            Debug.Assert((remaining == -1 && SessionHandle.IsNull) || (0 <= remaining && remaining < 3), string.Format("_pendingCallbacks values is invalid after decrementing: {0}", remaining));
             return remaining;
         }
 
@@ -2069,7 +2069,7 @@ namespace System.Data.SqlClient
                 throw ADP.ClosedConnectionError();
             }
 
-            object readPacket = null;
+            PacketHandle readPacket = default;
 
             uint error;
 
@@ -2291,7 +2291,7 @@ namespace System.Data.SqlClient
 #endif
 
 
-            object readPacket = null;
+            PacketHandle readPacket = default;
 
             uint error = 0;
 
@@ -2317,16 +2317,14 @@ namespace System.Data.SqlClient
                     ChangeNetworkPacketTimeout(msecsRemaining, Timeout.Infinite);
                 }
 
-                object handle = null;
-
                 Interlocked.Increment(ref _readingCount);
 
-                handle = SessionHandle;
-                if (handle != null)
+                SessionHandle handle = SessionHandle;
+                if (!handle.IsNull)
                 {
                     IncrementPendingCallbacks();
 
-                    readPacket = ReadAsync(out error, ref handle);
+                    readPacket = ReadAsync(handle, out error);
 
                     if (!(TdsEnums.SNI_SUCCESS == error || TdsEnums.SNI_SUCCESS_IO_PENDING == error))
                     {
@@ -2335,8 +2333,8 @@ namespace System.Data.SqlClient
                 }
 
                 Interlocked.Decrement(ref _readingCount);
-                
-                if (handle == null)
+
+                if (handle.IsNull)
                 {
                     throw ADP.ClosedConnectionError();
                 }
@@ -2419,8 +2417,8 @@ namespace System.Data.SqlClient
                 {
                     uint error;
                     SniContext = SniContext.Snix_Connect;
-
                     error = CheckConnection();
+
                     if ((error != TdsEnums.SNI_SUCCESS) && (error != TdsEnums.SNI_WAIT_TIMEOUT))
                     {
                         // Connection is dead
@@ -2498,7 +2496,7 @@ namespace System.Data.SqlClient
                         {
                             stateObj.SendAttention(mustTakeWriteLock: true);
 
-                            object syncReadPacket = null;
+                            PacketHandle syncReadPacket = default;
 
                             bool shouldDecrement = false;
                             try
@@ -2570,7 +2568,7 @@ namespace System.Data.SqlClient
             AssertValidState();
         }
 
-        public void ProcessSniPacket(object packet, uint error)
+        public void ProcessSniPacket(PacketHandle packet, uint error)
         {
             if (error != 0)
             {
@@ -2669,13 +2667,12 @@ namespace System.Data.SqlClient
             }
         }
 
-        public void ReadAsyncCallback<T>(T packet, uint error)
+        public void ReadAsyncCallback(PacketHandle packet, uint error)
         {
             ReadAsyncCallback(IntPtr.Zero, packet, error);
         }
 
-
-        public void ReadAsyncCallback<T>(IntPtr key, T packet, uint error)
+        public void ReadAsyncCallback(IntPtr key, PacketHandle packet, uint error)
         {
             // Key never used.
             // Note - it's possible that when native calls managed that an asynchronous exception
@@ -2755,7 +2752,7 @@ namespace System.Data.SqlClient
             }
         }
 
-        protected abstract bool CheckPacket(object packet, TaskCompletionSource<object> source);
+        protected abstract bool CheckPacket(PacketHandle packet, TaskCompletionSource<object> source);
 
         private void ReadAsyncCallbackCaptureException(TaskCompletionSource<object> source)
         {
@@ -2801,12 +2798,12 @@ namespace System.Data.SqlClient
 
 #pragma warning disable 0420 // a reference to a volatile field will not be treated as volatile
 
-        public void WriteAsyncCallback<T>(T packet, uint sniError)
+        public void WriteAsyncCallback(PacketHandle packet, uint sniError)
         {
             WriteAsyncCallback(IntPtr.Zero, packet, sniError);
         }
 
-        public void WriteAsyncCallback<T>(IntPtr key, T packet, uint sniError)
+        public void WriteAsyncCallback(IntPtr key, PacketHandle packet, uint sniError)
         { // Key never used.
             RemovePacketFromPendingList(packet);
             try
@@ -3218,7 +3215,7 @@ namespace System.Data.SqlClient
 
 #pragma warning disable 0420 // a reference to a volatile field will not be treated as volatile
 
-        private Task SNIWritePacket(object packet, out uint sniError, bool canAccumulate, bool callerHasConnectionLock)
+        private Task SNIWritePacket(PacketHandle packet, out uint sniError, bool canAccumulate, bool callerHasConnectionLock)
         {
             // Check for a stored exception
             var delayedException = Interlocked.Exchange(ref _delayedWriteAsyncCallbackException, null);
@@ -3230,7 +3227,7 @@ namespace System.Data.SqlClient
             Task task = null;
             _writeCompletionSource = null;
 
-            object packetPointer = EmptyReadPacket;
+            PacketHandle packetPointer = EmptyReadPacket;
 
             bool sync = !_parser._asyncWrite;
             if (sync && _asyncWriteCount > 0)
@@ -3351,8 +3348,9 @@ namespace System.Data.SqlClient
             return task;
         }
 
-        internal abstract bool IsValidPacket(object packetPointer);
-        internal abstract uint WritePacket(object packet, bool sync);
+        internal abstract bool IsValidPacket(PacketHandle packetPointer);
+
+        internal abstract uint WritePacket(PacketHandle packet, bool sync);
 
 #pragma warning restore 0420
 
@@ -3369,7 +3367,7 @@ namespace System.Data.SqlClient
                     return;
                 }
 
-                object attnPacket = CreateAndSetAttentionPacket();
+                PacketHandle attnPacket = CreateAndSetAttentionPacket();
 
                 try
                 {
@@ -3427,14 +3425,14 @@ namespace System.Data.SqlClient
             }
         }
 
-        internal abstract object CreateAndSetAttentionPacket();
+        internal abstract PacketHandle CreateAndSetAttentionPacket();
 
-        internal abstract void SetPacketData(object packet, byte[] buffer, int bytesUsed);
+        internal abstract void SetPacketData(PacketHandle packet, byte[] buffer, int bytesUsed);
 
         private Task WriteSni(bool canAccumulate)
         {
             // Prepare packet, and write to packet.
-            object packet = GetResetWritePacket();
+            PacketHandle packet = GetResetWritePacket();
 
             SetBufferSecureStrings();
             SetPacketData(packet, _outBuff, _outBytesUsed);
@@ -3646,7 +3644,7 @@ namespace System.Data.SqlClient
             }
         }
 
-        protected abstract object EmptyReadPacket { get; }
+        protected abstract PacketHandle EmptyReadPacket { get; }
 
         /// <summary>
         /// Gets the full list of errors and warnings (including the pre-attention ones), then wipes all error and warning lists
index 96832fb..fb57c55 100644 (file)
@@ -19,7 +19,14 @@ namespace System.Data.SqlClient
         //private static bool shouldUseLegacyNetorking;
         //public static bool UseManagedSNI { get; } = AppContext.TryGetSwitch(UseLegacyNetworkingOnWindows, out shouldUseLegacyNetorking) ? !shouldUseLegacyNetorking : true;
 
+#if DEBUG
+        private static Lazy<bool> useManagedSNIOnWindows = new Lazy<bool>(
+            () => bool.TrueString.Equals(Environment.GetEnvironmentVariable("System.Data.SqlClient.UseManagedSNIOnWindows"), StringComparison.InvariantCultureIgnoreCase)
+        );
+        public static bool UseManagedSNI => useManagedSNIOnWindows.Value;
+#else
         public static bool UseManagedSNI { get; } = false;
+#endif
 
         public EncryptionOptions EncryptionOptions
         {
index 103e6fb..151d4e5 100644 (file)
@@ -31,14 +31,12 @@ namespace System.Data.SqlClient.SNI
 
         internal override uint Status => _sessionHandle != null ? _sessionHandle.Status : TdsEnums.SNI_UNINITIALIZED;
 
-        internal override object SessionHandle => _sessionHandle;
+        internal override SessionHandle SessionHandle => SessionHandle.FromManagedSession(_sessionHandle);
 
-        protected override object EmptyReadPacket => null;
-
-        protected override bool CheckPacket(object packet, TaskCompletionSource<object> source)
+        protected override bool CheckPacket(PacketHandle packet, TaskCompletionSource<object> source)
         {
-            SNIPacket p = packet as SNIPacket;
-            return p.IsInvalid || (!p.IsInvalid && source != null);
+            SNIPacket p = packet.ManagedPacket;
+            return p.IsInvalid || source != null;
         }
 
         protected override void CreateSessionHandle(TdsParserStateObject physicalConnection, bool async)
@@ -54,7 +52,7 @@ namespace System.Data.SqlClient.SNI
             return _marsConnection.CreateMarsSession(callbackObject, async);
         }
 
-        protected override uint SNIPacketGetData(object packet, byte[] _inBuff, ref uint dataSize) => SNIProxy.Singleton.PacketGetData(packet as SNIPacket, _inBuff, ref dataSize);
+        protected override uint SNIPacketGetData(PacketHandle packet, byte[] _inBuff, ref uint dataSize) => SNIProxy.Singleton.PacketGetData(packet.ManagedPacket, _inBuff, ref dataSize);
 
         internal override void CreatePhysicalSNIHandle(string serverName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[] spnBuffer, bool flushCache, bool async, bool parallel, bool isIntegratedSecurity)
         {
@@ -72,11 +70,11 @@ namespace System.Data.SqlClient.SNI
             }
         }
 
-        internal void ReadAsyncCallback(SNIPacket packet, uint error) => ReadAsyncCallback(IntPtr.Zero, packet, error);
+        internal void ReadAsyncCallback(SNIPacket packet, uint error) => ReadAsyncCallback(IntPtr.Zero, PacketHandle.FromManagedPacket(packet), error);
 
-        internal void WriteAsyncCallback(SNIPacket packet, uint sniError) => WriteAsyncCallback(IntPtr.Zero, packet, sniError);
+        internal void WriteAsyncCallback(SNIPacket packet, uint sniError) => WriteAsyncCallback(IntPtr.Zero, PacketHandle.FromManagedPacket(packet), sniError);
 
-        protected override void RemovePacketFromPendingList(object packet)
+        protected override void RemovePacketFromPendingList(PacketHandle packet)
         {
             // No-Op
         }
@@ -125,7 +123,7 @@ namespace System.Data.SqlClient.SNI
 
         internal override bool IsFailedHandle() => _sessionHandle.Status != TdsEnums.SNI_SUCCESS;
 
-        internal override object ReadSyncOverAsync(int timeoutRemaining, out uint error)
+        internal override PacketHandle ReadSyncOverAsync(int timeoutRemaining, out uint error)
         {
             SNIHandle handle = Handle;
             if (handle == null)
@@ -134,17 +132,19 @@ namespace System.Data.SqlClient.SNI
             }
             SNIPacket packet = null;
             error = SNIProxy.Singleton.ReadSyncOverAsync(handle, out packet, timeoutRemaining);
-            return packet;
+            return PacketHandle.FromManagedPacket(packet);
         }
 
-        internal override bool IsPacketEmpty(object packet)
+        protected override PacketHandle EmptyReadPacket => PacketHandle.FromManagedPacket(null);
+
+        internal override bool IsPacketEmpty(PacketHandle packet)
         {
-            return packet == null;
+            return packet.ManagedPacket == null;
         }
 
-        internal override void ReleasePacket(object syncReadPacket)
+        internal override void ReleasePacket(PacketHandle syncReadPacket)
         {
-            ((SNIPacket)syncReadPacket).Dispose();
+            syncReadPacket.ManagedPacket?.Dispose();
         }
 
         internal override uint CheckConnection()
@@ -153,38 +153,46 @@ namespace System.Data.SqlClient.SNI
             return handle == null ? TdsEnums.SNI_SUCCESS : SNIProxy.Singleton.CheckConnection(handle);
         }
 
-        internal override object ReadAsync(out uint error, ref object handle)
+        internal override PacketHandle ReadAsync(SessionHandle handle, out uint error)
         {
             SNIPacket packet;
-            error = SNIProxy.Singleton.ReadAsync((SNIHandle)handle, out packet);
-            return packet;
+            error = SNIProxy.Singleton.ReadAsync(handle.ManagedHandle, out packet);
+            return PacketHandle.FromManagedPacket(packet);
         }
 
-        internal override object CreateAndSetAttentionPacket()
+        internal override PacketHandle CreateAndSetAttentionPacket()
         {
             if (_sniAsyncAttnPacket == null)
             {
                 SNIPacket attnPacket = new SNIPacket();
-                SetPacketData(attnPacket, SQL.AttentionHeader, TdsEnums.HEADER_LEN);
+                SetPacketData(PacketHandle.FromManagedPacket(attnPacket), SQL.AttentionHeader, TdsEnums.HEADER_LEN);
                 _sniAsyncAttnPacket = attnPacket;
             }
-            return _sniAsyncAttnPacket;
+            return PacketHandle.FromManagedPacket(_sniAsyncAttnPacket);
         }
 
-        internal override uint WritePacket(object packet, bool sync)
+        internal override uint WritePacket(PacketHandle packet, bool sync)
         {
-            return SNIProxy.Singleton.WritePacket((SNIHandle)Handle, (SNIPacket)packet, sync);
+            return SNIProxy.Singleton.WritePacket(Handle, packet.ManagedPacket, sync);
         }
 
-        internal override object AddPacketToPendingList(object packet)
+        internal override PacketHandle AddPacketToPendingList(PacketHandle packet)
         {
             // No-Op
             return packet;
         }
 
-        internal override bool IsValidPacket(object packetPointer) => (SNIPacket)packetPointer != null && !((SNIPacket)packetPointer).IsInvalid;
+        internal override bool IsValidPacket(PacketHandle packet)
+        {
+            Debug.Assert(packet.Type == PacketHandle.ManagedPacketType, "unexpected packet type when requiring ManagedPacket");
+            return (
+                packet.Type == PacketHandle.ManagedPacketType &&
+                packet.ManagedPacket != null &&
+                !packet.ManagedPacket.IsInvalid
+            );
+        }
 
-        internal override object GetResetWritePacket()
+        internal override PacketHandle GetResetWritePacket()
         {
             if (_sniPacket != null)
             {
@@ -197,7 +205,7 @@ namespace System.Data.SqlClient.SNI
                     _sniPacket = _writePacketCache.Take(Handle);
                 }
             }
-            return _sniPacket;
+            return PacketHandle.FromManagedPacket(_sniPacket);
         }
 
         internal override void ClearAllWritePackets()
@@ -214,8 +222,8 @@ namespace System.Data.SqlClient.SNI
             }
         }
 
-        internal override void SetPacketData(object packet, byte[] buffer, int bytesUsed) => SNIProxy.Singleton.PacketSetData((SNIPacket)packet, buffer, bytesUsed);
-        
+        internal override void SetPacketData(PacketHandle packet, byte[] buffer, int bytesUsed) => SNIProxy.Singleton.PacketSetData(packet.ManagedPacket, buffer, bytesUsed);
+
         internal override uint SniGetConnectionId(ref Guid clientConnectionId) => SNIProxy.Singleton.GetConnectionId(Handle, ref clientConnectionId);
 
         internal override uint DisabeSsl() => SNIProxy.Singleton.DisableSsl(Handle);
index e830fc0..5c43bdb 100644 (file)
@@ -12,8 +12,6 @@ namespace System.Data.SqlClient
 {
     internal class TdsParserStateObjectNative : TdsParserStateObject
     {
-        private static readonly object s_cachedEmptyReadPacketObjectPointer = (object)IntPtr.Zero;
-
         private SNIHandle _sessionHandle = null;              // the SNI handle we're to work on
 
         private SNIPacket _sniPacket = null;                // Will have to re-vamp this for MARS
@@ -35,9 +33,7 @@ namespace System.Data.SqlClient
 
         internal override uint Status => _sessionHandle != null ? _sessionHandle.Status : TdsEnums.SNI_UNINITIALIZED;
 
-        internal override object SessionHandle => _sessionHandle;
-
-        protected override object EmptyReadPacket => s_cachedEmptyReadPacketObjectPointer;
+        internal override SessionHandle SessionHandle => SessionHandle.FromNativeHandle(_sessionHandle);
 
         protected override void CreateSessionHandle(TdsParserStateObject physicalConnection, bool async)
         {
@@ -99,11 +95,16 @@ namespace System.Data.SqlClient
             _sessionHandle = new SNIHandle(myInfo, serverName, spnBuffer, ignoreSniOpenTimeout, checked((int)timeout), out instanceName, flushCache, !async, fParallel);
         }
 
-        protected override uint SNIPacketGetData(object packet, byte[] _inBuff, ref uint dataSize) => SNINativeMethodWrapper.SNIPacketGetData((IntPtr)packet, _inBuff, ref dataSize);
+        protected override uint SNIPacketGetData(PacketHandle packet, byte[] _inBuff, ref uint dataSize)
+        {
+            Debug.Assert(packet.Type == PacketHandle.NativePointerType, "unexpected packet type when requiring NativePointer");
+            return SNINativeMethodWrapper.SNIPacketGetData(packet.NativePointer, _inBuff, ref dataSize);
+        }
 
-        protected override bool CheckPacket(object packet, TaskCompletionSource<object> source)
+        protected override bool CheckPacket(PacketHandle packet, TaskCompletionSource<object> source)
         {
-            IntPtr ptr = (IntPtr)(object)packet;
+            Debug.Assert(packet.Type == PacketHandle.NativePointerType, "unexpected packet type when requiring NativePointer");
+            IntPtr ptr = packet.NativePointer;
             return IntPtr.Zero == ptr || IntPtr.Zero != ptr && source != null;
         }
 
@@ -111,9 +112,10 @@ namespace System.Data.SqlClient
 
         public void WriteAsyncCallback(IntPtr key, IntPtr packet, uint sniError) => WriteAsyncCallback(key, packet, sniError);
 
-        protected override void RemovePacketFromPendingList(object ptr)
+        protected override void RemovePacketFromPendingList(PacketHandle ptr)
         {
-            IntPtr pointer = (IntPtr)ptr;
+            Debug.Assert(ptr.Type == PacketHandle.NativePointerType, "unexpected packet type when requiring NativePointer");
+            IntPtr pointer = ptr.NativePointer;
 
             SNIPacket recoveredPacket;
 
@@ -171,7 +173,7 @@ namespace System.Data.SqlClient
 
         internal override bool IsFailedHandle() => _sessionHandle.Status != TdsEnums.SNI_SUCCESS;
 
-        internal override object ReadSyncOverAsync(int timeoutRemaining, out uint error)
+        internal override PacketHandle ReadSyncOverAsync(int timeoutRemaining, out uint error)
         {
             SNIHandle handle = Handle;
             if (handle == null)
@@ -180,12 +182,22 @@ namespace System.Data.SqlClient
             }
             IntPtr readPacketPtr = IntPtr.Zero;
             error = SNINativeMethodWrapper.SNIReadSyncOverAsync(handle, ref readPacketPtr, GetTimeoutRemaining());
-            return readPacketPtr;
+            return PacketHandle.FromNativePointer(readPacketPtr);
         }
 
-        internal override bool IsPacketEmpty(object readPacket) => IntPtr.Zero == (IntPtr)readPacket;
+        protected override PacketHandle EmptyReadPacket => PacketHandle.FromNativePointer(default);
 
-        internal override void ReleasePacket(object syncReadPacket) => SNINativeMethodWrapper.SNIPacketRelease((IntPtr)syncReadPacket);
+        internal override bool IsPacketEmpty(PacketHandle readPacket)
+        {
+            Debug.Assert(readPacket.Type == PacketHandle.NativePointerType || readPacket.Type == 0, "unexpected packet type when requiring NativePointer");
+            return IntPtr.Zero == readPacket.NativePointer;
+        }
+
+        internal override void ReleasePacket(PacketHandle syncReadPacket)
+        {
+            Debug.Assert(syncReadPacket.Type == PacketHandle.NativePointerType, "unexpected packet type when requiring NativePointer");
+            SNINativeMethodWrapper.SNIPacketRelease(syncReadPacket.NativePointer);
+        }
 
         internal override uint CheckConnection()
         {
@@ -193,27 +205,33 @@ namespace System.Data.SqlClient
             return handle == null ? TdsEnums.SNI_SUCCESS : SNINativeMethodWrapper.SNICheckConnection(handle);
         }
 
-        internal override object ReadAsync(out uint error, ref object handle)
+        internal override PacketHandle ReadAsync(SessionHandle handle, out uint error)
         {
+            Debug.Assert(handle.Type == SessionHandle.NativeHandleType, "unexpected handle type when requiring NativePointer");
             IntPtr readPacketPtr = IntPtr.Zero;
-            error = SNINativeMethodWrapper.SNIReadAsync((SNIHandle)handle, ref readPacketPtr);
-            return readPacketPtr;
+            error = SNINativeMethodWrapper.SNIReadAsync(handle.NativeHandle, ref readPacketPtr);
+            return PacketHandle.FromNativePointer(readPacketPtr);
         }
 
-        internal override object CreateAndSetAttentionPacket()
+        internal override PacketHandle CreateAndSetAttentionPacket()
         {
             SNIHandle handle = Handle;
             SNIPacket attnPacket = new SNIPacket(handle);
             _sniAsyncAttnPacket = attnPacket;
-            SetPacketData(attnPacket, SQL.AttentionHeader, TdsEnums.HEADER_LEN);
-            return attnPacket;
+            SetPacketData(PacketHandle.FromNativePacket(attnPacket), SQL.AttentionHeader, TdsEnums.HEADER_LEN);
+            return PacketHandle.FromNativePacket(attnPacket);
         }
 
-        internal override uint WritePacket(object packet, bool sync) => SNINativeMethodWrapper.SNIWritePacket(Handle, (SNIPacket)packet, sync);
+        internal override uint WritePacket(PacketHandle packet, bool sync)
+        {
+            Debug.Assert(packet.Type == PacketHandle.NativePacketType, "unexpected packet type when requiring NativePacket");
+            return SNINativeMethodWrapper.SNIWritePacket(Handle, packet.NativePacket, sync);
+        }
 
-        internal override object AddPacketToPendingList(object packetToAdd)
+        internal override PacketHandle AddPacketToPendingList(PacketHandle packetToAdd)
         {
-            SNIPacket packet = (SNIPacket)packetToAdd;
+            Debug.Assert(packetToAdd.Type == PacketHandle.NativePacketType, "unexpected packet type when requiring NativePacket");
+            SNIPacket packet = packetToAdd.NativePacket;
             Debug.Assert(packet == _sniPacket, "Adding a packet other than the current packet to the pending list");
             _sniPacket = null;
             IntPtr pointer = packet.DangerousGetHandle();
@@ -223,12 +241,20 @@ namespace System.Data.SqlClient
                 _pendingWritePackets.Add(pointer, packet);
             }
 
-            return pointer;
+            return PacketHandle.FromNativePointer(pointer);
         }
 
-        internal override bool IsValidPacket(object packetPointer) => (IntPtr)packetPointer != IntPtr.Zero;
+        internal override bool IsValidPacket(PacketHandle packetPointer)
+        {
+            Debug.Assert(packetPointer.Type == PacketHandle.NativePointerType || packetPointer.Type==PacketHandle.NativePacketType, "unexpected packet type when requiring NativePointer");
+            return (
+                (packetPointer.Type == PacketHandle.NativePointerType && packetPointer.NativePointer != IntPtr.Zero)
+                ||
+                (packetPointer.Type == PacketHandle.NativePacketType && packetPointer.NativePacket != null)
+            );
+        }
 
-        internal override object GetResetWritePacket()
+        internal override PacketHandle GetResetWritePacket()
         {
             if (_sniPacket != null)
             {
@@ -241,7 +267,7 @@ namespace System.Data.SqlClient
                     _sniPacket = _writePacketCache.Take(Handle);
                 }
             }
-            return _sniPacket;
+            return PacketHandle.FromNativePacket(_sniPacket);
         }
 
         internal override void ClearAllWritePackets()
@@ -258,8 +284,11 @@ namespace System.Data.SqlClient
             }
         }
 
-        internal override void SetPacketData(object packet, byte[] buffer, int bytesUsed) 
-            => SNINativeMethodWrapper.SNIPacketSetData((SNIPacket)packet, buffer, bytesUsed);
+        internal override void SetPacketData(PacketHandle packet, byte[] buffer, int bytesUsed)
+        {
+            Debug.Assert(packet.Type == PacketHandle.NativePacketType, "unexpected packet type when requiring NativePacket");
+            SNINativeMethodWrapper.SNIPacketSetData(packet.NativePacket, buffer, bytesUsed);
+        }
 
         internal override uint SniGetConnectionId(ref Guid clientConnectionId) 
             => SNINativeMethodWrapper.SniGetConnectionId(Handle, ref clientConnectionId);