ensure we never send EndStream after RST_STREAM
authorGeoff Kizer <geoffrek>
Sat, 27 Jul 2019 22:37:42 +0000 (15:37 -0700)
committerGeoff Kizer <geoffrek>
Sat, 27 Jul 2019 22:45:17 +0000 (15:45 -0700)
Commit migrated from https://github.com/dotnet/corefx/commit/9389088720e38b9d4f1d126d9daaa3461ad158e8

src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs

index a03e11d..4ed6157 100644 (file)
@@ -178,26 +178,38 @@ namespace System.Net.Http
                         Debug.Assert(_requestCompletionState == StreamCompletionState.InProgress, $"Request already completed with state={_requestCompletionState}");
 
                         _requestCompletionState = StreamCompletionState.Failed;
-                        CheckForCompletion();
+
+                        // Cancel above should ensure that the response is either Completed or Failed now.
+                        Debug.Assert(_responseCompletionState != StreamCompletionState.InProgress);
+
+                        Reset();
                     }
 
                     throw;
                 }
 
-                // Before we send the EndStream, mark the request as completed.
-                // This avoids races where the server may close the connection after finishing all requests
-                // and we get an OnReset callback before the request is marked as completed.
                 lock (SyncObject)
                 {
                     Debug.Assert(_requestCompletionState == StreamCompletionState.InProgress, $"Request already completed with state={_requestCompletionState}");
 
                     _requestCompletionState = StreamCompletionState.Completed;
-                    CheckForCompletion();
-                }
+                    if (_responseCompletionState == StreamCompletionState.Failed)
+                    {
+                        // Note, we can reach this point if the response stream failed but cancellation didn't propagate before we finished.
+                        Reset();
+                    }
+                    else
+                    {
+                        // Send EndStream asynchronously and without cancellation.
+                        // If this fails, it means that the connection is aborting and we will be reset.
+                        _connection.LogExceptions(_connection.SendEndStreamAsync(_streamId));
 
-                // Send EndStream asynchronously and without cancellation.
-                // If this fails, it means that the connection is aborting and we will be reset.
-                _connection.LogExceptions(_connection.SendEndStreamAsync(_streamId));
+                        if (_responseCompletionState == StreamCompletionState.Completed)
+                        {
+                            Complete();
+                        }
+                    }
+                }
             }
 
             // Delay sending request body if we sent Expect: 100-continue.
@@ -232,28 +244,34 @@ namespace System.Net.Http
                 return sendRequestContent;
             }
 
-            private void CheckForCompletion()
+            private void Reset()
             {
                 Debug.Assert(Monitor.IsEntered(SyncObject));
-                Debug.Assert(_requestCompletionState != StreamCompletionState.InProgress || _responseCompletionState != StreamCompletionState.InProgress,
-                    $"CheckForCompletion called but neither request nor response is completed");
-
-                if (_requestCompletionState == StreamCompletionState.InProgress || _responseCompletionState == StreamCompletionState.InProgress)
+                Debug.Assert(_requestCompletionState != StreamCompletionState.InProgress);
+                Debug.Assert(_responseCompletionState != StreamCompletionState.InProgress);
+                Debug.Assert(_requestCompletionState == StreamCompletionState.Failed || _responseCompletionState == StreamCompletionState.Failed,
+                    "Reset called but neither request nor response is failed");
+
+                if (NetEventSource.IsEnabled) Trace($"Stream reset. This is a test. Request={_requestCompletionState}, Response={_responseCompletionState}.");
+                // Don't send a RST_STREAM if we've already received one from the server.
+                if (_resetException == null)
                 {
-                    return;
+                    _connection.LogExceptions(_connection.SendRstStreamAsync(_streamId, Http2ProtocolErrorCode.Cancel));
                 }
 
-                if (NetEventSource.IsEnabled) Trace($"Stream complete. Request={_requestCompletionState}, Response={_responseCompletionState}.");
+                Complete();
+            }
 
-                if (_resetException == null &&
-                    (_requestCompletionState == StreamCompletionState.Failed || _responseCompletionState == StreamCompletionState.Failed))
-                {
-                    IgnoreExceptions(_connection.SendRstStreamAsync(_streamId, Http2ProtocolErrorCode.Cancel));
-                }
+            private void Complete()
+            {
+                Debug.Assert(Monitor.IsEntered(SyncObject));
+                Debug.Assert(_requestCompletionState != StreamCompletionState.InProgress);
+                Debug.Assert(_responseCompletionState != StreamCompletionState.InProgress);
+
+                if (NetEventSource.IsEnabled) Trace($"Stream complete. Request={_requestCompletionState}, Response={_responseCompletionState}.");
 
                 _connection.RemoveStream(this);
 
-                // Do cleanup.
                 _streamWindow.Dispose();
                 _requestBodyCancellationSource?.Dispose();
             }
@@ -276,7 +294,10 @@ namespace System.Net.Http
                     if (_responseCompletionState == StreamCompletionState.InProgress)
                     {
                         _responseCompletionState = StreamCompletionState.Failed;
-                        CheckForCompletion();
+                        if (_requestCompletionState != StreamCompletionState.InProgress)
+                        {
+                            Reset();
+                        }
                     }
 
                     // Discard any remaining buffered response data
@@ -512,7 +533,14 @@ namespace System.Net.Http
                         Debug.Assert(_responseCompletionState == StreamCompletionState.InProgress, $"Response already completed with state={_responseCompletionState}");
 
                         _responseCompletionState = StreamCompletionState.Completed;
-                        CheckForCompletion();
+                        if (_requestCompletionState == StreamCompletionState.Completed)
+                        {
+                            Complete();
+                        }
+
+                        // We should never reach here with the request failed. It's only set to Failed in SendRequestBodyAsync after we've called Cancel,
+                        // which will set the _responseCompletionState to Failed, meaning we'll never get here.
+                        Debug.Assert(_requestCompletionState != StreamCompletionState.Failed);
                     }
 
                     signalWaiter = _hasWaiter;
@@ -558,7 +586,14 @@ namespace System.Net.Http
                         Debug.Assert(_responseCompletionState == StreamCompletionState.InProgress, $"Response already completed with state={_responseCompletionState}");
 
                         _responseCompletionState = StreamCompletionState.Completed;
-                        CheckForCompletion();
+                        if (_requestCompletionState == StreamCompletionState.Completed)
+                        {
+                            Complete();
+                        }
+
+                        // We should never reach here with the request failed. It's only set to Failed in SendRequestBodyAsync after we've called Cancel,
+                        // which will set the _responseCompletionState to Failed, meaning we'll never get here.
+                        Debug.Assert(_requestCompletionState != StreamCompletionState.Failed);
                     }
 
                     signalWaiter = _hasWaiter;