Optimizing autotvm task extraction speed (#4138)
authorLiangHao <hliangac@connect.ust.hk>
Tue, 29 Oct 2019 18:45:02 +0000 (02:45 +0800)
committerYao Wang <kevinthesunwy@gmail.com>
Tue, 29 Oct 2019 18:45:02 +0000 (11:45 -0700)
* Optimize task extraction speed

* correct pylint errors

* Delete unused function

* remove unnecessary argument

* resolve code review comments

* corrent cpp lint errors

* remove one more graph_json return value

* fix test bugs

python/tvm/autotvm/task/relay_integration.py
python/tvm/relay/__init__.py
python/tvm/relay/backend/graph_runtime_codegen.py
python/tvm/relay/build_module.py
src/relay/backend/build_module.cc

index 55be05f..6ee8bc0 100644 (file)
@@ -31,23 +31,28 @@ from .topi_integration import TaskExtractEnv
 logger = logging.getLogger('autotvm')
 
 
-# TODO(moreau89) find a more elegant way to build for VTAs
-def _build(func,
+# TODO(moreau89) find a more elegant way to lower for VTAs
+def _lower(func,
            target,
-           target_host,
            params):
-    """ Helper to build VTA properly.
+    """ Helper to lower VTA properly.
     """
 
     from tvm import relay
+    from tvm.relay.backend import graph_runtime_codegen
 
     if hasattr(target, 'device_name') and target.device_name == "vta":
         with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}):
             import vta
             with vta.build_config():
-                return relay.build(func, target, target_host, params)
+                mod, _ = relay.optimize(func, target, params)
+                grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target)
+                return grc.codegen(mod["main"])
     # default case
-    return relay.build(func, target, target_host, params)
+    mod, _ = relay.optimize(func, target, params)
+    grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target)
+    return grc.codegen(mod["main"])
+
 
 def extract_from_program(func, params, ops, target, target_host=None):
     """ Extract tuning tasks from a relay program.
@@ -133,8 +138,8 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None):
             relay.backend.compile_engine.get().clear()
             # wrap build call in thread to avoid multiprocessing problems
             mod = relay.Module.from_expr(func)
-            build_thread = threading.Thread(target=_build,
-                                            args=(mod, target, target_host, param))
+            build_thread = threading.Thread(target=_lower,
+                                            args=(mod, target, param))
             build_thread.start()
             build_thread.join()
 
index fff9c99..f05098b 100644 (file)
@@ -28,7 +28,7 @@ from . import module
 from . import adt
 from . import analysis
 from . import transform
-from .build_module import build, create_executor
+from .build_module import build, create_executor, optimize
 from .transform import build_config
 from . import prelude
 from . import parser
index cf31e9c..73a700e 100644 (file)
@@ -36,7 +36,7 @@ contrib.graph_runtime or any other TVM runtime compatible systems.
 from __future__ import absolute_import
 
 from tvm.ndarray import empty
-from tvm.relay import build_module
+from tvm.relay import _build_module
 from tvm import target as _target
 from tvm import expr as _expr
 
@@ -44,7 +44,7 @@ class GraphRuntimeCodegen(object):
     """The compiler from Relay to the TVM runtime system."""
 
     def __init__(self, mod, target):
-        self._mod = build_module._GraphRuntimeCodegen()
+        self._mod = _build_module._GraphRuntimeCodegen()
         self._init = self._mod["init"]
         self._codegen = self._mod["codegen"]
         self._get_graph_json = self._mod["get_graph_json"]
index 404829f..28ce16b 100644 (file)
@@ -60,6 +60,7 @@ class BuildModule(object):
         self._get_graph_json = self.mod["get_graph_json"]
         self._get_module = self.mod["get_module"]
         self._build = self.mod["build"]
+        self._optimize = self.mod["optimize"]
         self._set_params_func = self.mod["set_params"]
         self._get_params_func = self.mod["get_params"]
 
@@ -113,6 +114,42 @@ class BuildModule(object):
 
         return graph_json, mod, params
 
+    def optimize(self, func, target=None, params=None):
+        """
+        Parameters
+        ----------
+        func: relay.Function
+            The function to build.
+
+        target : str, :any:`tvm.target.Target`, or dict of str(i.e.
+        device/context name) to str/tvm.target.Target, optional
+            For heterogeneous compilation, it is a dictionary indicating context
+            to target mapping. For homogeneous compilation, it is a build target.
+
+        params : dict of str to NDArray
+            Input parameters to the graph that do not change
+            during inference time. Used for constant folding.
+
+        Returns
+        -------
+        mod : relay.Module
+            The optimized relay module.
+
+        params : dict
+            The parameters of the final graph.
+        """
+        target = _update_target(target)
+
+        # Setup the params.
+        if params:
+            self._set_params(params)
+        mod = self._optimize(func, target)
+        # Get artifacts
+        params = self.get_params()
+
+        return mod, params
+
+
     def _set_params(self, params):
         inputs = {}
         for name, param in params.items():
@@ -208,6 +245,57 @@ def build(mod, target=None, target_host=None, params=None):
     return graph_json, mod, params
 
 
+def optimize(mod, target=None, params=None):
+    """Helper function that optimizes a Relay module.
+
+    Parameters
+    ----------
+    mod : relay.Module
+        The module to build. Using relay.Function is deprecated.
+
+    target : str, :any:`tvm.target.Target`, or dict of str(i.e. device/context
+    name) to str/tvm.target.Target, optional
+        For heterogeneous compilation, it is a dictionary indicating context to
+        target mapping. For homogeneous compilation, it is a build target.
+
+    params : dict of str to NDArray
+        Input parameters to the graph that do not change
+        during inference time. Used for constant folding.
+
+    Returns
+    -------
+    mod : relay.Module
+        The optimized relay module.
+
+    params : dict
+        The parameters of the final graph.
+    """
+    if isinstance(mod, _Module):
+        func = mod["main"]
+    elif isinstance(mod, _expr.Function):
+        func = mod
+        warnings.warn(
+            "Please use input parameter mod (tvm.relay.module.Module) "
+            "instead of deprecated parameter func (tvm.relay.expr.Function)",
+            DeprecationWarning)
+    else:
+        raise ValueError("Type of input parameter mod must be tvm.relay.module.Module")
+
+    target = _update_target(target)
+
+    # If current dispatch context is fallback context (the default root context),
+    # then load pre-tuned parameters from TopHub
+    if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext):
+        tophub_context = autotvm.tophub.context(list(target.values()))
+    else:
+        tophub_context = autotvm.util.EmptyContext()
+
+    with tophub_context:
+        bld_mod = BuildModule()
+        mod, params = bld_mod.optimize(func, target, params)
+    return mod, params
+
+
 class GraphExecutor(_interpreter.Executor):
     """Wrapper around Executor interface.
 
index dfe85fc..73cf6c2 100644 (file)
@@ -148,6 +148,11 @@ class RelayBuildModule : public runtime::ModuleNode {
       return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
           *rv = this->graph_codegen_->GetLoweredFunc();
       });
+    } else if (name == "optimize") {
+      return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+        CHECK_EQ(args.num_args, 2);
+        *rv = this->Optimize(args[0], args[1], this->params_);
+      });
     } else {
       LOG(FATAL) << "Unknown packed function: " << name;
       return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {});
@@ -273,19 +278,25 @@ class RelayBuildModule : public runtime::ModuleNode {
   }
 
   /*!
-   * \brief Optimize a Relay module.
+   * \brief Optimize a Relay Function.
    *
-   * \param relay_module The input Relay module where optmization will be
-   *        applied on.
+   * \param func The input Function where optmization will be applied on.
    * \param targets The device type to `Target` mapping.
    * \param params The param name to value mapping.
    *
    * \return relay::Module The updated Relay module after optimization.
    */
   relay::Module Optimize(
-      relay::Module relay_module,
+      Function func,
       const TargetsMap& targets,
       const std::unordered_map<std::string, runtime::NDArray>& params) {
+    if (params.size()) {
+      func = BindParamsByName(func, params);
+    }
+
+    // Perform Module->Module optimizations.
+    relay::Module relay_module = relay::ModuleNode::FromExpr(func);
+
     Array<Pass> pass_seqs;
 
     // Run all dialect legalization passes.
@@ -345,6 +356,7 @@ class RelayBuildModule : public runtime::ModuleNode {
     // Fuse the operations if it is needed.
     relay_module = transform::FuseOps()(relay_module);
     relay_module = transform::InferType()(relay_module);
+    CHECK(relay_module.defined());
 
     return relay_module;
   }
@@ -440,14 +452,8 @@ class RelayBuildModule : public runtime::ModuleNode {
   void BuildRelay(
       Function func,
       const std::unordered_map<std::string, tvm::runtime::NDArray>& params) {
-    if (params.size()) {
-      func = BindParamsByName(func, params);
-    }
-
-    // Perform Module->Module optimizations.
-    relay::Module relay_module = relay::ModuleNode::FromExpr(func);
-    relay_module = Optimize(relay_module, targets_, params);
-    CHECK(relay_module.defined());
+    // Optimize input Relay Function and returns Relay Module
+    relay::Module relay_module = Optimize(func, targets_, params);
     // Get the updated function.
     func = relay_module->Lookup("main");