Fix UriParser static table thread safety (dotnet/corefx#34411)
authorStephen Toub <stoub@microsoft.com>
Tue, 8 Jan 2019 14:09:24 +0000 (09:09 -0500)
committerGitHub <noreply@github.com>
Tue, 8 Jan 2019 14:09:24 +0000 (09:09 -0500)
There are two static Dictionaries in UriParser that are read without holding a lock and modified while holding a lock; while it is safe for any number of readers to be concurrently accessing a dictionary, it is not safe for any positive number of readers to be doing so while there's at least one writer.  Based on how this code is written, my assumption is that at some point is was implemented using Hashtable, which does make such a situation safe.  As such, I'm fixing it by just switching back to using Hashtable; in the process, I'm also removing the static cctor and fixing a few readonly statics to instead be const.

I considered a few alternate implementations, including using an immutable dictionary, but this approach is simple and fast enough, and doesn't incur some of the costs other schemes would.

Commit migrated from https://github.com/dotnet/corefx/commit/47b47e0b2247856f1cf630535b7d821ea2ad60eb

src/libraries/System.Private.Uri/src/System/UriSyntax.cs

index 07dbd0a..b1f8dc1 100644 (file)
@@ -5,7 +5,7 @@
 // This file utilizes partial class feature and contains
 // only internal implementation of UriParser type
 
-using System.Collections.Generic;
+using System.Collections;
 using System.Diagnostics;
 
 namespace System
@@ -58,8 +58,46 @@ namespace System
     //
     public abstract partial class UriParser
     {
-        private static readonly Dictionary<string, UriParser> s_table;
-        private static Dictionary<string, UriParser> s_tempTable;
+        // These are always available without paying hashtable lookup cost
+        // Note: see UpdateStaticSyntaxReference()
+        internal static readonly UriParser HttpUri = new BuiltInUriParser("http", 80, HttpSyntaxFlags);
+        internal static readonly UriParser HttpsUri = new BuiltInUriParser("https", 443, HttpUri._flags);
+        internal static readonly UriParser WsUri = new BuiltInUriParser("ws", 80, HttpSyntaxFlags);
+        internal static readonly UriParser WssUri = new BuiltInUriParser("wss", 443, HttpSyntaxFlags);
+        internal static readonly UriParser FtpUri = new BuiltInUriParser("ftp", 21, FtpSyntaxFlags);
+        internal static readonly UriParser FileUri = new BuiltInUriParser("file", NoDefaultPort, FileSyntaxFlags);
+        internal static readonly UriParser UnixFileUri = new BuiltInUriParser("file", NoDefaultPort, UnixFileSyntaxFlags);
+        internal static readonly UriParser GopherUri = new BuiltInUriParser("gopher", 70, GopherSyntaxFlags);
+        internal static readonly UriParser NntpUri = new BuiltInUriParser("nntp", 119, NntpSyntaxFlags);
+        internal static readonly UriParser NewsUri = new BuiltInUriParser("news", NoDefaultPort, NewsSyntaxFlags);
+        internal static readonly UriParser MailToUri = new BuiltInUriParser("mailto", 25, MailtoSyntaxFlags);
+        internal static readonly UriParser UuidUri = new BuiltInUriParser("uuid", NoDefaultPort, NewsUri._flags);
+        internal static readonly UriParser TelnetUri = new BuiltInUriParser("telnet", 23, TelnetSyntaxFlags);
+        internal static readonly UriParser LdapUri = new BuiltInUriParser("ldap", 389, LdapSyntaxFlags);
+        internal static readonly UriParser NetTcpUri = new BuiltInUriParser("net.tcp", 808, NetTcpSyntaxFlags);
+        internal static readonly UriParser NetPipeUri = new BuiltInUriParser("net.pipe", NoDefaultPort, NetPipeSyntaxFlags);
+        internal static readonly UriParser VsMacrosUri = new BuiltInUriParser("vsmacros", NoDefaultPort, VsmacrosSyntaxFlags);
+
+        private static readonly Hashtable s_table = new Hashtable(16) // Hashtable used instead of Dictionary<> for lock-free reads
+        {
+            { HttpUri.SchemeName, HttpUri }, // HTTP
+            { HttpsUri.SchemeName, HttpsUri }, // HTTPS cloned from HTTP
+            { WsUri.SchemeName, WsUri }, // WebSockets
+            { WssUri.SchemeName, WssUri }, // Secure WebSockets
+            { FtpUri.SchemeName, FtpUri }, //FTP
+            { FileUri.SchemeName, FileUri }, //FILE
+            { GopherUri.SchemeName, GopherUri }, //GOPHER
+            { NntpUri.SchemeName, NntpUri }, //NNTP
+            { NewsUri.SchemeName, NewsUri }, //NEWS
+            { MailToUri.SchemeName, MailToUri }, //MAILTO
+            { UuidUri.SchemeName, UuidUri }, //UUID cloned from NEWS
+            { TelnetUri.SchemeName, TelnetUri }, //TELNET
+            { LdapUri.SchemeName, LdapUri }, //LDAP
+            { NetTcpUri.SchemeName, NetTcpUri },
+            { NetPipeUri.SchemeName, NetPipeUri },
+            { VsMacrosUri.SchemeName, VsMacrosUri }, //VSMACROS
+        };
+        private static Hashtable s_tempTable = new Hashtable(c_InitialTableSize); // Hashtable used instead of Dictionary<> for lock-free reads
 
         private UriSyntaxFlags _flags;
 
@@ -79,27 +117,6 @@ namespace System
         internal const int NoDefaultPort = -1;
         private const int c_InitialTableSize = 25;
 
-        // These are always available without paying hashtable lookup cost
-        // Note: see UpdateStaticSyntaxReference()
-        internal static UriParser HttpUri;
-        internal static UriParser HttpsUri;
-        internal static UriParser WsUri;
-        internal static UriParser WssUri;
-        internal static UriParser FtpUri;
-        internal static UriParser FileUri;
-        internal static UriParser UnixFileUri;
-        internal static UriParser GopherUri;
-        internal static UriParser NntpUri;
-        internal static UriParser NewsUri;
-        internal static UriParser MailToUri;
-        internal static UriParser UuidUri;
-        internal static UriParser TelnetUri;
-        internal static UriParser LdapUri;
-        internal static UriParser NetTcpUri;
-        internal static UriParser NetPipeUri;
-
-        internal static UriParser VsMacrosUri;
-
         internal static bool DontEnableStrictRFC3986ReservedCharacterSets
         {
             // In .NET Framework this would test against an AppContextSwitch. Since this is a potentially
@@ -120,65 +137,6 @@ namespace System
             }
         }
 
-        static UriParser()
-        {
-            s_table = new Dictionary<string, UriParser>(c_InitialTableSize);
-            s_tempTable = new Dictionary<string, UriParser>(c_InitialTableSize);
-
-            //Now we will call for the instance constructors that will interrupt this static one.
-
-            // Below we simulate calls into FetchSyntax() but avoid using lock() and other things redundant for a .cctor
-
-            HttpUri = new BuiltInUriParser("http", 80, HttpSyntaxFlags);
-            s_table[HttpUri.SchemeName] = HttpUri;                   //HTTP
-
-            HttpsUri = new BuiltInUriParser("https", 443, HttpUri._flags);
-            s_table[HttpsUri.SchemeName] = HttpsUri;                  //HTTPS cloned from HTTP
-
-            WsUri = new BuiltInUriParser("ws", 80, HttpSyntaxFlags);
-            s_table[WsUri.SchemeName] = WsUri;                   // WebSockets
-
-            WssUri = new BuiltInUriParser("wss", 443, HttpSyntaxFlags);
-            s_table[WssUri.SchemeName] = WssUri;                  // Secure WebSockets
-
-            FtpUri = new BuiltInUriParser("ftp", 21, FtpSyntaxFlags);
-            s_table[FtpUri.SchemeName] = FtpUri;                    //FTP
-
-            FileUri = new BuiltInUriParser("file", NoDefaultPort, s_fileSyntaxFlags);
-            UnixFileUri = new BuiltInUriParser("file", NoDefaultPort, s_unixFileSyntaxFlags);
-            s_table[FileUri.SchemeName] = FileUri;                   //FILE
-
-            GopherUri = new BuiltInUriParser("gopher", 70, GopherSyntaxFlags);
-            s_table[GopherUri.SchemeName] = GopherUri;                 //GOPHER
-
-            NntpUri = new BuiltInUriParser("nntp", 119, NntpSyntaxFlags);
-            s_table[NntpUri.SchemeName] = NntpUri;                   //NNTP
-
-            NewsUri = new BuiltInUriParser("news", NoDefaultPort, NewsSyntaxFlags);
-            s_table[NewsUri.SchemeName] = NewsUri;                   //NEWS
-
-            MailToUri = new BuiltInUriParser("mailto", 25, MailtoSyntaxFlags);
-            s_table[MailToUri.SchemeName] = MailToUri;                 //MAILTO
-
-            UuidUri = new BuiltInUriParser("uuid", NoDefaultPort, NewsUri._flags);
-            s_table[UuidUri.SchemeName] = UuidUri;                   //UUID cloned from NEWS
-
-            TelnetUri = new BuiltInUriParser("telnet", 23, TelnetSyntaxFlags);
-            s_table[TelnetUri.SchemeName] = TelnetUri;                 //TELNET
-
-            LdapUri = new BuiltInUriParser("ldap", 389, LdapSyntaxFlags);
-            s_table[LdapUri.SchemeName] = LdapUri;                   //LDAP
-
-            NetTcpUri = new BuiltInUriParser("net.tcp", 808, NetTcpSyntaxFlags);
-            s_table[NetTcpUri.SchemeName] = NetTcpUri;
-
-            NetPipeUri = new BuiltInUriParser("net.pipe", NoDefaultPort, NetPipeSyntaxFlags);
-            s_table[NetPipeUri.SchemeName] = NetPipeUri;
-
-            VsMacrosUri = new BuiltInUriParser("vsmacros", NoDefaultPort, VsmacrosSyntaxFlags);
-            s_table[VsMacrosUri.SchemeName] = VsMacrosUri;               //VSMACROS
-        }
-
         private class BuiltInUriParser : UriParser
         {
             //
@@ -259,12 +217,11 @@ namespace System
             lock (s_table)
             {
                 syntax._flags &= ~UriSyntaxFlags.V1_UnknownUri;
-                UriParser oldSyntax = null;
-                s_table.TryGetValue(lwrCaseSchemeName, out oldSyntax);
+                UriParser oldSyntax = (UriParser)s_table[lwrCaseSchemeName];
                 if (oldSyntax != null)
                     throw new InvalidOperationException(SR.Format(SR.net_uri_AlreadyRegistered, oldSyntax.SchemeName));
-                
-                s_tempTable.TryGetValue(syntax.SchemeName, out oldSyntax);
+
+                oldSyntax = (UriParser)s_tempTable[syntax.SchemeName];
                 if (oldSyntax != null)
                 {
                     // optimization on schemeName, will try to keep the first reference
@@ -286,13 +243,12 @@ namespace System
         internal static UriParser FindOrFetchAsUnknownV1Syntax(string lwrCaseScheme)
         {
             // check may be other thread just added one
-            UriParser syntax = null;
-            s_table.TryGetValue(lwrCaseScheme, out syntax);
+            UriParser syntax = (UriParser)s_table[lwrCaseScheme];
             if (syntax != null)
             {
                 return syntax;
             }
-            s_tempTable.TryGetValue(lwrCaseScheme, out syntax);
+            syntax = (UriParser)s_tempTable[lwrCaseScheme];
             if (syntax != null)
             {
                 return syntax;
@@ -301,7 +257,7 @@ namespace System
             {
                 if (s_tempTable.Count >= c_MaxCapacity)
                 {
-                    s_tempTable = new Dictionary<string, UriParser>(c_InitialTableSize);
+                    s_tempTable = new Hashtable(c_InitialTableSize);
                 }
                 syntax = new BuiltInUriParser(lwrCaseScheme, NoDefaultPort, UnknownV1SyntaxFlags);
                 s_tempTable[lwrCaseScheme] = syntax;
@@ -309,16 +265,8 @@ namespace System
             }
         }
 
-        internal static UriParser GetSyntax(string lwrCaseScheme)
-        {
-            UriParser ret = null;
-            s_table.TryGetValue(lwrCaseScheme, out ret);
-            if (ret == null)
-            {
-                s_tempTable.TryGetValue(lwrCaseScheme, out ret);
-            }
-            return ret;
-        }
+        internal static UriParser GetSyntax(string lwrCaseScheme) =>
+            (UriParser)(s_table[lwrCaseScheme] ?? s_tempTable[lwrCaseScheme]);
 
         //
         // Builtin and User Simple syntaxes do not need custom validation/parsing (i.e. virtual method calls),
@@ -468,7 +416,7 @@ namespace System
                                         UriSyntaxFlags.AllowIdn |
                                         UriSyntaxFlags.AllowIriParsing;
 
-        private static readonly UriSyntaxFlags s_fileSyntaxFlags =
+        private const UriSyntaxFlags FileSyntaxFlags =
                                         UriSyntaxFlags.MustHaveAuthority |
                                         //
                                         UriSyntaxFlags.AllowEmptyHost |
@@ -491,8 +439,8 @@ namespace System
                                         UriSyntaxFlags.AllowIdn |
                                         UriSyntaxFlags.AllowIriParsing;
 
-        private static readonly UriSyntaxFlags s_unixFileSyntaxFlags =
-                                        s_fileSyntaxFlags & ~UriSyntaxFlags.ConvertPathSlashes;
+        private const UriSyntaxFlags UnixFileSyntaxFlags =
+                                        FileSyntaxFlags & ~UriSyntaxFlags.ConvertPathSlashes;
 
         private const UriSyntaxFlags VsmacrosSyntaxFlags =
                                         UriSyntaxFlags.MustHaveAuthority |