from tvm.runtime import Object, module, ndarray
from tvm.driver import build_module
from tvm.ir import transform
+from tvm.rpc.tracker import Tracker
+from tvm.rpc.server import Server
+from tvm.autotvm.measure.measure_methods import set_cuda_target_arch
from tvm.contrib import tar, ndk
from . import _ffi_api
-from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout
+from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout, request_remote, \
+ check_remote
# The maximum length of error message
MAX_ERROR_MSG_LEN = 512
# We use fork and a global variable to copy arguments between processings.
# This can avoid expensive serialization of TVM IR when using multiprocessing.Pool
GLOBAL_BUILD_ARGUMENTS = None
+GLOBAL_RUN_ARGUMENTS = None
@tvm._ffi.register_object("auto_scheduler.MeasureCallback")
class MeasureCallback(Object):
class LocalRunner(ProgramRunner):
""" LocalRunner that uses local CPU/GPU to measures the time cost of programs.
+ TODO(FrozenGene): Add cpu cache flush to this runner.
+
Parameters
----------
timeout : int = 10
_ffi_api.LocalRunner, timeout, number, repeat, min_repeat_ms, cooldown_interval)
+@tvm._ffi.register_object("auto_scheduler.RPCRunner")
+class RPCRunner(ProgramRunner):
+ """ RPCRunner that uses RPC call to measures the time cost of programs on remote devices.
+ Or sometime we may need to use RPC even in local running to insulate the thread environment.
+ (e.g. running CUDA programs)
+
+ TODO(FrozenGene): Add cpu cache flush to this runner.
+
+ Parameters
+ ----------
+ key : str
+ The key of the device registered in the RPC tracker.
+ host : str
+ The host address of the RPC Tracker.
+ port : int
+ The port of RPC Tracker.
+ priority : int = 1
+ The priority of this run request, larger is more prior.
+ n_parallel : int = 1
+ The number of tasks run in parallel.
+ timeout : int = 10
+ The timeout limit (in second) for each run.
+ This is used in a wrapper of the multiprocessing.Process.join().
+ number : int = 3
+ The number of times to run the generated code for taking average.
+ We call these runs as one `repeat` of measurement.
+ repeat : int = 1
+ The number of times to repeat the measurement.
+ In total, the generated code will be run (1 + number x repeat) times,
+ where the first "1" is warm up and will be discarded.
+ The returned result contains `repeat` costs,
+ each of which is an average of `number` costs.
+ min_repeat_ms : int = 0
+ The minimum duration of one `repeat` in milliseconds.
+ By default, one `repeat` contains `number` runs. If this parameter is set,
+ the parameters `number` will be dynamically adjusted to meet the
+ minimum duration requirement of one `repeat`.
+ i.e., When the run time of one `repeat` falls below this time, the `number` parameter
+ will be automatically increased.
+ cooldown_interval : float = 0.0
+ The cool down interval between two measurements.
+ """
+
+ def __init__(self, key, host, port,
+ priority=1, n_parallel=1, timeout=10, number=3, repeat=1,
+ min_repeat_ms=0, cooldown_interval=0.0):
+ self.__init_handle_by_constructor__(
+ _ffi_api.RPCRunner, key, host, port, priority, n_parallel, timeout,
+ number, repeat, min_repeat_ms, cooldown_interval)
+
+ if check_remote(key, host, port, priority, timeout):
+ print("Get devices for measurement successfully!")
+ else:
+ raise RuntimeError("Cannot get remote devices from the tracker. "
+ "Please check the status of tracker by "
+ "'python -m tvm.exec.query_rpc_tracker --port [THE PORT YOU USE]' "
+ "and make sure you have free devices on the queue status.")
+
+
+class LocalRPCMeasureContext:
+ """ A context wrapper for running RPCRunner locally.
+ This will launch a local RPC Tracker and local RPC Server.
+
+ TODO(FrozenGene): Add cpu cache flush to this RPC context.
+
+ Parameters
+ ----------
+ priority : int = 1
+ The priority of this run request, larger is more prior.
+ n_parallel : int = 1
+ The number of tasks run in parallel.
+ timeout : int = 10
+ The timeout limit (in second) for each run.
+ This is used in a wrapper of the multiprocessing.Process.join().
+ number : int = 3
+ The number of times to run the generated code for taking average.
+ We call these runs as one `repeat` of measurement.
+ repeat : int = 1
+ The number of times to repeat the measurement.
+ In total, the generated code will be run (1 + number x repeat) times,
+ where the first "1" is warm up and will be discarded.
+ The returned result contains `repeat` costs,
+ each of which is an average of `number` costs.
+ min_repeat_ms : int = 0
+ The minimum duration of one `repeat` in milliseconds.
+ By default, one `repeat` contains `number` runs. If this parameter is set,
+ the parameters `number` will be dynamically adjusted to meet the
+ minimum duration requirement of one `repeat`.
+ i.e., When the run time of one `repeat` falls below this time, the `number` parameter
+ will be automatically increased.
+ cooldown_interval : float = 0.0
+ The cool down interval between two measurements.
+ """
+
+ def __init__(self, priority=1, n_parallel=1, timeout=10, number=3, repeat=1,
+ min_repeat_ms=0, cooldown_interval=0.0):
+ ctx = tvm.context("cuda", 0)
+ if ctx.exist:
+ cuda_arch = "sm_" + "".join(ctx.compute_version.split('.'))
+ set_cuda_target_arch(cuda_arch)
+ host = '0.0.0.0'
+ self.tracker = Tracker(host, port=9000, port_end=10000, silent=True)
+ device_key = '$local$device$%d' % self.tracker.port
+ self.server = Server(host, port=self.tracker.port, port_end=10000,
+ key=device_key, use_popen=True, silent=True,
+ tracker_addr=(self.tracker.host, self.tracker.port))
+ self.runner = RPCRunner(device_key, host, self.tracker.port, priority,
+ n_parallel, timeout, number, repeat,
+ min_repeat_ms, cooldown_interval)
+ # Wait for the processes to start
+ time.sleep(0.5)
+
+ def __del__(self):
+ # Close the tracker and server before exit
+ self.tracker.terminate()
+ self.server.terminate()
+
+
class MeasureErrorNo(object):
""" Error type for MeasureResult. """
NO_ERROR = 0 # No error
dirname, "tmp_func." + build_func.output_format)
try:
- with transform.PassContext(): # todo(lmzheng): port the unroll pass
+ # TODO(merrymercy): Port the unroll pass.
+ with transform.PassContext():
func = build_module.build(
sch, args, target=task.target, target_host=task.target_host)
func.export_library(filename, build_func)
return results
+
@tvm._ffi.register_func("auto_scheduler.local_runner.run")
-def local_run(inputs, build_results, timeout, number, repeat, min_repeat_ms, cooldown_interval,
+def local_run(inputs, build_results,
+ timeout=10, number=3, repeat=1, min_repeat_ms=0, cooldown_interval=0,
verbose=1):
"""
Run function of LocalRunner to test the performance of the input BuildResults.
The MeasureInputs to be measured.
build_results : List[BuildResult]
The BuildResults to be measured.
- timeout : int
+ timeout : int = 10
The timeout limit (in second) for each run.
This is used in a wrapper of the multiprocessing.Process.join().
number : int = 3
try:
func = module.load_module(build_res.filename)
ctx = ndarray.context(str(inp.task.target), 0)
+ # TODO(FrozenGene): Add cpu cache flush to this function.
time_f = func.time_evaluator(
func.entry_name, ctx, number=number, repeat=repeat, min_repeat_ms=min_repeat_ms)
# pylint: disable=broad-except
if error_no == 0:
try:
+ # TODO(FrozenGene): Update to ndarray.non-empty.
args = [ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in
build_res.args]
ctx.sync()
print("")
return measure_results
+
+
+def rpc_run_worker(index):
+ """ Function to be ran in the RPCRunner thread pool.
+
+ Parameters
+ ----------
+ index : int
+ The MeasureInput and BuildResult index to be processed by the current Runner thread.
+
+ Returns
+ -------
+ res : MeasureResult
+ The measure result of this Runner thread.
+ """
+ global GLOBAL_RUN_ARGUMENTS
+ inputs, build_results, key, host, port, priority, timeout, number, \
+ repeat, min_repeat_ms, cooldown_interval, verbose = GLOBAL_RUN_ARGUMENTS
+
+ max_float = 1e10 # We use 1e10 instead of sys.float_info.max for better readability in log
+ inp = inputs[index]
+ build_res = build_results[index]
+
+ if build_res.error_no != MeasureErrorNo.NO_ERROR:
+ return (max_float,), build_res.error_no, build_res.error_msg, build_res.time_cost, \
+ time.time()
+
+ def timed_func():
+ tic = time.time()
+ error_no = 0
+ error_msg = None
+ try:
+ # upload built module
+ remote = request_remote(key, host, port, priority, timeout)
+ remote.upload(build_res.filename)
+ func = remote.load_module(os.path.split(build_res.filename)[1])
+ ctx = remote.context(str(inp.task.target), 0)
+ # TODO(FrozenGene): Add cpu cache flush to this function.
+ time_f = func.time_evaluator(
+ func.entry_name, ctx, number=number, repeat=repeat, min_repeat_ms=min_repeat_ms)
+ # pylint: disable=broad-except
+ except Exception:
+ costs = (max_float,)
+ error_no = MeasureErrorNo.COMPILE_DEVICE
+ error_msg = make_error_msg()
+
+ if error_no == 0:
+ try:
+ # TODO(FrozenGene): Update to ndarray.non-empty.
+ args = [ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in
+ build_res.args]
+ ctx.sync()
+
+ costs = time_f(*args).results
+ # clean up remote files
+ remote.remove(build_res.filename)
+ remote.remove(os.path.splitext(build_res.filename)[0] + '.so')
+ remote.remove('')
+ # pylint: disable=broad-except
+ except Exception:
+ costs = (max_float,)
+ error_no = MeasureErrorNo.RUNTIME_DEVICE
+ error_msg = make_error_msg()
+
+ shutil.rmtree(os.path.dirname(build_res.filename))
+ toc = time.time()
+
+ time.sleep(cooldown_interval)
+ if verbose >= 1:
+ if error_no == MeasureErrorNo.NO_ERROR:
+ print("*", end="")
+ else:
+ print("*E", end="") # Run error
+
+ return costs, error_no, error_msg, toc - tic + build_res.time_cost, toc
+
+ res = call_func_with_timeout(timeout, timed_func)
+
+ if isinstance(res, TimeoutError):
+ if verbose >= 1:
+ print("*T", end="") # Run timeout
+ res = (max_float,), MeasureErrorNo.RUN_TIMEOUT, None, build_res.time_cost + \
+ timeout, time.time()
+ return res
+
+
+@tvm._ffi.register_func("auto_scheduler.rpc_runner.run")
+def rpc_runner_run(inputs, build_results, key, host, port,
+ priority=1, n_parallel=1, timeout=10, number=3, repeat=1, min_repeat_ms=0,
+ cooldown_interval=0.0, verbose=1):
+ """ Run function of RPCRunner to test the performance of the input BuildResults.
+
+ Parameters
+ ----------
+ inputs : List[MeasureInput]
+ The MeasureInputs to be measured.
+ build_results : List[BuildResult]
+ The BuildResults to be measured.
+ key : str
+ The key of the device registered in the RPC tracker.
+ host : str
+ The host address of the RPC Tracker.
+ port : int
+ The port of RPC Tracker.
+ priority : int = 1
+ The priority of this run request, larger is more prior.
+ n_parallel : int = 1
+ The number of tasks run in parallel.
+ timeout : int = 10
+ The timeout limit (in second) for each run.
+ This is used in a wrapper of the multiprocessing.Process.join().
+ number : int = 3
+ The number of times to run the generated code for taking average.
+ We call these runs as one `repeat` of measurement.
+ repeat : int = 1
+ The number of times to repeat the measurement.
+ In total, the generated code will be run (1 + number x repeat) times,
+ where the first "1" is warm up and will be discarded.
+ The returned result contains `repeat` costs,
+ each of which is an average of `number` costs.
+ min_repeat_ms : int = 0
+ The minimum duration of one `repeat` in milliseconds.
+ By default, one `repeat` contains `number` runs. If this parameter is set,
+ the parameters `number` will be dynamically adjusted to meet the
+ minimum duration requirement of one `repeat`.
+ i.e., When the run time of one `repeat` falls below this time, the `number` parameter
+ will be automatically increased.
+ cooldown_interval : float = 0.0
+ The cool down interval between two measurements.
+ verbose: int = 1
+ Verbosity level. 0 for silent, 1 to output information during program measuring.
+
+ Returns
+ -------
+ res : List[MeasureResult]
+ The measure results of these MeasureInputs.
+ """
+ global GLOBAL_RUN_ARGUMENTS
+ GLOBAL_RUN_ARGUMENTS = (inputs, build_results, key, host, port, priority, timeout, number,
+ repeat, min_repeat_ms, cooldown_interval, verbose)
+
+ assert len(inputs) == len(build_results), \
+ "Measure input size should be equal to build results"
+ pool = NoDaemonPool(n_parallel)
+ tuple_res = pool.map(rpc_run_worker, range(len(build_results)))
+ pool.terminate()
+ pool.join()
+ del pool
+
+ results = []
+ for res in tuple_res:
+ results.append(MeasureResult(*res))
+
+ if verbose >= 1:
+ print("")
+
+ return results
TVM_REGISTER_OBJECT_TYPE(ProgramBuilderNode);
TVM_REGISTER_OBJECT_TYPE(LocalBuilderNode);
TVM_REGISTER_OBJECT_TYPE(LocalRunnerNode);
+TVM_REGISTER_OBJECT_TYPE(RPCRunnerNode);
static const char* ErrorNoToStr[] = {
"NoError",
throw;
}
+/********** RPCRunner **********/
+RPCRunner::RPCRunner(const String& key, const String& host, int port, int priority, int n_parallel,
+ int timeout, int number, int repeat, int min_repeat_ms,
+ double cooldown_interval) {
+ auto node = make_object<RPCRunnerNode>();
+ node->key = key;
+ node->host = host;
+ node->port = port;
+ node->priority = priority;
+ node->timeout = timeout;
+ node->n_parallel = n_parallel;
+ node->number = number;
+ node->repeat = repeat;
+ node->min_repeat_ms = min_repeat_ms;
+ node->cooldown_interval = cooldown_interval;
+ data_ = std::move(node);
+}
+
+Array<MeasureResult> RPCRunnerNode::Run(const Array<MeasureInput>& inputs,
+ const Array<BuildResult>& build_results, int verbose) {
+ if (const auto* f = runtime::Registry::Get("auto_scheduler.rpc_runner.run")) {
+ Array<MeasureResult> results =
+ (*f)(inputs, build_results, key, host, port, priority, n_parallel, timeout, number, repeat,
+ min_repeat_ms, cooldown_interval, verbose);
+ return results;
+ } else {
+ LOG(FATAL) << "auto_scheduler.rpc_runner.run is not registered. "
+ << "This is a function registered in Python, "
+ << "make sure the TVM Python runtime has been loaded successfully.";
+ }
+ return Array<MeasureResult>();
+}
+
/********** ProgramMeasurer **********/
ProgramMeasurer::ProgramMeasurer(ProgramBuilder builder, ProgramRunner runner,
Optional<Array<MeasureCallback>> callbacks, int verbose,
return LocalRunner(timeout, number, repeat, min_repeat_ms, cooldown_interval);
});
+TVM_REGISTER_GLOBAL("auto_scheduler.RPCRunner")
+ .set_body_typed([](const String& key, const String& host, int port, int priority,
+ int n_parallel, int timeout, int number, int repeat, int min_repeat_ms,
+ double cooldown_interval) {
+ return RPCRunner(key, host, port, priority, n_parallel, timeout, number, repeat,
+ min_repeat_ms, cooldown_interval);
+ });
+
} // namespace auto_scheduler
} // namespace tvm
public:
/*! \brief Timeout of a run. */
int timeout;
+ /*! \brief The number of times to run the generated code for taking average. */
+ int number;
+ /*! \brief The number of times to repeat the measurement. */
+ int repeat;
+ /*! \brief The minimum duration of one repeat in milliseconds. */
+ int min_repeat_ms;
+ /*! \brief The cool down interval between two measurements. */
+ double cooldown_interval;
/*!
* \brief Run measurement and return results.
/*! \brief LocalRunner that uses local CPU/GPU to measures the time cost of programs */
class LocalRunnerNode : public ProgramRunnerNode {
public:
- /*! \brief Number of measure times. */
- int number;
- /*! \brief Number of repeat times in each measure. */
- int repeat;
- /*! \brief The minimum duration of one repeat in milliseconds. */
- int min_repeat_ms;
- /*! \brief The cool down interval between two measurements. */
- double cooldown_interval;
-
Array<MeasureResult> Run(const Array<MeasureInput>& inputs,
const Array<BuildResult>& build_results, int verbose) final;
* for more detailed parameter explaination.
* \param timeout The timeout limit (in second) for each run.
* This is used in a wrapper of the multiprocessing.Process.join().
- * \param number Number of measure times.
- * \param repeat Number of repeat times in each measure.
+ * \param number The number of times to run the generated code for taking average.
+ * \param repeat The number of times to repeat the measurement.
* \param min_repeat_ms The minimum duration of one repeat in milliseconds.
* \param cooldown_interval The cool down interval between two measurements.
*/
};
/*!
+ * \brief RPCRunner that uses RPC call to measures the time cost of programs on remote devices.
+ * Or sometime we may need to use RPC even in local running to insulate the thread environment.
+ * (e.g. running CUDA programs)
+ */
+class RPCRunnerNode : public ProgramRunnerNode {
+ public:
+ /*! \brief The key of the device registered in the RPC tracker. */
+ String key;
+ /*! \brief The host address of the RPC Tracker. */
+ String host;
+ /*! \brief The port of RPC Tracker. */
+ int port;
+ /*! \brief The priority of this run request, larger is more prior. */
+ int priority;
+ /*! \brief The number of tasks run in parallel. */
+ int n_parallel;
+ /*! \brief The number of times to run the generated code for taking average. */
+
+ Array<MeasureResult> Run(const Array<MeasureInput>& inputs,
+ const Array<BuildResult>& build_results, int verbose) final;
+
+ static constexpr const char* _type_key = "auto_scheduler.RPCRunner";
+ TVM_DECLARE_FINAL_OBJECT_INFO(RPCRunnerNode, ProgramRunnerNode);
+};
+
+/*!
+ * \brief Managed reference to RPCRunnerNode.
+ * \sa RPCRunnerNode
+ */
+class RPCRunner : public ProgramRunner {
+ public:
+ /*!
+ * \brief The constructor.
+ * \param key The key of the device registered in the RPC tracker.
+ * \param host The host address of the RPC Tracker.
+ * \param prot The port of RPC Tracker.
+ * \param priority The priority of this run request, larger is more prior.
+ * \param n_parallel The number of tasks run in parallel.
+ * \param timeout Timeout of a run.
+ * \param number The number of times to run the generated code for taking average.
+ * \param repeat The number of times to repeat the measurement.
+ * \param min_repeat_ms The minimum duration of one repeat in milliseconds.
+ * \param cooldown_interval The cool down interval between two measurements.
+ */
+ RPCRunner(const String& key, const String& host, int port, int priority, int n_parallel,
+ int timeout, int number, int repeat, int min_repeat_ms, double cooldown_interval);
+
+ TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(RPCRunner, ProgramRunner, RPCRunnerNode);
+};
+
+/*!
* \brief Measurer that measures the time costs of tvm programs
* This class combines ProgramBuilder and ProgramRunner, and provides a simpler API */
class ProgramMeasurerNode : public Object {