Disallow unrestricted polymorphic deserialization in DataSet (#39304)
authorLevi Broderick <GrabYourPitchforks@users.noreply.github.com>
Wed, 15 Jul 2020 06:50:33 +0000 (23:50 -0700)
committerGitHub <noreply@github.com>
Wed, 15 Jul 2020 06:50:33 +0000 (23:50 -0700)
Fixes CVE-2020-1147
https://portal.msrc.microsoft.com/en-us/security-guidance/advisory/CVE-2020-1147
See also https://go.microsoft.com/fwlink/?linkid=2132227.

12 files changed:
src/libraries/System.Data.Common/src/Resources/Strings.resx
src/libraries/System.Data.Common/src/System.Data.Common.csproj
src/libraries/System.Data.Common/src/System/Data/Common/ObjectStorage.cs
src/libraries/System.Data.Common/src/System/Data/DataColumn.cs
src/libraries/System.Data.Common/src/System/Data/DataException.cs
src/libraries/System.Data.Common/src/System/Data/DataSet.cs
src/libraries/System.Data.Common/src/System/Data/DataTable.cs
src/libraries/System.Data.Common/src/System/Data/Filter/FunctionNode.cs
src/libraries/System.Data.Common/src/System/Data/LocalAppContextSwitches.cs [new file with mode: 0644]
src/libraries/System.Data.Common/src/System/Data/TypeLimiter.cs [new file with mode: 0644]
src/libraries/System.Data.Common/tests/System.Data.Common.Tests.csproj
src/libraries/System.Data.Common/tests/System/Data/RestrictedTypeHandlingTests.cs [new file with mode: 0644]

index 4535a6944bffc1bde19f3170a2320218865f13a2..671bf490bc1351f14f0af453c1037a17f20f8eff 100644 (file)
   <data name="Data_ArgumentOutOfRange" xml:space="preserve"><value>'{0}' argument is out of range.</value></data>
   <data name="Data_ArgumentNull" xml:space="preserve"><value>'{0}' argument cannot be null.</value></data>
   <data name="Data_ArgumentContainsNull" xml:space="preserve"><value>'{0}' argument contains null value.</value></data>
+  <data name="Data_TypeNotAllowed" xml:space="preserve"><value>Type '{0}' is not allowed here. See https://go.microsoft.com/fwlink/?linkid=2132227 for more details.</value></data>
   <data name="DataColumns_OutOfRange" xml:space="preserve"><value>Cannot find column {0}.</value></data>
   <data name="DataColumns_Add1" xml:space="preserve"><value>Column '{0}' already belongs to this DataTable.</value></data>
   <data name="DataColumns_Add2" xml:space="preserve"><value>Column '{0}' already belongs to another DataTable.</value></data>
index cbde6845684d1d98c3860104e00f84873205a522..a872f6cefe4519c73d05d00970833d407498bf7d 100644 (file)
     <Compile Include="System\Data\KeyRestrictionBehavior.cs" />
     <Compile Include="System\Data\LinqDataView.cs" />
     <Compile Include="System\Data\LoadOption.cs" />
+    <Compile Include="System\Data\LocalAppContextSwitches.cs" />
+    <Compile Include="$(CommonPath)System\LocalAppContextSwitches.Common.cs">
+      <Link>Common\System\LocalAppContextSwitches.Common.cs</Link>
+    </Compile>
     <Compile Include="System\Data\MappingType.cs" />
     <Compile Include="System\Data\MergeFailedEvent.cs" />
     <Compile Include="System\Data\MergeFailedEventHandler.cs" />
     <Compile Include="System\Data\StrongTypingException.cs" />
     <Compile Include="System\Data\TypedTableBase.cs" />
     <Compile Include="System\Data\TypedTableBaseExtensions.cs" />
+    <Compile Include="System\Data\TypeLimiter.cs" />
     <Compile Include="System\Data\UniqueConstraint.cs" />
     <Compile Include="System\Data\UpdateRowSource.cs" />
     <Compile Include="System\Data\Common\UInt64Storage.cs" />
     <Compile Include="System\Data\ProviderBase\SchemaMapping.cs" />
   </ItemGroup>
   <ItemGroup>
-    <Reference Include="System.Collections" />
+    <ProjectReference Include="$(CoreLibProject)" />
+    <ProjectReference Include="..\..\System.Collections\src\System.Collections.csproj" />
+    <ProjectReference Include="..\..\System.Collections.NonGeneric\src\System.Collections.NonGeneric.csproj" />
+    <ProjectReference Include="..\..\System.ComponentModel.TypeConverter\src\System.ComponentModel.TypeConverter.csproj" />
+    <ProjectReference Include="..\..\System.Runtime\src\System.Runtime.csproj" />
+    <ProjectReference Include="..\..\System.Runtime.Extensions\src\System.Runtime.Extensions.csproj" />
+    <ProjectReference Include="..\..\System.Private.Uri\src\System.Private.Uri.csproj" />
     <Reference Include="System.Collections.Concurrent" />
-    <Reference Include="System.Collections.NonGeneric" />
     <Reference Include="System.ComponentModel" />
     <Reference Include="System.ComponentModel.Primitives" />
-    <Reference Include="System.ComponentModel.TypeConverter" />
-    <Reference Include="System.Diagnostics.Tracing" />
+    <Reference Include="System.Drawing.Primitives" />
     <Reference Include="System.Linq" />
     <Reference Include="System.Linq.Expressions" />
-    <Reference Include="System.Memory" />
     <Reference Include="System.ObjectModel" />
-    <Reference Include="System.Runtime" />
-    <Reference Include="System.Runtime.Extensions" />
     <Reference Include="System.Runtime.Numerics" />
     <Reference Include="System.Runtime.Serialization.Formatters" />
-    <Reference Include="System.Text.Encoding.Extensions" />
     <Reference Include="System.Text.RegularExpressions" />
-    <Reference Include="System.Threading" />
-    <Reference Include="System.Threading.Thread" />
     <Reference Include="System.Transactions.Local" />
     <Reference Include="System.Xml.ReaderWriter" />
     <Reference Include="System.Xml.XmlSerializer" />
index f220cf84f7605058675a3d147ce0628c57129309..6e16308edca4993270c10ee0dc2e3340818e9c57 100644 (file)
@@ -403,6 +403,9 @@ namespace System.Data.Common
 
                         if (type == typeof(object))
                             throw ExceptionBuilder.CanNotDeserializeObjectType();
+
+                        TypeLimiter.EnsureTypeIsAllowed(type);
+
                         if (!isBaseCLRType)
                         {
                             retValue = System.Activator.CreateInstance(type, true)!;
index 1b09c0a75cbe4427927084abaac87c9003e28187..2547216825d58d77cae846b453e65beb9a0d859c 100644 (file)
@@ -143,6 +143,7 @@ namespace System.Data
 
         private void UpdateColumnType(Type type, StorageType typeCode)
         {
+            TypeLimiter.EnsureTypeIsAllowed(type);
             _dataType = type;
             _storageType = typeCode;
             if (StorageType.DateTime != typeCode)
index 261005c21ffc9ec71bd64800334cd4be31ea2b09..4df7e0be286e4a5c718ed9af21def84c8d5562b9 100644 (file)
@@ -350,6 +350,7 @@ namespace System.Data
         public static Exception ArgumentOutOfRange(string paramName) => _ArgumentOutOfRange(paramName, SR.Format(SR.Data_ArgumentOutOfRange, paramName));
         public static Exception BadObjectPropertyAccess(string error) => _InvalidOperation(SR.Format(SR.DataConstraint_BadObjectPropertyAccess, error));
         public static Exception ArgumentContainsNull(string paramName) => _Argument(paramName, SR.Format(SR.Data_ArgumentContainsNull, paramName));
+        public static Exception TypeNotAllowed(Type type) => _InvalidOperation(SR.Format(SR.Data_TypeNotAllowed, type.AssemblyQualifiedName));
 
 
         //
index 3bf231f8856f3cc73f71fa948af5896d1f778ea3..6ade9cb9fa7e6f38d0887e8824e1e506ce8461b4 100644 (file)
@@ -1964,9 +1964,11 @@ namespace System.Data
 
         internal XmlReadMode ReadXml(XmlReader reader, bool denyResolving)
         {
+            IDisposable? restrictedScope = null;
             long logScopeId = DataCommonEventSource.Log.EnterScope("<ds.DataSet.ReadXml|INFO> {0}, denyResolving={1}", ObjectID, denyResolving);
             try
             {
+                restrictedScope = TypeLimiter.EnterRestrictedScope(this);
                 DataTable.DSRowDiffIdUsageSection rowDiffIdUsage = default;
                 try
                 {
@@ -2234,6 +2236,7 @@ namespace System.Data
             }
             finally
             {
+                restrictedScope?.Dispose();
                 DataCommonEventSource.Log.ExitScope(logScopeId);
             }
         }
@@ -2470,9 +2473,11 @@ namespace System.Data
 
         internal XmlReadMode ReadXml(XmlReader? reader, XmlReadMode mode, bool denyResolving)
         {
+            IDisposable? restictedScope = null;
             long logScopeId = DataCommonEventSource.Log.EnterScope("<ds.DataSet.ReadXml|INFO> {0}, mode={1}, denyResolving={2}", ObjectID, mode, denyResolving);
             try
             {
+                restictedScope = TypeLimiter.EnterRestrictedScope(this);
                 XmlReadMode ret = mode;
 
                 if (reader == null)
@@ -2714,6 +2719,7 @@ namespace System.Data
             }
             finally
             {
+                restictedScope?.Dispose();
                 DataCommonEventSource.Log.ExitScope(logScopeId);
             }
         }
index e87718342def94c694e544257f484ba36816ef5e..851ac8900f70313c6576cc370a6dd2429ac63eb8 100644 (file)
@@ -5657,9 +5657,11 @@ namespace System.Data
 
         internal XmlReadMode ReadXml(XmlReader? reader, bool denyResolving)
         {
+            IDisposable? restrictedScope = null;
             long logScopeId = DataCommonEventSource.Log.EnterScope("<ds.DataTable.ReadXml|INFO> {0}, denyResolving={1}", ObjectID, denyResolving);
             try
             {
+                restrictedScope = TypeLimiter.EnterRestrictedScope(this);
                 RowDiffIdUsageSection rowDiffIdUsage = default;
                 try
                 {
@@ -5894,15 +5896,18 @@ namespace System.Data
             }
             finally
             {
+                restrictedScope?.Dispose();
                 DataCommonEventSource.Log.ExitScope(logScopeId);
             }
         }
 
         internal XmlReadMode ReadXml(XmlReader? reader, XmlReadMode mode, bool denyResolving)
         {
+            IDisposable? restrictedScope = null;
             RowDiffIdUsageSection rowDiffIdUsage = default;
             try
             {
+                restrictedScope = TypeLimiter.EnterRestrictedScope(this);
                 bool fSchemaFound = false;
                 bool fDataFound = false;
                 bool fIsXdr = false;
@@ -6188,6 +6193,7 @@ namespace System.Data
             }
             finally
             {
+                restrictedScope?.Dispose();
                 // prepare and cleanup rowDiffId hashtable
                 rowDiffIdUsage.Cleanup();
             }
index f2515033abb68c413e1b24d3549b6a1a8a2f183b..6a666e60718f4b25738b937f45fdab5519df21c5 100644 (file)
@@ -5,6 +5,7 @@ using System.Collections.Generic;
 using System.Data.Common;
 using System.Data.SqlTypes;
 using System.Diagnostics;
+using System.Runtime.Serialization;
 
 namespace System.Data
 {
@@ -15,6 +16,7 @@ namespace System.Data
         internal int _argumentCount;
         internal const int initialCapacity = 1;
         internal ExpressionNode[]? _arguments;
+        private readonly TypeLimiter? _capturedLimiter;
 
         private static readonly Function[] s_funcs = new Function[] {
             new Function("Abs", FunctionId.Abs, typeof(object), true, false, 1, typeof(object), null, null),
@@ -39,6 +41,12 @@ namespace System.Data
 
         internal FunctionNode(DataTable? table, string name) : base(table)
         {
+            // Because FunctionNode instances are created eagerly but evaluated lazily,
+            // we need to capture the deserialization scope here. The scope could be
+            // null if no deserialization is in progress.
+
+            _capturedLimiter = TypeLimiter.Capture();
+
             _name = name;
             for (int i = 0; i < s_funcs.Length; i++)
             {
@@ -288,6 +296,11 @@ namespace System.Data
                 throw ExprException.InvalidType(typeName);
             }
 
+            // ReadXml might not be on the current call stack. So we'll use the TypeLimiter
+            // that was captured when this FunctionNode instance was created.
+
+            TypeLimiter.EnsureTypeIsAllowed(dataType, _capturedLimiter);
+
             return dataType;
         }
 
@@ -493,10 +506,17 @@ namespace System.Data
                             {
                                 return SqlConvert.ChangeType2((decimal)SqlConvert.ChangeType2(argumentValues[0], StorageType.Decimal, typeof(decimal), FormatProvider), mytype, type, FormatProvider);
                             }
-                            return SqlConvert.ChangeType2(argumentValues[0], mytype, type, FormatProvider);
                         }
 
-                        return SqlConvert.ChangeType2(argumentValues[0], mytype, type, FormatProvider);
+                        // The Convert function can be called lazily, outside of a previous Serialization Guard scope.
+                        // If there was a type limiter scope on the stack at the time this Convert function was created,
+                        // we must manually re-enter the Serialization Guard scope.
+
+                        DeserializationToken deserializationToken = (_capturedLimiter != null) ? SerializationInfo.StartDeserialization() : default;
+                        using (deserializationToken)
+                        {
+                            return SqlConvert.ChangeType2(argumentValues[0], mytype, type, FormatProvider);
+                        }
                     }
 
                     return argumentValues[0];
diff --git a/src/libraries/System.Data.Common/src/System/Data/LocalAppContextSwitches.cs b/src/libraries/System.Data.Common/src/System/Data/LocalAppContextSwitches.cs
new file mode 100644 (file)
index 0000000..ed77bea
--- /dev/null
@@ -0,0 +1,17 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Runtime.CompilerServices;
+
+namespace System
+{
+    internal static partial class LocalAppContextSwitches
+    {
+        private static int s_allowArbitraryTypeInstantiation;
+        public static bool AllowArbitraryTypeInstantiation
+        {
+            [MethodImpl(MethodImplOptions.AggressiveInlining)]
+            get => GetCachedSwitchValue("Switch.System.Data.AllowArbitraryDataSetTypeInstantiation", ref s_allowArbitraryTypeInstantiation);
+        }
+    }
+}
diff --git a/src/libraries/System.Data.Common/src/System/Data/TypeLimiter.cs b/src/libraries/System.Data.Common/src/System/Data/TypeLimiter.cs
new file mode 100644 (file)
index 0000000..59258ea
--- /dev/null
@@ -0,0 +1,304 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Collections.Generic;
+using System.Data.SqlTypes;
+using System.Diagnostics;
+using System.Drawing;
+using System.Linq;
+using System.Numerics;
+using System.Runtime.Serialization;
+
+namespace System.Data
+{
+    internal sealed class TypeLimiter
+    {
+        [ThreadStatic]
+        private static Scope? s_activeScope;
+
+        private Scope m_instanceScope;
+
+        private const string AppDomainDataSetDefaultAllowedTypesKey = "System.Data.DataSetDefaultAllowedTypes";
+
+        private TypeLimiter(Scope scope)
+        {
+            Debug.Assert(scope != null);
+            m_instanceScope = scope;
+        }
+
+        private static bool IsTypeLimitingDisabled
+            => LocalAppContextSwitches.AllowArbitraryTypeInstantiation;
+
+        /// <summary>
+        /// Captures the current <see cref="TypeLimiter"/> instance so that future
+        /// type checks can be performed against the allow list that was active during
+        /// the current deserialization scope.
+        /// </summary>
+        /// <remarks>
+        /// Returns null if no limiter is active.
+        /// </remarks>
+        public static TypeLimiter? Capture()
+        {
+            Scope? activeScope = s_activeScope;
+            return (activeScope != null) ? new TypeLimiter(activeScope) : null;
+        }
+
+        /// <summary>
+        /// Ensures the requested type is allowed by the rules of the active
+        /// deserialization scope. If a captured scope is provided, we'll use
+        /// that previously captured scope rather than the thread-static active
+        /// scope.
+        /// </summary>
+        /// <exception cref="InvalidOperationException">
+        /// If <paramref name="type"/> is not allowed.
+        /// </exception>
+        public static void EnsureTypeIsAllowed(Type? type, TypeLimiter? capturedLimiter = null)
+        {
+            if (type is null)
+            {
+                return; // nothing to check
+            }
+
+            Scope? capturedScope = capturedLimiter?.m_instanceScope ?? s_activeScope;
+            if (capturedScope is null)
+            {
+                return; // we're not in a restricted scope
+            }
+
+            if (capturedScope.IsAllowedType(type))
+            {
+                return; // type was explicitly allowed
+            }
+
+            // We encountered a type that wasn't in the allow list.
+            // Throw an exception to fail the current operation.
+
+            throw ExceptionBuilder.TypeNotAllowed(type);
+        }
+
+        public static IDisposable? EnterRestrictedScope(DataSet dataSet)
+        {
+            if (IsTypeLimitingDisabled)
+            {
+                return null; // protections aren't enabled
+            }
+
+            Scope newScope = new Scope(s_activeScope, GetPreviouslyDeclaredDataTypes(dataSet));
+            s_activeScope = newScope;
+            return newScope;
+        }
+
+        public static IDisposable? EnterRestrictedScope(DataTable dataTable)
+        {
+            if (IsTypeLimitingDisabled)
+            {
+                return null; // protections aren't enabled
+            }
+
+            Scope newScope = new Scope(s_activeScope, GetPreviouslyDeclaredDataTypes(dataTable));
+            s_activeScope = newScope;
+            return newScope;
+        }
+
+        /// <summary>
+        /// Given a <see cref="DataTable"/>, returns all of the <see cref="DataColumn.DataType"/>
+        /// values declared on the instance.
+        /// </summary>
+        private static IEnumerable<Type> GetPreviouslyDeclaredDataTypes(DataTable dataTable)
+        {
+            return (dataTable != null)
+                ? dataTable.Columns.Cast<DataColumn>().Select(column => column.DataType)
+                : Enumerable.Empty<Type>();
+        }
+
+        /// <summary>
+        /// Given a <see cref="DataSet"/>, returns all of the <see cref="DataColumn.DataType"/>
+        /// values declared on the instance.
+        /// </summary>
+        private static IEnumerable<Type> GetPreviouslyDeclaredDataTypes(DataSet dataSet)
+        {
+            return (dataSet != null)
+                ? dataSet.Tables.Cast<DataTable>().SelectMany(table => GetPreviouslyDeclaredDataTypes(table))
+                : Enumerable.Empty<Type>();
+        }
+
+        private sealed class Scope : IDisposable
+        {
+            /// <summary>
+            /// Types which are always allowed, unconditionally.
+            /// </summary>
+            private static readonly HashSet<Type> s_allowedTypes = new HashSet<Type>()
+            {
+                /* primitives */
+                typeof(bool),
+                typeof(char),
+                typeof(sbyte),
+                typeof(byte),
+                typeof(short),
+                typeof(ushort),
+                typeof(int),
+                typeof(uint),
+                typeof(long),
+                typeof(ulong),
+                typeof(float),
+                typeof(double),
+                typeof(decimal),
+                typeof(DateTime),
+                typeof(DateTimeOffset),
+                typeof(TimeSpan),
+                typeof(string),
+                typeof(Guid),
+                typeof(SqlBinary),
+                typeof(SqlBoolean),
+                typeof(SqlByte),
+                typeof(SqlBytes),
+                typeof(SqlChars),
+                typeof(SqlDateTime),
+                typeof(SqlDecimal),
+                typeof(SqlDouble),
+                typeof(SqlGuid),
+                typeof(SqlInt16),
+                typeof(SqlInt32),
+                typeof(SqlInt64),
+                typeof(SqlMoney),
+                typeof(SqlSingle),
+                typeof(SqlString),
+
+                /* non-primitives, but common */
+                typeof(object),
+                typeof(Type),
+                typeof(BigInteger),
+                typeof(Uri),
+
+                /* frequently used System.Drawing types */
+                typeof(Color),
+                typeof(Point),
+                typeof(PointF),
+                typeof(Rectangle),
+                typeof(RectangleF),
+                typeof(Size),
+                typeof(SizeF),
+            };
+
+            /// <summary>
+            /// Types which are allowed within the context of this scope.
+            /// </summary>
+            private HashSet<Type> m_allowedTypes;
+
+            /// <summary>
+            /// This thread's previous scope.
+            /// </summary>
+            private readonly Scope? m_previousScope;
+
+            /// <summary>
+            /// The Serialization Guard token associated with this scope.
+            /// </summary>
+            private readonly DeserializationToken m_deserializationToken;
+
+            internal Scope(Scope? previousScope, IEnumerable<Type> allowedTypes)
+            {
+                Debug.Assert(allowedTypes != null);
+
+                m_previousScope = previousScope;
+                m_allowedTypes = new HashSet<Type>(allowedTypes.Where(type => type != null));
+                m_deserializationToken = SerializationInfo.StartDeserialization();
+            }
+
+            public void Dispose()
+            {
+                if (this != s_activeScope)
+                {
+                    // Stacks should never be popped out of order.
+                    // We want to trap this condition in production.
+                    Debug.Fail("Scope was popped out of order.");
+                    throw new ObjectDisposedException(GetType().FullName);
+                }
+
+                m_deserializationToken.Dispose(); // it's a readonly struct, but Dispose still works properly
+                s_activeScope = m_previousScope; // could be null
+            }
+
+            public bool IsAllowedType(Type type)
+            {
+                Debug.Assert(type != null);
+
+                // Is the incoming type unconditionally allowed?
+
+                if (IsTypeUnconditionallyAllowed(type))
+                {
+                    return true;
+                }
+
+                // The incoming type is allowed if the current scope or any nested inner
+                // scope allowed it.
+
+                for (Scope? currentScope = this; currentScope != null; currentScope = currentScope.m_previousScope)
+                {
+                    if (currentScope.m_allowedTypes.Contains(type))
+                    {
+                        return true;
+                    }
+                }
+
+                // Did the application programmatically allow this type to be deserialized?
+
+                Type[]? appDomainAllowedTypes = (Type[]?)AppDomain.CurrentDomain.GetData(AppDomainDataSetDefaultAllowedTypesKey);
+                if (appDomainAllowedTypes != null)
+                {
+                    for (int i = 0; i < appDomainAllowedTypes.Length; i++)
+                    {
+                        if (type == appDomainAllowedTypes[i])
+                        {
+                            return true;
+                        }
+                    }
+                }
+
+                // All checks failed
+
+                return false;
+            }
+
+            private static bool IsTypeUnconditionallyAllowed(Type type)
+            {
+            TryAgain:
+                Debug.Assert(type != null);
+
+                // Check the list of unconditionally allowed types.
+
+                if (s_allowedTypes.Contains(type))
+                {
+                    return true;
+                }
+
+                // Enums are also always allowed, as we optimistically assume the app
+                // developer didn't define a dangerous enum type.
+
+                if (type.IsEnum)
+                {
+                    return true;
+                }
+
+                // Allow single-dimensional arrays of any unconditionally allowed type.
+
+                if (type.IsSZArray)
+                {
+                    type = type.GetElementType()!;
+                    goto TryAgain;
+                }
+
+                // Allow generic lists of any unconditionally allowed type.
+
+                if (type.IsGenericType && !type.IsGenericTypeDefinition && type.GetGenericTypeDefinition() == typeof(List<>))
+                {
+                    type = type.GetGenericArguments()[0];
+                    goto TryAgain;
+                }
+
+                // All checks failed.
+
+                return false;
+            }
+        }
+    }
+}
index 533f17e39b195d67dbbc5e712462af2d9fc6e9b2..75ae4978aa7cdb95390ea2a7ee46abce3493adf4 100644 (file)
     <Compile Include="System\Data\VersionNotFoundException.cs" />
     <Compile Include="System\Data\XmlDataLoaderTest.cs" />
     <Compile Include="System\Data\XmlDataReaderTest.cs" />
+    <Compile Include="System\Data\RestrictedTypeHandlingTests.cs" />
     <Compile Include="System\Xml\XmlDataDocumentTests.cs" />
     <Compile Include="$(CommonTestPath)System\Diagnostics\Tracing\TestEventListener.cs"
              Link="Common\System\Diagnostics\Tracing\TestEventListener.cs" />
diff --git a/src/libraries/System.Data.Common/tests/System/Data/RestrictedTypeHandlingTests.cs b/src/libraries/System.Data.Common/tests/System/Data/RestrictedTypeHandlingTests.cs
new file mode 100644 (file)
index 0000000..54644b6
--- /dev/null
@@ -0,0 +1,445 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Collections.Generic;
+using System.Data.SqlTypes;
+using System.Drawing;
+using System.IO;
+using System.Numerics;
+using System.Runtime.Serialization;
+using System.Text;
+using System.Xml;
+using System.Xml.Schema;
+using System.Xml.Serialization;
+using Xunit;
+using Xunit.Sdk;
+
+namespace System.Data.Tests
+{
+    // !! Important !!
+    // These tests manipulate global state, so they cannot be run in parallel with one another.
+    // We rely on xunit's default behavior of not parallelizing unit tests declared on the same
+    // test class: see https://xunit.net/docs/running-tests-in-parallel.html.
+    public class RestrictedTypeHandlingTests
+    {
+        private const string AppDomainDataSetDefaultAllowedTypesKey = "System.Data.DataSetDefaultAllowedTypes";
+
+        private static readonly Type[] _alwaysAllowedTypes = new Type[]
+        {
+            /* primitives */
+            typeof(bool),
+            typeof(char),
+            typeof(sbyte),
+            typeof(byte),
+            typeof(short),
+            typeof(ushort),
+            typeof(int),
+            typeof(uint),
+            typeof(long),
+            typeof(ulong),
+            typeof(float),
+            typeof(double),
+            typeof(decimal),
+            typeof(DateTime),
+            typeof(DateTimeOffset),
+            typeof(TimeSpan),
+            typeof(string),
+            typeof(Guid),
+            typeof(SqlBinary),
+            typeof(SqlBoolean),
+            typeof(SqlByte),
+            typeof(SqlBytes),
+            typeof(SqlChars),
+            typeof(SqlDateTime),
+            typeof(SqlDecimal),
+            typeof(SqlDouble),
+            typeof(SqlGuid),
+            typeof(SqlInt16),
+            typeof(SqlInt32),
+            typeof(SqlInt64),
+            typeof(SqlMoney),
+            typeof(SqlSingle),
+            typeof(SqlString),
+
+            /* non-primitives, but common */
+            typeof(object),
+            typeof(Type),
+            typeof(BigInteger),
+            typeof(Uri),
+
+            /* frequently used System.Drawing types */
+            typeof(Color),
+            typeof(Point),
+            typeof(PointF),
+            typeof(Rectangle),
+            typeof(RectangleF),
+            typeof(Size),
+            typeof(SizeF),
+
+            /* to test that enums are allowed */
+            typeof(StringComparison),
+        };
+
+        public static IEnumerable<object[]> AllowedTypes()
+        {
+            foreach (Type type in _alwaysAllowedTypes)
+            {
+                yield return new object[] { type }; // T
+                yield return new object[] { type.MakeArrayType() }; // T[] (SZArray)
+                yield return new object[] { type.MakeArrayType().MakeArrayType() }; // T[][] (jagged array)
+                yield return new object[] { typeof(List<>).MakeGenericType(type) }; // List<T>
+            }
+        }
+
+        public static IEnumerable<object[]> ForbiddenTypes()
+        {
+            // StringBuilder isn't in the allow list
+
+            yield return new object[] { typeof(StringBuilder) };
+            yield return new object[] { typeof(StringBuilder[]) };
+
+            // multi-dim arrays and non-sz arrays are forbidden
+
+            yield return new object[] { typeof(int[,]) };
+            yield return new object[] { Array.CreateInstance(typeof(int), new[] { 1 }, new[] { 1 }).GetType() };
+
+            // HashSet<T> isn't in the allow list
+
+            yield return new object[] { typeof(HashSet<int>) };
+
+            // DataSet / DataTable / SqlXml aren't in the allow list
+
+            yield return new object[] { typeof(DataSet) };
+            yield return new object[] { typeof(DataTable) };
+            yield return new object[] { typeof(SqlXml) };
+
+            // Enum, Array, and other base types aren't allowed
+
+            yield return new object[] { typeof(Enum) };
+            yield return new object[] { typeof(Array) };
+            yield return new object[] { typeof(ValueType) };
+            yield return new object[] { typeof(void) };
+        }
+
+        [Theory]
+        [MemberData(nameof(AllowedTypes))]
+        public void DataTable_ReadXml_AllowsKnownTypes(Type type)
+        {
+            // Arrange
+
+            DataTable table = new DataTable("MyTable");
+            table.Columns.Add("MyColumn", type);
+
+            string asXml = WriteXmlWithSchema(table.WriteXml);
+
+            // Act
+
+            table = ReadXml<DataTable>(asXml);
+
+            // Assert
+
+            Assert.Equal("MyTable", table.TableName);
+            Assert.Equal(1, table.Columns.Count);
+            Assert.Equal("MyColumn", table.Columns[0].ColumnName);
+            Assert.Equal(type, table.Columns[0].DataType);
+        }
+
+        [Theory]
+        [MemberData(nameof(ForbiddenTypes))]
+        public void DataTable_ReadXml_ForbidsUnknownTypes(Type type)
+        {
+            // Arrange
+
+            DataTable table = new DataTable("MyTable");
+            table.Columns.Add("MyColumn", type);
+
+            string asXml = WriteXmlWithSchema(table.WriteXml);
+
+            // Act & assert
+
+            Assert.Throws<InvalidOperationException>(() => ReadXml<DataTable>(asXml));
+        }
+
+        [Fact]
+        public void DataTable_ReadXml_HandlesXmlSerializableTypes()
+        {
+            // Arrange
+
+            DataTable table = new DataTable("MyTable");
+            table.Columns.Add("MyColumn", typeof(object));
+            table.Rows.Add(new MyXmlSerializableClass());
+
+            string asXml = WriteXmlWithSchema(table.WriteXml, XmlWriteMode.IgnoreSchema);
+
+            // Act & assert
+            // MyXmlSerializableClass shouldn't be allowed as a member for a column
+            // typed as 'object'.
+
+            table.Rows.Clear();
+            Assert.Throws<InvalidOperationException>(() => table.ReadXml(new StringReader(asXml)));
+        }
+
+        [Theory]
+        [MemberData(nameof(ForbiddenTypes))]
+        public void DataTable_ReadXmlSchema_AllowsUnknownTypes(Type type)
+        {
+            // Arrange
+
+            DataTable table = new DataTable("MyTable");
+            table.Columns.Add("MyColumn", type);
+
+            string asXml = WriteXmlWithSchema(table.WriteXml);
+
+            // Act
+
+            table = new DataTable();
+            table.ReadXmlSchema(new StringReader(asXml));
+
+            // Assert
+
+            Assert.Equal("MyTable", table.TableName);
+            Assert.Equal(1, table.Columns.Count);
+            Assert.Equal("MyColumn", table.Columns[0].ColumnName);
+            Assert.Equal(type, table.Columns[0].DataType);
+        }
+
+        [Fact]
+        public void DataTable_HonorsGloballyDefinedAllowList()
+        {
+            // Arrange
+
+            DataTable table = new DataTable("MyTable");
+            table.Columns.Add("MyColumn", typeof(MyCustomClass));
+
+            string asXml = WriteXmlWithSchema(table.WriteXml);
+
+            // Act & assert 1
+            // First call should fail since MyCustomClass not allowed
+
+            Assert.Throws<InvalidOperationException>(() => ReadXml<DataTable>(asXml));
+
+            // Act & assert 2
+            // Deserialization should succeed since it's now in the allow list
+
+            try
+            {
+                AppDomain.CurrentDomain.SetData(AppDomainDataSetDefaultAllowedTypesKey, new Type[]
+                {
+                    typeof(MyCustomClass)
+                });
+
+                table = ReadXml<DataTable>(asXml);
+
+                Assert.Equal("MyTable", table.TableName);
+                Assert.Equal(1, table.Columns.Count);
+                Assert.Equal("MyColumn", table.Columns[0].ColumnName);
+                Assert.Equal(typeof(MyCustomClass), table.Columns[0].DataType);
+            }
+            finally
+            {
+                AppDomain.CurrentDomain.SetData(AppDomainDataSetDefaultAllowedTypesKey, null);
+            }
+        }
+
+        [Fact]
+        public void DataColumn_ConvertExpression_SubjectToAllowList_Success()
+        {
+            // Arrange
+
+            DataTable table = new DataTable("MyTable");
+            table.Columns.Add("MyColumn", typeof(object), "CONVERT('42', 'System.Int32')");
+
+            string asXml = WriteXmlWithSchema(table.WriteXml);
+
+            // Act
+
+            table = ReadXml<DataTable>(asXml);
+
+            // Assert
+
+            Assert.Equal("MyTable", table.TableName);
+            Assert.Equal(1, table.Columns.Count);
+            Assert.Equal("MyColumn", table.Columns[0].ColumnName);
+            Assert.Equal(typeof(object), table.Columns[0].DataType);
+            Assert.Equal("CONVERT('42', 'System.Int32')", table.Columns[0].Expression);
+        }
+
+        [Fact]
+        public void DataColumn_ConvertExpression_SubjectToAllowList_Failure()
+        {
+            // Arrange
+
+            DataTable table = new DataTable("MyTable");
+            table.Columns.Add("ColumnA", typeof(object));
+            table.Columns.Add("ColumnB", typeof(object), "CONVERT(ColumnA, 'System.Text.StringBuilder')");
+
+            string asXml = WriteXmlWithSchema(table.WriteXml);
+
+            // Act
+            // 'StringBuilder' isn't in the allow list, but we're not yet hydrating the Type
+            // object so we won't check it just yet.
+
+            table = ReadXml<DataTable>(asXml);
+
+            // Assert - the CONVERT function node should have captured the active allow list
+            // at construction and should apply it now.
+
+            Assert.Throws<InvalidOperationException>(() => table.Rows.Add(new StringBuilder()));
+        }
+
+        [Theory]
+        [MemberData(nameof(AllowedTypes))]
+        public void DataSet_ReadXml_AllowsKnownTypes(Type type)
+        {
+            // Arrange
+
+            DataSet set = new DataSet("MySet");
+            DataTable table = new DataTable("MyTable");
+            table.Columns.Add("MyColumn", type);
+            set.Tables.Add(table);
+
+            string asXml = WriteXmlWithSchema(set.WriteXml);
+
+            // Act
+
+            table = null;
+            set = ReadXml<DataSet>(asXml);
+
+            // Assert
+
+            Assert.Equal("MySet", set.DataSetName);
+            Assert.Equal(1, set.Tables.Count);
+
+            table = set.Tables[0];
+            Assert.Equal("MyTable", table.TableName);
+            Assert.Equal(1, table.Columns.Count);
+            Assert.Equal("MyColumn", table.Columns[0].ColumnName);
+            Assert.Equal(type, table.Columns[0].DataType);
+        }
+
+        [Theory]
+        [MemberData(nameof(ForbiddenTypes))]
+        public void DataSet_ReadXml_ForbidsUnknownTypes(Type type)
+        {
+            // Arrange
+
+            DataSet set = new DataSet("MySet");
+            DataTable table = new DataTable("MyTable");
+            table.Columns.Add("MyColumn", type);
+            set.Tables.Add(table);
+
+            string asXml = WriteXmlWithSchema(set.WriteXml);
+
+            // Act & assert
+
+            Assert.Throws<InvalidOperationException>(() => ReadXml<DataSet>(asXml));
+        }
+
+        [Theory]
+        [MemberData(nameof(ForbiddenTypes))]
+        public void DataSet_ReadXmlSchema_AllowsUnknownTypes(Type type)
+        {
+            // Arrange
+
+            DataSet set = new DataSet("MySet");
+            DataTable table = new DataTable("MyTable");
+            table.Columns.Add("MyColumn", type);
+            set.Tables.Add(table);
+
+            string asXml = WriteXmlWithSchema(set.WriteXml);
+
+            // Act
+
+            set = new DataSet();
+            set.ReadXmlSchema(new StringReader(asXml));
+
+            // Assert
+
+            Assert.Equal("MySet", set.DataSetName);
+            Assert.Equal(1, set.Tables.Count);
+
+            table = set.Tables[0];
+            Assert.Equal("MyTable", table.TableName);
+            Assert.Equal(1, table.Columns.Count);
+            Assert.Equal("MyColumn", table.Columns[0].ColumnName);
+            Assert.Equal(type, table.Columns[0].DataType);
+        }
+
+        [Fact]
+        public void SerializationGuard_BlocksFileAccessOnDeserialize()
+        {
+            // Arrange
+
+            DataTable table = new DataTable("MyTable");
+            table.Columns.Add("MyColumn", typeof(MyCustomClassThatWritesToAFile));
+            table.Rows.Add(new MyCustomClassThatWritesToAFile());
+
+            string asXml = WriteXmlWithSchema(table.WriteXml);
+            table.Rows.Clear();
+
+            // Act & assert
+
+            Assert.Throws<SerializationException>(() => table.ReadXml(new StringReader(asXml)));
+        }
+
+        private static string WriteXmlWithSchema(Action<TextWriter, XmlWriteMode> writeMethod, XmlWriteMode xmlWriteMode = XmlWriteMode.WriteSchema)
+        {
+            StringWriter writer = new StringWriter();
+            writeMethod(writer, xmlWriteMode);
+            return writer.ToString();
+        }
+
+        private static T ReadXml<T>(string xml) where T : IXmlSerializable, new()
+        {
+            T newObj = new T();
+            newObj.ReadXml(new XmlTextReader(new StringReader(xml)) { XmlResolver = null }); // suppress DTDs, same as runtime code
+            return newObj;
+        }
+
+        private sealed class MyCustomClass
+        {
+        }
+
+        public sealed class MyXmlSerializableClass : IXmlSerializable
+        {
+            public XmlSchema GetSchema()
+            {
+                return null;
+            }
+
+            public void ReadXml(XmlReader reader)
+            {
+                return; // no-op
+            }
+
+            public void WriteXml(XmlWriter writer)
+            {
+                writer.WriteElementString("MyElement", "MyValue");
+            }
+        }
+
+        private sealed class MyCustomClassThatWritesToAFile : IXmlSerializable
+        {
+            public XmlSchema GetSchema()
+            {
+                return null;
+            }
+
+            public void ReadXml(XmlReader reader)
+            {
+                // This should be called within a Serialization Guard scope, so the file write
+                // should fail.
+
+                string tempPath = Path.GetTempFileName();
+                File.WriteAllText(tempPath, "This better not be written...");
+                File.Delete(tempPath);
+                throw new XunitException("Unreachable code (SerializationGuard should have kicked in)");
+            }
+
+            public void WriteXml(XmlWriter writer)
+            {
+                writer.WriteElementString("MyElement", "MyValue");
+            }
+        }
+    }
+}