From 02e872ef3cf48c7fbdccc10bd52386c28b7a8676 Mon Sep 17 00:00:00 2001 From: Miha Zupan Date: Thu, 30 Jul 2020 10:36:38 +0200 Subject: [PATCH] Implement Header encoding selectors on SocketsHttpHandler (#39468) * Expose HeaderEncodingSelectors on SocketsHttpHandler * Implement header encoding selectors in SocketsHttpHandler * Add header encoding tests * Add summaries for new APIs * Use Stream.Write(byte[], int, int) overloads for framework compat * Add dummy API implementation to browser target * HPack/QPack fixes * Move HeaderEncodingSelector delegate to namespace, add TContext * Encoding fixes * Remove unused using * Simplify test * HttpConnection PR feedback * Simplify fast-path detection --- .../Http/aspnetcore/Http2/Hpack/HPackEncoder.cs | 140 ++++++++++++++------- .../Http/aspnetcore/Http3/QPack/QPackEncoder.cs | 111 ++++++++++++---- .../tests/System/Net/Http/GenericLoopbackServer.cs | 4 +- .../Common/tests/System/Net/Http/HPackEncoder.cs | 19 +-- .../System/Net/Http/Http2LoopbackConnection.cs | 2 +- .../tests/System/Net/Http/Http3LoopbackStream.cs | 2 +- .../Common/tests/System/Net/Http/LoopbackServer.cs | 84 ++++++++++--- .../tests/System/Net/Http/QPackTestEncoder.cs | 12 +- .../System.Net.Http/ref/System.Net.Http.cs | 3 + .../System.Net.Http/src/System.Net.Http.csproj | 1 + .../Http/BrowserHttpHandler/SocketsHttpHandler.cs | 12 ++ .../src/System/Net/Http/HeaderEncodingSelector.cs | 15 +++ .../System/Net/Http/Headers/HeaderDescriptor.cs | 5 +- .../Net/Http/SocketsHttpHandler/Http2Connection.cs | 34 ++--- .../Net/Http/SocketsHttpHandler/Http2Stream.cs | 8 +- .../Http/SocketsHttpHandler/Http3RequestStream.cs | 37 ++++-- .../Net/Http/SocketsHttpHandler/HttpConnection.cs | 71 +++++++++-- .../Http/SocketsHttpHandler/HttpConnectionBase.cs | 13 +- .../SocketsHttpHandler/HttpConnectionSettings.cs | 5 + .../Http/SocketsHttpHandler/SocketsHttpHandler.cs | 31 ++++- .../HttpClientHandlerTest.Headers.cs | 128 +++++++++++++++++++ .../tests/UnitTests/HPack/HPackRoundtripTests.cs | 41 +++--- .../tests/UnitTests/Headers/HeaderEncodingTest.cs | 37 ++++++ .../tests/UnitTests/Headers/KnownHeadersTest.cs | 8 +- .../UnitTests/System.Net.Http.Unit.Tests.csproj | 1 + 25 files changed, 656 insertions(+), 168 deletions(-) create mode 100644 src/libraries/System.Net.Http/src/System/Net/Http/HeaderEncodingSelector.cs create mode 100644 src/libraries/System.Net.Http/tests/UnitTests/Headers/HeaderEncodingTest.cs diff --git a/src/libraries/Common/src/System/Net/Http/aspnetcore/Http2/Hpack/HPackEncoder.cs b/src/libraries/Common/src/System/Net/Http/aspnetcore/Http2/Hpack/HPackEncoder.cs index d2fbc52..4c3ac29 100644 --- a/src/libraries/Common/src/System/Net/Http/aspnetcore/Http2/Hpack/HPackEncoder.cs +++ b/src/libraries/Common/src/System/Net/Http/aspnetcore/Http2/Hpack/HPackEncoder.cs @@ -4,6 +4,7 @@ #nullable enable using System.Collections.Generic; using System.Diagnostics; +using System.Text; namespace System.Net.Http.HPack { @@ -96,7 +97,7 @@ namespace System.Net.Http.HPack if (IntegerEncoder.Encode(index, 4, destination, out int indexLength)) { Debug.Assert(indexLength >= 1); - if (EncodeStringLiteral(value, destination.Slice(indexLength), out int nameLength)) + if (EncodeStringLiteral(value, valueEncoding: null, destination.Slice(indexLength), out int nameLength)) { bytesWritten = indexLength + nameLength; return true; @@ -128,7 +129,7 @@ namespace System.Net.Http.HPack if (IntegerEncoder.Encode(index, 4, destination, out int indexLength)) { Debug.Assert(indexLength >= 1); - if (EncodeStringLiteral(value, destination.Slice(indexLength), out int nameLength)) + if (EncodeStringLiteral(value, valueEncoding: null, destination.Slice(indexLength), out int nameLength)) { bytesWritten = indexLength + nameLength; return true; @@ -160,7 +161,7 @@ namespace System.Net.Http.HPack if (IntegerEncoder.Encode(index, 6, destination, out int indexLength)) { Debug.Assert(indexLength >= 1); - if (EncodeStringLiteral(value, destination.Slice(indexLength), out int nameLength)) + if (EncodeStringLiteral(value, valueEncoding: null, destination.Slice(indexLength), out int nameLength)) { bytesWritten = indexLength + nameLength; return true; @@ -276,7 +277,7 @@ namespace System.Net.Http.HPack { destination[0] = mask; if (EncodeLiteralHeaderName(name, destination.Slice(1), out int nameLength) && - EncodeStringLiteral(value, destination.Slice(1 + nameLength), out int valueLength)) + EncodeStringLiteral(value, valueEncoding: null, destination.Slice(1 + nameLength), out int valueLength)) { bytesWritten = 1 + nameLength + valueLength; return true; @@ -290,6 +291,11 @@ namespace System.Net.Http.HPack /// Encodes a "Literal Header Field without Indexing - New Name". public static bool EncodeLiteralHeaderFieldWithoutIndexingNewName(string name, ReadOnlySpan values, string separator, Span destination, out int bytesWritten) { + return EncodeLiteralHeaderFieldWithoutIndexingNewName(name, values, separator, valueEncoding: null, destination, out bytesWritten); + } + + public static bool EncodeLiteralHeaderFieldWithoutIndexingNewName(string name, ReadOnlySpan values, string separator, Encoding? valueEncoding, Span destination, out int bytesWritten) + { // From https://tools.ietf.org/html/rfc7541#section-6.2.2 // ------------------------------------------------------ // 0 1 2 3 4 5 6 7 @@ -309,7 +315,7 @@ namespace System.Net.Http.HPack { destination[0] = 0; if (EncodeLiteralHeaderName(name, destination.Slice(1), out int nameLength) && - EncodeStringLiterals(values, separator, destination.Slice(1 + nameLength), out int valueLength)) + EncodeStringLiterals(values, separator, valueEncoding, destination.Slice(1 + nameLength), out int valueLength)) { bytesWritten = 1 + nameLength + valueLength; return true; @@ -395,27 +401,20 @@ namespace System.Net.Http.HPack return false; } - private static bool EncodeStringLiteralValue(string value, Span destination, out int bytesWritten) + private static void EncodeValueStringPart(string value, Span destination) { - if (value.Length <= destination.Length) + Debug.Assert(destination.Length >= value.Length); + + for (int i = 0; i < value.Length; i++) { - for (int i = 0; i < value.Length; i++) + char c = value[i]; + if ((c & 0xFF80) != 0) { - char c = value[i]; - if ((c & 0xFF80) != 0) - { - throw new HttpRequestException(SR.net_http_request_invalid_char_encoding); - } - - destination[i] = (byte)c; + throw new HttpRequestException(SR.net_http_request_invalid_char_encoding); } - bytesWritten = value.Length; - return true; + destination[i] = (byte)c; } - - bytesWritten = 0; - return false; } public static bool EncodeStringLiteral(ReadOnlySpan value, Span destination, out int bytesWritten) @@ -454,6 +453,11 @@ namespace System.Net.Http.HPack public static bool EncodeStringLiteral(string value, Span destination, out int bytesWritten) { + return EncodeStringLiteral(value, valueEncoding: null, destination, out bytesWritten); + } + + public static bool EncodeStringLiteral(string value, Encoding? valueEncoding, Span destination, out int bytesWritten) + { // From https://tools.ietf.org/html/rfc7541#section-5.2 // ------------------------------------------------------ // 0 1 2 3 4 5 6 7 @@ -466,13 +470,28 @@ namespace System.Net.Http.HPack if (destination.Length != 0) { destination[0] = 0; // TODO: Use Huffman encoding - if (IntegerEncoder.Encode(value.Length, 7, destination, out int integerLength)) + + int encodedStringLength = valueEncoding is null || ReferenceEquals(valueEncoding, Encoding.Latin1) + ? value.Length + : valueEncoding.GetByteCount(value); + + if (IntegerEncoder.Encode(encodedStringLength, 7, destination, out int integerLength)) { Debug.Assert(integerLength >= 1); - - if (EncodeStringLiteralValue(value, destination.Slice(integerLength), out int valueLength)) + destination = destination.Slice(integerLength); + if (encodedStringLength <= destination.Length) { - bytesWritten = integerLength + valueLength; + if (valueEncoding is null) + { + EncodeValueStringPart(value, destination); + } + else + { + int written = valueEncoding.GetBytes(value, destination); + Debug.Assert(written == encodedStringLength); + } + + bytesWritten = integerLength + encodedStringLength; return true; } } @@ -503,55 +522,86 @@ namespace System.Net.Http.HPack public static bool EncodeStringLiterals(ReadOnlySpan values, string? separator, Span destination, out int bytesWritten) { + return EncodeStringLiterals(values, separator, valueEncoding: null, destination, out bytesWritten); + } + + public static bool EncodeStringLiterals(ReadOnlySpan values, string? separator, Encoding? valueEncoding, Span destination, out int bytesWritten) + { bytesWritten = 0; if (values.Length == 0) { - return EncodeStringLiteral("", destination, out bytesWritten); + return EncodeStringLiteral("", valueEncoding: null, destination, out bytesWritten); } else if (values.Length == 1) { - return EncodeStringLiteral(values[0], destination, out bytesWritten); + return EncodeStringLiteral(values[0], valueEncoding, destination, out bytesWritten); } if (destination.Length != 0) { - int valueLength = 0; + Debug.Assert(separator != null); + int valueLength; // Calculate length of all parts and separators. - foreach (string part in values) + if (valueEncoding is null || ReferenceEquals(valueEncoding, Encoding.Latin1)) { - valueLength = checked((int)(valueLength + part.Length)); + valueLength = checked((int)(values.Length - 1) * separator.Length); + foreach (string part in values) + { + valueLength = checked((int)(valueLength + part.Length)); + } + } + else + { + valueLength = checked((int)(values.Length - 1) * valueEncoding.GetByteCount(separator)); + foreach (string part in values) + { + valueLength = checked((int)(valueLength + valueEncoding.GetByteCount(part))); + } } - - Debug.Assert(separator != null); - valueLength = checked((int)(valueLength + (values.Length - 1) * separator.Length)); destination[0] = 0; if (IntegerEncoder.Encode(valueLength, 7, destination, out int integerLength)) { Debug.Assert(integerLength >= 1); - - int encodedLength = 0; - for (int j = 0; j < values.Length; j++) + destination = destination.Slice(integerLength); + if (destination.Length >= valueLength) { - if (j != 0 && !EncodeStringLiteralValue(separator, destination.Slice(integerLength), out encodedLength)) + if (valueEncoding is null) { - return false; + string value = values[0]; + EncodeValueStringPart(value, destination); + destination = destination.Slice(value.Length); + + for (int i = 1; i < values.Length; i++) + { + EncodeValueStringPart(separator, destination); + destination = destination.Slice(separator.Length); + + value = values[i]; + EncodeValueStringPart(value, destination); + destination = destination.Slice(value.Length); + } } + else + { + int written = valueEncoding.GetBytes(values[0], destination); + destination = destination.Slice(written); - integerLength += encodedLength; + for (int i = 1; i < values.Length; i++) + { + written = valueEncoding.GetBytes(separator, destination); + destination = destination.Slice(written); - if (!EncodeStringLiteralValue(values[j], destination.Slice(integerLength), out encodedLength)) - { - return false; + written = valueEncoding.GetBytes(values[i], destination); + destination = destination.Slice(written); + } } - integerLength += encodedLength; + bytesWritten = integerLength + valueLength; + return true; } - - bytesWritten = integerLength; - return true; } } diff --git a/src/libraries/Common/src/System/Net/Http/aspnetcore/Http3/QPack/QPackEncoder.cs b/src/libraries/Common/src/System/Net/Http/aspnetcore/Http3/QPack/QPackEncoder.cs index be43dc3..68e04ed 100644 --- a/src/libraries/Common/src/System/Net/Http/aspnetcore/Http3/QPack/QPackEncoder.cs +++ b/src/libraries/Common/src/System/Net/Http/aspnetcore/Http3/QPack/QPackEncoder.cs @@ -5,6 +5,7 @@ using System.Collections.Generic; using System.Diagnostics; using System.Net.Http.HPack; +using System.Text; namespace System.Net.Http.QPack { @@ -60,6 +61,11 @@ namespace System.Net.Http.QPack // - H is constant 0 here, as we do not yet perform Huffman coding. public static bool EncodeLiteralHeaderFieldWithStaticNameReference(int index, string value, Span destination, out int bytesWritten) { + return EncodeLiteralHeaderFieldWithStaticNameReference(index, value, valueEncoding: null, destination, out bytesWritten); + } + + public static bool EncodeLiteralHeaderFieldWithStaticNameReference(int index, string value, Encoding? valueEncoding, Span destination, out int bytesWritten) + { // Requires at least two bytes (one for name reference header, one for value length) if (destination.Length >= 2) { @@ -68,7 +74,7 @@ namespace System.Net.Http.QPack { destination = destination.Slice(headerBytesWritten); - if (EncodeValueString(value, destination, out int valueBytesWritten)) + if (EncodeValueString(value, valueEncoding, destination, out int valueBytesWritten)) { bytesWritten = headerBytesWritten + valueBytesWritten; return true; @@ -81,7 +87,7 @@ namespace System.Net.Http.QPack } /// - /// Encodes just the name part of a Literal Header Field With Static Name Reference. Must call after to encode the header's value. + /// Encodes just the name part of a Literal Header Field With Static Name Reference. Must call after to encode the header's value. /// public static byte[] EncodeLiteralHeaderFieldWithStaticNameReferenceToArray(int index) { @@ -119,7 +125,12 @@ namespace System.Net.Http.QPack // - H is constant 0 here, as we do not yet perform Huffman coding. public static bool EncodeLiteralHeaderFieldWithoutNameReference(string name, string value, Span destination, out int bytesWritten) { - if (EncodeNameString(name, destination, out int nameLength) && EncodeValueString(value, destination.Slice(nameLength), out int valueLength)) + return EncodeLiteralHeaderFieldWithoutNameReference(name, value, valueEncoding: null, destination, out bytesWritten); + } + + public static bool EncodeLiteralHeaderFieldWithoutNameReference(string name, string value, Encoding? valueEncoding, Span destination, out int bytesWritten) + { + if (EncodeNameString(name, destination, out int nameLength) && EncodeValueString(value, valueEncoding, destination.Slice(nameLength), out int valueLength)) { bytesWritten = nameLength + valueLength; return true; @@ -136,7 +147,12 @@ namespace System.Net.Http.QPack /// public static bool EncodeLiteralHeaderFieldWithoutNameReference(string name, ReadOnlySpan values, string valueSeparator, Span destination, out int bytesWritten) { - if (EncodeNameString(name, destination, out int nameLength) && EncodeValueString(values, valueSeparator, destination.Slice(nameLength), out int valueLength)) + return EncodeLiteralHeaderFieldWithoutNameReference(name, values, valueSeparator, valueEncoding: null, destination, out bytesWritten); + } + + public static bool EncodeLiteralHeaderFieldWithoutNameReference(string name, ReadOnlySpan values, string valueSeparator, Encoding? valueEncoding, Span destination, out int bytesWritten) + { + if (EncodeNameString(name, destination, out int nameLength) && EncodeValueString(values, valueSeparator, valueEncoding, destination.Slice(nameLength), out int valueLength)) { bytesWritten = nameLength + valueLength; return true; @@ -147,7 +163,7 @@ namespace System.Net.Http.QPack } /// - /// Encodes just the value part of a Literawl Header Field Without Static Name Reference. Must call after to encode the header's value. + /// Encodes just the value part of a Literawl Header Field Without Static Name Reference. Must call after to encode the header's value. /// public static byte[] EncodeLiteralHeaderFieldWithoutNameReferenceToArray(string name) { @@ -169,19 +185,32 @@ namespace System.Net.Http.QPack return temp.Slice(0, bytesWritten).ToArray(); } - private static bool EncodeValueString(string s, Span buffer, out int length) + private static bool EncodeValueString(string s, Encoding? valueEncoding, Span buffer, out int length) { if (buffer.Length != 0) { buffer[0] = 0; - if (IntegerEncoder.Encode(s.Length, 7, buffer, out int nameLength)) + + int encodedStringLength = valueEncoding is null || ReferenceEquals(valueEncoding, Encoding.Latin1) + ? s.Length + : valueEncoding.GetByteCount(s); + + if (IntegerEncoder.Encode(encodedStringLength, 7, buffer, out int nameLength)) { buffer = buffer.Slice(nameLength); - if (buffer.Length >= s.Length) + if (buffer.Length >= encodedStringLength) { - EncodeValueStringPart(s, buffer); + if (valueEncoding is null) + { + EncodeValueStringPart(s, buffer); + } + else + { + int written = valueEncoding.GetBytes(s, buffer); + Debug.Assert(written == encodedStringLength); + } - length = nameLength + s.Length; + length = nameLength + encodedStringLength; return true; } } @@ -196,24 +225,41 @@ namespace System.Net.Http.QPack /// public static bool EncodeValueString(ReadOnlySpan values, string? separator, Span buffer, out int length) { + return EncodeValueString(values, separator, valueEncoding: null, buffer, out length); + } + + public static bool EncodeValueString(ReadOnlySpan values, string? separator, Encoding? valueEncoding, Span buffer, out int length) + { if (values.Length == 1) { - return EncodeValueString(values[0], buffer, out length); + return EncodeValueString(values[0], valueEncoding, buffer, out length); } if (values.Length == 0) { // TODO: this will be called with a string array from HttpHeaderCollection. Can we ever get a 0-length array from that? Assert if not. - return EncodeValueString(string.Empty, buffer, out length); + return EncodeValueString(string.Empty, valueEncoding: null, buffer, out length); } if (buffer.Length > 0) { Debug.Assert(separator != null); - int valueLength = separator.Length * (values.Length - 1); - for (int i = 0; i < values.Length; ++i) + int valueLength; + if (valueEncoding is null || ReferenceEquals(valueEncoding, Encoding.Latin1)) + { + valueLength = separator.Length * (values.Length - 1); + foreach (string part in values) + { + valueLength += part.Length; + } + } + else { - valueLength += values[i].Length; + valueLength = valueEncoding.GetByteCount(separator) * (values.Length - 1); + foreach (string part in values) + { + valueLength += valueEncoding.GetByteCount(part); + } } buffer[0] = 0; @@ -222,18 +268,35 @@ namespace System.Net.Http.QPack buffer = buffer.Slice(nameLength); if (buffer.Length >= valueLength) { - string value = values[0]; - EncodeValueStringPart(value, buffer); - buffer = buffer.Slice(value.Length); - - for (int i = 1; i < values.Length; ++i) + if (valueEncoding is null) { - EncodeValueStringPart(separator, buffer); - buffer = buffer.Slice(separator.Length); - - value = values[i]; + string value = values[0]; EncodeValueStringPart(value, buffer); buffer = buffer.Slice(value.Length); + + for (int i = 1; i < values.Length; i++) + { + EncodeValueStringPart(separator, buffer); + buffer = buffer.Slice(separator.Length); + + value = values[i]; + EncodeValueStringPart(value, buffer); + buffer = buffer.Slice(value.Length); + } + } + else + { + int written = valueEncoding.GetBytes(values[0], buffer); + buffer = buffer.Slice(written); + + for (int i = 1; i < values.Length; i++) + { + written = valueEncoding.GetBytes(separator, buffer); + buffer = buffer.Slice(written); + + written = valueEncoding.GetBytes(values[i], buffer); + buffer = buffer.Slice(written); + } } length = nameLength + valueLength; diff --git a/src/libraries/Common/tests/System/Net/Http/GenericLoopbackServer.cs b/src/libraries/Common/tests/System/Net/Http/GenericLoopbackServer.cs index d15ab88..899b909 100644 --- a/src/libraries/Common/tests/System/Net/Http/GenericLoopbackServer.cs +++ b/src/libraries/Common/tests/System/Net/Http/GenericLoopbackServer.cs @@ -106,13 +106,15 @@ namespace System.Net.Test.Common public string Value { get; } public bool HuffmanEncoded { get; } public byte[] Raw { get; } + public Encoding ValueEncoding { get; } - public HttpHeaderData(string name, string value, bool huffmanEncoded = false, byte[] raw = null) + public HttpHeaderData(string name, string value, bool huffmanEncoded = false, byte[] raw = null, Encoding valueEncoding = null) { Name = name; Value = value; HuffmanEncoded = huffmanEncoded; Raw = raw; + ValueEncoding = valueEncoding; } public override string ToString() => Name == null ? "" : (Name + ": " + (Value ?? string.Empty)); diff --git a/src/libraries/Common/tests/System/Net/Http/HPackEncoder.cs b/src/libraries/Common/tests/System/Net/Http/HPackEncoder.cs index bbdf89f..9ecaf44 100644 --- a/src/libraries/Common/tests/System/Net/Http/HPackEncoder.cs +++ b/src/libraries/Common/tests/System/Net/Http/HPackEncoder.cs @@ -51,7 +51,7 @@ namespace System.Net.Test.Common public static int EncodeHeader(int nameIdx, string value, HPackFlags flags, Span headerBlock) { Debug.Assert(nameIdx > 0); - return EncodeHeaderImpl(nameIdx, null, value, flags, headerBlock); + return EncodeHeaderImpl(nameIdx, null, value, valueEncoding: null, flags, headerBlock); } /// @@ -63,10 +63,15 @@ namespace System.Net.Test.Common /// The number of bytes written to . public static int EncodeHeader(string name, string value, HPackFlags flags, Span headerBlock) { - return EncodeHeaderImpl(0, name, value, flags, headerBlock); + return EncodeHeader(name, value, valueEncoding: null, flags, headerBlock); } - private static int EncodeHeaderImpl(int nameIdx, string name, string value, HPackFlags flags, Span headerBlock) + public static int EncodeHeader(string name, string value, Encoding valueEncoding, HPackFlags flags, Span headerBlock) + { + return EncodeHeaderImpl(0, name, value, valueEncoding, flags, headerBlock); + } + + private static int EncodeHeaderImpl(int nameIdx, string name, string value, Encoding valueEncoding, HPackFlags flags, Span headerBlock) { const HPackFlags IndexingMask = HPackFlags.NeverIndexed | HPackFlags.NewIndexed | HPackFlags.WithoutIndexing; @@ -97,16 +102,16 @@ namespace System.Net.Test.Common if (name != null) { - bytesGenerated += EncodeString(name, headerBlock.Slice(bytesGenerated), (flags & HPackFlags.HuffmanEncodeName) != 0); + bytesGenerated += EncodeString(name, Encoding.ASCII, headerBlock.Slice(bytesGenerated), (flags & HPackFlags.HuffmanEncodeName) != 0); } - bytesGenerated += EncodeString(value, headerBlock.Slice(bytesGenerated), (flags & HPackFlags.HuffmanEncodeValue) != 0); + bytesGenerated += EncodeString(value, valueEncoding, headerBlock.Slice(bytesGenerated), (flags & HPackFlags.HuffmanEncodeValue) != 0); return bytesGenerated; } - public static int EncodeString(string value, Span headerBlock, bool huffmanEncode) + public static int EncodeString(string value, Encoding valueEncoding, Span headerBlock, bool huffmanEncode) { - byte[] data = Encoding.ASCII.GetBytes(value); + byte[] data = (valueEncoding ?? Encoding.ASCII).GetBytes(value); byte prefix; if (!huffmanEncode) diff --git a/src/libraries/Common/tests/System/Net/Http/Http2LoopbackConnection.cs b/src/libraries/Common/tests/System/Net/Http/Http2LoopbackConnection.cs index ad60410..9ee6479 100644 --- a/src/libraries/Common/tests/System/Net/Http/Http2LoopbackConnection.cs +++ b/src/libraries/Common/tests/System/Net/Http/Http2LoopbackConnection.cs @@ -662,7 +662,7 @@ namespace System.Net.Test.Common { foreach (HttpHeaderData headerData in headers) { - bytesGenerated += HPackEncoder.EncodeHeader(headerData.Name, headerData.Value, headerData.HuffmanEncoded ? HPackFlags.HuffmanEncode : HPackFlags.None, headerBlock.AsSpan(bytesGenerated)); + bytesGenerated += HPackEncoder.EncodeHeader(headerData.Name, headerData.Value, headerData.ValueEncoding, headerData.HuffmanEncoded ? HPackFlags.HuffmanEncode : HPackFlags.None, headerBlock.AsSpan(bytesGenerated)); } } diff --git a/src/libraries/Common/tests/System/Net/Http/Http3LoopbackStream.cs b/src/libraries/Common/tests/System/Net/Http/Http3LoopbackStream.cs index f13cbd6..d0e8d43 100644 --- a/src/libraries/Common/tests/System/Net/Http/Http3LoopbackStream.cs +++ b/src/libraries/Common/tests/System/Net/Http/Http3LoopbackStream.cs @@ -78,7 +78,7 @@ namespace System.Net.Test.Common foreach (HttpHeaderData header in headers) { - bytesWritten += QPackTestEncoder.EncodeHeader(buffer.AsSpan(bytesWritten), header.Name, header.Value, header.HuffmanEncoded ? QPackFlags.HuffmanEncode : QPackFlags.None); + bytesWritten += QPackTestEncoder.EncodeHeader(buffer.AsSpan(bytesWritten), header.Name, header.Value, header.ValueEncoding, header.HuffmanEncoded ? QPackFlags.HuffmanEncode : QPackFlags.None); } await SendFrameAsync(HeadersFrame, buffer.AsMemory(0, bytesWritten)).ConfigureAwait(false); diff --git a/src/libraries/Common/tests/System/Net/Http/LoopbackServer.cs b/src/libraries/Common/tests/System/Net/Http/LoopbackServer.cs index 76c8459..2d34c3b 100644 --- a/src/libraries/Common/tests/System/Net/Http/LoopbackServer.cs +++ b/src/libraries/Common/tests/System/Net/Http/LoopbackServer.cs @@ -15,6 +15,9 @@ namespace System.Net.Test.Common { public sealed partial class LoopbackServer : GenericLoopbackServer, IDisposable { + private static readonly byte[] s_newLineBytes = new byte[] { (byte)'\r', (byte)'\n' }; + private static readonly byte[] s_colonSpaceBytes = new byte[] { (byte)':', (byte)' ' }; + private Socket _listenSocket; private Options _options; private Uri _uri; @@ -533,6 +536,16 @@ namespace System.Net.Test.Common public async Task ReadLineAsync() { + byte[] lineBytes = await ReadLineBytesAsync().ConfigureAwait(false); + + if (lineBytes is null) + return null; + + return Encoding.ASCII.GetString(lineBytes); + } + + private async Task ReadLineBytesAsync() + { int index = 0; int startSearch = _readStart; @@ -578,7 +591,7 @@ namespace System.Net.Test.Common if (_readBuffer[_readStart + stringLength] == '\n') { stringLength--; } if (_readBuffer[_readStart + stringLength] == '\r') { stringLength--; } - string line = System.Text.Encoding.ASCII.GetString(_readBuffer, _readStart, stringLength + 1); + byte[] line = _readBuffer.AsSpan(_readStart, stringLength + 1).ToArray(); _readStart = index + 1; return line; } @@ -625,6 +638,32 @@ namespace System.Net.Test.Common return lines; } + private async Task> ReadRequestHeaderBytesAsync() + { + var lines = new List(); + + byte[] line; + + while (true) + { + line = await ReadLineBytesAsync().ConfigureAwait(false); + + if (line is null || line.Length == 0) + { + break; + } + + lines.Add(line); + } + + if (line == null) + { + throw new IOException("Unexpected EOF trying to read request header"); + } + + return lines; + } + public async Task SendResponseAsync(string response) { await _writer.WriteAsync(response).ConfigureAwait(false); @@ -663,24 +702,24 @@ namespace System.Net.Test.Common public override async Task ReadRequestDataAsync(bool readBody = true) { - List headerLines = null; HttpRequestData requestData = new HttpRequestData(); - headerLines = await ReadRequestHeaderAsync().ConfigureAwait(false); + List headerLines = await ReadRequestHeaderBytesAsync().ConfigureAwait(false); // Parse method and path - string[] splits = headerLines[0].Split(' '); + string[] splits = Encoding.ASCII.GetString(headerLines[0]).Split(' '); requestData.Method = splits[0]; requestData.Path = splits[1]; // Convert header lines to key/value pairs // Skip first line since it's the status line - foreach (var line in headerLines.Skip(1)) + foreach (byte[] lineBytes in headerLines.Skip(1)) { + string line = Encoding.ASCII.GetString(lineBytes); int offset = line.IndexOf(':'); string name = line.Substring(0, offset); string value = line.Substring(offset + 1).TrimStart(); - requestData.Headers.Add(new HttpHeaderData(name, value)); + requestData.Headers.Add(new HttpHeaderData(name, value, raw: lineBytes)); } if (requestData.Method != "GET") @@ -760,7 +799,7 @@ namespace System.Net.Test.Common public override async Task SendResponseAsync(HttpStatusCode? statusCode = HttpStatusCode.OK, IList headers = null, string content = null, bool isFinal = true, int requestId = 0) { - string headerString = null; + MemoryStream headerBytes = new MemoryStream(); int contentLength = -1; bool isChunked = false; bool hasContentLength = false; @@ -785,22 +824,39 @@ namespace System.Net.Test.Common isChunked = true; } - headerString = headerString + $"{headerData.Name}: {headerData.Value}\r\n"; + byte[] nameBytes = Encoding.ASCII.GetBytes(headerData.Name); + headerBytes.Write(nameBytes, 0, nameBytes.Length); + headerBytes.Write(s_colonSpaceBytes, 0, s_colonSpaceBytes.Length); + + byte[] valueBytes = (headerData.ValueEncoding ?? Encoding.ASCII).GetBytes(headerData.Value); + headerBytes.Write(valueBytes, 0, valueBytes.Length); + headerBytes.Write(s_newLineBytes, 0, s_newLineBytes.Length); } } bool endHeaders = content != null || isFinal; if (statusCode != null) { - // If we need to send status line, prepped it to headers and possibly add missing headers to the end. - headerString = + byte[] temp = headerBytes.ToArray(); + + headerBytes.SetLength(0); + + byte[] headerStartBytes = Encoding.ASCII.GetBytes( $"HTTP/1.1 {(int)statusCode} {GetStatusDescription((HttpStatusCode)statusCode)}\r\n" + - (!hasContentLength && !isChunked && content != null ? $"Content-length: {content.Length}\r\n" : "") + - headerString + - (endHeaders ? "\r\n" : ""); + (!hasContentLength && !isChunked && content != null ? $"Content-length: {content.Length}\r\n" : "")); + + headerBytes.Write(headerStartBytes, 0, headerStartBytes.Length); + headerBytes.Write(temp, 0, temp.Length); + + if (endHeaders) + { + headerBytes.Write(s_newLineBytes, 0, s_newLineBytes.Length); + } } - await SendResponseAsync(headerString).ConfigureAwait(false); + headerBytes.Position = 0; + await headerBytes.CopyToAsync(_stream).ConfigureAwait(false); + if (content != null) { await SendResponseBodyAsync(content, isFinal: isFinal, requestId: requestId).ConfigureAwait(false); diff --git a/src/libraries/Common/tests/System/Net/Http/QPackTestEncoder.cs b/src/libraries/Common/tests/System/Net/Http/QPackTestEncoder.cs index 3df439e..a73f420 100644 --- a/src/libraries/Common/tests/System/Net/Http/QPackTestEncoder.cs +++ b/src/libraries/Common/tests/System/Net/Http/QPackTestEncoder.cs @@ -49,7 +49,7 @@ namespace System.Net.Test.Common return EncodeInteger(buffer, nameValueIdx, prefix, prefixMask); } - public static int EncodeHeader(Span buffer, int nameIdx, string value, QPackFlags flags = QPackFlags.StaticIndex) + public static int EncodeHeader(Span buffer, int nameIdx, string value, Encoding valueEncoding, QPackFlags flags = QPackFlags.StaticIndex) { byte prefix, prefixMask; @@ -76,12 +76,12 @@ namespace System.Net.Test.Common } int nameLen = EncodeInteger(buffer, nameIdx, prefix, prefixMask); - int valueLen = EncodeString(buffer.Slice(nameLen), value, flags.HasFlag(QPackFlags.HuffmanEncodeValue)); + int valueLen = EncodeString(buffer.Slice(nameLen), value, valueEncoding, flags.HasFlag(QPackFlags.HuffmanEncodeValue)); return nameLen + valueLen; } - public static int EncodeHeader(Span buffer, string name, string value, QPackFlags flags = QPackFlags.None) + public static int EncodeHeader(Span buffer, string name, string value, Encoding valueEncoding, QPackFlags flags = QPackFlags.None) { byte[] data = Encoding.ASCII.GetBytes(name); byte prefix; @@ -116,14 +116,14 @@ namespace System.Net.Test.Common bytesGenerated += data.Length; // write value string. - bytesGenerated += EncodeString(buffer.Slice(bytesGenerated), value, flags.HasFlag(QPackFlags.HuffmanEncodeValue)); + bytesGenerated += EncodeString(buffer.Slice(bytesGenerated), value, valueEncoding, flags.HasFlag(QPackFlags.HuffmanEncodeValue)); return bytesGenerated; } - public static int EncodeString(Span buffer, string value, bool huffmanCoded = false) + public static int EncodeString(Span buffer, string value, Encoding valueEncoding, bool huffmanCoded = false) { - return HPackEncoder.EncodeString(value, buffer, huffmanCoded); + return HPackEncoder.EncodeString(value, valueEncoding, buffer, huffmanCoded); } public static int EncodeInteger(Span buffer, int value, byte prefix, byte prefixMask) diff --git a/src/libraries/System.Net.Http/ref/System.Net.Http.cs b/src/libraries/System.Net.Http/ref/System.Net.Http.cs index 8e2c957..8f72cd9 100644 --- a/src/libraries/System.Net.Http/ref/System.Net.Http.cs +++ b/src/libraries/System.Net.Http/ref/System.Net.Http.cs @@ -39,6 +39,7 @@ namespace System.Net.Http public FormUrlEncodedContent(System.Collections.Generic.IEnumerable> nameValueCollection) : base (default(byte[])) { } protected override System.Threading.Tasks.Task SerializeToStreamAsync(System.IO.Stream stream, System.Net.TransportContext? context, System.Threading.CancellationToken cancellationToken) { throw null; } } + public delegate System.Text.Encoding? HeaderEncodingSelector(string headerName, TContext context); public partial class HttpClient : System.Net.Http.HttpMessageInvoker { public HttpClient() : base (default(System.Net.Http.HttpMessageHandler)) { } @@ -346,7 +347,9 @@ namespace System.Net.Http public bool PreAuthenticate { get { throw null; } set { } } public System.Collections.Generic.IDictionary Properties { get { throw null; } } public System.Net.IWebProxy? Proxy { get { throw null; } set { } } + public System.Net.Http.HeaderEncodingSelector? RequestHeaderEncodingSelector { get { throw null; } set { } } public System.TimeSpan ResponseDrainTimeout { get { throw null; } set { } } + public System.Net.Http.HeaderEncodingSelector? ResponseHeaderEncodingSelector { get { throw null; } set { } } [System.Diagnostics.CodeAnalysis.AllowNullAttribute] public System.Net.Security.SslClientAuthenticationOptions SslOptions { get { throw null; } set { } } public bool UseCookies { get { throw null; } set { } } diff --git a/src/libraries/System.Net.Http/src/System.Net.Http.csproj b/src/libraries/System.Net.Http/src/System.Net.Http.csproj index 60e0afe..a30bc93 100644 --- a/src/libraries/System.Net.Http/src/System.Net.Http.csproj +++ b/src/libraries/System.Net.Http/src/System.Net.Http.csproj @@ -29,6 +29,7 @@ + diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/BrowserHttpHandler/SocketsHttpHandler.cs b/src/libraries/System.Net.Http/src/System/Net/Http/BrowserHttpHandler/SocketsHttpHandler.cs index 657d003..8c462ee 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/BrowserHttpHandler/SocketsHttpHandler.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/BrowserHttpHandler/SocketsHttpHandler.cs @@ -142,6 +142,18 @@ namespace System.Net.Http public IDictionary Properties => throw new PlatformNotSupportedException(); + public HeaderEncodingSelector? RequestHeaderEncodingSelector + { + get => throw new PlatformNotSupportedException(); + set => throw new PlatformNotSupportedException(); + } + + public HeaderEncodingSelector? ResponseHeaderEncodingSelector + { + get => throw new PlatformNotSupportedException(); + set => throw new PlatformNotSupportedException(); + } + protected internal override Task SendAsync( HttpRequestMessage request, CancellationToken cancellationToken) => throw new PlatformNotSupportedException(); diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/HeaderEncodingSelector.cs b/src/libraries/System.Net.Http/src/System/Net/Http/HeaderEncodingSelector.cs new file mode 100644 index 0000000..cb984b8 --- /dev/null +++ b/src/libraries/System.Net.Http/src/System/Net/Http/HeaderEncodingSelector.cs @@ -0,0 +1,15 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Text; + +namespace System.Net.Http +{ + /// + /// Represents a method that specifies the to use when interpreting header values. + /// + /// Name of the header to specify the for. + /// The we are enoding/decoding the headers for. + /// to use or to use the default behavior. + public delegate Encoding? HeaderEncodingSelector(string headerName, TContext context); +} diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HeaderDescriptor.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HeaderDescriptor.cs index 8970f78..229490b 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HeaderDescriptor.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HeaderDescriptor.cs @@ -4,6 +4,7 @@ using System.Buffers; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; +using System.Text; using System.Text.Unicode; namespace System.Net.Http.Headers @@ -116,7 +117,7 @@ namespace System.Net.Http.Headers return new HeaderDescriptor(_knownHeader.Name); } - public string GetHeaderValue(ReadOnlySpan headerValue) + public string GetHeaderValue(ReadOnlySpan headerValue, Encoding? valueEncoding) { if (headerValue.Length == 0) { @@ -156,7 +157,7 @@ namespace System.Net.Http.Headers } } - return HttpRuleParser.DefaultHttpEncoding.GetString(headerValue); + return (valueEncoding ?? HttpRuleParser.DefaultHttpEncoding).GetString(headerValue); } internal static string? GetKnownContentType(ReadOnlySpan contentTypeValue) diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs index 1bd39f4..c9ad97f 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs @@ -974,12 +974,12 @@ namespace System.Net.Http headerBuffer.Commit(bytesWritten); } - private void WriteLiteralHeader(string name, ReadOnlySpan values, ref ArrayBuffer headerBuffer) + private void WriteLiteralHeader(string name, ReadOnlySpan values, Encoding? valueEncoding, ref ArrayBuffer headerBuffer) { if (NetEventSource.Log.IsEnabled()) Trace($"{nameof(name)}={name}, {nameof(values)}={string.Join(", ", values.ToArray())}"); int bytesWritten; - while (!HPackEncoder.EncodeLiteralHeaderFieldWithoutIndexingNewName(name, values, HttpHeaderParser.DefaultSeparator, headerBuffer.AvailableSpan, out bytesWritten)) + while (!HPackEncoder.EncodeLiteralHeaderFieldWithoutIndexingNewName(name, values, HttpHeaderParser.DefaultSeparator, valueEncoding, headerBuffer.AvailableSpan, out bytesWritten)) { headerBuffer.EnsureAvailableSpace(headerBuffer.AvailableLength + 1); } @@ -987,12 +987,12 @@ namespace System.Net.Http headerBuffer.Commit(bytesWritten); } - private void WriteLiteralHeaderValues(ReadOnlySpan values, string? separator, ref ArrayBuffer headerBuffer) + private void WriteLiteralHeaderValues(ReadOnlySpan values, string? separator, Encoding? valueEncoding, ref ArrayBuffer headerBuffer) { if (NetEventSource.Log.IsEnabled()) Trace($"{nameof(values)}={string.Join(separator, values.ToArray())}"); int bytesWritten; - while (!HPackEncoder.EncodeStringLiterals(values, separator, headerBuffer.AvailableSpan, out bytesWritten)) + while (!HPackEncoder.EncodeStringLiterals(values, separator, valueEncoding, headerBuffer.AvailableSpan, out bytesWritten)) { headerBuffer.EnsureAvailableSpace(headerBuffer.AvailableLength + 1); } @@ -1000,12 +1000,12 @@ namespace System.Net.Http headerBuffer.Commit(bytesWritten); } - private void WriteLiteralHeaderValue(string value, ref ArrayBuffer headerBuffer) + private void WriteLiteralHeaderValue(string value, Encoding? valueEncoding, ref ArrayBuffer headerBuffer) { if (NetEventSource.Log.IsEnabled()) Trace($"{nameof(value)}={value}"); int bytesWritten; - while (!HPackEncoder.EncodeStringLiteral(value, headerBuffer.AvailableSpan, out bytesWritten)) + while (!HPackEncoder.EncodeStringLiteral(value, valueEncoding, headerBuffer.AvailableSpan, out bytesWritten)) { headerBuffer.EnsureAvailableSpace(headerBuffer.AvailableLength + 1); } @@ -1026,7 +1026,7 @@ namespace System.Net.Http headerBuffer.Commit(bytes.Length); } - private void WriteHeaderCollection(HttpHeaders headers, ref ArrayBuffer headerBuffer) + private void WriteHeaderCollection(HttpRequestMessage request, HttpHeaders headers, ref ArrayBuffer headerBuffer) { if (NetEventSource.Log.IsEnabled()) Trace(""); @@ -1035,6 +1035,8 @@ namespace System.Net.Http return; } + HeaderEncodingSelector? encodingSelector = _pool.Settings._requestHeaderEncodingSelector; + ref string[]? tmpHeaderValuesArray = ref t_headerValues; foreach (KeyValuePair header in headers.HeaderStore) { @@ -1042,6 +1044,8 @@ namespace System.Net.Http Debug.Assert(headerValuesCount > 0, "No values for header??"); ReadOnlySpan headerValues = tmpHeaderValuesArray.AsSpan(0, headerValuesCount); + Encoding? valueEncoding = encodingSelector?.Invoke(header.Key.Name, request); + KnownHeader? knownHeader = header.Key.KnownHeader; if (knownHeader != null) { @@ -1058,7 +1062,7 @@ namespace System.Net.Http if (string.Equals(value, "trailers", StringComparison.OrdinalIgnoreCase)) { WriteBytes(knownHeader.Http2EncodedName, ref headerBuffer); - WriteLiteralHeaderValue(value, ref headerBuffer); + WriteLiteralHeaderValue(value, valueEncoding, ref headerBuffer); break; } } @@ -1081,13 +1085,13 @@ namespace System.Net.Http } } - WriteLiteralHeaderValues(headerValues, separator, ref headerBuffer); + WriteLiteralHeaderValues(headerValues, separator, valueEncoding, ref headerBuffer); } } else { // The header is not known: fall back to just encoding the header name and value(s). - WriteLiteralHeader(header.Key.Name, headerValues, ref headerBuffer); + WriteLiteralHeader(header.Key.Name, headerValues, valueEncoding, ref headerBuffer); } } } @@ -1142,7 +1146,7 @@ namespace System.Net.Http if (request.HasHeaders) { - WriteHeaderCollection(request.Headers, ref headerBuffer); + WriteHeaderCollection(request, request.Headers, ref headerBuffer); } // Determine cookies to send. @@ -1152,7 +1156,9 @@ namespace System.Net.Http if (cookiesFromContainer != string.Empty) { WriteBytes(KnownHeaders.Cookie.Http2EncodedName, ref headerBuffer); - WriteLiteralHeaderValue(cookiesFromContainer, ref headerBuffer); + + Encoding? cookieEncoding = _pool.Settings._requestHeaderEncodingSelector?.Invoke(KnownHeaders.Cookie.Name, request); + WriteLiteralHeaderValue(cookiesFromContainer, cookieEncoding, ref headerBuffer); } } @@ -1163,12 +1169,12 @@ namespace System.Net.Http if (normalizedMethod.MustHaveRequestBody) { WriteBytes(KnownHeaders.ContentLength.Http2EncodedName, ref headerBuffer); - WriteLiteralHeaderValue("0", ref headerBuffer); + WriteLiteralHeaderValue("0", valueEncoding: null, ref headerBuffer); } } else { - WriteHeaderCollection(request.Content.Headers, ref headerBuffer); + WriteHeaderCollection(request, request.Content.Headers, ref headerBuffer); } } diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs index 5ac829b..254fec2 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs @@ -542,24 +542,26 @@ namespace System.Net.Http throw new HttpRequestException(SR.Format(SR.net_http_invalid_response_header_name, Encoding.ASCII.GetString(name))); } + Encoding? valueEncoding = _connection._pool.Settings._responseHeaderEncodingSelector?.Invoke(descriptor.Name, _request); + // Note we ignore the return value from TryAddWithoutValidation; // if the header can't be added, we silently drop it. if (_responseProtocolState == ResponseProtocolState.ExpectingTrailingHeaders) { Debug.Assert(_trailers != null); - string headerValue = descriptor.GetHeaderValue(value); + string headerValue = descriptor.GetHeaderValue(value, valueEncoding); _trailers.TryAddWithoutValidation((descriptor.HeaderType & HttpHeaderType.Request) == HttpHeaderType.Request ? descriptor.AsCustomHeader() : descriptor, headerValue); } else if ((descriptor.HeaderType & HttpHeaderType.Content) == HttpHeaderType.Content) { Debug.Assert(_response != null && _response.Content != null); - string headerValue = descriptor.GetHeaderValue(value); + string headerValue = descriptor.GetHeaderValue(value, valueEncoding); _response.Content.Headers.TryAddWithoutValidation(descriptor, headerValue); } else { Debug.Assert(_response != null); - string headerValue = _connection.GetResponseHeaderValueWithCaching(descriptor, value); + string headerValue = _connection.GetResponseHeaderValueWithCaching(descriptor, value, valueEncoding); _response.Headers.TryAddWithoutValidation((descriptor.HeaderType & HttpHeaderType.Request) == HttpHeaderType.Request ? descriptor.AsCustomHeader() : descriptor, headerValue); } } diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs index 07c2840..72e8090 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs @@ -548,7 +548,8 @@ namespace System.Net.Http string cookiesFromContainer = _connection.Pool.Settings._cookieContainer!.GetCookieHeader(request.RequestUri); if (cookiesFromContainer != string.Empty) { - BufferLiteralHeaderWithStaticNameReference(H3StaticTable.Cookie, cookiesFromContainer); + Encoding? valueEncoding = _connection.Pool.Settings._requestHeaderEncodingSelector?.Invoke(HttpKnownHeaderNames.Cookie, request); + BufferLiteralHeaderWithStaticNameReference(H3StaticTable.Cookie, cookiesFromContainer, valueEncoding); } } @@ -590,12 +591,16 @@ namespace System.Net.Http return; } + HeaderEncodingSelector? encodingSelector = _connection.Pool.Settings._requestHeaderEncodingSelector; + foreach (KeyValuePair header in headers.HeaderStore) { int headerValuesCount = HttpHeaders.GetValuesAsStrings(header.Key, header.Value, ref _headerValues); Debug.Assert(headerValuesCount > 0, "No values for header??"); ReadOnlySpan headerValues = _headerValues.AsSpan(0, headerValuesCount); + Encoding? valueEncoding = encodingSelector?.Invoke(header.Key.Name, _request); + KnownHeader? knownHeader = header.Key.KnownHeader; if (knownHeader != null) { @@ -612,7 +617,7 @@ namespace System.Net.Http { if (string.Equals(value, "trailers", StringComparison.OrdinalIgnoreCase)) { - BufferLiteralHeaderWithoutNameReference("TE", value); + BufferLiteralHeaderWithoutNameReference("TE", value, valueEncoding); break; } } @@ -635,13 +640,13 @@ namespace System.Net.Http } } - BufferLiteralHeaderValues(headerValues, separator); + BufferLiteralHeaderValues(headerValues, separator, valueEncoding); } } else { // The header is not known: fall back to just encoding the header name and value(s). - BufferLiteralHeaderWithoutNameReference(header.Key.Name, headerValues, ", "); + BufferLiteralHeaderWithoutNameReference(header.Key.Name, headerValues, HttpHeaderParser.DefaultSeparator, valueEncoding); } } } @@ -656,40 +661,40 @@ namespace System.Net.Http _sendBuffer.Commit(bytesWritten); } - private void BufferLiteralHeaderWithStaticNameReference(int nameIndex, string value) + private void BufferLiteralHeaderWithStaticNameReference(int nameIndex, string value, Encoding? valueEncoding = null) { int bytesWritten; - while (!QPackEncoder.EncodeLiteralHeaderFieldWithStaticNameReference(nameIndex, value, _sendBuffer.AvailableSpan, out bytesWritten)) + while (!QPackEncoder.EncodeLiteralHeaderFieldWithStaticNameReference(nameIndex, value, valueEncoding, _sendBuffer.AvailableSpan, out bytesWritten)) { _sendBuffer.Grow(); } _sendBuffer.Commit(bytesWritten); } - private void BufferLiteralHeaderWithoutNameReference(string name, ReadOnlySpan values, string separator) + private void BufferLiteralHeaderWithoutNameReference(string name, ReadOnlySpan values, string separator, Encoding? valueEncoding) { int bytesWritten; - while (!QPackEncoder.EncodeLiteralHeaderFieldWithoutNameReference(name, values, separator, _sendBuffer.AvailableSpan, out bytesWritten)) + while (!QPackEncoder.EncodeLiteralHeaderFieldWithoutNameReference(name, values, separator, valueEncoding, _sendBuffer.AvailableSpan, out bytesWritten)) { _sendBuffer.Grow(); } _sendBuffer.Commit(bytesWritten); } - private void BufferLiteralHeaderWithoutNameReference(string name, string value) + private void BufferLiteralHeaderWithoutNameReference(string name, string value, Encoding? valueEncoding) { int bytesWritten; - while (!QPackEncoder.EncodeLiteralHeaderFieldWithoutNameReference(name, value, _sendBuffer.AvailableSpan, out bytesWritten)) + while (!QPackEncoder.EncodeLiteralHeaderFieldWithoutNameReference(name, value, valueEncoding, _sendBuffer.AvailableSpan, out bytesWritten)) { _sendBuffer.Grow(); } _sendBuffer.Commit(bytesWritten); } - private void BufferLiteralHeaderValues(ReadOnlySpan values, string? separator) + private void BufferLiteralHeaderValues(ReadOnlySpan values, string? separator, Encoding? valueEncoding) { int bytesWritten; - while (!QPackEncoder.EncodeValueString(values, separator, _sendBuffer.AvailableSpan, out bytesWritten)) + while (!QPackEncoder.EncodeValueString(values, separator, valueEncoding, _sendBuffer.AvailableSpan, out bytesWritten)) { _sendBuffer.Grow(); } @@ -917,7 +922,13 @@ namespace System.Net.Http } else { - string headerValue = staticValue ?? _connection.GetResponseHeaderValueWithCaching(descriptor, literalValue); + string? headerValue = staticValue; + + if (headerValue is null) + { + Encoding? encoding = _connection.Pool.Settings._responseHeaderEncodingSelector?.Invoke(descriptor.Name, _request); + headerValue = _connection.GetResponseHeaderValueWithCaching(descriptor, literalValue, encoding); + } switch (_headerState) { diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs index 33383a1..131af80 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs @@ -197,6 +197,8 @@ namespace System.Net.Http private async ValueTask WriteHeadersAsync(HttpHeaders headers, string? cookiesFromContainer, bool async) { + Debug.Assert(_currentRequest != null); + if (headers.HeaderStore != null) { foreach (KeyValuePair header in headers.HeaderStore) @@ -215,12 +217,14 @@ namespace System.Net.Http Debug.Assert(headerValuesCount > 0, "No values for header??"); if (headerValuesCount > 0) { - await WriteStringAsync(_headerValues[0], async).ConfigureAwait(false); + Encoding? valueEncoding = _pool.Settings._requestHeaderEncodingSelector?.Invoke(header.Key.Name, _currentRequest); + + await WriteStringAsync(_headerValues[0], async, valueEncoding).ConfigureAwait(false); if (cookiesFromContainer != null && header.Key.KnownHeader == KnownHeaders.Cookie) { await WriteTwoBytesAsync((byte)';', (byte)' ', async).ConfigureAwait(false); - await WriteStringAsync(cookiesFromContainer, async).ConfigureAwait(false); + await WriteStringAsync(cookiesFromContainer, async, valueEncoding).ConfigureAwait(false); cookiesFromContainer = null; } @@ -238,7 +242,7 @@ namespace System.Net.Http for (int i = 1; i < headerValuesCount; i++) { await WriteAsciiStringAsync(separator, async).ConfigureAwait(false); - await WriteStringAsync(_headerValues[i], async).ConfigureAwait(false); + await WriteStringAsync(_headerValues[i], async, valueEncoding).ConfigureAwait(false); } } } @@ -251,7 +255,10 @@ namespace System.Net.Http { await WriteAsciiStringAsync(HttpKnownHeaderNames.Cookie, async).ConfigureAwait(false); await WriteTwoBytesAsync((byte)':', (byte)' ', async).ConfigureAwait(false); - await WriteStringAsync(cookiesFromContainer, async).ConfigureAwait(false); + + Encoding? valueEncoding = _pool.Settings._requestHeaderEncodingSelector?.Invoke(HttpKnownHeaderNames.Cookie, _currentRequest); + await WriteStringAsync(cookiesFromContainer, async, valueEncoding).ConfigureAwait(false); + await WriteTwoBytesAsync((byte)'\r', (byte)'\n', async).ConfigureAwait(false); } } @@ -984,22 +991,25 @@ namespace System.Net.Http pos++; } + Debug.Assert(response.RequestMessage != null); + Encoding? valueEncoding = connection._pool.Settings._responseHeaderEncodingSelector?.Invoke(descriptor.Name, response.RequestMessage); + // Note we ignore the return value from TryAddWithoutValidation. If the header can't be added, we silently drop it. ReadOnlySpan value = line.Slice(pos); if (isFromTrailer) { - string headerValue = descriptor.GetHeaderValue(value); + string headerValue = descriptor.GetHeaderValue(value, valueEncoding); response.TrailingHeaders.TryAddWithoutValidation((descriptor.HeaderType & HttpHeaderType.Request) == HttpHeaderType.Request ? descriptor.AsCustomHeader() : descriptor, headerValue); } else if ((descriptor.HeaderType & HttpHeaderType.Content) == HttpHeaderType.Content) { - string headerValue = descriptor.GetHeaderValue(value); + string headerValue = descriptor.GetHeaderValue(value, valueEncoding); response.Content!.Headers.TryAddWithoutValidation(descriptor, headerValue); } else { // Request headers returned on the response must be treated as custom headers. - string headerValue = connection.GetResponseHeaderValueWithCaching(descriptor, value); + string headerValue = connection.GetResponseHeaderValueWithCaching(descriptor, value, valueEncoding); response.Headers.TryAddWithoutValidation( (descriptor.HeaderType & HttpHeaderType.Request) == HttpHeaderType.Request ? descriptor.AsCustomHeader() : descriptor, headerValue); @@ -1182,23 +1192,23 @@ namespace System.Net.Http _writeOffset += bytes.Length; return Task.CompletedTask; } - return WriteBytesSlowAsync(bytes, async); + return WriteBytesSlowAsync(bytes, bytes.Length, async); } - private async Task WriteBytesSlowAsync(byte[] bytes, bool async) + private async Task WriteBytesSlowAsync(byte[] bytes, int length, bool async) { int offset = 0; while (true) { - int remaining = bytes.Length - offset; + int remaining = length - offset; int toCopy = Math.Min(remaining, _writeBuffer.Length - _writeOffset); Buffer.BlockCopy(bytes, offset, _writeBuffer, _writeOffset, toCopy); _writeOffset += toCopy; offset += toCopy; - Debug.Assert(offset <= bytes.Length, $"Expected {nameof(offset)} to be <= {bytes.Length}, got {offset}"); + Debug.Assert(offset <= length, $"Expected {nameof(offset)} to be <= {length}, got {offset}"); Debug.Assert(_writeOffset <= _writeBuffer.Length, $"Expected {nameof(_writeOffset)} to be <= {_writeBuffer.Length}, got {_writeOffset}"); - if (offset == bytes.Length) + if (offset == length) { break; } @@ -1235,6 +1245,43 @@ namespace System.Net.Http return WriteStringAsyncSlow(s, async); } + private Task WriteStringAsync(string s, bool async, Encoding? encoding) + { + if (encoding is null) + { + return WriteStringAsync(s, async); + } + + // If there's enough space in the buffer to just copy all of the string's bytes, do so. + if (encoding.GetMaxByteCount(s.Length) <= _writeBuffer.Length - _writeOffset) + { + _writeOffset += encoding.GetBytes(s, _writeBuffer.AsSpan(_writeOffset)); + return Task.CompletedTask; + } + + // Otherwise, fall back to doing a normal slow string write + return WriteStringWithEncodingAsyncSlow(s, async, encoding); + } + + private async Task WriteStringWithEncodingAsyncSlow(string s, bool async, Encoding encoding) + { + // Avoid calculating the length if the rented array would be small anyway + int length = s.Length <= 512 + ? encoding.GetMaxByteCount(s.Length) + : encoding.GetByteCount(s); + + byte[] rentedBuffer = ArrayPool.Shared.Rent(length); + try + { + int written = encoding.GetBytes(s, rentedBuffer); + await WriteBytesSlowAsync(rentedBuffer, written, async).ConfigureAwait(false); + } + finally + { + ArrayPool.Shared.Return(rentedBuffer); + } + } + private Task WriteAsciiStringAsync(string s, bool async) { // If there's enough space in the buffer to just copy all of the string's bytes, do so. diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionBase.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionBase.cs index b4b0aa1..df8191b 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionBase.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionBase.cs @@ -7,6 +7,7 @@ using System.IO; using System.Net.Http.Headers; using System.Net.Security; using System.Runtime.CompilerServices; +using System.Text; using System.Threading; using System.Threading.Tasks; @@ -20,19 +21,19 @@ namespace System.Net.Http private string? _lastServerHeaderValue; /// Uses , but first special-cases several known headers for which we can use caching. - public string GetResponseHeaderValueWithCaching(HeaderDescriptor descriptor, ReadOnlySpan value) + public string GetResponseHeaderValueWithCaching(HeaderDescriptor descriptor, ReadOnlySpan value, Encoding? valueEncoding) { return - ReferenceEquals(descriptor.KnownHeader, KnownHeaders.Date) ? GetOrAddCachedValue(ref _lastDateHeaderValue, descriptor, value) : - ReferenceEquals(descriptor.KnownHeader, KnownHeaders.Server) ? GetOrAddCachedValue(ref _lastServerHeaderValue, descriptor, value) : - descriptor.GetHeaderValue(value); + ReferenceEquals(descriptor.KnownHeader, KnownHeaders.Date) ? GetOrAddCachedValue(ref _lastDateHeaderValue, descriptor, value, valueEncoding) : + ReferenceEquals(descriptor.KnownHeader, KnownHeaders.Server) ? GetOrAddCachedValue(ref _lastServerHeaderValue, descriptor, value, valueEncoding) : + descriptor.GetHeaderValue(value, valueEncoding); - static string GetOrAddCachedValue([NotNull] ref string? cache, HeaderDescriptor descriptor, ReadOnlySpan value) + static string GetOrAddCachedValue([NotNull] ref string? cache, HeaderDescriptor descriptor, ReadOnlySpan value, Encoding? encoding) { string? lastValue = cache; if (lastValue is null || !ByteArrayHelpers.EqualsOrdinalAscii(lastValue, value)) { - cache = lastValue = descriptor.GetHeaderValue(value); + cache = lastValue = descriptor.GetHeaderValue(value, encoding); } return lastValue; } diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionSettings.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionSettings.cs index 558e6cf..507af88 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionSettings.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionSettings.cs @@ -46,6 +46,9 @@ namespace System.Net.Http internal TimeSpan _expect100ContinueTimeout = HttpHandlerDefaults.DefaultExpect100ContinueTimeout; internal TimeSpan _connectTimeout = HttpHandlerDefaults.DefaultConnectTimeout; + internal HeaderEncodingSelector? _requestHeaderEncodingSelector; + internal HeaderEncodingSelector? _responseHeaderEncodingSelector; + internal Version _maxHttpVersion; internal bool _allowUnencryptedHttp2; @@ -110,6 +113,8 @@ namespace System.Net.Http _useProxy = _useProxy, _allowUnencryptedHttp2 = _allowUnencryptedHttp2, _assumePrenegotiatedHttp3ForTesting = _assumePrenegotiatedHttp3ForTesting, + _requestHeaderEncodingSelector = _requestHeaderEncodingSelector, + _responseHeaderEncodingSelector = _responseHeaderEncodingSelector, _enableMultipleHttp2Connections = _enableMultipleHttp2Connections, _connectionFactory = _connectionFactory, _plaintextFilter = _plaintextFilter diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/SocketsHttpHandler.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/SocketsHttpHandler.cs index de391e9..8eb5a92 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/SocketsHttpHandler.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/SocketsHttpHandler.cs @@ -3,11 +3,12 @@ using System.Collections.Generic; using System.Diagnostics; +using System.Net.Connections; using System.Net.Security; using System.Threading; using System.Threading.Tasks; using System.Diagnostics.CodeAnalysis; -using System.Net.Connections; +using System.Text; namespace System.Net.Http { @@ -319,6 +320,34 @@ namespace System.Net.Http public IDictionary Properties => _settings._properties ?? (_settings._properties = new Dictionary()); + /// + /// Gets or sets a callback that returns the to encode the value for the specified request header name, + /// or to use the default behavior. + /// + public HeaderEncodingSelector? RequestHeaderEncodingSelector + { + get => _settings._requestHeaderEncodingSelector; + set + { + CheckDisposedOrStarted(); + _settings._requestHeaderEncodingSelector = value; + } + } + + /// + /// Gets or sets a callback that returns the to decode the value for the specified response header name, + /// or to use the default behavior. + /// + public HeaderEncodingSelector? ResponseHeaderEncodingSelector + { + get => _settings._responseHeaderEncodingSelector; + set + { + CheckDisposedOrStarted(); + _settings._responseHeaderEncodingSelector = value; + } + } + protected override void Dispose(bool disposing) { if (disposing && !_disposed) diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Headers.cs b/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Headers.cs index afffb68..3ed0f55 100644 --- a/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Headers.cs +++ b/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Headers.cs @@ -6,6 +6,7 @@ using System.IO; using System.Linq; using System.Net.Http.Headers; using System.Net.Test.Common; +using System.Text; using System.Threading.Tasks; using Xunit; @@ -332,5 +333,132 @@ namespace System.Net.Http.Functional.Tests }); }); } + + private static readonly (string Name, Encoding ValueEncoding, string[] Values)[] s_nonAsciiHeaders = new[] + { + ("foo", Encoding.ASCII, new[] { "bar" }), + ("header-0", Encoding.UTF8, new[] { "\uD83D\uDE03", "\uD83D\uDE48\uD83D\uDE49\uD83D\uDE4A" }), + ("Cache-Control", Encoding.UTF8, new[] { "no-cache" }), + ("header-1", Encoding.UTF8, new[] { "\uD83D\uDE03" }), + ("Some-Header1", Encoding.Latin1, new[] { "\uD83D\uDE03", "UTF8-best-fit-to-latin1" }), + ("Some-Header2", Encoding.Latin1, new[] { "\u00FF", "\u00C4nd", "Ascii\u00A9" }), + ("Some-Header3", Encoding.ASCII, new[] { "\u00FF", "\u00C4nd", "Ascii\u00A9", "Latin1-best-fit-to-ascii" }), + ("header-2", Encoding.UTF8, new[] { "\uD83D\uDE48\uD83D\uDE49\uD83D\uDE4A" }), + ("header-3", Encoding.UTF8, new[] { "\uFFFD" }), + ("header-4", Encoding.UTF8, new[] { "\uD83D\uDE48\uD83D\uDE49\uD83D\uDE4A", "\uD83D\uDE03" }), + ("Cookie", Encoding.UTF8, new[] { "Cookies", "\uD83C\uDF6A", "everywhere" }), + ("Set-Cookie", Encoding.UTF8, new[] { "\uD83C\uDDF8\uD83C\uDDEE" }), + ("header-5", Encoding.UTF8, new[] { "\uD83D\uDE48\uD83D\uDE49\uD83D\uDE4A", "foo", "\uD83D\uDE03", "bar" }), + ("bar", Encoding.UTF8, new[] { "foo" }) + }; + + [Fact] + public async Task SendAsync_CustomRequestEncodingSelector_CanSendNonAsciiHeaderValues() + { + await LoopbackServerFactory.CreateClientAndServerAsync( + async uri => + { + var requestMessage = new HttpRequestMessage(HttpMethod.Get, uri) + { + Version = UseVersion + }; + + foreach ((string name, _, string[] values) in s_nonAsciiHeaders) + { + requestMessage.Headers.Add(name, values); + } + + List seenHeaderNames = new List(); + + using HttpClientHandler handler = CreateHttpClientHandler(); + var underlyingHandler = (SocketsHttpHandler)GetUnderlyingSocketsHttpHandler(handler); + + underlyingHandler.RequestHeaderEncodingSelector = (name, request) => + { + Assert.NotNull(name); + Assert.Same(request, requestMessage); + seenHeaderNames.Add(name); + return Assert.Single(s_nonAsciiHeaders, h => h.Name.Equals(name, StringComparison.OrdinalIgnoreCase)).ValueEncoding; + }; + + using HttpClient client = CreateHttpClient(handler); + + await client.SendAsync(requestMessage); + + foreach ((string name, _, _) in s_nonAsciiHeaders) + { + Assert.Contains(name, seenHeaderNames); + } + }, + async server => + { + HttpRequestData requestData = await server.HandleRequestAsync(); + + Assert.All(requestData.Headers, + h => Assert.False(h.HuffmanEncoded, "Expose raw decoded bytes once HuffmanEncoding is supported")); + + foreach ((string name, Encoding valueEncoding, string[] values) in s_nonAsciiHeaders) + { + byte[] valueBytes = valueEncoding.GetBytes(string.Join(", ", values)); + Assert.Single(requestData.Headers, + h => h.Name.Equals(name, StringComparison.OrdinalIgnoreCase) && h.Raw.AsSpan().IndexOf(valueBytes) != -1); + } + }); + } + + [Fact] + public async Task SendAsync_CustomResponseEncodingSelector_CanReceiveNonAsciiHeaderValues() + { + await LoopbackServerFactory.CreateClientAndServerAsync( + async uri => + { + var requestMessage = new HttpRequestMessage(HttpMethod.Get, uri) + { + Version = UseVersion + }; + + List seenHeaderNames = new List(); + + using HttpClientHandler handler = CreateHttpClientHandler(); + var underlyingHandler = (SocketsHttpHandler)GetUnderlyingSocketsHttpHandler(handler); + + underlyingHandler.ResponseHeaderEncodingSelector = (name, request) => + { + Assert.NotNull(name); + Assert.Same(request, requestMessage); + seenHeaderNames.Add(name); + + if (s_nonAsciiHeaders.Any(h => h.Name.Equals(name, StringComparison.OrdinalIgnoreCase))) + { + return Assert.Single(s_nonAsciiHeaders, h => h.Name.Equals(name, StringComparison.OrdinalIgnoreCase)).ValueEncoding; + } + + // Not one of our custom headers + return null; + }; + + using HttpClient client = CreateHttpClient(handler); + + using HttpResponseMessage response = await client.SendAsync(requestMessage); + + foreach ((string name, Encoding valueEncoding, string[] values) in s_nonAsciiHeaders) + { + Assert.Contains(name, seenHeaderNames); + IEnumerable receivedValues = Assert.Single(response.Headers, h => h.Key.Equals(name, StringComparison.OrdinalIgnoreCase)).Value; + string value = Assert.Single(receivedValues); + + string expected = valueEncoding.GetString(valueEncoding.GetBytes(string.Join(", ", values))); + Assert.Equal(expected, value, StringComparer.OrdinalIgnoreCase); + } + }, + async server => + { + List headerData = s_nonAsciiHeaders + .Select(h => new HttpHeaderData(h.Name, string.Join(", ", h.Values), valueEncoding: h.ValueEncoding)) + .ToList(); + + await server.HandleRequestAsync(headers: headerData); + }); + } } } diff --git a/src/libraries/System.Net.Http/tests/UnitTests/HPack/HPackRoundtripTests.cs b/src/libraries/System.Net.Http/tests/UnitTests/HPack/HPackRoundtripTests.cs index 730cab3..3fd84d6 100644 --- a/src/libraries/System.Net.Http/tests/UnitTests/HPack/HPackRoundtripTests.cs +++ b/src/libraries/System.Net.Http/tests/UnitTests/HPack/HPackRoundtripTests.cs @@ -14,25 +14,35 @@ namespace System.Net.Http.Unit.Tests.HPack { public class HPackRoundtripTests { - public static IEnumerable TestHeaders() { - yield return new object[] { new HttpRequestHeaders() { { "header", "value" } } }; - yield return new object[] { new HttpRequestHeaders() { { "header", new[] { "value1", "value2" } } } }; + yield return new object[] { new HttpRequestHeaders() { { "header", "value" } }, null }; + yield return new object[] { new HttpRequestHeaders() { { "header", "value" } }, Encoding.ASCII }; + yield return new object[] { new HttpRequestHeaders() { { "header", new[] { "value1", "value2" } } }, null }; + yield return new object[] { new HttpRequestHeaders() { { "header", new[] { "value1", "value2" } } }, Encoding.ASCII }; yield return new object[] { new HttpRequestHeaders() { { "header-0", new[] { "value1", "value2" } }, { "header-0", "value3" }, { "header-1", "value1" }, { "header-2", new[] { "value1", "value2" } }, - } }; + }, null }; + yield return new object[] { new HttpRequestHeaders() { { "header", "foo" } }, Encoding.UTF8 }; + yield return new object[] { new HttpRequestHeaders() { { "header", "\uD83D\uDE03" } }, Encoding.UTF8 }; + yield return new object[] { new HttpRequestHeaders() + { + { "header-0", new[] { "\uD83D\uDE03", "\uD83D\uDE48\uD83D\uDE49\uD83D\uDE4A" } }, + { "header-1", "\uD83D\uDE03" }, + { "header-2", "\uD83D\uDE48\uD83D\uDE49\uD83D\uDE4A" }, + { "header-3", new[] { "\uD83D\uDE03", "\uD83D\uDE48\uD83D\uDE49\uD83D\uDE4A" } } + }, Encoding.UTF8 }; } [Theory, MemberData(nameof(TestHeaders))] - public void HPack_HeaderEncodeDecodeRoundtrip_ShouldMatchOriginalInput(HttpHeaders headers) + public void HPack_HeaderEncodeDecodeRoundtrip_ShouldMatchOriginalInput(HttpHeaders headers, Encoding? valueEncoding) { - Memory encoding = HPackEncode(headers); - HttpHeaders decodedHeaders = HPackDecode(encoding); + Memory encoding = HPackEncode(headers, valueEncoding); + HttpHeaders decodedHeaders = HPackDecode(encoding, valueEncoding); // Assert: decoded headers are structurally equal to original headers Assert.Equal(headers.Count(), decodedHeaders.Count()); @@ -44,7 +54,7 @@ namespace System.Net.Http.Unit.Tests.HPack } // adapted from Header serialization code in Http2Connection.cs - private static Memory HPackEncode(HttpHeaders headers) + private static Memory HPackEncode(HttpHeaders headers, Encoding? valueEncoding) { var buffer = new ArrayBuffer(4); FillAvailableSpaceWithOnes(buffer); @@ -101,7 +111,7 @@ namespace System.Net.Http.Unit.Tests.HPack void WriteLiteralHeaderValues(ReadOnlySpan values, string separator) { int bytesWritten; - while (!HPackEncoder.EncodeStringLiterals(values, separator, buffer.AvailableSpan, out bytesWritten)) + while (!HPackEncoder.EncodeStringLiterals(values, separator, valueEncoding, buffer.AvailableSpan, out bytesWritten)) { buffer.EnsureAvailableSpace(buffer.AvailableLength + 1); FillAvailableSpaceWithOnes(buffer); @@ -113,7 +123,7 @@ namespace System.Net.Http.Unit.Tests.HPack void WriteLiteralHeader(string name, ReadOnlySpan values) { int bytesWritten; - while (!HPackEncoder.EncodeLiteralHeaderFieldWithoutIndexingNewName(name, values, HttpHeaderParser.DefaultSeparator, buffer.AvailableSpan, out bytesWritten)) + while (!HPackEncoder.EncodeLiteralHeaderFieldWithoutIndexingNewName(name, values, HttpHeaderParser.DefaultSeparator, valueEncoding, buffer.AvailableSpan, out bytesWritten)) { buffer.EnsureAvailableSpace(buffer.AvailableLength + 1); FillAvailableSpaceWithOnes(buffer); @@ -127,12 +137,12 @@ namespace System.Net.Http.Unit.Tests.HPack } // adapted from header deserialization code in Http2Connection.cs - private static HttpHeaders HPackDecode(Memory memory) + private static HttpHeaders HPackDecode(Memory memory, Encoding? valueEncoding) { var header = new HttpRequestHeaders(); var hpackDecoder = new HPackDecoder(maxDynamicTableSize: 0, maxHeadersLength: HttpHandlerDefaults.DefaultMaxResponseHeadersLength * 1024); - hpackDecoder.Decode(memory.Span, true, new HeaderHandler(header)); + hpackDecoder.Decode(memory.Span, true, new HeaderHandler(header, valueEncoding)); return header; } @@ -140,9 +150,12 @@ namespace System.Net.Http.Unit.Tests.HPack private class HeaderHandler : IHttpHeadersHandler { HttpRequestHeaders _headers; - public HeaderHandler(HttpRequestHeaders headers) + Encoding? _valueEncoding; + + public HeaderHandler(HttpRequestHeaders headers, Encoding? valueEncoding) { _headers = headers; + _valueEncoding = valueEncoding; } public void OnHeader(ReadOnlySpan name, ReadOnlySpan value) @@ -152,7 +165,7 @@ namespace System.Net.Http.Unit.Tests.HPack throw new HttpRequestException(SR.Format(SR.net_http_invalid_response_header_name, Encoding.ASCII.GetString(name))); } - string headerValue = descriptor.GetHeaderValue(value); + string headerValue = descriptor.GetHeaderValue(value, _valueEncoding); _headers.TryAddWithoutValidation(descriptor, headerValue.Split(',').Select(x => x.Trim())); } diff --git a/src/libraries/System.Net.Http/tests/UnitTests/Headers/HeaderEncodingTest.cs b/src/libraries/System.Net.Http/tests/UnitTests/Headers/HeaderEncodingTest.cs new file mode 100644 index 0000000..d7f41e0 --- /dev/null +++ b/src/libraries/System.Net.Http/tests/UnitTests/Headers/HeaderEncodingTest.cs @@ -0,0 +1,37 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net.Http.Headers; +using System.Text; +using Xunit; + +namespace System.Net.Http.Tests +{ + public class HeaderEncodingTest + { + [Theory] + [InlineData("")] + [InlineData("foo")] + [InlineData("\uD83D\uDE03")] + [InlineData("\0")] + [InlineData("\x01")] + [InlineData("\xFF")] + [InlineData("\uFFFF")] + [InlineData("\uFFFD")] + [InlineData("\uD83D\uDE48\uD83D\uDE49\uD83D\uDE4A")] + public void RoundTripsUtf8(string input) + { + byte[] encoded = Encoding.UTF8.GetBytes(input); + + Assert.True(HeaderDescriptor.TryGet("custom-header", out HeaderDescriptor descriptor)); + Assert.Null(descriptor.KnownHeader); + string roundtrip = descriptor.GetHeaderValue(encoded, Encoding.UTF8); + Assert.Equal(input, roundtrip); + + Assert.True(HeaderDescriptor.TryGet("Cache-Control", out descriptor)); + Assert.NotNull(descriptor.KnownHeader); + roundtrip = descriptor.GetHeaderValue(encoded, Encoding.UTF8); + Assert.Equal(input, roundtrip); + } + } +} diff --git a/src/libraries/System.Net.Http/tests/UnitTests/Headers/KnownHeadersTest.cs b/src/libraries/System.Net.Http/tests/UnitTests/Headers/KnownHeadersTest.cs index a877602..78fada7 100644 --- a/src/libraries/System.Net.Http/tests/UnitTests/Headers/KnownHeadersTest.cs +++ b/src/libraries/System.Net.Http/tests/UnitTests/Headers/KnownHeadersTest.cs @@ -222,11 +222,11 @@ namespace System.Net.Http.Tests { Assert.NotNull(knownHeader); - string v1 = knownHeader.Descriptor.GetHeaderValue(value.Select(c => (byte)c).ToArray()); + string v1 = knownHeader.Descriptor.GetHeaderValue(value.Select(c => (byte)c).ToArray(), valueEncoding: null); Assert.NotNull(v1); Assert.Equal(value, v1, StringComparer.OrdinalIgnoreCase); - string v2 = knownHeader.Descriptor.GetHeaderValue(value.Select(c => (byte)c).ToArray()); + string v2 = knownHeader.Descriptor.GetHeaderValue(value.Select(c => (byte)c).ToArray(), valueEncoding: null); Assert.Same(v1, v2); } } @@ -239,8 +239,8 @@ namespace System.Net.Http.Tests KnownHeader knownHeader = KnownHeaders.TryGetKnownHeader(name); Assert.NotNull(knownHeader); - string v1 = knownHeader.Descriptor.GetHeaderValue(value.Select(c => (byte)c).ToArray()); - string v2 = knownHeader.Descriptor.GetHeaderValue(value.Select(c => (byte)c).ToArray()); + string v1 = knownHeader.Descriptor.GetHeaderValue(value.Select(c => (byte)c).ToArray(), valueEncoding: null); + string v2 = knownHeader.Descriptor.GetHeaderValue(value.Select(c => (byte)c).ToArray(), valueEncoding: null); Assert.Equal(value, v1); Assert.Equal(value, v2); Assert.NotSame(v1, v2); diff --git a/src/libraries/System.Net.Http/tests/UnitTests/System.Net.Http.Unit.Tests.csproj b/src/libraries/System.Net.Http/tests/UnitTests/System.Net.Http.Unit.Tests.csproj index 1a1199a..f5419d0 100644 --- a/src/libraries/System.Net.Http/tests/UnitTests/System.Net.Http.Unit.Tests.csproj +++ b/src/libraries/System.Net.Http/tests/UnitTests/System.Net.Http.Unit.Tests.csproj @@ -273,6 +273,7 @@ + -- 2.7.4