[Relay][Pass] Avoid FoldConstant folding some ops (#4245)
authorWuwei Lin <vincentl13x@gmail.com>
Fri, 1 Nov 2019 17:36:36 +0000 (13:36 -0400)
committerZhi <5145158+zhiics@users.noreply.github.com>
Fri, 1 Nov 2019 17:36:36 +0000 (10:36 -0700)
* [Relay][Pass] Avoid FoldConstant folding some ops

* rename

src/relay/pass/fold_constant.cc
tests/python/relay/test_pass_fold_constant.py

index 6848876..5825c1e 100644 (file)
@@ -102,6 +102,9 @@ class ConstantFolder : public ExprMutator {
 
   Expr VisitExpr_(const CallNode* call) final {
     static auto op_stateful = Op::GetAttr<TOpIsStateful>("TOpIsStateful");
+
+    std::unordered_set<std::string> skip_list{"zeros_like", "ones_like", "full_like", "full"};
+
     auto origin_args = call->args;
     Expr res = ExprMutator::VisitExpr_(call);
     call = res.as<CallNode>();
@@ -111,6 +114,9 @@ class ConstantFolder : public ExprMutator {
     if (call->args.size() == 0) return res;
     const OpNode* op = call->op.as<OpNode>();
     if (op == nullptr) return res;
+    if (skip_list.count(op->name)) {
+        return res;
+    }
     // skip stateful ops.
     if (op_stateful.get(GetRef<Op>(op), false)) return res;
     // Try to evaluate shape_of op
index 97b20c6..4752597 100644 (file)
@@ -146,9 +146,25 @@ def test_fold_shape_of():
         assert relay.analysis.graph_equal(zz, zexpected)
 
 
+def test_fold_full():
+    c_shape = (8, 9, 10)
+    def before():
+        dtype = 'float32'
+        return relay.full(relay.const(1.0, dtype), c_shape, dtype=dtype)
+
+    def expected():
+        # expect no changes
+        return before()
+
+    zz = run_opt_pass(before(), transform.FoldConstant())
+    zexpected = run_opt_pass(expected(), transform.InferType())
+    assert relay.analysis.graph_equal(zz, zexpected)
+
+
 if __name__ == "__main__":
     test_fold_const()
     test_fold_let()
     test_fold_tuple()
     test_fold_concat()
     test_fold_shape_of()
+    test_fold_full()