Expose relay BindParamsByName to Python (#4751)
authormasahi <masahi129@gmail.com>
Mon, 20 Jan 2020 22:32:22 +0000 (07:32 +0900)
committerGitHub <noreply@github.com>
Mon, 20 Jan 2020 22:32:22 +0000 (07:32 +0900)
* expose BindParamByName to python

* fixed alpha equal test

python/tvm/relay/build_module.py
src/relay/backend/build_module.cc
tests/python/relay/test_pass_fold_constant.py

index 28ce16b..d848d90 100644 (file)
@@ -51,6 +51,15 @@ def _update_target(target):
     return tgts
 
 
+def _convert_param_map(params):
+    inputs = {}
+    for name, param in params.items():
+        if isinstance(param, np.ndarray):
+            param = _nd.array(param)
+        inputs[name] = _expr.const(param)
+    return inputs
+
+
 class BuildModule(object):
     """Build a Relay function to run on TVM graph runtime. This class is used
     to expose the `RelayBuildModule` APIs implemented in C++.
@@ -151,12 +160,7 @@ class BuildModule(object):
 
 
     def _set_params(self, params):
-        inputs = {}
-        for name, param in params.items():
-            if isinstance(param, np.ndarray):
-                param = _nd.array(param)
-            inputs[name] = _expr.const(param)
-        self._set_params_func(inputs)
+        self._set_params_func(_convert_param_map(params))
 
     def get_json(self):
         """Return the json file of the built program."""
@@ -296,6 +300,29 @@ def optimize(mod, target=None, params=None):
     return mod, params
 
 
+def bind_params_by_name(func, params):
+    """Bind params to function by name.
+    This could be useful when assembling custom Relay optimization
+    passes that involve constant folding.
+
+    Parameters
+    ----------
+    func : relay.Function
+        The function to bind parameters to.
+
+    params : dict of str to NDArray
+        Input parameters to the graph that do not change
+        during inference time. Used for constant folding.
+
+    Returns
+    -------
+    func : relay.Function
+        The function with parameters bound
+    """
+    inputs = _convert_param_map(params)
+    return _build_module.BindParamsByName(func, inputs)
+
+
 class GraphExecutor(_interpreter.Executor):
     """Wrapper around Executor interface.
 
index 480fd9e..7b9f1e2 100644 (file)
@@ -42,6 +42,43 @@ using TargetsMap = Map<tvm::Integer, tvm::Target>;
 using namespace tvm::relay::transform;
 
 /*!
+ * \brief Bind params to function by using name
+ * \param func Relay function
+ * \param params params dict
+ * \return relay::Function
+ */
+relay::Function BindParamsByName(relay::Function func,
+                                 const std::unordered_map<std::string, runtime::NDArray>& params) {
+  std::unordered_map<std::string, relay::Var> name_dict;
+  std::unordered_set<relay::Var, ObjectHash, ObjectEqual> repeat_var;
+  for (auto arg : func->params) {
+    const auto& name = arg->name_hint();
+    if (name_dict.count(name)) {
+      repeat_var.insert(arg);
+    } else {
+      name_dict[name] = arg;
+    }
+  }
+
+  std::unordered_map<relay::Var, Expr, ObjectHash, ObjectEqual> bind_dict;
+  for (auto& kv : params) {
+    if (name_dict.count(kv.first) == 0) {
+      continue;
+    }
+    auto arg = name_dict.at(kv.first);
+    if (repeat_var.count(arg)) {
+      LOG(FATAL) << "Multiple args in the function have name " << kv.first;
+    }
+    bind_dict[arg] = ConstantNode::make(kv.second);
+  }
+  Expr bound_expr = relay::Bind(func, bind_dict);
+  Function ret = Downcast<Function>(bound_expr);
+  CHECK(ret.defined()) << "The returning type is expected to be a Relay Function."
+                       << "\n";
+  return ret;
+}
+
+/*!
  * \brief Output of building module
  *
  */
@@ -249,45 +286,6 @@ class RelayBuildModule : public runtime::ModuleNode {
 
  protected:
   /*!
-   * \brief Bind params to function by using name
-   * \param func Relay function
-   * \param params params dict
-   * \return relay::Function
-   */
-  relay::Function BindParamsByName(
-      relay::Function func,
-      const std::unordered_map<std::string, runtime::NDArray>& params) {
-    std::unordered_map<std::string, relay::Var> name_dict;
-    std::unordered_set<relay::Var, ObjectHash, ObjectEqual> repeat_var;
-    for (auto arg : func->params) {
-      const auto &name = arg->name_hint();
-      if (name_dict.count(name)) {
-        repeat_var.insert(arg);
-      } else {
-        name_dict[name] = arg;
-      }
-    }
-
-    std::unordered_map<relay::Var, Expr, ObjectHash, ObjectEqual> bind_dict;
-    for (auto &kv : params) {
-      if (name_dict.count(kv.first) == 0) {
-        continue;
-      }
-      auto arg = name_dict.at(kv.first);
-      if (repeat_var.count(arg)) {
-        LOG(FATAL) << "Multiple args in the function have name " << kv.first;
-      }
-      bind_dict[arg] = ConstantNode::make(kv.second);
-    }
-    Expr bound_expr = relay::Bind(func, bind_dict);
-    Function ret = Downcast<Function>(bound_expr);
-    CHECK(ret.defined())
-        << "The returning type is expected to be a Relay Function."
-        << "\n";
-    return ret;
-  }
-
-  /*!
    * \brief Optimize a Relay Function.
    *
    * \param func The input Function where optmization will be applied on.
@@ -522,6 +520,16 @@ TVM_REGISTER_GLOBAL("relay.build_module._BuildModule")
   *rv = RelayBuildCreate();
 });
 
+TVM_REGISTER_GLOBAL("relay.build_module.BindParamsByName")
+.set_body([](TVMArgs args, TVMRetValue* rv) {
+  Map<std::string, Constant> params = args[1];
+  std::unordered_map<std::string, runtime::NDArray> params_;
+  for (const auto& kv : params) {
+    params_[kv.first] = kv.second->data;
+  }
+  *rv = BindParamsByName(args[0], params_);
+});
+
 }  // namespace backend
 }  // namespace relay
 }  // namespace tvm
index 4752597..ca901b1 100644 (file)
@@ -18,6 +18,8 @@ import numpy as np
 import tvm
 from tvm import relay
 from tvm.relay import transform
+from tvm.relay.build_module import bind_params_by_name
+from tvm.relay.testing import run_infer_type, create_workload
 
 
 def run_opt_pass(expr, opt_pass):
@@ -161,6 +163,47 @@ def test_fold_full():
     assert relay.analysis.graph_equal(zz, zexpected)
 
 
+def test_fold_batch_norm():
+    def expected():
+        data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32"))
+        weight = relay.const(np.zeros((16, 3, 3, 3)))
+        bias = relay.const(np.zeros((16, 1, 1)))
+        conv = relay.nn.conv2d(data=data, weight=weight, kernel_size=(3, 3),
+                               channels=16, padding=(1, 1))
+        add = relay.add(conv, bias)
+        return relay.Function(relay.analysis.free_vars(add), add)
+
+    remove_bn_pass = transform.Sequential([
+        relay.transform.InferType(),
+        relay.transform.SimplifyInference(),
+        relay.transform.FoldConstant(),
+        relay.transform.FoldScaleAxis(),
+    ])
+
+    data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32"))
+    weight = relay.var("weight")
+    bn_gamma = relay.var("bn_gamma")
+    bn_beta = relay.var("bn_beta")
+    bn_mmean = relay.var("bn_mean")
+    bn_mvar = relay.var("bn_var")
+
+    conv = relay.nn.conv2d(data=data, weight=weight, kernel_size=(3, 3),
+                           channels=16, padding=(1, 1))
+    bn_output = relay.nn.batch_norm(conv, bn_gamma, bn_beta,
+                                    bn_mmean, bn_mvar)
+    def initializer(_, param):
+        param = np.zeros(param.shape)
+
+    mod, params = create_workload(bn_output[0], initializer)
+    mod["main"] = bind_params_by_name(mod["main"], params)
+
+    with relay.build_config(opt_level=3):
+        mod = remove_bn_pass(mod)
+
+    expect = run_infer_type(expected())
+    assert relay.analysis.graph_equal(mod["main"], expect)
+
+
 if __name__ == "__main__":
     test_fold_const()
     test_fold_let()
@@ -168,3 +211,4 @@ if __name__ == "__main__":
     test_fold_concat()
     test_fold_shape_of()
     test_fold_full()
+    test_fold_batch_norm()