fix regression in ChannelBinding/NTLM auth (#40222)
authorTomas Weinfurt <tweinfurt@yahoo.com>
Mon, 3 Aug 2020 20:08:46 +0000 (13:08 -0700)
committerGitHub <noreply@github.com>
Mon, 3 Aug 2020 20:08:46 +0000 (13:08 -0700)
* fix regression in ChannelBinding/NTLM auth

* fix index

src/libraries/Common/src/Interop/Windows/SspiCli/SecuritySafeHandles.cs

index bef2693..40ef8b2 100644 (file)
@@ -453,28 +453,50 @@ namespace System.Net.Security
                     if (inSecBuffers.Count > 2)
                     {
                         inUnmanagedBuffer[2].BufferType = inSecBuffers._item2.Type;
-                        inUnmanagedBuffer[2].cbBuffer = inSecBuffers._item2.Token.Length;
-                        inUnmanagedBuffer[2].pvBuffer = inSecBuffers._item2.UnmanagedToken != null ?
-                            (IntPtr)inSecBuffers._item2.UnmanagedToken.DangerousGetHandle() :
-                            (IntPtr)pinnedToken2;
+                        if (inSecBuffers._item2.UnmanagedToken != null)
+                        {
+                            Debug.Assert(inSecBuffers._item2.Type == SecurityBufferType.SECBUFFER_CHANNEL_BINDINGS);
+                            inUnmanagedBuffer[2].pvBuffer = (IntPtr)inSecBuffers._item2.UnmanagedToken.DangerousGetHandle();
+                            inUnmanagedBuffer[2].cbBuffer = ((ChannelBinding)inSecBuffers._item2.UnmanagedToken).Size;
+                        }
+                        else
+                        {
+                            inUnmanagedBuffer[2].cbBuffer = inSecBuffers._item2.Token.Length;
+                            inUnmanagedBuffer[2].pvBuffer = (IntPtr)pinnedToken2;
+                        }
+
                     }
 
                     if (inSecBuffers.Count > 1)
                     {
                         inUnmanagedBuffer[1].BufferType = inSecBuffers._item1.Type;
-                        inUnmanagedBuffer[1].cbBuffer = inSecBuffers._item1.Token.Length;
-                        inUnmanagedBuffer[1].pvBuffer = inSecBuffers._item1.UnmanagedToken != null ?
-                            (IntPtr)inSecBuffers._item1.UnmanagedToken.DangerousGetHandle() :
-                            (IntPtr)pinnedToken1;
+                        if (inSecBuffers._item1.UnmanagedToken != null)
+                        {
+                            Debug.Assert(inSecBuffers._item1.Type == SecurityBufferType.SECBUFFER_CHANNEL_BINDINGS);
+                            inUnmanagedBuffer[1].pvBuffer = (IntPtr)inSecBuffers._item1.UnmanagedToken.DangerousGetHandle();
+                            inUnmanagedBuffer[1].cbBuffer = ((ChannelBinding)inSecBuffers._item1.UnmanagedToken).Size;
+                        }
+                        else
+                        {
+                            inUnmanagedBuffer[1].cbBuffer = inSecBuffers._item1.Token.Length;
+                            inUnmanagedBuffer[1].pvBuffer = (IntPtr)pinnedToken1;
+                        }
                     }
 
                     if (inSecBuffers.Count > 0)
                     {
                         inUnmanagedBuffer[0].BufferType = inSecBuffers._item0.Type;
-                        inUnmanagedBuffer[0].cbBuffer = inSecBuffers._item0.Token.Length;
-                        inUnmanagedBuffer[0].pvBuffer = inSecBuffers._item0.UnmanagedToken != null ?
-                            (IntPtr)inSecBuffers._item0.UnmanagedToken.DangerousGetHandle() :
-                            (IntPtr)pinnedToken0;
+                        if (inSecBuffers._item0.UnmanagedToken != null)
+                        {
+                            Debug.Assert(inSecBuffers._item0.Type == SecurityBufferType.SECBUFFER_CHANNEL_BINDINGS);
+                            inUnmanagedBuffer[0].pvBuffer = (IntPtr)inSecBuffers._item0.UnmanagedToken.DangerousGetHandle();
+                            inUnmanagedBuffer[0].cbBuffer = ((ChannelBinding)inSecBuffers._item0.UnmanagedToken).Size;
+                        }
+                        else
+                        {
+                            inUnmanagedBuffer[0].cbBuffer = inSecBuffers._item0.Token.Length;
+                            inUnmanagedBuffer[0].pvBuffer = (IntPtr)pinnedToken0;
+                        }
                     }
 
                     fixed (byte* pinnedOutBytes = outSecBuffer.token)
@@ -685,28 +707,50 @@ namespace System.Net.Security
                     if (inSecBuffers.Count > 2)
                     {
                         inUnmanagedBuffer[2].BufferType = inSecBuffers._item2.Type;
-                        inUnmanagedBuffer[2].cbBuffer = inSecBuffers._item2.Token.Length;
-                        inUnmanagedBuffer[2].pvBuffer = inSecBuffers._item2.UnmanagedToken != null ?
-                            (IntPtr)inSecBuffers._item2.UnmanagedToken.DangerousGetHandle() :
-                            (IntPtr)pinnedToken2;
+                        if (inSecBuffers._item2.UnmanagedToken != null)
+                        {
+                            Debug.Assert(inSecBuffers._item2.Type == SecurityBufferType.SECBUFFER_CHANNEL_BINDINGS);
+                            inUnmanagedBuffer[2].pvBuffer = (IntPtr)inSecBuffers._item2.UnmanagedToken.DangerousGetHandle();
+                            inUnmanagedBuffer[2].cbBuffer = ((ChannelBinding)inSecBuffers._item2.UnmanagedToken).Size;
+                        }
+                        else
+                        {
+                            inUnmanagedBuffer[2].cbBuffer = inSecBuffers._item2.Token.Length;
+                            inUnmanagedBuffer[2].pvBuffer = (IntPtr)pinnedToken2;
+                        }
+
                     }
 
                     if (inSecBuffers.Count > 1)
                     {
                         inUnmanagedBuffer[1].BufferType = inSecBuffers._item1.Type;
-                        inUnmanagedBuffer[1].cbBuffer = inSecBuffers._item1.Token.Length;
-                        inUnmanagedBuffer[1].pvBuffer = inSecBuffers._item1.UnmanagedToken != null ?
-                            (IntPtr)inSecBuffers._item1.UnmanagedToken.DangerousGetHandle() :
-                            (IntPtr)pinnedToken1;
+                        if (inSecBuffers._item1.UnmanagedToken != null)
+                        {
+                            Debug.Assert(inSecBuffers._item1.Type == SecurityBufferType.SECBUFFER_CHANNEL_BINDINGS);
+                            inUnmanagedBuffer[1].pvBuffer = (IntPtr)inSecBuffers._item1.UnmanagedToken.DangerousGetHandle();
+                            inUnmanagedBuffer[1].cbBuffer = ((ChannelBinding)inSecBuffers._item1.UnmanagedToken).Size;
+                        }
+                        else
+                        {
+                            inUnmanagedBuffer[1].cbBuffer = inSecBuffers._item1.Token.Length;
+                            inUnmanagedBuffer[1].pvBuffer = (IntPtr)pinnedToken1;
+                        }
                     }
 
                     if (inSecBuffers.Count > 0)
                     {
                         inUnmanagedBuffer[0].BufferType = inSecBuffers._item0.Type;
-                        inUnmanagedBuffer[0].cbBuffer = inSecBuffers._item0.Token.Length;
-                        inUnmanagedBuffer[0].pvBuffer = inSecBuffers._item0.UnmanagedToken != null ?
-                            (IntPtr)inSecBuffers._item0.UnmanagedToken.DangerousGetHandle() :
-                            (IntPtr)pinnedToken0;
+                        if (inSecBuffers._item0.UnmanagedToken != null)
+                        {
+                            Debug.Assert(inSecBuffers._item0.Type == SecurityBufferType.SECBUFFER_CHANNEL_BINDINGS);
+                            inUnmanagedBuffer[0].pvBuffer = (IntPtr)inSecBuffers._item0.UnmanagedToken.DangerousGetHandle();
+                            inUnmanagedBuffer[0].cbBuffer = ((ChannelBinding)inSecBuffers._item0.UnmanagedToken).Size;
+                        }
+                        else
+                        {
+                            inUnmanagedBuffer[0].cbBuffer = inSecBuffers._item0.Token.Length;
+                            inUnmanagedBuffer[0].pvBuffer = (IntPtr)pinnedToken0;
+                        }
                     }
 
                     fixed (byte* pinnedOutBytes = outSecBuffer.token)