[Relay][Params] Add APIs for storing and retrieving parameters from individual functi...
authorJared Roesch <roeschinc@gmail.com>
Sun, 27 Oct 2019 00:05:03 +0000 (17:05 -0700)
committerGitHub <noreply@github.com>
Sun, 27 Oct 2019 00:05:03 +0000 (17:05 -0700)
* Add support for attaching params

* Fix types

* Fix test

include/tvm/relay/expr.h
python/tvm/relay/expr.py
src/relay/ir/expr.cc
tests/python/relay/test_ir_nodes.py

index 6df4273..ff075e3 100644 (file)
@@ -274,6 +274,19 @@ class FunctionNode : public ExprNode {
                                tvm::Array<TypeVar> ty_params,
                                tvm::Attrs attrs = Attrs());
 
+  /*!
+   * \brief Attach the function's parameters to its attributes for use in analysis.
+   * \return The function with its parameters attached.
+   */
+  Function SetParams(const tvm::Map<Var, Constant>& parameters) const;
+
+  /*!
+   * \brief Retrieve the function's parameters.
+   *
+   * \return The function's parameter.
+   */
+  tvm::Map<Var, Constant> GetParams() const;
+
   static constexpr const char* _type_key = "relay.Function";
   TVM_DECLARE_NODE_TYPE_INFO(FunctionNode, ExprNode);
 };
@@ -284,7 +297,6 @@ RELAY_DEFINE_NODE_REF(Function, FunctionNode, Expr);
 TVM_DLL NodeRef FunctionGetAttr(const Function& func, const std::string& key);
 TVM_DLL Function FunctionSetAttr(const Function& func, const std::string& key, const NodeRef& data);
 
-
 /*!
  * \brief Call corresponds to operator invocation.
  *  Corresponds to the operator in computational graph terminology.
index 88779df..8d59e99 100644 (file)
@@ -27,6 +27,7 @@ from . import ty as _ty
 from .._ffi import base as _base
 from .. import nd as _nd
 from .. import convert
+from ..ndarray import NDArray
 
 # will be registered afterwards
 _op_make = None
@@ -305,6 +306,17 @@ class Function(Expr):
         """
         return Call(self, args, None, None)
 
+    def get_params(self):
+        return _expr.FunctionGetParams(self)
+
+    def set_params(self, params):
+        for key in params:
+            value = params[key]
+            if isinstance(value, NDArray):
+                params[key] = Constant(value)
+
+        return _expr.FunctionSetParams(self, params)
+
 
 @register_relay_node
 class Call(Expr):
index 35e4f2b..c36b4c8 100644 (file)
@@ -159,6 +159,26 @@ bool FunctionNode::IsPrimitive() const {
   return pval && pval->value != 0;
 }
 
+Function FunctionNode::SetParams(const tvm::Map<Var, Constant>& parameters) const {
+  return FunctionSetAttr(GetRef<Function>(this), "__params__", parameters);
+}
+
+TVM_REGISTER_API("relay._expr.FunctionSetParams")
+.set_body_typed<Function(const Function&, const tvm::Map<Var, Constant>&)>(
+  [](const Function& func, const tvm::Map<Var, Constant>& parameters) {
+    return func->SetParams(parameters);
+});
+
+tvm::Map<Var, Constant> FunctionNode::GetParams() const {
+  auto node_ref = FunctionGetAttr(GetRef<Function>(this), "__params__");
+  return Downcast<tvm::Map<Var, Constant>>(node_ref);
+}
+
+TVM_REGISTER_API("relay._expr.FunctionGetParams")
+.set_body_typed<tvm::Map<Var, Constant>(const Function&)>([](const Function& func) {
+  return func->GetParams();
+});
+
 NodeRef FunctionGetAttr(const Function& func, const std::string& key) {
   if (!func->attrs.defined()) { return NodeRef(); }
 
index b42a1e6..dec840a 100644 (file)
@@ -20,7 +20,7 @@ from tvm import relay
 from tvm.expr import *
 from tvm.relay import op
 from tvm.relay.analysis import graph_equal
-
+import numpy as np
 
 def check_json_roundtrip(node):
     json_str = tvm.save_json(node)
@@ -160,7 +160,6 @@ def test_global_var():
     str(gv)
     check_json_roundtrip(gv)
 
-
 def test_function():
     param_names = ['a', 'b', 'c', 'd']
     params = tvm.convert([relay.Var(n) for n in param_names])
@@ -175,6 +174,34 @@ def test_function():
     str(fn)
     check_json_roundtrip(fn)
 
+def test_function_attrs():
+    param_names = ['a', 'b', 'c', 'd']
+    params = tvm.convert([relay.var(n, shape=(5, 2)) for n in param_names])
+    ret_type = relay.TupleType(tvm.convert([]))
+    body = relay.Tuple(tvm.convert([]))
+    type_params = tvm.convert([])
+    fn = relay.Function(params, body, ret_type, type_params)
+    model_params = {}
+    for param in params[:1]:
+        cty = param.type_annotation
+        tensor = np.random.rand(*[int(sh) for sh in cty.shape]).astype(cty.dtype)
+        model_params[param] = tvm.nd.array(tensor)
+    fn = fn.set_params(model_params)
+    assert fn.params == params
+    assert fn.body == body
+    assert fn.type_params == type_params
+    assert fn.span == None
+    str(fn)
+    check_json_roundtrip(fn)
+    json_str = tvm.save_json(fn)
+    fn_after = tvm.load_json(json_str)
+    model_params_after = fn_after.get_params()
+    after_keys = [item[0] for item in model_params_after.items()]
+    for key1, key2 in zip(model_params, after_keys):
+        assert key1.name_hint == key2.name_hint
+        p1 = model_params[key1]
+        p2 = model_params_after[key2]
+        np.testing.assert_allclose(p1.data.asnumpy(), p2.data.asnumpy())
 
 def test_call():
     op = relay.Var('f')
@@ -257,9 +284,11 @@ if __name__ == "__main__":
     test_local_var()
     test_global_var()
     test_function()
+    test_function_attrs()
     test_call()
     test_let()
     test_if()
     test_tuple_get_item()
     test_op()
     test_conv2d_attrs()
+