[Ansor][AutoTVM v2.0] Phase 1: Add RPC Runner (#6077)
authorChenfan <chengfan.jcf@alibaba-inc.com>
Mon, 20 Jul 2020 18:13:48 +0000 (02:13 +0800)
committerGitHub <noreply@github.com>
Mon, 20 Jul 2020 18:13:48 +0000 (11:13 -0700)
* Add rpc runner

* Update

* Update

* Add clflush & non-empty ndarray TODO hints

* Update

* UT Update

* Update timeout in UT

python/tvm/auto_scheduler/__init__.py
python/tvm/auto_scheduler/measure.py
python/tvm/auto_scheduler/utils.py
python/tvm/rpc/server.py
src/auto_scheduler/measure.cc
src/auto_scheduler/measure.h
tests/python/unittest/test_auto_scheduler_measure.py

index 90bec86..c3a3712 100644 (file)
@@ -28,7 +28,8 @@ from . import workload_registry
 from .compute_dag import ComputeDAG
 from .auto_schedule import SearchTask, TuningOptions, HardwareParams, \
     auto_schedule, EmptyPolicy
-from .measure import MeasureInput, LocalBuilder, LocalRunner
+from .measure import MeasureInput, LocalBuilder, LocalRunner, RPCRunner, \
+    LocalRPCMeasureContext
 from .measure_record import RecordToFile, RecordReader, load_best, \
     load_records, save_records
 from .workload_registry import register_workload, make_workload_key
index e99c47e..03ad23e 100644 (file)
@@ -42,10 +42,14 @@ import tvm._ffi
 from tvm.runtime import Object, module, ndarray
 from tvm.driver import build_module
 from tvm.ir import transform
+from tvm.rpc.tracker import Tracker
+from tvm.rpc.server import Server
+from tvm.autotvm.measure.measure_methods import set_cuda_target_arch
 from tvm.contrib import tar, ndk
 
 from . import _ffi_api
-from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout
+from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout, request_remote, \
+    check_remote
 
 # The maximum length of error message
 MAX_ERROR_MSG_LEN = 512
@@ -53,6 +57,7 @@ MAX_ERROR_MSG_LEN = 512
 # We use fork and a global variable to copy arguments between processings.
 # This can avoid expensive serialization of TVM IR when using multiprocessing.Pool
 GLOBAL_BUILD_ARGUMENTS = None
+GLOBAL_RUN_ARGUMENTS = None
 
 @tvm._ffi.register_object("auto_scheduler.MeasureCallback")
 class MeasureCallback(Object):
@@ -195,6 +200,8 @@ class LocalBuilder(ProgramBuilder):
 class LocalRunner(ProgramRunner):
     """ LocalRunner that uses local CPU/GPU to measures the time cost of programs.
 
+    TODO(FrozenGene): Add cpu cache flush to this runner.
+
     Parameters
     ----------
     timeout : int = 10
@@ -230,6 +237,124 @@ class LocalRunner(ProgramRunner):
             _ffi_api.LocalRunner, timeout, number, repeat, min_repeat_ms, cooldown_interval)
 
 
+@tvm._ffi.register_object("auto_scheduler.RPCRunner")
+class RPCRunner(ProgramRunner):
+    """ RPCRunner that uses RPC call to measures the time cost of programs on remote devices.
+    Or sometime we may need to use RPC even in local running to insulate the thread environment.
+    (e.g. running CUDA programs)
+
+    TODO(FrozenGene): Add cpu cache flush to this runner.
+
+    Parameters
+    ----------
+    key : str
+        The key of the device registered in the RPC tracker.
+    host : str
+        The host address of the RPC Tracker.
+    port : int
+        The port of RPC Tracker.
+    priority : int = 1
+        The priority of this run request, larger is more prior.
+    n_parallel : int = 1
+        The number of tasks run in parallel.
+    timeout : int = 10
+        The timeout limit (in second) for each run.
+        This is used in a wrapper of the multiprocessing.Process.join().
+    number : int = 3
+        The number of times to run the generated code for taking average.
+        We call these runs as one `repeat` of measurement.
+    repeat : int = 1
+        The number of times to repeat the measurement.
+        In total, the generated code will be run (1 + number x repeat) times,
+        where the first "1" is warm up and will be discarded.
+        The returned result contains `repeat` costs,
+        each of which is an average of `number` costs.
+    min_repeat_ms : int = 0
+        The minimum duration of one `repeat` in milliseconds.
+        By default, one `repeat` contains `number` runs. If this parameter is set,
+        the parameters `number` will be dynamically adjusted to meet the
+        minimum duration requirement of one `repeat`.
+        i.e., When the run time of one `repeat` falls below this time, the `number` parameter
+        will be automatically increased.
+    cooldown_interval : float = 0.0
+        The cool down interval between two measurements.
+    """
+
+    def __init__(self, key, host, port,
+                 priority=1, n_parallel=1, timeout=10, number=3, repeat=1,
+                 min_repeat_ms=0, cooldown_interval=0.0):
+        self.__init_handle_by_constructor__(
+            _ffi_api.RPCRunner, key, host, port, priority, n_parallel, timeout,
+            number, repeat, min_repeat_ms, cooldown_interval)
+
+        if check_remote(key, host, port, priority, timeout):
+            print("Get devices for measurement successfully!")
+        else:
+            raise RuntimeError("Cannot get remote devices from the tracker. "
+                               "Please check the status of tracker by "
+                               "'python -m tvm.exec.query_rpc_tracker --port [THE PORT YOU USE]' "
+                               "and make sure you have free devices on the queue status.")
+
+
+class LocalRPCMeasureContext:
+    """ A context wrapper for running RPCRunner locally.
+    This will launch a local RPC Tracker and local RPC Server.
+
+    TODO(FrozenGene): Add cpu cache flush to this RPC context.
+
+    Parameters
+    ----------
+    priority : int = 1
+        The priority of this run request, larger is more prior.
+    n_parallel : int = 1
+        The number of tasks run in parallel.
+    timeout : int = 10
+        The timeout limit (in second) for each run.
+        This is used in a wrapper of the multiprocessing.Process.join().
+    number : int = 3
+        The number of times to run the generated code for taking average.
+        We call these runs as one `repeat` of measurement.
+    repeat : int = 1
+        The number of times to repeat the measurement.
+        In total, the generated code will be run (1 + number x repeat) times,
+        where the first "1" is warm up and will be discarded.
+        The returned result contains `repeat` costs,
+        each of which is an average of `number` costs.
+    min_repeat_ms : int = 0
+        The minimum duration of one `repeat` in milliseconds.
+        By default, one `repeat` contains `number` runs. If this parameter is set,
+        the parameters `number` will be dynamically adjusted to meet the
+        minimum duration requirement of one `repeat`.
+        i.e., When the run time of one `repeat` falls below this time, the `number` parameter
+        will be automatically increased.
+    cooldown_interval : float = 0.0
+        The cool down interval between two measurements.
+    """
+
+    def __init__(self, priority=1, n_parallel=1, timeout=10, number=3, repeat=1,
+                 min_repeat_ms=0, cooldown_interval=0.0):
+        ctx = tvm.context("cuda", 0)
+        if ctx.exist:
+            cuda_arch = "sm_" + "".join(ctx.compute_version.split('.'))
+            set_cuda_target_arch(cuda_arch)
+        host = '0.0.0.0'
+        self.tracker = Tracker(host, port=9000, port_end=10000, silent=True)
+        device_key = '$local$device$%d' % self.tracker.port
+        self.server = Server(host, port=self.tracker.port, port_end=10000,
+                             key=device_key, use_popen=True, silent=True,
+                             tracker_addr=(self.tracker.host, self.tracker.port))
+        self.runner = RPCRunner(device_key, host, self.tracker.port, priority,
+                                n_parallel, timeout, number, repeat,
+                                min_repeat_ms, cooldown_interval)
+        # Wait for the processes to start
+        time.sleep(0.5)
+
+    def __del__(self):
+        # Close the tracker and server before exit
+        self.tracker.terminate()
+        self.server.terminate()
+
+
 class MeasureErrorNo(object):
     """ Error type for MeasureResult. """
     NO_ERROR = 0              # No error
@@ -307,7 +432,8 @@ def local_build_worker(index):
                 dirname, "tmp_func." + build_func.output_format)
 
             try:
-                with transform.PassContext():  # todo(lmzheng): port the unroll pass
+                # TODO(merrymercy): Port the unroll pass.
+                with transform.PassContext():
                     func = build_module.build(
                         sch, args, target=task.target, target_host=task.target_host)
                 func.export_library(filename, build_func)
@@ -376,8 +502,10 @@ def local_builder_build(inputs, timeout, n_parallel, build_func='default', verbo
 
     return results
 
+
 @tvm._ffi.register_func("auto_scheduler.local_runner.run")
-def local_run(inputs, build_results, timeout, number, repeat, min_repeat_ms, cooldown_interval,
+def local_run(inputs, build_results,
+              timeout=10, number=3, repeat=1, min_repeat_ms=0, cooldown_interval=0,
               verbose=1):
     """
     Run function of LocalRunner to test the performance of the input BuildResults.
@@ -388,7 +516,7 @@ def local_run(inputs, build_results, timeout, number, repeat, min_repeat_ms, coo
         The MeasureInputs to be measured.
     build_results : List[BuildResult]
         The BuildResults to be measured.
-    timeout : int
+    timeout : int = 10
         The timeout limit (in second) for each run.
         This is used in a wrapper of the multiprocessing.Process.join().
     number : int = 3
@@ -426,6 +554,7 @@ def local_run(inputs, build_results, timeout, number, repeat, min_repeat_ms, coo
         try:
             func = module.load_module(build_res.filename)
             ctx = ndarray.context(str(inp.task.target), 0)
+            # TODO(FrozenGene): Add cpu cache flush to this function.
             time_f = func.time_evaluator(
                 func.entry_name, ctx, number=number, repeat=repeat, min_repeat_ms=min_repeat_ms)
         # pylint: disable=broad-except
@@ -436,6 +565,7 @@ def local_run(inputs, build_results, timeout, number, repeat, min_repeat_ms, coo
 
         if error_no == 0:
             try:
+                # TODO(FrozenGene): Update to ndarray.non-empty.
                 args = [ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in
                         build_res.args]
                 ctx.sync()
@@ -478,3 +608,160 @@ def local_run(inputs, build_results, timeout, number, repeat, min_repeat_ms, coo
         print("")
 
     return measure_results
+
+
+def rpc_run_worker(index):
+    """ Function to be ran in the RPCRunner thread pool.
+
+    Parameters
+    ----------
+    index : int
+        The MeasureInput and BuildResult index to be processed by the current Runner thread.
+
+    Returns
+    -------
+    res : MeasureResult
+        The measure result of this Runner thread.
+    """
+    global GLOBAL_RUN_ARGUMENTS
+    inputs, build_results, key, host, port, priority, timeout, number, \
+        repeat, min_repeat_ms, cooldown_interval, verbose = GLOBAL_RUN_ARGUMENTS
+
+    max_float = 1e10  # We use 1e10 instead of sys.float_info.max for better readability in log
+    inp = inputs[index]
+    build_res = build_results[index]
+
+    if build_res.error_no != MeasureErrorNo.NO_ERROR:
+        return (max_float,), build_res.error_no, build_res.error_msg, build_res.time_cost, \
+            time.time()
+
+    def timed_func():
+        tic = time.time()
+        error_no = 0
+        error_msg = None
+        try:
+            # upload built module
+            remote = request_remote(key, host, port, priority, timeout)
+            remote.upload(build_res.filename)
+            func = remote.load_module(os.path.split(build_res.filename)[1])
+            ctx = remote.context(str(inp.task.target), 0)
+            # TODO(FrozenGene): Add cpu cache flush to this function.
+            time_f = func.time_evaluator(
+                func.entry_name, ctx, number=number, repeat=repeat, min_repeat_ms=min_repeat_ms)
+        # pylint: disable=broad-except
+        except Exception:
+            costs = (max_float,)
+            error_no = MeasureErrorNo.COMPILE_DEVICE
+            error_msg = make_error_msg()
+
+        if error_no == 0:
+            try:
+                # TODO(FrozenGene): Update to ndarray.non-empty.
+                args = [ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in
+                        build_res.args]
+                ctx.sync()
+
+                costs = time_f(*args).results
+                # clean up remote files
+                remote.remove(build_res.filename)
+                remote.remove(os.path.splitext(build_res.filename)[0] + '.so')
+                remote.remove('')
+            # pylint: disable=broad-except
+            except Exception:
+                costs = (max_float,)
+                error_no = MeasureErrorNo.RUNTIME_DEVICE
+                error_msg = make_error_msg()
+
+        shutil.rmtree(os.path.dirname(build_res.filename))
+        toc = time.time()
+
+        time.sleep(cooldown_interval)
+        if verbose >= 1:
+            if error_no == MeasureErrorNo.NO_ERROR:
+                print("*", end="")
+            else:
+                print("*E", end="")  # Run error
+
+        return costs, error_no, error_msg, toc - tic + build_res.time_cost, toc
+
+    res = call_func_with_timeout(timeout, timed_func)
+
+    if isinstance(res, TimeoutError):
+        if verbose >= 1:
+            print("*T", end="")  # Run timeout
+        res = (max_float,), MeasureErrorNo.RUN_TIMEOUT, None, build_res.time_cost + \
+            timeout, time.time()
+    return res
+
+
+@tvm._ffi.register_func("auto_scheduler.rpc_runner.run")
+def rpc_runner_run(inputs, build_results, key, host, port,
+                   priority=1, n_parallel=1, timeout=10, number=3, repeat=1, min_repeat_ms=0,
+                   cooldown_interval=0.0, verbose=1):
+    """ Run function of RPCRunner to test the performance of the input BuildResults.
+
+    Parameters
+    ----------
+    inputs : List[MeasureInput]
+        The MeasureInputs to be measured.
+    build_results : List[BuildResult]
+        The BuildResults to be measured.
+    key : str
+        The key of the device registered in the RPC tracker.
+    host : str
+        The host address of the RPC Tracker.
+    port : int
+        The port of RPC Tracker.
+    priority : int = 1
+        The priority of this run request, larger is more prior.
+    n_parallel : int = 1
+        The number of tasks run in parallel.
+    timeout : int = 10
+        The timeout limit (in second) for each run.
+        This is used in a wrapper of the multiprocessing.Process.join().
+    number : int = 3
+        The number of times to run the generated code for taking average.
+        We call these runs as one `repeat` of measurement.
+    repeat : int = 1
+        The number of times to repeat the measurement.
+        In total, the generated code will be run (1 + number x repeat) times,
+        where the first "1" is warm up and will be discarded.
+        The returned result contains `repeat` costs,
+        each of which is an average of `number` costs.
+    min_repeat_ms : int = 0
+        The minimum duration of one `repeat` in milliseconds.
+        By default, one `repeat` contains `number` runs. If this parameter is set,
+        the parameters `number` will be dynamically adjusted to meet the
+        minimum duration requirement of one `repeat`.
+        i.e., When the run time of one `repeat` falls below this time, the `number` parameter
+        will be automatically increased.
+    cooldown_interval : float = 0.0
+        The cool down interval between two measurements.
+    verbose: int = 1
+        Verbosity level. 0 for silent, 1 to output information during program measuring.
+
+    Returns
+    -------
+    res : List[MeasureResult]
+        The measure results of these MeasureInputs.
+    """
+    global GLOBAL_RUN_ARGUMENTS
+    GLOBAL_RUN_ARGUMENTS = (inputs, build_results, key, host, port, priority, timeout, number,
+                            repeat, min_repeat_ms, cooldown_interval, verbose)
+
+    assert len(inputs) == len(build_results), \
+        "Measure input size should be equal to build results"
+    pool = NoDaemonPool(n_parallel)
+    tuple_res = pool.map(rpc_run_worker, range(len(build_results)))
+    pool.terminate()
+    pool.join()
+    del pool
+
+    results = []
+    for res in tuple_res:
+        results.append(MeasureResult(*res))
+
+    if verbose >= 1:
+        print("")
+
+    return results
index f7b1202..f5b53fb 100644 (file)
@@ -22,12 +22,15 @@ import multiprocessing
 import multiprocessing.pool
 import queue
 import signal
+import threading
+import os
 
 try:
     import psutil
 except ImportError:
     raise ImportError("psutil not found, try `pip install psutil` to fix this")
 
+from tvm import rpc
 from tvm.tir import expr
 from tvm.tir.transform import Simplify
 from tvm.ir.transform import Sequential
@@ -193,3 +196,70 @@ def call_func_with_timeout(timeout, func, args=(), kwargs=None):
     del que
 
     return res
+
+
+def request_remote(device_key, host=None, port=None, priority=1, timeout=60):
+    """ Request a remote session.
+
+    Parameters
+    ----------
+    device_key : str
+        The device key of registered device in tracker.
+    host : Optional[str]
+        The host address of rpc tracker.
+        If is none, will use environment variable "TVM_TRACKER_HOST".
+    port : Optional[int]
+        The port of rpc tracker.
+        If is none, will use environment variable "TVM_TRACKER_PORT".
+    priority : int = 1
+        The priority of this request, larger is more prior.
+    timeout : int = 60
+        The timeout of this session in second.
+
+    Returns
+    -------
+    remote : RPCSession
+        The connected remote RPCSession.
+    """
+    # connect to the tracker
+    host = host or os.environ['TVM_TRACKER_HOST']
+    port = port or int(os.environ['TVM_TRACKER_PORT'])
+
+    tracker = rpc.connect_tracker(host, port)
+    remote = tracker.request(device_key, priority=priority,
+                             session_timeout=timeout)
+    return remote
+
+
+def check_remote(device_key, host=None, port=None, priority=100, timeout=10):
+    """
+    Check the availability of a remote device.
+
+    Parameters
+    ----------
+    device_key: str
+        device key of registered device in tracker.
+    host: Optional[str]
+        The host address of rpc tracker.
+        If is none, will use environment variable "TVM_TRACKER_HOST".
+    port: Optional[int]
+        The port address of rpc tracker.
+        If is none, will use environment variable "TVM_TRACKER_PORT".
+    priority: int = 100
+        The priority of this request, larger is more prior.
+    timeout: int = 10
+        The timeout of this check in seconds.
+
+    Returns
+    -------
+    available: bool
+        True if can find available device.
+    """
+
+    def _check():
+        request_remote(device_key, host, port, priority)
+
+    t = threading.Thread(target=_check, )
+    t.start()
+    t.join(timeout)
+    return not t.is_alive()
index 15a3c7d..42bcb00 100644 (file)
@@ -348,7 +348,8 @@ class Server(object):
             cmd = [sys.executable,
                    "-m", "tvm.exec.rpc_server",
                    "--host=%s" % host,
-                   "--port=%s" % port]
+                   "--port=%s" % port,
+                   "--port-end=%s" % port_end]
             if tracker_addr:
                 assert key
                 cmd += ["--tracker=%s:%d" % tracker_addr,
index 3c54552..6198f60 100644 (file)
@@ -41,6 +41,7 @@ TVM_REGISTER_OBJECT_TYPE(ProgramRunnerNode);
 TVM_REGISTER_OBJECT_TYPE(ProgramBuilderNode);
 TVM_REGISTER_OBJECT_TYPE(LocalBuilderNode);
 TVM_REGISTER_OBJECT_TYPE(LocalRunnerNode);
+TVM_REGISTER_OBJECT_TYPE(RPCRunnerNode);
 
 static const char* ErrorNoToStr[] = {
     "NoError",
@@ -146,6 +147,39 @@ Array<MeasureResult> LocalRunnerNode::Run(const Array<MeasureInput>& inputs,
   throw;
 }
 
+/********** RPCRunner **********/
+RPCRunner::RPCRunner(const String& key, const String& host, int port, int priority, int n_parallel,
+                     int timeout, int number, int repeat, int min_repeat_ms,
+                     double cooldown_interval) {
+  auto node = make_object<RPCRunnerNode>();
+  node->key = key;
+  node->host = host;
+  node->port = port;
+  node->priority = priority;
+  node->timeout = timeout;
+  node->n_parallel = n_parallel;
+  node->number = number;
+  node->repeat = repeat;
+  node->min_repeat_ms = min_repeat_ms;
+  node->cooldown_interval = cooldown_interval;
+  data_ = std::move(node);
+}
+
+Array<MeasureResult> RPCRunnerNode::Run(const Array<MeasureInput>& inputs,
+                                        const Array<BuildResult>& build_results, int verbose) {
+  if (const auto* f = runtime::Registry::Get("auto_scheduler.rpc_runner.run")) {
+    Array<MeasureResult> results =
+        (*f)(inputs, build_results, key, host, port, priority, n_parallel, timeout, number, repeat,
+             min_repeat_ms, cooldown_interval, verbose);
+    return results;
+  } else {
+    LOG(FATAL) << "auto_scheduler.rpc_runner.run is not registered. "
+               << "This is a function registered in Python, "
+               << "make sure the TVM Python runtime has been loaded successfully.";
+  }
+  return Array<MeasureResult>();
+}
+
 /********** ProgramMeasurer **********/
 ProgramMeasurer::ProgramMeasurer(ProgramBuilder builder, ProgramRunner runner,
                                  Optional<Array<MeasureCallback>> callbacks, int verbose,
@@ -327,5 +361,13 @@ TVM_REGISTER_GLOBAL("auto_scheduler.LocalRunner")
       return LocalRunner(timeout, number, repeat, min_repeat_ms, cooldown_interval);
     });
 
+TVM_REGISTER_GLOBAL("auto_scheduler.RPCRunner")
+    .set_body_typed([](const String& key, const String& host, int port, int priority,
+                       int n_parallel, int timeout, int number, int repeat, int min_repeat_ms,
+                       double cooldown_interval) {
+      return RPCRunner(key, host, port, priority, n_parallel, timeout, number, repeat,
+                       min_repeat_ms, cooldown_interval);
+    });
+
 }  // namespace auto_scheduler
 }  // namespace tvm
index 50b6fcf..02d6e87 100644 (file)
@@ -266,6 +266,14 @@ class ProgramRunnerNode : public Object {
  public:
   /*! \brief Timeout of a run. */
   int timeout;
+  /*! \brief The number of times to run the generated code for taking average. */
+  int number;
+  /*! \brief The number of times to repeat the measurement. */
+  int repeat;
+  /*! \brief The minimum duration of one repeat in milliseconds. */
+  int min_repeat_ms;
+  /*! \brief The cool down interval between two measurements. */
+  double cooldown_interval;
 
   /*!
    * \brief Run measurement and return results.
@@ -326,15 +334,6 @@ class LocalBuilder : public ProgramBuilder {
 /*! \brief LocalRunner that uses local CPU/GPU to measures the time cost of programs */
 class LocalRunnerNode : public ProgramRunnerNode {
  public:
-  /*! \brief Number of measure times. */
-  int number;
-  /*! \brief Number of repeat times in each measure. */
-  int repeat;
-  /*! \brief The minimum duration of one repeat in milliseconds. */
-  int min_repeat_ms;
-  /*! \brief The cool down interval between two measurements. */
-  double cooldown_interval;
-
   Array<MeasureResult> Run(const Array<MeasureInput>& inputs,
                            const Array<BuildResult>& build_results, int verbose) final;
 
@@ -353,8 +352,8 @@ class LocalRunner : public ProgramRunner {
    * for more detailed parameter explaination.
    * \param timeout The timeout limit (in second) for each run.
    * This is used in a wrapper of the multiprocessing.Process.join().
-   * \param number Number of measure times.
-   * \param repeat Number of repeat times in each measure.
+   * \param number The number of times to run the generated code for taking average.
+   * \param repeat The number of times to repeat the measurement.
    * \param min_repeat_ms The minimum duration of one repeat in milliseconds.
    * \param cooldown_interval The cool down interval between two measurements.
    */
@@ -364,6 +363,57 @@ class LocalRunner : public ProgramRunner {
 };
 
 /*!
+ * \brief RPCRunner that uses RPC call to measures the time cost of programs on remote devices.
+ * Or sometime we may need to use RPC even in local running to insulate the thread environment.
+ * (e.g. running CUDA programs)
+ */
+class RPCRunnerNode : public ProgramRunnerNode {
+ public:
+  /*! \brief The key of the device registered in the RPC tracker. */
+  String key;
+  /*! \brief The host address of the RPC Tracker. */
+  String host;
+  /*! \brief The port of RPC Tracker. */
+  int port;
+  /*! \brief The priority of this run request, larger is more prior. */
+  int priority;
+  /*! \brief The number of tasks run in parallel. */
+  int n_parallel;
+  /*! \brief The number of times to run the generated code for taking average. */
+
+  Array<MeasureResult> Run(const Array<MeasureInput>& inputs,
+                           const Array<BuildResult>& build_results, int verbose) final;
+
+  static constexpr const char* _type_key = "auto_scheduler.RPCRunner";
+  TVM_DECLARE_FINAL_OBJECT_INFO(RPCRunnerNode, ProgramRunnerNode);
+};
+
+/*!
+ * \brief Managed reference to RPCRunnerNode.
+ * \sa RPCRunnerNode
+ */
+class RPCRunner : public ProgramRunner {
+ public:
+  /*!
+   * \brief The constructor.
+   * \param key The key of the device registered in the RPC tracker.
+   * \param host The host address of the RPC Tracker.
+   * \param prot The port of RPC Tracker.
+   * \param priority The priority of this run request, larger is more prior.
+   * \param n_parallel The number of tasks run in parallel.
+   * \param timeout Timeout of a run.
+   * \param number The number of times to run the generated code for taking average.
+   * \param repeat The number of times to repeat the measurement.
+   * \param min_repeat_ms The minimum duration of one repeat in milliseconds.
+   * \param cooldown_interval The cool down interval between two measurements.
+   */
+  RPCRunner(const String& key, const String& host, int port, int priority, int n_parallel,
+            int timeout, int number, int repeat, int min_repeat_ms, double cooldown_interval);
+
+  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(RPCRunner, ProgramRunner, RPCRunnerNode);
+};
+
+/*!
  * \brief Measurer that measures the time costs of tvm programs
  * This class combines ProgramBuilder and ProgramRunner, and provides a simpler API */
 class ProgramMeasurerNode : public Object {
index 1bcd054..d6e6c51 100644 (file)
@@ -25,10 +25,10 @@ from test_auto_scheduler_common import get_tiled_matmul
 
 
 def test_record():
-    dag, s = get_tiled_matmul()
-
     if not tvm.runtime.enabled("llvm"):
         return
+
+    dag, s = get_tiled_matmul()
     target = tvm.target.create("llvm")
     task = auto_scheduler.SearchTask(dag, "test", target)
 
@@ -50,23 +50,43 @@ def test_record():
 
 
 def test_measure_local_builder_runner():
+    if not tvm.runtime.enabled("llvm"):
+        return
+
     dag, s0 = get_tiled_matmul()
+    tgt = tvm.target.create("llvm")
+    task = auto_scheduler.SearchTask(dag, "test", tgt)
 
+    minp = auto_scheduler.MeasureInput(task, s0)
+    local_builder = auto_scheduler.LocalBuilder()
+    local_runner = auto_scheduler.LocalRunner(timeout=60)
+
+    bress = local_builder.build([minp])
+    assert bress[0].error_no == 0
+    mress = local_runner.run([minp], bress)
+    assert mress[0].error_no == 0
+
+
+def test_measure_local_builder_rpc_runner():
     if not tvm.runtime.enabled("llvm"):
         return
+
+    dag, s0 = get_tiled_matmul()
     tgt = tvm.target.create("llvm")
     task = auto_scheduler.SearchTask(dag, "test", tgt)
 
     minp = auto_scheduler.MeasureInput(task, s0)
     local_builder = auto_scheduler.LocalBuilder()
-    local_runner = auto_scheduler.LocalRunner()
+    measure_ctx = auto_scheduler.LocalRPCMeasureContext(timeout=60)
+    rpc_runner = measure_ctx.runner
 
     bress = local_builder.build([minp])
     assert bress[0].error_no == 0
-    mress = local_runner.run([minp], bress)
+    mress = rpc_runner.run([minp], bress)
     assert mress[0].error_no == 0
 
 
 if __name__ == "__main__":
     test_record()
     test_measure_local_builder_runner()
+    test_measure_local_builder_rpc_runner()