[Fix] Add ConstantNode to IsAtomic (#5457)
authorZhi <5145158+zhiics@users.noreply.github.com>
Thu, 30 Apr 2020 17:00:27 +0000 (10:00 -0700)
committerGitHub <noreply@github.com>
Thu, 30 Apr 2020 17:00:27 +0000 (10:00 -0700)
* add constantnode to atomic

* Add ToANormalForm to FoldConstant

src/relay/transforms/fold_constant.cc
tests/python/relay/test_pass_fold_constant.py

index a52f420..fab184c 100644 (file)
@@ -203,6 +203,7 @@ class ConstantFolder : public ExprMutator {
   // Constant evaluate a expression.
   Expr ConstEvaluate(Expr expr) {
     std::vector<transform::Pass> passes = {transform::FuseOps(0),
+                                           transform::ToANormalForm(),
                                            transform::InferType()};
     Function func;
     if (expr.as<FunctionNode>()) {
index b212b26..a981667 100644 (file)
@@ -32,6 +32,25 @@ def run_opt_pass(expr, opt_pass):
     return entry if isinstance(expr, relay.Function) else entry.body
 
 
+def test_concatenate_const():
+    def before():
+        data = tvm.nd.array(np.array([1.0, 2.0, 3.0]))
+        const = relay.const(data)
+        concat = relay.op.concatenate([const, const], axis=0)
+        func = relay.Function([], concat)
+        return func
+
+    def expected():
+        data = tvm.nd.array(np.array([1.0, 2.0, 3.0, 1.0, 2.0, 3.0]))
+        const = relay.const(data)
+        func = relay.Function([], const)
+        return func
+
+    zz = run_opt_pass(before(), transform.FoldConstant())
+    zexpected = run_opt_pass(expected(), transform.InferType())
+    assert tvm.ir.structural_equal(zz, zexpected)
+
+
 def test_fold_const():
     c_data = np.array([1, 2, 3]).astype("float32")
     t = relay.TensorType([1, 2, 3], "float32")