From 9cc7874166fe696abfae028d5f56db3ff75ee8ef Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sat, 26 Oct 2019 17:05:03 -0700 Subject: [PATCH] [Relay][Params] Add APIs for storing and retrieving parameters from individual functions. (#4194) * Add support for attaching params * Fix types * Fix test --- include/tvm/relay/expr.h | 14 +++++++++++++- python/tvm/relay/expr.py | 12 ++++++++++++ src/relay/ir/expr.cc | 20 ++++++++++++++++++++ tests/python/relay/test_ir_nodes.py | 33 +++++++++++++++++++++++++++++++-- 4 files changed, 76 insertions(+), 3 deletions(-) diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 6df4273..ff075e3 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -274,6 +274,19 @@ class FunctionNode : public ExprNode { tvm::Array ty_params, tvm::Attrs attrs = Attrs()); + /*! + * \brief Attach the function's parameters to its attributes for use in analysis. + * \return The function with its parameters attached. + */ + Function SetParams(const tvm::Map& parameters) const; + + /*! + * \brief Retrieve the function's parameters. + * + * \return The function's parameter. + */ + tvm::Map GetParams() const; + static constexpr const char* _type_key = "relay.Function"; TVM_DECLARE_NODE_TYPE_INFO(FunctionNode, ExprNode); }; @@ -284,7 +297,6 @@ RELAY_DEFINE_NODE_REF(Function, FunctionNode, Expr); TVM_DLL NodeRef FunctionGetAttr(const Function& func, const std::string& key); TVM_DLL Function FunctionSetAttr(const Function& func, const std::string& key, const NodeRef& data); - /*! * \brief Call corresponds to operator invocation. * Corresponds to the operator in computational graph terminology. diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 88779df..8d59e99 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -27,6 +27,7 @@ from . import ty as _ty from .._ffi import base as _base from .. import nd as _nd from .. import convert +from ..ndarray import NDArray # will be registered afterwards _op_make = None @@ -305,6 +306,17 @@ class Function(Expr): """ return Call(self, args, None, None) + def get_params(self): + return _expr.FunctionGetParams(self) + + def set_params(self, params): + for key in params: + value = params[key] + if isinstance(value, NDArray): + params[key] = Constant(value) + + return _expr.FunctionSetParams(self, params) + @register_relay_node class Call(Expr): diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 35e4f2b..c36b4c8 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -159,6 +159,26 @@ bool FunctionNode::IsPrimitive() const { return pval && pval->value != 0; } +Function FunctionNode::SetParams(const tvm::Map& parameters) const { + return FunctionSetAttr(GetRef(this), "__params__", parameters); +} + +TVM_REGISTER_API("relay._expr.FunctionSetParams") +.set_body_typed&)>( + [](const Function& func, const tvm::Map& parameters) { + return func->SetParams(parameters); +}); + +tvm::Map FunctionNode::GetParams() const { + auto node_ref = FunctionGetAttr(GetRef(this), "__params__"); + return Downcast>(node_ref); +} + +TVM_REGISTER_API("relay._expr.FunctionGetParams") +.set_body_typed(const Function&)>([](const Function& func) { + return func->GetParams(); +}); + NodeRef FunctionGetAttr(const Function& func, const std::string& key) { if (!func->attrs.defined()) { return NodeRef(); } diff --git a/tests/python/relay/test_ir_nodes.py b/tests/python/relay/test_ir_nodes.py index b42a1e6..dec840a 100644 --- a/tests/python/relay/test_ir_nodes.py +++ b/tests/python/relay/test_ir_nodes.py @@ -20,7 +20,7 @@ from tvm import relay from tvm.expr import * from tvm.relay import op from tvm.relay.analysis import graph_equal - +import numpy as np def check_json_roundtrip(node): json_str = tvm.save_json(node) @@ -160,7 +160,6 @@ def test_global_var(): str(gv) check_json_roundtrip(gv) - def test_function(): param_names = ['a', 'b', 'c', 'd'] params = tvm.convert([relay.Var(n) for n in param_names]) @@ -175,6 +174,34 @@ def test_function(): str(fn) check_json_roundtrip(fn) +def test_function_attrs(): + param_names = ['a', 'b', 'c', 'd'] + params = tvm.convert([relay.var(n, shape=(5, 2)) for n in param_names]) + ret_type = relay.TupleType(tvm.convert([])) + body = relay.Tuple(tvm.convert([])) + type_params = tvm.convert([]) + fn = relay.Function(params, body, ret_type, type_params) + model_params = {} + for param in params[:1]: + cty = param.type_annotation + tensor = np.random.rand(*[int(sh) for sh in cty.shape]).astype(cty.dtype) + model_params[param] = tvm.nd.array(tensor) + fn = fn.set_params(model_params) + assert fn.params == params + assert fn.body == body + assert fn.type_params == type_params + assert fn.span == None + str(fn) + check_json_roundtrip(fn) + json_str = tvm.save_json(fn) + fn_after = tvm.load_json(json_str) + model_params_after = fn_after.get_params() + after_keys = [item[0] for item in model_params_after.items()] + for key1, key2 in zip(model_params, after_keys): + assert key1.name_hint == key2.name_hint + p1 = model_params[key1] + p2 = model_params_after[key2] + np.testing.assert_allclose(p1.data.asnumpy(), p2.data.asnumpy()) def test_call(): op = relay.Var('f') @@ -257,9 +284,11 @@ if __name__ == "__main__": test_local_var() test_global_var() test_function() + test_function_attrs() test_call() test_let() test_if() test_tuple_get_item() test_op() test_conv2d_attrs() + -- 2.7.4