Fix a few Socket.SendFile issues (#42535)
authorStephen Toub <stoub@microsoft.com>
Wed, 23 Sep 2020 13:22:58 +0000 (09:22 -0400)
committerGitHub <noreply@github.com>
Wed, 23 Sep 2020 13:22:58 +0000 (09:22 -0400)
* Fix a few Socket.SendFile issues

- The string argument in the single-argument overload should be nullable.
- All overloads on Windows should allow a null file path, but they've been throwing an exception
- On Linux, data was silently truncated when sending a file larger than int.MaxValue with BeginSendFile.

* Address PR feedback

src/libraries/Common/src/Interop/Windows/WinSock/Interop.TransmitFile.cs
src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs
src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs
src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs
src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs
src/libraries/System.Net.Sockets/tests/FunctionalTests/SendFile.cs

index d097205..d6e2907 100644 (file)
@@ -14,7 +14,7 @@ internal static partial class Interop
         [DllImport(Interop.Libraries.Mswsock, SetLastError = true)]
         internal static extern unsafe bool TransmitFile(
             SafeHandle socket,
-            SafeHandle? fileHandle,
+            IntPtr fileHandle,
             int numberOfBytesToWrite,
             int numberOfBytesPerSend,
             NativeOverlapped* overlapped,
index d8ccc33..4025a30 100644 (file)
@@ -383,7 +383,7 @@ namespace System.Net.Sockets
         public int Send(System.ReadOnlySpan<byte> buffer, System.Net.Sockets.SocketFlags socketFlags) { throw null; }
         public int Send(System.ReadOnlySpan<byte> buffer, System.Net.Sockets.SocketFlags socketFlags, out System.Net.Sockets.SocketError errorCode) { throw null; }
         public bool SendAsync(System.Net.Sockets.SocketAsyncEventArgs e) { throw null; }
-        public void SendFile(string fileName) { }
+        public void SendFile(string? fileName) { }
         public void SendFile(string? fileName, byte[]? preBuffer, byte[]? postBuffer, System.Net.Sockets.TransmitFileOptions flags) { }
         public bool SendPacketsAsync(System.Net.Sockets.SocketAsyncEventArgs e) { throw null; }
         public int SendTo(byte[] buffer, int offset, int size, System.Net.Sockets.SocketFlags socketFlags, System.Net.EndPoint remoteEP) { throw null; }
index ce976d0..82465b6 100644 (file)
@@ -1299,7 +1299,7 @@ namespace System.Net.Sockets
             return bytesTransferred;
         }
 
-        public void SendFile(string fileName)
+        public void SendFile(string? fileName)
         {
             SendFile(fileName, null, null, TransmitFileOptions.UseDefaultWorkerThread);
         }
index fca424b..9bbe0f4 100644 (file)
@@ -1809,9 +1809,9 @@ namespace System.Net.Sockets
         }
 
         public static SocketError SendFileAsync(SafeSocketHandle handle, FileStream fileStream, Action<long, SocketError> callback) =>
-            SendFileAsync(handle, fileStream, 0, (int)fileStream.Length, callback);
+            SendFileAsync(handle, fileStream, 0, fileStream.Length, callback);
 
-        private static SocketError SendFileAsync(SafeSocketHandle handle, FileStream fileStream, long offset, int count, Action<long, SocketError> callback)
+        private static SocketError SendFileAsync(SafeSocketHandle handle, FileStream fileStream, long offset, long count, Action<long, SocketError> callback)
         {
             long bytesSent;
             SocketError socketError = handle.AsyncContext.SendFileAsync(fileStream.SafeFileHandle, offset, count, out bytesSent, callback);
@@ -1849,7 +1849,7 @@ namespace System.Net.Sockets
 
                             var tcs = new TaskCompletionSource<SocketError>();
                             error = SendFileAsync(socket.InternalSafeHandle, fs, e.OffsetLong,
-                                e.Count > 0 ? e.Count : checked((int)(fs.Length - e.OffsetLong)),
+                                e.Count > 0 ? e.Count : fs.Length - e.OffsetLong,
                                 (transferred, se) =>
                                 {
                                     bytesTransferred += transferred;
index 1d79475..1a23f93 100644 (file)
@@ -1137,10 +1137,27 @@ namespace System.Net.Sockets
                 transmitFileBuffers.TailLength = postBuffer.Length;
             }
 
-            bool success = Interop.Mswsock.TransmitFile(socket, fileHandle, 0, 0, overlapped,
-                needTransmitFileBuffers ? &transmitFileBuffers : null, flags);
+            bool releaseRef = false;
+            IntPtr fileHandlePtr = IntPtr.Zero;
+            try
+            {
+                if (fileHandle != null)
+                {
+                    fileHandle.DangerousAddRef(ref releaseRef);
+                    fileHandlePtr = fileHandle.DangerousGetHandle();
+                }
 
-            return success;
+                return Interop.Mswsock.TransmitFile(
+                    socket, fileHandlePtr, 0, 0, overlapped,
+                    needTransmitFileBuffers ? &transmitFileBuffers : null, flags);
+            }
+            finally
+            {
+                if (releaseRef)
+                {
+                    fileHandle!.DangerousRelease();
+                }
+            }
         }
 
         public static unsafe SocketError SendFileAsync(SafeSocketHandle handle, FileStream? fileStream, byte[]? preBuffer, byte[]? postBuffer, TransmitFileOptions flags, TransmitFileAsyncResult asyncResult)
index e17435e..3a8db0f 100644 (file)
@@ -104,6 +104,108 @@ namespace System.Net.Sockets.Tests
             }
         }
 
+        [Theory]
+        [InlineData(false, false, false)]
+        [InlineData(false, false, true)]
+        [InlineData(false, true, false)]
+        [InlineData(false, true, true)]
+        [InlineData(true, false, false)]
+        [InlineData(true, false, true)]
+        [InlineData(true, true, false)]
+        [InlineData(true, true, true)]
+        public async Task SendFile_NoFile_Succeeds(bool useAsync, bool usePreBuffer, bool usePostBuffer)
+        {
+            using var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
+            using var listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
+            listener.BindToAnonymousPort(IPAddress.Loopback);
+            listener.Listen(1);
+
+            client.Connect(listener.LocalEndPoint);
+            using Socket server = listener.Accept();
+
+            if (useAsync)
+            {
+                await Task.Factory.FromAsync<string>(server.BeginSendFile, server.EndSendFile, null, null);
+            }
+            else
+            {
+                server.SendFile(null);
+            }
+            Assert.Equal(0, client.Available);
+
+            byte[] preBuffer = usePreBuffer ? new byte[1] : null;
+            byte[] postBuffer = usePostBuffer ? new byte[1] : null;
+            int bytesExpected = (usePreBuffer ? 1 : 0) + (usePostBuffer ? 1 : 0);
+
+            if (useAsync)
+            {
+                await Task.Factory.FromAsync((c, s) => server.BeginSendFile(null, preBuffer, postBuffer, TransmitFileOptions.UseDefaultWorkerThread, c, s), server.EndSendFile, null);
+            }
+            else
+            {
+                server.SendFile(null, preBuffer, postBuffer, TransmitFileOptions.UseDefaultWorkerThread);
+            }
+
+            byte[] receiveBuffer = new byte[1];
+            for (int i = 0; i < bytesExpected; i++)
+            {
+                Assert.Equal(1, client.Receive(receiveBuffer));
+            }
+
+            Assert.Equal(0, client.Available);
+        }
+
+        [ActiveIssue("https://github.com/dotnet/runtime/issues/42534", TestPlatforms.Windows)]
+        [OuterLoop("Creates and sends a file several gigabytes long")]
+        [Theory]
+        [InlineData(false)]
+        [InlineData(true)]
+        public async Task SendFile_GreaterThan2GBFile_SendsAllBytes(bool useAsync)
+        {
+            const long FileLength = 100L + int.MaxValue;
+
+            string tmpFile = GetTestFilePath();
+            using (FileStream fs = File.Create(tmpFile))
+            {
+                fs.SetLength(FileLength);
+            }
+
+            using var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
+            using var listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
+            listener.BindToAnonymousPort(IPAddress.Loopback);
+            listener.Listen(1);
+
+            client.Connect(listener.LocalEndPoint);
+            using Socket server = listener.Accept();
+
+            await new Task[]
+            {
+                Task.Run(async () =>
+                {
+                    if (useAsync)
+                    {
+                        await Task.Factory.FromAsync(server.BeginSendFile, server.EndSendFile, tmpFile, null);
+                    }
+                    else
+                    {
+                        server.SendFile(tmpFile);
+                    }
+                }),
+                Task.Run(() =>
+                {
+                    byte[] buffer = new byte[100_000];
+                    long count = 0;
+                    while (count < FileLength)
+                    {
+                        int received = client.Receive(buffer);
+                        Assert.NotEqual(0, received);
+                        count += received;
+                    }
+                    Assert.Equal(0, client.Available);
+                })
+            }.WhenAllOrAnyFailed();
+        }
+
         [OuterLoop]
         [Theory]
         [MemberData(nameof(SendFileSync_MemberData))]