Bugfix StmtMutator IfThenElse (#4609)
authorTianqi Chen <tqchen@users.noreply.github.com>
Thu, 2 Jan 2020 17:08:54 +0000 (09:08 -0800)
committerGitHub <noreply@github.com>
Thu, 2 Jan 2020 17:08:54 +0000 (09:08 -0800)
src/pass/ir_functor.cc
tests/cpp/ir_functor_test.cc

index 079da75..fae4c03 100644 (file)
@@ -417,7 +417,7 @@ Stmt StmtMutator::VisitStmt_(const IfThenElse* op) {
     auto n = CopyOnWrite(op);
     n->condition = std::move(condition);
     n->then_case = std::move(then_case);
-    n->else_case = std::move(then_case);
+    n->else_case = std::move(else_case);
     return Stmt(n);
   }
 }
index 5f08601..ea854f5 100644 (file)
@@ -146,16 +146,22 @@ TEST(IRF, StmtMutator) {
       return ExprMutator::VisitExpr(expr);
     }
   };
-  auto fmaketest = [&]() {
+  auto fmakealloc = [&]() {
     auto z = x + 1;
     Stmt body = Evaluate::make(z);
     Var buffer("b", DataType::Handle());
     return Allocate::make(buffer, DataType::Float(32), {1, z}, const_true(), body);
   };
 
+  auto fmakeif = [&]() {
+    auto z = x + 1;
+    Stmt body = Evaluate::make(z);
+    return IfThenElse::make(x < 0, Evaluate::make(0), body);
+  };
+
   MyVisitor v;
   {
-    auto body = fmaketest();
+    auto body = fmakealloc();
     Stmt body2 = Evaluate::make(1);
     Stmt bref = body.as<Allocate>()->body;
     auto* extentptr = body.as<Allocate>()->extents.get();
@@ -172,7 +178,7 @@ TEST(IRF, StmtMutator) {
     CHECK(bref.as<Evaluate>()->value.as<Add>());
   }
   {
-    Array<Stmt> arr{fmaketest()};
+    Array<Stmt> arr{fmakealloc()};
     // mutate array get reference by another one, triiger copy.
     Array<Stmt> arr2 = arr;
     auto* arrptr = arr.get();
@@ -186,6 +192,16 @@ TEST(IRF, StmtMutator) {
     CHECK(arr2.get() == arr.get());
   }
   {
+    Array<Stmt> arr{fmakeif()};
+    arr.MutateByApply([&](Stmt s) { return v(std::move(s)); });
+    CHECK(arr[0].as<IfThenElse>()->else_case.as<Evaluate>()->value.same_as(x));
+    // mutate but no content change.
+    auto arr2 = arr;
+    arr.MutateByApply([&](Stmt s) { return v(std::move(s)); });
+    CHECK(arr2.get() == arr.get());
+  }
+
+  {
     auto body = Evaluate::make(Call::make(DataType::Int(32), "xyz", {x + 1}, Call::Extern));
     auto res = v(std::move(body));
     CHECK(res.as<Evaluate>()->value.as<Call>()->args[0].same_as(x));