[BYOC][ACL] Enable remote device via environment variables (#6279)
authorlhutton1 <35535092+lhutton1@users.noreply.github.com>
Tue, 25 Aug 2020 16:04:20 +0000 (17:04 +0100)
committerGitHub <noreply@github.com>
Tue, 25 Aug 2020 16:04:20 +0000 (09:04 -0700)
* [BYOC][ACL] Enable remote device via environment variables

Improves the ACL remote testing infrastructure by allowing a remote device to be specified via environment variables. This means external scripts can be used to enable the runtime tests. By default an RPC server will not be used and the runtime tests will be skipped.

Change-Id: I8fc0b88106683ac6f1cbff44c8954726325cda21

* Use json file as configuration for tests

Change-Id: Iadce931d91056ed3a2d57a49f14af1ce771ae14b

* Do not load the test config during class creation

Change-Id: If718b5d163e399711111830f878db325db9c5f84

* Add check for existence of file

Change-Id: I2568bca7f4c3ad22ee8f9d065a9486ee3114f35c

docs/deploy/arm_compute_lib.rst
tests/python/contrib/test_arm_compute_lib/infrastructure.py
tests/python/contrib/test_arm_compute_lib/test_config.json [new file with mode: 0644]
tests/python/contrib/test_arm_compute_lib/test_conv2d.py
tests/python/contrib/test_arm_compute_lib/test_dense.py
tests/python/contrib/test_arm_compute_lib/test_network.py
tests/python/contrib/test_arm_compute_lib/test_pooling.py
tests/python/contrib/test_arm_compute_lib/test_reshape.py
tests/python/contrib/test_arm_compute_lib/test_runtime.py

index c0b1a7e..26b42ae 100644 (file)
@@ -162,7 +162,28 @@ More examples
 The example above only shows a basic example of how ACL can be used for offloading a single
 Maxpool2D. If you would like to see more examples for each implemented operator and for
 networks refer to the tests: `tests/python/contrib/test_arm_compute_lib`. Here you can modify
-`infrastructure.py` to use the remote device you have setup.
+`test_config.json` to configure how a remote device is created in `infrastructure.py` and,
+as a result, how runtime tests will be run.
+
+An example configuration for `test_config.json`:
+
+* connection_type - The type of RPC connection. Options: local, tracker, remote.
+* host - The host device to connect to.
+* port - The port to use when connecting.
+* target - The target to use for compilation.
+* device_key - The device key when connecting via a tracker.
+* cross_compile - Path to cross compiler when connecting from a non-arm platform e.g. aarch64-linux-gnu-g++.
+
+.. code:: json
+
+    {
+      "connection_type": "local",
+      "host": "localhost",
+      "port": 9090,
+      "target": "llvm -mtriple=aarch64-linux-gnu -mattr=+neon",
+      "device_key": "",
+      "cross_compile": ""
+    }
 
 
 Operator support
index 5ed2763..4e930e2 100644 (file)
@@ -16,6 +16,8 @@
 # under the License.
 from itertools import zip_longest, combinations
 import json
+import os
+import warnings
 
 import numpy as np
 
@@ -25,15 +27,52 @@ from tvm import rpc
 from tvm.contrib import graph_runtime
 from tvm.relay.op.contrib import arm_compute_lib
 from tvm.contrib import util
+from tvm.autotvm.measure import request_remote
 
 
 class Device:
-    """Adjust the following settings to connect to and use a remote device for tests."""
-    use_remote = False
+    """
+    Configuration for Arm Compute Library tests.
+
+    Check tests/python/contrib/arm_compute_lib/ for the presence of an test_config.json file.
+    This file can be used to override the default configuration here which will attempt to run the Arm
+    Compute Library runtime tests locally if the runtime is available. Changing the configuration
+    will allow these runtime tests to be offloaded to a remote Arm device via a tracker for example.
+
+    Notes
+    -----
+        The test configuration will be loaded once when the the class is created. If the configuration
+        changes between tests, any changes will not be picked up.
+
+    Parameters
+    ----------
+    device : RPCSession
+        Allows tests to connect to and use remote device.
+
+    Attributes
+    ----------
+    connection_type : str
+        Details the type of RPC connection to use. Options:
+        local - Use the local device,
+        tracker - Connect to a tracker to request a remote device,
+        remote - Connect to a remote device directly.
+    host : str
+        Specify IP address or hostname of remote target.
+    port : int
+        Specify port number of remote target.
+    target : str
+        The compilation target.
+    device_key : str
+        The device key of the remote target. Use when connecting to a remote device via a tracker.
+    cross_compile : str
+        Specify path to cross compiler to use when connecting a remote device from a non-arm platform.
+    """
+    connection_type = "local"
+    host = "localhost"
+    port = 9090
     target = "llvm -mtriple=aarch64-linux-gnu -mattr=+neon"
-    # Enable cross compilation when connecting a remote device from a non-arm platform.
-    cross_compile = None
-    # cross_compile = "aarch64-linux-gnu-g++"
+    device_key = ""
+    cross_compile = ""
 
     def __init__(self):
         """Keep remote device for lifetime of object."""
@@ -42,30 +81,42 @@ class Device:
     @classmethod
     def _get_remote(cls):
         """Get a remote (or local) device to use for testing."""
-        if cls.use_remote:
-            # Here you may adjust settings to run the ACL unit tests via a remote
-            # device using the RPC mechanism. Use this in the case you want to compile
-            # an ACL module on a different machine to what you run the module on i.e.
-            # x86 -> AArch64.
-            #
-            # Use the following to connect directly to a remote device:
-            # device = rpc.connect(
-            #     hostname="0.0.0.0",
-            #     port=9090)
-            #
-            # Or connect via a tracker:
-            # device = tvm.autotvm.measure.request_remote(
-            #     host="0.0.0.0",
-            #     port=9090,
-            #     device_key="device_key",
-            #     timeout=1000)
-            #
-            # return device
-            raise NotImplementedError(
-                "Please adjust these settings to connect to your remote device.")
-        else:
+        if cls.connection_type == "tracker":
+            device = request_remote(cls.device_key,
+                                    cls.host,
+                                    cls.port,
+                                    timeout=1000)
+        elif cls.connection_type == "remote":
+            device = rpc.connect(cls.host, cls.port)
+        elif cls.connection_type == "local":
             device = rpc.LocalSession()
-            return device
+        else:
+            raise ValueError("connection_type in test_config.json should be one of: "
+                             "local, tracker, remote.")
+
+        return device
+
+    @classmethod
+    def load(cls, file_name):
+        """Load test config
+
+        Load the test configuration by looking for file_name relative
+        to the test_arm_compute_lib directory.
+        """
+        location = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
+        config_file = os.path.join(location, file_name)
+        if not os.path.exists(config_file):
+            warnings.warn("Config file doesn't exist, resuming Arm Compute Library tests with default config.")
+            return
+        with open(config_file, mode="r") as config:
+            test_config = json.load(config)
+
+        cls.connection_type = test_config["connection_type"]
+        cls.host = test_config["host"]
+        cls.port = test_config["port"]
+        cls.target = test_config["target"]
+        cls.device_key = test_config.get("device_key") or ""
+        cls.cross_compile = test_config.get("cross_compile") or ""
 
 
 def get_cpu_op_count(mod):
@@ -94,7 +145,8 @@ def skip_runtime_test():
         return True
 
     # Remote device is in use or ACL runtime not present
-    if not Device.use_remote and not arm_compute_lib.is_arm_compute_runtime_enabled():
+    # Note: Ensure that the device config has been loaded before this check
+    if not Device.connection_type != "local" and not arm_compute_lib.is_arm_compute_runtime_enabled():
         print("Skip because runtime isn't present or a remote device isn't being used.")
         return True
 
diff --git a/tests/python/contrib/test_arm_compute_lib/test_config.json b/tests/python/contrib/test_arm_compute_lib/test_config.json
new file mode 100644 (file)
index 0000000..c8168ae
--- /dev/null
@@ -0,0 +1,8 @@
+{
+  "connection_type": "local",
+  "host": "localhost",
+  "port": 9090,
+  "target": "llvm -mtriple=aarch64-linux-gnu -mattr=+neon",
+  "device_key": "",
+  "cross_compile": ""
+}
index a89f04d..555cbe1 100644 (file)
@@ -235,6 +235,8 @@ def _get_expected_codegen(shape, kernel_h, kernel_w, padding, strides,
 
 
 def test_conv2d():
+    Device.load("test_config.json")
+
     if skip_runtime_test():
         return
 
@@ -325,6 +327,8 @@ def test_codegen_conv2d():
 
 
 def test_qnn_conv2d():
+    Device.load("test_config.json")
+
     if skip_runtime_test():
         return
 
index 2208026..5482075 100644 (file)
@@ -176,6 +176,8 @@ def _get_expected_codegen(shape, weight_shape, units, dtype,
 
 
 def test_dense():
+    Device.load("test_config.json")
+
     if skip_runtime_test():
         return
 
@@ -231,6 +233,8 @@ def test_codegen_dense():
 
 
 def test_qnn_dense():
+    Device.load("test_config.json")
+
     if skip_runtime_test():
         return
 
index ceef179..18cac33 100644 (file)
@@ -79,6 +79,8 @@ def _get_keras_model(keras_model, inputs_dict):
 
 
 def test_vgg16():
+    Device.load("test_config.json")
+
     if skip_runtime_test():
         return
 
@@ -98,6 +100,8 @@ def test_vgg16():
 
 
 def test_mobilenet():
+    Device.load("test_config.json")
+
     if skip_runtime_test():
         return
 
@@ -117,6 +121,8 @@ def test_mobilenet():
 
 
 def test_quantized_mobilenet():
+    Device.load("test_config.json")
+
     if skip_runtime_test():
         return
 
index c9ae1d9..32176af 100644 (file)
@@ -68,6 +68,8 @@ def _get_expected_codegen(shape, dtype, typef, sizes, strides,
 
 
 def test_pooling():
+    Device.load("test_config.json")
+
     if skip_runtime_test():
         return
 
index 8ab9437..38694e8 100644 (file)
@@ -57,6 +57,8 @@ def _get_expected_codegen(input_shape, output_shape, dtype):
 
 
 def test_reshape():
+    Device.load("test_config.json")
+
     if skip_runtime_test():
         return
 
index 2bb17ad..1ce2909 100644 (file)
@@ -31,6 +31,8 @@ def test_multiple_ops():
     The ACL runtime will expect these ops as 2 separate functions for
     the time being.
     """
+    Device.load("test_config.json")
+
     if skip_runtime_test():
         return
 
@@ -61,6 +63,8 @@ def test_heterogeneous():
     Test to check if offloading only supported operators works,
     while leaving unsupported operators computed via tvm.
     """
+    Device.load("test_config.json")
+
     if skip_runtime_test():
         return
 
@@ -92,6 +96,8 @@ def test_multiple_runs():
     """
     Test that multiple runs of an operator work.
     """
+    Device.load("test_config.json")
+
     if skip_runtime_test():
         return