From 296bda64b6bc156cf44f1984ceaed4baaf4e135d Mon Sep 17 00:00:00 2001 From: Johan Lorensson Date: Wed, 23 Jun 2021 20:22:32 +0200 Subject: [PATCH] Add adb port forward and usbmux support to dsrouter. (#2366) --- .../DiagnosticsServerRouterFactory.cs | 57 +- .../DiagnosticsServerRouterRunner.cs | 12 +- .../dotnet-dsrouter/ADBTcpRouterFactory.cs | 173 ++++++ .../DiagnosticsServerRouterCommands.cs | 61 ++- src/Tools/dotnet-dsrouter/Program.cs | 20 +- .../USBMuxTcpClientRouterFactory.cs | 492 ++++++++++++++++++ 6 files changed, 782 insertions(+), 33 deletions(-) create mode 100644 src/Tools/dotnet-dsrouter/ADBTcpRouterFactory.cs create mode 100644 src/Tools/dotnet-dsrouter/USBMuxTcpClientRouterFactory.cs diff --git a/src/Microsoft.Diagnostics.NETCore.Client/DiagnosticsServerRouter/DiagnosticsServerRouterFactory.cs b/src/Microsoft.Diagnostics.NETCore.Client/DiagnosticsServerRouter/DiagnosticsServerRouterFactory.cs index 9f36ba052..c92f607b1 100644 --- a/src/Microsoft.Diagnostics.NETCore.Client/DiagnosticsServerRouter/DiagnosticsServerRouterFactory.cs +++ b/src/Microsoft.Diagnostics.NETCore.Client/DiagnosticsServerRouter/DiagnosticsServerRouterFactory.cs @@ -150,7 +150,7 @@ namespace Microsoft.Diagnostics.NETCore.Client /// internal class TcpServerRouterFactory : IIpcServerTransportCallbackInternal { - readonly ILogger _logger; + protected readonly ILogger _logger; string _tcpServerAddress; @@ -177,6 +177,13 @@ namespace Microsoft.Diagnostics.NETCore.Client get { return _tcpServerAddress; } } + public delegate TcpServerRouterFactory CreateInstanceDelegate(string tcpServer, int runtimeTimeoutMs, ILogger logger); + + public static TcpServerRouterFactory CreateDefaultInstance(string tcpServer, int runtimeTimeoutMs, ILogger logger) + { + return new TcpServerRouterFactory(tcpServer, runtimeTimeoutMs, logger); + } + public TcpServerRouterFactory(string tcpServer, int runtimeTimeoutMs, ILogger logger) { _logger = logger; @@ -192,12 +199,12 @@ namespace Microsoft.Diagnostics.NETCore.Client _tcpServer.TransportCallback = this; } - public void Start() + public virtual void Start() { _tcpServer.Start(); } - public async Task Stop() + public virtual async Task Stop() { await _tcpServer.DisposeAsync().ConfigureAwait(false); } @@ -281,15 +288,22 @@ namespace Microsoft.Diagnostics.NETCore.Client /// internal class TcpClientRouterFactory { - readonly ILogger _logger; + protected readonly ILogger _logger; - readonly string _tcpClientAddress; + protected readonly string _tcpClientAddress; - bool _auto_shutdown; + protected bool _auto_shutdown; + + protected int TcpClientTimeoutMs { get; set; } = Timeout.Infinite; - int TcpClientTimeoutMs { get; set; } = Timeout.Infinite; + protected int TcpClientRetryTimeoutMs { get; set; } = 500; - int TcpClientRetryTimeoutMs { get; set; } = 500; + public delegate TcpClientRouterFactory CreateInstanceDelegate(string tcpClient, int runtimeTimeoutMs, ILogger logger); + + public static TcpClientRouterFactory CreateDefaultInstance(string tcpClient, int runtimeTimeoutMs, ILogger logger) + { + return new TcpClientRouterFactory(tcpClient, runtimeTimeoutMs, logger); + } public string TcpClientAddress { get { return _tcpClientAddress; } @@ -304,7 +318,7 @@ namespace Microsoft.Diagnostics.NETCore.Client TcpClientTimeoutMs = runtimeTimeoutMs; } - public async Task ConnectTcpStreamAsync(CancellationToken token) + public virtual async Task ConnectTcpStreamAsync(CancellationToken token) { Stream tcpClientStream = null; @@ -368,6 +382,14 @@ namespace Microsoft.Diagnostics.NETCore.Client return tcpClientStream; } + public virtual void Start() + { + } + + public virtual void Stop() + { + } + async Task ConnectAsyncInternal(Socket clientSocket, EndPoint remoteEP, CancellationToken token) { using (token.Register(() => clientSocket.Close(0))) @@ -597,10 +619,10 @@ namespace Microsoft.Diagnostics.NETCore.Client TcpServerRouterFactory _tcpServerRouterFactory; IpcServerRouterFactory _ipcServerRouterFactory; - public IpcServerTcpServerRouterFactory(string ipcServer, string tcpServer, int runtimeTimeoutMs, ILogger logger) + public IpcServerTcpServerRouterFactory(string ipcServer, string tcpServer, int runtimeTimeoutMs, TcpServerRouterFactory.CreateInstanceDelegate factory, ILogger logger) { _logger = logger; - _tcpServerRouterFactory = new TcpServerRouterFactory(tcpServer, runtimeTimeoutMs, logger); + _tcpServerRouterFactory = factory(tcpServer, runtimeTimeoutMs, logger); _ipcServerRouterFactory = new IpcServerRouterFactory(ipcServer, logger); } @@ -638,6 +660,7 @@ namespace Microsoft.Diagnostics.NETCore.Client public override Task Stop() { + _logger?.LogInformation($"Stopping IPC server ({_ipcServerRouterFactory.IpcServerPath}) <--> TCP server ({_tcpServerRouterFactory.TcpServerAddress}) router."); _ipcServerRouterFactory.Stop(); return _tcpServerRouterFactory.Stop(); } @@ -786,11 +809,11 @@ namespace Microsoft.Diagnostics.NETCore.Client IpcServerRouterFactory _ipcServerRouterFactory; TcpClientRouterFactory _tcpClientRouterFactory; - public IpcServerTcpClientRouterFactory(string ipcServer, string tcpClient, int runtimeTimeoutMs, ILogger logger) + public IpcServerTcpClientRouterFactory(string ipcServer, string tcpClient, int runtimeTimeoutMs, TcpClientRouterFactory.CreateInstanceDelegate factory, ILogger logger) { _logger = logger; _ipcServerRouterFactory = new IpcServerRouterFactory(ipcServer, logger); - _tcpClientRouterFactory = new TcpClientRouterFactory(tcpClient, runtimeTimeoutMs, logger); + _tcpClientRouterFactory = factory(tcpClient, runtimeTimeoutMs, logger); } public override string IpcAddress @@ -820,11 +843,14 @@ namespace Microsoft.Diagnostics.NETCore.Client public override void Start() { _ipcServerRouterFactory.Start(); + _tcpClientRouterFactory.Start(); _logger?.LogInformation($"Starting IPC server ({_ipcServerRouterFactory.IpcServerPath}) <--> TCP client ({_tcpClientRouterFactory.TcpClientAddress}) router."); } public override Task Stop() { + _logger?.LogInformation($"Stopping IPC server ({_ipcServerRouterFactory.IpcServerPath}) <--> TCP client ({_tcpClientRouterFactory.TcpClientAddress}) router."); + _tcpClientRouterFactory.Stop(); _ipcServerRouterFactory.Stop(); return Task.CompletedTask; } @@ -905,11 +931,11 @@ namespace Microsoft.Diagnostics.NETCore.Client IpcClientRouterFactory _ipcClientRouterFactory; TcpServerRouterFactory _tcpServerRouterFactory; - public IpcClientTcpServerRouterFactory(string ipcClient, string tcpServer, int runtimeTimeoutMs, ILogger logger) + public IpcClientTcpServerRouterFactory(string ipcClient, string tcpServer, int runtimeTimeoutMs, TcpServerRouterFactory.CreateInstanceDelegate factory, ILogger logger) { _logger = logger; _ipcClientRouterFactory = new IpcClientRouterFactory(ipcClient, logger); - _tcpServerRouterFactory = new TcpServerRouterFactory(tcpServer, runtimeTimeoutMs, logger); + _tcpServerRouterFactory = factory(tcpServer, runtimeTimeoutMs, logger); } public override string IpcAddress @@ -944,6 +970,7 @@ namespace Microsoft.Diagnostics.NETCore.Client public override Task Stop() { + _logger?.LogInformation($"Stopping IPC client ({_ipcClientRouterFactory.IpcClientPath}) <--> TCP server ({_tcpServerRouterFactory.TcpServerAddress}) router."); return _tcpServerRouterFactory.Stop(); } diff --git a/src/Microsoft.Diagnostics.NETCore.Client/DiagnosticsServerRouter/DiagnosticsServerRouterRunner.cs b/src/Microsoft.Diagnostics.NETCore.Client/DiagnosticsServerRouter/DiagnosticsServerRouterRunner.cs index c4fdd81b2..91d087475 100644 --- a/src/Microsoft.Diagnostics.NETCore.Client/DiagnosticsServerRouter/DiagnosticsServerRouterRunner.cs +++ b/src/Microsoft.Diagnostics.NETCore.Client/DiagnosticsServerRouter/DiagnosticsServerRouterRunner.cs @@ -22,19 +22,19 @@ namespace Microsoft.Diagnostics.NETCore.Client void OnRouterStopped(); } - public static async Task runIpcClientTcpServerRouter(CancellationToken token, string ipcClient, string tcpServer, int runtimeTimeoutMs, ILogger logger, Callbacks callbacks) + public static async Task runIpcClientTcpServerRouter(CancellationToken token, string ipcClient, string tcpServer, int runtimeTimeoutMs, TcpServerRouterFactory.CreateInstanceDelegate tcpServerRouterFactory, ILogger logger, Callbacks callbacks) { - return await runRouter(token, new IpcClientTcpServerRouterFactory(ipcClient, tcpServer, runtimeTimeoutMs, logger), callbacks).ConfigureAwait(false); + return await runRouter(token, new IpcClientTcpServerRouterFactory(ipcClient, tcpServer, runtimeTimeoutMs, tcpServerRouterFactory, logger), callbacks).ConfigureAwait(false); } - public static async Task runIpcServerTcpServerRouter(CancellationToken token, string ipcServer, string tcpServer, int runtimeTimeoutMs, ILogger logger, Callbacks callbacks) + public static async Task runIpcServerTcpServerRouter(CancellationToken token, string ipcServer, string tcpServer, int runtimeTimeoutMs, TcpServerRouterFactory.CreateInstanceDelegate tcpServerRouterFactory, ILogger logger, Callbacks callbacks) { - return await runRouter(token, new IpcServerTcpServerRouterFactory(ipcServer, tcpServer, runtimeTimeoutMs, logger), callbacks).ConfigureAwait(false); + return await runRouter(token, new IpcServerTcpServerRouterFactory(ipcServer, tcpServer, runtimeTimeoutMs, tcpServerRouterFactory, logger), callbacks).ConfigureAwait(false); } - public static async Task runIpcServerTcpClientRouter(CancellationToken token, string ipcServer, string tcpClient, int runtimeTimeoutMs, ILogger logger, Callbacks callbacks) + public static async Task runIpcServerTcpClientRouter(CancellationToken token, string ipcServer, string tcpClient, int runtimeTimeoutMs, TcpClientRouterFactory.CreateInstanceDelegate tcpClientRouterFactory, ILogger logger, Callbacks callbacks) { - return await runRouter(token, new IpcServerTcpClientRouterFactory(ipcServer, tcpClient, runtimeTimeoutMs, logger), callbacks).ConfigureAwait(false); + return await runRouter(token, new IpcServerTcpClientRouterFactory(ipcServer, tcpClient, runtimeTimeoutMs, tcpClientRouterFactory, logger), callbacks).ConfigureAwait(false); } public static bool isLoopbackOnly(string address) diff --git a/src/Tools/dotnet-dsrouter/ADBTcpRouterFactory.cs b/src/Tools/dotnet-dsrouter/ADBTcpRouterFactory.cs new file mode 100644 index 000000000..f74159309 --- /dev/null +++ b/src/Tools/dotnet-dsrouter/ADBTcpRouterFactory.cs @@ -0,0 +1,173 @@ +using System; +using System.Diagnostics; +using System.IO; +using System.Threading.Tasks; +using Microsoft.Diagnostics.NETCore.Client; +using Microsoft.Extensions.Logging; + +namespace Microsoft.Diagnostics.Tools.DiagnosticsServerRouter +{ + internal class ADBCommandExec + { + public static bool AdbAddPortForward(int port, ILogger logger) + { + bool ownsPortForward = false; + if (!RunAdbCommandInternal($"forward --list", $"tcp:{port}", 0, logger)) + { + ownsPortForward = RunAdbCommandInternal($"forward tcp:{port} tcp:{port}", "", 0, logger); + if (!ownsPortForward) + logger?.LogError($"Failed setting up port forward for tcp:{port}."); + } + return ownsPortForward; + } + + public static bool AdbAddPortReverse(int port, ILogger logger) + { + bool ownsPortForward = false; + if (!RunAdbCommandInternal($"reverse --list", $"tcp:{port}", 0, logger)) + { + ownsPortForward = RunAdbCommandInternal($"reverse tcp:{port} tcp:{port}", "", 0, logger); + if (!ownsPortForward) + logger?.LogError($"Failed setting up port forward for tcp:{port}."); + } + return ownsPortForward; + } + + public static void AdbRemovePortForward(int port, bool ownsPortForward, ILogger logger) + { + if (ownsPortForward) + { + if (!RunAdbCommandInternal($"forward --remove tcp:{port}", "", 0, logger)) + logger?.LogError($"Failed removing port forward for tcp:{port}."); + } + } + + public static void AdbRemovePortReverse(int port, bool ownsPortForward, ILogger logger) + { + if (ownsPortForward) + { + if (!RunAdbCommandInternal($"reverse --remove tcp:{port}", "", 0, logger)) + logger?.LogError($"Failed removing port forward for tcp:{port}."); + } + } + + public static bool RunAdbCommandInternal(string command, string expectedOutput, int expectedExitCode, ILogger logger) + { + var sdkRoot = Environment.GetEnvironmentVariable("ANDROID_SDK_ROOT"); + var adbTool = "adb"; + + if (!string.IsNullOrEmpty(sdkRoot)) + adbTool = sdkRoot + Path.DirectorySeparatorChar + "platform-tools" + Path.DirectorySeparatorChar + adbTool; + + logger?.LogDebug($"Executing {adbTool} {command}."); + + var process = new Process(); + process.StartInfo.FileName = adbTool; + process.StartInfo.Arguments = command; + + process.StartInfo.UseShellExecute = false; + process.StartInfo.RedirectStandardOutput = true; + process.StartInfo.RedirectStandardError = true; + process.StartInfo.RedirectStandardInput = false; + + bool processStartedResult = false; + bool expectedOutputResult = true; + bool expectedExitCodeResult = true; + + try + { + processStartedResult = process.Start(); + } + catch (Exception) + { + } + + if (processStartedResult) + { + var stdout = process.StandardOutput.ReadToEnd(); + var stderr = process.StandardError.ReadToEnd(); + + if (!string.IsNullOrEmpty(expectedOutput)) + expectedOutputResult = !string.IsNullOrEmpty(stdout) ? stdout.Contains(expectedOutput) : false; + + if (!string.IsNullOrEmpty(stdout)) + logger.LogTrace($"stdout: {stdout}"); + + if (!string.IsNullOrEmpty(stderr)) + logger.LogError($"stderr: {stderr}"); + } + + if (processStartedResult) + { + process.WaitForExit(); + expectedExitCodeResult = (expectedExitCode != -1) ? (process.ExitCode == expectedExitCode) : true; + } + + return processStartedResult && expectedOutputResult && expectedExitCodeResult; + } + } + + internal class ADBTcpServerRouterFactory : TcpServerRouterFactory + { + readonly int _port; + bool _ownsPortReverse; + + public static TcpServerRouterFactory CreateADBInstance(string tcpServer, int runtimeTimeoutMs, ILogger logger) + { + return new ADBTcpServerRouterFactory(tcpServer, runtimeTimeoutMs, logger); + } + + public ADBTcpServerRouterFactory(string tcpServer, int runtimeTimeoutMs, ILogger logger) + : base(tcpServer, runtimeTimeoutMs, logger) + { + _port = new IpcTcpSocketEndPoint(tcpServer).EndPoint.Port; + } + + public override void Start() + { + // Enable port reverse. + _ownsPortReverse = ADBCommandExec.AdbAddPortReverse(_port, _logger); + + base.Start(); + } + + public override async Task Stop() + { + await base.Stop().ConfigureAwait(false); + + // Disable port reverse. + ADBCommandExec.AdbRemovePortReverse(_port, _ownsPortReverse, _logger); + _ownsPortReverse = false; + } + } + + internal class ADBTcpClientRouterFactory : TcpClientRouterFactory + { + readonly int _port; + bool _ownsPortForward; + + public static TcpClientRouterFactory CreateADBInstance(string tcpClient, int runtimeTimeoutMs, ILogger logger) + { + return new ADBTcpClientRouterFactory(tcpClient, runtimeTimeoutMs, logger); + } + + public ADBTcpClientRouterFactory(string tcpClient, int runtimeTimeoutMs, ILogger logger) + : base(tcpClient, runtimeTimeoutMs, logger) + { + _port = new IpcTcpSocketEndPoint(tcpClient).EndPoint.Port; + } + + public override void Start() + { + // Enable port forwarding. + _ownsPortForward = ADBCommandExec.AdbAddPortForward(_port, _logger); + } + + public override void Stop() + { + // Disable port forwarding. + ADBCommandExec.AdbRemovePortForward(_port, _ownsPortForward, _logger); + _ownsPortForward = false; + } + } +} diff --git a/src/Tools/dotnet-dsrouter/DiagnosticsServerRouterCommands.cs b/src/Tools/dotnet-dsrouter/DiagnosticsServerRouterCommands.cs index c51639f3d..f79b5cbcc 100644 --- a/src/Tools/dotnet-dsrouter/DiagnosticsServerRouterCommands.cs +++ b/src/Tools/dotnet-dsrouter/DiagnosticsServerRouterCommands.cs @@ -48,7 +48,7 @@ namespace Microsoft.Diagnostics.Tools.DiagnosticsServerRouter { } - public async Task RunIpcClientTcpServerRouter(CancellationToken token, string ipcClient, string tcpServer, int runtimeTimeout, string verbose) + public async Task RunIpcClientTcpServerRouter(CancellationToken token, string ipcClient, string tcpServer, int runtimeTimeout, string verbose, string forwardPort) { checkLoopbackOnly(tcpServer); @@ -69,7 +69,22 @@ namespace Microsoft.Diagnostics.Tools.DiagnosticsServerRouter Launcher.Verbose = logLevel != LogLevel.Information; Launcher.CommandToken = token; - var routerTask = DiagnosticsServerRouterRunner.runIpcClientTcpServerRouter(linkedCancelToken.Token, ipcClient, tcpServer, runtimeTimeout == Timeout.Infinite ? runtimeTimeout : runtimeTimeout * 1000, factory.CreateLogger("dotnet-dsrounter"), Launcher); + var logger = factory.CreateLogger("dotnet-dsrouter"); + + TcpServerRouterFactory.CreateInstanceDelegate tcpServerRouterFactory = TcpServerRouterFactory.CreateDefaultInstance; + if (!string.IsNullOrEmpty(forwardPort)) + { + if (string.Compare(forwardPort, "android", StringComparison.OrdinalIgnoreCase) == 0) + { + tcpServerRouterFactory = ADBTcpServerRouterFactory.CreateADBInstance; + } + else + { + logger.LogError($"Unknown port forwarding argument, {forwardPort}. Only Android port fowarding is supported for TcpServer mode. Ignoring --forward-port argument."); + } + } + + var routerTask = DiagnosticsServerRouterRunner.runIpcClientTcpServerRouter(linkedCancelToken.Token, ipcClient, tcpServer, runtimeTimeout == Timeout.Infinite ? runtimeTimeout : runtimeTimeout * 1000, tcpServerRouterFactory, logger, Launcher); while (!linkedCancelToken.IsCancellationRequested) { @@ -91,7 +106,7 @@ namespace Microsoft.Diagnostics.Tools.DiagnosticsServerRouter return routerTask.Result; } - public async Task RunIpcServerTcpServerRouter(CancellationToken token, string ipcServer, string tcpServer, int runtimeTimeout, string verbose) + public async Task RunIpcServerTcpServerRouter(CancellationToken token, string ipcServer, string tcpServer, int runtimeTimeout, string verbose, string forwardPort) { checkLoopbackOnly(tcpServer); @@ -112,7 +127,22 @@ namespace Microsoft.Diagnostics.Tools.DiagnosticsServerRouter Launcher.Verbose = logLevel != LogLevel.Information; Launcher.CommandToken = token; - var routerTask = DiagnosticsServerRouterRunner.runIpcServerTcpServerRouter(linkedCancelToken.Token, ipcServer, tcpServer, runtimeTimeout == Timeout.Infinite ? runtimeTimeout : runtimeTimeout * 1000, factory.CreateLogger("dotnet-dsrounter"), Launcher); + var logger = factory.CreateLogger("dotnet-dsrouter"); + + TcpServerRouterFactory.CreateInstanceDelegate tcpServerRouterFactory = TcpServerRouterFactory.CreateDefaultInstance; + if (!string.IsNullOrEmpty(forwardPort)) + { + if (string.Compare(forwardPort, "android", StringComparison.OrdinalIgnoreCase) == 0) + { + tcpServerRouterFactory = ADBTcpServerRouterFactory.CreateADBInstance; + } + else + { + logger.LogError($"Unknown port forwarding argument, {forwardPort}. Only Android port fowarding is supported for TcpServer mode. Ignoring --forward-port argument."); + } + } + + var routerTask = DiagnosticsServerRouterRunner.runIpcServerTcpServerRouter(linkedCancelToken.Token, ipcServer, tcpServer, runtimeTimeout == Timeout.Infinite ? runtimeTimeout : runtimeTimeout * 1000, tcpServerRouterFactory, logger, Launcher); while (!linkedCancelToken.IsCancellationRequested) { @@ -134,7 +164,7 @@ namespace Microsoft.Diagnostics.Tools.DiagnosticsServerRouter return routerTask.Result; } - public async Task RunIpcServerTcpClientRouter(CancellationToken token, string ipcServer, string tcpClient, int runtimeTimeout, string verbose) + public async Task RunIpcServerTcpClientRouter(CancellationToken token, string ipcServer, string tcpClient, int runtimeTimeout, string verbose, string forwardPort) { using CancellationTokenSource cancelRouterTask = new CancellationTokenSource(); using CancellationTokenSource linkedCancelToken = CancellationTokenSource.CreateLinkedTokenSource(token, cancelRouterTask.Token); @@ -153,7 +183,26 @@ namespace Microsoft.Diagnostics.Tools.DiagnosticsServerRouter Launcher.Verbose = logLevel != LogLevel.Information; Launcher.CommandToken = token; - var routerTask = DiagnosticsServerRouterRunner.runIpcServerTcpClientRouter(linkedCancelToken.Token, ipcServer, tcpClient, runtimeTimeout == Timeout.Infinite ? runtimeTimeout : runtimeTimeout * 1000, factory.CreateLogger("dotnet-dsrounter"), Launcher); + var logger = factory.CreateLogger("dotnet-dsrouter"); + + TcpClientRouterFactory.CreateInstanceDelegate tcpClientRouterFactory = TcpClientRouterFactory.CreateDefaultInstance; + if (!string.IsNullOrEmpty(forwardPort)) + { + if (string.Compare(forwardPort, "android", StringComparison.OrdinalIgnoreCase) == 0) + { + tcpClientRouterFactory = ADBTcpClientRouterFactory.CreateADBInstance; + } + else if (string.Compare(forwardPort, "ios", StringComparison.OrdinalIgnoreCase) == 0) + { + tcpClientRouterFactory = USBMuxTcpClientRouterFactory.CreateUSBMuxInstance; + } + else + { + logger.LogError($"Unknown port forwarding argument, {forwardPort}. Ignoring --forward-port argument."); + } + } + + var routerTask = DiagnosticsServerRouterRunner.runIpcServerTcpClientRouter(linkedCancelToken.Token, ipcServer, tcpClient, runtimeTimeout == Timeout.Infinite ? runtimeTimeout : runtimeTimeout * 1000, tcpClientRouterFactory, logger, Launcher); while (!linkedCancelToken.IsCancellationRequested) { diff --git a/src/Tools/dotnet-dsrouter/Program.cs b/src/Tools/dotnet-dsrouter/Program.cs index a56da4fbb..52ab8165a 100644 --- a/src/Tools/dotnet-dsrouter/Program.cs +++ b/src/Tools/dotnet-dsrouter/Program.cs @@ -17,9 +17,9 @@ namespace Microsoft.Diagnostics.Tools.DiagnosticsServerRouter { internal class Program { - delegate Task DiagnosticsServerIpcClientTcpServerRouterDelegate(CancellationToken ct, string ipcClient, string tcpServer, int runtimeTimeoutS, string verbose); - delegate Task DiagnosticsServerIpcServerTcpServerRouterDelegate(CancellationToken ct, string ipcServer, string tcpServer, int runtimeTimeoutS, string verbose); - delegate Task DiagnosticsServerIpcServerTcpClientRouterDelegate(CancellationToken ct, string ipcServer, string tcpClient, int runtimeTimeoutS, string verbose); + delegate Task DiagnosticsServerIpcClientTcpServerRouterDelegate(CancellationToken ct, string ipcClient, string tcpServer, int runtimeTimeoutS, string verbose, string forwardPort); + delegate Task DiagnosticsServerIpcServerTcpServerRouterDelegate(CancellationToken ct, string ipcServer, string tcpServer, int runtimeTimeoutS, string verbose, string forwardPort); + delegate Task DiagnosticsServerIpcServerTcpClientRouterDelegate(CancellationToken ct, string ipcServer, string tcpClient, int runtimeTimeoutS, string verbose, string forwardPort); private static Command IpcClientTcpServerRouterCommand() => new Command( @@ -31,7 +31,7 @@ namespace Microsoft.Diagnostics.Tools.DiagnosticsServerRouter // Handler HandlerDescriptor.FromDelegate((DiagnosticsServerIpcClientTcpServerRouterDelegate)new DiagnosticsServerRouterCommands().RunIpcClientTcpServerRouter).GetCommandHandler(), // Options - IpcClientAddressOption(), TcpServerAddressOption(), RuntimeTimeoutOption(), VerboseOption() + IpcClientAddressOption(), TcpServerAddressOption(), RuntimeTimeoutOption(), VerboseOption(), ForwardPortOption() }; private static Command IpcServerTcpServerRouterCommand() => @@ -44,7 +44,7 @@ namespace Microsoft.Diagnostics.Tools.DiagnosticsServerRouter // Handler HandlerDescriptor.FromDelegate((DiagnosticsServerIpcServerTcpServerRouterDelegate)new DiagnosticsServerRouterCommands().RunIpcServerTcpServerRouter).GetCommandHandler(), // Options - IpcServerAddressOption(), TcpServerAddressOption(), RuntimeTimeoutOption(), VerboseOption() + IpcServerAddressOption(), TcpServerAddressOption(), RuntimeTimeoutOption(), VerboseOption(), ForwardPortOption() }; private static Command IpcServerTcpClientRouterCommand() => @@ -57,7 +57,7 @@ namespace Microsoft.Diagnostics.Tools.DiagnosticsServerRouter // Handler HandlerDescriptor.FromDelegate((DiagnosticsServerIpcServerTcpClientRouterDelegate)new DiagnosticsServerRouterCommands().RunIpcServerTcpClientRouter).GetCommandHandler(), // Options - IpcServerAddressOption(), TcpClientAddressOption(), RuntimeTimeoutOption(), VerboseOption() + IpcServerAddressOption(), TcpClientAddressOption(), RuntimeTimeoutOption(), VerboseOption(), ForwardPortOption() }; private static Option IpcClientAddressOption() => @@ -118,6 +118,14 @@ namespace Microsoft.Diagnostics.Tools.DiagnosticsServerRouter Argument = new Argument(name: "verbose", getDefaultValue: () => "") }; + private static Option ForwardPortOption() => + new Option( + aliases: new[] { "--forward-port", "-fp" }, + description: "Enable port forwarding, values Android|iOS for TcpClient and only Android for TcpServer. Make sure to set ANDROID_SDK_ROOT before using this option on Android.") + { + Argument = new Argument(name: "forwardPort", getDefaultValue: () => "") + }; + private static int Main(string[] args) { StringBuilder message = new StringBuilder(); diff --git a/src/Tools/dotnet-dsrouter/USBMuxTcpClientRouterFactory.cs b/src/Tools/dotnet-dsrouter/USBMuxTcpClientRouterFactory.cs new file mode 100644 index 000000000..29ca466c5 --- /dev/null +++ b/src/Tools/dotnet-dsrouter/USBMuxTcpClientRouterFactory.cs @@ -0,0 +1,492 @@ +using System; +using System.IO; +using System.Net; +using System.Runtime.InteropServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Diagnostics.NETCore.Client; +using Microsoft.Extensions.Logging; + +namespace Microsoft.Diagnostics.Tools.DiagnosticsServerRouter +{ + internal class USBMuxInterop + { + public const string CoreFoundationLibrary = "/System/Library/Frameworks/CoreFoundation.framework/CoreFoundation"; + public const string MobileDeviceLibrary = "/System/Library/PrivateFrameworks/MobileDevice.framework/MobileDevice"; + public const string LibC = "libc"; + + public const int EINTR = 4; + + public enum AMDeviceNotificationMessage : uint + { + None = 0, + Connected = 1, + Disconnected = 2, + Unsubscribed = 3 + } + + public struct AMDeviceNotificationCallbackInfo + { + public AMDeviceNotificationCallbackInfo(IntPtr device, AMDeviceNotificationMessage message) + { + this.am_device = device; + this.message = message; + } + + public IntPtr am_device; + public AMDeviceNotificationMessage message; + } + + public delegate void DeviceNotificationDelegate(ref AMDeviceNotificationCallbackInfo info); + +#region MobileDeviceLibrary + [DllImport(MobileDeviceLibrary)] + public static extern uint AMDeviceNotificationSubscribe(DeviceNotificationDelegate callback, uint unused0, uint unused1, uint unused2, out IntPtr context); + + [DllImport(MobileDeviceLibrary)] + public static extern uint AMDeviceNotificationUnsubscribe(IntPtr context); + + [DllImport(MobileDeviceLibrary)] + public static extern uint AMDeviceConnect(IntPtr device); + + [DllImport(MobileDeviceLibrary)] + public static extern uint AMDeviceDisconnect(IntPtr device); + + [DllImport(MobileDeviceLibrary)] + public static extern uint AMDeviceGetConnectionID(IntPtr device); + + [DllImport(MobileDeviceLibrary)] + public static extern int AMDeviceGetInterfaceType(IntPtr device); + + [DllImport(MobileDeviceLibrary)] + public static extern uint USBMuxConnectByPort(uint connection, ushort port, out int socketHandle); +#endregion +#region CoreFoundationLibrary + [DllImport(CoreFoundationLibrary)] + public static extern void CFRunLoopRun(); + + [DllImport(CoreFoundationLibrary)] + public static extern void CFRunLoopStop(IntPtr runLoop); + + [DllImport(CoreFoundationLibrary)] + public static extern IntPtr CFRunLoopGetCurrent(); +#endregion +#region LibC + [DllImport(LibC, SetLastError = true)] + public static extern unsafe int send(int handle, byte* buffer, IntPtr length, int flags); + + [DllImport(LibC, SetLastError = true)] + public static extern unsafe int recv(int handle, byte* buffer, IntPtr length, int flags); + + [DllImport(LibC, SetLastError = true)] + public static extern int close(int handle); +#endregion + } + + internal class USBMuxStream : Stream + { + int _handle = -1; + + public USBMuxStream(int handle) + { + _handle = handle; + } + + public bool IsOpen => _handle != -1; + + public override bool CanRead => true; + + public override bool CanSeek => false; + + public override bool CanWrite => true; + + public override long Length => throw new NotImplementedException(); + + public override long Position { + get => throw new NotImplementedException(); + set => throw new NotImplementedException(); + } + + public override void Flush() + { + } + + public override Task FlushAsync(CancellationToken cancellationToken) + { + return Task.CompletedTask; + } + + public override int Read(byte[] buffer, int offset, int count) + { + bool continueRead = true; + int bytesToRead = count; + int totalBytesRead = 0; + int currentBytesRead = 0; + + while (continueRead && bytesToRead - totalBytesRead > 0) + { + if (!IsOpen) + throw new EndOfStreamException(); + + unsafe + { + fixed (byte* fixedBuffer = buffer) + { + currentBytesRead = USBMuxInterop.recv(_handle, fixedBuffer + totalBytesRead, new IntPtr(bytesToRead - totalBytesRead), 0); + } + } + + if (currentBytesRead == -1 && Marshal.GetLastWin32Error() == USBMuxInterop.EINTR) + continue; + + continueRead = currentBytesRead > 0; + if (!continueRead) + break; + + totalBytesRead += currentBytesRead; + } + + return totalBytesRead; + } + + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return Task.Run(() => + { + int result = 0; + using (cancellationToken.Register(() => Close())) + { + try + { + result = Read(buffer, offset, count); + } + catch (Exception) + { + cancellationToken.ThrowIfCancellationRequested(); + result = 0; + } + } + return result; + }); + } + + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotImplementedException(); + } + + public override void SetLength(long value) + { + throw new NotImplementedException(); + } + + public override void Write(byte[] buffer, int offset, int count) + { + bool continueWrite = true; + int bytesToWrite = count; + int currentBytesWritten = 0; + int totalBytesWritten = 0; + + while (continueWrite && bytesToWrite - totalBytesWritten > 0) + { + if (!IsOpen) + throw new EndOfStreamException(); + + unsafe + { + fixed (byte* fixedBuffer = buffer) + { + currentBytesWritten = USBMuxInterop.send(_handle, fixedBuffer + totalBytesWritten, new IntPtr(bytesToWrite - totalBytesWritten), 0); + } + } + + if (currentBytesWritten == -1 && Marshal.GetLastWin32Error() == USBMuxInterop.EINTR) + continue; + + continueWrite = currentBytesWritten != -1; + + if (!continueWrite) + break; + + totalBytesWritten += currentBytesWritten; + } + } + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return Task.Run(() => + { + using (cancellationToken.Register(() => Close())) + { + Write(buffer, offset, count); + } + }); + } + + public override void Close() + { + if (IsOpen) + { + USBMuxInterop.close(_handle); + _handle = -1; + } + } + + protected override void Dispose(bool disposing) + { + Close(); + base.Dispose(disposing); + } + } + + internal class USBMuxTcpClientRouterFactory : TcpClientRouterFactory + { + readonly int _port; + + IntPtr _device = IntPtr.Zero; + uint _deviceConnectionID = 0; + IntPtr _loopingThread = IntPtr.Zero; + + public static TcpClientRouterFactory CreateUSBMuxInstance(string tcpClient, int runtimeTimeoutMs, ILogger logger) + { + return new USBMuxTcpClientRouterFactory(tcpClient, runtimeTimeoutMs, logger); + } + + public USBMuxTcpClientRouterFactory(string tcpClient, int runtimeTimeoutMs, ILogger logger) + : base(tcpClient, runtimeTimeoutMs, logger) + { + _port = new IpcTcpSocketEndPoint(tcpClient).EndPoint.Port; + } + + public override async Task ConnectTcpStreamAsync(CancellationToken token) + { + bool retry = false; + int handle = -1; + ushort networkPort = (ushort)IPAddress.HostToNetworkOrder(unchecked((short)_port)); + + _logger?.LogDebug($"Connecting new tcp endpoint over usbmux \"{_tcpClientAddress}\"."); + + using var connectTimeoutTokenSource = new CancellationTokenSource(); + using var connectTokenSource = CancellationTokenSource.CreateLinkedTokenSource(token, connectTimeoutTokenSource.Token); + + connectTimeoutTokenSource.CancelAfter(TcpClientTimeoutMs); + + do + { + try + { + handle = ConnectTcpClientOverUSBMux(); + retry = false; + } + catch (Exception) + { + if (connectTimeoutTokenSource.IsCancellationRequested) + { + _logger?.LogDebug("No USB stream connected, timing out."); + + if (_auto_shutdown) + throw new RuntimeTimeoutException(TcpClientTimeoutMs); + + throw new TimeoutException(); + } + + // If we are not doing auto shutdown when runtime is unavailable, fail right away, this will + // break any accepted IPC connections, making sure client is notified and could reconnect. + // If we do have auto shutdown enabled, retry until succeed or time out. + if (!_auto_shutdown) + { + _logger?.LogTrace($"Failed connecting {_port} over usbmux."); + throw; + } + + _logger?.LogTrace($"Failed connecting {_port} over usbmux, wait {TcpClientRetryTimeoutMs} ms before retrying."); + + // If we get an error (without hitting timeout above), most likely due to unavailable device/listener. + // Delay execution to prevent to rapid retry attempts. + await Task.Delay(TcpClientRetryTimeoutMs, token).ConfigureAwait(false); + + retry = true; + } + } + while (retry); + + return new USBMuxStream(handle); + } + + public override void Start() + { + // Start device subscription thread. + StartNotificationSubscribeThread(); + } + + public override void Stop() + { + // Stop device subscription thread. + StopNotificationSubscribeThread(); + } + + int ConnectTcpClientOverUSBMux() + { + uint result = 0; + int handle = -1; + ushort networkPort = (ushort)IPAddress.HostToNetworkOrder(unchecked((short)_port)); + + lock (this) + { + if (_deviceConnectionID == 0) + throw new Exception($"Failed to connect device over USB, no device currently connected."); + + result = USBMuxInterop.USBMuxConnectByPort(_deviceConnectionID, networkPort, out handle); + } + + if (result != 0) + throw new Exception($"Failed to connect device over USB using connection {_deviceConnectionID} and port {_port}."); + + return handle; + } + + bool ConnectDevice(IntPtr newDevice) + { + if (_device != IntPtr.Zero) + return false; + + _device = newDevice; + if (USBMuxInterop.AMDeviceConnect(_device) == 0) + { + _deviceConnectionID = USBMuxInterop.AMDeviceGetConnectionID(_device); + _logger?.LogInformation($"Successfully connected new device, id={_deviceConnectionID}."); + return true; + } + else + { + _logger?.LogError($"Failed connecting new device."); + return false; + } + } + + bool DisconnectDevice() + { + if (_device != IntPtr.Zero) + { + if (_deviceConnectionID != 0) + { + USBMuxInterop.AMDeviceDisconnect(_device); + _logger?.LogInformation($"Successfully disconnected device, id={_deviceConnectionID}."); + _deviceConnectionID = 0; + } + + _device = IntPtr.Zero; + } + + return true; + } + + void AMDeviceNotificationCallback(ref USBMuxInterop.AMDeviceNotificationCallbackInfo info) + { + _logger?.LogTrace($"AMDeviceNotificationInternal callback, device={info.am_device}, action={info.message}"); + + try + { + lock (this) + { + int interfaceType = USBMuxInterop.AMDeviceGetInterfaceType(info.am_device); + switch (info.message) + { + case USBMuxInterop.AMDeviceNotificationMessage.Connected: + if (interfaceType == 1 && _device == IntPtr.Zero) + { + ConnectDevice(info.am_device); + } + else if (interfaceType == 1 && _device != IntPtr.Zero) + { + _logger?.LogInformation($"Discovered new device, but one is already connected, ignoring new device."); + } + else if (interfaceType == 0) + { + _logger?.LogInformation($"Discovered new device not connected over USB, ignoring new device."); + } + break; + case USBMuxInterop.AMDeviceNotificationMessage.Disconnected: + case USBMuxInterop.AMDeviceNotificationMessage.Unsubscribed: + if (_device == info.am_device) + { + DisconnectDevice(); + } + break; + } + } + } + catch (Exception ex) + { + _logger?.LogError($"Failed AMDeviceNotificationCallback: {ex.Message}. Failed handling device={info.am_device} using action={info.message}"); + } + } + + void AMDeviceNotificationSubscribeLoop() + { + IntPtr context = IntPtr.Zero; + + try + { + lock (this) + { + if (_loopingThread != IntPtr.Zero) + { + _logger?.LogError($"AMDeviceNotificationSubscribeLoop already running."); + throw new Exception("AMDeviceNotificationSubscribeLoop already running."); + } + + _loopingThread = USBMuxInterop.CFRunLoopGetCurrent(); + } + + _logger?.LogTrace($"Calling AMDeviceNotificationSubscribe."); + + if (USBMuxInterop.AMDeviceNotificationSubscribe(AMDeviceNotificationCallback, 0, 0, 0, out context) != 0) + { + _logger?.LogError($"Failed AMDeviceNotificationSubscribe call."); + throw new Exception("Failed AMDeviceNotificationSubscribe call."); + } + + _logger?.LogTrace($"Start dispatching notifications."); + USBMuxInterop.CFRunLoopRun(); + _logger?.LogTrace($"Stop dispatching notifications."); + } + catch (Exception ex) + { + _logger?.LogError($"Failed running subscribe loop: {ex.Message}. Disabling detection of devices connected over USB."); + } + finally + { + lock (this) + { + if (_loopingThread != IntPtr.Zero) + { + _loopingThread = IntPtr.Zero; + } + + DisconnectDevice(); + } + + if (context != IntPtr.Zero) + { + _logger?.LogTrace($"Calling AMDeviceNotificationUnsubscribe."); + USBMuxInterop.AMDeviceNotificationUnsubscribe(context); + } + } + } + + void StartNotificationSubscribeThread() + { + new Thread(new ThreadStart(() => AMDeviceNotificationSubscribeLoop())).Start(); + } + + void StopNotificationSubscribeThread() + { + lock (this) + { + if (_loopingThread != IntPtr.Zero) + USBMuxInterop.CFRunLoopStop(_loopingThread); + } + } + } +} -- 2.34.1