From 672203f241a33edf3c1e765821cb99d397594b5d Mon Sep 17 00:00:00 2001 From: =?utf8?q?=E9=9B=BE=E9=9B=A8=E9=AD=94=E7=90=86=E6=B2=99?= Date: Fri, 2 Aug 2019 10:35:27 -0700 Subject: [PATCH] [Relay] [Error] Fix error in partial evaluator (#3693) * fix * lint --- src/relay/pass/partial_eval.cc | 39 +++++++++++++++++++--------- tests/python/relay/test_pass_partial_eval.py | 11 +++++++- 2 files changed, 37 insertions(+), 13 deletions(-) diff --git a/src/relay/pass/partial_eval.cc b/src/relay/pass/partial_eval.cc index ca55e8c..869c056 100644 --- a/src/relay/pass/partial_eval.cc +++ b/src/relay/pass/partial_eval.cc @@ -131,7 +131,7 @@ Expr PostProcess(const Expr&); /*! \brief The base container type of Relay values. */ class StaticNode : public RelayNode { public: - static constexpr const char* _type_key = "relay.Value"; + static constexpr const char* _type_key = "relay.Static"; TVM_DECLARE_BASE_NODE_INFO(ValueNode, RelayNode); }; @@ -161,6 +161,7 @@ struct PStaticNode : Node { PStaticNode(const Static& pstatic, const Expr& dynamic) : pstatic(pstatic), dynamic(dynamic), created_time(time()) { } explicit PStaticNode(const Expr& dynamic) : PStaticNode(Static(), dynamic) { } + static constexpr const char* _type_key = "relay.PStatic"; TVM_DECLARE_NODE_TYPE_INFO(PStaticNode, Node); }; @@ -169,6 +170,7 @@ RELAY_DEFINE_NODE_REF(PStatic, PStaticNode, NodeRef); struct STupleNode : StaticNode { std::vector fields; explicit STupleNode(const std::vector& fields) : fields(fields) { } + static constexpr const char* _type_key = "relay.STuple"; TVM_DECLARE_NODE_TYPE_INFO(STupleNode, StaticNode); }; @@ -181,7 +183,8 @@ Static MkSTuple(const std::vector& fields) { struct STensorNode : StaticNode { runtime::NDArray data; explicit STensorNode(const NDArray& data) : data(data) { } - TVM_DECLARE_NODE_TYPE_INFO(STupleNode, StaticNode); + static constexpr const char* _type_key = "relay.STensor"; + TVM_DECLARE_NODE_TYPE_INFO(STensorNode, StaticNode); }; RELAY_DEFINE_NODE_REF(STensor, STensorNode, Value); @@ -195,6 +198,7 @@ struct SConstructorNode : StaticNode { std::vector fields; SConstructorNode(const Constructor& constructor, const std::vector& fields) : constructor(constructor), fields(fields) { } + static constexpr const char* _type_key = "relay.SConstructor"; TVM_DECLARE_NODE_TYPE_INFO(SConstructorNode, StaticNode); }; @@ -205,6 +209,7 @@ Static MkSConstructor(const Constructor& constructor, const std::vector } struct SRefNode : StaticNode { + static constexpr const char* _type_key = "relay.SRef"; // we will use the address as the guid for hashing TVM_DECLARE_NODE_TYPE_INFO(SRefNode, StaticNode); }; @@ -223,6 +228,7 @@ using Func = std::function&, struct SFuncNode : StaticNode { Func func; explicit SFuncNode(const Func& func) : func(func) { } + static constexpr const char* _type_key = "relay.SFunc"; TVM_DECLARE_NODE_TYPE_INFO(SFuncNode, StaticNode); }; @@ -711,8 +717,14 @@ class PartialEvaluator : public ExprFunctor return VisitFunc(GetRef(op), ll); } + struct ReflectError : dmlc::Error { + ReflectError() : dmlc::Error("static value not found") { } + }; + Expr Reflect(const PStatic& st) { - if (const STensorNode* op = st->pstatic.as()) { + if (!st->pstatic.defined()) { + throw ReflectError(); + } else if (const STensorNode* op = st->pstatic.as()) { return ConstantNode::make(op->data); } else if (const STupleNode* op = st->pstatic.as()) { tvm::Array fields; @@ -721,7 +733,7 @@ class PartialEvaluator : public ExprFunctor } return TupleNode::make(fields); } else { - LOG(FATAL) << "Unknown case"; + LOG(FATAL) << "Unknown case: " << st->dynamic; throw; } } @@ -767,19 +779,22 @@ class PartialEvaluator : public ExprFunctor for (const PStatic& ps : pv) { ns_args.push_back(ps->dynamic); } - PStatic ns = NoStatic(ll->Push(CallNode::make(expr, ns_args, attrs, type_args))); + auto ns = [&]() { + return NoStatic(ll->Push(CallNode::make(expr, ns_args, attrs, type_args))); + }; if (StatefulOp(expr)) { - return ns; + return ns(); } - tvm::Array args; - for (const PStatic& ps : pv) { - if (ps->pstatic.defined()) { + try { + tvm::Array args; + for (const PStatic& ps : pv) { args.push_back(Reflect(ps)); - } else { - return ns; } + return ConstEvaluate(CallNode::make(expr, args, attrs, type_args), ll); + } + catch (const ReflectError&) { + return ns(); } - return ConstEvaluate(CallNode::make(expr, args, attrs, type_args), ll); }; } diff --git a/tests/python/relay/test_pass_partial_eval.py b/tests/python/relay/test_pass_partial_eval.py index 7eb025a..b4dfa4e 100644 --- a/tests/python/relay/test_pass_partial_eval.py +++ b/tests/python/relay/test_pass_partial_eval.py @@ -18,7 +18,7 @@ import numpy as np import tvm from tvm import relay -from tvm.relay.analysis import alpha_equal +from tvm.relay.analysis import alpha_equal, assert_alpha_equal from tvm.relay.prelude import Prelude from tvm.relay import op, create_executor, transform from tvm.relay import Var, TypeVar, TupleGetItem, Let, Function, const, RefRead, RefWrite, RefCreate @@ -306,6 +306,14 @@ def test_double(): assert alpha_equal(res.body, make_nat_expr(p, 6)) +def test_concat(): + t = relay.TensorType([10], "float32") + x = Var("x", t) + y = Var("x", t) + orig = run_infer_type(Function([x, y], op.concatenate([x, y], axis=0))) + assert_alpha_equal(orig, dcpe(orig)) + + if __name__ == '__main__': test_ref() test_tuple() @@ -323,3 +331,4 @@ if __name__ == '__main__': test_nat_id() test_global_match_nat_id() test_match_nat_id() + test_concat() -- 2.7.4