Add asynchronous overload of WindowsIdentity.RunImpersonated (#1152)
authorMarco Rossignoli <marco.rossignoli@gmail.com>
Wed, 22 Jan 2020 22:09:18 +0000 (23:09 +0100)
committerStephen Toub <stoub@microsoft.com>
Wed, 22 Jan 2020 22:09:18 +0000 (17:09 -0500)
* add RunImpersonateAsync overloads

* skip on full framework

* run test netcoreapp only

* skip tests on nanoserver

* address PR feedback

* address PR feedback

* address PR feedback

* address PR feedback

* Update WindowsIdentity.cs

update comment

* Update WindowsIdentity.cs

update comment

* Update WindowsIdentity.cs

Fix comment

* Update WindowsIdentity.cs

fix comment

* Fix comment

src/libraries/System.Security.Principal.Windows/ref/System.Security.Principal.Windows.cs
src/libraries/System.Security.Principal.Windows/src/System/Security/Principal/WindowsIdentity.cs
src/libraries/System.Security.Principal.Windows/tests/System.Security.Principal.Windows.Tests.csproj
src/libraries/System.Security.Principal.Windows/tests/WindowsIdentityImpersonatedTests.netcoreapp.cs [new file with mode: 0644]

index 941aa57..d06d3b0 100644 (file)
@@ -9,7 +9,7 @@ namespace Microsoft.Win32.SafeHandles
 {
     public sealed partial class SafeAccessTokenHandle : System.Runtime.InteropServices.SafeHandle
     {
-        public SafeAccessTokenHandle(System.IntPtr handle) : base (default(System.IntPtr), default(bool)) { }
+        public SafeAccessTokenHandle(System.IntPtr handle) : base(default(System.IntPtr), default(bool)) { }
         public static Microsoft.Win32.SafeHandles.SafeAccessTokenHandle InvalidHandle { get { throw null; } }
         public override bool IsInvalid { get { throw null; } }
         protected override bool ReleaseHandle() { throw null; }
@@ -264,6 +264,8 @@ namespace System.Security.Principal
         public static System.Security.Principal.WindowsIdentity GetCurrent(System.Security.Principal.TokenAccessLevels desiredAccess) { throw null; }
         public static void RunImpersonated(Microsoft.Win32.SafeHandles.SafeAccessTokenHandle safeAccessTokenHandle, System.Action action) { }
         public static T RunImpersonated<T>(Microsoft.Win32.SafeHandles.SafeAccessTokenHandle safeAccessTokenHandle, System.Func<T> func) { throw null; }
+        public static System.Threading.Tasks.Task RunImpersonatedAsync(Microsoft.Win32.SafeHandles.SafeAccessTokenHandle safeAccessTokenHandle, Func<System.Threading.Tasks.Task> func) { throw null; }
+        public static System.Threading.Tasks.Task<T> RunImpersonatedAsync<T>(Microsoft.Win32.SafeHandles.SafeAccessTokenHandle safeAccessTokenHandle, Func<System.Threading.Tasks.Task<T>> func) { throw null; }
         void System.Runtime.Serialization.IDeserializationCallback.OnDeserialization(object sender) { }
         void System.Runtime.Serialization.ISerializable.GetObjectData(System.Runtime.Serialization.SerializationInfo info, System.Runtime.Serialization.StreamingContext context) { }
     }
index 4d1f130..05d949d 100644 (file)
@@ -21,6 +21,7 @@ using QUOTA_LIMITS = Interop.SspiCli.QUOTA_LIMITS;
 using SECURITY_LOGON_TYPE = Interop.SspiCli.SECURITY_LOGON_TYPE;
 using TOKEN_SOURCE = Interop.SspiCli.TOKEN_SOURCE;
 using System.Runtime.Serialization;
+using System.Threading.Tasks;
 
 namespace System.Security.Principal
 {
@@ -697,6 +698,24 @@ namespace System.Security.Principal
             return result;
         }
 
+        /// <summary>
+        /// Runs the specified asynchronous action as the impersonated Windows identity
+        /// </summary>
+        /// <param name="safeAccessTokenHandle">The SafeAccessTokenHandle of the impersonated Windows identity.</param>
+        /// <param name="func">The <see cref="System.Func{Task}"/> to run.</param>
+        /// <returns>A <see cref="Task"/> that represents the asynchronous operation of the provided <see cref="System.Func{Task}"/>.</returns>
+        public static Task RunImpersonatedAsync(SafeAccessTokenHandle safeAccessTokenHandle, Func<Task> func)
+            => RunImpersonated(safeAccessTokenHandle, func);
+
+        /// <summary>
+        /// Runs the specified asynchronous action as the impersonated Windows identity
+        /// </summary>
+        /// <typeparam name="T">The type of the object to return.</typeparam>
+        /// <param name="safeAccessTokenHandle">The SafeAccessTokenHandle of the impersonated Windows identity.</param>
+        /// <param name="func">The <see cref="System.Func{Task}"/> of <see cref="System.Threading.Tasks.Task{T}"/> to run.</param>
+        /// <returns>A <see cref="Task{T}"/> that represents the asynchronous operation of the <see cref="System.Func{Task}"/> of <see cref="System.Threading.Tasks.Task{T}"/> provided.</returns>
+        public static Task<T> RunImpersonatedAsync<T>(SafeAccessTokenHandle safeAccessTokenHandle, Func<Task<T>> func)
+            => RunImpersonated(safeAccessTokenHandle, func);
 
         protected virtual void Dispose(bool disposing)
         {
index bba20c6..2215325 100644 (file)
@@ -7,6 +7,7 @@
     <Compile Include="NTAccount.cs" />
     <Compile Include="SecurityIdentifierTests.cs" />
     <Compile Include="WindowsIdentityTests.cs" />
+    <Compile Include="WindowsIdentityImpersonatedTests.netcoreapp.cs" Condition="'$(TargetsNetCoreApp)' == 'true'" />
     <Compile Include="WindowsPrincipalTests.cs" />
     <Compile Include="WellKnownSidTypeTests.cs" />
   </ItemGroup>
diff --git a/src/libraries/System.Security.Principal.Windows/tests/WindowsIdentityImpersonatedTests.netcoreapp.cs b/src/libraries/System.Security.Principal.Windows/tests/WindowsIdentityImpersonatedTests.netcoreapp.cs
new file mode 100644 (file)
index 0000000..6282668
--- /dev/null
@@ -0,0 +1,187 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.ComponentModel;
+using System.Runtime.InteropServices;
+using System.Security.Cryptography;
+using System.Security.Principal;
+using System.Threading.Tasks;
+using Microsoft.Win32.SafeHandles;
+using Xunit;
+
+// On nano server netapi32.dll is not present
+// we'll skip all tests on that platform
+public class WindowsIdentityImpersonatedTests : IClassFixture<WindowsIdentityFixture>
+{
+    private readonly WindowsIdentityFixture _fixture;
+
+    public WindowsIdentityImpersonatedTests(WindowsIdentityFixture windowsIdentityFixture)
+    {
+        _fixture = windowsIdentityFixture;
+
+        Assert.False(_fixture.TestAccount.AccountTokenHandle.IsInvalid);
+        Assert.False(string.IsNullOrEmpty(_fixture.TestAccount.AccountName));
+    }
+
+    [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsNotWindowsNanoServer))]
+    [OuterLoop]
+    public async Task RunImpersonatedAsync_TaskAndTaskOfT()
+    {
+        WindowsIdentity currentWindowsIdentity = WindowsIdentity.GetCurrent();
+
+        await WindowsIdentity.RunImpersonatedAsync(_fixture.TestAccount.AccountTokenHandle, async () =>
+        {
+            Asserts(currentWindowsIdentity);
+            await Task.Delay(100);
+            Asserts(currentWindowsIdentity);
+        });
+
+        Assert.Equal(WindowsIdentity.GetCurrent().Name, currentWindowsIdentity.Name);
+
+        int result = await WindowsIdentity.RunImpersonatedAsync(_fixture.TestAccount.AccountTokenHandle, async () =>
+        {
+            Asserts(currentWindowsIdentity);
+            await Task.Delay(100);
+            Asserts(currentWindowsIdentity);
+            return 42;
+        });
+
+        Assert.Equal(42, result);
+        Assert.Equal(WindowsIdentity.GetCurrent().Name, currentWindowsIdentity.Name);
+
+        // Assertions
+        void Asserts(WindowsIdentity currentWindowsIdentity)
+        {
+            Assert.Equal(_fixture.TestAccount.AccountName, WindowsIdentity.GetCurrent().Name);
+            Assert.NotEqual(currentWindowsIdentity.Name, WindowsIdentity.GetCurrent().Name);
+        }
+    }
+}
+
+public class WindowsIdentityFixture : IDisposable
+{
+    public WindowsTestAccount TestAccount { get; private set; }
+
+    public WindowsIdentityFixture()
+    {
+        TestAccount = new WindowsTestAccount("CorFxTstWiIde01kiu");
+    }
+
+    public void Dispose()
+    {
+        TestAccount.Dispose();
+    }
+}
+
+public sealed class WindowsTestAccount : IDisposable
+{
+    private readonly string _userName;
+    private SafeAccessTokenHandle _accountTokenHandle;
+    public SafeAccessTokenHandle AccountTokenHandle => _accountTokenHandle;
+    public string AccountName { get; private set; }
+
+    public WindowsTestAccount(string userName)
+    {
+        _userName = userName;
+        CreateUser();
+    }
+
+    private void CreateUser()
+    {
+        string testAccountPassword;
+        using (RandomNumberGenerator rng = new RNGCryptoServiceProvider())
+        {
+            byte[] randomBytes = new byte[33];
+            rng.GetBytes(randomBytes);
+
+            // Add special chars to ensure it satisfies password requirements.
+            testAccountPassword = Convert.ToBase64String(randomBytes) + "_-As@!%*(1)4#2";
+
+            USER_INFO_1 userInfo = new USER_INFO_1
+            {
+                usri1_name = _userName,
+                usri1_password = testAccountPassword,
+                usri1_priv = 1
+            };
+
+            // Create user and remove/create if already exists
+            uint result = NetUserAdd(null, 1, ref userInfo, out uint param_err);
+
+            // error codes https://docs.microsoft.com/en-us/windows/desktop/netmgmt/network-management-error-codes
+            // 0 == NERR_Success
+            if (result == 2224) // NERR_UserExists
+            {
+                result = NetUserDel(null, userInfo.usri1_name);
+                if (result != 0)
+                {
+                    throw new Win32Exception((int)result);
+                }
+                result = NetUserAdd(null, 1, ref userInfo, out param_err);
+                if (result != 0)
+                {
+                    throw new Win32Exception((int)result);
+                }
+            }
+
+            const int LOGON32_PROVIDER_DEFAULT = 0;
+            const int LOGON32_LOGON_INTERACTIVE = 2;
+
+            if (!LogonUser(_userName, ".", testAccountPassword, LOGON32_LOGON_INTERACTIVE, LOGON32_PROVIDER_DEFAULT, out _accountTokenHandle))
+            {
+                _accountTokenHandle = null;
+                throw new Exception($"Failed to get SafeAccessTokenHandle for test account {_userName}", new Win32Exception());
+            }
+
+            bool gotRef = false;
+            try
+            {
+                _accountTokenHandle.DangerousAddRef(ref gotRef);
+                IntPtr logonToken = _accountTokenHandle.DangerousGetHandle();
+                AccountName = new WindowsIdentity(logonToken).Name;
+            }
+            finally
+            {
+                if (gotRef)
+                    _accountTokenHandle.DangerousRelease();
+            }
+        }
+    }
+
+    [DllImport("advapi32.dll", SetLastError = true, CharSet = CharSet.Unicode)]
+    private static extern bool LogonUser(string userName, string domain, string password, int logonType, int logonProvider, out SafeAccessTokenHandle safeAccessTokenHandle);
+
+    [DllImport("netapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)]
+    internal static extern uint NetUserAdd([MarshalAs(UnmanagedType.LPWStr)]string servername, uint level, ref USER_INFO_1 buf, out uint parm_err);
+
+    [DllImport("netapi32.dll")]
+    internal static extern uint NetUserDel([MarshalAs(UnmanagedType.LPWStr)]string servername, [MarshalAs(UnmanagedType.LPWStr)]string username);
+
+    [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)]
+    internal struct USER_INFO_1
+    {
+        public string usri1_name;
+        public string usri1_password;
+        public uint usri1_password_age;
+        public uint usri1_priv;
+        public string usri1_home_dir;
+        public string usri1_comment;
+        public uint usri1_flags;
+        public string usri1_script_path;
+    }
+
+    public void Dispose()
+    {
+        _accountTokenHandle?.Dispose();
+
+        uint result = NetUserDel(null, _userName);
+
+        // 2221= NERR_UserNotFound
+        if (result != 0 && result != 2221)
+        {
+            throw new Win32Exception((int)result);
+        }
+    }
+}
+