Simplify catch-rethrow logic in NetworkStream (#44246)
authorAnton Firszov <Anton.Firszov@microsoft.com>
Fri, 6 Nov 2020 13:52:49 +0000 (14:52 +0100)
committerGitHub <noreply@github.com>
Fri, 6 Nov 2020 13:52:49 +0000 (14:52 +0100)
A follow-up on #40772 (comment), simplifies and harmonizes the way we wrap exceptions into IOException. Having one catch block working with System.Exception seems to be enough here, no need for specific handling of SocketException.

src/libraries/System.Net.Sockets/src/System/Net/Sockets/NetworkStream.cs
src/libraries/System.Net.Sockets/tests/FunctionalTests/NetworkStreamTest.cs

index 3f13b71..32a02d9 100644 (file)
@@ -50,15 +50,15 @@ namespace System.Net.Sockets
                 // allowing non-blocking sockets could result in non-deterministic failures from those
                 // operations. A developer that requires using NetworkStream with a non-blocking socket can
                 // temporarily flip Socket.Blocking as a workaround.
-                throw GetCustomException(SR.net_sockets_blocking);
+                throw new IOException(SR.net_sockets_blocking);
             }
             if (!socket.Connected)
             {
-                throw GetCustomException(SR.net_notconnected);
+                throw new IOException(SR.net_notconnected);
             }
             if (socket.SocketType != SocketType.Stream)
             {
-                throw GetCustomException(SR.net_notstream);
+                throw new IOException(SR.net_notstream);
             }
 
             _streamSocket = socket;
@@ -227,13 +227,9 @@ namespace System.Net.Sockets
             {
                 return _streamSocket.Receive(buffer, offset, count, 0);
             }
-            catch (SocketException socketException)
-            {
-                throw GetExceptionFromSocketException(SR.Format(SR.net_io_readfailure, socketException.Message), socketException);
-            }
             catch (Exception exception) when (!(exception is OutOfMemoryException))
             {
-                throw GetCustomException(SR.Format(SR.net_io_readfailure, exception.Message), exception);
+                throw WrapException(SR.net_io_readfailure, exception);
             }
         }
 
@@ -250,23 +246,14 @@ namespace System.Net.Sockets
             ThrowIfDisposed();
             if (!CanRead) throw new InvalidOperationException(SR.net_writeonlystream);
 
-            int bytesRead;
-            SocketError errorCode;
             try
             {
-                bytesRead = _streamSocket.Receive(buffer, SocketFlags.None, out errorCode);
+                return _streamSocket.Receive(buffer, SocketFlags.None);
             }
             catch (Exception exception) when (!(exception is OutOfMemoryException))
             {
-                throw GetCustomException(SR.Format(SR.net_io_readfailure, exception.Message), exception);
-            }
-
-            if (errorCode != SocketError.Success)
-            {
-                var socketException = new SocketException((int)errorCode);
-                throw GetExceptionFromSocketException(SR.Format(SR.net_io_readfailure, socketException.Message), socketException);
+                throw WrapException(SR.net_io_readfailure, exception);
             }
-            return bytesRead;
         }
 
         public override unsafe int ReadByte()
@@ -306,13 +293,9 @@ namespace System.Net.Sockets
                 // after ALL the requested number of bytes was transferred.
                 _streamSocket.Send(buffer, offset, count, SocketFlags.None);
             }
-            catch (SocketException socketException)
-            {
-                throw GetExceptionFromSocketException(SR.Format(SR.net_io_writefailure, socketException.Message), socketException);
-            }
             catch (Exception exception) when (!(exception is OutOfMemoryException))
             {
-                throw GetCustomException(SR.Format(SR.net_io_writefailure, exception.Message), exception);
+                throw WrapException(SR.net_io_writefailure, exception);
             }
         }
 
@@ -330,20 +313,13 @@ namespace System.Net.Sockets
             ThrowIfDisposed();
             if (!CanWrite) throw new InvalidOperationException(SR.net_readonlystream);
 
-            SocketError errorCode;
             try
             {
-                _streamSocket.Send(buffer, SocketFlags.None, out errorCode);
+                _streamSocket.Send(buffer, SocketFlags.None);
             }
             catch (Exception exception) when (!(exception is OutOfMemoryException))
             {
-                throw GetCustomException(SR.Format(SR.net_io_writefailure, exception.Message), exception);
-            }
-
-            if (errorCode != SocketError.Success)
-            {
-                var socketException = new SocketException((int)errorCode);
-                throw GetExceptionFromSocketException(SR.Format(SR.net_io_writefailure, socketException.Message), socketException);
+                throw WrapException(SR.net_io_writefailure, exception);
             }
         }
 
@@ -424,13 +400,9 @@ namespace System.Net.Sockets
                         callback,
                         state);
             }
-            catch (SocketException socketException)
-            {
-                throw GetExceptionFromSocketException(SR.Format(SR.net_io_readfailure, socketException.Message), socketException);
-            }
             catch (Exception exception) when (!(exception is OutOfMemoryException))
             {
-                throw GetCustomException(SR.Format(SR.net_io_readfailure, exception.Message), exception);
+                throw WrapException(SR.net_io_readfailure, exception);
             }
         }
 
@@ -456,13 +428,9 @@ namespace System.Net.Sockets
             {
                 return _streamSocket.EndReceive(asyncResult);
             }
-            catch (SocketException socketException)
-            {
-                throw GetExceptionFromSocketException(SR.Format(SR.net_io_readfailure, socketException.Message), socketException);
-            }
             catch (Exception exception) when (!(exception is OutOfMemoryException))
             {
-                throw GetCustomException(SR.Format(SR.net_io_readfailure, exception.Message), exception);
+                throw WrapException(SR.net_io_readfailure, exception);
             }
         }
 
@@ -500,13 +468,9 @@ namespace System.Net.Sockets
                         callback,
                         state);
             }
-            catch (SocketException socketException)
-            {
-                throw GetExceptionFromSocketException(SR.Format(SR.net_io_writefailure, socketException.Message), socketException);
-            }
             catch (Exception exception) when (!(exception is OutOfMemoryException))
             {
-                throw GetCustomException(SR.Format(SR.net_io_writefailure, exception.Message), exception);
+                throw WrapException(SR.net_io_writefailure, exception);
             }
         }
 
@@ -528,13 +492,9 @@ namespace System.Net.Sockets
             {
                 _streamSocket.EndSend(asyncResult);
             }
-            catch (SocketException socketException)
-            {
-                throw GetExceptionFromSocketException(SR.Format(SR.net_io_writefailure, socketException.Message), socketException);
-            }
             catch (Exception exception) when (!(exception is OutOfMemoryException))
             {
-                throw GetCustomException(SR.Format(SR.net_io_writefailure, exception.Message), exception);
+                throw WrapException(SR.net_io_writefailure, exception);
             }
         }
 
@@ -570,13 +530,9 @@ namespace System.Net.Sockets
                     fromNetworkStream: true,
                     cancellationToken).AsTask();
             }
-            catch (SocketException socketException)
-            {
-                throw GetExceptionFromSocketException(SR.Format(SR.net_io_readfailure, socketException.Message), socketException);
-            }
             catch (Exception exception) when (!(exception is OutOfMemoryException))
             {
-                throw GetCustomException(SR.Format(SR.net_io_readfailure, exception.Message), exception);
+                throw WrapException(SR.net_io_readfailure, exception);
             }
         }
 
@@ -597,13 +553,9 @@ namespace System.Net.Sockets
                     fromNetworkStream: true,
                     cancellationToken: cancellationToken);
             }
-            catch (SocketException socketException)
-            {
-                throw GetExceptionFromSocketException(SR.Format(SR.net_io_readfailure, socketException.Message), socketException);
-            }
             catch (Exception exception) when (!(exception is OutOfMemoryException))
             {
-                throw GetCustomException(SR.Format(SR.net_io_readfailure, exception.Message), exception);
+                throw WrapException(SR.net_io_readfailure, exception);
             }
         }
 
@@ -638,13 +590,9 @@ namespace System.Net.Sockets
                     SocketFlags.None,
                     cancellationToken).AsTask();
             }
-            catch (SocketException socketException)
-            {
-                throw GetExceptionFromSocketException(SR.Format(SR.net_io_writefailure, socketException.Message), socketException);
-            }
             catch (Exception exception) when (!(exception is OutOfMemoryException))
             {
-                throw GetCustomException(SR.Format(SR.net_io_writefailure, exception.Message), exception);
+                throw WrapException(SR.net_io_writefailure, exception);
             }
         }
 
@@ -664,13 +612,9 @@ namespace System.Net.Sockets
                     SocketFlags.None,
                     cancellationToken);
             }
-            catch (SocketException socketException)
-            {
-                throw GetExceptionFromSocketException(SR.Format(SR.net_io_writefailure, socketException.Message), socketException);
-            }
             catch (Exception exception) when (!(exception is OutOfMemoryException))
             {
-                throw GetCustomException(SR.Format(SR.net_io_writefailure, exception.Message), exception);
+                throw WrapException(SR.net_io_writefailure, exception);
             }
         }
 
@@ -728,14 +672,9 @@ namespace System.Net.Sockets
             void ThrowObjectDisposedException() => throw new ObjectDisposedException(GetType().FullName);
         }
 
-        private static IOException GetExceptionFromSocketException(string message, SocketException innerException)
-        {
-            return new IOException(message, innerException);
-        }
-
-        private static IOException GetCustomException(string message, Exception? innerException = null)
+        private static IOException WrapException(string resourceFormatString, Exception innerException)
         {
-            return new IOException(message, innerException);
+            return new IOException(SR.Format(resourceFormatString, innerException.Message), innerException);
         }
     }
 }
index 4f03d4a..05f7f9b 100644 (file)
@@ -321,21 +321,28 @@ namespace System.Net.Sockets.Tests
                 Task<Socket> acceptTask = listener.AcceptAsync();
                 await Task.WhenAll(acceptTask, client.ConnectAsync(new IPEndPoint(IPAddress.Loopback, ((IPEndPoint)listener.LocalEndPoint).Port)));
                 using Socket serverSocket = await acceptTask;
+
                 using NetworkStream server = derivedNetworkStream ? (NetworkStream)new DerivedNetworkStream(serverSocket) : new NetworkStream(serverSocket);
-                
+
                 serverSocket.Dispose();
 
-                Assert.Throws<IOException>(() => server.Read(new byte[1], 0, 1));
-                Assert.Throws<IOException>(() => server.Write(new byte[1], 0, 1));
+                ExpectIOException(() => server.Read(new byte[1], 0, 1));
+                ExpectIOException(() => server.Write(new byte[1], 0, 1));
 
-                Assert.Throws<IOException>(() => server.Read((Span<byte>)new byte[1]));
-                Assert.Throws<IOException>(() => server.Write((ReadOnlySpan<byte>)new byte[1]));
+                ExpectIOException(() => server.Read((Span<byte>)new byte[1]));
+                ExpectIOException(() => server.Write((ReadOnlySpan<byte>)new byte[1]));
 
-                Assert.Throws<IOException>(() => server.BeginRead(new byte[1], 0, 1, null, null));
-                Assert.Throws<IOException>(() => server.BeginWrite(new byte[1], 0, 1, null, null));
+                ExpectIOException(() => server.BeginRead(new byte[1], 0, 1, null, null));
+                ExpectIOException(() => server.BeginWrite(new byte[1], 0, 1, null, null));
 
-                Assert.Throws<IOException>(() => { server.ReadAsync(new byte[1], 0, 1); });
-                Assert.Throws<IOException>(() => { server.WriteAsync(new byte[1], 0, 1); });
+                ExpectIOException(() => { _ = server.ReadAsync(new byte[1], 0, 1); });
+                ExpectIOException(() => { _ = server.WriteAsync(new byte[1], 0, 1); });
+            }
+
+            static void ExpectIOException(Action action)
+            {
+                IOException ex = Assert.Throws<IOException>(action);
+                Assert.IsType<ObjectDisposedException>(ex.InnerException);
             }
         }