Dedup BindParamByName function in VM compiler (#4793)
authormasahi <masahi129@gmail.com>
Thu, 30 Jan 2020 19:09:48 +0000 (04:09 +0900)
committerGitHub <noreply@github.com>
Thu, 30 Jan 2020 19:09:48 +0000 (04:09 +0900)
src/relay/backend/build_module.cc
src/relay/backend/utils.h
src/relay/backend/vm/compiler.cc
src/relay/backend/vm/compiler.h

index 035ab1b..ff64d4a 100644 (file)
@@ -42,43 +42,6 @@ using TargetsMap = Map<tvm::Integer, tvm::Target>;
 using namespace tvm::relay::transform;
 
 /*!
- * \brief Bind params to function by using name
- * \param func Relay function
- * \param params params dict
- * \return relay::Function
- */
-relay::Function BindParamsByName(relay::Function func,
-                                 const std::unordered_map<std::string, runtime::NDArray>& params) {
-  std::unordered_map<std::string, relay::Var> name_dict;
-  std::unordered_set<relay::Var, ObjectHash, ObjectEqual> repeat_var;
-  for (auto arg : func->params) {
-    const auto& name = arg->name_hint();
-    if (name_dict.count(name)) {
-      repeat_var.insert(arg);
-    } else {
-      name_dict[name] = arg;
-    }
-  }
-
-  std::unordered_map<relay::Var, Expr, ObjectHash, ObjectEqual> bind_dict;
-  for (auto& kv : params) {
-    if (name_dict.count(kv.first) == 0) {
-      continue;
-    }
-    auto arg = name_dict.at(kv.first);
-    if (repeat_var.count(arg)) {
-      LOG(FATAL) << "Multiple args in the function have name " << kv.first;
-    }
-    bind_dict[arg] = ConstantNode::make(kv.second);
-  }
-  Expr bound_expr = relay::Bind(func, bind_dict);
-  Function ret = Downcast<Function>(bound_expr);
-  CHECK(ret.defined()) << "The returning type is expected to be a Relay Function."
-                       << "\n";
-  return ret;
-}
-
-/*!
  * \brief Output of building module
  *
  */
@@ -527,7 +490,7 @@ TVM_REGISTER_GLOBAL("relay.build_module.BindParamsByName")
   for (const auto& kv : params) {
     params_[kv.first] = kv.second->data;
   }
-  *rv = BindParamsByName(args[0], params_);
+  *rv = relay::backend::BindParamsByName(args[0], params_);
 });
 
 }  // namespace backend
index 24e338e..cccd4ba 100644 (file)
@@ -27,6 +27,7 @@
 #include <dmlc/json.h>
 #include <tvm/relay/expr.h>
 #include <tvm/relay/type.h>
+#include <tvm/relay/transform.h>
 #include <tvm/driver/driver_api.h>
 #include <tvm/target/codegen.h>
 #include <tvm/tir/ir_pass.h>
@@ -34,6 +35,8 @@
 
 #include <typeinfo>
 #include <string>
+#include <unordered_map>
+#include <unordered_set>
 
 namespace tvm {
 namespace relay {
@@ -81,6 +84,44 @@ inline std::string DType2String(const tvm::DataType dtype) {
   return os.str();
 }
 
+/*!
+ * \brief Bind params to function by using name
+ * \param func Relay function
+ * \param params params dict
+ * \return relay::Function
+ */
+inline relay::Function
+BindParamsByName(relay::Function func,
+                 const std::unordered_map<std::string, runtime::NDArray>& params) {
+  std::unordered_map<std::string, relay::Var> name_dict;
+  std::unordered_set<relay::Var, ObjectHash, ObjectEqual> repeat_var;
+  for (auto arg : func->params) {
+    const auto& name = arg->name_hint();
+    if (name_dict.count(name)) {
+      repeat_var.insert(arg);
+    } else {
+      name_dict[name] = arg;
+    }
+  }
+
+  std::unordered_map<relay::Var, Expr, ObjectHash, ObjectEqual> bind_dict;
+  for (auto& kv : params) {
+    if (name_dict.count(kv.first) == 0) {
+      continue;
+    }
+    auto arg = name_dict.at(kv.first);
+    if (repeat_var.count(arg)) {
+      LOG(FATAL) << "Multiple args in the function have name " << kv.first;
+    }
+    bind_dict[arg] = ConstantNode::make(kv.second);
+  }
+  Expr bound_expr = relay::Bind(func, bind_dict);
+  Function ret = Downcast<Function>(bound_expr);
+  CHECK(ret.defined()) << "The returning type is expected to be a Relay Function."
+                       << "\n";
+  return ret;
+}
+
 }  // namespace backend
 }  // namespace relay
 }  // namespace tvm
index c899644..cc5d6bc 100644 (file)
@@ -37,9 +37,8 @@
 #include <memory>
 #include <string>
 #include <tuple>
-#include <unordered_map>
-#include <unordered_set>
 #include <vector>
+#include "../utils.h"
 #include "../../backend/compile_engine.h"
 #include "../../pass/pass_util.h"
 #include "../../op/op_common.h"
@@ -783,38 +782,6 @@ void VMCompiler::SetParam(const std::string& name, runtime::NDArray data_in) {
   params_[name] = data_in;
 }
 
-relay::Function VMCompiler::BindParamsByName(
-    relay::Function func,
-    const std::unordered_map<std::string, runtime::NDArray>& params) {
-  std::unordered_map<std::string, relay::Var> name_dict;
-  std::unordered_set<relay::Var, ObjectHash, ObjectEqual> repeat_var;
-  for (auto arg : func->params) {
-    const auto &name = arg->name_hint();
-    if (name_dict.count(name)) {
-      repeat_var.insert(arg);
-    } else {
-      name_dict[name] = arg;
-    }
-  }
-  std::unordered_map<relay::Var, Expr, ObjectHash, ObjectEqual> bind_dict;
-  for (auto &kv : params) {
-    if (name_dict.count(kv.first) == 0) {
-      continue;
-    }
-    auto arg = name_dict.at(kv.first);
-    if (repeat_var.count(arg)) {
-      LOG(FATAL) << "Multiple args in the function have name " << kv.first;
-    }
-    bind_dict[arg] = ConstantNode::make(kv.second);
-  }
-  Expr bound_expr = relay::Bind(func, bind_dict);
-  Function ret = Downcast<Function>(bound_expr);
-  CHECK(ret.defined())
-      << "The returning type is expected to be a Relay Function."
-      << "\n";
-  return ret;
-}
-
 void VMCompiler::Lower(IRModule mod,
                        const TargetsMap& targets,
                        const tvm::Target& target_host) {
@@ -824,7 +791,7 @@ void VMCompiler::Lower(IRModule mod,
     BaseFunc base_func = mod->Lookup("main");
     CHECK(base_func->IsInstance<FunctionNode>())
         << "VM compiler expects to compile relay::Function";
-    auto f = BindParamsByName(Downcast<Function>(base_func), params_);
+    auto f = relay::backend::BindParamsByName(Downcast<Function>(base_func), params_);
     auto gvar = mod->GetGlobalVar("main");
     mod->Add(gvar, f);
   }
index 602e6cc..19e1ee8 100644 (file)
@@ -115,16 +115,6 @@ class VMCompiler : public runtime::ModuleNode {
   void Codegen();
 
  protected:
-  /*!
-   * \brief Bind params to function by using name
-   * \param func Relay function
-   * \param params params dict
-   * \return relay::Function
-   */
-  relay::Function BindParamsByName(
-      relay::Function func,
-      const std::unordered_map<std::string, runtime::NDArray>& params);
-
   IRModule OptimizeModule(const IRModule& mod, const TargetsMap& targets);
 
   void PopulateGlobalMap();