Throttle allocations in BinaryReader (#36348)
authorLevi Broderick <GrabYourPitchforks@users.noreply.github.com>
Wed, 13 May 2020 18:02:13 +0000 (11:02 -0700)
committerGitHub <noreply@github.com>
Wed, 13 May 2020 18:02:13 +0000 (18:02 +0000)
src/libraries/Common/src/System/Text/StringBuilderCache.cs
src/libraries/System.Private.CoreLib/src/System/IO/BinaryReader.cs
src/libraries/System.Private.DataContractSerialization/src/System/Xml/XmlBufferReader.cs

index aac2c2e..4d96cb3 100644 (file)
@@ -11,7 +11,7 @@ namespace System.Text
         // The value 360 was chosen in discussion with performance experts as a compromise between using
         // as litle memory per thread as possible and still covering a large part of short-lived
         // StringBuilder creations on the startup path of VS designers.
-        private const int MaxBuilderSize = 360;
+        internal const int MaxBuilderSize = 360;
         private const int DefaultCapacity = 16; // == StringBuilder.DefaultCapacity
 
         // WARNING: We allow diagnostic tools to directly inspect this member (t_cachedInstance).
index 1055013..8b9a53f 100644 (file)
@@ -306,7 +306,10 @@ namespace System.IO
                     return new string(_charBuffer, 0, charsRead);
                 }
 
-                sb ??= StringBuilderCache.Acquire(stringLength); // Actual string length in chars may be smaller.
+                // Since we could be reading from an untrusted data source, limit the initial size of the
+                // StringBuilder instance we're about to get or create. It'll expand automatically as needed.
+
+                sb ??= StringBuilderCache.Acquire(Math.Min(stringLength, StringBuilderCache.MaxBuilderSize)); // Actual string length in chars may be smaller.
                 sb.Append(_charBuffer, 0, charsRead);
                 currPos += n;
             } while (currPos < stringLength);
index 24920b8..43d74a8 100644 (file)
@@ -213,28 +213,37 @@ namespace System.Xml
         {
             if (_stream == null)
                 return false;
-            DiagnosticUtility.DebugAssert(_offset <= int.MaxValue - count, "");
-            int newOffsetMax = _offset + count;
-            if (newOffsetMax < _offsetMax)
-                return true;
-            DiagnosticUtility.DebugAssert(newOffsetMax <= _windowOffsetMax, "");
-            if (newOffsetMax > _buffer.Length)
-            {
-                byte[] newBuffer = new byte[Math.Max(newOffsetMax, _buffer.Length * 2)];
-                System.Buffer.BlockCopy(_buffer, 0, newBuffer, 0, _offsetMax);
-                _buffer = newBuffer;
-                _streamBuffer = newBuffer;
-            }
-            int needed = newOffsetMax - _offsetMax;
-            while (needed > 0)
+
+            // The data could be coming from an untrusted source, so we use a standard
+            // "multiply by 2" growth algorithm to avoid overly large memory utilization.
+            // Constant value of 256 comes from MemoryStream implementation.
+
+            do
             {
-                int actual = _stream.Read(_buffer, _offsetMax, needed);
-                if (actual == 0)
-                    return false;
-                _offsetMax += actual;
-                needed -= actual;
-            }
-            return true;
+                DiagnosticUtility.DebugAssert(_offset <= int.MaxValue - count, "");
+                int newOffsetMax = _offset + count;
+                if (newOffsetMax <= _offsetMax)
+                    return true;
+                DiagnosticUtility.DebugAssert(newOffsetMax <= _windowOffsetMax, "");
+                if (newOffsetMax > _buffer.Length)
+                {
+                    byte[] newBuffer = new byte[Math.Max(256, _buffer.Length * 2)];
+                    System.Buffer.BlockCopy(_buffer, 0, newBuffer, 0, _offsetMax);
+                    newOffsetMax = Math.Min(newOffsetMax, newBuffer.Length);
+                    _buffer = newBuffer;
+                    _streamBuffer = newBuffer;
+                }
+                int needed = newOffsetMax - _offsetMax;
+                DiagnosticUtility.DebugAssert(needed > 0, "");
+                do
+                {
+                    int actual = _stream.Read(_buffer, _offsetMax, needed);
+                    if (actual == 0)
+                        return false;
+                    _offsetMax += actual;
+                    needed -= actual;
+                } while (needed > 0);
+            } while (true);
         }
 
         public void Advance(int count)