From 50f80474e2f0f3aa83c97295b22f23e93101e859 Mon Sep 17 00:00:00 2001 From: Chenfan Date: Tue, 21 Jul 2020 02:13:48 +0800 Subject: [PATCH] [Ansor][AutoTVM v2.0] Phase 1: Add RPC Runner (#6077) * Add rpc runner * Update * Update * Add clflush & non-empty ndarray TODO hints * Update * UT Update * Update timeout in UT --- python/tvm/auto_scheduler/__init__.py | 3 +- python/tvm/auto_scheduler/measure.py | 295 ++++++++++++++++++++- python/tvm/auto_scheduler/utils.py | 70 +++++ python/tvm/rpc/server.py | 3 +- src/auto_scheduler/measure.cc | 42 +++ src/auto_scheduler/measure.h | 72 ++++- .../python/unittest/test_auto_scheduler_measure.py | 28 +- 7 files changed, 492 insertions(+), 21 deletions(-) diff --git a/python/tvm/auto_scheduler/__init__.py b/python/tvm/auto_scheduler/__init__.py index 90bec86..c3a3712 100644 --- a/python/tvm/auto_scheduler/__init__.py +++ b/python/tvm/auto_scheduler/__init__.py @@ -28,7 +28,8 @@ from . import workload_registry from .compute_dag import ComputeDAG from .auto_schedule import SearchTask, TuningOptions, HardwareParams, \ auto_schedule, EmptyPolicy -from .measure import MeasureInput, LocalBuilder, LocalRunner +from .measure import MeasureInput, LocalBuilder, LocalRunner, RPCRunner, \ + LocalRPCMeasureContext from .measure_record import RecordToFile, RecordReader, load_best, \ load_records, save_records from .workload_registry import register_workload, make_workload_key diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index e99c47e..03ad23e 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -42,10 +42,14 @@ import tvm._ffi from tvm.runtime import Object, module, ndarray from tvm.driver import build_module from tvm.ir import transform +from tvm.rpc.tracker import Tracker +from tvm.rpc.server import Server +from tvm.autotvm.measure.measure_methods import set_cuda_target_arch from tvm.contrib import tar, ndk from . import _ffi_api -from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout +from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout, request_remote, \ + check_remote # The maximum length of error message MAX_ERROR_MSG_LEN = 512 @@ -53,6 +57,7 @@ MAX_ERROR_MSG_LEN = 512 # We use fork and a global variable to copy arguments between processings. # This can avoid expensive serialization of TVM IR when using multiprocessing.Pool GLOBAL_BUILD_ARGUMENTS = None +GLOBAL_RUN_ARGUMENTS = None @tvm._ffi.register_object("auto_scheduler.MeasureCallback") class MeasureCallback(Object): @@ -195,6 +200,8 @@ class LocalBuilder(ProgramBuilder): class LocalRunner(ProgramRunner): """ LocalRunner that uses local CPU/GPU to measures the time cost of programs. + TODO(FrozenGene): Add cpu cache flush to this runner. + Parameters ---------- timeout : int = 10 @@ -230,6 +237,124 @@ class LocalRunner(ProgramRunner): _ffi_api.LocalRunner, timeout, number, repeat, min_repeat_ms, cooldown_interval) +@tvm._ffi.register_object("auto_scheduler.RPCRunner") +class RPCRunner(ProgramRunner): + """ RPCRunner that uses RPC call to measures the time cost of programs on remote devices. + Or sometime we may need to use RPC even in local running to insulate the thread environment. + (e.g. running CUDA programs) + + TODO(FrozenGene): Add cpu cache flush to this runner. + + Parameters + ---------- + key : str + The key of the device registered in the RPC tracker. + host : str + The host address of the RPC Tracker. + port : int + The port of RPC Tracker. + priority : int = 1 + The priority of this run request, larger is more prior. + n_parallel : int = 1 + The number of tasks run in parallel. + timeout : int = 10 + The timeout limit (in second) for each run. + This is used in a wrapper of the multiprocessing.Process.join(). + number : int = 3 + The number of times to run the generated code for taking average. + We call these runs as one `repeat` of measurement. + repeat : int = 1 + The number of times to repeat the measurement. + In total, the generated code will be run (1 + number x repeat) times, + where the first "1" is warm up and will be discarded. + The returned result contains `repeat` costs, + each of which is an average of `number` costs. + min_repeat_ms : int = 0 + The minimum duration of one `repeat` in milliseconds. + By default, one `repeat` contains `number` runs. If this parameter is set, + the parameters `number` will be dynamically adjusted to meet the + minimum duration requirement of one `repeat`. + i.e., When the run time of one `repeat` falls below this time, the `number` parameter + will be automatically increased. + cooldown_interval : float = 0.0 + The cool down interval between two measurements. + """ + + def __init__(self, key, host, port, + priority=1, n_parallel=1, timeout=10, number=3, repeat=1, + min_repeat_ms=0, cooldown_interval=0.0): + self.__init_handle_by_constructor__( + _ffi_api.RPCRunner, key, host, port, priority, n_parallel, timeout, + number, repeat, min_repeat_ms, cooldown_interval) + + if check_remote(key, host, port, priority, timeout): + print("Get devices for measurement successfully!") + else: + raise RuntimeError("Cannot get remote devices from the tracker. " + "Please check the status of tracker by " + "'python -m tvm.exec.query_rpc_tracker --port [THE PORT YOU USE]' " + "and make sure you have free devices on the queue status.") + + +class LocalRPCMeasureContext: + """ A context wrapper for running RPCRunner locally. + This will launch a local RPC Tracker and local RPC Server. + + TODO(FrozenGene): Add cpu cache flush to this RPC context. + + Parameters + ---------- + priority : int = 1 + The priority of this run request, larger is more prior. + n_parallel : int = 1 + The number of tasks run in parallel. + timeout : int = 10 + The timeout limit (in second) for each run. + This is used in a wrapper of the multiprocessing.Process.join(). + number : int = 3 + The number of times to run the generated code for taking average. + We call these runs as one `repeat` of measurement. + repeat : int = 1 + The number of times to repeat the measurement. + In total, the generated code will be run (1 + number x repeat) times, + where the first "1" is warm up and will be discarded. + The returned result contains `repeat` costs, + each of which is an average of `number` costs. + min_repeat_ms : int = 0 + The minimum duration of one `repeat` in milliseconds. + By default, one `repeat` contains `number` runs. If this parameter is set, + the parameters `number` will be dynamically adjusted to meet the + minimum duration requirement of one `repeat`. + i.e., When the run time of one `repeat` falls below this time, the `number` parameter + will be automatically increased. + cooldown_interval : float = 0.0 + The cool down interval between two measurements. + """ + + def __init__(self, priority=1, n_parallel=1, timeout=10, number=3, repeat=1, + min_repeat_ms=0, cooldown_interval=0.0): + ctx = tvm.context("cuda", 0) + if ctx.exist: + cuda_arch = "sm_" + "".join(ctx.compute_version.split('.')) + set_cuda_target_arch(cuda_arch) + host = '0.0.0.0' + self.tracker = Tracker(host, port=9000, port_end=10000, silent=True) + device_key = '$local$device$%d' % self.tracker.port + self.server = Server(host, port=self.tracker.port, port_end=10000, + key=device_key, use_popen=True, silent=True, + tracker_addr=(self.tracker.host, self.tracker.port)) + self.runner = RPCRunner(device_key, host, self.tracker.port, priority, + n_parallel, timeout, number, repeat, + min_repeat_ms, cooldown_interval) + # Wait for the processes to start + time.sleep(0.5) + + def __del__(self): + # Close the tracker and server before exit + self.tracker.terminate() + self.server.terminate() + + class MeasureErrorNo(object): """ Error type for MeasureResult. """ NO_ERROR = 0 # No error @@ -307,7 +432,8 @@ def local_build_worker(index): dirname, "tmp_func." + build_func.output_format) try: - with transform.PassContext(): # todo(lmzheng): port the unroll pass + # TODO(merrymercy): Port the unroll pass. + with transform.PassContext(): func = build_module.build( sch, args, target=task.target, target_host=task.target_host) func.export_library(filename, build_func) @@ -376,8 +502,10 @@ def local_builder_build(inputs, timeout, n_parallel, build_func='default', verbo return results + @tvm._ffi.register_func("auto_scheduler.local_runner.run") -def local_run(inputs, build_results, timeout, number, repeat, min_repeat_ms, cooldown_interval, +def local_run(inputs, build_results, + timeout=10, number=3, repeat=1, min_repeat_ms=0, cooldown_interval=0, verbose=1): """ Run function of LocalRunner to test the performance of the input BuildResults. @@ -388,7 +516,7 @@ def local_run(inputs, build_results, timeout, number, repeat, min_repeat_ms, coo The MeasureInputs to be measured. build_results : List[BuildResult] The BuildResults to be measured. - timeout : int + timeout : int = 10 The timeout limit (in second) for each run. This is used in a wrapper of the multiprocessing.Process.join(). number : int = 3 @@ -426,6 +554,7 @@ def local_run(inputs, build_results, timeout, number, repeat, min_repeat_ms, coo try: func = module.load_module(build_res.filename) ctx = ndarray.context(str(inp.task.target), 0) + # TODO(FrozenGene): Add cpu cache flush to this function. time_f = func.time_evaluator( func.entry_name, ctx, number=number, repeat=repeat, min_repeat_ms=min_repeat_ms) # pylint: disable=broad-except @@ -436,6 +565,7 @@ def local_run(inputs, build_results, timeout, number, repeat, min_repeat_ms, coo if error_no == 0: try: + # TODO(FrozenGene): Update to ndarray.non-empty. args = [ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in build_res.args] ctx.sync() @@ -478,3 +608,160 @@ def local_run(inputs, build_results, timeout, number, repeat, min_repeat_ms, coo print("") return measure_results + + +def rpc_run_worker(index): + """ Function to be ran in the RPCRunner thread pool. + + Parameters + ---------- + index : int + The MeasureInput and BuildResult index to be processed by the current Runner thread. + + Returns + ------- + res : MeasureResult + The measure result of this Runner thread. + """ + global GLOBAL_RUN_ARGUMENTS + inputs, build_results, key, host, port, priority, timeout, number, \ + repeat, min_repeat_ms, cooldown_interval, verbose = GLOBAL_RUN_ARGUMENTS + + max_float = 1e10 # We use 1e10 instead of sys.float_info.max for better readability in log + inp = inputs[index] + build_res = build_results[index] + + if build_res.error_no != MeasureErrorNo.NO_ERROR: + return (max_float,), build_res.error_no, build_res.error_msg, build_res.time_cost, \ + time.time() + + def timed_func(): + tic = time.time() + error_no = 0 + error_msg = None + try: + # upload built module + remote = request_remote(key, host, port, priority, timeout) + remote.upload(build_res.filename) + func = remote.load_module(os.path.split(build_res.filename)[1]) + ctx = remote.context(str(inp.task.target), 0) + # TODO(FrozenGene): Add cpu cache flush to this function. + time_f = func.time_evaluator( + func.entry_name, ctx, number=number, repeat=repeat, min_repeat_ms=min_repeat_ms) + # pylint: disable=broad-except + except Exception: + costs = (max_float,) + error_no = MeasureErrorNo.COMPILE_DEVICE + error_msg = make_error_msg() + + if error_no == 0: + try: + # TODO(FrozenGene): Update to ndarray.non-empty. + args = [ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in + build_res.args] + ctx.sync() + + costs = time_f(*args).results + # clean up remote files + remote.remove(build_res.filename) + remote.remove(os.path.splitext(build_res.filename)[0] + '.so') + remote.remove('') + # pylint: disable=broad-except + except Exception: + costs = (max_float,) + error_no = MeasureErrorNo.RUNTIME_DEVICE + error_msg = make_error_msg() + + shutil.rmtree(os.path.dirname(build_res.filename)) + toc = time.time() + + time.sleep(cooldown_interval) + if verbose >= 1: + if error_no == MeasureErrorNo.NO_ERROR: + print("*", end="") + else: + print("*E", end="") # Run error + + return costs, error_no, error_msg, toc - tic + build_res.time_cost, toc + + res = call_func_with_timeout(timeout, timed_func) + + if isinstance(res, TimeoutError): + if verbose >= 1: + print("*T", end="") # Run timeout + res = (max_float,), MeasureErrorNo.RUN_TIMEOUT, None, build_res.time_cost + \ + timeout, time.time() + return res + + +@tvm._ffi.register_func("auto_scheduler.rpc_runner.run") +def rpc_runner_run(inputs, build_results, key, host, port, + priority=1, n_parallel=1, timeout=10, number=3, repeat=1, min_repeat_ms=0, + cooldown_interval=0.0, verbose=1): + """ Run function of RPCRunner to test the performance of the input BuildResults. + + Parameters + ---------- + inputs : List[MeasureInput] + The MeasureInputs to be measured. + build_results : List[BuildResult] + The BuildResults to be measured. + key : str + The key of the device registered in the RPC tracker. + host : str + The host address of the RPC Tracker. + port : int + The port of RPC Tracker. + priority : int = 1 + The priority of this run request, larger is more prior. + n_parallel : int = 1 + The number of tasks run in parallel. + timeout : int = 10 + The timeout limit (in second) for each run. + This is used in a wrapper of the multiprocessing.Process.join(). + number : int = 3 + The number of times to run the generated code for taking average. + We call these runs as one `repeat` of measurement. + repeat : int = 1 + The number of times to repeat the measurement. + In total, the generated code will be run (1 + number x repeat) times, + where the first "1" is warm up and will be discarded. + The returned result contains `repeat` costs, + each of which is an average of `number` costs. + min_repeat_ms : int = 0 + The minimum duration of one `repeat` in milliseconds. + By default, one `repeat` contains `number` runs. If this parameter is set, + the parameters `number` will be dynamically adjusted to meet the + minimum duration requirement of one `repeat`. + i.e., When the run time of one `repeat` falls below this time, the `number` parameter + will be automatically increased. + cooldown_interval : float = 0.0 + The cool down interval between two measurements. + verbose: int = 1 + Verbosity level. 0 for silent, 1 to output information during program measuring. + + Returns + ------- + res : List[MeasureResult] + The measure results of these MeasureInputs. + """ + global GLOBAL_RUN_ARGUMENTS + GLOBAL_RUN_ARGUMENTS = (inputs, build_results, key, host, port, priority, timeout, number, + repeat, min_repeat_ms, cooldown_interval, verbose) + + assert len(inputs) == len(build_results), \ + "Measure input size should be equal to build results" + pool = NoDaemonPool(n_parallel) + tuple_res = pool.map(rpc_run_worker, range(len(build_results))) + pool.terminate() + pool.join() + del pool + + results = [] + for res in tuple_res: + results.append(MeasureResult(*res)) + + if verbose >= 1: + print("") + + return results diff --git a/python/tvm/auto_scheduler/utils.py b/python/tvm/auto_scheduler/utils.py index f7b1202..f5b53fb 100644 --- a/python/tvm/auto_scheduler/utils.py +++ b/python/tvm/auto_scheduler/utils.py @@ -22,12 +22,15 @@ import multiprocessing import multiprocessing.pool import queue import signal +import threading +import os try: import psutil except ImportError: raise ImportError("psutil not found, try `pip install psutil` to fix this") +from tvm import rpc from tvm.tir import expr from tvm.tir.transform import Simplify from tvm.ir.transform import Sequential @@ -193,3 +196,70 @@ def call_func_with_timeout(timeout, func, args=(), kwargs=None): del que return res + + +def request_remote(device_key, host=None, port=None, priority=1, timeout=60): + """ Request a remote session. + + Parameters + ---------- + device_key : str + The device key of registered device in tracker. + host : Optional[str] + The host address of rpc tracker. + If is none, will use environment variable "TVM_TRACKER_HOST". + port : Optional[int] + The port of rpc tracker. + If is none, will use environment variable "TVM_TRACKER_PORT". + priority : int = 1 + The priority of this request, larger is more prior. + timeout : int = 60 + The timeout of this session in second. + + Returns + ------- + remote : RPCSession + The connected remote RPCSession. + """ + # connect to the tracker + host = host or os.environ['TVM_TRACKER_HOST'] + port = port or int(os.environ['TVM_TRACKER_PORT']) + + tracker = rpc.connect_tracker(host, port) + remote = tracker.request(device_key, priority=priority, + session_timeout=timeout) + return remote + + +def check_remote(device_key, host=None, port=None, priority=100, timeout=10): + """ + Check the availability of a remote device. + + Parameters + ---------- + device_key: str + device key of registered device in tracker. + host: Optional[str] + The host address of rpc tracker. + If is none, will use environment variable "TVM_TRACKER_HOST". + port: Optional[int] + The port address of rpc tracker. + If is none, will use environment variable "TVM_TRACKER_PORT". + priority: int = 100 + The priority of this request, larger is more prior. + timeout: int = 10 + The timeout of this check in seconds. + + Returns + ------- + available: bool + True if can find available device. + """ + + def _check(): + request_remote(device_key, host, port, priority) + + t = threading.Thread(target=_check, ) + t.start() + t.join(timeout) + return not t.is_alive() diff --git a/python/tvm/rpc/server.py b/python/tvm/rpc/server.py index 15a3c7d..42bcb00 100644 --- a/python/tvm/rpc/server.py +++ b/python/tvm/rpc/server.py @@ -348,7 +348,8 @@ class Server(object): cmd = [sys.executable, "-m", "tvm.exec.rpc_server", "--host=%s" % host, - "--port=%s" % port] + "--port=%s" % port, + "--port-end=%s" % port_end] if tracker_addr: assert key cmd += ["--tracker=%s:%d" % tracker_addr, diff --git a/src/auto_scheduler/measure.cc b/src/auto_scheduler/measure.cc index 3c54552..6198f60 100644 --- a/src/auto_scheduler/measure.cc +++ b/src/auto_scheduler/measure.cc @@ -41,6 +41,7 @@ TVM_REGISTER_OBJECT_TYPE(ProgramRunnerNode); TVM_REGISTER_OBJECT_TYPE(ProgramBuilderNode); TVM_REGISTER_OBJECT_TYPE(LocalBuilderNode); TVM_REGISTER_OBJECT_TYPE(LocalRunnerNode); +TVM_REGISTER_OBJECT_TYPE(RPCRunnerNode); static const char* ErrorNoToStr[] = { "NoError", @@ -146,6 +147,39 @@ Array LocalRunnerNode::Run(const Array& inputs, throw; } +/********** RPCRunner **********/ +RPCRunner::RPCRunner(const String& key, const String& host, int port, int priority, int n_parallel, + int timeout, int number, int repeat, int min_repeat_ms, + double cooldown_interval) { + auto node = make_object(); + node->key = key; + node->host = host; + node->port = port; + node->priority = priority; + node->timeout = timeout; + node->n_parallel = n_parallel; + node->number = number; + node->repeat = repeat; + node->min_repeat_ms = min_repeat_ms; + node->cooldown_interval = cooldown_interval; + data_ = std::move(node); +} + +Array RPCRunnerNode::Run(const Array& inputs, + const Array& build_results, int verbose) { + if (const auto* f = runtime::Registry::Get("auto_scheduler.rpc_runner.run")) { + Array results = + (*f)(inputs, build_results, key, host, port, priority, n_parallel, timeout, number, repeat, + min_repeat_ms, cooldown_interval, verbose); + return results; + } else { + LOG(FATAL) << "auto_scheduler.rpc_runner.run is not registered. " + << "This is a function registered in Python, " + << "make sure the TVM Python runtime has been loaded successfully."; + } + return Array(); +} + /********** ProgramMeasurer **********/ ProgramMeasurer::ProgramMeasurer(ProgramBuilder builder, ProgramRunner runner, Optional> callbacks, int verbose, @@ -327,5 +361,13 @@ TVM_REGISTER_GLOBAL("auto_scheduler.LocalRunner") return LocalRunner(timeout, number, repeat, min_repeat_ms, cooldown_interval); }); +TVM_REGISTER_GLOBAL("auto_scheduler.RPCRunner") + .set_body_typed([](const String& key, const String& host, int port, int priority, + int n_parallel, int timeout, int number, int repeat, int min_repeat_ms, + double cooldown_interval) { + return RPCRunner(key, host, port, priority, n_parallel, timeout, number, repeat, + min_repeat_ms, cooldown_interval); + }); + } // namespace auto_scheduler } // namespace tvm diff --git a/src/auto_scheduler/measure.h b/src/auto_scheduler/measure.h index 50b6fcf..02d6e87 100644 --- a/src/auto_scheduler/measure.h +++ b/src/auto_scheduler/measure.h @@ -266,6 +266,14 @@ class ProgramRunnerNode : public Object { public: /*! \brief Timeout of a run. */ int timeout; + /*! \brief The number of times to run the generated code for taking average. */ + int number; + /*! \brief The number of times to repeat the measurement. */ + int repeat; + /*! \brief The minimum duration of one repeat in milliseconds. */ + int min_repeat_ms; + /*! \brief The cool down interval between two measurements. */ + double cooldown_interval; /*! * \brief Run measurement and return results. @@ -326,15 +334,6 @@ class LocalBuilder : public ProgramBuilder { /*! \brief LocalRunner that uses local CPU/GPU to measures the time cost of programs */ class LocalRunnerNode : public ProgramRunnerNode { public: - /*! \brief Number of measure times. */ - int number; - /*! \brief Number of repeat times in each measure. */ - int repeat; - /*! \brief The minimum duration of one repeat in milliseconds. */ - int min_repeat_ms; - /*! \brief The cool down interval between two measurements. */ - double cooldown_interval; - Array Run(const Array& inputs, const Array& build_results, int verbose) final; @@ -353,8 +352,8 @@ class LocalRunner : public ProgramRunner { * for more detailed parameter explaination. * \param timeout The timeout limit (in second) for each run. * This is used in a wrapper of the multiprocessing.Process.join(). - * \param number Number of measure times. - * \param repeat Number of repeat times in each measure. + * \param number The number of times to run the generated code for taking average. + * \param repeat The number of times to repeat the measurement. * \param min_repeat_ms The minimum duration of one repeat in milliseconds. * \param cooldown_interval The cool down interval between two measurements. */ @@ -364,6 +363,57 @@ class LocalRunner : public ProgramRunner { }; /*! + * \brief RPCRunner that uses RPC call to measures the time cost of programs on remote devices. + * Or sometime we may need to use RPC even in local running to insulate the thread environment. + * (e.g. running CUDA programs) + */ +class RPCRunnerNode : public ProgramRunnerNode { + public: + /*! \brief The key of the device registered in the RPC tracker. */ + String key; + /*! \brief The host address of the RPC Tracker. */ + String host; + /*! \brief The port of RPC Tracker. */ + int port; + /*! \brief The priority of this run request, larger is more prior. */ + int priority; + /*! \brief The number of tasks run in parallel. */ + int n_parallel; + /*! \brief The number of times to run the generated code for taking average. */ + + Array Run(const Array& inputs, + const Array& build_results, int verbose) final; + + static constexpr const char* _type_key = "auto_scheduler.RPCRunner"; + TVM_DECLARE_FINAL_OBJECT_INFO(RPCRunnerNode, ProgramRunnerNode); +}; + +/*! + * \brief Managed reference to RPCRunnerNode. + * \sa RPCRunnerNode + */ +class RPCRunner : public ProgramRunner { + public: + /*! + * \brief The constructor. + * \param key The key of the device registered in the RPC tracker. + * \param host The host address of the RPC Tracker. + * \param prot The port of RPC Tracker. + * \param priority The priority of this run request, larger is more prior. + * \param n_parallel The number of tasks run in parallel. + * \param timeout Timeout of a run. + * \param number The number of times to run the generated code for taking average. + * \param repeat The number of times to repeat the measurement. + * \param min_repeat_ms The minimum duration of one repeat in milliseconds. + * \param cooldown_interval The cool down interval between two measurements. + */ + RPCRunner(const String& key, const String& host, int port, int priority, int n_parallel, + int timeout, int number, int repeat, int min_repeat_ms, double cooldown_interval); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(RPCRunner, ProgramRunner, RPCRunnerNode); +}; + +/*! * \brief Measurer that measures the time costs of tvm programs * This class combines ProgramBuilder and ProgramRunner, and provides a simpler API */ class ProgramMeasurerNode : public Object { diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py index 1bcd054..d6e6c51 100644 --- a/tests/python/unittest/test_auto_scheduler_measure.py +++ b/tests/python/unittest/test_auto_scheduler_measure.py @@ -25,10 +25,10 @@ from test_auto_scheduler_common import get_tiled_matmul def test_record(): - dag, s = get_tiled_matmul() - if not tvm.runtime.enabled("llvm"): return + + dag, s = get_tiled_matmul() target = tvm.target.create("llvm") task = auto_scheduler.SearchTask(dag, "test", target) @@ -50,23 +50,43 @@ def test_record(): def test_measure_local_builder_runner(): + if not tvm.runtime.enabled("llvm"): + return + dag, s0 = get_tiled_matmul() + tgt = tvm.target.create("llvm") + task = auto_scheduler.SearchTask(dag, "test", tgt) + minp = auto_scheduler.MeasureInput(task, s0) + local_builder = auto_scheduler.LocalBuilder() + local_runner = auto_scheduler.LocalRunner(timeout=60) + + bress = local_builder.build([minp]) + assert bress[0].error_no == 0 + mress = local_runner.run([minp], bress) + assert mress[0].error_no == 0 + + +def test_measure_local_builder_rpc_runner(): if not tvm.runtime.enabled("llvm"): return + + dag, s0 = get_tiled_matmul() tgt = tvm.target.create("llvm") task = auto_scheduler.SearchTask(dag, "test", tgt) minp = auto_scheduler.MeasureInput(task, s0) local_builder = auto_scheduler.LocalBuilder() - local_runner = auto_scheduler.LocalRunner() + measure_ctx = auto_scheduler.LocalRPCMeasureContext(timeout=60) + rpc_runner = measure_ctx.runner bress = local_builder.build([minp]) assert bress[0].error_no == 0 - mress = local_runner.run([minp], bress) + mress = rpc_runner.run([minp], bress) assert mress[0].error_no == 0 if __name__ == "__main__": test_record() test_measure_local_builder_runner() + test_measure_local_builder_rpc_runner() -- 2.7.4