Add HeaderEncodingSelector to MultipartContent (#39169)
authorMiha Zupan <mihazupan.zupan1@gmail.com>
Thu, 6 Aug 2020 02:56:52 +0000 (04:56 +0200)
committerGitHub <noreply@github.com>
Thu, 6 Aug 2020 02:56:52 +0000 (04:56 +0200)
* Add HeaderEncodingSelector to MultipartContent

* Test cleanup

* Avoid WriteLatin1 logic duplication

* Move to common HeaderEncodingSelector<TContext>

* Fix indentation

src/libraries/System.Net.Http/ref/System.Net.Http.cs
src/libraries/System.Net.Http/src/System/Net/Http/MultipartContent.cs
src/libraries/System.Net.Http/tests/FunctionalTests/MultipartContentTest.cs
src/libraries/System.Net.Http/tests/UnitTests/Headers/MultipartContentTest.cs [new file with mode: 0644]
src/libraries/System.Net.Http/tests/UnitTests/System.Net.Http.Unit.Tests.csproj

index f259892..415a52f 100644 (file)
@@ -285,6 +285,7 @@ namespace System.Net.Http
         public MultipartContent() { }
         public MultipartContent(string subtype) { }
         public MultipartContent(string subtype, string boundary) { }
+        public System.Net.Http.HeaderEncodingSelector<System.Net.Http.HttpContent>? HeaderEncodingSelector { get { throw null; } set { } }
         public virtual void Add(System.Net.Http.HttpContent content) { }
         protected override System.IO.Stream CreateContentReadStream(System.Threading.CancellationToken cancellationToken) { throw null; }
         protected override System.Threading.Tasks.Task<System.IO.Stream> CreateContentReadStreamAsync() { throw null; }
index 4f1a1dd..7ea8eb2 100644 (file)
@@ -1,6 +1,7 @@
 // Licensed to the .NET Foundation under one or more agreements.
 // The .NET Foundation licenses this file to you under the MIT license.
 
+using System.Buffers;
 using System.Collections.Generic;
 using System.Diagnostics;
 using System.Diagnostics.CodeAnalysis;
@@ -18,10 +19,10 @@ namespace System.Net.Http
 
         private const string CrLf = "\r\n";
 
-        private static readonly int s_crlfLength = GetEncodedLength(CrLf);
-        private static readonly int s_dashDashLength = GetEncodedLength("--");
-        private static readonly int s_colonSpaceLength = GetEncodedLength(": ");
-        private static readonly int s_commaSpaceLength = GetEncodedLength(", ");
+        private const int CrLfLength = 2;
+        private const int DashDashLength = 2;
+        private const int ColonSpaceLength = 2;
+        private const int CommaSpaceLength = 2;
 
         private readonly List<HttpContent> _nestedContent;
         private readonly string _boundary;
@@ -157,6 +158,12 @@ namespace System.Net.Http
 
         #region Serialization
 
+        /// <summary>
+        /// Gets or sets a callback that returns the <see cref="Encoding"/> to decode the value for the specified response header name,
+        /// or <see langword="null"/> to use the default behavior.
+        /// </summary>
+        public HeaderEncodingSelector<HttpContent>? HeaderEncodingSelector { get; set; }
+
         // for-each content
         //   write "--" + boundary
         //   for-each content header
@@ -171,20 +178,19 @@ namespace System.Net.Http
             try
             {
                 // Write start boundary.
-                EncodeStringToStream(stream, "--" + _boundary + CrLf);
+                WriteToStream(stream, "--" + _boundary + CrLf);
 
                 // Write each nested content.
-                var output = new StringBuilder();
                 for (int contentIndex = 0; contentIndex < _nestedContent.Count; contentIndex++)
                 {
                     // Write divider, headers, and content.
                     HttpContent content = _nestedContent[contentIndex];
-                    EncodeStringToStream(stream, SerializeHeadersToString(output, contentIndex, content));
+                    SerializeHeadersToStream(stream, content, writeDivider: contentIndex != 0);
                     content.CopyTo(stream, context, cancellationToken);
                 }
 
                 // Write footer boundary.
-                EncodeStringToStream(stream, CrLf + "--" + _boundary + "--" + CrLf);
+                WriteToStream(stream, CrLf + "--" + _boundary + "--" + CrLf);
             }
             catch (Exception ex)
             {
@@ -219,12 +225,17 @@ namespace System.Net.Http
                 await EncodeStringToStreamAsync(stream, "--" + _boundary + CrLf, cancellationToken).ConfigureAwait(false);
 
                 // Write each nested content.
-                var output = new StringBuilder();
+                var output = new MemoryStream();
                 for (int contentIndex = 0; contentIndex < _nestedContent.Count; contentIndex++)
                 {
                     // Write divider, headers, and content.
                     HttpContent content = _nestedContent[contentIndex];
-                    await EncodeStringToStreamAsync(stream, SerializeHeadersToString(output, contentIndex, content), cancellationToken).ConfigureAwait(false);
+
+                    output.SetLength(0);
+                    SerializeHeadersToStream(output, content, writeDivider: contentIndex != 0);
+                    output.Position = 0;
+                    await output.CopyToAsync(stream, cancellationToken).ConfigureAwait(false);
+
                     await content.CopyToAsync(stream, context, cancellationToken).ConfigureAwait(false);
                 }
 
@@ -259,7 +270,6 @@ namespace System.Net.Http
             try
             {
                 var streams = new Stream[2 + (_nestedContent.Count * 2)];
-                var scratch = new StringBuilder();
                 int streamIndex = 0;
 
                 // Start boundary.
@@ -271,7 +281,7 @@ namespace System.Net.Http
                     cancellationToken.ThrowIfCancellationRequested();
 
                     HttpContent nestedContent = _nestedContent[contentIndex];
-                    streams[streamIndex++] = EncodeStringToNewStream(SerializeHeadersToString(scratch, contentIndex, nestedContent));
+                    streams[streamIndex++] = EncodeHeadersToNewStream(nestedContent, writeDivider: contentIndex != 0);
 
                     Stream readStream;
                     if (async)
@@ -312,43 +322,35 @@ namespace System.Net.Http
             }
         }
 
-        private string SerializeHeadersToString(StringBuilder scratch, int contentIndex, HttpContent content)
+        private void SerializeHeadersToStream(Stream stream, HttpContent content, bool writeDivider)
         {
-            scratch.Clear();
-
             // Add divider.
-            if (contentIndex != 0) // Write divider for all but the first content.
+            if (writeDivider) // Write divider for all but the first content.
             {
-                scratch.Append(CrLf + "--"); // const strings
-                scratch.Append(_boundary);
-                scratch.Append(CrLf);
+                WriteToStream(stream, CrLf + "--"); // const strings
+                WriteToStream(stream, _boundary);
+                WriteToStream(stream, CrLf);
             }
 
             // Add headers.
             foreach (KeyValuePair<string, IEnumerable<string>> headerPair in content.Headers)
             {
-                scratch.Append(headerPair.Key);
-                scratch.Append(": ");
+                Encoding headerValueEncoding = HeaderEncodingSelector?.Invoke(headerPair.Key, content) ?? HttpRuleParser.DefaultHttpEncoding;
+
+                WriteToStream(stream, headerPair.Key);
+                WriteToStream(stream, ": ");
                 string delim = string.Empty;
                 foreach (string value in headerPair.Value)
                 {
-                    scratch.Append(delim);
-                    scratch.Append(value);
+                    WriteToStream(stream, delim);
+                    WriteToStream(stream, value, headerValueEncoding);
                     delim = ", ";
                 }
-                scratch.Append(CrLf);
+                WriteToStream(stream, CrLf);
             }
 
             // Extra CRLF to end headers (even if there are no headers).
-            scratch.Append(CrLf);
-
-            return scratch.ToString();
-        }
-
-        private static void EncodeStringToStream(Stream stream, string input)
-        {
-            byte[] buffer = HttpRuleParser.DefaultHttpEncoding.GetBytes(input);
-            stream.Write(buffer);
+            WriteToStream(stream, CrLf);
         }
 
         private static ValueTask EncodeStringToStreamAsync(Stream stream, string input, CancellationToken cancellationToken)
@@ -362,55 +364,55 @@ namespace System.Net.Http
             return new MemoryStream(HttpRuleParser.DefaultHttpEncoding.GetBytes(input), writable: false);
         }
 
+        private Stream EncodeHeadersToNewStream(HttpContent content, bool writeDivider)
+        {
+            var stream = new MemoryStream();
+            SerializeHeadersToStream(stream, content, writeDivider);
+            stream.Position = 0;
+            return stream;
+        }
+
         internal override bool AllowDuplex => false;
 
         protected internal override bool TryComputeLength(out long length)
         {
-            int boundaryLength = GetEncodedLength(_boundary);
-
-            long currentLength = 0;
-            long internalBoundaryLength = s_crlfLength + s_dashDashLength + boundaryLength + s_crlfLength;
-
             // Start Boundary.
-            currentLength += s_dashDashLength + boundaryLength + s_crlfLength;
+            long currentLength = DashDashLength + _boundary.Length + CrLfLength;
 
-            bool first = true;
-            foreach (HttpContent content in _nestedContent)
+            if (_nestedContent.Count > 1)
             {
-                if (first)
-                {
-                    first = false; // First boundary already written.
-                }
-                else
-                {
-                    // Internal Boundary.
-                    currentLength += internalBoundaryLength;
-                }
+                // Internal boundaries
+                currentLength += (_nestedContent.Count - 1) * (CrLfLength + DashDashLength + _boundary.Length + CrLfLength);
+            }
 
+            foreach (HttpContent content in _nestedContent)
+            {
                 // Headers.
                 foreach (KeyValuePair<string, IEnumerable<string>> headerPair in content.Headers)
                 {
-                    currentLength += GetEncodedLength(headerPair.Key) + s_colonSpaceLength;
+                    currentLength += headerPair.Key.Length + ColonSpaceLength;
+
+                    Encoding headerValueEncoding = HeaderEncodingSelector?.Invoke(headerPair.Key, content) ?? HttpRuleParser.DefaultHttpEncoding;
 
                     int valueCount = 0;
                     foreach (string value in headerPair.Value)
                     {
-                        currentLength += GetEncodedLength(value);
+                        currentLength += headerValueEncoding.GetByteCount(value);
                         valueCount++;
                     }
+
                     if (valueCount > 1)
                     {
-                        currentLength += (valueCount - 1) * s_commaSpaceLength;
+                        currentLength += (valueCount - 1) * CommaSpaceLength;
                     }
 
-                    currentLength += s_crlfLength;
+                    currentLength += CrLfLength;
                 }
 
-                currentLength += s_crlfLength;
+                currentLength += CrLfLength;
 
                 // Content.
-                long tempContentLength = 0;
-                if (!content.TryComputeLength(out tempContentLength))
+                if (!content.TryComputeLength(out long tempContentLength))
                 {
                     length = 0;
                     return false;
@@ -419,17 +421,12 @@ namespace System.Net.Http
             }
 
             // Terminating boundary.
-            currentLength += s_crlfLength + s_dashDashLength + boundaryLength + s_dashDashLength + s_crlfLength;
+            currentLength += CrLfLength + DashDashLength + _boundary.Length + DashDashLength + CrLfLength;
 
             length = currentLength;
             return true;
         }
 
-        private static int GetEncodedLength(string input)
-        {
-            return HttpRuleParser.DefaultHttpEncoding.GetByteCount(input);
-        }
-
         private sealed class ContentReadStream : Stream
         {
             private readonly Stream[] _streams;
@@ -671,6 +668,36 @@ namespace System.Net.Http
             public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { throw new NotSupportedException(); }
             public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default) { throw new NotSupportedException(); }
         }
+
+
+        private static void WriteToStream(Stream stream, string content) =>
+            WriteToStream(stream, content, HttpRuleParser.DefaultHttpEncoding);
+
+        private static void WriteToStream(Stream stream, string content, Encoding encoding)
+        {
+            const int StackallocThreshold = 1024;
+
+            int maxLength = encoding.GetMaxByteCount(content.Length);
+
+            byte[]? rentedBuffer = null;
+            Span<byte> buffer = maxLength <= StackallocThreshold
+                ? stackalloc byte[StackallocThreshold]
+                : (rentedBuffer = ArrayPool<byte>.Shared.Rent(maxLength));
+
+            try
+            {
+                int written = encoding.GetBytes(content, buffer);
+                stream.Write(buffer.Slice(0, written));
+            }
+            finally
+            {
+                if (rentedBuffer != null)
+                {
+                    ArrayPool<byte>.Shared.Return(rentedBuffer);
+                }
+            }
+        }
+
         #endregion Serialization
     }
 }
index 7565e5b..a945cc3 100644 (file)
@@ -2,6 +2,7 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 
 using System.IO;
+using System.Linq;
 using System.Text;
 using System.Threading;
 using System.Threading.Tasks;
@@ -387,6 +388,122 @@ namespace System.Net.Http.Functional.Tests
             Assert.Throws<NotImplementedException>(() => mc.ReadAsStream());
         }
 
+        [Theory]
+        [InlineData(true)]
+        [InlineData(false)]
+        public async Task ReadAsStreamAsync_CustomEncodingSelector_SelectorIsCalledWithCustomState(bool async)
+        {
+            var mc = new MultipartContent();
+
+            var stringContent = new StringContent("foo");
+            stringContent.Headers.Add("StringContent", "foo");
+            mc.Add(stringContent);
+
+            var byteArrayContent = new ByteArrayContent(Encoding.ASCII.GetBytes("foo"));
+            byteArrayContent.Headers.Add("ByteArrayContent", "foo");
+            mc.Add(byteArrayContent);
+
+            bool seenStringContent = false, seenByteArrayContent = false;
+
+            mc.HeaderEncodingSelector = (name, content) =>
+            {
+                if (ReferenceEquals(content, stringContent) && name == "StringContent")
+                {
+                    seenStringContent = true;
+                }
+
+                if (ReferenceEquals(content, byteArrayContent) && name == "ByteArrayContent")
+                {
+                    seenByteArrayContent = true;
+                }
+
+                return null;
+            };
+
+            var dummy = new MemoryStream();
+            if (async)
+            {
+                await (await mc.ReadAsStreamAsync()).CopyToAsync(dummy);
+            }
+            else
+            {
+                mc.ReadAsStream().CopyTo(dummy);
+            }
+
+            Assert.True(seenStringContent);
+            Assert.True(seenByteArrayContent);
+        }
+
+        [Theory]
+        [InlineData(true)]
+        [InlineData(false)]
+        public async Task ReadAsStreamAsync_CustomEncodingSelector_CustomEncodingIsUsed(bool async)
+        {
+            var mc = new MultipartContent("subtype", "fooBoundary");
+
+            var stringContent = new StringContent("bar1");
+            stringContent.Headers.Add("latin1", "\uD83D\uDE00");
+            mc.Add(stringContent);
+
+            var byteArrayContent = new ByteArrayContent(Encoding.ASCII.GetBytes("bar2"));
+            byteArrayContent.Headers.Add("utf8", "\uD83D\uDE00");
+            mc.Add(byteArrayContent);
+
+            byteArrayContent = new ByteArrayContent(Encoding.ASCII.GetBytes("bar3"));
+            byteArrayContent.Headers.Add("ascii", "\uD83D\uDE00");
+            mc.Add(byteArrayContent);
+
+            byteArrayContent = new ByteArrayContent(Encoding.ASCII.GetBytes("bar4"));
+            byteArrayContent.Headers.Add("default", "\uD83D\uDE00");
+            mc.Add(byteArrayContent);
+
+            mc.HeaderEncodingSelector = (name, _) => name switch
+            {
+                "latin1" => Encoding.Latin1,
+                "utf8" => Encoding.UTF8,
+                "ascii" => Encoding.ASCII,
+                _ => null
+            };
+
+            var ms = new MemoryStream();
+            if (async)
+            {
+                await (await mc.ReadAsStreamAsync()).CopyToAsync(ms);
+            }
+            else
+            {
+                mc.ReadAsStream().CopyTo(ms);
+            }
+
+            byte[] expected = Concat(
+                Encoding.Latin1.GetBytes("--fooBoundary\r\n"),
+                Encoding.Latin1.GetBytes("Content-Type: text/plain; charset=utf-8\r\n"),
+                Encoding.Latin1.GetBytes("latin1: "),
+                Encoding.Latin1.GetBytes("\uD83D\uDE00"),
+                Encoding.Latin1.GetBytes("\r\n\r\n"),
+                Encoding.Latin1.GetBytes("bar1"),
+                Encoding.Latin1.GetBytes("\r\n--fooBoundary\r\n"),
+                Encoding.Latin1.GetBytes("utf8: "),
+                Encoding.UTF8.GetBytes("\uD83D\uDE00"),
+                Encoding.Latin1.GetBytes("\r\n\r\n"),
+                Encoding.Latin1.GetBytes("bar2"),
+                Encoding.Latin1.GetBytes("\r\n--fooBoundary\r\n"),
+                Encoding.Latin1.GetBytes("ascii: "),
+                Encoding.ASCII.GetBytes("\uD83D\uDE00"),
+                Encoding.Latin1.GetBytes("\r\n\r\n"),
+                Encoding.Latin1.GetBytes("bar3"),
+                Encoding.Latin1.GetBytes("\r\n--fooBoundary\r\n"),
+                Encoding.Latin1.GetBytes("default: "),
+                Encoding.Latin1.GetBytes("\uD83D\uDE00"),
+                Encoding.Latin1.GetBytes("\r\n\r\n"),
+                Encoding.Latin1.GetBytes("bar4"),
+                Encoding.Latin1.GetBytes("\r\n--fooBoundary--\r\n"));
+
+            Assert.Equal(expected, ms.ToArray());
+
+            static byte[] Concat(params byte[][] arrays) => arrays.SelectMany(b => b).ToArray();
+        }
+
         #region Helpers
 
         private static async Task<string> MultipartContentToStringAsync(MultipartContent content, MultipartContentToStringMode mode, bool async)
diff --git a/src/libraries/System.Net.Http/tests/UnitTests/Headers/MultipartContentTest.cs b/src/libraries/System.Net.Http/tests/UnitTests/Headers/MultipartContentTest.cs
new file mode 100644 (file)
index 0000000..e89d38c
--- /dev/null
@@ -0,0 +1,86 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Collections.Generic;
+using System.IO;
+using System.Text;
+using System.Threading.Tasks;
+using Xunit;
+
+namespace System.Net.Http.Tests
+{
+    public class MultipartContentTest
+    {
+        public static IEnumerable<object[]> MultipartContent_TestData()
+        {
+            var multipartContents = new List<MultipartContent>();
+
+            var complexContent = new MultipartContent();
+
+            var stringContent = new StringContent("bar1");
+            stringContent.Headers.Add("latin1", "\uD83D\uDE00");
+            complexContent.Add(stringContent);
+
+            var byteArrayContent = new ByteArrayContent(Encoding.ASCII.GetBytes("bar2"));
+            byteArrayContent.Headers.Add("utf8", "\uD83D\uDE00");
+            complexContent.Add(byteArrayContent);
+
+            byteArrayContent = new ByteArrayContent(Encoding.ASCII.GetBytes("bar3"));
+            byteArrayContent.Headers.Add("ascii", "\uD83D\uDE00");
+            complexContent.Add(byteArrayContent);
+
+            byteArrayContent = new ByteArrayContent(Encoding.ASCII.GetBytes("bar4"));
+            byteArrayContent.Headers.Add("default", "\uD83D\uDE00");
+            complexContent.Add(byteArrayContent);
+
+            stringContent = new StringContent("bar5");
+            stringContent.Headers.Add("foo", "bar");
+            complexContent.Add(stringContent);
+
+            multipartContents.Add(complexContent);
+            multipartContents.Add(new MultipartContent());
+            multipartContents.Add(new MultipartFormDataContent());
+
+            var encodingSelectors = new HeaderEncodingSelector<HttpContent>[]
+            {
+                (_, _) => null,
+                (_, _) => Encoding.ASCII,
+                (_, _) => Encoding.Latin1,
+                (_, _) => Encoding.UTF8,
+                (name, _) => name switch
+                {
+                    "latin1" => Encoding.Latin1,
+                    "utf8" => Encoding.UTF8,
+                    "ascii" => Encoding.ASCII,
+                    _ => null
+                }
+            };
+
+            foreach (MultipartContent multipartContent in multipartContents)
+            {
+                foreach (HeaderEncodingSelector<HttpContent> encodingSelector in encodingSelectors)
+                {
+                    multipartContent.HeaderEncodingSelector = encodingSelector;
+                    yield return new object[] { multipartContent };
+                }
+            }
+        }
+
+        [Theory]
+        [MemberData(nameof(MultipartContent_TestData))]
+        public async Task MultipartContent_TryComputeLength_ReturnsSameLengthAsCopyToAsync(MultipartContent multipartContent)
+        {
+            Assert.True(multipartContent.TryComputeLength(out long length));
+
+            var copyToStream = new MemoryStream();
+            multipartContent.CopyTo(copyToStream, context: null, cancellationToken: default);
+            Assert.Equal(length, copyToStream.Length);
+
+            var copyToAsyncStream = new MemoryStream();
+            await multipartContent.CopyToAsync(copyToAsyncStream, context: null, cancellationToken: default);
+            Assert.Equal(length, copyToAsyncStream.Length);
+
+            Assert.Equal(copyToStream.ToArray(), copyToAsyncStream.ToArray());
+        }
+    }
+}
index c4d1682..19de137 100644 (file)
@@ -1,4 +1,4 @@
-<Project Sdk="Microsoft.NET.Sdk">
+<Project Sdk="Microsoft.NET.Sdk">
   <PropertyGroup>
     <StringResourcesPath>../../src/Resources/Strings.resx</StringResourcesPath>
     <AllowUnsafeBlocks>true</AllowUnsafeBlocks>
@@ -90,6 +90,8 @@
          Link="ProductionCode\System\Net\Http\EmptyReadStream.cs" />
     <Compile Include="..\..\src\System\Net\Http\FormUrlEncodedContent.cs"
              Link="ProductionCode\System\Net\Http\FormUrlEncodedContent.cs" />
+    <Compile Include="..\..\src\System\Net\Http\HeaderEncodingSelector.cs"
+             Link="ProductionCode\System\Net\Http\HeaderEncodingSelector.cs" />
     <Compile Include="..\..\src\System\Net\Http\Headers\AltSvcHeaderParser.cs"
              Link="ProductionCode\System\Net\Http\Headers\AltSvcHeaderParser.cs" />
     <Compile Include="..\..\src\System\Net\Http\Headers\AltSvcHeaderValue.cs"
     <Compile Include="Headers\MediaTypeHeaderParserTest.cs" />
     <Compile Include="Headers\MediaTypeHeaderValueTest.cs" />
     <Compile Include="Headers\MediaTypeWithQualityHeaderValueTest.cs" />
+    <Compile Include="Headers\MultipartContentTest.cs" />
     <Compile Include="Headers\NameValueHeaderValueTest.cs" />
     <Compile Include="Headers\NameValueWithParametersHeaderValueTest.cs" />
     <Compile Include="Headers\ObjectCollectionTest.cs" />