From: Wuwei Lin Date: Fri, 1 Nov 2019 17:36:36 +0000 (-0400) Subject: [Relay][Pass] Avoid FoldConstant folding some ops (#4245) X-Git-Tag: upstream/0.7.0~1702 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=aa49e851e7fed1e9111ffbd7d890188a3a173178;p=platform%2Fupstream%2Ftvm.git [Relay][Pass] Avoid FoldConstant folding some ops (#4245) * [Relay][Pass] Avoid FoldConstant folding some ops * rename --- diff --git a/src/relay/pass/fold_constant.cc b/src/relay/pass/fold_constant.cc index 6848876..5825c1e 100644 --- a/src/relay/pass/fold_constant.cc +++ b/src/relay/pass/fold_constant.cc @@ -102,6 +102,9 @@ class ConstantFolder : public ExprMutator { Expr VisitExpr_(const CallNode* call) final { static auto op_stateful = Op::GetAttr("TOpIsStateful"); + + std::unordered_set skip_list{"zeros_like", "ones_like", "full_like", "full"}; + auto origin_args = call->args; Expr res = ExprMutator::VisitExpr_(call); call = res.as(); @@ -111,6 +114,9 @@ class ConstantFolder : public ExprMutator { if (call->args.size() == 0) return res; const OpNode* op = call->op.as(); if (op == nullptr) return res; + if (skip_list.count(op->name)) { + return res; + } // skip stateful ops. if (op_stateful.get(GetRef(op), false)) return res; // Try to evaluate shape_of op diff --git a/tests/python/relay/test_pass_fold_constant.py b/tests/python/relay/test_pass_fold_constant.py index 97b20c6..4752597 100644 --- a/tests/python/relay/test_pass_fold_constant.py +++ b/tests/python/relay/test_pass_fold_constant.py @@ -146,9 +146,25 @@ def test_fold_shape_of(): assert relay.analysis.graph_equal(zz, zexpected) +def test_fold_full(): + c_shape = (8, 9, 10) + def before(): + dtype = 'float32' + return relay.full(relay.const(1.0, dtype), c_shape, dtype=dtype) + + def expected(): + # expect no changes + return before() + + zz = run_opt_pass(before(), transform.FoldConstant()) + zexpected = run_opt_pass(expected(), transform.InferType()) + assert relay.analysis.graph_equal(zz, zexpected) + + if __name__ == "__main__": test_fold_const() test_fold_let() test_fold_tuple() test_fold_concat() test_fold_shape_of() + test_fold_full()