From f3d84ca003943b09a8b80091e4b61221be339b22 Mon Sep 17 00:00:00 2001 From: Anton Firszov Date: Wed, 4 Nov 2020 13:45:51 +0100 Subject: [PATCH] Fix: NetworkStream throwing inconsistent exceptions (#40772) Fix a bug: Span overloads of NetworkStream throwing ObjectDisposedException instead of NetworkException, when not using derived NetworkStream. --- .../src/System/Net/Sockets/NetworkStream.cs | 23 +++++++++++++++-- .../tests/FunctionalTests/NetworkStreamTest.cs | 30 ++++++++++++---------- 2 files changed, 38 insertions(+), 15 deletions(-) diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/NetworkStream.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/NetworkStream.cs index 0cff6c8..3f13b71 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/NetworkStream.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/NetworkStream.cs @@ -250,7 +250,17 @@ namespace System.Net.Sockets ThrowIfDisposed(); if (!CanRead) throw new InvalidOperationException(SR.net_writeonlystream); - int bytesRead = _streamSocket.Receive(buffer, SocketFlags.None, out SocketError errorCode); + int bytesRead; + SocketError errorCode; + try + { + bytesRead = _streamSocket.Receive(buffer, SocketFlags.None, out errorCode); + } + catch (Exception exception) when (!(exception is OutOfMemoryException)) + { + throw GetCustomException(SR.Format(SR.net_io_readfailure, exception.Message), exception); + } + if (errorCode != SocketError.Success) { var socketException = new SocketException((int)errorCode); @@ -320,7 +330,16 @@ namespace System.Net.Sockets ThrowIfDisposed(); if (!CanWrite) throw new InvalidOperationException(SR.net_readonlystream); - _streamSocket.Send(buffer, SocketFlags.None, out SocketError errorCode); + SocketError errorCode; + try + { + _streamSocket.Send(buffer, SocketFlags.None, out errorCode); + } + catch (Exception exception) when (!(exception is OutOfMemoryException)) + { + throw GetCustomException(SR.Format(SR.net_io_writefailure, exception.Message), exception); + } + if (errorCode != SocketError.Success) { var socketException = new SocketException((int)errorCode); diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/NetworkStreamTest.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/NetworkStreamTest.cs index b812da8..4f03d4a 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/NetworkStreamTest.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/NetworkStreamTest.cs @@ -307,8 +307,10 @@ namespace System.Net.Sockets.Tests }); } - [Fact] - public async Task DisposeSocketDirectly_ReadWriteThrowNetworkException() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task DisposeSocketDirectly_ReadWriteThrowNetworkException(bool derivedNetworkStream) { using (Socket listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) using (Socket client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) @@ -318,20 +320,22 @@ namespace System.Net.Sockets.Tests Task acceptTask = listener.AcceptAsync(); await Task.WhenAll(acceptTask, client.ConnectAsync(new IPEndPoint(IPAddress.Loopback, ((IPEndPoint)listener.LocalEndPoint).Port))); - using (Socket serverSocket = await acceptTask) - using (DerivedNetworkStream server = new DerivedNetworkStream(serverSocket)) - { - serverSocket.Dispose(); + using Socket serverSocket = await acceptTask; + using NetworkStream server = derivedNetworkStream ? (NetworkStream)new DerivedNetworkStream(serverSocket) : new NetworkStream(serverSocket); + + serverSocket.Dispose(); - Assert.Throws(() => server.Read(new byte[1], 0, 1)); - Assert.Throws(() => server.Write(new byte[1], 0, 1)); + Assert.Throws(() => server.Read(new byte[1], 0, 1)); + Assert.Throws(() => server.Write(new byte[1], 0, 1)); - Assert.Throws(() => server.BeginRead(new byte[1], 0, 1, null, null)); - Assert.Throws(() => server.BeginWrite(new byte[1], 0, 1, null, null)); + Assert.Throws(() => server.Read((Span)new byte[1])); + Assert.Throws(() => server.Write((ReadOnlySpan)new byte[1])); - Assert.Throws(() => { server.ReadAsync(new byte[1], 0, 1); }); - Assert.Throws(() => { server.WriteAsync(new byte[1], 0, 1); }); - } + Assert.Throws(() => server.BeginRead(new byte[1], 0, 1, null, null)); + Assert.Throws(() => server.BeginWrite(new byte[1], 0, 1, null, null)); + + Assert.Throws(() => { server.ReadAsync(new byte[1], 0, 1); }); + Assert.Throws(() => { server.WriteAsync(new byte[1], 0, 1); }); } } -- 2.7.4