Replace thread aborts in SoundPlayer (dotnet/corefx#34815)
authorStephen Toub <stoub@microsoft.com>
Fri, 25 Jan 2019 00:21:39 +0000 (19:21 -0500)
committerGitHub <noreply@github.com>
Fri, 25 Jan 2019 00:21:39 +0000 (19:21 -0500)
This only addresses the issue of SoundPlayer using Thread.Abort, replacing it with cancellation.

Commit migrated from https://github.com/dotnet/corefx/commit/2100e4f54637c124bdf880c7c036b4c79053e72f

src/libraries/System.Windows.Extensions/src/System.Windows.Extensions.csproj
src/libraries/System.Windows.Extensions/src/System/Media/SoundPlayer.cs
src/libraries/System.Windows.Extensions/tests/System/Media/SoundPlayerTests.cs

index 2f93a70f8d994b5b042dd128d21b532337fa11c5..0b0b9cf875ace85a86d37e06e46060bbee85b042 100644 (file)
@@ -90,6 +90,7 @@
     <Reference Include="System.Runtime" />
     <Reference Include="System.Security.Cryptography.X509Certificates" />
     <Reference Include="System.Text.RegularExpressions" />
+    <Reference Include="System.Threading.Tasks" />
     <Reference Include="System.Threading.Thread" />
     <Reference Include="System.Threading" />
   </ItemGroup>
index 97a9ef47f006d5c86a0a07a0aa6ddcfead875f63..d00bb2e5102bfc8cc26a8701f5b171288c420bcc 100644 (file)
@@ -9,6 +9,7 @@ using System.Net;
 using System.Runtime.InteropServices;
 using System.Runtime.Serialization;
 using System.Threading;
+using System.Threading.Tasks;
 
 namespace System.Media
 {
@@ -24,10 +25,11 @@ namespace System.Media
         // used to lock all synchronous calls to the SoundPlayer object
         private readonly ManualResetEvent _semaphore = new ManualResetEvent(true);
 
-        // the worker copyThread
-        // we start the worker copyThread ONLY from entry points in the SoundPlayer API
+        // the worker copyTask
+        // we start the worker copyTask ONLY from entry points in the SoundPlayer API
         // we also set the tread to null only from the entry points in the SoundPlayer API
-        private Thread _copyThread = null;
+        private Task _copyTask = null;
+        private CancellationTokenSource _copyTaskCancellation = null;
 
         // local buffer information
         private int _currentPos = 0;
@@ -146,7 +148,7 @@ namespace System.Media
             }
 
             // if we are actively loading, keep it running
-            if (_copyThread != null && _copyThread.ThreadState == ThreadState.Running)
+            if (_copyTask != null && !_copyTask.IsCompleted)
             {
                 return;
             }
@@ -174,7 +176,7 @@ namespace System.Media
             IsLoadCompleted = false;
             _lastLoadException = null;
             _doesLoadAppearSynchronous = false;
-            _copyThread = null;
+            _copyTask = null;
             _semaphore.Set();
         }
 
@@ -230,6 +232,12 @@ namespace System.Media
             }
         }
 
+        private void CancelLoad()
+        {
+            _copyTaskCancellation?.Cancel();
+            _copyTaskCancellation = null;
+        }
+
         private void LoadSync()
         {
             Debug.Assert((_uri == null || !_uri.IsFile), "we only load streams");
@@ -237,7 +245,7 @@ namespace System.Media
             // first make sure that any possible download ended
             if (!_semaphore.WaitOne(LoadTimeout, false))
             {
-                _copyThread?.Abort();
+                CancelLoad();
                 CleanupStreamData();
                 throw new TimeoutException(SR.SoundAPILoadTimedOut);
             }
@@ -275,7 +283,7 @@ namespace System.Media
 
                 if (!_semaphore.WaitOne(LoadTimeout, false))
                 {
-                    _copyThread?.Abort();
+                    CancelLoad();
                     CleanupStreamData();
                     throw new TimeoutException(SR.SoundAPILoadTimedOut);
                 }
@@ -289,7 +297,7 @@ namespace System.Media
             }
 
             // we don't need the worker copyThread anymore
-            _copyThread = null;
+            _copyTask = null;
         }
 
         private void LoadStream(bool loadSync)
@@ -308,8 +316,9 @@ namespace System.Media
                 // lock any synchronous calls on the Sound object
                 _semaphore.Reset();
                 // start loading
-                _copyThread = new Thread(new ThreadStart(WorkerThread));
-                _copyThread.Start();
+                var cts = new CancellationTokenSource();
+                _copyTaskCancellation = cts;
+                _copyTask = CopyStreamAsync(cts.Token);
             }
         }
 
@@ -359,9 +368,9 @@ namespace System.Media
         private void SetupSoundLocation(string soundLocation)
         {
             // if we are loading a file, stop it right now
-            if (_copyThread != null)
+            if (_copyTask != null)
             {
-                _copyThread.Abort();
+                CancelLoad();
                 CleanupStreamData();
             }
 
@@ -390,9 +399,9 @@ namespace System.Media
 
         private void SetupStream(Stream stream)
         {
-            if (_copyThread != null)
+            if (_copyTask != null)
             {
-                _copyThread.Abort();
+                CancelLoad();
                 CleanupStreamData();
             }
 
@@ -463,7 +472,7 @@ namespace System.Media
             ((EventHandler)Events[s_eventStreamChanged])?.Invoke(this, e);
         }
 
-        private void WorkerThread()
+        private async Task CopyStreamAsync(CancellationToken cancellationToken)
         {
             try
             {
@@ -471,15 +480,16 @@ namespace System.Media
                 if (_uri != null && !_uri.IsFile && _stream == null)
                 {
                     WebRequest webRequest = WebRequest.Create(_uri);
-
-                    WebResponse webResponse = webRequest.GetResponse();
-
-                    _stream = webResponse.GetResponseStream();
+                    using (cancellationToken.Register(r => ((WebRequest)r).Abort(), webRequest))
+                    {
+                        WebResponse webResponse = await webRequest.GetResponseAsync().ConfigureAwait(false);
+                        _stream = webResponse.GetResponseStream();
+                    }
                 }
 
                 _streamData = new byte[BlockSize];
 
-                int readBytes = _stream.Read(_streamData, _currentPos, BlockSize);
+                int readBytes = await _stream.ReadAsync(_streamData, _currentPos, BlockSize, cancellationToken).ConfigureAwait(false);
                 int totalBytes = readBytes;
 
                 while (readBytes > 0)
@@ -491,7 +501,7 @@ namespace System.Media
                         Array.Copy(_streamData, newData, _streamData.Length);
                         _streamData = newData;
                     }
-                    readBytes = _stream.Read(_streamData, _currentPos, BlockSize);
+                    readBytes = await _stream.ReadAsync(_streamData, _currentPos, BlockSize, cancellationToken).ConfigureAwait(false);
                     totalBytes += readBytes;
                 }
 
@@ -502,15 +512,17 @@ namespace System.Media
                 _lastLoadException = exception;
             }
 
+            IsLoadCompleted = true;
+            _semaphore.Set();
+
             if (!_doesLoadAppearSynchronous)
             {
                 // Post notification back to the UI thread.
-                _asyncOperation.PostOperationCompleted(
-                    _loadAsyncOperationCompleted,
-                    new AsyncCompletedEventArgs(_lastLoadException, false, null));
+                AsyncCompletedEventArgs ea = _lastLoadException is OperationCanceledException ?
+                    new AsyncCompletedEventArgs(null, cancelled: true, null) :
+                    new AsyncCompletedEventArgs(_lastLoadException, cancelled: false, null);
+                _asyncOperation.PostOperationCompleted(_loadAsyncOperationCompleted, ea);
             }
-            IsLoadCompleted = true;
-            _semaphore.Set();
         }
 
         private unsafe void ValidateSoundFile(string fileName)
index 0efb92aab99cc0a933a935881702b010ada2bd52..12e8c2cc3173aef007bda9fa07a66bb315550148 100644 (file)
@@ -3,8 +3,13 @@
 // See the LICENSE file in the project root for more information.
 
 using System.Collections.Generic;
+using System.ComponentModel;
 using System.IO;
-using System.Runtime.Serialization.Formatters.Binary;
+using System.Net;
+using System.Net.Sockets;
+using System.Text;
+using System.Threading;
+using System.Threading.Tasks;
 using Xunit;
 
 namespace System.Media.Test
@@ -95,6 +100,49 @@ namespace System.Media.Test
             soundPlayer.Play();
         }
 
+        [Theory]
+        [MemberData(nameof(Play_String_TestData))]
+        [OuterLoop]
+        public async Task LoadAsync_SourceLocationFromNetwork_Success(string sourceLocation)
+        {
+            var player = new SoundPlayer();
+
+            using (Socket listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
+            {
+                listener.Bind(new IPEndPoint(IPAddress.Loopback, 0));
+                listener.Listen(1);
+                var ep = (IPEndPoint)listener.LocalEndPoint;
+
+                Task serverTask = Task.Run(async () =>
+                {
+                    using (Socket server = await listener.AcceptAsync())
+                    using (var serverStream = new NetworkStream(server))
+                    using (var reader = new StreamReader(new NetworkStream(server)))
+                    using (FileStream sourceStream = File.OpenRead(sourceLocation.Replace("file://", "")))
+                    {
+                        string line;
+                        while (!string.IsNullOrEmpty(line = await reader.ReadLineAsync()));
+                        byte[] header = Encoding.UTF8.GetBytes($"HTTP/1.1 200 OK\r\nContent-Length: {sourceStream.Length}\r\n\r\n");
+                        serverStream.Write(header, 0, header.Length);
+                        await sourceStream.CopyToAsync(serverStream);
+                        server.Shutdown(SocketShutdown.Both);
+                    }
+                });
+
+                var tcs = new TaskCompletionSource<AsyncCompletedEventArgs>();
+                player.LoadCompleted += (s, e) => tcs.TrySetResult(e);
+                player.SoundLocation = $"http://{ep.Address}:{ep.Port}";
+                player.LoadAsync();
+                AsyncCompletedEventArgs ea = await tcs.Task;
+                Assert.Null(ea.Error);
+                Assert.False(ea.Cancelled);
+
+                await serverTask;
+            }
+
+            player.Play();
+        }
+
         [Theory]
         [MemberData(nameof(Play_String_TestData))]
         [OuterLoop]
@@ -394,5 +442,85 @@ namespace System.Media.Test
                 Assert.False(calledHandler);
             }
         }
+
+        [SkipOnTargetFramework(TargetFrameworkMonikers.NetFramework, "netfx aborts a worker thread and never signals operation completion")]
+        [Theory]
+        [InlineData(0)]
+        [InlineData(1)]
+        [InlineData(2)]
+        public async Task LoadAsync_CancelDuringLoad_CompletesAsCanceled(int cancellationCause)
+        {
+            var tcs = new TaskCompletionSource<AsyncCompletedEventArgs>();
+            var player = new SoundPlayer();
+            player.LoadCompleted += (s, e) => tcs.SetResult(e);
+            player.Stream = new ReadAsyncBlocksUntilCanceledStream();
+            player.LoadAsync();
+
+            Assert.False(tcs.Task.IsCompleted);
+
+            switch (cancellationCause)
+            {
+                case 0:
+                    player.Stream = new MemoryStream();
+                    break;
+
+                case 1:
+                    player.LoadTimeout = 1;
+                    Assert.Throws<TimeoutException>(() => player.Load());
+                    break;
+
+                case 2:
+                    player.SoundLocation = "DoesntExistButThatDoesntMatter";
+                    break;
+            }
+
+            AsyncCompletedEventArgs ea = await tcs.Task;
+            Assert.Null(ea.Error);
+            Assert.True(ea.Cancelled);
+            Assert.Null(ea.UserState);
+        }
+
+        [Theory]
+        [MemberData(nameof(Play_String_TestData))]
+        [OuterLoop]
+        public async Task CancelDuringLoad_ThenPlay_Success(string sourceLocation)
+        {
+            using (FileStream stream = File.OpenRead(sourceLocation.Replace("file://", "")))
+            {
+                var tcs = new TaskCompletionSource<bool>();
+                AsyncCompletedEventHandler handler = (s, e) => tcs.SetResult(true);
+
+                var player = new SoundPlayer();
+                player.LoadCompleted += handler;
+                player.Stream = new ReadAsyncBlocksUntilCanceledStream();
+                player.LoadAsync();
+
+                player.Stream = stream;
+                await tcs.Task;
+                player.LoadCompleted -= handler;
+
+                player.Play();
+            }
+        }
+
+        private sealed class ReadAsyncBlocksUntilCanceledStream : Stream
+        {
+            public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
+            {
+                await Task.Delay(-1, cancellationToken);
+                return 0;
+            }
+
+            public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException();
+            public override void Flush() { }
+            public override bool CanRead => true;
+            public override bool CanSeek => false;
+            public override bool CanWrite => false;
+            public override long Length => throw new NotSupportedException();
+            public override long Position { get => throw new NotSupportedException(); set => throw new NotSupportedException(); }
+            public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException();
+            public override void SetLength(long value) => throw new NotSupportedException();
+            public override void Write(byte[] buffer, int offset, int count) => throw new NotSupportedException();
+        }
     }
 }