tvm::Array<TypeVar> ty_params,
tvm::Attrs attrs = Attrs());
+ /*!
+ * \brief Attach the function's parameters to its attributes for use in analysis.
+ * \return The function with its parameters attached.
+ */
+ Function SetParams(const tvm::Map<Var, Constant>& parameters) const;
+
+ /*!
+ * \brief Retrieve the function's parameters.
+ *
+ * \return The function's parameter.
+ */
+ tvm::Map<Var, Constant> GetParams() const;
+
static constexpr const char* _type_key = "relay.Function";
TVM_DECLARE_NODE_TYPE_INFO(FunctionNode, ExprNode);
};
TVM_DLL NodeRef FunctionGetAttr(const Function& func, const std::string& key);
TVM_DLL Function FunctionSetAttr(const Function& func, const std::string& key, const NodeRef& data);
-
/*!
* \brief Call corresponds to operator invocation.
* Corresponds to the operator in computational graph terminology.
from .._ffi import base as _base
from .. import nd as _nd
from .. import convert
+from ..ndarray import NDArray
# will be registered afterwards
_op_make = None
"""
return Call(self, args, None, None)
+ def get_params(self):
+ return _expr.FunctionGetParams(self)
+
+ def set_params(self, params):
+ for key in params:
+ value = params[key]
+ if isinstance(value, NDArray):
+ params[key] = Constant(value)
+
+ return _expr.FunctionSetParams(self, params)
+
@register_relay_node
class Call(Expr):
return pval && pval->value != 0;
}
+Function FunctionNode::SetParams(const tvm::Map<Var, Constant>& parameters) const {
+ return FunctionSetAttr(GetRef<Function>(this), "__params__", parameters);
+}
+
+TVM_REGISTER_API("relay._expr.FunctionSetParams")
+.set_body_typed<Function(const Function&, const tvm::Map<Var, Constant>&)>(
+ [](const Function& func, const tvm::Map<Var, Constant>& parameters) {
+ return func->SetParams(parameters);
+});
+
+tvm::Map<Var, Constant> FunctionNode::GetParams() const {
+ auto node_ref = FunctionGetAttr(GetRef<Function>(this), "__params__");
+ return Downcast<tvm::Map<Var, Constant>>(node_ref);
+}
+
+TVM_REGISTER_API("relay._expr.FunctionGetParams")
+.set_body_typed<tvm::Map<Var, Constant>(const Function&)>([](const Function& func) {
+ return func->GetParams();
+});
+
NodeRef FunctionGetAttr(const Function& func, const std::string& key) {
if (!func->attrs.defined()) { return NodeRef(); }
from tvm.expr import *
from tvm.relay import op
from tvm.relay.analysis import graph_equal
-
+import numpy as np
def check_json_roundtrip(node):
json_str = tvm.save_json(node)
str(gv)
check_json_roundtrip(gv)
-
def test_function():
param_names = ['a', 'b', 'c', 'd']
params = tvm.convert([relay.Var(n) for n in param_names])
str(fn)
check_json_roundtrip(fn)
+def test_function_attrs():
+ param_names = ['a', 'b', 'c', 'd']
+ params = tvm.convert([relay.var(n, shape=(5, 2)) for n in param_names])
+ ret_type = relay.TupleType(tvm.convert([]))
+ body = relay.Tuple(tvm.convert([]))
+ type_params = tvm.convert([])
+ fn = relay.Function(params, body, ret_type, type_params)
+ model_params = {}
+ for param in params[:1]:
+ cty = param.type_annotation
+ tensor = np.random.rand(*[int(sh) for sh in cty.shape]).astype(cty.dtype)
+ model_params[param] = tvm.nd.array(tensor)
+ fn = fn.set_params(model_params)
+ assert fn.params == params
+ assert fn.body == body
+ assert fn.type_params == type_params
+ assert fn.span == None
+ str(fn)
+ check_json_roundtrip(fn)
+ json_str = tvm.save_json(fn)
+ fn_after = tvm.load_json(json_str)
+ model_params_after = fn_after.get_params()
+ after_keys = [item[0] for item in model_params_after.items()]
+ for key1, key2 in zip(model_params, after_keys):
+ assert key1.name_hint == key2.name_hint
+ p1 = model_params[key1]
+ p2 = model_params_after[key2]
+ np.testing.assert_allclose(p1.data.asnumpy(), p2.data.asnumpy())
def test_call():
op = relay.Var('f')
test_local_var()
test_global_var()
test_function()
+ test_function_attrs()
test_call()
test_let()
test_if()
test_tuple_get_item()
test_op()
test_conv2d_attrs()
+