Implement Header encoding selectors on SocketsHttpHandler (#39468)
authorMiha Zupan <mihazupan.zupan1@gmail.com>
Thu, 30 Jul 2020 08:36:38 +0000 (10:36 +0200)
committerGitHub <noreply@github.com>
Thu, 30 Jul 2020 08:36:38 +0000 (10:36 +0200)
* Expose HeaderEncodingSelectors on SocketsHttpHandler

* Implement header encoding selectors in SocketsHttpHandler

* Add header encoding tests

* Add summaries for new APIs

* Use Stream.Write(byte[], int, int) overloads for framework compat

* Add dummy API implementation to browser target

* HPack/QPack fixes

* Move HeaderEncodingSelector delegate to namespace, add TContext

* Encoding fixes

* Remove unused using

* Simplify test

* HttpConnection PR feedback

* Simplify fast-path detection

25 files changed:
src/libraries/Common/src/System/Net/Http/aspnetcore/Http2/Hpack/HPackEncoder.cs
src/libraries/Common/src/System/Net/Http/aspnetcore/Http3/QPack/QPackEncoder.cs
src/libraries/Common/tests/System/Net/Http/GenericLoopbackServer.cs
src/libraries/Common/tests/System/Net/Http/HPackEncoder.cs
src/libraries/Common/tests/System/Net/Http/Http2LoopbackConnection.cs
src/libraries/Common/tests/System/Net/Http/Http3LoopbackStream.cs
src/libraries/Common/tests/System/Net/Http/LoopbackServer.cs
src/libraries/Common/tests/System/Net/Http/QPackTestEncoder.cs
src/libraries/System.Net.Http/ref/System.Net.Http.cs
src/libraries/System.Net.Http/src/System.Net.Http.csproj
src/libraries/System.Net.Http/src/System/Net/Http/BrowserHttpHandler/SocketsHttpHandler.cs
src/libraries/System.Net.Http/src/System/Net/Http/HeaderEncodingSelector.cs [new file with mode: 0644]
src/libraries/System.Net.Http/src/System/Net/Http/Headers/HeaderDescriptor.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionBase.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionSettings.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/SocketsHttpHandler.cs
src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Headers.cs
src/libraries/System.Net.Http/tests/UnitTests/HPack/HPackRoundtripTests.cs
src/libraries/System.Net.Http/tests/UnitTests/Headers/HeaderEncodingTest.cs [new file with mode: 0644]
src/libraries/System.Net.Http/tests/UnitTests/Headers/KnownHeadersTest.cs
src/libraries/System.Net.Http/tests/UnitTests/System.Net.Http.Unit.Tests.csproj

index d2fbc52..4c3ac29 100644 (file)
@@ -4,6 +4,7 @@
 #nullable enable
 using System.Collections.Generic;
 using System.Diagnostics;
+using System.Text;
 
 namespace System.Net.Http.HPack
 {
@@ -96,7 +97,7 @@ namespace System.Net.Http.HPack
                 if (IntegerEncoder.Encode(index, 4, destination, out int indexLength))
                 {
                     Debug.Assert(indexLength >= 1);
-                    if (EncodeStringLiteral(value, destination.Slice(indexLength), out int nameLength))
+                    if (EncodeStringLiteral(value, valueEncoding: null, destination.Slice(indexLength), out int nameLength))
                     {
                         bytesWritten = indexLength + nameLength;
                         return true;
@@ -128,7 +129,7 @@ namespace System.Net.Http.HPack
                 if (IntegerEncoder.Encode(index, 4, destination, out int indexLength))
                 {
                     Debug.Assert(indexLength >= 1);
-                    if (EncodeStringLiteral(value, destination.Slice(indexLength), out int nameLength))
+                    if (EncodeStringLiteral(value, valueEncoding: null, destination.Slice(indexLength), out int nameLength))
                     {
                         bytesWritten = indexLength + nameLength;
                         return true;
@@ -160,7 +161,7 @@ namespace System.Net.Http.HPack
                 if (IntegerEncoder.Encode(index, 6, destination, out int indexLength))
                 {
                     Debug.Assert(indexLength >= 1);
-                    if (EncodeStringLiteral(value, destination.Slice(indexLength), out int nameLength))
+                    if (EncodeStringLiteral(value, valueEncoding: null, destination.Slice(indexLength), out int nameLength))
                     {
                         bytesWritten = indexLength + nameLength;
                         return true;
@@ -276,7 +277,7 @@ namespace System.Net.Http.HPack
             {
                 destination[0] = mask;
                 if (EncodeLiteralHeaderName(name, destination.Slice(1), out int nameLength) &&
-                    EncodeStringLiteral(value, destination.Slice(1 + nameLength), out int valueLength))
+                    EncodeStringLiteral(value, valueEncoding: null, destination.Slice(1 + nameLength), out int valueLength))
                 {
                     bytesWritten = 1 + nameLength + valueLength;
                     return true;
@@ -290,6 +291,11 @@ namespace System.Net.Http.HPack
         /// <summary>Encodes a "Literal Header Field without Indexing - New Name".</summary>
         public static bool EncodeLiteralHeaderFieldWithoutIndexingNewName(string name, ReadOnlySpan<string> values, string separator, Span<byte> destination, out int bytesWritten)
         {
+            return EncodeLiteralHeaderFieldWithoutIndexingNewName(name, values, separator, valueEncoding: null, destination, out bytesWritten);
+        }
+
+        public static bool EncodeLiteralHeaderFieldWithoutIndexingNewName(string name, ReadOnlySpan<string> values, string separator, Encoding? valueEncoding, Span<byte> destination, out int bytesWritten)
+        {
             // From https://tools.ietf.org/html/rfc7541#section-6.2.2
             // ------------------------------------------------------
             //   0   1   2   3   4   5   6   7
@@ -309,7 +315,7 @@ namespace System.Net.Http.HPack
             {
                 destination[0] = 0;
                 if (EncodeLiteralHeaderName(name, destination.Slice(1), out int nameLength) &&
-                    EncodeStringLiterals(values, separator, destination.Slice(1 + nameLength), out int valueLength))
+                    EncodeStringLiterals(values, separator, valueEncoding, destination.Slice(1 + nameLength), out int valueLength))
                 {
                     bytesWritten = 1 + nameLength + valueLength;
                     return true;
@@ -395,27 +401,20 @@ namespace System.Net.Http.HPack
             return false;
         }
 
-        private static bool EncodeStringLiteralValue(string value, Span<byte> destination, out int bytesWritten)
+        private static void EncodeValueStringPart(string value, Span<byte> destination)
         {
-            if (value.Length <= destination.Length)
+            Debug.Assert(destination.Length >= value.Length);
+
+            for (int i = 0; i < value.Length; i++)
             {
-                for (int i = 0; i < value.Length; i++)
+                char c = value[i];
+                if ((c & 0xFF80) != 0)
                 {
-                    char c = value[i];
-                    if ((c & 0xFF80) != 0)
-                    {
-                        throw new HttpRequestException(SR.net_http_request_invalid_char_encoding);
-                    }
-
-                    destination[i] = (byte)c;
+                    throw new HttpRequestException(SR.net_http_request_invalid_char_encoding);
                 }
 
-                bytesWritten = value.Length;
-                return true;
+                destination[i] = (byte)c;
             }
-
-            bytesWritten = 0;
-            return false;
         }
 
         public static bool EncodeStringLiteral(ReadOnlySpan<byte> value, Span<byte> destination, out int bytesWritten)
@@ -454,6 +453,11 @@ namespace System.Net.Http.HPack
 
         public static bool EncodeStringLiteral(string value, Span<byte> destination, out int bytesWritten)
         {
+            return EncodeStringLiteral(value, valueEncoding: null, destination, out bytesWritten);
+        }
+
+        public static bool EncodeStringLiteral(string value, Encoding? valueEncoding, Span<byte> destination, out int bytesWritten)
+        {
             // From https://tools.ietf.org/html/rfc7541#section-5.2
             // ------------------------------------------------------
             //   0   1   2   3   4   5   6   7
@@ -466,13 +470,28 @@ namespace System.Net.Http.HPack
             if (destination.Length != 0)
             {
                 destination[0] = 0; // TODO: Use Huffman encoding
-                if (IntegerEncoder.Encode(value.Length, 7, destination, out int integerLength))
+
+                int encodedStringLength = valueEncoding is null || ReferenceEquals(valueEncoding, Encoding.Latin1)
+                    ? value.Length
+                    : valueEncoding.GetByteCount(value);
+
+                if (IntegerEncoder.Encode(encodedStringLength, 7, destination, out int integerLength))
                 {
                     Debug.Assert(integerLength >= 1);
-
-                    if (EncodeStringLiteralValue(value, destination.Slice(integerLength), out int valueLength))
+                    destination = destination.Slice(integerLength);
+                    if (encodedStringLength <= destination.Length)
                     {
-                        bytesWritten = integerLength + valueLength;
+                        if (valueEncoding is null)
+                        {
+                            EncodeValueStringPart(value, destination);
+                        }
+                        else
+                        {
+                            int written = valueEncoding.GetBytes(value, destination);
+                            Debug.Assert(written == encodedStringLength);
+                        }
+
+                        bytesWritten = integerLength + encodedStringLength;
                         return true;
                     }
                 }
@@ -503,55 +522,86 @@ namespace System.Net.Http.HPack
 
         public static bool EncodeStringLiterals(ReadOnlySpan<string> values, string? separator, Span<byte> destination, out int bytesWritten)
         {
+            return EncodeStringLiterals(values, separator, valueEncoding: null, destination, out bytesWritten);
+        }
+
+        public static bool EncodeStringLiterals(ReadOnlySpan<string> values, string? separator, Encoding? valueEncoding, Span<byte> destination, out int bytesWritten)
+        {
             bytesWritten = 0;
 
             if (values.Length == 0)
             {
-                return EncodeStringLiteral("", destination, out bytesWritten);
+                return EncodeStringLiteral("", valueEncoding: null, destination, out bytesWritten);
             }
             else if (values.Length == 1)
             {
-                return EncodeStringLiteral(values[0], destination, out bytesWritten);
+                return EncodeStringLiteral(values[0], valueEncoding, destination, out bytesWritten);
             }
 
             if (destination.Length != 0)
             {
-                int valueLength = 0;
+                Debug.Assert(separator != null);
+                int valueLength;
 
                 // Calculate length of all parts and separators.
-                foreach (string part in values)
+                if (valueEncoding is null || ReferenceEquals(valueEncoding, Encoding.Latin1))
                 {
-                    valueLength = checked((int)(valueLength + part.Length));
+                    valueLength = checked((int)(values.Length - 1) * separator.Length);
+                    foreach (string part in values)
+                    {
+                        valueLength = checked((int)(valueLength + part.Length));
+                    }
+                }
+                else
+                {
+                    valueLength = checked((int)(values.Length - 1) * valueEncoding.GetByteCount(separator));
+                    foreach (string part in values)
+                    {
+                        valueLength = checked((int)(valueLength + valueEncoding.GetByteCount(part)));
+                    }
                 }
-
-                Debug.Assert(separator != null);
-                valueLength = checked((int)(valueLength + (values.Length - 1) * separator.Length));
 
                 destination[0] = 0;
                 if (IntegerEncoder.Encode(valueLength, 7, destination, out int integerLength))
                 {
                     Debug.Assert(integerLength >= 1);
-
-                    int encodedLength = 0;
-                    for (int j = 0; j < values.Length; j++)
+                    destination = destination.Slice(integerLength);
+                    if (destination.Length >= valueLength)
                     {
-                        if (j != 0 && !EncodeStringLiteralValue(separator, destination.Slice(integerLength), out encodedLength))
+                        if (valueEncoding is null)
                         {
-                            return false;
+                            string value = values[0];
+                            EncodeValueStringPart(value, destination);
+                            destination = destination.Slice(value.Length);
+
+                            for (int i = 1; i < values.Length; i++)
+                            {
+                                EncodeValueStringPart(separator, destination);
+                                destination = destination.Slice(separator.Length);
+
+                                value = values[i];
+                                EncodeValueStringPart(value, destination);
+                                destination = destination.Slice(value.Length);
+                            }
                         }
+                        else
+                        {
+                            int written = valueEncoding.GetBytes(values[0], destination);
+                            destination = destination.Slice(written);
 
-                        integerLength += encodedLength;
+                            for (int i = 1; i < values.Length; i++)
+                            {
+                                written = valueEncoding.GetBytes(separator, destination);
+                                destination = destination.Slice(written);
 
-                        if (!EncodeStringLiteralValue(values[j], destination.Slice(integerLength), out encodedLength))
-                        {
-                            return false;
+                                written = valueEncoding.GetBytes(values[i], destination);
+                                destination = destination.Slice(written);
+                            }
                         }
 
-                        integerLength += encodedLength;
+                        bytesWritten = integerLength + valueLength;
+                        return true;
                     }
-
-                    bytesWritten = integerLength;
-                    return true;
                 }
             }
 
index be43dc3..68e04ed 100644 (file)
@@ -5,6 +5,7 @@
 using System.Collections.Generic;
 using System.Diagnostics;
 using System.Net.Http.HPack;
+using System.Text;
 
 namespace System.Net.Http.QPack
 {
@@ -60,6 +61,11 @@ namespace System.Net.Http.QPack
         // - H is constant 0 here, as we do not yet perform Huffman coding.
         public static bool EncodeLiteralHeaderFieldWithStaticNameReference(int index, string value, Span<byte> destination, out int bytesWritten)
         {
+            return EncodeLiteralHeaderFieldWithStaticNameReference(index, value, valueEncoding: null, destination, out bytesWritten);
+        }
+
+        public static bool EncodeLiteralHeaderFieldWithStaticNameReference(int index, string value, Encoding? valueEncoding, Span<byte> destination, out int bytesWritten)
+        {
             // Requires at least two bytes (one for name reference header, one for value length)
             if (destination.Length >= 2)
             {
@@ -68,7 +74,7 @@ namespace System.Net.Http.QPack
                 {
                     destination = destination.Slice(headerBytesWritten);
 
-                    if (EncodeValueString(value, destination, out int valueBytesWritten))
+                    if (EncodeValueString(value, valueEncoding, destination, out int valueBytesWritten))
                     {
                         bytesWritten = headerBytesWritten + valueBytesWritten;
                         return true;
@@ -81,7 +87,7 @@ namespace System.Net.Http.QPack
         }
 
         /// <summary>
-        /// Encodes just the name part of a Literal Header Field With Static Name Reference. Must call <see cref="EncodeValueString(string, Span{byte}, out int)"/> after to encode the header's value.
+        /// Encodes just the name part of a Literal Header Field With Static Name Reference. Must call <see cref="EncodeValueString(string, Encoding?, Span{byte}, out int)"/> after to encode the header's value.
         /// </summary>
         public static byte[] EncodeLiteralHeaderFieldWithStaticNameReferenceToArray(int index)
         {
@@ -119,7 +125,12 @@ namespace System.Net.Http.QPack
         // - H is constant 0 here, as we do not yet perform Huffman coding.
         public static bool EncodeLiteralHeaderFieldWithoutNameReference(string name, string value, Span<byte> destination, out int bytesWritten)
         {
-            if (EncodeNameString(name, destination, out int nameLength) && EncodeValueString(value, destination.Slice(nameLength), out int valueLength))
+            return EncodeLiteralHeaderFieldWithoutNameReference(name, value, valueEncoding: null, destination, out bytesWritten);
+        }
+
+        public static bool EncodeLiteralHeaderFieldWithoutNameReference(string name, string value, Encoding? valueEncoding, Span<byte> destination, out int bytesWritten)
+        {
+            if (EncodeNameString(name, destination, out int nameLength) && EncodeValueString(value, valueEncoding, destination.Slice(nameLength), out int valueLength))
             {
                 bytesWritten = nameLength + valueLength;
                 return true;
@@ -136,7 +147,12 @@ namespace System.Net.Http.QPack
         /// </summary>
         public static bool EncodeLiteralHeaderFieldWithoutNameReference(string name, ReadOnlySpan<string> values, string valueSeparator, Span<byte> destination, out int bytesWritten)
         {
-            if (EncodeNameString(name, destination, out int nameLength) && EncodeValueString(values, valueSeparator, destination.Slice(nameLength), out int valueLength))
+            return EncodeLiteralHeaderFieldWithoutNameReference(name, values, valueSeparator, valueEncoding: null, destination, out bytesWritten);
+        }
+
+        public static bool EncodeLiteralHeaderFieldWithoutNameReference(string name, ReadOnlySpan<string> values, string valueSeparator, Encoding? valueEncoding, Span<byte> destination, out int bytesWritten)
+        {
+            if (EncodeNameString(name, destination, out int nameLength) && EncodeValueString(values, valueSeparator, valueEncoding, destination.Slice(nameLength), out int valueLength))
             {
                 bytesWritten = nameLength + valueLength;
                 return true;
@@ -147,7 +163,7 @@ namespace System.Net.Http.QPack
         }
 
         /// <summary>
-        /// Encodes just the value part of a Literawl Header Field Without Static Name Reference. Must call <see cref="EncodeValueString(string, Span{byte}, out int)"/> after to encode the header's value.
+        /// Encodes just the value part of a Literawl Header Field Without Static Name Reference. Must call <see cref="EncodeValueString(string, Encoding?, Span{byte}, out int)"/> after to encode the header's value.
         /// </summary>
         public static byte[] EncodeLiteralHeaderFieldWithoutNameReferenceToArray(string name)
         {
@@ -169,19 +185,32 @@ namespace System.Net.Http.QPack
             return temp.Slice(0, bytesWritten).ToArray();
         }
 
-        private static bool EncodeValueString(string s, Span<byte> buffer, out int length)
+        private static bool EncodeValueString(string s, Encoding? valueEncoding, Span<byte> buffer, out int length)
         {
             if (buffer.Length != 0)
             {
                 buffer[0] = 0;
-                if (IntegerEncoder.Encode(s.Length, 7, buffer, out int nameLength))
+
+                int encodedStringLength = valueEncoding is null || ReferenceEquals(valueEncoding, Encoding.Latin1)
+                    ? s.Length
+                    : valueEncoding.GetByteCount(s);
+
+                if (IntegerEncoder.Encode(encodedStringLength, 7, buffer, out int nameLength))
                 {
                     buffer = buffer.Slice(nameLength);
-                    if (buffer.Length >= s.Length)
+                    if (buffer.Length >= encodedStringLength)
                     {
-                        EncodeValueStringPart(s, buffer);
+                        if (valueEncoding is null)
+                        {
+                            EncodeValueStringPart(s, buffer);
+                        }
+                        else
+                        {
+                            int written = valueEncoding.GetBytes(s, buffer);
+                            Debug.Assert(written == encodedStringLength);
+                        }
 
-                        length = nameLength + s.Length;
+                        length = nameLength + encodedStringLength;
                         return true;
                     }
                 }
@@ -196,24 +225,41 @@ namespace System.Net.Http.QPack
         /// </summary>
         public static bool EncodeValueString(ReadOnlySpan<string> values, string? separator, Span<byte> buffer, out int length)
         {
+            return EncodeValueString(values, separator, valueEncoding: null, buffer, out length);
+        }
+
+        public static bool EncodeValueString(ReadOnlySpan<string> values, string? separator, Encoding? valueEncoding, Span<byte> buffer, out int length)
+        {
             if (values.Length == 1)
             {
-                return EncodeValueString(values[0], buffer, out length);
+                return EncodeValueString(values[0], valueEncoding, buffer, out length);
             }
 
             if (values.Length == 0)
             {
                 // TODO: this will be called with a string array from HttpHeaderCollection. Can we ever get a 0-length array from that? Assert if not.
-                return EncodeValueString(string.Empty, buffer, out length);
+                return EncodeValueString(string.Empty, valueEncoding: null, buffer, out length);
             }
 
             if (buffer.Length > 0)
             {
                 Debug.Assert(separator != null);
-                int valueLength = separator.Length * (values.Length - 1);
-                for (int i = 0; i < values.Length; ++i)
+                int valueLength;
+                if (valueEncoding is null || ReferenceEquals(valueEncoding, Encoding.Latin1))
+                {
+                    valueLength = separator.Length * (values.Length - 1);
+                    foreach (string part in values)
+                    {
+                        valueLength += part.Length;
+                    }
+                }
+                else
                 {
-                    valueLength += values[i].Length;
+                    valueLength = valueEncoding.GetByteCount(separator) * (values.Length - 1);
+                    foreach (string part in values)
+                    {
+                        valueLength += valueEncoding.GetByteCount(part);
+                    }
                 }
 
                 buffer[0] = 0;
@@ -222,18 +268,35 @@ namespace System.Net.Http.QPack
                     buffer = buffer.Slice(nameLength);
                     if (buffer.Length >= valueLength)
                     {
-                        string value = values[0];
-                        EncodeValueStringPart(value, buffer);
-                        buffer = buffer.Slice(value.Length);
-
-                        for (int i = 1; i < values.Length; ++i)
+                        if (valueEncoding is null)
                         {
-                            EncodeValueStringPart(separator, buffer);
-                            buffer = buffer.Slice(separator.Length);
-
-                            value = values[i];
+                            string value = values[0];
                             EncodeValueStringPart(value, buffer);
                             buffer = buffer.Slice(value.Length);
+
+                            for (int i = 1; i < values.Length; i++)
+                            {
+                                EncodeValueStringPart(separator, buffer);
+                                buffer = buffer.Slice(separator.Length);
+
+                                value = values[i];
+                                EncodeValueStringPart(value, buffer);
+                                buffer = buffer.Slice(value.Length);
+                            }
+                        }
+                        else
+                        {
+                            int written = valueEncoding.GetBytes(values[0], buffer);
+                            buffer = buffer.Slice(written);
+
+                            for (int i = 1; i < values.Length; i++)
+                            {
+                                written = valueEncoding.GetBytes(separator, buffer);
+                                buffer = buffer.Slice(written);
+
+                                written = valueEncoding.GetBytes(values[i], buffer);
+                                buffer = buffer.Slice(written);
+                            }
                         }
 
                         length = nameLength + valueLength;
index d15ab88..899b909 100644 (file)
@@ -106,13 +106,15 @@ namespace System.Net.Test.Common
         public string Value { get; }
         public bool HuffmanEncoded { get; }
         public byte[] Raw { get; }
+        public Encoding ValueEncoding { get; }
 
-        public HttpHeaderData(string name, string value, bool huffmanEncoded = false, byte[] raw = null)
+        public HttpHeaderData(string name, string value, bool huffmanEncoded = false, byte[] raw = null, Encoding valueEncoding = null)
         {
             Name = name;
             Value = value;
             HuffmanEncoded = huffmanEncoded;
             Raw = raw;
+            ValueEncoding = valueEncoding;
         }
 
         public override string ToString() => Name == null ? "<empty>" : (Name + ": " + (Value ?? string.Empty));
index bbdf89f..9ecaf44 100644 (file)
@@ -51,7 +51,7 @@ namespace System.Net.Test.Common
         public static int EncodeHeader(int nameIdx, string value, HPackFlags flags, Span<byte> headerBlock)
         {
             Debug.Assert(nameIdx > 0);
-            return EncodeHeaderImpl(nameIdx, null, value, flags, headerBlock);
+            return EncodeHeaderImpl(nameIdx, null, value, valueEncoding: null, flags, headerBlock);
         }
 
         /// <summary>
@@ -63,10 +63,15 @@ namespace System.Net.Test.Common
         /// <returns>The number of bytes written to <paramref name="headerBlock"/>.</returns>
         public static int EncodeHeader(string name, string value, HPackFlags flags, Span<byte> headerBlock)
         {
-            return EncodeHeaderImpl(0, name, value, flags, headerBlock);
+            return EncodeHeader(name, value, valueEncoding: null, flags, headerBlock);
         }
 
-        private static int EncodeHeaderImpl(int nameIdx, string name, string value, HPackFlags flags, Span<byte> headerBlock)
+        public static int EncodeHeader(string name, string value, Encoding valueEncoding, HPackFlags flags, Span<byte> headerBlock)
+        {
+            return EncodeHeaderImpl(0, name, value, valueEncoding, flags, headerBlock);
+        }
+
+        private static int EncodeHeaderImpl(int nameIdx, string name, string value, Encoding valueEncoding, HPackFlags flags, Span<byte> headerBlock)
         {
             const HPackFlags IndexingMask = HPackFlags.NeverIndexed | HPackFlags.NewIndexed | HPackFlags.WithoutIndexing;
 
@@ -97,16 +102,16 @@ namespace System.Net.Test.Common
 
             if (name != null)
             {
-                bytesGenerated += EncodeString(name, headerBlock.Slice(bytesGenerated), (flags & HPackFlags.HuffmanEncodeName) != 0);
+                bytesGenerated += EncodeString(name, Encoding.ASCII, headerBlock.Slice(bytesGenerated), (flags & HPackFlags.HuffmanEncodeName) != 0);
             }
 
-            bytesGenerated += EncodeString(value, headerBlock.Slice(bytesGenerated), (flags & HPackFlags.HuffmanEncodeValue) != 0);
+            bytesGenerated += EncodeString(value, valueEncoding, headerBlock.Slice(bytesGenerated), (flags & HPackFlags.HuffmanEncodeValue) != 0);
             return bytesGenerated;
         }
 
-        public static int EncodeString(string value, Span<byte> headerBlock, bool huffmanEncode)
+        public static int EncodeString(string value, Encoding valueEncoding, Span<byte> headerBlock, bool huffmanEncode)
         {
-            byte[] data = Encoding.ASCII.GetBytes(value);
+            byte[] data = (valueEncoding ?? Encoding.ASCII).GetBytes(value);
             byte prefix;
 
             if (!huffmanEncode)
index ad60410..9ee6479 100644 (file)
@@ -662,7 +662,7 @@ namespace System.Net.Test.Common
             {
                 foreach (HttpHeaderData headerData in headers)
                 {
-                    bytesGenerated += HPackEncoder.EncodeHeader(headerData.Name, headerData.Value, headerData.HuffmanEncoded ? HPackFlags.HuffmanEncode : HPackFlags.None, headerBlock.AsSpan(bytesGenerated));
+                    bytesGenerated += HPackEncoder.EncodeHeader(headerData.Name, headerData.Value, headerData.ValueEncoding, headerData.HuffmanEncoded ? HPackFlags.HuffmanEncode : HPackFlags.None, headerBlock.AsSpan(bytesGenerated));
                 }
             }
 
index f13cbd6..d0e8d43 100644 (file)
@@ -78,7 +78,7 @@ namespace System.Net.Test.Common
 
             foreach (HttpHeaderData header in headers)
             {
-                bytesWritten += QPackTestEncoder.EncodeHeader(buffer.AsSpan(bytesWritten), header.Name, header.Value, header.HuffmanEncoded ? QPackFlags.HuffmanEncode : QPackFlags.None);
+                bytesWritten += QPackTestEncoder.EncodeHeader(buffer.AsSpan(bytesWritten), header.Name, header.Value, header.ValueEncoding, header.HuffmanEncoded ? QPackFlags.HuffmanEncode : QPackFlags.None);
             }
 
             await SendFrameAsync(HeadersFrame, buffer.AsMemory(0, bytesWritten)).ConfigureAwait(false);
index 76c8459..2d34c3b 100644 (file)
@@ -15,6 +15,9 @@ namespace System.Net.Test.Common
 {
     public sealed partial class LoopbackServer : GenericLoopbackServer, IDisposable
     {
+        private static readonly byte[] s_newLineBytes = new byte[] { (byte)'\r', (byte)'\n' };
+        private static readonly byte[] s_colonSpaceBytes = new byte[] { (byte)':', (byte)' ' };
+
         private Socket _listenSocket;
         private Options _options;
         private Uri _uri;
@@ -533,6 +536,16 @@ namespace System.Net.Test.Common
 
             public async Task<string> ReadLineAsync()
             {
+                byte[] lineBytes = await ReadLineBytesAsync().ConfigureAwait(false);
+
+                if (lineBytes is null)
+                    return null;
+
+                return Encoding.ASCII.GetString(lineBytes);
+            }
+
+            private async Task<byte[]> ReadLineBytesAsync()
+            {
                 int index = 0;
                 int startSearch = _readStart;
 
@@ -578,7 +591,7 @@ namespace System.Net.Test.Common
                     if (_readBuffer[_readStart + stringLength] == '\n') { stringLength--; }
                     if (_readBuffer[_readStart + stringLength] == '\r') { stringLength--; }
 
-                    string line = System.Text.Encoding.ASCII.GetString(_readBuffer, _readStart, stringLength + 1);
+                    byte[] line = _readBuffer.AsSpan(_readStart, stringLength + 1).ToArray();
                     _readStart = index + 1;
                     return line;
                 }
@@ -625,6 +638,32 @@ namespace System.Net.Test.Common
                 return lines;
             }
 
+            private async Task<List<byte[]>> ReadRequestHeaderBytesAsync()
+            {
+                var lines = new List<byte[]>();
+
+                byte[] line;
+
+                while (true)
+                {
+                    line = await ReadLineBytesAsync().ConfigureAwait(false);
+
+                    if (line is null || line.Length == 0)
+                    {
+                        break;
+                    }
+
+                    lines.Add(line);
+                }
+
+                if (line == null)
+                {
+                    throw new IOException("Unexpected EOF trying to read request header");
+                }
+
+                return lines;
+            }
+
             public async Task SendResponseAsync(string response)
             {
                 await _writer.WriteAsync(response).ConfigureAwait(false);
@@ -663,24 +702,24 @@ namespace System.Net.Test.Common
 
             public override async Task<HttpRequestData> ReadRequestDataAsync(bool readBody = true)
             {
-                List<string> headerLines = null;
                 HttpRequestData requestData = new HttpRequestData();
 
-                headerLines = await ReadRequestHeaderAsync().ConfigureAwait(false);
+                List<byte[]> headerLines = await ReadRequestHeaderBytesAsync().ConfigureAwait(false);
 
                 // Parse method and path
-                string[] splits = headerLines[0].Split(' ');
+                string[] splits = Encoding.ASCII.GetString(headerLines[0]).Split(' ');
                 requestData.Method = splits[0];
                 requestData.Path = splits[1];
 
                 // Convert header lines to key/value pairs
                 // Skip first line since it's the status line
-                foreach (var line in headerLines.Skip(1))
+                foreach (byte[] lineBytes in headerLines.Skip(1))
                 {
+                    string line = Encoding.ASCII.GetString(lineBytes);
                     int offset = line.IndexOf(':');
                     string name = line.Substring(0, offset);
                     string value = line.Substring(offset + 1).TrimStart();
-                    requestData.Headers.Add(new HttpHeaderData(name, value));
+                    requestData.Headers.Add(new HttpHeaderData(name, value, raw: lineBytes));
                 }
 
                 if (requestData.Method != "GET")
@@ -760,7 +799,7 @@ namespace System.Net.Test.Common
 
             public override async Task SendResponseAsync(HttpStatusCode? statusCode = HttpStatusCode.OK, IList<HttpHeaderData> headers = null, string content = null, bool isFinal = true, int requestId = 0)
             {
-                string headerString = null;
+                MemoryStream headerBytes = new MemoryStream();
                 int contentLength = -1;
                 bool isChunked = false;
                 bool hasContentLength  = false;
@@ -785,22 +824,39 @@ namespace System.Net.Test.Common
                             isChunked = true;
                         }
 
-                        headerString = headerString + $"{headerData.Name}: {headerData.Value}\r\n";
+                        byte[] nameBytes = Encoding.ASCII.GetBytes(headerData.Name);
+                        headerBytes.Write(nameBytes, 0, nameBytes.Length);
+                        headerBytes.Write(s_colonSpaceBytes, 0, s_colonSpaceBytes.Length);
+
+                        byte[] valueBytes = (headerData.ValueEncoding ?? Encoding.ASCII).GetBytes(headerData.Value);
+                        headerBytes.Write(valueBytes, 0, valueBytes.Length);
+                        headerBytes.Write(s_newLineBytes, 0, s_newLineBytes.Length);
                     }
                 }
 
                 bool endHeaders = content != null || isFinal;
                 if (statusCode != null)
                 {
-                    // If we need to send status line, prepped it to headers and possibly add missing headers to the end.
-                    headerString =
+                    byte[] temp = headerBytes.ToArray();
+
+                    headerBytes.SetLength(0);
+
+                    byte[] headerStartBytes = Encoding.ASCII.GetBytes(
                         $"HTTP/1.1 {(int)statusCode} {GetStatusDescription((HttpStatusCode)statusCode)}\r\n" +
-                        (!hasContentLength && !isChunked && content != null ? $"Content-length: {content.Length}\r\n" : "") +
-                        headerString +
-                        (endHeaders ? "\r\n" : "");
+                        (!hasContentLength && !isChunked && content != null ? $"Content-length: {content.Length}\r\n" : ""));
+
+                    headerBytes.Write(headerStartBytes, 0, headerStartBytes.Length);
+                    headerBytes.Write(temp, 0, temp.Length);
+
+                    if (endHeaders)
+                    {
+                        headerBytes.Write(s_newLineBytes, 0, s_newLineBytes.Length);
+                    }
                 }
 
-                await SendResponseAsync(headerString).ConfigureAwait(false);
+                headerBytes.Position = 0;
+                await headerBytes.CopyToAsync(_stream).ConfigureAwait(false);
+
                 if (content != null)
                 {
                     await SendResponseBodyAsync(content, isFinal: isFinal, requestId: requestId).ConfigureAwait(false);
index 3df439e..a73f420 100644 (file)
@@ -49,7 +49,7 @@ namespace System.Net.Test.Common
             return EncodeInteger(buffer, nameValueIdx, prefix, prefixMask);
         }
 
-        public static int EncodeHeader(Span<byte> buffer, int nameIdx, string value, QPackFlags flags = QPackFlags.StaticIndex)
+        public static int EncodeHeader(Span<byte> buffer, int nameIdx, string value, Encoding valueEncoding, QPackFlags flags = QPackFlags.StaticIndex)
         {
             byte prefix, prefixMask;
 
@@ -76,12 +76,12 @@ namespace System.Net.Test.Common
             }
 
             int nameLen = EncodeInteger(buffer, nameIdx, prefix, prefixMask);
-            int valueLen = EncodeString(buffer.Slice(nameLen), value, flags.HasFlag(QPackFlags.HuffmanEncodeValue));
+            int valueLen = EncodeString(buffer.Slice(nameLen), value, valueEncoding, flags.HasFlag(QPackFlags.HuffmanEncodeValue));
 
             return nameLen + valueLen;
         }
 
-        public static int EncodeHeader(Span<byte> buffer, string name, string value, QPackFlags flags = QPackFlags.None)
+        public static int EncodeHeader(Span<byte> buffer, string name, string value, Encoding valueEncoding, QPackFlags flags = QPackFlags.None)
         {
             byte[] data = Encoding.ASCII.GetBytes(name);
             byte prefix;
@@ -116,14 +116,14 @@ namespace System.Net.Test.Common
             bytesGenerated += data.Length;
 
             // write value string.
-            bytesGenerated += EncodeString(buffer.Slice(bytesGenerated), value, flags.HasFlag(QPackFlags.HuffmanEncodeValue));
+            bytesGenerated += EncodeString(buffer.Slice(bytesGenerated), value, valueEncoding, flags.HasFlag(QPackFlags.HuffmanEncodeValue));
 
             return bytesGenerated;
         }
 
-        public static int EncodeString(Span<byte> buffer, string value, bool huffmanCoded = false)
+        public static int EncodeString(Span<byte> buffer, string value, Encoding valueEncoding, bool huffmanCoded = false)
         {
-            return HPackEncoder.EncodeString(value, buffer, huffmanCoded);
+            return HPackEncoder.EncodeString(value, valueEncoding, buffer, huffmanCoded);
         }
 
         public static int EncodeInteger(Span<byte> buffer, int value, byte prefix, byte prefixMask)
index 8e2c957..8f72cd9 100644 (file)
@@ -39,6 +39,7 @@ namespace System.Net.Http
         public FormUrlEncodedContent(System.Collections.Generic.IEnumerable<System.Collections.Generic.KeyValuePair<string?, string?>> nameValueCollection) : base (default(byte[])) { }
         protected override System.Threading.Tasks.Task SerializeToStreamAsync(System.IO.Stream stream, System.Net.TransportContext? context, System.Threading.CancellationToken cancellationToken) { throw null; }
     }
+    public delegate System.Text.Encoding? HeaderEncodingSelector<TContext>(string headerName, TContext context);
     public partial class HttpClient : System.Net.Http.HttpMessageInvoker
     {
         public HttpClient() : base (default(System.Net.Http.HttpMessageHandler)) { }
@@ -346,7 +347,9 @@ namespace System.Net.Http
         public bool PreAuthenticate { get { throw null; } set { } }
         public System.Collections.Generic.IDictionary<string, object?> Properties { get { throw null; } }
         public System.Net.IWebProxy? Proxy { get { throw null; } set { } }
+        public System.Net.Http.HeaderEncodingSelector<System.Net.Http.HttpRequestMessage>? RequestHeaderEncodingSelector { get { throw null; } set { } }
         public System.TimeSpan ResponseDrainTimeout { get { throw null; } set { } }
+        public System.Net.Http.HeaderEncodingSelector<System.Net.Http.HttpRequestMessage>? ResponseHeaderEncodingSelector { get { throw null; } set { } }
         [System.Diagnostics.CodeAnalysis.AllowNullAttribute]
         public System.Net.Security.SslClientAuthenticationOptions SslOptions { get { throw null; } set { } }
         public bool UseCookies { get { throw null; } set { } }
index 60e0afe..a30bc93 100644 (file)
@@ -29,6 +29,7 @@
     <Compile Include="System\Net\Http\EmptyContent.cs" />
     <Compile Include="System\Net\Http\EmptyReadStream.cs" />
     <Compile Include="System\Net\Http\FormUrlEncodedContent.cs" />
+    <Compile Include="System\Net\Http\HeaderEncodingSelector.cs" />
     <Compile Include="System\Net\Http\Headers\AltSvcHeaderParser.cs" />
     <Compile Include="System\Net\Http\Headers\AltSvcHeaderValue.cs" />
     <Compile Include="System\Net\Http\Headers\KnownHeader.cs" />
index 657d003..8c462ee 100644 (file)
@@ -142,6 +142,18 @@ namespace System.Net.Http
 
         public IDictionary<string, object?> Properties => throw new PlatformNotSupportedException();
 
+        public HeaderEncodingSelector<HttpRequestMessage>? RequestHeaderEncodingSelector
+        {
+            get => throw new PlatformNotSupportedException();
+            set => throw new PlatformNotSupportedException();
+        }
+
+        public HeaderEncodingSelector<HttpRequestMessage>? ResponseHeaderEncodingSelector
+        {
+            get => throw new PlatformNotSupportedException();
+            set => throw new PlatformNotSupportedException();
+        }
+
         protected internal override Task<HttpResponseMessage> SendAsync(
             HttpRequestMessage request, CancellationToken cancellationToken) => throw new PlatformNotSupportedException();
 
diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/HeaderEncodingSelector.cs b/src/libraries/System.Net.Http/src/System/Net/Http/HeaderEncodingSelector.cs
new file mode 100644 (file)
index 0000000..cb984b8
--- /dev/null
@@ -0,0 +1,15 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Text;
+
+namespace System.Net.Http
+{
+    /// <summary>
+    /// Represents a method that specifies the <see cref="Encoding"/> to use when interpreting header values.
+    /// </summary>
+    /// <param name="headerName">Name of the header to specify the <see cref="Encoding"/> for.</param>
+    /// <param name="context">The <typeparamref name="TContext"/> we are enoding/decoding the headers for.</param>
+    /// <returns><see cref="Encoding"/> to use or <see langword="null"/> to use the default behavior.</returns>
+    public delegate Encoding? HeaderEncodingSelector<TContext>(string headerName, TContext context);
+}
index 8970f78..229490b 100644 (file)
@@ -4,6 +4,7 @@
 using System.Buffers;
 using System.Diagnostics;
 using System.Diagnostics.CodeAnalysis;
+using System.Text;
 using System.Text.Unicode;
 
 namespace System.Net.Http.Headers
@@ -116,7 +117,7 @@ namespace System.Net.Http.Headers
             return new HeaderDescriptor(_knownHeader.Name);
         }
 
-        public string GetHeaderValue(ReadOnlySpan<byte> headerValue)
+        public string GetHeaderValue(ReadOnlySpan<byte> headerValue, Encoding? valueEncoding)
         {
             if (headerValue.Length == 0)
             {
@@ -156,7 +157,7 @@ namespace System.Net.Http.Headers
                 }
             }
 
-            return HttpRuleParser.DefaultHttpEncoding.GetString(headerValue);
+            return (valueEncoding ?? HttpRuleParser.DefaultHttpEncoding).GetString(headerValue);
         }
 
         internal static string? GetKnownContentType(ReadOnlySpan<byte> contentTypeValue)
index 1bd39f4..c9ad97f 100644 (file)
@@ -974,12 +974,12 @@ namespace System.Net.Http
             headerBuffer.Commit(bytesWritten);
         }
 
-        private void WriteLiteralHeader(string name, ReadOnlySpan<string> values, ref ArrayBuffer headerBuffer)
+        private void WriteLiteralHeader(string name, ReadOnlySpan<string> values, Encoding? valueEncoding, ref ArrayBuffer headerBuffer)
         {
             if (NetEventSource.Log.IsEnabled()) Trace($"{nameof(name)}={name}, {nameof(values)}={string.Join(", ", values.ToArray())}");
 
             int bytesWritten;
-            while (!HPackEncoder.EncodeLiteralHeaderFieldWithoutIndexingNewName(name, values, HttpHeaderParser.DefaultSeparator, headerBuffer.AvailableSpan, out bytesWritten))
+            while (!HPackEncoder.EncodeLiteralHeaderFieldWithoutIndexingNewName(name, values, HttpHeaderParser.DefaultSeparator, valueEncoding, headerBuffer.AvailableSpan, out bytesWritten))
             {
                 headerBuffer.EnsureAvailableSpace(headerBuffer.AvailableLength + 1);
             }
@@ -987,12 +987,12 @@ namespace System.Net.Http
             headerBuffer.Commit(bytesWritten);
         }
 
-        private void WriteLiteralHeaderValues(ReadOnlySpan<string> values, string? separator, ref ArrayBuffer headerBuffer)
+        private void WriteLiteralHeaderValues(ReadOnlySpan<string> values, string? separator, Encoding? valueEncoding, ref ArrayBuffer headerBuffer)
         {
             if (NetEventSource.Log.IsEnabled()) Trace($"{nameof(values)}={string.Join(separator, values.ToArray())}");
 
             int bytesWritten;
-            while (!HPackEncoder.EncodeStringLiterals(values, separator, headerBuffer.AvailableSpan, out bytesWritten))
+            while (!HPackEncoder.EncodeStringLiterals(values, separator, valueEncoding, headerBuffer.AvailableSpan, out bytesWritten))
             {
                 headerBuffer.EnsureAvailableSpace(headerBuffer.AvailableLength + 1);
             }
@@ -1000,12 +1000,12 @@ namespace System.Net.Http
             headerBuffer.Commit(bytesWritten);
         }
 
-        private void WriteLiteralHeaderValue(string value, ref ArrayBuffer headerBuffer)
+        private void WriteLiteralHeaderValue(string value, Encoding? valueEncoding, ref ArrayBuffer headerBuffer)
         {
             if (NetEventSource.Log.IsEnabled()) Trace($"{nameof(value)}={value}");
 
             int bytesWritten;
-            while (!HPackEncoder.EncodeStringLiteral(value, headerBuffer.AvailableSpan, out bytesWritten))
+            while (!HPackEncoder.EncodeStringLiteral(value, valueEncoding, headerBuffer.AvailableSpan, out bytesWritten))
             {
                 headerBuffer.EnsureAvailableSpace(headerBuffer.AvailableLength + 1);
             }
@@ -1026,7 +1026,7 @@ namespace System.Net.Http
             headerBuffer.Commit(bytes.Length);
         }
 
-        private void WriteHeaderCollection(HttpHeaders headers, ref ArrayBuffer headerBuffer)
+        private void WriteHeaderCollection(HttpRequestMessage request, HttpHeaders headers, ref ArrayBuffer headerBuffer)
         {
             if (NetEventSource.Log.IsEnabled()) Trace("");
 
@@ -1035,6 +1035,8 @@ namespace System.Net.Http
                 return;
             }
 
+            HeaderEncodingSelector<HttpRequestMessage>? encodingSelector = _pool.Settings._requestHeaderEncodingSelector;
+
             ref string[]? tmpHeaderValuesArray = ref t_headerValues;
             foreach (KeyValuePair<HeaderDescriptor, object> header in headers.HeaderStore)
             {
@@ -1042,6 +1044,8 @@ namespace System.Net.Http
                 Debug.Assert(headerValuesCount > 0, "No values for header??");
                 ReadOnlySpan<string> headerValues = tmpHeaderValuesArray.AsSpan(0, headerValuesCount);
 
+                Encoding? valueEncoding = encodingSelector?.Invoke(header.Key.Name, request);
+
                 KnownHeader? knownHeader = header.Key.KnownHeader;
                 if (knownHeader != null)
                 {
@@ -1058,7 +1062,7 @@ namespace System.Net.Http
                                 if (string.Equals(value, "trailers", StringComparison.OrdinalIgnoreCase))
                                 {
                                     WriteBytes(knownHeader.Http2EncodedName, ref headerBuffer);
-                                    WriteLiteralHeaderValue(value, ref headerBuffer);
+                                    WriteLiteralHeaderValue(value, valueEncoding, ref headerBuffer);
                                     break;
                                 }
                             }
@@ -1081,13 +1085,13 @@ namespace System.Net.Http
                             }
                         }
 
-                        WriteLiteralHeaderValues(headerValues, separator, ref headerBuffer);
+                        WriteLiteralHeaderValues(headerValues, separator, valueEncoding, ref headerBuffer);
                     }
                 }
                 else
                 {
                     // The header is not known: fall back to just encoding the header name and value(s).
-                    WriteLiteralHeader(header.Key.Name, headerValues, ref headerBuffer);
+                    WriteLiteralHeader(header.Key.Name, headerValues, valueEncoding, ref headerBuffer);
                 }
             }
         }
@@ -1142,7 +1146,7 @@ namespace System.Net.Http
 
             if (request.HasHeaders)
             {
-                WriteHeaderCollection(request.Headers, ref headerBuffer);
+                WriteHeaderCollection(request, request.Headers, ref headerBuffer);
             }
 
             // Determine cookies to send.
@@ -1152,7 +1156,9 @@ namespace System.Net.Http
                 if (cookiesFromContainer != string.Empty)
                 {
                     WriteBytes(KnownHeaders.Cookie.Http2EncodedName, ref headerBuffer);
-                    WriteLiteralHeaderValue(cookiesFromContainer, ref headerBuffer);
+
+                    Encoding? cookieEncoding = _pool.Settings._requestHeaderEncodingSelector?.Invoke(KnownHeaders.Cookie.Name, request);
+                    WriteLiteralHeaderValue(cookiesFromContainer, cookieEncoding, ref headerBuffer);
                 }
             }
 
@@ -1163,12 +1169,12 @@ namespace System.Net.Http
                 if (normalizedMethod.MustHaveRequestBody)
                 {
                     WriteBytes(KnownHeaders.ContentLength.Http2EncodedName, ref headerBuffer);
-                    WriteLiteralHeaderValue("0", ref headerBuffer);
+                    WriteLiteralHeaderValue("0", valueEncoding: null, ref headerBuffer);
                 }
             }
             else
             {
-                WriteHeaderCollection(request.Content.Headers, ref headerBuffer);
+                WriteHeaderCollection(request, request.Content.Headers, ref headerBuffer);
             }
         }
 
index 5ac829b..254fec2 100644 (file)
@@ -542,24 +542,26 @@ namespace System.Net.Http
                             throw new HttpRequestException(SR.Format(SR.net_http_invalid_response_header_name, Encoding.ASCII.GetString(name)));
                         }
 
+                        Encoding? valueEncoding = _connection._pool.Settings._responseHeaderEncodingSelector?.Invoke(descriptor.Name, _request);
+
                         // Note we ignore the return value from TryAddWithoutValidation;
                         // if the header can't be added, we silently drop it.
                         if (_responseProtocolState == ResponseProtocolState.ExpectingTrailingHeaders)
                         {
                             Debug.Assert(_trailers != null);
-                            string headerValue = descriptor.GetHeaderValue(value);
+                            string headerValue = descriptor.GetHeaderValue(value, valueEncoding);
                             _trailers.TryAddWithoutValidation((descriptor.HeaderType & HttpHeaderType.Request) == HttpHeaderType.Request ? descriptor.AsCustomHeader() : descriptor, headerValue);
                         }
                         else if ((descriptor.HeaderType & HttpHeaderType.Content) == HttpHeaderType.Content)
                         {
                             Debug.Assert(_response != null && _response.Content != null);
-                            string headerValue = descriptor.GetHeaderValue(value);
+                            string headerValue = descriptor.GetHeaderValue(value, valueEncoding);
                             _response.Content.Headers.TryAddWithoutValidation(descriptor, headerValue);
                         }
                         else
                         {
                             Debug.Assert(_response != null);
-                            string headerValue = _connection.GetResponseHeaderValueWithCaching(descriptor, value);
+                            string headerValue = _connection.GetResponseHeaderValueWithCaching(descriptor, value, valueEncoding);
                             _response.Headers.TryAddWithoutValidation((descriptor.HeaderType & HttpHeaderType.Request) == HttpHeaderType.Request ? descriptor.AsCustomHeader() : descriptor, headerValue);
                         }
                     }
index 07c2840..72e8090 100644 (file)
@@ -548,7 +548,8 @@ namespace System.Net.Http
                 string cookiesFromContainer = _connection.Pool.Settings._cookieContainer!.GetCookieHeader(request.RequestUri);
                 if (cookiesFromContainer != string.Empty)
                 {
-                    BufferLiteralHeaderWithStaticNameReference(H3StaticTable.Cookie, cookiesFromContainer);
+                    Encoding? valueEncoding = _connection.Pool.Settings._requestHeaderEncodingSelector?.Invoke(HttpKnownHeaderNames.Cookie, request);
+                    BufferLiteralHeaderWithStaticNameReference(H3StaticTable.Cookie, cookiesFromContainer, valueEncoding);
                 }
             }
 
@@ -590,12 +591,16 @@ namespace System.Net.Http
                 return;
             }
 
+            HeaderEncodingSelector<HttpRequestMessage>? encodingSelector = _connection.Pool.Settings._requestHeaderEncodingSelector;
+
             foreach (KeyValuePair<HeaderDescriptor, object> header in headers.HeaderStore)
             {
                 int headerValuesCount = HttpHeaders.GetValuesAsStrings(header.Key, header.Value, ref _headerValues);
                 Debug.Assert(headerValuesCount > 0, "No values for header??");
                 ReadOnlySpan<string> headerValues = _headerValues.AsSpan(0, headerValuesCount);
 
+                Encoding? valueEncoding = encodingSelector?.Invoke(header.Key.Name, _request);
+
                 KnownHeader? knownHeader = header.Key.KnownHeader;
                 if (knownHeader != null)
                 {
@@ -612,7 +617,7 @@ namespace System.Net.Http
                             {
                                 if (string.Equals(value, "trailers", StringComparison.OrdinalIgnoreCase))
                                 {
-                                    BufferLiteralHeaderWithoutNameReference("TE", value);
+                                    BufferLiteralHeaderWithoutNameReference("TE", value, valueEncoding);
                                     break;
                                 }
                             }
@@ -635,13 +640,13 @@ namespace System.Net.Http
                             }
                         }
 
-                        BufferLiteralHeaderValues(headerValues, separator);
+                        BufferLiteralHeaderValues(headerValues, separator, valueEncoding);
                     }
                 }
                 else
                 {
                     // The header is not known: fall back to just encoding the header name and value(s).
-                    BufferLiteralHeaderWithoutNameReference(header.Key.Name, headerValues, ", ");
+                    BufferLiteralHeaderWithoutNameReference(header.Key.Name, headerValues, HttpHeaderParser.DefaultSeparator, valueEncoding);
                 }
             }
         }
@@ -656,40 +661,40 @@ namespace System.Net.Http
             _sendBuffer.Commit(bytesWritten);
         }
 
-        private void BufferLiteralHeaderWithStaticNameReference(int nameIndex, string value)
+        private void BufferLiteralHeaderWithStaticNameReference(int nameIndex, string value, Encoding? valueEncoding = null)
         {
             int bytesWritten;
-            while (!QPackEncoder.EncodeLiteralHeaderFieldWithStaticNameReference(nameIndex, value, _sendBuffer.AvailableSpan, out bytesWritten))
+            while (!QPackEncoder.EncodeLiteralHeaderFieldWithStaticNameReference(nameIndex, value, valueEncoding, _sendBuffer.AvailableSpan, out bytesWritten))
             {
                 _sendBuffer.Grow();
             }
             _sendBuffer.Commit(bytesWritten);
         }
 
-        private void BufferLiteralHeaderWithoutNameReference(string name, ReadOnlySpan<string> values, string separator)
+        private void BufferLiteralHeaderWithoutNameReference(string name, ReadOnlySpan<string> values, string separator, Encoding? valueEncoding)
         {
             int bytesWritten;
-            while (!QPackEncoder.EncodeLiteralHeaderFieldWithoutNameReference(name, values, separator, _sendBuffer.AvailableSpan, out bytesWritten))
+            while (!QPackEncoder.EncodeLiteralHeaderFieldWithoutNameReference(name, values, separator, valueEncoding, _sendBuffer.AvailableSpan, out bytesWritten))
             {
                 _sendBuffer.Grow();
             }
             _sendBuffer.Commit(bytesWritten);
         }
 
-        private void BufferLiteralHeaderWithoutNameReference(string name, string value)
+        private void BufferLiteralHeaderWithoutNameReference(string name, string value, Encoding? valueEncoding)
         {
             int bytesWritten;
-            while (!QPackEncoder.EncodeLiteralHeaderFieldWithoutNameReference(name, value, _sendBuffer.AvailableSpan, out bytesWritten))
+            while (!QPackEncoder.EncodeLiteralHeaderFieldWithoutNameReference(name, value, valueEncoding, _sendBuffer.AvailableSpan, out bytesWritten))
             {
                 _sendBuffer.Grow();
             }
             _sendBuffer.Commit(bytesWritten);
         }
 
-        private void BufferLiteralHeaderValues(ReadOnlySpan<string> values, string? separator)
+        private void BufferLiteralHeaderValues(ReadOnlySpan<string> values, string? separator, Encoding? valueEncoding)
         {
             int bytesWritten;
-            while (!QPackEncoder.EncodeValueString(values, separator, _sendBuffer.AvailableSpan, out bytesWritten))
+            while (!QPackEncoder.EncodeValueString(values, separator, valueEncoding, _sendBuffer.AvailableSpan, out bytesWritten))
             {
                 _sendBuffer.Grow();
             }
@@ -917,7 +922,13 @@ namespace System.Net.Http
             }
             else
             {
-                string headerValue = staticValue ?? _connection.GetResponseHeaderValueWithCaching(descriptor, literalValue);
+                string? headerValue = staticValue;
+
+                if (headerValue is null)
+                {
+                    Encoding? encoding = _connection.Pool.Settings._responseHeaderEncodingSelector?.Invoke(descriptor.Name, _request);
+                    headerValue = _connection.GetResponseHeaderValueWithCaching(descriptor, literalValue, encoding);
+                }
 
                 switch (_headerState)
                 {
index 33383a1..131af80 100644 (file)
@@ -197,6 +197,8 @@ namespace System.Net.Http
 
         private async ValueTask WriteHeadersAsync(HttpHeaders headers, string? cookiesFromContainer, bool async)
         {
+            Debug.Assert(_currentRequest != null);
+
             if (headers.HeaderStore != null)
             {
                 foreach (KeyValuePair<HeaderDescriptor, object> header in headers.HeaderStore)
@@ -215,12 +217,14 @@ namespace System.Net.Http
                     Debug.Assert(headerValuesCount > 0, "No values for header??");
                     if (headerValuesCount > 0)
                     {
-                        await WriteStringAsync(_headerValues[0], async).ConfigureAwait(false);
+                        Encoding? valueEncoding = _pool.Settings._requestHeaderEncodingSelector?.Invoke(header.Key.Name, _currentRequest);
+
+                        await WriteStringAsync(_headerValues[0], async, valueEncoding).ConfigureAwait(false);
 
                         if (cookiesFromContainer != null && header.Key.KnownHeader == KnownHeaders.Cookie)
                         {
                             await WriteTwoBytesAsync((byte)';', (byte)' ', async).ConfigureAwait(false);
-                            await WriteStringAsync(cookiesFromContainer, async).ConfigureAwait(false);
+                            await WriteStringAsync(cookiesFromContainer, async, valueEncoding).ConfigureAwait(false);
 
                             cookiesFromContainer = null;
                         }
@@ -238,7 +242,7 @@ namespace System.Net.Http
                             for (int i = 1; i < headerValuesCount; i++)
                             {
                                 await WriteAsciiStringAsync(separator, async).ConfigureAwait(false);
-                                await WriteStringAsync(_headerValues[i], async).ConfigureAwait(false);
+                                await WriteStringAsync(_headerValues[i], async, valueEncoding).ConfigureAwait(false);
                             }
                         }
                     }
@@ -251,7 +255,10 @@ namespace System.Net.Http
             {
                 await WriteAsciiStringAsync(HttpKnownHeaderNames.Cookie, async).ConfigureAwait(false);
                 await WriteTwoBytesAsync((byte)':', (byte)' ', async).ConfigureAwait(false);
-                await WriteStringAsync(cookiesFromContainer, async).ConfigureAwait(false);
+
+                Encoding? valueEncoding = _pool.Settings._requestHeaderEncodingSelector?.Invoke(HttpKnownHeaderNames.Cookie, _currentRequest);
+                await WriteStringAsync(cookiesFromContainer, async, valueEncoding).ConfigureAwait(false);
+
                 await WriteTwoBytesAsync((byte)'\r', (byte)'\n', async).ConfigureAwait(false);
             }
         }
@@ -984,22 +991,25 @@ namespace System.Net.Http
                 pos++;
             }
 
+            Debug.Assert(response.RequestMessage != null);
+            Encoding? valueEncoding = connection._pool.Settings._responseHeaderEncodingSelector?.Invoke(descriptor.Name, response.RequestMessage);
+
             // Note we ignore the return value from TryAddWithoutValidation. If the header can't be added, we silently drop it.
             ReadOnlySpan<byte> value = line.Slice(pos);
             if (isFromTrailer)
             {
-                string headerValue = descriptor.GetHeaderValue(value);
+                string headerValue = descriptor.GetHeaderValue(value, valueEncoding);
                 response.TrailingHeaders.TryAddWithoutValidation((descriptor.HeaderType & HttpHeaderType.Request) == HttpHeaderType.Request ? descriptor.AsCustomHeader() : descriptor, headerValue);
             }
             else if ((descriptor.HeaderType & HttpHeaderType.Content) == HttpHeaderType.Content)
             {
-                string headerValue = descriptor.GetHeaderValue(value);
+                string headerValue = descriptor.GetHeaderValue(value, valueEncoding);
                 response.Content!.Headers.TryAddWithoutValidation(descriptor, headerValue);
             }
             else
             {
                 // Request headers returned on the response must be treated as custom headers.
-                string headerValue = connection.GetResponseHeaderValueWithCaching(descriptor, value);
+                string headerValue = connection.GetResponseHeaderValueWithCaching(descriptor, value, valueEncoding);
                 response.Headers.TryAddWithoutValidation(
                     (descriptor.HeaderType & HttpHeaderType.Request) == HttpHeaderType.Request ? descriptor.AsCustomHeader() : descriptor,
                     headerValue);
@@ -1182,23 +1192,23 @@ namespace System.Net.Http
                 _writeOffset += bytes.Length;
                 return Task.CompletedTask;
             }
-            return WriteBytesSlowAsync(bytes, async);
+            return WriteBytesSlowAsync(bytes, bytes.Length, async);
         }
 
-        private async Task WriteBytesSlowAsync(byte[] bytes, bool async)
+        private async Task WriteBytesSlowAsync(byte[] bytes, int length, bool async)
         {
             int offset = 0;
             while (true)
             {
-                int remaining = bytes.Length - offset;
+                int remaining = length - offset;
                 int toCopy = Math.Min(remaining, _writeBuffer.Length - _writeOffset);
                 Buffer.BlockCopy(bytes, offset, _writeBuffer, _writeOffset, toCopy);
                 _writeOffset += toCopy;
                 offset += toCopy;
 
-                Debug.Assert(offset <= bytes.Length, $"Expected {nameof(offset)} to be <= {bytes.Length}, got {offset}");
+                Debug.Assert(offset <= length, $"Expected {nameof(offset)} to be <= {length}, got {offset}");
                 Debug.Assert(_writeOffset <= _writeBuffer.Length, $"Expected {nameof(_writeOffset)} to be <= {_writeBuffer.Length}, got {_writeOffset}");
-                if (offset == bytes.Length)
+                if (offset == length)
                 {
                     break;
                 }
@@ -1235,6 +1245,43 @@ namespace System.Net.Http
             return WriteStringAsyncSlow(s, async);
         }
 
+        private Task WriteStringAsync(string s, bool async, Encoding? encoding)
+        {
+            if (encoding is null)
+            {
+                return WriteStringAsync(s, async);
+            }
+
+            // If there's enough space in the buffer to just copy all of the string's bytes, do so.
+            if (encoding.GetMaxByteCount(s.Length) <= _writeBuffer.Length - _writeOffset)
+            {
+                _writeOffset += encoding.GetBytes(s, _writeBuffer.AsSpan(_writeOffset));
+                return Task.CompletedTask;
+            }
+
+            // Otherwise, fall back to doing a normal slow string write
+            return WriteStringWithEncodingAsyncSlow(s, async, encoding);
+        }
+
+        private async Task WriteStringWithEncodingAsyncSlow(string s, bool async, Encoding encoding)
+        {
+            // Avoid calculating the length if the rented array would be small anyway
+            int length = s.Length <= 512
+                ? encoding.GetMaxByteCount(s.Length)
+                : encoding.GetByteCount(s);
+
+            byte[] rentedBuffer = ArrayPool<byte>.Shared.Rent(length);
+            try
+            {
+                int written = encoding.GetBytes(s, rentedBuffer);
+                await WriteBytesSlowAsync(rentedBuffer, written, async).ConfigureAwait(false);
+            }
+            finally
+            {
+                ArrayPool<byte>.Shared.Return(rentedBuffer);
+            }
+        }
+
         private Task WriteAsciiStringAsync(string s, bool async)
         {
             // If there's enough space in the buffer to just copy all of the string's bytes, do so.
index b4b0aa1..df8191b 100644 (file)
@@ -7,6 +7,7 @@ using System.IO;
 using System.Net.Http.Headers;
 using System.Net.Security;
 using System.Runtime.CompilerServices;
+using System.Text;
 using System.Threading;
 using System.Threading.Tasks;
 
@@ -20,19 +21,19 @@ namespace System.Net.Http
         private string? _lastServerHeaderValue;
 
         /// <summary>Uses <see cref="HeaderDescriptor.GetHeaderValue"/>, but first special-cases several known headers for which we can use caching.</summary>
-        public string GetResponseHeaderValueWithCaching(HeaderDescriptor descriptor, ReadOnlySpan<byte> value)
+        public string GetResponseHeaderValueWithCaching(HeaderDescriptor descriptor, ReadOnlySpan<byte> value, Encoding? valueEncoding)
         {
             return
-                ReferenceEquals(descriptor.KnownHeader, KnownHeaders.Date) ? GetOrAddCachedValue(ref _lastDateHeaderValue, descriptor, value) :
-                ReferenceEquals(descriptor.KnownHeader, KnownHeaders.Server) ? GetOrAddCachedValue(ref _lastServerHeaderValue, descriptor, value) :
-                descriptor.GetHeaderValue(value);
+                ReferenceEquals(descriptor.KnownHeader, KnownHeaders.Date) ? GetOrAddCachedValue(ref _lastDateHeaderValue, descriptor, value, valueEncoding) :
+                ReferenceEquals(descriptor.KnownHeader, KnownHeaders.Server) ? GetOrAddCachedValue(ref _lastServerHeaderValue, descriptor, value, valueEncoding) :
+                descriptor.GetHeaderValue(value, valueEncoding);
 
-            static string GetOrAddCachedValue([NotNull] ref string? cache, HeaderDescriptor descriptor, ReadOnlySpan<byte> value)
+            static string GetOrAddCachedValue([NotNull] ref string? cache, HeaderDescriptor descriptor, ReadOnlySpan<byte> value, Encoding? encoding)
             {
                 string? lastValue = cache;
                 if (lastValue is null || !ByteArrayHelpers.EqualsOrdinalAscii(lastValue, value))
                 {
-                    cache = lastValue = descriptor.GetHeaderValue(value);
+                    cache = lastValue = descriptor.GetHeaderValue(value, encoding);
                 }
                 return lastValue;
             }
index 558e6cf..507af88 100644 (file)
@@ -46,6 +46,9 @@ namespace System.Net.Http
         internal TimeSpan _expect100ContinueTimeout = HttpHandlerDefaults.DefaultExpect100ContinueTimeout;
         internal TimeSpan _connectTimeout = HttpHandlerDefaults.DefaultConnectTimeout;
 
+        internal HeaderEncodingSelector<HttpRequestMessage>? _requestHeaderEncodingSelector;
+        internal HeaderEncodingSelector<HttpRequestMessage>? _responseHeaderEncodingSelector;
+
         internal Version _maxHttpVersion;
 
         internal bool _allowUnencryptedHttp2;
@@ -110,6 +113,8 @@ namespace System.Net.Http
                 _useProxy = _useProxy,
                 _allowUnencryptedHttp2 = _allowUnencryptedHttp2,
                 _assumePrenegotiatedHttp3ForTesting = _assumePrenegotiatedHttp3ForTesting,
+                _requestHeaderEncodingSelector = _requestHeaderEncodingSelector,
+                _responseHeaderEncodingSelector = _responseHeaderEncodingSelector,
                 _enableMultipleHttp2Connections = _enableMultipleHttp2Connections,
                 _connectionFactory = _connectionFactory,
                 _plaintextFilter = _plaintextFilter
index de391e9..8eb5a92 100644 (file)
@@ -3,11 +3,12 @@
 
 using System.Collections.Generic;
 using System.Diagnostics;
+using System.Net.Connections;
 using System.Net.Security;
 using System.Threading;
 using System.Threading.Tasks;
 using System.Diagnostics.CodeAnalysis;
-using System.Net.Connections;
+using System.Text;
 
 namespace System.Net.Http
 {
@@ -319,6 +320,34 @@ namespace System.Net.Http
         public IDictionary<string, object?> Properties =>
             _settings._properties ?? (_settings._properties = new Dictionary<string, object?>());
 
+        /// <summary>
+        /// Gets or sets a callback that returns the <see cref="Encoding"/> to encode the value for the specified request header name,
+        /// or <see langword="null"/> to use the default behavior.
+        /// </summary>
+        public HeaderEncodingSelector<HttpRequestMessage>? RequestHeaderEncodingSelector
+        {
+            get => _settings._requestHeaderEncodingSelector;
+            set
+            {
+                CheckDisposedOrStarted();
+                _settings._requestHeaderEncodingSelector = value;
+            }
+        }
+
+        /// <summary>
+        /// Gets or sets a callback that returns the <see cref="Encoding"/> to decode the value for the specified response header name,
+        /// or <see langword="null"/> to use the default behavior.
+        /// </summary>
+        public HeaderEncodingSelector<HttpRequestMessage>? ResponseHeaderEncodingSelector
+        {
+            get => _settings._responseHeaderEncodingSelector;
+            set
+            {
+                CheckDisposedOrStarted();
+                _settings._responseHeaderEncodingSelector = value;
+            }
+        }
+
         protected override void Dispose(bool disposing)
         {
             if (disposing && !_disposed)
index afffb68..3ed0f55 100644 (file)
@@ -6,6 +6,7 @@ using System.IO;
 using System.Linq;
 using System.Net.Http.Headers;
 using System.Net.Test.Common;
+using System.Text;
 using System.Threading.Tasks;
 
 using Xunit;
@@ -332,5 +333,132 @@ namespace System.Net.Http.Functional.Tests
                     });
                 });
         }
+
+        private static readonly (string Name, Encoding ValueEncoding, string[] Values)[] s_nonAsciiHeaders = new[]
+        {
+            ("foo",             Encoding.ASCII,     new[] { "bar" }),
+            ("header-0",        Encoding.UTF8,      new[] { "\uD83D\uDE03", "\uD83D\uDE48\uD83D\uDE49\uD83D\uDE4A" }),
+            ("Cache-Control",   Encoding.UTF8,      new[] { "no-cache" }),
+            ("header-1",        Encoding.UTF8,      new[] { "\uD83D\uDE03" }),
+            ("Some-Header1",    Encoding.Latin1,    new[] { "\uD83D\uDE03", "UTF8-best-fit-to-latin1" }),
+            ("Some-Header2",    Encoding.Latin1,    new[] { "\u00FF", "\u00C4nd", "Ascii\u00A9" }),
+            ("Some-Header3",    Encoding.ASCII,     new[] { "\u00FF", "\u00C4nd", "Ascii\u00A9", "Latin1-best-fit-to-ascii" }),
+            ("header-2",        Encoding.UTF8,      new[] { "\uD83D\uDE48\uD83D\uDE49\uD83D\uDE4A" }),
+            ("header-3",        Encoding.UTF8,      new[] { "\uFFFD" }),
+            ("header-4",        Encoding.UTF8,      new[] { "\uD83D\uDE48\uD83D\uDE49\uD83D\uDE4A", "\uD83D\uDE03" }),
+            ("Cookie",          Encoding.UTF8,      new[] { "Cookies", "\uD83C\uDF6A", "everywhere" }),
+            ("Set-Cookie",      Encoding.UTF8,      new[] { "\uD83C\uDDF8\uD83C\uDDEE" }),
+            ("header-5",        Encoding.UTF8,      new[] { "\uD83D\uDE48\uD83D\uDE49\uD83D\uDE4A", "foo", "\uD83D\uDE03", "bar" }),
+            ("bar",             Encoding.UTF8,      new[] { "foo" })
+        };
+
+        [Fact]
+        public async Task SendAsync_CustomRequestEncodingSelector_CanSendNonAsciiHeaderValues()
+        {
+            await LoopbackServerFactory.CreateClientAndServerAsync(
+                async uri =>
+                {
+                    var requestMessage = new HttpRequestMessage(HttpMethod.Get, uri)
+                    {
+                        Version = UseVersion
+                    };
+
+                    foreach ((string name, _, string[] values) in s_nonAsciiHeaders)
+                    {
+                        requestMessage.Headers.Add(name, values);
+                    }
+
+                    List<string> seenHeaderNames = new List<string>();
+
+                    using HttpClientHandler handler = CreateHttpClientHandler();
+                    var underlyingHandler = (SocketsHttpHandler)GetUnderlyingSocketsHttpHandler(handler);
+
+                    underlyingHandler.RequestHeaderEncodingSelector = (name, request) =>
+                    {
+                        Assert.NotNull(name);
+                        Assert.Same(request, requestMessage);
+                        seenHeaderNames.Add(name);
+                        return Assert.Single(s_nonAsciiHeaders, h => h.Name.Equals(name, StringComparison.OrdinalIgnoreCase)).ValueEncoding;
+                    };
+
+                    using HttpClient client = CreateHttpClient(handler);
+
+                    await client.SendAsync(requestMessage);
+
+                    foreach ((string name, _, _) in s_nonAsciiHeaders)
+                    {
+                        Assert.Contains(name, seenHeaderNames);
+                    }
+                },
+                async server =>
+                {
+                    HttpRequestData requestData = await server.HandleRequestAsync();
+
+                    Assert.All(requestData.Headers,
+                        h => Assert.False(h.HuffmanEncoded, "Expose raw decoded bytes once HuffmanEncoding is supported"));
+
+                    foreach ((string name, Encoding valueEncoding, string[] values) in s_nonAsciiHeaders)
+                    {
+                        byte[] valueBytes = valueEncoding.GetBytes(string.Join(", ", values));
+                        Assert.Single(requestData.Headers,
+                            h => h.Name.Equals(name, StringComparison.OrdinalIgnoreCase) && h.Raw.AsSpan().IndexOf(valueBytes) != -1);
+                    }
+                });
+        }
+
+        [Fact]
+        public async Task SendAsync_CustomResponseEncodingSelector_CanReceiveNonAsciiHeaderValues()
+        {
+            await LoopbackServerFactory.CreateClientAndServerAsync(
+                async uri =>
+                {
+                    var requestMessage = new HttpRequestMessage(HttpMethod.Get, uri)
+                    {
+                        Version = UseVersion
+                    };
+
+                    List<string> seenHeaderNames = new List<string>();
+
+                    using HttpClientHandler handler = CreateHttpClientHandler();
+                    var underlyingHandler = (SocketsHttpHandler)GetUnderlyingSocketsHttpHandler(handler);
+
+                    underlyingHandler.ResponseHeaderEncodingSelector = (name, request) =>
+                    {
+                        Assert.NotNull(name);
+                        Assert.Same(request, requestMessage);
+                        seenHeaderNames.Add(name);
+
+                        if (s_nonAsciiHeaders.Any(h => h.Name.Equals(name, StringComparison.OrdinalIgnoreCase)))
+                        {
+                            return Assert.Single(s_nonAsciiHeaders, h => h.Name.Equals(name, StringComparison.OrdinalIgnoreCase)).ValueEncoding;
+                        }
+
+                        // Not one of our custom headers
+                        return null;
+                    };
+
+                    using HttpClient client = CreateHttpClient(handler);
+
+                    using HttpResponseMessage response = await client.SendAsync(requestMessage);
+
+                    foreach ((string name, Encoding valueEncoding, string[] values) in s_nonAsciiHeaders)
+                    {
+                        Assert.Contains(name, seenHeaderNames);
+                        IEnumerable<string> receivedValues = Assert.Single(response.Headers, h => h.Key.Equals(name, StringComparison.OrdinalIgnoreCase)).Value;
+                        string value = Assert.Single(receivedValues);
+
+                        string expected = valueEncoding.GetString(valueEncoding.GetBytes(string.Join(", ", values)));
+                        Assert.Equal(expected, value, StringComparer.OrdinalIgnoreCase);
+                    }
+                },
+                async server =>
+                {
+                    List<HttpHeaderData> headerData = s_nonAsciiHeaders
+                        .Select(h => new HttpHeaderData(h.Name, string.Join(", ", h.Values), valueEncoding: h.ValueEncoding))
+                        .ToList();
+
+                    await server.HandleRequestAsync(headers: headerData);
+                });
+        }
     }
 }
index 730cab3..3fd84d6 100644 (file)
@@ -14,25 +14,35 @@ namespace System.Net.Http.Unit.Tests.HPack
 {
     public class HPackRoundtripTests
     {
-
         public static IEnumerable<object[]> TestHeaders()
         {
-            yield return new object[] { new HttpRequestHeaders() { { "header", "value" } } };
-            yield return new object[] { new HttpRequestHeaders() { { "header", new[] { "value1", "value2" } } } };
+            yield return new object[] { new HttpRequestHeaders() { { "header", "value" } }, null };
+            yield return new object[] { new HttpRequestHeaders() { { "header", "value" } }, Encoding.ASCII };
+            yield return new object[] { new HttpRequestHeaders() { { "header", new[] { "value1", "value2" } } }, null };
+            yield return new object[] { new HttpRequestHeaders() { { "header", new[] { "value1", "value2" } } }, Encoding.ASCII };
             yield return new object[] { new HttpRequestHeaders()
             {
                 { "header-0", new[] { "value1", "value2" } },
                 { "header-0", "value3" },
                 { "header-1", "value1" },
                 { "header-2", new[] { "value1", "value2" } },
-            } };
+            }, null };
+            yield return new object[] { new HttpRequestHeaders() { { "header", "foo" } }, Encoding.UTF8 };
+            yield return new object[] { new HttpRequestHeaders() { { "header", "\uD83D\uDE03" } }, Encoding.UTF8 };
+            yield return new object[] { new HttpRequestHeaders()
+            {
+                { "header-0", new[] { "\uD83D\uDE03", "\uD83D\uDE48\uD83D\uDE49\uD83D\uDE4A" } },
+                { "header-1", "\uD83D\uDE03" },
+                { "header-2", "\uD83D\uDE48\uD83D\uDE49\uD83D\uDE4A" },
+                { "header-3", new[] { "\uD83D\uDE03", "\uD83D\uDE48\uD83D\uDE49\uD83D\uDE4A" } }
+            }, Encoding.UTF8 };
         }
 
         [Theory, MemberData(nameof(TestHeaders))]
-        public void HPack_HeaderEncodeDecodeRoundtrip_ShouldMatchOriginalInput(HttpHeaders headers)
+        public void HPack_HeaderEncodeDecodeRoundtrip_ShouldMatchOriginalInput(HttpHeaders headers, Encoding? valueEncoding)
         {
-            Memory<byte> encoding = HPackEncode(headers);
-            HttpHeaders decodedHeaders = HPackDecode(encoding);
+            Memory<byte> encoding = HPackEncode(headers, valueEncoding);
+            HttpHeaders decodedHeaders = HPackDecode(encoding, valueEncoding);
 
             // Assert: decoded headers are structurally equal to original headers
             Assert.Equal(headers.Count(), decodedHeaders.Count());
@@ -44,7 +54,7 @@ namespace System.Net.Http.Unit.Tests.HPack
         }
 
         // adapted from Header serialization code in Http2Connection.cs
-        private static Memory<byte> HPackEncode(HttpHeaders headers)
+        private static Memory<byte> HPackEncode(HttpHeaders headers, Encoding? valueEncoding)
         {
             var buffer = new ArrayBuffer(4);
             FillAvailableSpaceWithOnes(buffer);
@@ -101,7 +111,7 @@ namespace System.Net.Http.Unit.Tests.HPack
             void WriteLiteralHeaderValues(ReadOnlySpan<string> values, string separator)
             {
                 int bytesWritten;
-                while (!HPackEncoder.EncodeStringLiterals(values, separator, buffer.AvailableSpan, out bytesWritten))
+                while (!HPackEncoder.EncodeStringLiterals(values, separator, valueEncoding, buffer.AvailableSpan, out bytesWritten))
                 {
                     buffer.EnsureAvailableSpace(buffer.AvailableLength + 1);
                     FillAvailableSpaceWithOnes(buffer);
@@ -113,7 +123,7 @@ namespace System.Net.Http.Unit.Tests.HPack
             void WriteLiteralHeader(string name, ReadOnlySpan<string> values)
             {
                 int bytesWritten;
-                while (!HPackEncoder.EncodeLiteralHeaderFieldWithoutIndexingNewName(name, values, HttpHeaderParser.DefaultSeparator, buffer.AvailableSpan, out bytesWritten))
+                while (!HPackEncoder.EncodeLiteralHeaderFieldWithoutIndexingNewName(name, values, HttpHeaderParser.DefaultSeparator, valueEncoding, buffer.AvailableSpan, out bytesWritten))
                 {
                     buffer.EnsureAvailableSpace(buffer.AvailableLength + 1);
                     FillAvailableSpaceWithOnes(buffer);
@@ -127,12 +137,12 @@ namespace System.Net.Http.Unit.Tests.HPack
         }
 
         // adapted from header deserialization code in Http2Connection.cs
-        private static HttpHeaders HPackDecode(Memory<byte> memory)
+        private static HttpHeaders HPackDecode(Memory<byte> memory, Encoding? valueEncoding)
         {
             var header = new HttpRequestHeaders();
             var hpackDecoder = new HPackDecoder(maxDynamicTableSize: 0, maxHeadersLength: HttpHandlerDefaults.DefaultMaxResponseHeadersLength * 1024);
 
-            hpackDecoder.Decode(memory.Span, true, new HeaderHandler(header));
+            hpackDecoder.Decode(memory.Span, true, new HeaderHandler(header, valueEncoding));
 
             return header;
         }
@@ -140,9 +150,12 @@ namespace System.Net.Http.Unit.Tests.HPack
         private class HeaderHandler : IHttpHeadersHandler
         {
             HttpRequestHeaders _headers;
-            public HeaderHandler(HttpRequestHeaders headers)
+            Encoding? _valueEncoding;
+
+            public HeaderHandler(HttpRequestHeaders headers, Encoding? valueEncoding)
             {
                 _headers = headers;
+                _valueEncoding = valueEncoding;
             }
 
             public void OnHeader(ReadOnlySpan<byte> name, ReadOnlySpan<byte> value)
@@ -152,7 +165,7 @@ namespace System.Net.Http.Unit.Tests.HPack
                     throw new HttpRequestException(SR.Format(SR.net_http_invalid_response_header_name, Encoding.ASCII.GetString(name)));
                 }
 
-                string headerValue = descriptor.GetHeaderValue(value);
+                string headerValue = descriptor.GetHeaderValue(value, _valueEncoding);
 
                 _headers.TryAddWithoutValidation(descriptor, headerValue.Split(',').Select(x => x.Trim()));
             }
diff --git a/src/libraries/System.Net.Http/tests/UnitTests/Headers/HeaderEncodingTest.cs b/src/libraries/System.Net.Http/tests/UnitTests/Headers/HeaderEncodingTest.cs
new file mode 100644 (file)
index 0000000..d7f41e0
--- /dev/null
@@ -0,0 +1,37 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Net.Http.Headers;
+using System.Text;
+using Xunit;
+
+namespace System.Net.Http.Tests
+{
+    public class HeaderEncodingTest
+    {
+        [Theory]
+        [InlineData("")]
+        [InlineData("foo")]
+        [InlineData("\uD83D\uDE03")]
+        [InlineData("\0")]
+        [InlineData("\x01")]
+        [InlineData("\xFF")]
+        [InlineData("\uFFFF")]
+        [InlineData("\uFFFD")]
+        [InlineData("\uD83D\uDE48\uD83D\uDE49\uD83D\uDE4A")]
+        public void RoundTripsUtf8(string input)
+        {
+            byte[] encoded = Encoding.UTF8.GetBytes(input);
+
+            Assert.True(HeaderDescriptor.TryGet("custom-header", out HeaderDescriptor descriptor));
+            Assert.Null(descriptor.KnownHeader);
+            string roundtrip = descriptor.GetHeaderValue(encoded, Encoding.UTF8);
+            Assert.Equal(input, roundtrip);
+
+            Assert.True(HeaderDescriptor.TryGet("Cache-Control", out descriptor));
+            Assert.NotNull(descriptor.KnownHeader);
+            roundtrip = descriptor.GetHeaderValue(encoded, Encoding.UTF8);
+            Assert.Equal(input, roundtrip);
+        }
+    }
+}
index a877602..78fada7 100644 (file)
@@ -222,11 +222,11 @@ namespace System.Net.Http.Tests
             {
                 Assert.NotNull(knownHeader);
 
-                string v1 = knownHeader.Descriptor.GetHeaderValue(value.Select(c => (byte)c).ToArray());
+                string v1 = knownHeader.Descriptor.GetHeaderValue(value.Select(c => (byte)c).ToArray(), valueEncoding: null);
                 Assert.NotNull(v1);
                 Assert.Equal(value, v1, StringComparer.OrdinalIgnoreCase);
 
-                string v2 = knownHeader.Descriptor.GetHeaderValue(value.Select(c => (byte)c).ToArray());
+                string v2 = knownHeader.Descriptor.GetHeaderValue(value.Select(c => (byte)c).ToArray(), valueEncoding: null);
                 Assert.Same(v1, v2);
             }
         }
@@ -239,8 +239,8 @@ namespace System.Net.Http.Tests
             KnownHeader knownHeader = KnownHeaders.TryGetKnownHeader(name);
             Assert.NotNull(knownHeader);
 
-            string v1 = knownHeader.Descriptor.GetHeaderValue(value.Select(c => (byte)c).ToArray());
-            string v2 = knownHeader.Descriptor.GetHeaderValue(value.Select(c => (byte)c).ToArray());
+            string v1 = knownHeader.Descriptor.GetHeaderValue(value.Select(c => (byte)c).ToArray(), valueEncoding: null);
+            string v2 = knownHeader.Descriptor.GetHeaderValue(value.Select(c => (byte)c).ToArray(), valueEncoding: null);
             Assert.Equal(value, v1);
             Assert.Equal(value, v2);
             Assert.NotSame(v1, v2);
index 1a1199a..f5419d0 100644 (file)
     <Compile Include="Headers\GenericHeaderParserTest\TokenListParserTest.cs" />
     <Compile Include="Headers\GenericHeaderParserTest\ViaParserTest.cs" />
     <Compile Include="Headers\GenericHeaderParserTest\WarningParserTest.cs" />
+    <Compile Include="Headers\HeaderEncodingTest.cs" />
     <Compile Include="Headers\HeaderUtilitiesTest.cs" />
     <Compile Include="Headers\HttpContentHeadersTest.cs" />
     <Compile Include="Headers\HttpHeadersTest.cs" />