From 53d70b08ab088c318e4137ea1f1b67ac0fb06a60 Mon Sep 17 00:00:00 2001 From: Miha Zupan Date: Thu, 6 Aug 2020 04:56:52 +0200 Subject: [PATCH] Add HeaderEncodingSelector to MultipartContent (#39169) * Add HeaderEncodingSelector to MultipartContent * Test cleanup * Avoid WriteLatin1 logic duplication * Move to common HeaderEncodingSelector * Fix indentation --- .../System.Net.Http/ref/System.Net.Http.cs | 1 + .../src/System/Net/Http/MultipartContent.cs | 153 ++++++++++++--------- .../tests/FunctionalTests/MultipartContentTest.cs | 117 ++++++++++++++++ .../UnitTests/Headers/MultipartContentTest.cs | 86 ++++++++++++ .../UnitTests/System.Net.Http.Unit.Tests.csproj | 5 +- 5 files changed, 298 insertions(+), 64 deletions(-) create mode 100644 src/libraries/System.Net.Http/tests/UnitTests/Headers/MultipartContentTest.cs diff --git a/src/libraries/System.Net.Http/ref/System.Net.Http.cs b/src/libraries/System.Net.Http/ref/System.Net.Http.cs index f259892..415a52f 100644 --- a/src/libraries/System.Net.Http/ref/System.Net.Http.cs +++ b/src/libraries/System.Net.Http/ref/System.Net.Http.cs @@ -285,6 +285,7 @@ namespace System.Net.Http public MultipartContent() { } public MultipartContent(string subtype) { } public MultipartContent(string subtype, string boundary) { } + public System.Net.Http.HeaderEncodingSelector? HeaderEncodingSelector { get { throw null; } set { } } public virtual void Add(System.Net.Http.HttpContent content) { } protected override System.IO.Stream CreateContentReadStream(System.Threading.CancellationToken cancellationToken) { throw null; } protected override System.Threading.Tasks.Task CreateContentReadStreamAsync() { throw null; } diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/MultipartContent.cs b/src/libraries/System.Net.Http/src/System/Net/Http/MultipartContent.cs index 4f1a1dd..7ea8eb2 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/MultipartContent.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/MultipartContent.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Buffers; using System.Collections.Generic; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; @@ -18,10 +19,10 @@ namespace System.Net.Http private const string CrLf = "\r\n"; - private static readonly int s_crlfLength = GetEncodedLength(CrLf); - private static readonly int s_dashDashLength = GetEncodedLength("--"); - private static readonly int s_colonSpaceLength = GetEncodedLength(": "); - private static readonly int s_commaSpaceLength = GetEncodedLength(", "); + private const int CrLfLength = 2; + private const int DashDashLength = 2; + private const int ColonSpaceLength = 2; + private const int CommaSpaceLength = 2; private readonly List _nestedContent; private readonly string _boundary; @@ -157,6 +158,12 @@ namespace System.Net.Http #region Serialization + /// + /// Gets or sets a callback that returns the to decode the value for the specified response header name, + /// or to use the default behavior. + /// + public HeaderEncodingSelector? HeaderEncodingSelector { get; set; } + // for-each content // write "--" + boundary // for-each content header @@ -171,20 +178,19 @@ namespace System.Net.Http try { // Write start boundary. - EncodeStringToStream(stream, "--" + _boundary + CrLf); + WriteToStream(stream, "--" + _boundary + CrLf); // Write each nested content. - var output = new StringBuilder(); for (int contentIndex = 0; contentIndex < _nestedContent.Count; contentIndex++) { // Write divider, headers, and content. HttpContent content = _nestedContent[contentIndex]; - EncodeStringToStream(stream, SerializeHeadersToString(output, contentIndex, content)); + SerializeHeadersToStream(stream, content, writeDivider: contentIndex != 0); content.CopyTo(stream, context, cancellationToken); } // Write footer boundary. - EncodeStringToStream(stream, CrLf + "--" + _boundary + "--" + CrLf); + WriteToStream(stream, CrLf + "--" + _boundary + "--" + CrLf); } catch (Exception ex) { @@ -219,12 +225,17 @@ namespace System.Net.Http await EncodeStringToStreamAsync(stream, "--" + _boundary + CrLf, cancellationToken).ConfigureAwait(false); // Write each nested content. - var output = new StringBuilder(); + var output = new MemoryStream(); for (int contentIndex = 0; contentIndex < _nestedContent.Count; contentIndex++) { // Write divider, headers, and content. HttpContent content = _nestedContent[contentIndex]; - await EncodeStringToStreamAsync(stream, SerializeHeadersToString(output, contentIndex, content), cancellationToken).ConfigureAwait(false); + + output.SetLength(0); + SerializeHeadersToStream(output, content, writeDivider: contentIndex != 0); + output.Position = 0; + await output.CopyToAsync(stream, cancellationToken).ConfigureAwait(false); + await content.CopyToAsync(stream, context, cancellationToken).ConfigureAwait(false); } @@ -259,7 +270,6 @@ namespace System.Net.Http try { var streams = new Stream[2 + (_nestedContent.Count * 2)]; - var scratch = new StringBuilder(); int streamIndex = 0; // Start boundary. @@ -271,7 +281,7 @@ namespace System.Net.Http cancellationToken.ThrowIfCancellationRequested(); HttpContent nestedContent = _nestedContent[contentIndex]; - streams[streamIndex++] = EncodeStringToNewStream(SerializeHeadersToString(scratch, contentIndex, nestedContent)); + streams[streamIndex++] = EncodeHeadersToNewStream(nestedContent, writeDivider: contentIndex != 0); Stream readStream; if (async) @@ -312,43 +322,35 @@ namespace System.Net.Http } } - private string SerializeHeadersToString(StringBuilder scratch, int contentIndex, HttpContent content) + private void SerializeHeadersToStream(Stream stream, HttpContent content, bool writeDivider) { - scratch.Clear(); - // Add divider. - if (contentIndex != 0) // Write divider for all but the first content. + if (writeDivider) // Write divider for all but the first content. { - scratch.Append(CrLf + "--"); // const strings - scratch.Append(_boundary); - scratch.Append(CrLf); + WriteToStream(stream, CrLf + "--"); // const strings + WriteToStream(stream, _boundary); + WriteToStream(stream, CrLf); } // Add headers. foreach (KeyValuePair> headerPair in content.Headers) { - scratch.Append(headerPair.Key); - scratch.Append(": "); + Encoding headerValueEncoding = HeaderEncodingSelector?.Invoke(headerPair.Key, content) ?? HttpRuleParser.DefaultHttpEncoding; + + WriteToStream(stream, headerPair.Key); + WriteToStream(stream, ": "); string delim = string.Empty; foreach (string value in headerPair.Value) { - scratch.Append(delim); - scratch.Append(value); + WriteToStream(stream, delim); + WriteToStream(stream, value, headerValueEncoding); delim = ", "; } - scratch.Append(CrLf); + WriteToStream(stream, CrLf); } // Extra CRLF to end headers (even if there are no headers). - scratch.Append(CrLf); - - return scratch.ToString(); - } - - private static void EncodeStringToStream(Stream stream, string input) - { - byte[] buffer = HttpRuleParser.DefaultHttpEncoding.GetBytes(input); - stream.Write(buffer); + WriteToStream(stream, CrLf); } private static ValueTask EncodeStringToStreamAsync(Stream stream, string input, CancellationToken cancellationToken) @@ -362,55 +364,55 @@ namespace System.Net.Http return new MemoryStream(HttpRuleParser.DefaultHttpEncoding.GetBytes(input), writable: false); } + private Stream EncodeHeadersToNewStream(HttpContent content, bool writeDivider) + { + var stream = new MemoryStream(); + SerializeHeadersToStream(stream, content, writeDivider); + stream.Position = 0; + return stream; + } + internal override bool AllowDuplex => false; protected internal override bool TryComputeLength(out long length) { - int boundaryLength = GetEncodedLength(_boundary); - - long currentLength = 0; - long internalBoundaryLength = s_crlfLength + s_dashDashLength + boundaryLength + s_crlfLength; - // Start Boundary. - currentLength += s_dashDashLength + boundaryLength + s_crlfLength; + long currentLength = DashDashLength + _boundary.Length + CrLfLength; - bool first = true; - foreach (HttpContent content in _nestedContent) + if (_nestedContent.Count > 1) { - if (first) - { - first = false; // First boundary already written. - } - else - { - // Internal Boundary. - currentLength += internalBoundaryLength; - } + // Internal boundaries + currentLength += (_nestedContent.Count - 1) * (CrLfLength + DashDashLength + _boundary.Length + CrLfLength); + } + foreach (HttpContent content in _nestedContent) + { // Headers. foreach (KeyValuePair> headerPair in content.Headers) { - currentLength += GetEncodedLength(headerPair.Key) + s_colonSpaceLength; + currentLength += headerPair.Key.Length + ColonSpaceLength; + + Encoding headerValueEncoding = HeaderEncodingSelector?.Invoke(headerPair.Key, content) ?? HttpRuleParser.DefaultHttpEncoding; int valueCount = 0; foreach (string value in headerPair.Value) { - currentLength += GetEncodedLength(value); + currentLength += headerValueEncoding.GetByteCount(value); valueCount++; } + if (valueCount > 1) { - currentLength += (valueCount - 1) * s_commaSpaceLength; + currentLength += (valueCount - 1) * CommaSpaceLength; } - currentLength += s_crlfLength; + currentLength += CrLfLength; } - currentLength += s_crlfLength; + currentLength += CrLfLength; // Content. - long tempContentLength = 0; - if (!content.TryComputeLength(out tempContentLength)) + if (!content.TryComputeLength(out long tempContentLength)) { length = 0; return false; @@ -419,17 +421,12 @@ namespace System.Net.Http } // Terminating boundary. - currentLength += s_crlfLength + s_dashDashLength + boundaryLength + s_dashDashLength + s_crlfLength; + currentLength += CrLfLength + DashDashLength + _boundary.Length + DashDashLength + CrLfLength; length = currentLength; return true; } - private static int GetEncodedLength(string input) - { - return HttpRuleParser.DefaultHttpEncoding.GetByteCount(input); - } - private sealed class ContentReadStream : Stream { private readonly Stream[] _streams; @@ -671,6 +668,36 @@ namespace System.Net.Http public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { throw new NotSupportedException(); } public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) { throw new NotSupportedException(); } } + + + private static void WriteToStream(Stream stream, string content) => + WriteToStream(stream, content, HttpRuleParser.DefaultHttpEncoding); + + private static void WriteToStream(Stream stream, string content, Encoding encoding) + { + const int StackallocThreshold = 1024; + + int maxLength = encoding.GetMaxByteCount(content.Length); + + byte[]? rentedBuffer = null; + Span buffer = maxLength <= StackallocThreshold + ? stackalloc byte[StackallocThreshold] + : (rentedBuffer = ArrayPool.Shared.Rent(maxLength)); + + try + { + int written = encoding.GetBytes(content, buffer); + stream.Write(buffer.Slice(0, written)); + } + finally + { + if (rentedBuffer != null) + { + ArrayPool.Shared.Return(rentedBuffer); + } + } + } + #endregion Serialization } } diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/MultipartContentTest.cs b/src/libraries/System.Net.Http/tests/FunctionalTests/MultipartContentTest.cs index 7565e5b..a945cc3 100644 --- a/src/libraries/System.Net.Http/tests/FunctionalTests/MultipartContentTest.cs +++ b/src/libraries/System.Net.Http/tests/FunctionalTests/MultipartContentTest.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.IO; +using System.Linq; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -387,6 +388,122 @@ namespace System.Net.Http.Functional.Tests Assert.Throws(() => mc.ReadAsStream()); } + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task ReadAsStreamAsync_CustomEncodingSelector_SelectorIsCalledWithCustomState(bool async) + { + var mc = new MultipartContent(); + + var stringContent = new StringContent("foo"); + stringContent.Headers.Add("StringContent", "foo"); + mc.Add(stringContent); + + var byteArrayContent = new ByteArrayContent(Encoding.ASCII.GetBytes("foo")); + byteArrayContent.Headers.Add("ByteArrayContent", "foo"); + mc.Add(byteArrayContent); + + bool seenStringContent = false, seenByteArrayContent = false; + + mc.HeaderEncodingSelector = (name, content) => + { + if (ReferenceEquals(content, stringContent) && name == "StringContent") + { + seenStringContent = true; + } + + if (ReferenceEquals(content, byteArrayContent) && name == "ByteArrayContent") + { + seenByteArrayContent = true; + } + + return null; + }; + + var dummy = new MemoryStream(); + if (async) + { + await (await mc.ReadAsStreamAsync()).CopyToAsync(dummy); + } + else + { + mc.ReadAsStream().CopyTo(dummy); + } + + Assert.True(seenStringContent); + Assert.True(seenByteArrayContent); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task ReadAsStreamAsync_CustomEncodingSelector_CustomEncodingIsUsed(bool async) + { + var mc = new MultipartContent("subtype", "fooBoundary"); + + var stringContent = new StringContent("bar1"); + stringContent.Headers.Add("latin1", "\uD83D\uDE00"); + mc.Add(stringContent); + + var byteArrayContent = new ByteArrayContent(Encoding.ASCII.GetBytes("bar2")); + byteArrayContent.Headers.Add("utf8", "\uD83D\uDE00"); + mc.Add(byteArrayContent); + + byteArrayContent = new ByteArrayContent(Encoding.ASCII.GetBytes("bar3")); + byteArrayContent.Headers.Add("ascii", "\uD83D\uDE00"); + mc.Add(byteArrayContent); + + byteArrayContent = new ByteArrayContent(Encoding.ASCII.GetBytes("bar4")); + byteArrayContent.Headers.Add("default", "\uD83D\uDE00"); + mc.Add(byteArrayContent); + + mc.HeaderEncodingSelector = (name, _) => name switch + { + "latin1" => Encoding.Latin1, + "utf8" => Encoding.UTF8, + "ascii" => Encoding.ASCII, + _ => null + }; + + var ms = new MemoryStream(); + if (async) + { + await (await mc.ReadAsStreamAsync()).CopyToAsync(ms); + } + else + { + mc.ReadAsStream().CopyTo(ms); + } + + byte[] expected = Concat( + Encoding.Latin1.GetBytes("--fooBoundary\r\n"), + Encoding.Latin1.GetBytes("Content-Type: text/plain; charset=utf-8\r\n"), + Encoding.Latin1.GetBytes("latin1: "), + Encoding.Latin1.GetBytes("\uD83D\uDE00"), + Encoding.Latin1.GetBytes("\r\n\r\n"), + Encoding.Latin1.GetBytes("bar1"), + Encoding.Latin1.GetBytes("\r\n--fooBoundary\r\n"), + Encoding.Latin1.GetBytes("utf8: "), + Encoding.UTF8.GetBytes("\uD83D\uDE00"), + Encoding.Latin1.GetBytes("\r\n\r\n"), + Encoding.Latin1.GetBytes("bar2"), + Encoding.Latin1.GetBytes("\r\n--fooBoundary\r\n"), + Encoding.Latin1.GetBytes("ascii: "), + Encoding.ASCII.GetBytes("\uD83D\uDE00"), + Encoding.Latin1.GetBytes("\r\n\r\n"), + Encoding.Latin1.GetBytes("bar3"), + Encoding.Latin1.GetBytes("\r\n--fooBoundary\r\n"), + Encoding.Latin1.GetBytes("default: "), + Encoding.Latin1.GetBytes("\uD83D\uDE00"), + Encoding.Latin1.GetBytes("\r\n\r\n"), + Encoding.Latin1.GetBytes("bar4"), + Encoding.Latin1.GetBytes("\r\n--fooBoundary--\r\n")); + + Assert.Equal(expected, ms.ToArray()); + + static byte[] Concat(params byte[][] arrays) => arrays.SelectMany(b => b).ToArray(); + } + #region Helpers private static async Task MultipartContentToStringAsync(MultipartContent content, MultipartContentToStringMode mode, bool async) diff --git a/src/libraries/System.Net.Http/tests/UnitTests/Headers/MultipartContentTest.cs b/src/libraries/System.Net.Http/tests/UnitTests/Headers/MultipartContentTest.cs new file mode 100644 index 0000000..e89d38c --- /dev/null +++ b/src/libraries/System.Net.Http/tests/UnitTests/Headers/MultipartContentTest.cs @@ -0,0 +1,86 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.IO; +using System.Text; +using System.Threading.Tasks; +using Xunit; + +namespace System.Net.Http.Tests +{ + public class MultipartContentTest + { + public static IEnumerable MultipartContent_TestData() + { + var multipartContents = new List(); + + var complexContent = new MultipartContent(); + + var stringContent = new StringContent("bar1"); + stringContent.Headers.Add("latin1", "\uD83D\uDE00"); + complexContent.Add(stringContent); + + var byteArrayContent = new ByteArrayContent(Encoding.ASCII.GetBytes("bar2")); + byteArrayContent.Headers.Add("utf8", "\uD83D\uDE00"); + complexContent.Add(byteArrayContent); + + byteArrayContent = new ByteArrayContent(Encoding.ASCII.GetBytes("bar3")); + byteArrayContent.Headers.Add("ascii", "\uD83D\uDE00"); + complexContent.Add(byteArrayContent); + + byteArrayContent = new ByteArrayContent(Encoding.ASCII.GetBytes("bar4")); + byteArrayContent.Headers.Add("default", "\uD83D\uDE00"); + complexContent.Add(byteArrayContent); + + stringContent = new StringContent("bar5"); + stringContent.Headers.Add("foo", "bar"); + complexContent.Add(stringContent); + + multipartContents.Add(complexContent); + multipartContents.Add(new MultipartContent()); + multipartContents.Add(new MultipartFormDataContent()); + + var encodingSelectors = new HeaderEncodingSelector[] + { + (_, _) => null, + (_, _) => Encoding.ASCII, + (_, _) => Encoding.Latin1, + (_, _) => Encoding.UTF8, + (name, _) => name switch + { + "latin1" => Encoding.Latin1, + "utf8" => Encoding.UTF8, + "ascii" => Encoding.ASCII, + _ => null + } + }; + + foreach (MultipartContent multipartContent in multipartContents) + { + foreach (HeaderEncodingSelector encodingSelector in encodingSelectors) + { + multipartContent.HeaderEncodingSelector = encodingSelector; + yield return new object[] { multipartContent }; + } + } + } + + [Theory] + [MemberData(nameof(MultipartContent_TestData))] + public async Task MultipartContent_TryComputeLength_ReturnsSameLengthAsCopyToAsync(MultipartContent multipartContent) + { + Assert.True(multipartContent.TryComputeLength(out long length)); + + var copyToStream = new MemoryStream(); + multipartContent.CopyTo(copyToStream, context: null, cancellationToken: default); + Assert.Equal(length, copyToStream.Length); + + var copyToAsyncStream = new MemoryStream(); + await multipartContent.CopyToAsync(copyToAsyncStream, context: null, cancellationToken: default); + Assert.Equal(length, copyToAsyncStream.Length); + + Assert.Equal(copyToStream.ToArray(), copyToAsyncStream.ToArray()); + } + } +} diff --git a/src/libraries/System.Net.Http/tests/UnitTests/System.Net.Http.Unit.Tests.csproj b/src/libraries/System.Net.Http/tests/UnitTests/System.Net.Http.Unit.Tests.csproj index c4d1682..19de137 100644 --- a/src/libraries/System.Net.Http/tests/UnitTests/System.Net.Http.Unit.Tests.csproj +++ b/src/libraries/System.Net.Http/tests/UnitTests/System.Net.Http.Unit.Tests.csproj @@ -1,4 +1,4 @@ - + ../../src/Resources/Strings.resx true @@ -90,6 +90,8 @@ Link="ProductionCode\System\Net\Http\EmptyReadStream.cs" /> + + -- 2.7.4