From 3c0d755886d2130e1b0db01524c140cb6c6b6fe6 Mon Sep 17 00:00:00 2001 From: Morgan Brown Date: Fri, 1 Mar 2019 08:46:22 -0800 Subject: [PATCH] Implement Serialization Guard Add Serialization Guard API and consume it in CoreLib targets --- src/System.Private.CoreLib/Resources/Strings.resx | 10 +- .../shared/System.Private.CoreLib.Shared.projitems | 3 + .../shared/System/IO/FileStream.cs | 9 ++ .../System/LocalAppContextSwitches.Common.cs | 18 ++- .../shared/System/LocalAppContextSwitches.cs | 10 ++ .../DeserializationBlockedException.cs | 41 +++++++ .../Runtime/Serialization/DeserializationToken.cs | 37 ++++++ .../Serialization/DeserializationTracker.cs | 17 +++ .../Runtime/Serialization/SerializationInfo.cs | 134 +++++++++++++++++++++ .../src/System/Reflection/Assembly.CoreCLR.cs | 4 + .../src/System/Threading/Thread.CoreCLR.cs | 4 + src/vm/comsynchronizable.cpp | 23 ++++ src/vm/comsynchronizable.h | 1 + src/vm/ecalllist.h | 1 + src/vm/mscorlib.h | 2 +- src/vm/threads.cpp | 33 +++++ src/vm/threads.h | 6 + 17 files changed, 349 insertions(+), 4 deletions(-) create mode 100644 src/System.Private.CoreLib/shared/System/Runtime/Serialization/DeserializationBlockedException.cs create mode 100644 src/System.Private.CoreLib/shared/System/Runtime/Serialization/DeserializationToken.cs create mode 100644 src/System.Private.CoreLib/shared/System/Runtime/Serialization/DeserializationTracker.cs diff --git a/src/System.Private.CoreLib/Resources/Strings.resx b/src/System.Private.CoreLib/Resources/Strings.resx index 155cb43..a912b83 100644 --- a/src/System.Private.CoreLib/Resources/Strings.resx +++ b/src/System.Private.CoreLib/Resources/Strings.resx @@ -2476,7 +2476,7 @@ A resolver is already set for the assembly. - + Cannot remove the last element from an empty collection. @@ -3751,4 +3751,10 @@ ArrayWithOffset: offset exceeds array size. - + + An action was attempted during deserialization that could lead to a security vulnerability. The action has been aborted. + + + An action was attempted during deserialization that could lead to a security vulnerability. The action has been aborted. To allow the action, set the '{0}' AppContext switch to true. + + \ No newline at end of file diff --git a/src/System.Private.CoreLib/shared/System.Private.CoreLib.Shared.projitems b/src/System.Private.CoreLib/shared/System.Private.CoreLib.Shared.projitems index 77b9589..e0bf112 100644 --- a/src/System.Private.CoreLib/shared/System.Private.CoreLib.Shared.projitems +++ b/src/System.Private.CoreLib/shared/System.Private.CoreLib.Shared.projitems @@ -690,6 +690,9 @@ + + + diff --git a/src/System.Private.CoreLib/shared/System/IO/FileStream.cs b/src/System.Private.CoreLib/shared/System/IO/FileStream.cs index 5c8913b..bc84d98 100644 --- a/src/System.Private.CoreLib/shared/System/IO/FileStream.cs +++ b/src/System.Private.CoreLib/shared/System/IO/FileStream.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System.Runtime.Serialization; using System.Threading; using System.Threading.Tasks; using Microsoft.Win32.SafeHandles; @@ -64,6 +65,9 @@ namespace System.IO /// Whether the file stream's handle has been exposed. private bool _exposedHandle; + /// Caches whether Serialization Guard has been disabled for file writes + private static int s_cachedSerializationSwitch = 0; + [Obsolete("This constructor has been deprecated. Please use new FileStream(SafeFileHandle handle, FileAccess access) instead. https://go.microsoft.com/fwlink/?linkid=14202")] public FileStream(IntPtr handle, FileAccess access) : this(handle, access, true, DefaultBufferSize, false) @@ -228,6 +232,11 @@ namespace System.IO if ((options & FileOptions.Asynchronous) != 0) _useAsyncIO = true; + if ((access & FileAccess.Write) == FileAccess.Write) + { + SerializationInfo.ThrowIfDeserializationInProgress("AllowFileWrites", ref s_cachedSerializationSwitch); + } + _fileHandle = OpenHandle(mode, share, options); try diff --git a/src/System.Private.CoreLib/shared/System/LocalAppContextSwitches.Common.cs b/src/System.Private.CoreLib/shared/System/LocalAppContextSwitches.Common.cs index 521848f..6c0e750 100644 --- a/src/System.Private.CoreLib/shared/System/LocalAppContextSwitches.Common.cs +++ b/src/System.Private.CoreLib/shared/System/LocalAppContextSwitches.Common.cs @@ -25,7 +25,12 @@ namespace System private static bool GetCachedSwitchValueInternal(string switchName, ref int cachedSwitchValue) { bool isSwitchEnabled; - AppContext.TryGetSwitch(switchName, out isSwitchEnabled); + + bool hasSwitch = AppContext.TryGetSwitch(switchName, out isSwitchEnabled); + if (!hasSwitch) + { + isSwitchEnabled = GetSwitchDefaultValue(switchName); + } AppContext.TryGetSwitch(@"TestSwitch.LocalAppContext.DisableCaching", out bool disableCaching); if (!disableCaching) @@ -35,5 +40,16 @@ namespace System return isSwitchEnabled; } + + // Provides default values for switches if they're not always false by default + private static bool GetSwitchDefaultValue(string switchName) + { + if (switchName == "Switch.System.Runtime.Serialization.SerializationGuard") + { + return true; + } + + return false; + } } } diff --git a/src/System.Private.CoreLib/shared/System/LocalAppContextSwitches.cs b/src/System.Private.CoreLib/shared/System/LocalAppContextSwitches.cs index 8194f87..b0999da 100644 --- a/src/System.Private.CoreLib/shared/System/LocalAppContextSwitches.cs +++ b/src/System.Private.CoreLib/shared/System/LocalAppContextSwitches.cs @@ -47,5 +47,15 @@ namespace System return GetCachedSwitchValue("Switch.System.Diagnostics.EventSource.PreserveEventListnerObjectIdentity", ref s_preserveEventListnerObjectIdentity); } } + + private static int s_serializationGuard; + public static bool SerializationGuard + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get + { + return GetCachedSwitchValue("Switch.System.Runtime.Serialization.SerializationGuard", ref s_serializationGuard); + } + } } } diff --git a/src/System.Private.CoreLib/shared/System/Runtime/Serialization/DeserializationBlockedException.cs b/src/System.Private.CoreLib/shared/System/Runtime/Serialization/DeserializationBlockedException.cs new file mode 100644 index 0000000..8c8f780 --- /dev/null +++ b/src/System.Private.CoreLib/shared/System/Runtime/Serialization/DeserializationBlockedException.cs @@ -0,0 +1,41 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; + +namespace System.Runtime.Serialization +{ + // Thrown when a dangerous action would be performed during deserialization + [Serializable] + [System.Runtime.CompilerServices.TypeForwardedFrom("mscorlib, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089")] + public sealed class DeserializationBlockedException : SerializationException + { + // Creates a new DeserializationBlockedException with its message + // string set to a default message. + public DeserializationBlockedException() + : base(SR.Serialization_DangerousDeserialization) + { + HResult = HResults.COR_E_SERIALIZATION; + } + + // Creates a new DeserializationBlockedException with a message indicating an opt-out switch + // for a particular part of SerializationGuard + public DeserializationBlockedException(string message) + : base(message) + { + HResult = HResults.COR_E_SERIALIZATION; + } + + public DeserializationBlockedException(Exception innerException) + : base(SR.Serialization_DangerousDeserialization, innerException) + { + HResult = HResults.COR_E_SERIALIZATION; + } + + private DeserializationBlockedException(SerializationInfo info, StreamingContext context) + : base(info, context) + { + } + } +} diff --git a/src/System.Private.CoreLib/shared/System/Runtime/Serialization/DeserializationToken.cs b/src/System.Private.CoreLib/shared/System/Runtime/Serialization/DeserializationToken.cs new file mode 100644 index 0000000..e879ef5 --- /dev/null +++ b/src/System.Private.CoreLib/shared/System/Runtime/Serialization/DeserializationToken.cs @@ -0,0 +1,37 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Runtime.CompilerServices; +using System.Threading; + +namespace System.Runtime.Serialization +{ + // Tracks whether deserialization is currently in progress + public readonly struct DeserializationToken : IDisposable + { + private readonly DeserializationTracker _tracker; + + internal DeserializationToken(DeserializationTracker tracker) + { + _tracker = tracker; + } + + // If this token owned the DeserializationTracker, turn off DeserializationInProgress tracking + public void Dispose() + { + if (_tracker != null && _tracker.DeserializationInProgress) + { + lock (_tracker) + { + if (_tracker.DeserializationInProgress) + { + _tracker.DeserializationInProgress = false; + SerializationInfo.AsyncDeserializationInProgress.Value = false; + } + } + } + } + } +} diff --git a/src/System.Private.CoreLib/shared/System/Runtime/Serialization/DeserializationTracker.cs b/src/System.Private.CoreLib/shared/System/Runtime/Serialization/DeserializationTracker.cs new file mode 100644 index 0000000..fb4ac51 --- /dev/null +++ b/src/System.Private.CoreLib/shared/System/Runtime/Serialization/DeserializationTracker.cs @@ -0,0 +1,17 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Runtime.CompilerServices; + +namespace System.Runtime.Serialization +{ + // Tracks whether deserialization is currently in progress + internal sealed class DeserializationTracker + { + // True if the thread this tracker applies to is currently deserializing + // potentially untrusted data + internal bool DeserializationInProgress { get; set; } + } +} diff --git a/src/System.Private.CoreLib/shared/System/Runtime/Serialization/SerializationInfo.cs b/src/System.Private.CoreLib/shared/System/Runtime/Serialization/SerializationInfo.cs index 8af5c4b..fb6ac44 100644 --- a/src/System.Private.CoreLib/shared/System/Runtime/Serialization/SerializationInfo.cs +++ b/src/System.Private.CoreLib/shared/System/Runtime/Serialization/SerializationInfo.cs @@ -2,8 +2,11 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System.Collections; using System.Collections.Generic; using System.Diagnostics; +using System.Security; +using System.Threading; namespace System.Runtime.Serialization { @@ -23,6 +26,137 @@ namespace System.Runtime.Serialization private string _rootTypeName; private string _rootTypeAssemblyName; private Type _rootType; + + internal static AsyncLocal AsyncDeserializationInProgress { get; } = new AsyncLocal(); + +#if !CORECLR + // On AoT, assume private members are reflection blocked, so there's no further protection required + // for the thread's DeserializationTracker + [ThreadStatic] + private static DeserializationTracker t_deserializationTracker; + + private static DeserializationTracker GetThreadDeserializationTracker() + { + if (t_deserializationTracker == null) + { + t_deserializationTracker = new DeserializationTracker(); + } + + return t_deserializationTracker; + } +#endif // !CORECLR + + // Returns true if deserialization is currently in progress + public static bool DeserializationInProgress + { + [DynamicSecurityMethod] // Methods containing StackCrawlMark local var must be marked DynamicSecurityMethod + get + { + if (AsyncDeserializationInProgress.Value) + { + return true; + } + +#if CORECLR + StackCrawlMark stackMark = StackCrawlMark.LookForMe; + DeserializationTracker tracker = Thread.GetThreadDeserializationTracker(ref stackMark); +#else + DeserializationTracker tracker = GetThreadDeserializationTracker(); +#endif + bool result = tracker.DeserializationInProgress; + return result; + } + } + + // Throws a DeserializationBlockedException if dangerous deserialization is currently + // in progress + public static void ThrowIfDeserializationInProgress() + { + if (DeserializationInProgress) + { + throw new DeserializationBlockedException(); + } + } + + // Throws a DeserializationBlockedException if dangerous deserialization is currently + // in progress and the AppContext switch Switch.System.Runtime.Serialization.SerializationGuard.{switchSuffix} + // is not true. The value of the switch is cached in cachedValue to avoid repeated lookups: + // 0: No value cached + // 1: The switch is true + // -1: The switch is false + public static void ThrowIfDeserializationInProgress(string switchSuffix, ref int cachedValue) + { + const string SwitchPrefix = "Switch.System.Runtime.Serialization.SerializationGuard."; + if (switchSuffix == null) + { + throw new ArgumentNullException(nameof(switchSuffix)); + } + if (String.IsNullOrWhiteSpace(switchSuffix)) + { + throw new ArgumentException(SR.Argument_EmptyName, nameof(switchSuffix)); + } + + if (cachedValue == 0) + { + bool isEnabled = false; + if (AppContext.TryGetSwitch(SwitchPrefix + switchSuffix, out isEnabled) && isEnabled) + { + cachedValue = 1; + } + else + { + cachedValue = -1; + } + } + + if (cachedValue == 1) + { + return; + } + else if (cachedValue == -1) + { + if (DeserializationInProgress) + { + throw new DeserializationBlockedException(SR.Format(SR.Serialization_DangerousDeserialization_Switch, SwitchPrefix + switchSuffix)); + } + } + else + { + throw new ArgumentOutOfRangeException(nameof(cachedValue)); + } + } + + // Declares that the current thread and async context have begun deserialization. + // In this state, if the SerializationGuard or other related AppContext switches are set, + // actions likely to be dangerous during deserialization, such as starting a process will be blocked. + // Returns a DeserializationToken that must be disposed to remove the deserialization state. + [DynamicSecurityMethod] // Methods containing StackCrawlMark local var must be marked DynamicSecurityMethod + public static DeserializationToken StartDeserialization() + { + if (LocalAppContextSwitches.SerializationGuard) + { +#if CORECLR + StackCrawlMark stackMark = StackCrawlMark.LookForMe; + DeserializationTracker tracker = Thread.GetThreadDeserializationTracker(ref stackMark); +#else + DeserializationTracker tracker = GetThreadDeserializationTracker(); +#endif + if (!tracker.DeserializationInProgress) + { + lock (tracker) + { + if (!tracker.DeserializationInProgress) + { + AsyncDeserializationInProgress.Value = true; + tracker.DeserializationInProgress = true; + return new DeserializationToken(tracker); + } + } + } + } + + return new DeserializationToken(null); + } [CLSCompliant(false)] public SerializationInfo(Type type, IFormatterConverter converter) diff --git a/src/System.Private.CoreLib/src/System/Reflection/Assembly.CoreCLR.cs b/src/System.Private.CoreLib/src/System/Reflection/Assembly.CoreCLR.cs index 30acc44..83c2c0a 100644 --- a/src/System.Private.CoreLib/src/System/Reflection/Assembly.CoreCLR.cs +++ b/src/System.Private.CoreLib/src/System/Reflection/Assembly.CoreCLR.cs @@ -19,6 +19,7 @@ namespace System.Reflection private static object s_syncRootLoadFrom = new object(); private static List s_LoadFromAssemblyList = new List(); private static object s_syncLoadFromAssemblyList = new object(); + private static int s_cachedSerializationSwitch = 0; private static Assembly LoadFromResolveHandler(object sender, ResolveEventArgs args) { @@ -163,6 +164,9 @@ namespace System.Reflection throw new NotSupportedException(SR.Format(SR.NotSupported_AppX, "Assembly.Load(byte[], ...)")); #endif + SerializationInfo.ThrowIfDeserializationInProgress("AllowAssembliesFromByteArrays", + ref s_cachedSerializationSwitch); + AssemblyLoadContext alc = new IndividualAssemblyLoadContext(); MemoryStream assemblyStream = new MemoryStream(rawAssembly); MemoryStream symbolStream = (rawSymbolStore != null) ? new MemoryStream(rawSymbolStore) : null; diff --git a/src/System.Private.CoreLib/src/System/Threading/Thread.CoreCLR.cs b/src/System.Private.CoreLib/src/System/Threading/Thread.CoreCLR.cs index 46ca988..950b143 100644 --- a/src/System.Private.CoreLib/src/System/Threading/Thread.CoreCLR.cs +++ b/src/System.Private.CoreLib/src/System/Threading/Thread.CoreCLR.cs @@ -7,6 +7,7 @@ using System.Globalization; using System.Runtime.CompilerServices; using System.Runtime.ConstrainedExecution; using System.Runtime.InteropServices; +using System.Runtime.Serialization; namespace System.Threading { @@ -335,6 +336,9 @@ namespace System.Threading [DllImport(JitHelpers.QCall, CharSet = CharSet.Unicode)] private static extern void InformThreadNameChange(ThreadHandle t, string name, int len); + [MethodImplAttribute(MethodImplOptions.InternalCall)] + internal static extern DeserializationTracker GetThreadDeserializationTracker(ref StackCrawlMark stackMark); + /// Returns true if the thread has been started and is not dead. public extern bool IsAlive { diff --git a/src/vm/comsynchronizable.cpp b/src/vm/comsynchronizable.cpp index 0cd7b85..a460306 100644 --- a/src/vm/comsynchronizable.cpp +++ b/src/vm/comsynchronizable.cpp @@ -1423,6 +1423,29 @@ BOOL QCALLTYPE ThreadNative::YieldThread() return ret; } +FCIMPL1(Object*, ThreadNative::GetThreadDeserializationTracker, StackCrawlMark* stackMark) +{ + FCALL_CONTRACT; + OBJECTREF refRetVal = NULL; + HELPER_METHOD_FRAME_BEGIN_RET_1(refRetVal) + + // To avoid reflection trying to bypass deserialization tracking, check the caller + // and only allow SerializationInfo to call into this method. + MethodTable* pCallerMT = SystemDomain::GetCallersType(stackMark); + if (pCallerMT != MscorlibBinder::GetClass(CLASS__SERIALIZATION_INFO)) + { + COMPlusThrowArgumentException(W("stackMark"), NULL); + } + + Thread* pThread = GetThread(); + + refRetVal = ObjectFromHandle(pThread->GetOrCreateDeserializationTracker()); + + HELPER_METHOD_FRAME_END(); + + return OBJECTREFToObject(refRetVal); +} +FCIMPLEND FCIMPL0(INT32, ThreadNative::GetCurrentProcessorNumber) { diff --git a/src/vm/comsynchronizable.h b/src/vm/comsynchronizable.h index 15c0218..d0f3bb3 100644 --- a/src/vm/comsynchronizable.h +++ b/src/vm/comsynchronizable.h @@ -100,6 +100,7 @@ public: static FCDECL1(void, DisableComObjectEagerCleanup, ThreadBaseObject* pThis); #endif //FEATURE_COMINTEROP static FCDECL1(FC_BOOL_RET,IsThreadpoolThread, ThreadBaseObject* thread); + static FCDECL1(Object*, GetThreadDeserializationTracker, StackCrawlMark* stackMark); static FCDECL0(INT32, GetCurrentProcessorNumber); diff --git a/src/vm/ecalllist.h b/src/vm/ecalllist.h index 9a204e6..35fbae2 100644 --- a/src/vm/ecalllist.h +++ b/src/vm/ecalllist.h @@ -655,6 +655,7 @@ FCFuncStart(gThreadFuncs) FCFuncElement("Join", ThreadNative::Join) QCFuncElement("GetOptimalMaxSpinWaitsPerSpinIterationInternal", ThreadNative::GetOptimalMaxSpinWaitsPerSpinIteration) FCFuncElement("GetCurrentProcessorNumber", ThreadNative::GetCurrentProcessorNumber) + FCFuncElement("GetThreadDeserializationTracker", ThreadNative::GetThreadDeserializationTracker) FCFuncEnd() FCFuncStart(gThreadPoolFuncs) diff --git a/src/vm/mscorlib.h b/src/vm/mscorlib.h index d7b5f6e..b78684d 100644 --- a/src/vm/mscorlib.h +++ b/src/vm/mscorlib.h @@ -455,7 +455,7 @@ DEFINE_CLASS(IOBJECTREFERENCE, Serialization, IObjectReference) DEFINE_CLASS(IDESERIALIZATIONCB, Serialization, IDeserializationCallback) DEFINE_CLASS(STREAMING_CONTEXT, Serialization, StreamingContext) DEFINE_CLASS(SERIALIZATION_INFO, Serialization, SerializationInfo) - +DEFINE_CLASS(DESERIALIZATION_TRACKER, Serialization, DeserializationTracker) DEFINE_CLASS(IENUMERATOR, Collections, IEnumerator) diff --git a/src/vm/threads.cpp b/src/vm/threads.cpp index 7b99d2b..d2aab7c 100644 --- a/src/vm/threads.cpp +++ b/src/vm/threads.cpp @@ -1626,6 +1626,7 @@ Thread::Thread() memset(&m_activityId, 0, sizeof(m_activityId)); #endif // FEATURE_PERFTRACING m_HijackReturnKind = RT_Illegal; + m_DeserializationTracker = NULL; } //-------------------------------------------------------------------- @@ -2641,6 +2642,11 @@ Thread::~Thread() // Destroy any handles that we're using to hold onto exception objects SafeSetThrowables(NULL); + if (m_DeserializationTracker != NULL) + { + DestroyGlobalStrongHandle(m_DeserializationTracker); + } + DestroyShortWeakHandle(m_ExposedObject); DestroyStrongHandle(m_StrongHndToExposedObject); } @@ -9396,3 +9402,30 @@ ULONGLONG Thread::QueryThreadProcessorUsage() return ullCurrentUsage - ullPreviousUsage; } #endif // FEATURE_APPDOMAIN_RESOURCE_MONITORING + +OBJECTHANDLE Thread::GetOrCreateDeserializationTracker() +{ + CONTRACTL + { + THROWS; + GC_TRIGGERS; + MODE_COOPERATIVE; + } + CONTRACTL_END; + +#if !defined (DACCESS_COMPILE) + if (m_DeserializationTracker != NULL) + { + return m_DeserializationTracker; + } + + _ASSERTE(this == GetThread()); + + MethodTable* pMT = MscorlibBinder::GetClass(CLASS__DESERIALIZATION_TRACKER); + m_DeserializationTracker = CreateGlobalStrongHandle(AllocateObject(pMT)); + + _ASSERTE(m_DeserializationTracker != NULL); +#endif // !defined (DACCESS_COMPILE) + + return m_DeserializationTracker; +} diff --git a/src/vm/threads.h b/src/vm/threads.h index 2fa7569..90a8931 100644 --- a/src/vm/threads.h +++ b/src/vm/threads.h @@ -4970,6 +4970,12 @@ public: m_HijackReturnKind = returnKind; } #endif // FEATURE_HIJACK + +public: + OBJECTHANDLE GetOrCreateDeserializationTracker(); + +private: + OBJECTHANDLE m_DeserializationTracker; }; // End of class Thread -- 2.7.4