Implement Serialization Guard
authorMorgan Brown <morganbr@users.noreply.github.com>
Fri, 1 Mar 2019 16:46:22 +0000 (08:46 -0800)
committerJeremy Barton <jbarton@microsoft.com>
Fri, 1 Mar 2019 16:46:22 +0000 (08:46 -0800)
Add Serialization Guard API and consume it in CoreLib targets

17 files changed:
src/System.Private.CoreLib/Resources/Strings.resx
src/System.Private.CoreLib/shared/System.Private.CoreLib.Shared.projitems
src/System.Private.CoreLib/shared/System/IO/FileStream.cs
src/System.Private.CoreLib/shared/System/LocalAppContextSwitches.Common.cs
src/System.Private.CoreLib/shared/System/LocalAppContextSwitches.cs
src/System.Private.CoreLib/shared/System/Runtime/Serialization/DeserializationBlockedException.cs [new file with mode: 0644]
src/System.Private.CoreLib/shared/System/Runtime/Serialization/DeserializationToken.cs [new file with mode: 0644]
src/System.Private.CoreLib/shared/System/Runtime/Serialization/DeserializationTracker.cs [new file with mode: 0644]
src/System.Private.CoreLib/shared/System/Runtime/Serialization/SerializationInfo.cs
src/System.Private.CoreLib/src/System/Reflection/Assembly.CoreCLR.cs
src/System.Private.CoreLib/src/System/Threading/Thread.CoreCLR.cs
src/vm/comsynchronizable.cpp
src/vm/comsynchronizable.h
src/vm/ecalllist.h
src/vm/mscorlib.h
src/vm/threads.cpp
src/vm/threads.h

index 155cb43..a912b83 100644 (file)
   <data name="InvalidOperation_CannotRegisterSecondResolver" xml:space="preserve">
     <value>A resolver is already set for the assembly.</value>
   </data>
-    <data name="InvalidOperation_CannotRemoveLastFromEmptyCollection" xml:space="preserve">
+  <data name="InvalidOperation_CannotRemoveLastFromEmptyCollection" xml:space="preserve">
     <value>Cannot remove the last element from an empty collection.</value>
   </data>
   <data name="InvalidOperation_CannotRestoreUnsupressedFlow" xml:space="preserve">
   <data name="IndexOutOfRange_ArrayWithOffset" xml:space="preserve">
     <value>ArrayWithOffset: offset exceeds array size.</value>
   </data>
-</root>
+  <data name="Serialization_DangerousDeserialization" xml:space="preserve">
+    <value>An action was attempted during deserialization that could lead to a security vulnerability. The action has been aborted.</value>
+  </data>
+  <data name="Serialization_DangerousDeserialization_Switch" xml:space="preserve">
+    <value>An action was attempted during deserialization that could lead to a security vulnerability. The action has been aborted. To allow the action, set the '{0}' AppContext switch to true.</value>
+  </data>
+</root>
\ No newline at end of file
index 77b9589..e0bf112 100644 (file)
     <Compile Include="$(MSBuildThisFileDirectory)System\Runtime\Intrinsics\X86\Enums.cs" />
     <Compile Include="$(MSBuildThisFileDirectory)System\Runtime\Loader\AssemblyLoadContext.cs" Condition="'$(TargetsCoreRT)' != 'true'" />
     <Compile Include="$(MSBuildThisFileDirectory)System\Runtime\Remoting\ObjectHandle.cs" />
+    <Compile Include="$(MSBuildThisFileDirectory)System\Runtime\Serialization\DeserializationBlockedException.cs" />
+    <Compile Include="$(MSBuildThisFileDirectory)System\Runtime\Serialization\DeserializationToken.cs" />
+    <Compile Include="$(MSBuildThisFileDirectory)System\Runtime\Serialization\DeserializationTracker.cs" />
     <Compile Include="$(MSBuildThisFileDirectory)System\Runtime\Serialization\IDeserializationCallback.cs" />
     <Compile Include="$(MSBuildThisFileDirectory)System\Runtime\Serialization\IFormatterConverter.cs" />
     <Compile Include="$(MSBuildThisFileDirectory)System\Runtime\Serialization\IObjectReference.cs" />
index 5c8913b..bc84d98 100644 (file)
@@ -2,6 +2,7 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 // See the LICENSE file in the project root for more information.
 
+using System.Runtime.Serialization;
 using System.Threading;
 using System.Threading.Tasks;
 using Microsoft.Win32.SafeHandles;
@@ -64,6 +65,9 @@ namespace System.IO
         /// <summary>Whether the file stream's handle has been exposed.</summary>
         private bool _exposedHandle;
 
+        /// <summary>Caches whether Serialization Guard has been disabled for file writes</summary>
+        private static int s_cachedSerializationSwitch = 0;
+
         [Obsolete("This constructor has been deprecated.  Please use new FileStream(SafeFileHandle handle, FileAccess access) instead.  https://go.microsoft.com/fwlink/?linkid=14202")]
         public FileStream(IntPtr handle, FileAccess access)
             : this(handle, access, true, DefaultBufferSize, false)
@@ -228,6 +232,11 @@ namespace System.IO
             if ((options & FileOptions.Asynchronous) != 0)
                 _useAsyncIO = true;
 
+            if ((access & FileAccess.Write) == FileAccess.Write)
+            {
+                SerializationInfo.ThrowIfDeserializationInProgress("AllowFileWrites", ref s_cachedSerializationSwitch);
+            }
+
             _fileHandle = OpenHandle(mode, share, options);
 
             try
index 521848f..6c0e750 100644 (file)
@@ -25,7 +25,12 @@ namespace System
         private static bool GetCachedSwitchValueInternal(string switchName, ref int cachedSwitchValue)
         {
             bool isSwitchEnabled;
-            AppContext.TryGetSwitch(switchName, out isSwitchEnabled);
+            
+            bool hasSwitch = AppContext.TryGetSwitch(switchName, out isSwitchEnabled);
+            if (!hasSwitch)
+            {
+                isSwitchEnabled = GetSwitchDefaultValue(switchName);
+            }
 
             AppContext.TryGetSwitch(@"TestSwitch.LocalAppContext.DisableCaching", out bool disableCaching);
             if (!disableCaching)
@@ -35,5 +40,16 @@ namespace System
 
             return isSwitchEnabled;
         }
+
+        // Provides default values for switches if they're not always false by default
+        private static bool GetSwitchDefaultValue(string switchName)
+        {
+            if (switchName == "Switch.System.Runtime.Serialization.SerializationGuard")
+            {
+                return true;
+            }
+            
+            return false;
+        }
     }
 }
index 8194f87..b0999da 100644 (file)
@@ -47,5 +47,15 @@ namespace System
                 return GetCachedSwitchValue("Switch.System.Diagnostics.EventSource.PreserveEventListnerObjectIdentity", ref s_preserveEventListnerObjectIdentity);
             }
         }
+
+        private static int s_serializationGuard;
+        public static bool SerializationGuard
+        {
+            [MethodImpl(MethodImplOptions.AggressiveInlining)]
+            get
+            {
+                return GetCachedSwitchValue("Switch.System.Runtime.Serialization.SerializationGuard", ref s_serializationGuard);
+            }
+        }
     }
 }
diff --git a/src/System.Private.CoreLib/shared/System/Runtime/Serialization/DeserializationBlockedException.cs b/src/System.Private.CoreLib/shared/System/Runtime/Serialization/DeserializationBlockedException.cs
new file mode 100644 (file)
index 0000000..8c8f780
--- /dev/null
@@ -0,0 +1,41 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+
+namespace System.Runtime.Serialization
+{
+    // Thrown when a dangerous action would be performed during deserialization 
+    [Serializable]
+    [System.Runtime.CompilerServices.TypeForwardedFrom("mscorlib, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089")]
+    public sealed class DeserializationBlockedException : SerializationException
+    {
+        // Creates a new DeserializationBlockedException with its message 
+        // string set to a default message.
+        public DeserializationBlockedException()
+            : base(SR.Serialization_DangerousDeserialization)
+        {
+            HResult = HResults.COR_E_SERIALIZATION;
+        }
+
+        // Creates a new DeserializationBlockedException with a message indicating an opt-out switch
+        // for a particular part of SerializationGuard
+        public DeserializationBlockedException(string message)
+            : base(message)
+        {
+            HResult = HResults.COR_E_SERIALIZATION;
+        }
+
+        public DeserializationBlockedException(Exception innerException)
+            : base(SR.Serialization_DangerousDeserialization, innerException)
+        {
+            HResult = HResults.COR_E_SERIALIZATION;
+        }
+
+        private DeserializationBlockedException(SerializationInfo info, StreamingContext context)
+            : base(info, context)
+        {
+        }
+    }
+}
diff --git a/src/System.Private.CoreLib/shared/System/Runtime/Serialization/DeserializationToken.cs b/src/System.Private.CoreLib/shared/System/Runtime/Serialization/DeserializationToken.cs
new file mode 100644 (file)
index 0000000..e879ef5
--- /dev/null
@@ -0,0 +1,37 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Runtime.CompilerServices;
+using System.Threading;
+
+namespace System.Runtime.Serialization
+{
+    // Tracks whether deserialization is currently in progress
+    public readonly struct DeserializationToken : IDisposable
+    {
+        private readonly DeserializationTracker _tracker;
+
+        internal DeserializationToken(DeserializationTracker tracker)
+        {
+            _tracker = tracker;
+        }
+
+        // If this token owned the DeserializationTracker, turn off DeserializationInProgress tracking
+        public void Dispose()
+        {
+            if (_tracker != null && _tracker.DeserializationInProgress)
+            {
+                lock (_tracker)
+                {
+                    if (_tracker.DeserializationInProgress)
+                    {
+                        _tracker.DeserializationInProgress = false;
+                        SerializationInfo.AsyncDeserializationInProgress.Value = false;
+                    }
+                }
+            }
+        }
+    }
+}
diff --git a/src/System.Private.CoreLib/shared/System/Runtime/Serialization/DeserializationTracker.cs b/src/System.Private.CoreLib/shared/System/Runtime/Serialization/DeserializationTracker.cs
new file mode 100644 (file)
index 0000000..fb4ac51
--- /dev/null
@@ -0,0 +1,17 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Runtime.CompilerServices;
+
+namespace System.Runtime.Serialization
+{
+    // Tracks whether deserialization is currently in progress
+    internal sealed class DeserializationTracker
+    {
+        // True if the thread this tracker applies to is currently deserializing
+        // potentially untrusted data
+        internal bool DeserializationInProgress { get; set; }
+    }
+}
index 8af5c4b..fb6ac44 100644 (file)
@@ -2,8 +2,11 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 // See the LICENSE file in the project root for more information.
 
+using System.Collections;
 using System.Collections.Generic;
 using System.Diagnostics;
+using System.Security;
+using System.Threading;
 
 namespace System.Runtime.Serialization
 {
@@ -23,6 +26,137 @@ namespace System.Runtime.Serialization
         private string _rootTypeName;
         private string _rootTypeAssemblyName;
         private Type _rootType;
+        
+        internal static AsyncLocal<bool> AsyncDeserializationInProgress { get; } = new AsyncLocal<bool>();        
+
+#if !CORECLR
+        // On AoT, assume private members are reflection blocked, so there's no further protection required
+        // for the thread's DeserializationTracker
+        [ThreadStatic]
+        private static DeserializationTracker t_deserializationTracker;
+
+        private static DeserializationTracker GetThreadDeserializationTracker()
+        {
+            if (t_deserializationTracker == null)
+            {
+                t_deserializationTracker = new DeserializationTracker();
+            }
+
+            return t_deserializationTracker;
+        }
+#endif // !CORECLR
+
+        // Returns true if deserialization is currently in progress
+        public static bool DeserializationInProgress
+        {
+            [DynamicSecurityMethod] // Methods containing StackCrawlMark local var must be marked DynamicSecurityMethod
+            get
+            {
+                if (AsyncDeserializationInProgress.Value)
+                {
+                    return true;
+                }
+
+#if CORECLR
+                StackCrawlMark stackMark = StackCrawlMark.LookForMe;
+                DeserializationTracker tracker = Thread.GetThreadDeserializationTracker(ref stackMark);
+#else
+                DeserializationTracker tracker = GetThreadDeserializationTracker();
+#endif
+                bool result = tracker.DeserializationInProgress;
+                return result;
+            }
+        }
+
+        // Throws a DeserializationBlockedException if dangerous deserialization is currently
+        // in progress
+        public static void ThrowIfDeserializationInProgress()
+        {
+            if (DeserializationInProgress)
+            {
+                throw new DeserializationBlockedException();
+            }
+        }
+
+        // Throws a DeserializationBlockedException if dangerous deserialization is currently
+        // in progress and the AppContext switch Switch.System.Runtime.Serialization.SerializationGuard.{switchSuffix}
+        // is not true. The value of the switch is cached in cachedValue to avoid repeated lookups:
+        // 0: No value cached
+        // 1: The switch is true
+        // -1: The switch is false
+        public static void ThrowIfDeserializationInProgress(string switchSuffix, ref int cachedValue)
+        {
+            const string SwitchPrefix = "Switch.System.Runtime.Serialization.SerializationGuard.";
+            if (switchSuffix == null)
+            {
+                throw new ArgumentNullException(nameof(switchSuffix));
+            }
+            if (String.IsNullOrWhiteSpace(switchSuffix))
+            {
+                throw new ArgumentException(SR.Argument_EmptyName, nameof(switchSuffix));
+            }
+
+            if (cachedValue == 0)
+            {
+                bool isEnabled = false;
+                if (AppContext.TryGetSwitch(SwitchPrefix + switchSuffix, out isEnabled) && isEnabled)
+                {
+                    cachedValue = 1;
+                }
+                else
+                {
+                    cachedValue = -1;
+                }
+            }
+
+            if (cachedValue == 1)
+            {
+                return;
+            }
+            else if (cachedValue == -1)
+            {
+                if (DeserializationInProgress)
+                {
+                    throw new DeserializationBlockedException(SR.Format(SR.Serialization_DangerousDeserialization_Switch, SwitchPrefix + switchSuffix));
+                }
+            }
+            else
+            {
+                throw new ArgumentOutOfRangeException(nameof(cachedValue));
+            }
+        }
+
+        // Declares that the current thread and async context have begun deserialization.
+        // In this state, if the SerializationGuard or other related AppContext switches are set,
+        // actions likely to be dangerous during deserialization, such as starting a process will be blocked.
+        // Returns a DeserializationToken that must be disposed to remove the deserialization state.        
+        [DynamicSecurityMethod] // Methods containing StackCrawlMark local var must be marked DynamicSecurityMethod
+        public static DeserializationToken StartDeserialization()
+        {
+            if (LocalAppContextSwitches.SerializationGuard)
+            {
+#if CORECLR
+                StackCrawlMark stackMark = StackCrawlMark.LookForMe;
+                DeserializationTracker tracker = Thread.GetThreadDeserializationTracker(ref stackMark);
+#else
+                DeserializationTracker tracker = GetThreadDeserializationTracker();
+#endif
+                if  (!tracker.DeserializationInProgress)
+                {
+                    lock (tracker)
+                    {
+                        if (!tracker.DeserializationInProgress)
+                        {
+                            AsyncDeserializationInProgress.Value = true;
+                            tracker.DeserializationInProgress = true;
+                            return new DeserializationToken(tracker);
+                        }
+                    }
+                }
+            }
+            
+            return new DeserializationToken(null);
+        }
 
         [CLSCompliant(false)]
         public SerializationInfo(Type type, IFormatterConverter converter) 
index 30acc44..83c2c0a 100644 (file)
@@ -19,6 +19,7 @@ namespace System.Reflection
         private static object s_syncRootLoadFrom = new object();
         private static List<string> s_LoadFromAssemblyList = new List<string>();
         private static object s_syncLoadFromAssemblyList = new object();
+        private static int s_cachedSerializationSwitch = 0;
 
         private static Assembly LoadFromResolveHandler(object sender, ResolveEventArgs args)
         {
@@ -163,6 +164,9 @@ namespace System.Reflection
                 throw new NotSupportedException(SR.Format(SR.NotSupported_AppX, "Assembly.Load(byte[], ...)"));
 #endif
 
+            SerializationInfo.ThrowIfDeserializationInProgress("AllowAssembliesFromByteArrays", 
+                ref s_cachedSerializationSwitch);
+
             AssemblyLoadContext alc = new IndividualAssemblyLoadContext();
             MemoryStream assemblyStream = new MemoryStream(rawAssembly);
             MemoryStream symbolStream = (rawSymbolStore != null) ? new MemoryStream(rawSymbolStore) : null;
index 46ca988..950b143 100644 (file)
@@ -7,6 +7,7 @@ using System.Globalization;
 using System.Runtime.CompilerServices;
 using System.Runtime.ConstrainedExecution;
 using System.Runtime.InteropServices;
+using System.Runtime.Serialization;
 
 namespace System.Threading
 {
@@ -335,6 +336,9 @@ namespace System.Threading
         [DllImport(JitHelpers.QCall, CharSet = CharSet.Unicode)]
         private static extern void InformThreadNameChange(ThreadHandle t, string name, int len);
 
+        [MethodImplAttribute(MethodImplOptions.InternalCall)]
+        internal static extern DeserializationTracker GetThreadDeserializationTracker(ref StackCrawlMark stackMark);
+
         /// <summary>Returns true if the thread has been started and is not dead.</summary>
         public extern bool IsAlive
         {
index 0cd7b85..a460306 100644 (file)
@@ -1423,6 +1423,29 @@ BOOL QCALLTYPE ThreadNative::YieldThread()
     return ret;
 }
 
+FCIMPL1(Object*, ThreadNative::GetThreadDeserializationTracker, StackCrawlMark* stackMark)
+{
+    FCALL_CONTRACT;
+    OBJECTREF refRetVal = NULL;
+    HELPER_METHOD_FRAME_BEGIN_RET_1(refRetVal)
+
+    // To avoid reflection trying to bypass deserialization tracking, check the caller
+    // and only allow SerializationInfo to call into this method.
+    MethodTable* pCallerMT = SystemDomain::GetCallersType(stackMark);
+    if (pCallerMT != MscorlibBinder::GetClass(CLASS__SERIALIZATION_INFO))
+    {
+        COMPlusThrowArgumentException(W("stackMark"), NULL);
+    }
+
+    Thread* pThread = GetThread();
+
+    refRetVal = ObjectFromHandle(pThread->GetOrCreateDeserializationTracker());
+
+    HELPER_METHOD_FRAME_END();
+
+    return OBJECTREFToObject(refRetVal);
+}
+FCIMPLEND
 
 FCIMPL0(INT32, ThreadNative::GetCurrentProcessorNumber)
 {
index 15c0218..d0f3bb3 100644 (file)
@@ -100,6 +100,7 @@ public:
     static FCDECL1(void,    DisableComObjectEagerCleanup,   ThreadBaseObject* pThis);
 #endif //FEATURE_COMINTEROP
     static FCDECL1(FC_BOOL_RET,IsThreadpoolThread,             ThreadBaseObject* thread);
+    static FCDECL1(Object*, GetThreadDeserializationTracker, StackCrawlMark* stackMark);
 
     static FCDECL0(INT32,   GetCurrentProcessorNumber);
 
index 9a204e6..35fbae2 100644 (file)
@@ -655,6 +655,7 @@ FCFuncStart(gThreadFuncs)
     FCFuncElement("Join", ThreadNative::Join)
     QCFuncElement("GetOptimalMaxSpinWaitsPerSpinIterationInternal", ThreadNative::GetOptimalMaxSpinWaitsPerSpinIteration)
     FCFuncElement("GetCurrentProcessorNumber", ThreadNative::GetCurrentProcessorNumber)
+    FCFuncElement("GetThreadDeserializationTracker", ThreadNative::GetThreadDeserializationTracker)
 FCFuncEnd()
 
 FCFuncStart(gThreadPoolFuncs)
index d7b5f6e..b78684d 100644 (file)
@@ -455,7 +455,7 @@ DEFINE_CLASS(IOBJECTREFERENCE,      Serialization,          IObjectReference)
 DEFINE_CLASS(IDESERIALIZATIONCB,    Serialization,          IDeserializationCallback)
 DEFINE_CLASS(STREAMING_CONTEXT,     Serialization,          StreamingContext)
 DEFINE_CLASS(SERIALIZATION_INFO,    Serialization,          SerializationInfo)
-
+DEFINE_CLASS(DESERIALIZATION_TRACKER, Serialization, DeserializationTracker)
 
 
 DEFINE_CLASS(IENUMERATOR,           Collections,            IEnumerator)
index 7b99d2b..d2aab7c 100644 (file)
@@ -1626,6 +1626,7 @@ Thread::Thread()
     memset(&m_activityId, 0, sizeof(m_activityId));
 #endif // FEATURE_PERFTRACING
     m_HijackReturnKind = RT_Illegal;
+    m_DeserializationTracker = NULL;
 }
 
 //--------------------------------------------------------------------
@@ -2641,6 +2642,11 @@ Thread::~Thread()
         // Destroy any handles that we're using to hold onto exception objects
         SafeSetThrowables(NULL);
 
+        if (m_DeserializationTracker != NULL)
+        {
+            DestroyGlobalStrongHandle(m_DeserializationTracker);
+        }
+
         DestroyShortWeakHandle(m_ExposedObject);
         DestroyStrongHandle(m_StrongHndToExposedObject);
     }
@@ -9396,3 +9402,30 @@ ULONGLONG Thread::QueryThreadProcessorUsage()
     return ullCurrentUsage - ullPreviousUsage;
 }
 #endif // FEATURE_APPDOMAIN_RESOURCE_MONITORING
+
+OBJECTHANDLE Thread::GetOrCreateDeserializationTracker()
+{
+    CONTRACTL
+    {
+        THROWS;
+        GC_TRIGGERS;
+        MODE_COOPERATIVE;
+    }
+    CONTRACTL_END;
+
+#if !defined (DACCESS_COMPILE)
+    if (m_DeserializationTracker != NULL)
+    {
+        return m_DeserializationTracker;
+    }
+
+    _ASSERTE(this == GetThread());
+
+    MethodTable* pMT = MscorlibBinder::GetClass(CLASS__DESERIALIZATION_TRACKER);
+    m_DeserializationTracker = CreateGlobalStrongHandle(AllocateObject(pMT));
+
+    _ASSERTE(m_DeserializationTracker != NULL);
+#endif // !defined (DACCESS_COMPILE)
+
+    return m_DeserializationTracker;
+}
index 2fa7569..90a8931 100644 (file)
@@ -4970,6 +4970,12 @@ public:
         m_HijackReturnKind = returnKind;
     }
 #endif // FEATURE_HIJACK
+
+public:
+    OBJECTHANDLE GetOrCreateDeserializationTracker();
+
+private:
+    OBJECTHANDLE m_DeserializationTracker;
 };
 
 // End of class Thread