Support IClassFactory2 - LicenseManager use in COM activation (#22846)
authorAaron Robinson <arobins@microsoft.com>
Sat, 2 Mar 2019 06:53:42 +0000 (22:53 -0800)
committerGitHub <noreply@github.com>
Sat, 2 Mar 2019 06:53:42 +0000 (22:53 -0800)
* Implement IClassFactory2

* Add test support for IClassFactory2

* Add support for testing RCW activation via IClassFactory2

22 files changed:
src/System.Private.CoreLib/src/Internal/Runtime/InteropServices/ComActivator.cs
tests/issues.targets
tests/src/Interop/CMakeLists.txt
tests/src/Interop/COM/NETClients/Licensing/App.manifest [new file with mode: 0644]
tests/src/Interop/COM/NETClients/Licensing/NETClientLicense.csproj [new file with mode: 0644]
tests/src/Interop/COM/NETClients/Licensing/Program.cs [new file with mode: 0644]
tests/src/Interop/COM/NETServer/LicenseTesting.cs [new file with mode: 0644]
tests/src/Interop/COM/NETServer/NETServer.csproj
tests/src/Interop/COM/NativeClients/Licensing.csproj [new file with mode: 0644]
tests/src/Interop/COM/NativeClients/Licensing/App.manifest [new file with mode: 0644]
tests/src/Interop/COM/NativeClients/Licensing/CMakeLists.txt [new file with mode: 0644]
tests/src/Interop/COM/NativeClients/Licensing/CoreShim.X.manifest [new file with mode: 0644]
tests/src/Interop/COM/NativeClients/Licensing/LicenseTests.cpp [new file with mode: 0644]
tests/src/Interop/COM/NativeServer/COMNativeServer.X.manifest
tests/src/Interop/COM/NativeServer/LicenseTesting.h [new file with mode: 0644]
tests/src/Interop/COM/NativeServer/Servers.cpp
tests/src/Interop/COM/NativeServer/Servers.h
tests/src/Interop/COM/ServerContracts/NativeServers.cs
tests/src/Interop/COM/ServerContracts/Server.Contracts.cs
tests/src/Interop/COM/ServerContracts/Server.Contracts.h
tests/src/Interop/COM/ServerContracts/ServerGuids.cs
tests/src/Interop/common/ComHelpers.h

index 1da9ffa..d8689e4 100644 (file)
@@ -115,6 +115,12 @@ namespace Internal.Runtime.InteropServices
             }
 
             Type classType = FindClassType(cxt.ClassId, cxt.AssemblyPath, cxt.AssemblyName, cxt.TypeName);
+
+            if (LicenseInteropProxy.HasLicense(classType))
+            {
+                return new LicenseClassFactory(cxt.ClassId, classType);
+            }
+
             return new BasicClassFactory(cxt.ClassId, classType);
         }
 
@@ -250,57 +256,79 @@ $@"{nameof(GetClassFactoryForTypeInternal)} arguments:
         }
 
         [ComVisible(true)]
-        internal class BasicClassFactory : IClassFactory2
+        private class BasicClassFactory : IClassFactory
         {
-            private readonly Guid classId;
-            private readonly Type classType;
+            private readonly Guid _classId;
+            private readonly Type _classType;
 
             public BasicClassFactory(Guid clsid, Type classType)
             {
-                this.classId = clsid;
-                this.classType = classType;
+                _classId = clsid;
+                _classType = classType;
             }
 
-            public void CreateInstance(
-                [MarshalAs(UnmanagedType.Interface)] object pUnkOuter,
-                ref Guid riid,
-                [MarshalAs(UnmanagedType.Interface)] out object ppvObject)
+            public static void ValidateInterfaceRequest(Type classType, ref Guid riid, object outer)
             {
-                if (riid != Marshal.IID_IUnknown)
+                Debug.Assert(classType != null);
+                if (riid == Marshal.IID_IUnknown)
                 {
-                    bool found = false;
+                    return;
+                }
 
-                    // Verify the class implements the desired interface
-                    foreach (Type i in this.classType.GetInterfaces())
-                    {
-                        if (i.GUID == riid)
-                        {
-                            found = true;
-                            break;
-                        }
-                    }
+                // Aggregation can only be done when requesting IUnknown.
+                if (outer != null)
+                {
+                    const int CLASS_E_NOAGGREGATION = unchecked((int)0x80040110);
+                    throw new COMException(string.Empty, CLASS_E_NOAGGREGATION);
+                }
+
+                bool found = false;
 
-                    if (!found)
+                // Verify the class implements the desired interface
+                foreach (Type i in classType.GetInterfaces())
+                {
+                    if (i.GUID == riid)
                     {
-                        // E_NOINTERFACE
-                        throw new InvalidCastException();
+                        found = true;
+                        break;
                     }
                 }
 
-                ppvObject = Activator.CreateInstance(this.classType);
+                if (!found)
+                {
+                    // E_NOINTERFACE
+                    throw new InvalidCastException();
+                }
+            }
+
+            public static object CreateAggregatedObject(object pUnkOuter, object comObject)
+            {
+                Debug.Assert(pUnkOuter != null && comObject != null);
+                IntPtr outerPtr = Marshal.GetIUnknownForObject(pUnkOuter);
+
+                try
+                {
+                    IntPtr innerPtr = Marshal.CreateAggregatedObject(outerPtr, comObject);
+                    return Marshal.GetObjectForIUnknown(innerPtr);
+                }
+                finally
+                {
+                    // Decrement the above 'Marshal.GetIUnknownForObject()'
+                    Marshal.ReleaseComObject(pUnkOuter);
+                }
+            }
+
+            public void CreateInstance(
+                [MarshalAs(UnmanagedType.Interface)] object pUnkOuter,
+                ref Guid riid,
+                [MarshalAs(UnmanagedType.Interface)] out object ppvObject)
+            {
+                BasicClassFactory.ValidateInterfaceRequest(_classType, ref riid, pUnkOuter);
+
+                ppvObject = Activator.CreateInstance(_classType);
                 if (pUnkOuter != null)
                 {
-                    try
-                    {
-                        IntPtr outerPtr = Marshal.GetIUnknownForObject(pUnkOuter);
-                        IntPtr innerPtr = Marshal.CreateAggregatedObject(outerPtr, ppvObject);
-                        ppvObject = Marshal.GetObjectForIUnknown(innerPtr);
-                    }
-                    finally
-                    {
-                        // Decrement the above 'Marshal.GetIUnknownForObject()'
-                        Marshal.ReleaseComObject(pUnkOuter);
-                    }
+                    ppvObject = BasicClassFactory.CreateAggregatedObject(pUnkOuter, ppvObject);
                 }
             }
 
@@ -308,15 +336,50 @@ $@"{nameof(GetClassFactoryForTypeInternal)} arguments:
             {
                 // nop
             }
+        }
 
-            public void GetLicInfo(ref LICINFO pLicInfo)
+        [ComVisible(true)]
+        private class LicenseClassFactory : IClassFactory2
+        {
+            private readonly LicenseInteropProxy _licenseProxy = new LicenseInteropProxy();
+            private readonly Guid _classId;
+            private readonly Type _classType;
+
+            public LicenseClassFactory(Guid clsid, Type classType)
             {
-                throw new NotImplementedException();
+                _classId = clsid;
+                _classType = classType;
+            }
+
+            public void CreateInstance(
+                [MarshalAs(UnmanagedType.Interface)] object pUnkOuter,
+                ref Guid riid,
+                [MarshalAs(UnmanagedType.Interface)] out object ppvObject)
+            {
+                CreateInstanceInner(pUnkOuter, ref riid, key: null, isDesignTime: true, out ppvObject);
+            }
+
+            public void LockServer([MarshalAs(UnmanagedType.Bool)] bool fLock)
+            {
+                // nop
+            }
+
+            public void GetLicInfo(ref LICINFO licInfo)
+            {
+                bool runtimeKeyAvail;
+                bool licVerified;
+                _licenseProxy.GetLicInfo(_classType, out runtimeKeyAvail, out licVerified);
+
+                // The LICINFO is a struct with a DWORD size field and two BOOL fields. Each BOOL
+                // is typedef'd from a DWORD, therefore the size is manually computed as below.
+                licInfo.cbLicInfo = sizeof(int) + sizeof(int) + sizeof(int);
+                licInfo.fRuntimeKeyAvail = runtimeKeyAvail;
+                licInfo.fLicVerified = licVerified;
             }
 
             public void RequestLicKey(int dwReserved, [MarshalAs(UnmanagedType.BStr)] out string pBstrKey)
             {
-                throw new NotImplementedException();
+                pBstrKey = _licenseProxy.RequestLicKey(_classType);
             }
 
             public void CreateInstanceLic(
@@ -326,7 +389,184 @@ $@"{nameof(GetClassFactoryForTypeInternal)} arguments:
                 [MarshalAs(UnmanagedType.BStr)] string bstrKey,
                 [MarshalAs(UnmanagedType.Interface)] out object ppvObject)
             {
-                throw new NotImplementedException();
+                Debug.Assert(pUnkReserved == null);
+                CreateInstanceInner(pUnkOuter, ref riid, bstrKey, isDesignTime: false, out ppvObject);
+            }
+
+            private void CreateInstanceInner(
+                object pUnkOuter,
+                ref Guid riid,
+                string key,
+                bool isDesignTime,
+                out object ppvObject)
+            {
+                BasicClassFactory.ValidateInterfaceRequest(_classType, ref riid, pUnkOuter);
+
+                ppvObject = _licenseProxy.AllocateAndValidateLicense(_classType, key, isDesignTime);
+                if (pUnkOuter != null)
+                {
+                    ppvObject = BasicClassFactory.CreateAggregatedObject(pUnkOuter, ppvObject);
+                }
+            }
+        }
+
+        // This is a helper class that supports the CLR's IClassFactory2 marshaling
+        // support.
+        //
+        // When a managed object is exposed to COM, the CLR invokes
+        // AllocateAndValidateLicense() to set up the appropriate
+        // license context and instantiate the object.
+        private class LicenseInteropProxy
+        {
+            private static readonly Type s_licenseAttrType;
+            private static readonly Type s_licenseExceptionType;
+
+            private MethodInfo _createWithContext;
+            private MethodInfo _validateTypeAndReturnDetails;
+            private MethodInfo _createDesignContext;
+            private MethodInfo _createRuntimeContext;
+
+            private Type _licInfoHelper;
+            private MethodInfo _licInfoHelperContains;
+
+            static LicenseInteropProxy()
+            {
+                s_licenseAttrType = Type.GetType("System.ComponentModel.LicenseProviderAttribute, System.ComponentModel.TypeConverter", throwOnError: false);
+                s_licenseExceptionType = Type.GetType("System.ComponentModel.LicenseException, System.ComponentModel.TypeConverter", throwOnError: false);
+            }
+
+            public LicenseInteropProxy()
+            {
+                Type licManager = Type.GetType("System.ComponentModel.LicenseManager, System.ComponentModel.TypeConverter", throwOnError: true);
+
+                Type licContext = Type.GetType("System.ComponentModel.LicenseContext, System.ComponentModel.TypeConverter", throwOnError: true);
+                _createWithContext = licManager.GetMethod("CreateWithContext", new[] { typeof(Type), licContext });
+
+                Type interopHelper = licManager.GetNestedType("LicenseInteropHelper", BindingFlags.NonPublic);
+                _validateTypeAndReturnDetails = interopHelper.GetMethod("ValidateAndRetrieveLicenseDetails", BindingFlags.Static | BindingFlags.Public);
+
+                Type clrLicContext = licManager.GetNestedType("CLRLicenseContext", BindingFlags.NonPublic);
+                _createDesignContext = clrLicContext.GetMethod("CreateDesignContext", BindingFlags.Static | BindingFlags.Public);
+                _createRuntimeContext = clrLicContext.GetMethod("CreateRuntimeContext", BindingFlags.Static | BindingFlags.Public);
+
+                _licInfoHelper = licManager.GetNestedType("LicInfoHelperLicenseContext", BindingFlags.NonPublic);
+                _licInfoHelperContains = _licInfoHelper.GetMethod("Contains", BindingFlags.Instance | BindingFlags.Public);
+            }
+
+            // Determine if the type supports licensing
+            public static bool HasLicense(Type type)
+            {
+                // If the attribute type can't be found, then the type
+                // definitely doesn't support licensing.
+                if (s_licenseAttrType == null)
+                {
+                    return false;
+                }
+
+                return type.IsDefined(s_licenseAttrType, inherit: true);
+            }
+
+            // The CLR invokes this whenever a COM client invokes
+            // IClassFactory2::GetLicInfo on a managed class.
+            //
+            // COM normally doesn't expect this function to fail so this method
+            // should only throw in the case of a catastrophic error (stack, memory, etc.)
+            public void GetLicInfo(Type type, out bool runtimeKeyAvail, out bool licVerified)
+            {
+                runtimeKeyAvail = false;
+                licVerified = false;
+
+                // Types are as follows:
+                // LicenseContext, Type, out License, out string
+                object licContext = Activator.CreateInstance(_licInfoHelper);
+                var parameters = new object[] { licContext, type, /* out */ null, /* out */ null };
+                bool isValid = (bool)_validateTypeAndReturnDetails.Invoke(null, BindingFlags.DoNotWrapExceptions, binder: null, parameters: parameters, culture: null);
+                if (!isValid)
+                {
+                    return;
+                }
+
+                var license = (IDisposable)parameters[2];
+                if (license != null)
+                {
+                    license.Dispose();
+                    licVerified = true;
+                }
+
+                parameters = new object[] { type.AssemblyQualifiedName };
+                runtimeKeyAvail = (bool)_licInfoHelperContains.Invoke(licContext, BindingFlags.DoNotWrapExceptions, binder: null, parameters: parameters, culture: null);
+            }
+
+            // The CLR invokes this whenever a COM client invokes
+            // IClassFactory2::RequestLicKey on a managed class.
+            public string RequestLicKey(Type type)
+            {
+                // License will be null, since we passed no instance,
+                // however we can still retrieve the "first" license
+                // key from the file. This really will only
+                // work for simple COM-compatible license providers
+                // like LicFileLicenseProvider that don't require the
+                // instance to grant a key.
+
+                // Types are as follows:
+                // LicenseContext, Type, out License, out string
+                var parameters = new object[] { /* use global LicenseContext */ null, type, /* out */ null, /* out */ null };
+                bool isValid = (bool)_validateTypeAndReturnDetails.Invoke(null, BindingFlags.DoNotWrapExceptions, binder: null, parameters: parameters, culture: null);
+                if (!isValid)
+                {
+                    throw new COMException(); //E_FAIL
+                }
+
+                var license = (IDisposable)parameters[2];
+                if (license != null)
+                {
+                    license.Dispose();
+                }
+
+                string licenseKey = (string)parameters[3];
+                if (licenseKey == null)
+                {
+                    throw new COMException(); //E_FAIL
+                }
+
+                return licenseKey;
+            }
+
+            // The CLR invokes this whenever a COM client invokes
+            // IClassFactory::CreateInstance() or IClassFactory2::CreateInstanceLic()
+            // on a managed that has a LicenseProvider custom attribute.
+            //
+            // If we are being entered because of a call to ICF::CreateInstance(),
+            // "isDesignTime" will be "true".
+            //
+            // If we are being entered because of a call to ICF::CreateInstanceLic(),
+            // "isDesignTime" will be "false" and "key" will point to a non-null
+            // license key.
+            public object AllocateAndValidateLicense(Type type, string key, bool isDesignTime)
+            {
+                object[] parameters;
+                object licContext;
+                if (isDesignTime)
+                {
+                    parameters = new object[] { type };
+                    licContext = _createDesignContext.Invoke(null, BindingFlags.DoNotWrapExceptions, binder: null, parameters: parameters, culture: null);
+                }
+                else
+                {
+                    parameters = new object[] { type, key };
+                    licContext = _createRuntimeContext.Invoke(null, BindingFlags.DoNotWrapExceptions, binder: null, parameters: parameters, culture: null);
+                }
+
+                try
+                {
+                    parameters = new object[] { type, licContext };
+                    return _createWithContext.Invoke(null, BindingFlags.DoNotWrapExceptions, binder: null, parameters: parameters, culture: null);
+                }
+                catch (Exception exception) when (exception.GetType() == s_licenseExceptionType)
+                {
+                    const int CLASS_E_NOTLICENSED = unchecked((int)0x80040112);
+                    throw new COMException(exception.Message, CLASS_E_NOTLICENSED);
+                }
             }
         }
     }
index a575498..020caf0 100644 (file)
         <ExcludeList Include="$(XunitTestBinBase)/Interop/COM/NativeClients/Primitives/*">
             <Issue>20682</Issue>
         </ExcludeList>
+        <ExcludeList Include="$(XunitTestBinBase)/Interop/COM/NativeClients/Licensing/*">
+            <Issue>20682</Issue>
+        </ExcludeList>
         <ExcludeList Include="$(XunitTestBinBase)/Interop/COM/NETClients/IDispatch/NETClientIDispatch/*">
             <Issue>20682</Issue>
         </ExcludeList>
         <ExcludeList Include="$(XunitTestBinBase)/Interop/COM/NETClients/Aggregation/NETClientAggregation/*">
             <Issue>20682</Issue>
         </ExcludeList>
+        <ExcludeList Include="$(XunitTestBinBase)/Interop/COM/NETClients/Licensing/NETClientLicense/*">
+            <Issue>20682</Issue>
+        </ExcludeList>
         <ExcludeList Include="$(XunitTestBinBase)/Interop/COM/NETClients/Events/NETClientEvents/*">
             <Issue>22784</Issue>
         </ExcludeList>
index 106ef1f..66124de 100644 (file)
@@ -81,6 +81,7 @@ if(WIN32)
     add_subdirectory(ArrayMarshalling/SafeArray)
     add_subdirectory(COM/NativeServer)
     add_subdirectory(COM/NativeClients/Primitives)
+    add_subdirectory(COM/NativeClients/Licensing)
     add_subdirectory(IJW/FakeMscoree)
 
     # IJW isn't supported on ARM64
diff --git a/tests/src/Interop/COM/NETClients/Licensing/App.manifest b/tests/src/Interop/COM/NETClients/Licensing/App.manifest
new file mode 100644 (file)
index 0000000..c0181f3
--- /dev/null
@@ -0,0 +1,18 @@
+<?xml version="1.0" encoding="utf-8"?>
+<assembly manifestVersion="1.0" xmlns="urn:schemas-microsoft-com:asm.v1">
+  <assemblyIdentity
+    type="win32"
+    name="NetClientLicensing"
+    version="1.0.0.0" />
+
+  <dependency>
+    <dependentAssembly>
+      <!-- RegFree COM -->
+      <assemblyIdentity
+          type="win32"
+          name="COMNativeServer.X"
+          version="1.0.0.0"/>
+    </dependentAssembly>
+  </dependency>
+
+</assembly>
diff --git a/tests/src/Interop/COM/NETClients/Licensing/NETClientLicense.csproj b/tests/src/Interop/COM/NETClients/Licensing/NETClientLicense.csproj
new file mode 100644 (file)
index 0000000..5215422
--- /dev/null
@@ -0,0 +1,42 @@
+<?xml version="1.0" encoding="utf-8"?>
+<Project ToolsVersion="12.0" DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
+  <Import Project="$([MSBuild]::GetDirectoryNameOfFileAbove($(MSBuildThisFileDirectory), dir.props))\dir.props" />
+  <PropertyGroup>
+    <Configuration Condition=" '$(Configuration)' == '' ">Debug</Configuration>
+    <Platform Condition=" '$(Platform)' == '' ">AnyCPU</Platform>
+    <AssemblyName>NETClientLicense</AssemblyName>
+    <SchemaVersion>2.0</SchemaVersion>
+    <ProjectGuid>{85C57688-DA98-4DE3-AC9B-526E4747434C}</ProjectGuid>
+    <OutputType>Exe</OutputType>
+    <ProjectTypeGuids>{209912F9-0DA1-4184-9CC1-8D583BAF4A28};{87799F5D-CEBD-499D-BDBA-B2C6105CD766}</ProjectTypeGuids>
+    <ApplicationManifest>App.manifest</ApplicationManifest>
+
+    <!-- Blocked on ILAsm supporting embedding resources. See https://github.com/dotnet/coreclr/issues/20819 -->
+    <IlrtTestKind>BuildOnly</IlrtTestKind>
+
+    <!-- Blocked on CrossGen.exe supporting embedding resources. See https://github.com/dotnet/coreclr/issues/21006 -->
+    <CrossGenTest>false</CrossGenTest>
+
+    <!-- Test unsupported outside of windows -->
+    <TestUnsupportedOutsideWindows>true</TestUnsupportedOutsideWindows>
+    <DisableProjectBuild Condition="'$(TargetsUnix)' == 'true'">true</DisableProjectBuild>
+    <!-- This test would require the runincontext.exe to include App.manifest describing the COM interfaces -->
+    <UnloadabilityIncompatible>true</UnloadabilityIncompatible>
+  </PropertyGroup>
+  <!-- Default configurations to help VS understand the configurations -->
+  <PropertyGroup Condition="'$(Configuration)|$(Platform)' == 'Debug|x64'">
+  </PropertyGroup>
+  <PropertyGroup Condition="'$(Configuration)|$(Platform)' == 'Release|x64'">
+  </PropertyGroup>
+  <ItemGroup>
+    <Compile Include="Program.cs" />
+    <Compile Include="../../ServerContracts/NativeServers.cs" />
+    <Compile Include="../../ServerContracts/Server.Contracts.cs" />
+    <Compile Include="../../ServerContracts/ServerGuids.cs" />
+  </ItemGroup>
+  <ItemGroup>
+    <ProjectReference Include="../../NativeServer/CMakeLists.txt" />
+    <ProjectReference Include="../../../../Common/CoreCLRTestLibrary/CoreCLRTestLibrary.csproj" />
+  </ItemGroup>
+  <Import Project="$([MSBuild]::GetDirectoryNameOfFileAbove($(MSBuildThisFileDirectory), dir.targets))\dir.targets" />
+</Project>
diff --git a/tests/src/Interop/COM/NETClients/Licensing/Program.cs b/tests/src/Interop/COM/NETClients/Licensing/Program.cs
new file mode 100644 (file)
index 0000000..68ac881
--- /dev/null
@@ -0,0 +1,147 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+namespace NetClient
+{
+    using System;
+    using System.Collections.Generic;
+    using System.ComponentModel;
+    using System.Reflection;
+    using System.Runtime.InteropServices;
+
+    using TestLibrary;
+    using Server.Contract;
+    using Server.Contract.Servers;
+
+    class Program
+    {
+        static readonly string DefaultLicKey = "__MOCK_LICENSE_KEY__";
+        static void ActivateLicensedObject()
+        {
+            Console.WriteLine($"Calling {nameof(ActivateLicensedObject)}...");
+
+            // Validate activation
+            var licenseTesting = (LicenseTesting)new LicenseTestingClass();
+
+            // Validate license denial
+            licenseTesting.SetNextDenyLicense(true);
+            try
+            {
+                var tmp = (LicenseTesting)new LicenseTestingClass();
+                Assert.Fail("Activation of licensed class should fail");
+            }
+            catch (COMException e)
+            {
+                const int CLASS_E_NOTLICENSED = unchecked((int)0x80040112);
+                Assert.AreEqual(CLASS_E_NOTLICENSED, e.HResult);
+            }
+            finally
+            {
+                licenseTesting.SetNextDenyLicense(false);
+            }
+        }
+
+        class MockLicenseContext : LicenseContext
+        {
+            private readonly Type _type;
+            private string _key;
+
+            public MockLicenseContext(Type type, LicenseUsageMode mode)
+            {
+                UsageMode = mode;
+                _type = type;
+            }
+
+            public override LicenseUsageMode UsageMode { get; }
+
+            public override string GetSavedLicenseKey(Type type, Assembly resourceAssembly)
+            {
+                if (type == _type)
+                {
+                    return _key;
+                }
+
+                return null;
+            }
+
+            public override void SetSavedLicenseKey(Type type, string key)
+            {
+                if (type == _type)
+                {
+                    _key = key;
+                }
+            }
+        }
+
+        static void ActivateUnderDesigntimeContext()
+        {
+            Console.WriteLine($"Calling {nameof(ActivateUnderDesigntimeContext)}...");
+
+            LicenseContext prev = LicenseManager.CurrentContext;
+            try
+            {
+                string licKey = "__TEST__";
+                LicenseManager.CurrentContext = new MockLicenseContext(typeof(LicenseTestingClass), LicenseUsageMode.Designtime);
+                LicenseManager.CurrentContext.SetSavedLicenseKey(typeof(LicenseTestingClass), licKey);
+
+                var licenseTesting = (LicenseTesting)new LicenseTestingClass();
+
+                // During design time the IClassFactory::CreateInstance will be called - no license
+                Assert.AreEqual(null, licenseTesting.GetLicense());
+
+                // Verify the value retrieved from the IClassFactory2::RequestLicKey was what was set
+                Assert.AreEqual(DefaultLicKey, LicenseManager.CurrentContext.GetSavedLicenseKey(typeof(LicenseTestingClass), resourceAssembly: null));
+            }
+            finally
+            {
+                LicenseManager.CurrentContext = prev;
+            }
+        }
+
+        static void ActivateUnderRuntimeContext()
+        {
+            Console.WriteLine($"Calling {nameof(ActivateUnderRuntimeContext)}...");
+
+            LicenseContext prev = LicenseManager.CurrentContext;
+            try
+            {
+                string licKey = "__TEST__";
+                LicenseManager.CurrentContext = new MockLicenseContext(typeof(LicenseTestingClass), LicenseUsageMode.Runtime);
+                LicenseManager.CurrentContext.SetSavedLicenseKey(typeof(LicenseTestingClass), licKey);
+
+                var licenseTesting = (LicenseTesting)new LicenseTestingClass();
+
+                // During runtime the IClassFactory::CreateInstance2 will be called with license from context
+                Assert.AreEqual(licKey, licenseTesting.GetLicense());
+            }
+            finally
+            {
+                LicenseManager.CurrentContext = prev;
+            }
+        }
+
+        static int Main(string[] doNotUse)
+        {
+            // RegFree COM is not supported on Windows Nano
+            if (Utilities.IsWindowsNanoServer)
+            {
+                return 100;
+            }
+
+            try
+            {
+                ActivateLicensedObject();
+                ActivateUnderDesigntimeContext();
+                ActivateUnderRuntimeContext();
+            }
+            catch (Exception e)
+            {
+                Console.WriteLine($"Test Failure: {e}");
+                return 101;
+            }
+
+            return 100;
+        }
+    }
+}
diff --git a/tests/src/Interop/COM/NETServer/LicenseTesting.cs b/tests/src/Interop/COM/NETServer/LicenseTesting.cs
new file mode 100644 (file)
index 0000000..e55ad8a
--- /dev/null
@@ -0,0 +1,77 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.ComponentModel;
+using System.Runtime.InteropServices;
+
+[ComVisible(true)]
+[Guid(Server.Contract.Guids.LicenseTesting)]
+[LicenseProvider(typeof(MockLicenseProvider))]
+public class LicenseTesting : Server.Contract.ILicenseTesting
+{
+    public LicenseTesting()
+    {
+        LicenseManager.Validate(typeof(LicenseTesting), this);
+    }
+
+    public string LicenseUsed { get; set; }
+
+    void Server.Contract.ILicenseTesting.SetNextDenyLicense(bool denyLicense)
+    {
+        MockLicenseProvider.DenyLicense = denyLicense;
+    }
+
+    void Server.Contract.ILicenseTesting.SetNextLicense(string lic)
+    {
+        MockLicenseProvider.License = lic;
+    }
+
+    string Server.Contract.ILicenseTesting.GetLicense()
+    {
+        return LicenseUsed;
+    }
+}
+
+public class MockLicenseProvider : LicenseProvider
+{
+    public static bool DenyLicense { get; set; }
+    public static string License { get; set; }
+
+    public override License GetLicense(LicenseContext context, Type type, object instance, bool allowExceptions)
+    {
+        if (DenyLicense)
+        {
+            if (allowExceptions)
+            {
+                throw new LicenseException(type);
+            }
+            else
+            {
+                return null;
+            }
+        }
+
+        if (type != typeof(LicenseTesting))
+        {
+            throw new Exception();
+        }
+
+        var lic = new MockLicense();
+
+        if (instance != null)
+        {
+            ((LicenseTesting)instance).LicenseUsed = lic.LicenseKey;
+        }
+
+        return lic;
+    }
+
+    private class MockLicense : License
+    {
+        public override string LicenseKey => MockLicenseProvider.License ?? "__MOCK_LICENSE_KEY__";
+
+        public override void Dispose () { }
+    }
+}
index 3548ea1..aa4e791 100644 (file)
@@ -23,6 +23,7 @@
     <Compile Include="StringTesting.cs" />
     <Compile Include="ErrorMarshalTesting.cs" />
     <Compile Include="ColorTesting.cs" />
+    <Compile Include="LicenseTesting.cs" />
     <Compile Include="../ServerContracts/Server.Contracts.cs" />
     <Compile Include="../ServerContracts/ServerGuids.cs" />
   </ItemGroup>
diff --git a/tests/src/Interop/COM/NativeClients/Licensing.csproj b/tests/src/Interop/COM/NativeClients/Licensing.csproj
new file mode 100644 (file)
index 0000000..dba48db
--- /dev/null
@@ -0,0 +1,23 @@
+<?xml version="1.0" encoding="utf-8"?>
+<Project ToolsVersion="12.0" DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
+  <Import Project="$([MSBuild]::GetDirectoryNameOfFileAbove($(MSBuildThisFileDirectory), dir.props))\dir.props" />
+  <Import Project="$([MSBuild]::GetDirectoryNameOfFileAbove($(MSBuildThisFileDirectory), Interop.settings.targets))\Interop.settings.targets" />
+  <PropertyGroup>
+    <IgnoreCoreCLRTestLibraryDependency>true</IgnoreCoreCLRTestLibraryDependency>
+    <CLRTestScriptLocalCoreShim>true</CLRTestScriptLocalCoreShim>
+
+    <TestUnsupportedOutsideWindows>true</TestUnsupportedOutsideWindows>
+    <DisableProjectBuild Condition="'$(TargetsUnix)' == 'true'">true</DisableProjectBuild>
+    <DefineConstants>BLOCK_WINDOWS_NANO</DefineConstants>
+  </PropertyGroup>
+  <ItemGroup>
+    <Compile Include="$(InteropCommonDir)ExeLauncherProgram.cs" />
+  </ItemGroup>
+  <ItemGroup>
+    <ProjectReference Include="Licensing/CMakeLists.txt" />
+    <ProjectReference Include="../NetServer/NetServer.csproj" />
+    <ProjectReference Include="../../../Common/hostpolicymock/CMakeLists.txt" />
+    <ProjectReference Include="../../../Common/CoreCLRTestLibrary/CoreCLRTestLibrary.csproj" />
+  </ItemGroup>
+  <Import Project="$([MSBuild]::GetDirectoryNameOfFileAbove($(MSBuildThisFileDirectory), dir.targets))\dir.targets" />
+</Project>
\ No newline at end of file
diff --git a/tests/src/Interop/COM/NativeClients/Licensing/App.manifest b/tests/src/Interop/COM/NativeClients/Licensing/App.manifest
new file mode 100644 (file)
index 0000000..613f8bf
--- /dev/null
@@ -0,0 +1,17 @@
+<?xml version="1.0" encoding="utf-8" standalone="yes" ?>
+<assembly xmlns="urn:schemas-microsoft-com:asm.v1" manifestVersion="1.0">
+  <assemblyIdentity
+    type="win32"
+    name="COMClientLicensing"
+    version="1.0.0.0"/>
+
+  <dependency>
+    <dependentAssembly>
+      <!-- RegFree COM - CoreCLR Shim -->
+      <assemblyIdentity
+        type="win32"
+        name="CoreShim.X"
+        version="1.0.0.0"/>
+    </dependentAssembly>
+  </dependency>
+</assembly>
\ No newline at end of file
diff --git a/tests/src/Interop/COM/NativeClients/Licensing/CMakeLists.txt b/tests/src/Interop/COM/NativeClients/Licensing/CMakeLists.txt
new file mode 100644 (file)
index 0000000..237b614
--- /dev/null
@@ -0,0 +1,19 @@
+cmake_minimum_required (VERSION 2.6)
+
+project (COMClientLicensing)
+include_directories( ${INC_PLATFORM_DIR} )
+include_directories( "../../ServerContracts" )
+include_directories( "../../NativeServer" )
+set(SOURCES
+    LicenseTests.cpp
+    App.manifest)
+
+# add the executable
+add_executable (COMClientLicensing ${SOURCES})
+target_link_libraries(COMClientLicensing ${LINK_LIBRARIES_ADDITIONAL})
+
+# Copy CoreShim manifest to project output
+file(GENERATE OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/$<CONFIG>/CoreShim.X.manifest INPUT ${CMAKE_CURRENT_SOURCE_DIR}/CoreShim.X.manifest)
+
+# add the install targets
+install (TARGETS COMClientLicensing DESTINATION bin)
diff --git a/tests/src/Interop/COM/NativeClients/Licensing/CoreShim.X.manifest b/tests/src/Interop/COM/NativeClients/Licensing/CoreShim.X.manifest
new file mode 100644 (file)
index 0000000..27479e2
--- /dev/null
@@ -0,0 +1,16 @@
+<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
+<assembly xmlns="urn:schemas-microsoft-com:asm.v1" manifestVersion="1.0">
+
+<assemblyIdentity
+  type="win32"
+  name="CoreShim.X"
+  version="1.0.0.0" />
+
+<file name="CoreShim.dll">
+  <!-- LicenseTesting -->
+  <comClass
+    clsid="{66DB7882-E2B0-471D-92C7-B2B52A0EA535}"
+    threadingModel="Both" />
+</file>
+
+</assembly>
diff --git a/tests/src/Interop/COM/NativeClients/Licensing/LicenseTests.cpp b/tests/src/Interop/COM/NativeClients/Licensing/LicenseTests.cpp
new file mode 100644 (file)
index 0000000..8fbff01
--- /dev/null
@@ -0,0 +1,150 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+#include <xplatform.h>
+#include <cassert>
+#include <Server.Contracts.h>
+
+// COM headers
+#include <objbase.h>
+#include <combaseapi.h>
+
+#define COM_CLIENT
+#include <Servers.h>
+
+#define THROW_IF_FAILED(exp) { hr = exp; if (FAILED(hr)) { ::printf("FAILURE: 0x%08x = %s\n", hr, #exp); throw hr; } }
+#define THROW_FAIL_IF_FALSE(exp) { if (!(exp)) { ::printf("FALSE: %s\n", #exp); throw E_FAIL; } }
+
+template<COINIT TM>
+struct ComInit
+{
+    const HRESULT Result;
+
+    ComInit()
+        : Result{ ::CoInitializeEx(nullptr, TM) }
+    { }
+
+    ~ComInit()
+    {
+        if (SUCCEEDED(Result))
+            ::CoUninitialize();
+    }
+};
+
+using ComMTA = ComInit<COINIT_MULTITHREADED>;
+
+void ActivateViaCoCreateInstance();
+void ActivateViaCoGetClassObject();
+
+int __cdecl main()
+{
+    ComMTA init;
+    if (FAILED(init.Result))
+        return -1;
+
+    try
+    {
+        CoreShimComActivation csact{ W("NETServer"), W("LicenseTesting") };
+
+        ActivateViaCoCreateInstance();
+        ActivateViaCoGetClassObject();
+    }
+    catch (HRESULT hr)
+    {
+        ::printf("Test Failure: 0x%08x\n", hr);
+        return 101;
+    }
+
+    return 100;
+}
+
+void ActivateViaCoCreateInstance()
+{
+    ::printf("License test through CoCreateInstance...\n");
+
+    HRESULT hr;
+
+    ILicenseTesting *licenseTesting;
+    THROW_IF_FAILED(::CoCreateInstance(CLSID_LicenseTesting, nullptr, CLSCTX_INPROC, IID_ILicenseTesting, (void**)&licenseTesting));
+    THROW_IF_FAILED(licenseTesting->SetNextDenyLicense(VARIANT_TRUE));
+
+    ILicenseTesting *failToCreate = nullptr;
+    hr = ::CoCreateInstance(CLSID_LicenseTesting, nullptr, CLSCTX_INPROC, IID_ILicenseTesting, (void**)&failToCreate);
+    if (hr != CLASS_E_NOTLICENSED || failToCreate != nullptr)
+    {
+        ::printf("Should fail to activate without license: %#08x\n", hr);
+        throw E_FAIL;
+    }
+
+    // Reset the environment
+    licenseTesting->SetNextDenyLicense(VARIANT_FALSE);
+    licenseTesting->Release();
+}
+
+void ActivateViaCoGetClassObject()
+{
+    ::printf("License test through CoGetClassObject...\n");
+
+    HRESULT hr;
+
+    IClassFactory2 *factory;
+    THROW_IF_FAILED(::CoGetClassObject(CLSID_LicenseTesting, CLSCTX_INPROC, nullptr, IID_IClassFactory2, (void**)&factory));
+
+    // Validate license info
+    LICINFO info;
+    THROW_IF_FAILED(factory->GetLicInfo(&info));
+    THROW_FAIL_IF_FALSE(info.fLicVerified != FALSE);
+    THROW_FAIL_IF_FALSE(info.fRuntimeKeyAvail == FALSE); // Have not populated the cache
+
+    // Initialize to default key.
+    LPCOLESTR key = W("__MOCK_LICENSE_KEY__");
+
+    // Validate license key
+    BSTR lic;
+    THROW_IF_FAILED(factory->RequestLicKey(0, &lic));
+    THROW_FAIL_IF_FALSE(::CompareStringOrdinal(lic, -1, key, -1, FALSE) == CSTR_EQUAL);
+
+    // Create instance
+    IUnknown *test;
+    THROW_IF_FAILED(factory->CreateInstanceLic(nullptr, nullptr, IID_IUnknown, lic, (void**)&test));
+    CoreClrBStrFree(lic);
+
+    ILicenseTesting *licenseTesting;
+    THROW_IF_FAILED(test->QueryInterface(&licenseTesting));
+    test->Release();
+
+    // Validate license key used
+    BSTR licMaybe;
+    THROW_IF_FAILED(licenseTesting->GetLicense(&licMaybe));
+    THROW_FAIL_IF_FALSE(::CompareStringOrdinal(licMaybe, -1, key, -1, FALSE) == CSTR_EQUAL);
+    CoreClrBStrFree(licMaybe);
+    licMaybe = nullptr;
+
+    // Set new license key
+    key = W("__TEST__");
+    THROW_IF_FAILED(licenseTesting->SetNextLicense(key));
+
+    // Free previous instance
+    licenseTesting->Release();
+
+    // Create instance and validate key used
+    THROW_IF_FAILED(factory->RequestLicKey(0, &lic));
+    THROW_FAIL_IF_FALSE(::CompareStringOrdinal(lic, -1, key, -1, FALSE) == CSTR_EQUAL);
+
+    test = nullptr;
+    THROW_IF_FAILED(factory->CreateInstanceLic(nullptr, nullptr, IID_IUnknown, lic, (void**)&test));
+    CoreClrBStrFree(lic);
+
+    licenseTesting = nullptr;
+    THROW_IF_FAILED(test->QueryInterface(&licenseTesting));
+    test->Release();
+
+    // Validate license key used
+    THROW_IF_FAILED(licenseTesting->GetLicense(&licMaybe));
+    THROW_FAIL_IF_FALSE(::CompareStringOrdinal(licMaybe, -1, key, -1, FALSE) == CSTR_EQUAL);
+    CoreClrBStrFree(licMaybe);
+
+    licenseTesting->Release();
+    factory->Release();
+}
index 4509ee9..3c67514 100644 (file)
   <comClass
     clsid="{C222F472-DA5A-4FC6-9321-92F4F7053A65}"
     threadingModel="Both" />
+
+  <!-- LicenseTesting -->
+  <comClass
+    clsid="{66DB7882-E2B0-471D-92C7-B2B52A0EA535}"
+    threadingModel="Both" />
 </file>
 
 </assembly>
diff --git a/tests/src/Interop/COM/NativeServer/LicenseTesting.h b/tests/src/Interop/COM/NativeServer/LicenseTesting.h
new file mode 100644 (file)
index 0000000..b0c058e
--- /dev/null
@@ -0,0 +1,77 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+#pragma once
+
+#include <xplatform.h>
+#include "Servers.h"
+
+class LicenseTesting : public UnknownImpl, public ILicenseTesting
+{
+private: // static
+    static bool s_DenyLicense;
+    static BSTR s_License;
+
+public: // static
+    static HRESULT RequestLicKey(BSTR *key)
+    {
+        LPCOLESTR lic = s_License;
+        if (lic == nullptr)
+            lic = W("__MOCK_LICENSE_KEY__");
+
+        *key = TP_SysAllocString(lic);
+        return S_OK;
+    }
+
+private:
+    BSTR _lic;
+
+public:
+    LicenseTesting(_In_opt_ BSTR lic)
+        : _lic{ lic }
+    {
+        if (s_DenyLicense)
+            throw CLASS_E_NOTLICENSED;
+    }
+
+    ~LicenseTesting()
+    {
+        CoreClrBStrFree(_lic);
+    }
+
+public: // ILicenseTesting
+    DEF_FUNC(SetNextDenyLicense)(_In_ VARIANT_BOOL denyLicense)
+    {
+        s_DenyLicense = (denyLicense == VARIANT_FALSE) ? false : true;
+        return S_OK;
+    }
+
+    DEF_FUNC(GetLicense)(_Out_ BSTR *lic)
+    {
+        *lic = TP_SysAllocString(_lic);
+        return S_OK;
+    }
+
+    DEF_FUNC(SetNextLicense)(_In_z_ LPCOLESTR lic)
+    {
+        if (s_License != nullptr)
+            CoreClrBStrFree(s_License);
+
+        s_License = TP_SysAllocString(lic);
+        return S_OK;
+    }
+
+public: // IUnknown
+    STDMETHOD(QueryInterface)(
+        /* [in] */ REFIID riid,
+        /* [iid_is][out] */ _COM_Outptr_ void __RPC_FAR *__RPC_FAR *ppvObject)
+    {
+        return DoQueryInterface(riid, ppvObject, static_cast<ILicenseTesting *>(this));
+    }
+
+    DEFINE_REF_COUNTING();
+};
+
+bool LicenseTesting::s_DenyLicense = false;
+BSTR LicenseTesting::s_License = nullptr;
index 3bf7072..8cf7802 100644 (file)
@@ -214,5 +214,8 @@ STDAPI DllGetClassObject(_In_ REFCLSID rclsid, _In_ REFIID riid, _Out_ LPVOID FA
     if (rclsid == __uuidof(ColorTesting))
         return ClassFactoryBasic<ColorTesting>::Create(riid, ppv);
 
+    if (rclsid == __uuidof(LicenseTesting))
+        return ClassFactoryLicense<LicenseTesting>::Create(riid, ppv);
+
     return CLASS_E_CLASSNOTAVAILABLE;
 }
index 7a2a1ff..61f16ca 100644 (file)
@@ -18,6 +18,7 @@ class DECLSPEC_UUID("0F8ACD0C-ECE0-4F2A-BD1B-6BFCA93A0726") DispatchTesting;
 class DECLSPEC_UUID("4DBD9B61-E372-499F-84DE-EFC70AA8A009") EventTesting;
 class DECLSPEC_UUID("4CEFE36D-F377-4B6E-8C34-819A8BB9CB04") AggregationTesting;
 class DECLSPEC_UUID("C222F472-DA5A-4FC6-9321-92F4F7053A65") ColorTesting;
+class DECLSPEC_UUID("66DB7882-E2B0-471D-92C7-B2B52A0EA535") LicenseTesting;
 
 #define CLSID_NumericTesting __uuidof(NumericTesting)
 #define CLSID_ArrayTesting __uuidof(ArrayTesting)
@@ -27,6 +28,7 @@ class DECLSPEC_UUID("C222F472-DA5A-4FC6-9321-92F4F7053A65") ColorTesting;
 #define CLSID_EventTesting __uuidof(EventTesting)
 #define CLSID_AggregationTesting __uuidof(AggregationTesting)
 #define CLSID_ColorTesting __uuidof(ColorTesting)
+#define CLSID_LicenseTesting __uuidof(LicenseTesting)
 
 #define IID_INumericTesting __uuidof(INumericTesting)
 #define IID_IArrayTesting __uuidof(IArrayTesting)
@@ -37,6 +39,7 @@ class DECLSPEC_UUID("C222F472-DA5A-4FC6-9321-92F4F7053A65") ColorTesting;
 #define IID_IEventTesting __uuidof(IEventTesting)
 #define IID_IAggregationTesting __uuidof(IAggregationTesting)
 #define IID_IColorTesting __uuidof(IColorTesting)
+#define IID_ILicenseTesting __uuidof(ILicenseTesting)
 
 // Class used for COM activation when using CoreShim
 struct CoreShimComActivation
@@ -74,4 +77,5 @@ private:
     #include "EventTesting.h"
     #include "AggregationTesting.h"
     #include "ColorTesting.h"
+    #include "LicenseTesting.h"
 #endif
index e868345..0a14e4f 100644 (file)
@@ -142,6 +142,25 @@ namespace Server.Contract.Servers
     internal class ColorTestingClass
     {
     }
+
+    /// <summary>
+    /// Managed definition of CoClass
+    /// </summary>
+    [ComImport]
+    [CoClass(typeof(LicenseTestingClass))]
+    [Guid("6C9E230E-411F-4219-ABFD-E71F2B84FD50")]
+    internal interface LicenseTesting : Server.Contract.ILicenseTesting
+    {
+    }
+
+    /// <summary>
+    /// Managed activation for CoClass
+    /// </summary>
+    [ComImport]
+    [Guid(Server.Contract.Guids.LicenseTesting)]
+    internal class LicenseTestingClass
+    {
+    }
 }
 
 #pragma warning restore 618 // Must test deprecated features
index f401c48..9c2e680 100644 (file)
@@ -273,6 +273,19 @@ namespace Server.Contract
         bool AreColorsEqual(Color managed, int native);
         Color GetRed();
     }
+
+    [ComVisible(true)]
+    [Guid("6C9E230E-411F-4219-ABFD-E71F2B84FD50")]
+    [InterfaceType(ComInterfaceType.InterfaceIsIUnknown)]
+    public interface ILicenseTesting
+    {
+        void SetNextDenyLicense([MarshalAs(UnmanagedType.VariantBool)] bool denyLicense);
+
+        [return: MarshalAs(UnmanagedType.BStr)]
+        string GetLicense();
+
+        void SetNextLicense([MarshalAs(UnmanagedType.LPWStr)] string lic);
+    }
 }
 
 #pragma warning restore 618 // Must test deprecated features
index 2502567..1db14e4 100644 (file)
@@ -25,6 +25,8 @@ struct __declspec(uuid("98cc27f0-d521-4f79-8b63-e980e3a92974"))
 /* interface */ IAggregationTesting;
 struct __declspec(uuid("E6D72BA7-0936-4396-8A69-3B76DA1108DA"))
 /* interface */ IColorTesting;
+struct __declspec(uuid("6C9E230E-411F-4219-ABFD-E71F2B84FD50"))
+/* interface */ ILicenseTesting;
 
 //
 // Smart pointer typedef declarations
@@ -38,6 +40,7 @@ _COM_SMARTPTR_TYPEDEF(IDispatchTesting, __uuidof(IDispatchTesting));
 _COM_SMARTPTR_TYPEDEF(IEventTesting, __uuidof(IEventTesting));
 _COM_SMARTPTR_TYPEDEF(IAggregationTesting, __uuidof(IAggregationTesting));
 _COM_SMARTPTR_TYPEDEF(IColorTesting, __uuidof(IColorTesting));
+_COM_SMARTPTR_TYPEDEF(ILicenseTesting, __uuidof(ILicenseTesting));
 
 //
 // Type library items
@@ -486,4 +489,14 @@ IColorTesting : public IUnknown
         _Out_ _Ret_ OLE_COLOR* color) = 0;
 };
 
+struct __declspec(uuid("6C9E230E-411F-4219-ABFD-E71F2B84FD50"))
+ILicenseTesting : IUnknown
+{
+    virtual HRESULT STDMETHODCALLTYPE SetNextDenyLicense(_In_ VARIANT_BOOL denyLicense) = 0;
+
+    virtual HRESULT STDMETHODCALLTYPE GetLicense(_Out_ BSTR *lic) = 0;
+
+    virtual HRESULT STDMETHODCALLTYPE SetNextLicense(_In_z_ LPCOLESTR lic) = 0;
+};
+
 #pragma pack(pop)
index 1269e6a..2ebeed8 100644 (file)
@@ -17,5 +17,6 @@ namespace Server.Contract
         public const string EventTesting = "4DBD9B61-E372-499F-84DE-EFC70AA8A009";
         public const string AggregationTesting = "4CEFE36D-F377-4B6E-8C34-819A8BB9CB04";
         public const string ColorTesting = "C222F472-DA5A-4FC6-9321-92F4F7053A65";
+        public const string LicenseTesting = "66DB7882-E2B0-471D-92C7-B2B52A0EA535";
     }
 }
index c4d9e6c..aa3e43b 100644 (file)
@@ -231,3 +231,103 @@ public: // IUnknown
 
     DEFINE_REF_COUNTING();
 };
+
+// Templated class factory
+// Supplied type must have the following properties to use this template:
+//  1) Have a static method with the following signature:
+//    - HRESULT RequestLicKey(BSTR *key);
+//  2) Have a constructor that takes an optional BSTR value as the key
+template<typename T>
+class ClassFactoryLicense : public UnknownImpl, public IClassFactory2
+{
+public: // static
+    static HRESULT Create(_In_ REFIID riid, _Outptr_ LPVOID FAR* ppv)
+    {
+        try
+        {
+            auto cf = new ClassFactoryLicense();
+            HRESULT hr = cf->QueryInterface(riid, ppv);
+            cf->Release();
+            return hr;
+        }
+        catch (const std::bad_alloc&)
+        {
+            return E_OUTOFMEMORY;
+        }
+    }
+
+public: // IClassFactory
+    STDMETHOD(CreateInstance)(
+        _In_opt_  IUnknown *pUnkOuter,
+        _In_  REFIID riid,
+        _COM_Outptr_  void **ppvObject)
+    {
+        return CreateInstanceLic(pUnkOuter, nullptr, riid, nullptr, ppvObject);
+    }
+
+    STDMETHOD(LockServer)(/* [in] */ BOOL fLock)
+    {
+        assert(false && "Not impl");
+        return E_NOTIMPL;
+    }
+
+public: // IClassFactory2
+    STDMETHOD(GetLicInfo)(
+        /* [out][in] */ __RPC__inout LICINFO *pLicInfo)
+    {
+        // The CLR does not call this function and as such,
+        // returns an error. Note that this is explicitly illegal
+        // in a proper implementation of IClassFactory2.
+        return E_UNEXPECTED;
+    }
+
+    STDMETHOD(RequestLicKey)(
+        /* [in] */ DWORD dwReserved,
+        /* [out] */ __RPC__deref_out_opt BSTR *pBstrKey)
+    {
+        if (dwReserved != 0)
+            return E_UNEXPECTED;
+
+        return T::RequestLicKey(pBstrKey);
+    }
+
+    STDMETHOD(CreateInstanceLic)(
+        /* [annotation][in] */ _In_opt_  IUnknown *pUnkOuter,
+        /* [annotation][in] */ _Reserved_  IUnknown *pUnkReserved,
+        /* [annotation][in] */ __RPC__in  REFIID riid,
+        /* [annotation][in] */ __RPC__in  BSTR bstrKey,
+        /* [annotation][iid_is][out] */ __RPC__deref_out_opt  PVOID *ppvObj)
+    {
+        if (pUnkOuter != nullptr)
+            return CLASS_E_NOAGGREGATION;
+
+        if (pUnkReserved != nullptr)
+            return E_UNEXPECTED;
+
+        try
+        {
+            auto ti = new T(bstrKey);
+            HRESULT hr = ti->QueryInterface(riid, ppvObj);
+            ti->Release();
+            return hr;
+        }
+        catch (HRESULT hr)
+        {
+            return hr;
+        }
+        catch (const std::bad_alloc&)
+        {
+            return E_OUTOFMEMORY;
+        }
+    }
+
+public: // IUnknown
+    STDMETHOD(QueryInterface)(
+        /* [in] */ REFIID riid,
+        /* [iid_is][out] */ _COM_Outptr_ void __RPC_FAR *__RPC_FAR *ppvObject)
+    {
+        return DoQueryInterface(riid, ppvObject, static_cast<IClassFactory *>(this), static_cast<IClassFactory2 *>(this));
+    }
+
+    DEFINE_REF_COUNTING();
+};