refactor build module to take IRModule (#4988)
authorZhi <5145158+zhiics@users.noreply.github.com>
Thu, 5 Mar 2020 15:37:26 +0000 (07:37 -0800)
committerGitHub <noreply@github.com>
Thu, 5 Mar 2020 15:37:26 +0000 (07:37 -0800)
include/tvm/relay/transform.h
python/tvm/relay/build_module.py
src/relay/backend/build_module.cc
src/relay/backend/vm/compiler.cc
tests/cpp/relay_build_module_test.cc

index 2837c1f..0a2c77a 100644 (file)
@@ -332,6 +332,16 @@ TVM_DLL Pass PartitionGraph();
  */
 TVM_DLL Pass Inline();
 
+/*!
+ * \brief Remove the unused functions in the Relay IRModule.
+ *
+ * \param entry_functions The entry functions used to search the functions that
+ *        are being used.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass RemoveUnusedFunctions(Array<tvm::PrimExpr> entry_functions);
+
 }  // namespace transform
 
 /*!
index 22e0b91..e894933 100644 (file)
@@ -62,7 +62,7 @@ def _convert_param_map(params):
 
 
 class BuildModule(object):
-    """Build a Relay function to run on TVM graph runtime. This class is used
+    """Build an IR module to run on TVM graph runtime. This class is used
     to expose the `RelayBuildModule` APIs implemented in C++.
     """
     def __init__(self):
@@ -74,12 +74,12 @@ class BuildModule(object):
         self._set_params_func = self.mod["set_params"]
         self._get_params_func = self.mod["get_params"]
 
-    def build(self, func, target=None, target_host=None, params=None):
+    def build(self, mod, target=None, target_host=None, params=None):
         """
         Parameters
         ----------
-        func: relay.Function
-            The function to build.
+        mod : :py:class:`~tvm.IRModule`
+            The IRModule to build.
 
         target : str, :any:`tvm.target.Target`, or dict of str(i.e.
         device/context name) to str/tvm.target.Target, optional
@@ -115,8 +115,8 @@ class BuildModule(object):
         # Setup the params.
         if params:
             self._set_params(params)
-        # Build the function
-        self._build(func, target, target_host)
+        # Build the IR module
+        self._build(mod, target, target_host)
         # Get artifacts
         graph_json = self.get_json()
         mod = self.get_module()
@@ -124,12 +124,12 @@ class BuildModule(object):
 
         return graph_json, mod, params
 
-    def optimize(self, func, target=None, params=None):
+    def optimize(self, mod, target=None, params=None):
         """
         Parameters
         ----------
-        func: relay.Function
-            The function to build.
+        mod : :py:class:`~tvm.IRModule`
+            The IR module to build.
 
         target : str, :any:`tvm.target.Target`, or dict of str(i.e.
         device/context name) to str/tvm.target.Target, optional
@@ -142,7 +142,7 @@ class BuildModule(object):
 
         Returns
         -------
-        mod : tvm.IRModule
+        mod : :py:class:`~tvm.IRModule`
             The optimized relay module.
 
         params : dict
@@ -153,7 +153,7 @@ class BuildModule(object):
         # Setup the params.
         if params:
             self._set_params(params)
-        mod = self._optimize(func, target)
+        mod = self._optimize(mod, target)
         # Get artifacts
         params = self.get_params()
 
@@ -186,8 +186,8 @@ def build(mod, target=None, target_host=None, params=None):
 
     Parameters
     ----------
-    mod : tvm.IRModule
-        The module to build. Using relay.Function is deprecated.
+    mod : :py:class:`~tvm.IRModule`
+        The IR module to build. Using relay.Function is deprecated.
 
     target : str, :any:`tvm.target.Target`, or dict of str(i.e. device/context
     name) to str/tvm.target.Target, optional
@@ -218,16 +218,15 @@ def build(mod, target=None, target_host=None, params=None):
     params : dict
         The parameters of the final graph.
     """
-    if isinstance(mod, IRModule):
-        func = mod["main"]
-    elif isinstance(mod, _expr.Function):
-        func = mod
+    if not isinstance(mod, (IRModule, _expr.Function)):
+        raise ValueError("Type of input parameter mod must be tvm.IRModule")
+
+    if isinstance(mod, _expr.Function):
+        mod = IRModule.from_expr(mod)
         warnings.warn(
             "Please use input parameter mod (tvm.IRModule) "
-            "instead of deprecated parameter func (tvm.relay.expr.Function)",
+            "instead of deprecated parameter mod (tvm.relay.expr.Function)",
             DeprecationWarning)
-    else:
-        raise ValueError("Type of input parameter mod must be tvm.IRModule")
 
     target = _update_target(target)
 
@@ -246,7 +245,7 @@ def build(mod, target=None, target_host=None, params=None):
 
     with tophub_context:
         bld_mod = BuildModule()
-        graph_json, mod, params = bld_mod.build(func, target, target_host, params)
+        graph_json, mod, params = bld_mod.build(mod, target, target_host, params)
     return graph_json, mod, params
 
 
@@ -255,7 +254,7 @@ def optimize(mod, target=None, params=None):
 
     Parameters
     ----------
-    mod : tvm.IRModule
+    mod : :py:class:`~tvm.IRModule`
         The module to build. Using relay.Function is deprecated.
 
     target : str, :any:`tvm.target.Target`, or dict of str(i.e. device/context
@@ -269,22 +268,21 @@ def optimize(mod, target=None, params=None):
 
     Returns
     -------
-    mod : tvm.IRModule
+    mod : :py:class:`~tvm.IRModule`
         The optimized relay module.
 
     params : dict
         The parameters of the final graph.
     """
-    if isinstance(mod, IRModule):
-        func = mod["main"]
-    elif isinstance(mod, _expr.Function):
-        func = mod
+    if not isinstance(mod, (IRModule, _expr.Function)):
+        raise ValueError("Type of input parameter mod must be tvm.IRModule")
+
+    if isinstance(mod, _expr.Function):
+        mod = IRModule.from_expr(mod)
         warnings.warn(
             "Please use input parameter mod (tvm.IRModule) "
             "instead of deprecated parameter func (tvm.relay.expr.Function)",
             DeprecationWarning)
-    else:
-        raise ValueError("Type of input parameter mod must be tvm.IRModule")
 
     target = _update_target(target)
 
@@ -297,7 +295,7 @@ def optimize(mod, target=None, params=None):
 
     with tophub_context:
         bld_mod = BuildModule()
-        mod, params = bld_mod.optimize(func, target, params)
+        mod, params = bld_mod.optimize(mod, target, params)
     return mod, params
 
 
index 0c0a8b8..61ec281 100644 (file)
@@ -233,42 +233,46 @@ class RelayBuildModule : public runtime::ModuleNode {
   }
 
   /*!
-   * \brief Build relay function for graph runtime
+   * \brief Build relay IRModule for graph runtime
    *
-   * \param func Relay Function
+   * \param mod Relay IRModule
    * \param target Target device
    * \param target_host Host target device
    */
-  void Build(Function func,
+  void Build(IRModule mod,
              const TargetsMap& targets,
              const tvm::Target& target_host) {
     targets_ = targets;
     target_host_ = target_host;
-    BuildRelay(func, params_);
+    BuildRelay(mod, params_);
   }
 
  protected:
   /*!
-   * \brief Optimize a Relay Function.
+   * \brief Optimize a Relay IRModule.
    *
-   * \param func The input Function where optmization will be applied on.
+   * \param relay_module The input IRModule where optmization will be applied on.
    * \param targets The device type to `Target` mapping.
    * \param params The param name to value mapping.
    *
-   * \return relay::Module The updated Relay module after optimization.
+   * \return relay::IRModule The updated Relay IR module after optimization.
    */
   IRModule Optimize(
-      Function func,
+      IRModule relay_module,
       const TargetsMap& targets,
       const std::unordered_map<std::string, runtime::NDArray>& params) {
     if (params.size()) {
-      func = BindParamsByName(func, params);
+      CHECK(relay_module->ContainGlobalVar("main"))
+        << "Missing the main entry function";
+      GlobalVar main_glb_var = relay_module->GetGlobalVar("main");
+      Function main_func = Downcast<Function>(relay_module->Lookup(main_glb_var));
+      auto new_main = BindParamsByName(main_func, params);
+      relay_module->Update(main_glb_var, new_main);
     }
 
-    // Perform Module->Module optimizations.
-    IRModule relay_module = IRModule::FromExpr(func);
-
     Array<Pass> pass_seqs;
+    Array<tvm::PrimExpr> entry_functions{tvm::PrimExpr{"main"}};
+    pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions));
 
     // Run all dialect legalization passes.
     pass_seqs.push_back(relay::qnn::transform::Legalize());
@@ -418,18 +422,18 @@ class RelayBuildModule : public runtime::ModuleNode {
   }
 
   /*!
-   * \brief Compile a Relay function to runtime module.
+   * \brief Compile a Relay IR module to runtime module.
    *
-   * \param func The Relay function.
+   * \param relay_module The Relay IR module.
    * \param params The parameters.
    */
   void BuildRelay(
-      Function func,
+      IRModule relay_module,
       const std::unordered_map<std::string, tvm::runtime::NDArray>& params) {
-    // Optimize input Relay Function and returns Relay Module
-    IRModule relay_module = Optimize(func, targets_, params);
+    // Relay IRModule -> IRModule optimizations.
+    relay_module = Optimize(relay_module, targets_, params);
     // Get the updated function.
-    func = Downcast<Function>(relay_module->Lookup("main"));
+    auto func = Downcast<Function>(relay_module->Lookup("main"));
 
     // Generate code for the updated function.
     graph_codegen_ = std::unique_ptr<GraphCodegen>(new GraphCodegen());
index 73a6450..2129b64 100644 (file)
@@ -51,7 +51,6 @@ namespace transform {
 
 Pass LambdaLift();
 Pass InlinePrimitives();
-Pass RemoveUnusedFunctions(Array<tvm::PrimExpr> entry_functions);
 
 Pass ManifestAlloc(Target target_host) {
   auto f = tvm::runtime::Registry::Get("relay.transform.ManifestAlloc");
index b9a8f8f..a94dce6 100644 (file)
@@ -29,6 +29,7 @@
 #include <topi/broadcast.h>
 #include <topi/generic/injective.h>
 #include <tvm/runtime/packed_func.h>
+#include <tvm/ir/module.h>
 #include <tvm/runtime/module.h>
 #include <tvm/runtime/registry.h>
 
@@ -115,7 +116,8 @@ TEST(Relay, BuildModule) {
   Map<tvm::Integer, tvm::Target> targets;
   Target llvm_tgt = Target::Create("llvm");
   targets.Set(0, llvm_tgt);
-  build_f(func, targets, llvm_tgt);
+  auto relay_mod = tvm::IRModule::FromExpr(func);
+  build_f(relay_mod, targets, llvm_tgt);
   std::string json = json_f();
   tvm::runtime::Module mod = mod_f();
   // run