Make Uri Thread-Safe (#33042)
authorMiha Zupan <mihazupan.zupan1@gmail.com>
Sat, 9 May 2020 19:26:16 +0000 (21:26 +0200)
committerGitHub <noreply@github.com>
Sat, 9 May 2020 19:26:16 +0000 (15:26 -0400)
* Do not lock on string in Uri

* Make Uri thread-safe without locks

*A lock may still be used for custom parsers

* Update comments

* Use existing style for enum values

* Add a comment about locking for custom parsers

* Add a comment about DebugSetLeftCtor usage

* Add trailing dot in exception message

* Fix typos

* Fix typo

Co-authored-by: Stephen Toub <stoub@microsoft.com>
Co-authored-by: Stephen Toub <stoub@microsoft.com>
src/libraries/System.Private.Uri/src/Resources/Strings.resx
src/libraries/System.Private.Uri/src/System/Uri.cs
src/libraries/System.Private.Uri/src/System/UriExt.cs
src/libraries/System.Private.Uri/src/System/UriScheme.cs
src/libraries/System.Private.Uri/src/System/UriSyntax.cs
src/libraries/System.Private.Uri/tests/FunctionalTests/UriParserTest.cs
src/libraries/System.Private.Uri/tests/FunctionalTests/UriTests.cs

index 655f2820692e2cf121d5bd38dbd7b8b6da143814..16e3b576da63aba40f1597eabcaabd14a0c06b51 100644 (file)
@@ -1,4 +1,5 @@
-<root>
+<?xml version="1.0" encoding="utf-8"?>
+<root>
   <!-- 
     Microsoft ResX Schema 
     
   <data name="Arg_KeyNotFoundWithKey" xml:space="preserve">
     <value>The given key '{0}' was not present in the dictionary.</value>
   </data>
+  <data name="net_uri_InitializeCalledAlreadyOrTooLate" xml:space="preserve">
+    <value>UriParser's base InitializeAndValidate may only be called once on a single Uri instance and only from an override of InitializeAndValidate.</value>
+  </data>
 </root>
\ No newline at end of file
index 30d54570b89f32a8a7fe7e16c206df07db870685..0755daca458aaa4b31ba67536d2a3ee23529f8f5 100644 (file)
@@ -2,12 +2,14 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 // See the LICENSE file in the project root for more information.
 
+using Internal.Runtime.CompilerServices;
 using System.Collections.Generic;
 using System.Diagnostics;
 using System.Globalization;
 using System.Runtime.InteropServices;
 using System.Runtime.Serialization;
 using System.Text;
+using System.Threading;
 
 namespace System
 {
@@ -42,10 +44,13 @@ namespace System
         // untouched user string if string has unicode with iri on or unicode/idn host with idn on
         private string _originalUnicodeString = null!; // initialized in ctor via helper
 
-        private UriParser _syntax = null!;   // Initialized in ctor via helper. This is a whole Uri syntax, not only the scheme name
+        internal UriParser _syntax = null!;   // Initialized in ctor via helper. This is a whole Uri syntax, not only the scheme name
+
+        internal Flags _flags;
+        private UriInfo _info = null!;
 
         [Flags]
-        private enum Flags : ulong
+        internal enum Flags : ulong
         {
             Zero = 0x00000000,
 
@@ -110,10 +115,29 @@ namespace System
             FragmentIriCanonical = 0x40000000000,
             IriCanonical = 0x78000000000,
             UnixPath = 0x100000000000,
+
+            /// <summary>
+            /// Used to ensure that InitializeAndValidate is only called once per Uri instance and only from an override of InitializeAndValidate
+            /// </summary>
+            CustomParser_ParseMinimalAlreadyCalled = 0x4000000000000000,
+
+            /// <summary>
+            /// Used for asserting that certain methods are only called from the constructor to validate thread-safety assumptions
+            /// </summary>
+            Debug_LeftConstructor = 0x8000000000000000
+        }
+
+        [Conditional("DEBUG")]
+        private void DebugSetLeftCtor()
+        {
+            _flags |= Flags.Debug_LeftConstructor;
         }
 
-        private Flags _flags;
-        private UriInfo _info = null!; // initialized during ctor via helper
+        [Conditional("DEBUG")]
+        internal void DebugAssertInCtor()
+        {
+            Debug.Assert((_flags & Flags.Debug_LeftConstructor) == 0);
+        }
 
         private class UriInfo
         {
@@ -122,8 +146,19 @@ namespace System
             public string? String;
             public Offset Offset;
             public string? DnsSafeHost;    // stores dns safe host when idn is on and we have unicode or idn host
-            public MoreInfo? MoreInfo;     // Multi-threading: This field must be always accessed through a _local_
-                                           // stack copy of _info.
+
+            private MoreInfo? _moreInfo;
+            public MoreInfo MoreInfo
+            {
+                get
+                {
+                    if (_moreInfo is null)
+                    {
+                        Interlocked.CompareExchange(ref _moreInfo, new MoreInfo(), null);
+                    }
+                    return _moreInfo;
+                }
+            }
         };
 
         [StructLayout(LayoutKind.Sequential, Pack = 1)]
@@ -148,6 +183,27 @@ namespace System
             public string? RemoteUrl;
         };
 
+        private void InterlockedSetFlags(Flags flags)
+        {
+            Debug.Assert(_syntax != null);
+
+            if (_syntax.IsSimple)
+            {
+                // For built-in (simple) parsers, it is safe to do an Interlocked update here
+                Debug.Assert(sizeof(Flags) == sizeof(ulong));
+                Interlocked.Or(ref Unsafe.As<Flags, ulong>(ref _flags), (ulong)flags);
+            }
+            else
+            {
+                // Custom parsers still use a lock in CreateHostString and perform non-atomic flags updates
+                // We have to take the lock to ensure flags access synchronization if CreateHostString and ParseRemaining are called concurrently
+                lock (_info)
+                {
+                    _flags |= flags;
+                }
+            }
+        }
+
         private bool IsImplicitFile
         {
             get { return (_flags & Flags.ImplicitFile) != 0; }
@@ -208,11 +264,6 @@ namespace System
                 return (_flags & Flags.UserDrivenParsing) != 0;
             }
         }
-        private void SetUserDrivenParsing()
-        {
-            // we use = here to clear all parsing flags for a uri that we think is invalid.
-            _flags = Flags.UserDrivenParsing | (_flags & Flags.UserEscaped);
-        }
 
         private int SecuredPathIndex
         {
@@ -252,10 +303,11 @@ namespace System
         private UriInfo EnsureUriInfo()
         {
             Flags cF = _flags;
-            if ((_flags & Flags.MinimalUriInfoSet) == 0)
+            if ((cF & Flags.MinimalUriInfoSet) == 0)
             {
                 CreateUriInfo(cF);
             }
+            Debug.Assert(_info != null && (_flags & Flags.MinimalUriInfoSet) != 0);
             return _info;
         }
 
@@ -298,6 +350,7 @@ namespace System
                 throw new ArgumentNullException(nameof(uriString));
 
             CreateThis(uriString, false, UriKind.Absolute);
+            DebugSetLeftCtor();
         }
 
         //
@@ -312,6 +365,7 @@ namespace System
                 throw new ArgumentNullException(nameof(uriString));
 
             CreateThis(uriString, dontEscape, UriKind.Absolute);
+            DebugSetLeftCtor();
         }
 
         //
@@ -323,13 +377,14 @@ namespace System
         [Obsolete("The constructor has been deprecated. Please new Uri(Uri, string). The dontEscape parameter is deprecated and is always false. https://go.microsoft.com/fwlink/?linkid=14202")]
         public Uri(Uri baseUri, string? relativeUri, bool dontEscape)
         {
-            if (baseUri == null)
+            if (baseUri is null)
                 throw new ArgumentNullException(nameof(baseUri));
 
             if (!baseUri.IsAbsoluteUri)
                 throw new ArgumentOutOfRangeException(nameof(baseUri));
 
             CreateUri(baseUri, relativeUri, dontEscape);
+            DebugSetLeftCtor();
         }
 
         //
@@ -337,10 +392,11 @@ namespace System
         //
         public Uri(string uriString, UriKind uriKind)
         {
-            if ((object)uriString == null)
+            if (uriString is null)
                 throw new ArgumentNullException(nameof(uriString));
 
             CreateThis(uriString, false, uriKind);
+            DebugSetLeftCtor();
         }
 
         //
@@ -352,13 +408,14 @@ namespace System
         //
         public Uri(Uri baseUri, string? relativeUri)
         {
-            if ((object)baseUri == null)
+            if (baseUri is null)
                 throw new ArgumentNullException(nameof(baseUri));
 
             if (!baseUri.IsAbsoluteUri)
                 throw new ArgumentOutOfRangeException(nameof(baseUri));
 
             CreateUri(baseUri, relativeUri, false);
+            DebugSetLeftCtor();
         }
 
         //
@@ -373,6 +430,7 @@ namespace System
             if (uriString!.Length != 0)
             {
                 CreateThis(uriString, false, UriKind.Absolute);
+                DebugSetLeftCtor();
                 return;
             }
 
@@ -381,6 +439,7 @@ namespace System
                 throw new ArgumentNullException(nameof(uriString));
 
             CreateThis(uriString, false, UriKind.Relative);
+            DebugSetLeftCtor();
         }
 
         //
@@ -410,6 +469,8 @@ namespace System
 
         private void CreateUri(Uri baseUri, string? relativeUri, bool dontEscape)
         {
+            DebugAssertInCtor();
+
             // Parse relativeUri and populate Uri internal data.
             CreateThis(relativeUri, dontEscape, UriKind.RelativeOrAbsolute);
 
@@ -452,7 +513,7 @@ namespace System
         //
         public Uri(Uri baseUri, Uri relativeUri)
         {
-            if ((object)baseUri == null)
+            if (baseUri is null)
                 throw new ArgumentNullException(nameof(baseUri));
 
             if (!baseUri.IsAbsoluteUri)
@@ -477,6 +538,7 @@ namespace System
                     if ((object)resolvedRelativeUri != (object)this)
                         CreateThisFromUri(resolvedRelativeUri);
 
+                    DebugSetLeftCtor();
                     return;
                 }
             }
@@ -492,6 +554,7 @@ namespace System
             _info = null!;
             _syntax = null!;
             CreateThis(newUriString, dontEscape, UriKind.Absolute);
+            DebugSetLeftCtor();
         }
 
         //
@@ -623,18 +686,10 @@ namespace System
         {
             get
             {
-                UriInfo info = EnsureUriInfo();
-                if ((object?)info.MoreInfo == null)
-                {
-                    info.MoreInfo = new MoreInfo();
-                }
-                string? result = info.MoreInfo.Path;
-                if ((object?)result == null)
-                {
-                    result = GetParts(UriComponents.Path | UriComponents.KeepDelimiter, UriFormat.UriEscaped);
-                    info.MoreInfo.Path = result;
-                }
-                return result;
+                Debug.Assert(IsAbsoluteUri);
+
+                MoreInfo info = EnsureUriInfo().MoreInfo;
+                return info.Path ??= GetParts(UriComponents.Path | UriComponents.KeepDelimiter, UriFormat.UriEscaped);
             }
         }
 
@@ -647,18 +702,8 @@ namespace System
                     throw new InvalidOperationException(SR.net_uri_NotAbsolute);
                 }
 
-                UriInfo info = EnsureUriInfo();
-                if ((object?)info.MoreInfo == null)
-                {
-                    info.MoreInfo = new MoreInfo();
-                }
-                string? result = info.MoreInfo.AbsoluteUri;
-                if ((object?)result == null)
-                {
-                    result = GetParts(UriComponents.AbsoluteUri, UriFormat.UriEscaped);
-                    info.MoreInfo.AbsoluteUri = result;
-                }
-                return result;
+                MoreInfo info = EnsureUriInfo().MoreInfo;
+                return info.AbsoluteUri ??= GetParts(UriComponents.AbsoluteUri, UriFormat.UriEscaped);
             }
         }
 
@@ -1024,18 +1069,8 @@ namespace System
                     throw new InvalidOperationException(SR.net_uri_NotAbsolute);
                 }
 
-                UriInfo info = EnsureUriInfo();
-                if ((object?)info.MoreInfo == null)
-                {
-                    info.MoreInfo = new MoreInfo();
-                }
-                string? result = info.MoreInfo.Query;
-                if ((object?)result == null)
-                {
-                    result = GetParts(UriComponents.Query | UriComponents.KeepDelimiter, UriFormat.UriEscaped);
-                    info.MoreInfo.Query = result;
-                }
-                return result;
+                MoreInfo info = EnsureUriInfo().MoreInfo;
+                return info.Query ??= GetParts(UriComponents.Query | UriComponents.KeepDelimiter, UriFormat.UriEscaped);
             }
         }
 
@@ -1050,18 +1085,8 @@ namespace System
                     throw new InvalidOperationException(SR.net_uri_NotAbsolute);
                 }
 
-                UriInfo info = EnsureUriInfo();
-                if ((object?)info.MoreInfo == null)
-                {
-                    info.MoreInfo = new MoreInfo();
-                }
-                string? result = info.MoreInfo.Fragment;
-                if ((object?)result == null)
-                {
-                    result = GetParts(UriComponents.Fragment | UriComponents.KeepDelimiter, UriFormat.UriEscaped);
-                    info.MoreInfo.Fragment = result;
-                }
-                return result;
+                MoreInfo info = EnsureUriInfo().MoreInfo;
+                return info.Fragment ??= GetParts(UriComponents.Fragment | UriComponents.KeepDelimiter, UriFormat.UriEscaped);
             }
         }
 
@@ -1489,8 +1514,8 @@ namespace System
             }
             else
             {
-                MoreInfo moreInfo = EnsureUriInfo().MoreInfo ??= new MoreInfo();
-                string remoteUrl = moreInfo.RemoteUrl ??= GetParts(UriComponents.HttpRequestUrl, UriFormat.SafeUnescaped);
+                MoreInfo info = EnsureUriInfo().MoreInfo;
+                string remoteUrl = info.RemoteUrl ??= GetParts(UriComponents.HttpRequestUrl, UriFormat.SafeUnescaped);
 
                 if (IsUncOrDosPath)
                 {
@@ -1725,31 +1750,11 @@ namespace System
             // We want to cache RemoteUrl to improve perf for Uri as a key.
             // We should consider reducing the overall working set by not caching some other properties mentioned in MoreInfo
 
-            UriInfo selfInfo = _info;
-            UriInfo otherInfo = obj._info;
-            if ((object?)selfInfo.MoreInfo == null)
-            {
-                selfInfo.MoreInfo = new MoreInfo();
-            }
-            if ((object?)otherInfo.MoreInfo == null)
-            {
-                otherInfo.MoreInfo = new MoreInfo();
-            }
+            MoreInfo selfInfo = _info.MoreInfo;
+            MoreInfo otherInfo = obj._info.MoreInfo;
 
-            // NB: To avoid a race condition when creating MoreInfo field
-            // "selfInfo" and "otherInfo" shall remain as local copies.
-            string? selfUrl = selfInfo.MoreInfo.RemoteUrl;
-            if ((object?)selfUrl == null)
-            {
-                selfUrl = GetParts(UriComponents.HttpRequestUrl, UriFormat.SafeUnescaped);
-                selfInfo.MoreInfo.RemoteUrl = selfUrl;
-            }
-            string? otherUrl = otherInfo.MoreInfo.RemoteUrl;
-            if ((object?)otherUrl == null)
-            {
-                otherUrl = obj.GetParts(UriComponents.HttpRequestUrl, UriFormat.SafeUnescaped);
-                otherInfo.MoreInfo.RemoteUrl = otherUrl;
-            }
+            string selfUrl = selfInfo.RemoteUrl ??= GetParts(UriComponents.HttpRequestUrl, UriFormat.SafeUnescaped);
+            string otherUrl = otherInfo.RemoteUrl ??= obj.GetParts(UriComponents.HttpRequestUrl, UriFormat.SafeUnescaped);
 
             if (!IsUncOrDosPath)
             {
@@ -1783,8 +1788,8 @@ namespace System
             // Get Unescaped form as most safe for the comparison
             // Fragment AND UserInfo are ignored
             //
-            return (string.Compare(selfInfo.MoreInfo.RemoteUrl,
-                                   otherInfo.MoreInfo.RemoteUrl,
+            return (string.Compare(selfUrl,
+                                   otherUrl,
                                    IsUncOrDosPath ? StringComparison.OrdinalIgnoreCase : StringComparison.Ordinal) == 0);
         }
 
@@ -1857,6 +1862,8 @@ namespace System
         //
         private static unsafe ParsingError ParseScheme(string uriString, ref Flags flags, ref UriParser? syntax)
         {
+            Debug.Assert((flags & Flags.Debug_LeftConstructor) == 0);
+
             int length = uriString.Length;
             if (length == 0)
                 return ParsingError.EmptyUriString;
@@ -1885,6 +1892,10 @@ namespace System
         //
         internal UriFormatException? ParseMinimal()
         {
+            Debug.Assert(_syntax != null && !_syntax.IsSimple);
+            Debug.Assert((_flags & Flags.CustomParser_ParseMinimalAlreadyCalled) != 0);
+            DebugAssertInCtor();
+
             ParsingError result = PrivateParseMinimal();
             if (result == ParsingError.None)
                 return null;
@@ -1908,6 +1919,9 @@ namespace System
         //
         private unsafe ParsingError PrivateParseMinimal()
         {
+            Debug.Assert(_syntax != null);
+            DebugAssertInCtor();
+
             int idx = (int)(_flags & Flags.IndexMask);
             int length = _string.Length;
             string? newHost = null;      // stores newly parsed host when original strings are being switched
@@ -2324,13 +2338,21 @@ namespace System
 
         Done:
             cF |= Flags.MinimalUriInfoSet;
-            lock (_string)
+
+            Debug.Assert(sizeof(Flags) == sizeof(ulong));
+
+            Interlocked.CompareExchange(ref _info, info, null!);
+
+            Flags current = _flags;
+            while ((current & Flags.MinimalUriInfoSet) == 0)
             {
-                if ((_flags & Flags.MinimalUriInfoSet) == 0)
+                Flags newValue = (current & ~Flags.IndexMask) | cF;
+                ulong oldValue = Interlocked.CompareExchange(ref Unsafe.As<Flags, ulong>(ref _flags), (ulong)newValue, (ulong)current);
+                if (oldValue == (ulong)current)
                 {
-                    _info = info;
-                    _flags = (_flags & ~Flags.IndexMask) | cF;
+                    return;
                 }
+                current = (Flags)oldValue;
             }
         }
 
@@ -2426,10 +2448,7 @@ namespace System
             }
 
             _info.Host = host;
-            lock (_info)
-            {
-                _flags |= flags;
-            }
+            InterlockedSetFlags(flags);
         }
 
         private static string CreateHostStringHelper(string str, int idx, int end, ref Flags flags, ref string? scopeId)
@@ -2585,6 +2604,8 @@ namespace System
 
         private string GetEscapedParts(UriComponents uriParts)
         {
+            Debug.Assert(_info != null && (_flags & Flags.MinimalUriInfoSet) != 0);
+
             // Which Uri parts are not escaped canonically ?
             // Notice that public UriPart and private Flags must be in Sync so below code can work
             //
@@ -2622,6 +2643,8 @@ namespace System
 
         private string GetUnescapedParts(UriComponents uriParts, UriFormat formatAs)
         {
+            Debug.Assert(_info != null && (_flags & Flags.MinimalUriInfoSet) != 0);
+
             // Which Uri parts are not escaped canonically ?
             // Notice that public UriComponents and private Uri.Flags must me in Sync so below code can work
             //
@@ -3558,14 +3581,10 @@ namespace System
                 }
             }
             _info.Offset.End = (ushort)idx;
-        Done:
 
-            cF |= Flags.AllUriInfoSet;
-            lock (_info)
-            {
-                _flags |= cF;
-            }
-            _flags |= Flags.RestUnicodeNormalized;
+        Done:
+            cF |= Flags.AllUriInfoSet | Flags.RestUnicodeNormalized;
+            InterlockedSetFlags(cF);
         }
 
         //
@@ -3577,6 +3596,8 @@ namespace System
         private static unsafe int ParseSchemeCheckImplicitFile(char* uriString, int length,
             ref ParsingError err, ref Flags flags, ref UriParser? syntax)
         {
+            Debug.Assert((flags & Flags.Debug_LeftConstructor) == 0);
+
             int idx = 0;
 
             //skip whitespace
@@ -3958,6 +3979,8 @@ namespace System
         private unsafe int CheckAuthorityHelper(char* pString, int idx, int length,
             ref ParsingError err, ref Flags flags, UriParser syntax, ref string? newHost)
         {
+            Debug.Assert((_flags & Flags.Debug_LeftConstructor) == 0 || (!_syntax.IsSimple && Monitor.IsEntered(_info)));
+
             int end = length;
             char ch;
             int startInput = idx;
@@ -3969,6 +3992,8 @@ namespace System
             bool hostNotUnicodeNormalized = hasUnicode && ((flags & Flags.HostUnicodeNormalized) == 0);
             UriSyntaxFlags syntaxFlags = syntax.Flags;
 
+            Debug.Assert((_flags & Flags.HasUserInfo) == 0 && (_flags & Flags.HostTypeMask) == 0);
+
             //Special case is an empty authority
             if (idx == length || ((ch = pString[idx]) == '/' || (ch == '\\' && StaticIsFile(syntax)) || ch == '#' || ch == '?'))
             {
index 67e7eab4007b2f425b9a74f3c409755b9ee6b87e..5f40d02e0569f1bd2179df26ba7985bb652e3b22 100644 (file)
@@ -16,6 +16,8 @@ namespace System
         //
         private void CreateThis(string? uri, bool dontEscape, UriKind uriKind)
         {
+            DebugAssertInCtor();
+
             // if (!Enum.IsDefined(typeof(UriKind), uriKind)) -- We currently believe that Enum.IsDefined() is too slow
             // to be used here.
             if ((int)uriKind < (int)UriKind.RelativeOrAbsolute || (int)uriKind > (int)UriKind.Relative)
@@ -23,21 +25,22 @@ namespace System
                 throw new ArgumentException(SR.Format(SR.net_uri_InvalidUriKind, uriKind));
             }
 
-            _string = uri == null ? string.Empty : uri;
+            _string = uri ?? string.Empty;
 
             if (dontEscape)
                 _flags |= Flags.UserEscaped;
 
             ParsingError err = ParseScheme(_string, ref _flags, ref _syntax!);
-            UriFormatException? e;
 
-            InitializeUri(err, uriKind, out e);
+            InitializeUri(err, uriKind, out UriFormatException? e);
             if (e != null)
                 throw e;
         }
 
         private void InitializeUri(ParsingError err, UriKind uriKind, out UriFormatException? e)
         {
+            DebugAssertInCtor();
+
             if (err == ParsingError.None)
             {
                 if (IsImplicitFile)
@@ -157,7 +160,8 @@ namespace System
                         if (err != ParsingError.None || InFact(Flags.ErrorOrParsingRecursion))
                         {
                             // User parser took over on an invalid Uri
-                            SetUserDrivenParsing();
+                            // we use = here to clear all parsing flags for a uri that we think is invalid.
+                            _flags = Flags.UserDrivenParsing | (_flags & Flags.UserEscaped);
                         }
                         else if (uriKind == UriKind.Relative)
                         {
@@ -263,20 +267,20 @@ namespace System
         //
         public static bool TryCreate(string? uriString, UriKind uriKind, [NotNullWhen(true)] out Uri? result)
         {
-            if ((object?)uriString == null)
+            if (uriString is null)
             {
                 result = null;
                 return false;
             }
             UriFormatException? e = null;
             result = CreateHelper(uriString, false, uriKind, ref e);
-            return (object?)e == null && result != null;
+            result?.DebugSetLeftCtor();
+            return e is null && result != null;
         }
 
         public static bool TryCreate(Uri? baseUri, string? relativeUri, [NotNullWhen(true)] out Uri? result)
         {
-            Uri? relativeLink;
-            if (TryCreate(relativeUri, UriKind.RelativeOrAbsolute, out relativeLink))
+            if (TryCreate(relativeUri, UriKind.RelativeOrAbsolute, out Uri? relativeLink))
             {
                 if (!relativeLink.IsAbsoluteUri)
                     return TryCreate(baseUri, relativeLink, out result);
@@ -292,7 +296,7 @@ namespace System
         {
             result = null;
 
-            if ((object?)baseUri == null || (object?)relativeUri == null)
+            if (baseUri is null || relativeUri is null)
                 return false;
 
             if (baseUri.IsNotAbsoluteUri)
@@ -306,6 +310,7 @@ namespace System
             {
                 dontEscape = relativeUri.UserEscaped;
                 result = ResolveHelper(baseUri, relativeUri, ref newUriString, ref dontEscape, out e);
+                Debug.Assert(e is null || result is null);
             }
             else
             {
@@ -316,10 +321,11 @@ namespace System
             if (e != null)
                 return false;
 
-            if ((object?)result == null)
+            if (result is null)
                 result = CreateHelper(newUriString!, dontEscape, UriKind.Absolute, ref e);
 
-            return (object?)e == null && result != null && result.IsAbsoluteUri;
+            result?.DebugSetLeftCtor();
+            return e is null && result != null && result.IsAbsoluteUri;
         }
 
         public string GetComponents(UriComponents components, UriFormat format)
@@ -590,6 +596,13 @@ namespace System
             _flags = flags;
             _syntax = uriParser!;
             _string = uri;
+
+            if (uriParser is null)
+            {
+                // Relative Uris are fully initialized after the call to this constructor
+                // Absolute Uris will be initialized with a call to InitializeUri on the newly created instance
+                DebugSetLeftCtor();
+            }
         }
 
         //
@@ -622,6 +635,7 @@ namespace System
             }
 
             // Cannot be relative Uri if came here
+            Debug.Assert(syntax != null);
             Uri result = new Uri(flags, syntax, uriString);
 
             // Validate instance using ether built in or a user Parser
@@ -630,7 +644,10 @@ namespace System
                 result.InitializeUri(err, uriKind, out e);
 
                 if (e == null)
+                {
+                    result.DebugSetLeftCtor();
                     return result;
+                }
 
                 return null;
             }
@@ -893,6 +910,8 @@ namespace System
         //
         private void CreateThisFromUri(Uri otherUri)
         {
+            DebugAssertInCtor();
+
             // Clone the other URI but develop own UriInfo member
             _info = null!;
 
index 36c759bb2a6af6126f763dcfc60df6eb8ac49f78..0675b41204a87bf0e91594bbe09d915e5c65d8d6 100644 (file)
@@ -2,6 +2,10 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 // See the LICENSE file in the project root for more information.
 
+using System.Diagnostics;
+using System.Threading;
+using Internal.Runtime.CompilerServices;
+
 namespace System
 {
     //
@@ -61,6 +65,25 @@ namespace System
         //
         protected virtual void InitializeAndValidate(Uri uri, out UriFormatException? parsingError)
         {
+            if (uri._syntax is null)
+            {
+                throw new InvalidOperationException(SR.net_uri_NotAbsolute);
+            }
+
+            if (!ReferenceEquals(uri._syntax, this))
+            {
+                throw new InvalidOperationException(SR.Format(SR.net_uri_UserDrivenParsing, uri._syntax.GetType()));
+            }
+
+            Debug.Assert(sizeof(Uri.Flags) == sizeof(ulong));
+
+            // If ParseMinimal is called multiple times this Uri instance may be corrupted, throw an exception instead
+            ulong previous = Interlocked.Or(ref Unsafe.As<Uri.Flags, ulong>(ref uri._flags), (ulong)Uri.Flags.CustomParser_ParseMinimalAlreadyCalled);
+            if (((Uri.Flags)previous & Uri.Flags.CustomParser_ParseMinimalAlreadyCalled) != 0)
+            {
+                throw new InvalidOperationException(SR.net_uri_InitializeCalledAlreadyOrTooLate);
+            }
+
             parsingError = uri.ParseMinimal();
         }
 
index 071efb60c0ff0ef4bf0e4f87704ee638a40e1dfb..48afbcb53fd0a032a4f522b9e96804c5d3c055ae 100644 (file)
@@ -6,6 +6,9 @@
 // only internal implementation of UriParser type
 
 using System.Collections;
+using System.Diagnostics;
+using System.Threading;
+using Internal.Runtime.CompilerServices;
 
 namespace System
 {
@@ -266,7 +269,12 @@ namespace System
 
         internal void InternalValidate(Uri thisUri, out UriFormatException? parsingError)
         {
+            thisUri.DebugAssertInCtor();
             InitializeAndValidate(thisUri, out parsingError);
+
+            // InitializeAndValidate should not be called outside of the constructor
+            Debug.Assert(sizeof(Uri.Flags) == sizeof(ulong));
+            Interlocked.Or(ref Unsafe.As<Uri.Flags, ulong>(ref thisUri._flags), (ulong)Uri.Flags.CustomParser_ParseMinimalAlreadyCalled);
         }
 
         internal string? InternalResolve(Uri thisBaseUri, Uri uriLink, out UriFormatException? parsingError)
index 0989ad99cf1b654af372d848e0e7354347c8fdb1..b063572a75696acff4dd94378593edf45b35ead1 100644 (file)
@@ -35,7 +35,14 @@ namespace System.PrivateUri.Tests
     {
         public TestUriParser() : base() { }
         public new string GetComponents(Uri uri, UriComponents components, UriFormat format) => base.GetComponents(uri, components, format);
-        public new void InitializeAndValidate(Uri uri, out UriFormatException parsingError) => base.InitializeAndValidate(uri, out parsingError);
+        protected override void InitializeAndValidate(Uri uri, out UriFormatException parsingError)
+        {
+            parsingError = null;
+            for (int i = 0; i < BaseInitializeAndValidateCallCount; i++)
+            {
+                base.InitializeAndValidate(uri, out parsingError);
+            }
+        }
         public new bool IsBaseOf(Uri baseUri, Uri relativeUri) => base.IsBaseOf(baseUri, relativeUri);
         public new bool IsWellFormedOriginalString(Uri uri) => base.IsWellFormedOriginalString(uri);
         public new string Resolve(Uri baseUri, Uri relativeUri, out UriFormatException parsingError) => base.Resolve(baseUri, relativeUri, out parsingError);
@@ -55,6 +62,11 @@ namespace System.PrivateUri.Tests
         public string SchemeName { get; private set; }
         public int DefaultPort { get; private set; }
 
+        public int BaseInitializeAndValidateCallCount = 1;
+        public void DangerousExposed_InitializeAndValidate(Uri uri, out UriFormatException parsingError)
+        {
+            InitializeAndValidate(uri, out parsingError);
+        }
     }
     #endregion Test class
 
@@ -186,19 +198,40 @@ namespace System.PrivateUri.Tests
         }
 
         [Fact]
-        public static void InitializeAndValidate()
+        public static void InitializeAndValidate_ThrowsOnUriOfDifferentScheme()
         {
-            Uri http = new Uri(FullHttpUri);
+            Uri uri = new Uri(FullHttpUri);
             TestUriParser parser = new TestUriParser();
-            parser.InitializeAndValidate(http, out UriFormatException error);
-            Assert.NotNull(error);
+            Assert.Throws<InvalidOperationException>(() => parser.DangerousExposed_InitializeAndValidate(uri, out _));
         }
 
         [Fact]
-        public static void InitializeAndValidate_Null()
+        public static void InitializeAndValidate_ThrowsOnRelativeUri()
         {
+            Uri uri = new Uri("foo", UriKind.Relative);
             TestUriParser parser = new TestUriParser();
-            Assert.Throws<NullReferenceException>(() => parser.InitializeAndValidate(null, out _));
+            Assert.Throws<InvalidOperationException>(() => parser.DangerousExposed_InitializeAndValidate(uri, out _));
+        }
+
+        [Fact]
+        public static void InitializeAndValidate_ThrowsIfCalledOutsideOfConstructorOrMultipleTimes()
+        {
+            TestUriParser parser = new TestUriParser();
+            UriParser.Register(parser, "test-scheme", 12345);
+
+            // Does not throw if called once from the constructor
+            parser.BaseInitializeAndValidateCallCount = 1;
+            Uri uri = new Uri("test-scheme://foo.bar");
+
+            // Throws if called multiple times
+            parser.BaseInitializeAndValidateCallCount = 2;
+            Assert.Throws<InvalidOperationException>(() => new Uri("test-scheme://foo.bar"));
+
+            // Throws if called after the constructor
+            parser.BaseInitializeAndValidateCallCount = 0;
+            uri = new Uri("test-scheme://foo.bar");
+            parser.BaseInitializeAndValidateCallCount = 1;
+            Assert.Throws<InvalidOperationException>(() => parser.DangerousExposed_InitializeAndValidate(uri, out _));
         }
 
         [Fact]
index 9fd7dda2fbef66caac66f6be1daef5fd38894499..6317c87b2942b1bcc0dbbeefa429c4e1f53d3ce1 100644 (file)
@@ -3,6 +3,8 @@
 // See the LICENSE file in the project root for more information.
 
 using System.Collections.Generic;
+using System.Threading;
+using System.Threading.Tasks;
 using Xunit;
 
 namespace System.PrivateUri.Tests
@@ -707,5 +709,34 @@ namespace System.PrivateUri.Tests
             Assert.Equal("http://www.contoso.com/", u.AbsoluteUri);
             Assert.Equal(80, u.Port);
         }
+
+        [Fact]
+        public static void Uri_DoesNotLockOnString()
+        {
+            // Don't intern the string we lock on
+            string uriString = "*http://www.contoso.com".Substring(1);
+
+            bool timedOut = false;
+
+            var enteredLockMre = new ManualResetEvent(false);
+            var finishedParsingMre = new ManualResetEvent(false);
+
+            Task.Factory.StartNew(() =>
+            {
+                lock (uriString)
+                {
+                    enteredLockMre.Set();
+                    timedOut = !finishedParsingMre.WaitOne(TimeSpan.FromSeconds(10));
+                }
+            }, TaskCreationOptions.LongRunning);
+
+            enteredLockMre.WaitOne();
+            int port = new Uri(uriString).Port;
+            finishedParsingMre.Set();
+            Assert.Equal(80, port);
+
+            Assert.True(Monitor.TryEnter(uriString, TimeSpan.FromSeconds(10)));
+            Assert.False(timedOut);
+        }
     }
 }