Add adb port forward and usbmux support to dsrouter. (#2366)
authorJohan Lorensson <lateralusx.github@gmail.com>
Wed, 23 Jun 2021 18:22:32 +0000 (20:22 +0200)
committerGitHub <noreply@github.com>
Wed, 23 Jun 2021 18:22:32 +0000 (11:22 -0700)
src/Microsoft.Diagnostics.NETCore.Client/DiagnosticsServerRouter/DiagnosticsServerRouterFactory.cs
src/Microsoft.Diagnostics.NETCore.Client/DiagnosticsServerRouter/DiagnosticsServerRouterRunner.cs
src/Tools/dotnet-dsrouter/ADBTcpRouterFactory.cs [new file with mode: 0644]
src/Tools/dotnet-dsrouter/DiagnosticsServerRouterCommands.cs
src/Tools/dotnet-dsrouter/Program.cs
src/Tools/dotnet-dsrouter/USBMuxTcpClientRouterFactory.cs [new file with mode: 0644]

index 9f36ba052058c2f06a879f62b088eec9ab8a0607..c92f607b157f8972cd8adc3c8e3b329bdc67dd75 100644 (file)
@@ -150,7 +150,7 @@ namespace Microsoft.Diagnostics.NETCore.Client
     /// </summary>
     internal class TcpServerRouterFactory : IIpcServerTransportCallbackInternal
     {
-        readonly ILogger _logger;
+        protected readonly ILogger _logger;
 
         string _tcpServerAddress;
 
@@ -177,6 +177,13 @@ namespace Microsoft.Diagnostics.NETCore.Client
             get { return _tcpServerAddress; }
         }
 
+        public delegate TcpServerRouterFactory CreateInstanceDelegate(string tcpServer, int runtimeTimeoutMs, ILogger logger);
+
+        public static TcpServerRouterFactory CreateDefaultInstance(string tcpServer, int runtimeTimeoutMs, ILogger logger)
+        {
+            return new TcpServerRouterFactory(tcpServer, runtimeTimeoutMs, logger);
+        }
+
         public TcpServerRouterFactory(string tcpServer, int runtimeTimeoutMs, ILogger logger)
         {
             _logger = logger;
@@ -192,12 +199,12 @@ namespace Microsoft.Diagnostics.NETCore.Client
             _tcpServer.TransportCallback = this;
         }
 
-        public void Start()
+        public virtual void Start()
         {
             _tcpServer.Start();
         }
 
-        public async Task Stop()
+        public virtual async Task Stop()
         {
             await _tcpServer.DisposeAsync().ConfigureAwait(false);
         }
@@ -281,15 +288,22 @@ namespace Microsoft.Diagnostics.NETCore.Client
     /// </summary>
     internal class TcpClientRouterFactory
     {
-        readonly ILogger _logger;
+        protected readonly ILogger _logger;
 
-        readonly string _tcpClientAddress;
+        protected readonly string _tcpClientAddress;
 
-        bool _auto_shutdown;
+        protected bool _auto_shutdown;
+
+        protected int TcpClientTimeoutMs { get; set; } = Timeout.Infinite;
 
-        int TcpClientTimeoutMs { get; set; } = Timeout.Infinite;
+        protected int TcpClientRetryTimeoutMs { get; set; } = 500;
 
-        int TcpClientRetryTimeoutMs { get; set; } = 500;
+        public delegate TcpClientRouterFactory CreateInstanceDelegate(string tcpClient, int runtimeTimeoutMs, ILogger logger);
+
+        public static TcpClientRouterFactory CreateDefaultInstance(string tcpClient, int runtimeTimeoutMs, ILogger logger)
+        {
+            return new TcpClientRouterFactory(tcpClient, runtimeTimeoutMs, logger);
+        }
 
         public string TcpClientAddress {
             get { return _tcpClientAddress; }
@@ -304,7 +318,7 @@ namespace Microsoft.Diagnostics.NETCore.Client
                 TcpClientTimeoutMs = runtimeTimeoutMs;
         }
 
-        public async Task<Stream> ConnectTcpStreamAsync(CancellationToken token)
+        public virtual async Task<Stream> ConnectTcpStreamAsync(CancellationToken token)
         {
             Stream tcpClientStream = null;
 
@@ -368,6 +382,14 @@ namespace Microsoft.Diagnostics.NETCore.Client
             return tcpClientStream;
         }
 
+        public virtual void Start()
+        {
+        }
+
+        public virtual void Stop()
+        {
+        }
+
         async Task ConnectAsyncInternal(Socket clientSocket, EndPoint remoteEP, CancellationToken token)
         {
             using (token.Register(() => clientSocket.Close(0)))
@@ -597,10 +619,10 @@ namespace Microsoft.Diagnostics.NETCore.Client
         TcpServerRouterFactory _tcpServerRouterFactory;
         IpcServerRouterFactory _ipcServerRouterFactory;
 
-        public IpcServerTcpServerRouterFactory(string ipcServer, string tcpServer, int runtimeTimeoutMs, ILogger logger)
+        public IpcServerTcpServerRouterFactory(string ipcServer, string tcpServer, int runtimeTimeoutMs, TcpServerRouterFactory.CreateInstanceDelegate factory, ILogger logger)
         {
             _logger = logger;
-            _tcpServerRouterFactory = new TcpServerRouterFactory(tcpServer, runtimeTimeoutMs, logger);
+            _tcpServerRouterFactory = factory(tcpServer, runtimeTimeoutMs, logger);
             _ipcServerRouterFactory = new IpcServerRouterFactory(ipcServer, logger);
         }
 
@@ -638,6 +660,7 @@ namespace Microsoft.Diagnostics.NETCore.Client
 
         public override Task Stop()
         {
+            _logger?.LogInformation($"Stopping IPC server ({_ipcServerRouterFactory.IpcServerPath}) <--> TCP server ({_tcpServerRouterFactory.TcpServerAddress}) router.");
             _ipcServerRouterFactory.Stop();
             return _tcpServerRouterFactory.Stop();
         }
@@ -786,11 +809,11 @@ namespace Microsoft.Diagnostics.NETCore.Client
         IpcServerRouterFactory _ipcServerRouterFactory;
         TcpClientRouterFactory _tcpClientRouterFactory;
 
-        public IpcServerTcpClientRouterFactory(string ipcServer, string tcpClient, int runtimeTimeoutMs, ILogger logger)
+        public IpcServerTcpClientRouterFactory(string ipcServer, string tcpClient, int runtimeTimeoutMs, TcpClientRouterFactory.CreateInstanceDelegate factory, ILogger logger)
         {
             _logger = logger;
             _ipcServerRouterFactory = new IpcServerRouterFactory(ipcServer, logger);
-            _tcpClientRouterFactory = new TcpClientRouterFactory(tcpClient, runtimeTimeoutMs, logger);
+            _tcpClientRouterFactory = factory(tcpClient, runtimeTimeoutMs, logger);
         }
 
         public override string IpcAddress
@@ -820,11 +843,14 @@ namespace Microsoft.Diagnostics.NETCore.Client
         public override void Start()
         {
             _ipcServerRouterFactory.Start();
+            _tcpClientRouterFactory.Start();
             _logger?.LogInformation($"Starting IPC server ({_ipcServerRouterFactory.IpcServerPath}) <--> TCP client ({_tcpClientRouterFactory.TcpClientAddress}) router.");
         }
 
         public override Task Stop()
         {
+            _logger?.LogInformation($"Stopping IPC server ({_ipcServerRouterFactory.IpcServerPath}) <--> TCP client ({_tcpClientRouterFactory.TcpClientAddress}) router.");
+            _tcpClientRouterFactory.Stop();
             _ipcServerRouterFactory.Stop();
             return Task.CompletedTask;
         }
@@ -905,11 +931,11 @@ namespace Microsoft.Diagnostics.NETCore.Client
         IpcClientRouterFactory _ipcClientRouterFactory;
         TcpServerRouterFactory _tcpServerRouterFactory;
 
-        public IpcClientTcpServerRouterFactory(string ipcClient, string tcpServer, int runtimeTimeoutMs, ILogger logger)
+        public IpcClientTcpServerRouterFactory(string ipcClient, string tcpServer, int runtimeTimeoutMs, TcpServerRouterFactory.CreateInstanceDelegate factory, ILogger logger)
         {
             _logger = logger;
             _ipcClientRouterFactory = new IpcClientRouterFactory(ipcClient, logger);
-            _tcpServerRouterFactory = new TcpServerRouterFactory(tcpServer, runtimeTimeoutMs, logger);
+            _tcpServerRouterFactory = factory(tcpServer, runtimeTimeoutMs, logger);
         }
 
         public override string IpcAddress
@@ -944,6 +970,7 @@ namespace Microsoft.Diagnostics.NETCore.Client
 
         public override Task Stop()
         {
+            _logger?.LogInformation($"Stopping IPC client ({_ipcClientRouterFactory.IpcClientPath}) <--> TCP server ({_tcpServerRouterFactory.TcpServerAddress}) router.");
             return _tcpServerRouterFactory.Stop();
         }
 
index c4fdd81b2c90d8f86753320e5d420383736f10a2..91d0874756c8cc4e4809af23d422cb2f5d81cffb 100644 (file)
@@ -22,19 +22,19 @@ namespace Microsoft.Diagnostics.NETCore.Client
             void OnRouterStopped();
         }
 
-        public static async Task<int> runIpcClientTcpServerRouter(CancellationToken token, string ipcClient, string tcpServer, int runtimeTimeoutMs, ILogger logger, Callbacks callbacks)
+        public static async Task<int> runIpcClientTcpServerRouter(CancellationToken token, string ipcClient, string tcpServer, int runtimeTimeoutMs, TcpServerRouterFactory.CreateInstanceDelegate tcpServerRouterFactory, ILogger logger, Callbacks callbacks)
         {
-            return await runRouter(token, new IpcClientTcpServerRouterFactory(ipcClient, tcpServer, runtimeTimeoutMs, logger), callbacks).ConfigureAwait(false);
+            return await runRouter(token, new IpcClientTcpServerRouterFactory(ipcClient, tcpServer, runtimeTimeoutMs, tcpServerRouterFactory, logger), callbacks).ConfigureAwait(false);
         }
 
-        public static async Task<int> runIpcServerTcpServerRouter(CancellationToken token, string ipcServer, string tcpServer, int runtimeTimeoutMs, ILogger logger, Callbacks callbacks)
+        public static async Task<int> runIpcServerTcpServerRouter(CancellationToken token, string ipcServer, string tcpServer, int runtimeTimeoutMs, TcpServerRouterFactory.CreateInstanceDelegate tcpServerRouterFactory, ILogger logger, Callbacks callbacks)
         {
-            return await runRouter(token, new IpcServerTcpServerRouterFactory(ipcServer, tcpServer, runtimeTimeoutMs, logger), callbacks).ConfigureAwait(false);
+            return await runRouter(token, new IpcServerTcpServerRouterFactory(ipcServer, tcpServer, runtimeTimeoutMs, tcpServerRouterFactory, logger), callbacks).ConfigureAwait(false);
         }
 
-        public static async Task<int> runIpcServerTcpClientRouter(CancellationToken token, string ipcServer, string tcpClient, int runtimeTimeoutMs, ILogger logger, Callbacks callbacks)
+        public static async Task<int> runIpcServerTcpClientRouter(CancellationToken token, string ipcServer, string tcpClient, int runtimeTimeoutMs, TcpClientRouterFactory.CreateInstanceDelegate tcpClientRouterFactory, ILogger logger, Callbacks callbacks)
         {
-            return await runRouter(token, new IpcServerTcpClientRouterFactory(ipcServer, tcpClient, runtimeTimeoutMs, logger), callbacks).ConfigureAwait(false);
+            return await runRouter(token, new IpcServerTcpClientRouterFactory(ipcServer, tcpClient, runtimeTimeoutMs, tcpClientRouterFactory, logger), callbacks).ConfigureAwait(false);
         }
 
         public static bool isLoopbackOnly(string address)
diff --git a/src/Tools/dotnet-dsrouter/ADBTcpRouterFactory.cs b/src/Tools/dotnet-dsrouter/ADBTcpRouterFactory.cs
new file mode 100644 (file)
index 0000000..f741593
--- /dev/null
@@ -0,0 +1,173 @@
+using System;
+using System.Diagnostics;
+using System.IO;
+using System.Threading.Tasks;
+using Microsoft.Diagnostics.NETCore.Client;
+using Microsoft.Extensions.Logging;
+
+namespace Microsoft.Diagnostics.Tools.DiagnosticsServerRouter
+{
+    internal class ADBCommandExec
+    {
+        public static bool AdbAddPortForward(int port, ILogger logger)
+        {
+            bool ownsPortForward = false;
+            if (!RunAdbCommandInternal($"forward --list", $"tcp:{port}", 0, logger))
+            {
+                ownsPortForward = RunAdbCommandInternal($"forward tcp:{port} tcp:{port}", "", 0, logger);
+                if (!ownsPortForward)
+                    logger?.LogError($"Failed setting up port forward for tcp:{port}.");
+            }
+            return ownsPortForward;
+        }
+
+        public static bool AdbAddPortReverse(int port, ILogger logger)
+        {
+            bool ownsPortForward = false;
+            if (!RunAdbCommandInternal($"reverse --list", $"tcp:{port}", 0, logger))
+            {
+                ownsPortForward = RunAdbCommandInternal($"reverse tcp:{port} tcp:{port}", "", 0, logger);
+                if (!ownsPortForward)
+                    logger?.LogError($"Failed setting up port forward for tcp:{port}.");
+            }
+            return ownsPortForward;
+        }
+
+        public static void AdbRemovePortForward(int port, bool ownsPortForward, ILogger logger)
+        {
+            if (ownsPortForward)
+            {
+                if (!RunAdbCommandInternal($"forward --remove tcp:{port}", "", 0, logger))
+                    logger?.LogError($"Failed removing port forward for tcp:{port}.");
+            }
+        }
+
+        public static void AdbRemovePortReverse(int port, bool ownsPortForward, ILogger logger)
+        {
+            if (ownsPortForward)
+            {
+                if (!RunAdbCommandInternal($"reverse --remove tcp:{port}", "", 0, logger))
+                    logger?.LogError($"Failed removing port forward for tcp:{port}.");
+            }
+        }
+
+        public static bool RunAdbCommandInternal(string command, string expectedOutput, int expectedExitCode, ILogger logger)
+        {
+            var sdkRoot = Environment.GetEnvironmentVariable("ANDROID_SDK_ROOT");
+            var adbTool = "adb";
+
+            if (!string.IsNullOrEmpty(sdkRoot))
+                adbTool = sdkRoot + Path.DirectorySeparatorChar + "platform-tools" + Path.DirectorySeparatorChar + adbTool;
+
+            logger?.LogDebug($"Executing {adbTool} {command}.");
+
+            var process = new Process();
+            process.StartInfo.FileName = adbTool;
+            process.StartInfo.Arguments = command;
+
+            process.StartInfo.UseShellExecute = false;
+            process.StartInfo.RedirectStandardOutput = true;
+            process.StartInfo.RedirectStandardError = true;
+            process.StartInfo.RedirectStandardInput = false;
+
+            bool processStartedResult = false;
+            bool expectedOutputResult = true;
+            bool expectedExitCodeResult = true;
+
+            try
+            {
+                processStartedResult = process.Start();
+            }
+            catch (Exception)
+            {
+            }
+
+            if (processStartedResult)
+            {
+                var stdout = process.StandardOutput.ReadToEnd();
+                var stderr = process.StandardError.ReadToEnd();
+
+                if (!string.IsNullOrEmpty(expectedOutput))
+                    expectedOutputResult = !string.IsNullOrEmpty(stdout) ? stdout.Contains(expectedOutput) : false;
+
+                if (!string.IsNullOrEmpty(stdout))
+                    logger.LogTrace($"stdout: {stdout}");
+
+                if (!string.IsNullOrEmpty(stderr))
+                    logger.LogError($"stderr: {stderr}");
+            }
+
+            if (processStartedResult)
+            {
+                process.WaitForExit();
+                expectedExitCodeResult = (expectedExitCode != -1) ? (process.ExitCode == expectedExitCode) : true;
+            }
+
+            return processStartedResult && expectedOutputResult && expectedExitCodeResult;
+        }
+    }
+
+    internal class ADBTcpServerRouterFactory : TcpServerRouterFactory
+    {
+        readonly int _port;
+        bool _ownsPortReverse;
+
+        public static TcpServerRouterFactory CreateADBInstance(string tcpServer, int runtimeTimeoutMs, ILogger logger)
+        {
+            return new ADBTcpServerRouterFactory(tcpServer, runtimeTimeoutMs, logger);
+        }
+
+        public ADBTcpServerRouterFactory(string tcpServer, int runtimeTimeoutMs, ILogger logger)
+            : base(tcpServer, runtimeTimeoutMs, logger)
+        {
+            _port = new IpcTcpSocketEndPoint(tcpServer).EndPoint.Port;
+        }
+
+        public override void Start()
+        {
+            // Enable port reverse.
+            _ownsPortReverse = ADBCommandExec.AdbAddPortReverse(_port, _logger);
+
+            base.Start();
+        }
+
+        public override async Task Stop()
+        {
+            await base.Stop().ConfigureAwait(false);
+
+            // Disable port reverse.
+            ADBCommandExec.AdbRemovePortReverse(_port, _ownsPortReverse, _logger);
+            _ownsPortReverse = false;
+        }
+    }
+
+    internal class ADBTcpClientRouterFactory : TcpClientRouterFactory
+    {
+        readonly int _port;
+        bool _ownsPortForward;
+
+        public static TcpClientRouterFactory CreateADBInstance(string tcpClient, int runtimeTimeoutMs, ILogger logger)
+        {
+            return new ADBTcpClientRouterFactory(tcpClient, runtimeTimeoutMs, logger);
+        }
+
+        public ADBTcpClientRouterFactory(string tcpClient, int runtimeTimeoutMs, ILogger logger)
+            : base(tcpClient, runtimeTimeoutMs, logger)
+        {
+            _port = new IpcTcpSocketEndPoint(tcpClient).EndPoint.Port;
+        }
+
+        public override void Start()
+        {
+            // Enable port forwarding.
+            _ownsPortForward = ADBCommandExec.AdbAddPortForward(_port, _logger);
+        }
+
+        public override void Stop()
+        {
+            // Disable port forwarding.
+            ADBCommandExec.AdbRemovePortForward(_port, _ownsPortForward, _logger);
+            _ownsPortForward = false;
+        }
+    }
+}
index c51639f3d85cc2d3077e833ea22e1443fafdcdd4..f79b5cbcc75890ea5f90ce3d3cf629f71707a213 100644 (file)
@@ -48,7 +48,7 @@ namespace Microsoft.Diagnostics.Tools.DiagnosticsServerRouter
         {
         }
 
-        public async Task<int> RunIpcClientTcpServerRouter(CancellationToken token, string ipcClient, string tcpServer, int runtimeTimeout, string verbose)
+        public async Task<int> RunIpcClientTcpServerRouter(CancellationToken token, string ipcClient, string tcpServer, int runtimeTimeout, string verbose, string forwardPort)
         {
             checkLoopbackOnly(tcpServer);
 
@@ -69,7 +69,22 @@ namespace Microsoft.Diagnostics.Tools.DiagnosticsServerRouter
             Launcher.Verbose = logLevel != LogLevel.Information;
             Launcher.CommandToken = token;
 
-            var routerTask = DiagnosticsServerRouterRunner.runIpcClientTcpServerRouter(linkedCancelToken.Token, ipcClient, tcpServer, runtimeTimeout == Timeout.Infinite ? runtimeTimeout : runtimeTimeout * 1000, factory.CreateLogger("dotnet-dsrounter"), Launcher);
+            var logger = factory.CreateLogger("dotnet-dsrouter");
+
+            TcpServerRouterFactory.CreateInstanceDelegate tcpServerRouterFactory = TcpServerRouterFactory.CreateDefaultInstance;
+            if (!string.IsNullOrEmpty(forwardPort))
+            {
+                if (string.Compare(forwardPort, "android", StringComparison.OrdinalIgnoreCase) == 0)
+                {
+                    tcpServerRouterFactory = ADBTcpServerRouterFactory.CreateADBInstance;
+                }
+                else
+                {
+                    logger.LogError($"Unknown port forwarding argument, {forwardPort}. Only Android port fowarding is supported for TcpServer mode. Ignoring --forward-port argument.");
+                }
+            }
+
+            var routerTask = DiagnosticsServerRouterRunner.runIpcClientTcpServerRouter(linkedCancelToken.Token, ipcClient, tcpServer, runtimeTimeout == Timeout.Infinite ? runtimeTimeout : runtimeTimeout * 1000, tcpServerRouterFactory, logger, Launcher);
 
             while (!linkedCancelToken.IsCancellationRequested)
             {
@@ -91,7 +106,7 @@ namespace Microsoft.Diagnostics.Tools.DiagnosticsServerRouter
             return routerTask.Result;
         }
 
-        public async Task<int> RunIpcServerTcpServerRouter(CancellationToken token, string ipcServer, string tcpServer, int runtimeTimeout, string verbose)
+        public async Task<int> RunIpcServerTcpServerRouter(CancellationToken token, string ipcServer, string tcpServer, int runtimeTimeout, string verbose, string forwardPort)
         {
             checkLoopbackOnly(tcpServer);
 
@@ -112,7 +127,22 @@ namespace Microsoft.Diagnostics.Tools.DiagnosticsServerRouter
             Launcher.Verbose = logLevel != LogLevel.Information;
             Launcher.CommandToken = token;
 
-            var routerTask = DiagnosticsServerRouterRunner.runIpcServerTcpServerRouter(linkedCancelToken.Token, ipcServer, tcpServer, runtimeTimeout == Timeout.Infinite ? runtimeTimeout : runtimeTimeout * 1000, factory.CreateLogger("dotnet-dsrounter"), Launcher);
+            var logger = factory.CreateLogger("dotnet-dsrouter");
+
+            TcpServerRouterFactory.CreateInstanceDelegate tcpServerRouterFactory = TcpServerRouterFactory.CreateDefaultInstance;
+            if (!string.IsNullOrEmpty(forwardPort))
+            {
+                if (string.Compare(forwardPort, "android", StringComparison.OrdinalIgnoreCase) == 0)
+                {
+                    tcpServerRouterFactory = ADBTcpServerRouterFactory.CreateADBInstance;
+                }
+                else
+                {
+                    logger.LogError($"Unknown port forwarding argument, {forwardPort}. Only Android port fowarding is supported for TcpServer mode. Ignoring --forward-port argument.");
+                }
+            }
+
+            var routerTask = DiagnosticsServerRouterRunner.runIpcServerTcpServerRouter(linkedCancelToken.Token, ipcServer, tcpServer, runtimeTimeout == Timeout.Infinite ? runtimeTimeout : runtimeTimeout * 1000, tcpServerRouterFactory, logger, Launcher);
 
             while (!linkedCancelToken.IsCancellationRequested)
             {
@@ -134,7 +164,7 @@ namespace Microsoft.Diagnostics.Tools.DiagnosticsServerRouter
             return routerTask.Result;
         }
 
-        public async Task<int> RunIpcServerTcpClientRouter(CancellationToken token, string ipcServer, string tcpClient, int runtimeTimeout, string verbose)
+        public async Task<int> RunIpcServerTcpClientRouter(CancellationToken token, string ipcServer, string tcpClient, int runtimeTimeout, string verbose, string forwardPort)
         {
             using CancellationTokenSource cancelRouterTask = new CancellationTokenSource();
             using CancellationTokenSource linkedCancelToken = CancellationTokenSource.CreateLinkedTokenSource(token, cancelRouterTask.Token);
@@ -153,7 +183,26 @@ namespace Microsoft.Diagnostics.Tools.DiagnosticsServerRouter
             Launcher.Verbose = logLevel != LogLevel.Information;
             Launcher.CommandToken = token;
 
-            var routerTask = DiagnosticsServerRouterRunner.runIpcServerTcpClientRouter(linkedCancelToken.Token, ipcServer, tcpClient, runtimeTimeout == Timeout.Infinite ? runtimeTimeout : runtimeTimeout * 1000, factory.CreateLogger("dotnet-dsrounter"), Launcher);
+            var logger = factory.CreateLogger("dotnet-dsrouter");
+
+            TcpClientRouterFactory.CreateInstanceDelegate tcpClientRouterFactory = TcpClientRouterFactory.CreateDefaultInstance;
+            if (!string.IsNullOrEmpty(forwardPort))
+            {
+                if (string.Compare(forwardPort, "android", StringComparison.OrdinalIgnoreCase) == 0)
+                {
+                    tcpClientRouterFactory = ADBTcpClientRouterFactory.CreateADBInstance;
+                }
+                else if (string.Compare(forwardPort, "ios", StringComparison.OrdinalIgnoreCase) == 0)
+                {
+                    tcpClientRouterFactory = USBMuxTcpClientRouterFactory.CreateUSBMuxInstance;
+                }
+                else
+                {
+                    logger.LogError($"Unknown port forwarding argument, {forwardPort}. Ignoring --forward-port argument.");
+                }
+            }
+
+            var routerTask = DiagnosticsServerRouterRunner.runIpcServerTcpClientRouter(linkedCancelToken.Token, ipcServer, tcpClient, runtimeTimeout == Timeout.Infinite ? runtimeTimeout : runtimeTimeout * 1000, tcpClientRouterFactory, logger, Launcher);
 
             while (!linkedCancelToken.IsCancellationRequested)
             {
index a56da4fbb8da4ad735b42eaf3c2d475cbeb3da3b..52ab8165a061a4565429172211fc8e9413ad5416 100644 (file)
@@ -17,9 +17,9 @@ namespace Microsoft.Diagnostics.Tools.DiagnosticsServerRouter
 {
     internal class Program
     {
-        delegate Task<int> DiagnosticsServerIpcClientTcpServerRouterDelegate(CancellationToken ct, string ipcClient, string tcpServer, int runtimeTimeoutS, string verbose);
-        delegate Task<int> DiagnosticsServerIpcServerTcpServerRouterDelegate(CancellationToken ct, string ipcServer, string tcpServer, int runtimeTimeoutS, string verbose);
-        delegate Task<int> DiagnosticsServerIpcServerTcpClientRouterDelegate(CancellationToken ct, string ipcServer, string tcpClient, int runtimeTimeoutS, string verbose);
+        delegate Task<int> DiagnosticsServerIpcClientTcpServerRouterDelegate(CancellationToken ct, string ipcClient, string tcpServer, int runtimeTimeoutS, string verbose, string forwardPort);
+        delegate Task<int> DiagnosticsServerIpcServerTcpServerRouterDelegate(CancellationToken ct, string ipcServer, string tcpServer, int runtimeTimeoutS, string verbose, string forwardPort);
+        delegate Task<int> DiagnosticsServerIpcServerTcpClientRouterDelegate(CancellationToken ct, string ipcServer, string tcpClient, int runtimeTimeoutS, string verbose, string forwardPort);
 
         private static Command IpcClientTcpServerRouterCommand() =>
             new Command(
@@ -31,7 +31,7 @@ namespace Microsoft.Diagnostics.Tools.DiagnosticsServerRouter
                 // Handler
                 HandlerDescriptor.FromDelegate((DiagnosticsServerIpcClientTcpServerRouterDelegate)new DiagnosticsServerRouterCommands().RunIpcClientTcpServerRouter).GetCommandHandler(),
                 // Options
-                IpcClientAddressOption(), TcpServerAddressOption(), RuntimeTimeoutOption(), VerboseOption()
+                IpcClientAddressOption(), TcpServerAddressOption(), RuntimeTimeoutOption(), VerboseOption(), ForwardPortOption()
             };
 
         private static Command IpcServerTcpServerRouterCommand() =>
@@ -44,7 +44,7 @@ namespace Microsoft.Diagnostics.Tools.DiagnosticsServerRouter
                 // Handler
                 HandlerDescriptor.FromDelegate((DiagnosticsServerIpcServerTcpServerRouterDelegate)new DiagnosticsServerRouterCommands().RunIpcServerTcpServerRouter).GetCommandHandler(),
                 // Options
-                IpcServerAddressOption(), TcpServerAddressOption(), RuntimeTimeoutOption(), VerboseOption()
+                IpcServerAddressOption(), TcpServerAddressOption(), RuntimeTimeoutOption(), VerboseOption(), ForwardPortOption()
             };
 
         private static Command IpcServerTcpClientRouterCommand() =>
@@ -57,7 +57,7 @@ namespace Microsoft.Diagnostics.Tools.DiagnosticsServerRouter
                 // Handler
                 HandlerDescriptor.FromDelegate((DiagnosticsServerIpcServerTcpClientRouterDelegate)new DiagnosticsServerRouterCommands().RunIpcServerTcpClientRouter).GetCommandHandler(),
                 // Options
-                IpcServerAddressOption(), TcpClientAddressOption(), RuntimeTimeoutOption(), VerboseOption()
+                IpcServerAddressOption(), TcpClientAddressOption(), RuntimeTimeoutOption(), VerboseOption(), ForwardPortOption()
             };
 
         private static Option IpcClientAddressOption() =>
@@ -118,6 +118,14 @@ namespace Microsoft.Diagnostics.Tools.DiagnosticsServerRouter
                 Argument = new Argument<string>(name: "verbose", getDefaultValue: () => "")
             };
 
+        private static Option ForwardPortOption() =>
+            new Option(
+                aliases: new[] { "--forward-port", "-fp" },
+                description: "Enable port forwarding, values Android|iOS for TcpClient and only Android for TcpServer. Make sure to set ANDROID_SDK_ROOT before using this option on Android.")
+            {
+                Argument = new Argument<string>(name: "forwardPort", getDefaultValue: () => "")
+            };
+
         private static int Main(string[] args)
         {
             StringBuilder message = new StringBuilder();
diff --git a/src/Tools/dotnet-dsrouter/USBMuxTcpClientRouterFactory.cs b/src/Tools/dotnet-dsrouter/USBMuxTcpClientRouterFactory.cs
new file mode 100644 (file)
index 0000000..29ca466
--- /dev/null
@@ -0,0 +1,492 @@
+using System;
+using System.IO;
+using System.Net;
+using System.Runtime.InteropServices;
+using System.Threading;
+using System.Threading.Tasks;
+using Microsoft.Diagnostics.NETCore.Client;
+using Microsoft.Extensions.Logging;
+
+namespace Microsoft.Diagnostics.Tools.DiagnosticsServerRouter
+{
+    internal class USBMuxInterop
+    {
+        public const string CoreFoundationLibrary = "/System/Library/Frameworks/CoreFoundation.framework/CoreFoundation";
+        public const string MobileDeviceLibrary = "/System/Library/PrivateFrameworks/MobileDevice.framework/MobileDevice";
+        public const string LibC = "libc";
+
+        public const int EINTR = 4;
+
+        public enum AMDeviceNotificationMessage : uint
+        {
+            None = 0,
+            Connected = 1,
+            Disconnected = 2,
+            Unsubscribed = 3
+        }
+
+        public struct AMDeviceNotificationCallbackInfo
+        {
+            public AMDeviceNotificationCallbackInfo(IntPtr device, AMDeviceNotificationMessage message)
+            {
+                this.am_device = device;
+                this.message = message;
+            }
+
+            public IntPtr am_device;
+            public AMDeviceNotificationMessage message;
+        }
+
+        public delegate void DeviceNotificationDelegate(ref AMDeviceNotificationCallbackInfo info);
+
+#region MobileDeviceLibrary
+        [DllImport(MobileDeviceLibrary)]
+        public static extern uint AMDeviceNotificationSubscribe(DeviceNotificationDelegate callback, uint unused0, uint unused1, uint unused2, out IntPtr context);
+
+        [DllImport(MobileDeviceLibrary)]
+        public static extern uint AMDeviceNotificationUnsubscribe(IntPtr context);
+
+        [DllImport(MobileDeviceLibrary)]
+        public static extern uint AMDeviceConnect(IntPtr device);
+
+        [DllImport(MobileDeviceLibrary)]
+        public static extern uint AMDeviceDisconnect(IntPtr device);
+
+        [DllImport(MobileDeviceLibrary)]
+        public static extern uint AMDeviceGetConnectionID(IntPtr device);
+
+        [DllImport(MobileDeviceLibrary)]
+        public static extern int AMDeviceGetInterfaceType(IntPtr device);
+
+        [DllImport(MobileDeviceLibrary)]
+        public static extern uint USBMuxConnectByPort(uint connection, ushort port, out int socketHandle);
+#endregion
+#region CoreFoundationLibrary
+        [DllImport(CoreFoundationLibrary)]
+        public static extern void CFRunLoopRun();
+
+        [DllImport(CoreFoundationLibrary)]
+        public static extern void CFRunLoopStop(IntPtr runLoop);
+
+        [DllImport(CoreFoundationLibrary)]
+        public static extern IntPtr CFRunLoopGetCurrent();
+#endregion
+#region LibC
+        [DllImport(LibC, SetLastError = true)]
+        public static extern unsafe int send(int handle, byte* buffer, IntPtr length, int flags);
+
+        [DllImport(LibC, SetLastError = true)]
+        public static extern unsafe int recv(int handle, byte* buffer, IntPtr length, int flags);
+
+        [DllImport(LibC, SetLastError = true)]
+        public static extern int close(int handle);
+#endregion
+    }
+
+    internal class USBMuxStream : Stream
+    {
+        int _handle = -1;
+
+        public USBMuxStream(int handle)
+        {
+            _handle = handle;
+        }
+
+        public bool IsOpen => _handle != -1;
+
+        public override bool CanRead => true;
+
+        public override bool CanSeek => false;
+
+        public override bool CanWrite => true;
+
+        public override long Length => throw new NotImplementedException();
+
+        public override long Position {
+            get => throw new NotImplementedException();
+            set => throw new NotImplementedException();
+        }
+
+        public override void Flush()
+        {
+        }
+
+        public override Task FlushAsync(CancellationToken cancellationToken)
+        {
+            return Task.CompletedTask;
+        }
+
+        public override int Read(byte[] buffer, int offset, int count)
+        {
+            bool continueRead = true;
+            int bytesToRead = count;
+            int totalBytesRead = 0;
+            int currentBytesRead = 0;
+
+            while (continueRead && bytesToRead - totalBytesRead > 0)
+            {
+                if (!IsOpen)
+                    throw new EndOfStreamException();
+
+                unsafe
+                {
+                    fixed (byte* fixedBuffer = buffer)
+                    {
+                        currentBytesRead = USBMuxInterop.recv(_handle, fixedBuffer + totalBytesRead, new IntPtr(bytesToRead - totalBytesRead), 0);
+                    }
+                }
+
+                if (currentBytesRead == -1 && Marshal.GetLastWin32Error() == USBMuxInterop.EINTR)
+                    continue;
+
+                continueRead = currentBytesRead > 0;
+                if (!continueRead)
+                    break;
+
+                totalBytesRead += currentBytesRead;
+            }
+
+            return totalBytesRead;
+        }
+
+        public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
+        {
+            return Task.Run(() =>
+            {
+                int result = 0;
+                using (cancellationToken.Register(() => Close()))
+                {
+                    try
+                    {
+                        result = Read(buffer, offset, count);
+                    }
+                    catch (Exception)
+                    {
+                        cancellationToken.ThrowIfCancellationRequested();
+                        result = 0;
+                    }
+                }
+                return result;
+            });
+        }
+
+        public override long Seek(long offset, SeekOrigin origin)
+        {
+            throw new NotImplementedException();
+        }
+
+        public override void SetLength(long value)
+        {
+            throw new NotImplementedException();
+        }
+
+        public override void Write(byte[] buffer, int offset, int count)
+        {
+            bool continueWrite = true;
+            int bytesToWrite = count;
+            int currentBytesWritten = 0;
+            int totalBytesWritten = 0;
+
+            while (continueWrite && bytesToWrite - totalBytesWritten > 0)
+            {
+                if (!IsOpen)
+                    throw new EndOfStreamException();
+
+                unsafe
+                {
+                    fixed (byte* fixedBuffer = buffer)
+                    {
+                        currentBytesWritten = USBMuxInterop.send(_handle, fixedBuffer + totalBytesWritten, new IntPtr(bytesToWrite - totalBytesWritten), 0);
+                    }
+                }
+
+                if (currentBytesWritten == -1 && Marshal.GetLastWin32Error() == USBMuxInterop.EINTR)
+                    continue;
+
+                continueWrite = currentBytesWritten != -1;
+
+                if (!continueWrite)
+                    break;
+
+                totalBytesWritten += currentBytesWritten;
+            }
+        }
+
+        public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
+        {
+            return Task.Run(() =>
+            {
+                using (cancellationToken.Register(() => Close()))
+                {
+                    Write(buffer, offset, count);
+                }
+            });
+        }
+
+        public override void Close()
+        {
+            if (IsOpen)
+            {
+                USBMuxInterop.close(_handle);
+                _handle = -1;
+            }
+        }
+
+        protected override void Dispose(bool disposing)
+        {
+            Close();
+            base.Dispose(disposing);
+        }
+    }
+
+    internal class USBMuxTcpClientRouterFactory : TcpClientRouterFactory
+    {
+        readonly int _port;
+
+        IntPtr _device = IntPtr.Zero;
+        uint _deviceConnectionID = 0;
+        IntPtr _loopingThread = IntPtr.Zero;
+
+        public static TcpClientRouterFactory CreateUSBMuxInstance(string tcpClient, int runtimeTimeoutMs, ILogger logger)
+        {
+            return new USBMuxTcpClientRouterFactory(tcpClient, runtimeTimeoutMs, logger);
+        }
+
+        public USBMuxTcpClientRouterFactory(string tcpClient, int runtimeTimeoutMs, ILogger logger)
+            : base(tcpClient, runtimeTimeoutMs, logger)
+        {
+            _port = new IpcTcpSocketEndPoint(tcpClient).EndPoint.Port;
+        }
+
+        public override async Task<Stream> ConnectTcpStreamAsync(CancellationToken token)
+        {
+            bool retry = false;
+            int handle = -1;
+            ushort networkPort = (ushort)IPAddress.HostToNetworkOrder(unchecked((short)_port));
+
+            _logger?.LogDebug($"Connecting new tcp endpoint over usbmux \"{_tcpClientAddress}\".");
+
+            using var connectTimeoutTokenSource = new CancellationTokenSource();
+            using var connectTokenSource = CancellationTokenSource.CreateLinkedTokenSource(token, connectTimeoutTokenSource.Token);
+
+            connectTimeoutTokenSource.CancelAfter(TcpClientTimeoutMs);
+
+            do
+            {
+                try
+                {
+                    handle = ConnectTcpClientOverUSBMux();
+                    retry = false;
+                }
+                catch (Exception)
+                {
+                    if (connectTimeoutTokenSource.IsCancellationRequested)
+                    {
+                        _logger?.LogDebug("No USB stream connected, timing out.");
+
+                        if (_auto_shutdown)
+                            throw new RuntimeTimeoutException(TcpClientTimeoutMs);
+
+                        throw new TimeoutException();
+                    }
+
+                    // If we are not doing auto shutdown when runtime is unavailable, fail right away, this will
+                    // break any accepted IPC connections, making sure client is notified and could reconnect.
+                    // If we do have auto shutdown enabled, retry until succeed or time out.
+                    if (!_auto_shutdown)
+                    {
+                        _logger?.LogTrace($"Failed connecting {_port} over usbmux.");
+                        throw;
+                    }
+
+                    _logger?.LogTrace($"Failed connecting {_port} over usbmux, wait {TcpClientRetryTimeoutMs} ms before retrying.");
+
+                    // If we get an error (without hitting timeout above), most likely due to unavailable device/listener.
+                    // Delay execution to prevent to rapid retry attempts.
+                    await Task.Delay(TcpClientRetryTimeoutMs, token).ConfigureAwait(false);
+
+                    retry = true;
+                }
+            }
+            while (retry);
+
+            return new USBMuxStream(handle);
+        }
+
+        public override void Start()
+        {
+            // Start device subscription thread.
+            StartNotificationSubscribeThread();
+        }
+
+        public override void Stop()
+        {
+            // Stop device subscription thread.
+            StopNotificationSubscribeThread();
+        }
+
+        int ConnectTcpClientOverUSBMux()
+        {
+            uint result = 0;
+            int handle = -1;
+            ushort networkPort = (ushort)IPAddress.HostToNetworkOrder(unchecked((short)_port));
+
+            lock (this)
+            {
+                if (_deviceConnectionID == 0)
+                    throw new Exception($"Failed to connect device over USB, no device currently connected.");
+
+                result = USBMuxInterop.USBMuxConnectByPort(_deviceConnectionID, networkPort, out handle);
+            }
+
+            if (result != 0)
+                throw new Exception($"Failed to connect device over USB using connection {_deviceConnectionID} and port {_port}.");
+
+            return handle;
+        }
+
+        bool ConnectDevice(IntPtr newDevice)
+        {
+            if (_device != IntPtr.Zero)
+                return false;
+
+            _device = newDevice;
+            if (USBMuxInterop.AMDeviceConnect(_device) == 0)
+            {
+                _deviceConnectionID = USBMuxInterop.AMDeviceGetConnectionID(_device);
+                _logger?.LogInformation($"Successfully connected new device, id={_deviceConnectionID}.");
+                return true;
+            }
+            else
+            {
+                _logger?.LogError($"Failed connecting new device.");
+                return false;
+            }
+        }
+
+        bool DisconnectDevice()
+        {
+            if (_device != IntPtr.Zero)
+            {
+                if (_deviceConnectionID != 0)
+                {
+                    USBMuxInterop.AMDeviceDisconnect(_device);
+                    _logger?.LogInformation($"Successfully disconnected device, id={_deviceConnectionID}.");
+                    _deviceConnectionID = 0;
+                }
+
+                _device = IntPtr.Zero;
+            }
+
+            return true;
+        }
+
+        void AMDeviceNotificationCallback(ref USBMuxInterop.AMDeviceNotificationCallbackInfo info)
+        {
+            _logger?.LogTrace($"AMDeviceNotificationInternal callback, device={info.am_device}, action={info.message}");
+
+            try
+            {
+                lock (this)
+                {
+                    int interfaceType = USBMuxInterop.AMDeviceGetInterfaceType(info.am_device);
+                    switch (info.message)
+                    {
+                        case USBMuxInterop.AMDeviceNotificationMessage.Connected:
+                            if (interfaceType == 1 && _device == IntPtr.Zero)
+                            {
+                                ConnectDevice(info.am_device);
+                            }
+                            else if (interfaceType == 1 && _device != IntPtr.Zero)
+                            {
+                                _logger?.LogInformation($"Discovered new device, but one is already connected, ignoring new device.");
+                            }
+                            else if (interfaceType == 0)
+                            {
+                                _logger?.LogInformation($"Discovered new device not connected over USB, ignoring new device.");
+                            }
+                            break;
+                        case USBMuxInterop.AMDeviceNotificationMessage.Disconnected:
+                        case USBMuxInterop.AMDeviceNotificationMessage.Unsubscribed:
+                            if (_device == info.am_device)
+                            {
+                                DisconnectDevice();
+                            }
+                            break;
+                    }
+                }
+            }
+            catch (Exception ex)
+            {
+                _logger?.LogError($"Failed AMDeviceNotificationCallback: {ex.Message}. Failed handling device={info.am_device} using action={info.message}");
+            }
+        }
+
+        void AMDeviceNotificationSubscribeLoop()
+        {
+            IntPtr context = IntPtr.Zero;
+
+            try
+            {
+                lock (this)
+                {
+                    if (_loopingThread != IntPtr.Zero)
+                    {
+                        _logger?.LogError($"AMDeviceNotificationSubscribeLoop already running.");
+                        throw new Exception("AMDeviceNotificationSubscribeLoop already running.");
+                    }
+
+                    _loopingThread = USBMuxInterop.CFRunLoopGetCurrent();
+                }
+
+                _logger?.LogTrace($"Calling AMDeviceNotificationSubscribe.");
+
+                if (USBMuxInterop.AMDeviceNotificationSubscribe(AMDeviceNotificationCallback, 0, 0, 0, out context) != 0)
+                {
+                    _logger?.LogError($"Failed AMDeviceNotificationSubscribe call.");
+                    throw new Exception("Failed AMDeviceNotificationSubscribe call.");
+                }
+
+                _logger?.LogTrace($"Start dispatching notifications.");
+                USBMuxInterop.CFRunLoopRun();
+                _logger?.LogTrace($"Stop dispatching notifications.");
+            }
+            catch (Exception ex)
+            {
+                _logger?.LogError($"Failed running subscribe loop: {ex.Message}. Disabling detection of devices connected over USB.");
+            }
+            finally
+            {
+                lock (this)
+                {
+                    if (_loopingThread != IntPtr.Zero)
+                    {
+                        _loopingThread = IntPtr.Zero;
+                    }
+
+                    DisconnectDevice();
+                }
+
+                if (context != IntPtr.Zero)
+                {
+                    _logger?.LogTrace($"Calling AMDeviceNotificationUnsubscribe.");
+                    USBMuxInterop.AMDeviceNotificationUnsubscribe(context);
+                }
+            }
+        }
+
+        void StartNotificationSubscribeThread()
+        {
+            new Thread(new ThreadStart(() => AMDeviceNotificationSubscribeLoop())).Start();
+        }
+
+        void StopNotificationSubscribeThread()
+        {
+            lock (this)
+            {
+                if (_loopingThread != IntPtr.Zero)
+                    USBMuxInterop.CFRunLoopStop(_loopingThread);
+            }
+        }
+    }
+}