public MultipartContent() { }
public MultipartContent(string subtype) { }
public MultipartContent(string subtype, string boundary) { }
+ public System.Net.Http.HeaderEncodingSelector<System.Net.Http.HttpContent>? HeaderEncodingSelector { get { throw null; } set { } }
public virtual void Add(System.Net.Http.HttpContent content) { }
protected override System.IO.Stream CreateContentReadStream(System.Threading.CancellationToken cancellationToken) { throw null; }
protected override System.Threading.Tasks.Task<System.IO.Stream> CreateContentReadStreamAsync() { throw null; }
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
+using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
private const string CrLf = "\r\n";
- private static readonly int s_crlfLength = GetEncodedLength(CrLf);
- private static readonly int s_dashDashLength = GetEncodedLength("--");
- private static readonly int s_colonSpaceLength = GetEncodedLength(": ");
- private static readonly int s_commaSpaceLength = GetEncodedLength(", ");
+ private const int CrLfLength = 2;
+ private const int DashDashLength = 2;
+ private const int ColonSpaceLength = 2;
+ private const int CommaSpaceLength = 2;
private readonly List<HttpContent> _nestedContent;
private readonly string _boundary;
#region Serialization
+ /// <summary>
+ /// Gets or sets a callback that returns the <see cref="Encoding"/> to decode the value for the specified response header name,
+ /// or <see langword="null"/> to use the default behavior.
+ /// </summary>
+ public HeaderEncodingSelector<HttpContent>? HeaderEncodingSelector { get; set; }
+
// for-each content
// write "--" + boundary
// for-each content header
try
{
// Write start boundary.
- EncodeStringToStream(stream, "--" + _boundary + CrLf);
+ WriteToStream(stream, "--" + _boundary + CrLf);
// Write each nested content.
- var output = new StringBuilder();
for (int contentIndex = 0; contentIndex < _nestedContent.Count; contentIndex++)
{
// Write divider, headers, and content.
HttpContent content = _nestedContent[contentIndex];
- EncodeStringToStream(stream, SerializeHeadersToString(output, contentIndex, content));
+ SerializeHeadersToStream(stream, content, writeDivider: contentIndex != 0);
content.CopyTo(stream, context, cancellationToken);
}
// Write footer boundary.
- EncodeStringToStream(stream, CrLf + "--" + _boundary + "--" + CrLf);
+ WriteToStream(stream, CrLf + "--" + _boundary + "--" + CrLf);
}
catch (Exception ex)
{
await EncodeStringToStreamAsync(stream, "--" + _boundary + CrLf, cancellationToken).ConfigureAwait(false);
// Write each nested content.
- var output = new StringBuilder();
+ var output = new MemoryStream();
for (int contentIndex = 0; contentIndex < _nestedContent.Count; contentIndex++)
{
// Write divider, headers, and content.
HttpContent content = _nestedContent[contentIndex];
- await EncodeStringToStreamAsync(stream, SerializeHeadersToString(output, contentIndex, content), cancellationToken).ConfigureAwait(false);
+
+ output.SetLength(0);
+ SerializeHeadersToStream(output, content, writeDivider: contentIndex != 0);
+ output.Position = 0;
+ await output.CopyToAsync(stream, cancellationToken).ConfigureAwait(false);
+
await content.CopyToAsync(stream, context, cancellationToken).ConfigureAwait(false);
}
try
{
var streams = new Stream[2 + (_nestedContent.Count * 2)];
- var scratch = new StringBuilder();
int streamIndex = 0;
// Start boundary.
cancellationToken.ThrowIfCancellationRequested();
HttpContent nestedContent = _nestedContent[contentIndex];
- streams[streamIndex++] = EncodeStringToNewStream(SerializeHeadersToString(scratch, contentIndex, nestedContent));
+ streams[streamIndex++] = EncodeHeadersToNewStream(nestedContent, writeDivider: contentIndex != 0);
Stream readStream;
if (async)
}
}
- private string SerializeHeadersToString(StringBuilder scratch, int contentIndex, HttpContent content)
+ private void SerializeHeadersToStream(Stream stream, HttpContent content, bool writeDivider)
{
- scratch.Clear();
-
// Add divider.
- if (contentIndex != 0) // Write divider for all but the first content.
+ if (writeDivider) // Write divider for all but the first content.
{
- scratch.Append(CrLf + "--"); // const strings
- scratch.Append(_boundary);
- scratch.Append(CrLf);
+ WriteToStream(stream, CrLf + "--"); // const strings
+ WriteToStream(stream, _boundary);
+ WriteToStream(stream, CrLf);
}
// Add headers.
foreach (KeyValuePair<string, IEnumerable<string>> headerPair in content.Headers)
{
- scratch.Append(headerPair.Key);
- scratch.Append(": ");
+ Encoding headerValueEncoding = HeaderEncodingSelector?.Invoke(headerPair.Key, content) ?? HttpRuleParser.DefaultHttpEncoding;
+
+ WriteToStream(stream, headerPair.Key);
+ WriteToStream(stream, ": ");
string delim = string.Empty;
foreach (string value in headerPair.Value)
{
- scratch.Append(delim);
- scratch.Append(value);
+ WriteToStream(stream, delim);
+ WriteToStream(stream, value, headerValueEncoding);
delim = ", ";
}
- scratch.Append(CrLf);
+ WriteToStream(stream, CrLf);
}
// Extra CRLF to end headers (even if there are no headers).
- scratch.Append(CrLf);
-
- return scratch.ToString();
- }
-
- private static void EncodeStringToStream(Stream stream, string input)
- {
- byte[] buffer = HttpRuleParser.DefaultHttpEncoding.GetBytes(input);
- stream.Write(buffer);
+ WriteToStream(stream, CrLf);
}
private static ValueTask EncodeStringToStreamAsync(Stream stream, string input, CancellationToken cancellationToken)
return new MemoryStream(HttpRuleParser.DefaultHttpEncoding.GetBytes(input), writable: false);
}
+ private Stream EncodeHeadersToNewStream(HttpContent content, bool writeDivider)
+ {
+ var stream = new MemoryStream();
+ SerializeHeadersToStream(stream, content, writeDivider);
+ stream.Position = 0;
+ return stream;
+ }
+
internal override bool AllowDuplex => false;
protected internal override bool TryComputeLength(out long length)
{
- int boundaryLength = GetEncodedLength(_boundary);
-
- long currentLength = 0;
- long internalBoundaryLength = s_crlfLength + s_dashDashLength + boundaryLength + s_crlfLength;
-
// Start Boundary.
- currentLength += s_dashDashLength + boundaryLength + s_crlfLength;
+ long currentLength = DashDashLength + _boundary.Length + CrLfLength;
- bool first = true;
- foreach (HttpContent content in _nestedContent)
+ if (_nestedContent.Count > 1)
{
- if (first)
- {
- first = false; // First boundary already written.
- }
- else
- {
- // Internal Boundary.
- currentLength += internalBoundaryLength;
- }
+ // Internal boundaries
+ currentLength += (_nestedContent.Count - 1) * (CrLfLength + DashDashLength + _boundary.Length + CrLfLength);
+ }
+ foreach (HttpContent content in _nestedContent)
+ {
// Headers.
foreach (KeyValuePair<string, IEnumerable<string>> headerPair in content.Headers)
{
- currentLength += GetEncodedLength(headerPair.Key) + s_colonSpaceLength;
+ currentLength += headerPair.Key.Length + ColonSpaceLength;
+
+ Encoding headerValueEncoding = HeaderEncodingSelector?.Invoke(headerPair.Key, content) ?? HttpRuleParser.DefaultHttpEncoding;
int valueCount = 0;
foreach (string value in headerPair.Value)
{
- currentLength += GetEncodedLength(value);
+ currentLength += headerValueEncoding.GetByteCount(value);
valueCount++;
}
+
if (valueCount > 1)
{
- currentLength += (valueCount - 1) * s_commaSpaceLength;
+ currentLength += (valueCount - 1) * CommaSpaceLength;
}
- currentLength += s_crlfLength;
+ currentLength += CrLfLength;
}
- currentLength += s_crlfLength;
+ currentLength += CrLfLength;
// Content.
- long tempContentLength = 0;
- if (!content.TryComputeLength(out tempContentLength))
+ if (!content.TryComputeLength(out long tempContentLength))
{
length = 0;
return false;
}
// Terminating boundary.
- currentLength += s_crlfLength + s_dashDashLength + boundaryLength + s_dashDashLength + s_crlfLength;
+ currentLength += CrLfLength + DashDashLength + _boundary.Length + DashDashLength + CrLfLength;
length = currentLength;
return true;
}
- private static int GetEncodedLength(string input)
- {
- return HttpRuleParser.DefaultHttpEncoding.GetByteCount(input);
- }
-
private sealed class ContentReadStream : Stream
{
private readonly Stream[] _streams;
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { throw new NotSupportedException(); }
public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default) { throw new NotSupportedException(); }
}
+
+
+ private static void WriteToStream(Stream stream, string content) =>
+ WriteToStream(stream, content, HttpRuleParser.DefaultHttpEncoding);
+
+ private static void WriteToStream(Stream stream, string content, Encoding encoding)
+ {
+ const int StackallocThreshold = 1024;
+
+ int maxLength = encoding.GetMaxByteCount(content.Length);
+
+ byte[]? rentedBuffer = null;
+ Span<byte> buffer = maxLength <= StackallocThreshold
+ ? stackalloc byte[StackallocThreshold]
+ : (rentedBuffer = ArrayPool<byte>.Shared.Rent(maxLength));
+
+ try
+ {
+ int written = encoding.GetBytes(content, buffer);
+ stream.Write(buffer.Slice(0, written));
+ }
+ finally
+ {
+ if (rentedBuffer != null)
+ {
+ ArrayPool<byte>.Shared.Return(rentedBuffer);
+ }
+ }
+ }
+
#endregion Serialization
}
}
// The .NET Foundation licenses this file to you under the MIT license.
using System.IO;
+using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
Assert.Throws<NotImplementedException>(() => mc.ReadAsStream());
}
+ [Theory]
+ [InlineData(true)]
+ [InlineData(false)]
+ public async Task ReadAsStreamAsync_CustomEncodingSelector_SelectorIsCalledWithCustomState(bool async)
+ {
+ var mc = new MultipartContent();
+
+ var stringContent = new StringContent("foo");
+ stringContent.Headers.Add("StringContent", "foo");
+ mc.Add(stringContent);
+
+ var byteArrayContent = new ByteArrayContent(Encoding.ASCII.GetBytes("foo"));
+ byteArrayContent.Headers.Add("ByteArrayContent", "foo");
+ mc.Add(byteArrayContent);
+
+ bool seenStringContent = false, seenByteArrayContent = false;
+
+ mc.HeaderEncodingSelector = (name, content) =>
+ {
+ if (ReferenceEquals(content, stringContent) && name == "StringContent")
+ {
+ seenStringContent = true;
+ }
+
+ if (ReferenceEquals(content, byteArrayContent) && name == "ByteArrayContent")
+ {
+ seenByteArrayContent = true;
+ }
+
+ return null;
+ };
+
+ var dummy = new MemoryStream();
+ if (async)
+ {
+ await (await mc.ReadAsStreamAsync()).CopyToAsync(dummy);
+ }
+ else
+ {
+ mc.ReadAsStream().CopyTo(dummy);
+ }
+
+ Assert.True(seenStringContent);
+ Assert.True(seenByteArrayContent);
+ }
+
+ [Theory]
+ [InlineData(true)]
+ [InlineData(false)]
+ public async Task ReadAsStreamAsync_CustomEncodingSelector_CustomEncodingIsUsed(bool async)
+ {
+ var mc = new MultipartContent("subtype", "fooBoundary");
+
+ var stringContent = new StringContent("bar1");
+ stringContent.Headers.Add("latin1", "\uD83D\uDE00");
+ mc.Add(stringContent);
+
+ var byteArrayContent = new ByteArrayContent(Encoding.ASCII.GetBytes("bar2"));
+ byteArrayContent.Headers.Add("utf8", "\uD83D\uDE00");
+ mc.Add(byteArrayContent);
+
+ byteArrayContent = new ByteArrayContent(Encoding.ASCII.GetBytes("bar3"));
+ byteArrayContent.Headers.Add("ascii", "\uD83D\uDE00");
+ mc.Add(byteArrayContent);
+
+ byteArrayContent = new ByteArrayContent(Encoding.ASCII.GetBytes("bar4"));
+ byteArrayContent.Headers.Add("default", "\uD83D\uDE00");
+ mc.Add(byteArrayContent);
+
+ mc.HeaderEncodingSelector = (name, _) => name switch
+ {
+ "latin1" => Encoding.Latin1,
+ "utf8" => Encoding.UTF8,
+ "ascii" => Encoding.ASCII,
+ _ => null
+ };
+
+ var ms = new MemoryStream();
+ if (async)
+ {
+ await (await mc.ReadAsStreamAsync()).CopyToAsync(ms);
+ }
+ else
+ {
+ mc.ReadAsStream().CopyTo(ms);
+ }
+
+ byte[] expected = Concat(
+ Encoding.Latin1.GetBytes("--fooBoundary\r\n"),
+ Encoding.Latin1.GetBytes("Content-Type: text/plain; charset=utf-8\r\n"),
+ Encoding.Latin1.GetBytes("latin1: "),
+ Encoding.Latin1.GetBytes("\uD83D\uDE00"),
+ Encoding.Latin1.GetBytes("\r\n\r\n"),
+ Encoding.Latin1.GetBytes("bar1"),
+ Encoding.Latin1.GetBytes("\r\n--fooBoundary\r\n"),
+ Encoding.Latin1.GetBytes("utf8: "),
+ Encoding.UTF8.GetBytes("\uD83D\uDE00"),
+ Encoding.Latin1.GetBytes("\r\n\r\n"),
+ Encoding.Latin1.GetBytes("bar2"),
+ Encoding.Latin1.GetBytes("\r\n--fooBoundary\r\n"),
+ Encoding.Latin1.GetBytes("ascii: "),
+ Encoding.ASCII.GetBytes("\uD83D\uDE00"),
+ Encoding.Latin1.GetBytes("\r\n\r\n"),
+ Encoding.Latin1.GetBytes("bar3"),
+ Encoding.Latin1.GetBytes("\r\n--fooBoundary\r\n"),
+ Encoding.Latin1.GetBytes("default: "),
+ Encoding.Latin1.GetBytes("\uD83D\uDE00"),
+ Encoding.Latin1.GetBytes("\r\n\r\n"),
+ Encoding.Latin1.GetBytes("bar4"),
+ Encoding.Latin1.GetBytes("\r\n--fooBoundary--\r\n"));
+
+ Assert.Equal(expected, ms.ToArray());
+
+ static byte[] Concat(params byte[][] arrays) => arrays.SelectMany(b => b).ToArray();
+ }
+
#region Helpers
private static async Task<string> MultipartContentToStringAsync(MultipartContent content, MultipartContentToStringMode mode, bool async)
--- /dev/null
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Collections.Generic;
+using System.IO;
+using System.Text;
+using System.Threading.Tasks;
+using Xunit;
+
+namespace System.Net.Http.Tests
+{
+ public class MultipartContentTest
+ {
+ public static IEnumerable<object[]> MultipartContent_TestData()
+ {
+ var multipartContents = new List<MultipartContent>();
+
+ var complexContent = new MultipartContent();
+
+ var stringContent = new StringContent("bar1");
+ stringContent.Headers.Add("latin1", "\uD83D\uDE00");
+ complexContent.Add(stringContent);
+
+ var byteArrayContent = new ByteArrayContent(Encoding.ASCII.GetBytes("bar2"));
+ byteArrayContent.Headers.Add("utf8", "\uD83D\uDE00");
+ complexContent.Add(byteArrayContent);
+
+ byteArrayContent = new ByteArrayContent(Encoding.ASCII.GetBytes("bar3"));
+ byteArrayContent.Headers.Add("ascii", "\uD83D\uDE00");
+ complexContent.Add(byteArrayContent);
+
+ byteArrayContent = new ByteArrayContent(Encoding.ASCII.GetBytes("bar4"));
+ byteArrayContent.Headers.Add("default", "\uD83D\uDE00");
+ complexContent.Add(byteArrayContent);
+
+ stringContent = new StringContent("bar5");
+ stringContent.Headers.Add("foo", "bar");
+ complexContent.Add(stringContent);
+
+ multipartContents.Add(complexContent);
+ multipartContents.Add(new MultipartContent());
+ multipartContents.Add(new MultipartFormDataContent());
+
+ var encodingSelectors = new HeaderEncodingSelector<HttpContent>[]
+ {
+ (_, _) => null,
+ (_, _) => Encoding.ASCII,
+ (_, _) => Encoding.Latin1,
+ (_, _) => Encoding.UTF8,
+ (name, _) => name switch
+ {
+ "latin1" => Encoding.Latin1,
+ "utf8" => Encoding.UTF8,
+ "ascii" => Encoding.ASCII,
+ _ => null
+ }
+ };
+
+ foreach (MultipartContent multipartContent in multipartContents)
+ {
+ foreach (HeaderEncodingSelector<HttpContent> encodingSelector in encodingSelectors)
+ {
+ multipartContent.HeaderEncodingSelector = encodingSelector;
+ yield return new object[] { multipartContent };
+ }
+ }
+ }
+
+ [Theory]
+ [MemberData(nameof(MultipartContent_TestData))]
+ public async Task MultipartContent_TryComputeLength_ReturnsSameLengthAsCopyToAsync(MultipartContent multipartContent)
+ {
+ Assert.True(multipartContent.TryComputeLength(out long length));
+
+ var copyToStream = new MemoryStream();
+ multipartContent.CopyTo(copyToStream, context: null, cancellationToken: default);
+ Assert.Equal(length, copyToStream.Length);
+
+ var copyToAsyncStream = new MemoryStream();
+ await multipartContent.CopyToAsync(copyToAsyncStream, context: null, cancellationToken: default);
+ Assert.Equal(length, copyToAsyncStream.Length);
+
+ Assert.Equal(copyToStream.ToArray(), copyToAsyncStream.ToArray());
+ }
+ }
+}
-<Project Sdk="Microsoft.NET.Sdk">
+<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<StringResourcesPath>../../src/Resources/Strings.resx</StringResourcesPath>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
Link="ProductionCode\System\Net\Http\EmptyReadStream.cs" />
<Compile Include="..\..\src\System\Net\Http\FormUrlEncodedContent.cs"
Link="ProductionCode\System\Net\Http\FormUrlEncodedContent.cs" />
+ <Compile Include="..\..\src\System\Net\Http\HeaderEncodingSelector.cs"
+ Link="ProductionCode\System\Net\Http\HeaderEncodingSelector.cs" />
<Compile Include="..\..\src\System\Net\Http\Headers\AltSvcHeaderParser.cs"
Link="ProductionCode\System\Net\Http\Headers\AltSvcHeaderParser.cs" />
<Compile Include="..\..\src\System\Net\Http\Headers\AltSvcHeaderValue.cs"
<Compile Include="Headers\MediaTypeHeaderParserTest.cs" />
<Compile Include="Headers\MediaTypeHeaderValueTest.cs" />
<Compile Include="Headers\MediaTypeWithQualityHeaderValueTest.cs" />
+ <Compile Include="Headers\MultipartContentTest.cs" />
<Compile Include="Headers\NameValueHeaderValueTest.cs" />
<Compile Include="Headers\NameValueWithParametersHeaderValueTest.cs" />
<Compile Include="Headers\ObjectCollectionTest.cs" />