[TIR][Transform] HoistIfThenElse added (#6066)
authorANSHUMAN TRIPATHY <anshuman.t@huawei.com>
Mon, 3 Aug 2020 20:50:11 +0000 (02:20 +0530)
committerGitHub <noreply@github.com>
Mon, 3 Aug 2020 20:50:11 +0000 (13:50 -0700)
* [TIR][Transform] HoistIfThenElse added

* lint error resolved

* Pass position changed

* pylint error resolved

* CI issues resolved

* Frontend tflite test case failure resolved

* [1] Review comment handled

* [2] Review comment handled

* [3] Review comment handled

* Lint error resolved

include/tvm/tir/transform.h
python/tvm/driver/build_module.py
python/tvm/tir/transform/transform.py
src/tir/transforms/hoist_if_then_else.cc [new file with mode: 0644]
tests/python/unittest/test_te_build_lower.py
tests/python/unittest/test_tir_transform_hoist_if.py [new file with mode: 0644]

index 5e04838..f31e515 100644 (file)
@@ -338,6 +338,14 @@ TVM_DLL Pass BF16Legalize();
  */
 TVM_DLL Pass PointerValueTypeRewrite();
 
+/*!
+ * \brief Hoist loop-invariant IfThenElse nodes to
+ * outside the elligible loops.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass HoistIfThenElse();
+
 }  // namespace transform
 }  // namespace tir
 }  // namespace tvm
index b107000..663a17a 100644 (file)
@@ -179,6 +179,7 @@ def lower(sch,
         tvm.tir.transform.BF16Legalize(),
         tvm.tir.transform.NarrowDataType(32),
         tvm.tir.transform.Simplify(),
+        tvm.tir.transform.HoistIfThenElse(),
     ]
     pass_list += lower_phase1
 
index 86e7a33..d2f5acd 100644 (file)
@@ -499,3 +499,12 @@ def VerifyMemory():
         The result pass
     """
     return _ffi_api.VerifyMemory()
+
+def HoistIfThenElse():
+    """Hoist loop-invariant IfThenElse nodes to outside the elligible loops.
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The result pass
+    """
+    return _ffi_api.HoistIfThenElse()
diff --git a/src/tir/transforms/hoist_if_then_else.cc b/src/tir/transforms/hoist_if_then_else.cc
new file mode 100644 (file)
index 0000000..f58eb96
--- /dev/null
@@ -0,0 +1,365 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file hoist_if_then_else.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../../arith/interval_set.h"
+#include "../../runtime/thread_storage_scope.h"
+#include "ir_util.h"
+
+namespace tvm {
+namespace tir {
+
+using VarForMap = std::unordered_map<const VarNode*, const ForNode*>;
+using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>;
+
+/*
+ * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant.
+ * For example, given the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt.
+ * Then we hoist IfThenElse stmt by one For stmt each step:
+ *
+ * Step 1:
+ * for (i = 0; i < 3; i++)
+ *     for (j = 0; j < 4; j++)
+ *         if (likely(i*2 < 4))
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Step 2:
+ * for (i = 0; i < 3; i++)
+ *     if (likely(i*2 < 4))
+ *         for (j = 0; j < 4; j++)
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * In this pass, we only continue detecting possible hoisting chance when visiting For,
+ * IfThenElse or AttrStmt Node. For example, for the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Only the For with k variable will be considered and the resulting stmt would be:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        if (likely(i*2 < 4))
+ *            for (k = 0; k < 5; k++)
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following
+ * block won't be optimized:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *            if (likely(j > 2))
+ *                A[i+j+k] = B[i+j+k]
+ *
+ */
+
+// Select potential candidate IRs that can be hoisted.
+class HoistCandidateSelector final : public StmtExprVisitor {
+ public:
+  HoistCandidateSelector() { InitRecorder(); }
+
+  void VisitStmt_(const ForNode* op) final {
+    // If already recording complete,
+    // then stop tracing
+    if (RecordingComplete()) {
+      return;
+    }
+
+    // Check if it is first for loop, then start the recorder
+    StartOrAddRecord(op);
+    StmtExprVisitor::VisitStmt_(op);
+    RemoveRecord(op);
+  }
+
+  void VisitStmt_(const SeqStmtNode* op) final {
+    // If SeqStmt is encountered in the middle of recording
+    //  then need to purge all, as it can not be hoisted
+    if (IsRecordingOn()) {
+      ResetRecorder();
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const AttrStmtNode* op) final {
+    // Maintain list of all vars in AttrStmt
+    // To stop hoisting if any of the block variables are used.
+    //
+    // NOTE: If in future
+    // hoisting is required for any specific case,
+    // then add exception to only those case
+    // rather than allowing for all.
+    UpdateAttrVarList(op);
+    StmtExprVisitor::VisitStmt_(op);
+    RemoveAttrVarList(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    if (!IsRecordingOn()) {
+      StmtExprVisitor::VisitStmt_(op);
+      return;
+    }
+
+    is_if_cond_ = true;
+    StmtExprVisitor::VisitExpr(op->condition);
+    is_if_cond_ = false;
+
+    if (CheckValidIf()) {
+      // Check corresponding for loop
+      bool match_found = false;
+      size_t match_for_loop_pos = 0;
+      for (auto var : if_var_list_) {
+        for (size_t i = 0; i < ordered_for_list_.size() - 1; ++i) {
+          if (ordered_for_list_[i] == var_for_map_[var]) {
+            if (match_for_loop_pos < i) {
+              match_for_loop_pos = i;
+            }
+            match_found = true;
+            break;
+          }
+        }
+      }
+      // If none of the for loop has the matching loop variable as if condition,
+      // then the if node need to be hoisted on top of all, provided no parent loop exists.
+      int target_for_pos = match_found ? match_for_loop_pos + 1 : 0;
+
+      // Check if target for loop is not the parent of current if node
+      if (!IsParentForLoop(target_for_pos)) {
+        StopAndAddRecord(ordered_for_list_[target_for_pos], op);
+        if_var_list_.clear();
+        return;
+      }
+    }
+
+    if_var_list_.clear();
+    StmtExprVisitor::VisitStmt_(op);
+    StopRecording();
+  }
+
+  void VisitExpr_(const VarNode* op) final {
+    if (is_if_cond_) {
+      if_var_list_.emplace_back(op);
+    }
+  }
+
+  HoistForIfTuple hoist_for_if_recorder;
+
+  void ResetRecorder() {
+    if (is_recorder_on_) {
+      CHECK_GT(ordered_for_list_.size(), 0);
+      is_recorder_on_ = false;
+    }
+    ordered_for_list_.clear();
+    var_for_map_.clear();
+    hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr);
+  }
+
+  bool RecordingComplete() { return std::get<0>(hoist_for_if_recorder); }
+
+  const ForNode* GetTargetForNode() { return std::get<1>(hoist_for_if_recorder); }
+
+  const IfThenElseNode* GetTargetIfNode() { return std::get<2>(hoist_for_if_recorder); }
+
+ private:
+  bool CheckValidIf() {
+    // If no if var list is present, then all the condition vars are possibly from AttrStmt, so stop
+    // hoisting
+    return ((!if_var_list_.empty()) && (!CheckAttrVar()));
+  }
+
+  bool IsParentForLoop(int loop_pos) {
+    // Check if the loop position is higher than the parent loop position
+    for (auto var : if_var_list_) {
+      if (GetParentLoopPos(var_for_map_[var]) >= loop_pos) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  int GetParentLoopPos(const Object* node) {
+    for (size_t i = 0; i < ordered_for_list_.size(); ++i) {
+      if (ordered_for_list_[i] == node) {
+        return i;
+      }
+    }
+    return -1;
+  }
+
+  void InitRecorder() { hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr); }
+
+  void StopRecording() { is_recorder_on_ = false; }
+
+  bool IsRecordingOn() { return is_recorder_on_; }
+
+  void StartOrAddRecord(const ForNode* op) {
+    is_recorder_on_ = true;
+    if (!var_for_map_.count(op->loop_var.get())) {
+      var_for_map_.insert({op->loop_var.get(), op});
+    }
+    ordered_for_list_.emplace_back(op);
+  }
+
+  void RemoveRecord(const ForNode* op) {
+    StopRecording();
+    var_for_map_.erase(op->loop_var.get());
+    if (ordered_for_list_.size() > 0) ordered_for_list_.pop_back();
+  }
+
+  void StopAndAddRecord(const ForNode* for_node, const IfThenElseNode* if_node) {
+    hoist_for_if_recorder = std::make_tuple(true, for_node, if_node);
+    StopRecording();
+  }
+
+  void UpdateAttrVarList(const AttrStmtNode* op) {
+    if (const auto* iv = op->node.as<IterVarNode>()) {
+      attr_var_list_.insert(iv->var.get());
+    } else if (const auto* iv = op->node.as<VarNode>()) {
+      attr_var_list_.insert(iv);
+    }
+  }
+
+  void RemoveAttrVarList(const AttrStmtNode* op) {
+    if (const auto* iv = op->node.as<IterVarNode>()) {
+      attr_var_list_.erase(iv->var.get());
+    } else if (const auto* iv = op->node.as<VarNode>()) {
+      attr_var_list_.erase(iv);
+    }
+  }
+
+  bool CheckAttrVar() {
+    for (auto var : if_var_list_) {
+      if (attr_var_list_.count(var)) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  std::vector<const ForNode*> ordered_for_list_;
+  std::vector<const VarNode*> if_var_list_;
+  std::unordered_set<const VarNode*> attr_var_list_;
+  VarForMap var_for_map_;
+
+  bool is_if_cond_{false};
+  bool is_recorder_on_{false};
+};
+
+class IfThenElseHoister : public StmtMutator {
+ public:
+  IfThenElseHoister() : hoist_selector_(HoistCandidateSelector()) {}
+
+  Stmt VisitAndMutate(Stmt stmt) {
+    hoist_selector_(stmt);
+    Stmt stmt_copy = std::move(stmt);
+
+    while (hoist_selector_.RecordingComplete()) {
+      target_for_ = hoist_selector_.GetTargetForNode();
+      target_if_ = hoist_selector_.GetTargetIfNode();
+
+      stmt_copy = operator()(stmt_copy);
+
+      hoist_selector_.ResetRecorder();
+      hoist_selector_(stmt_copy);
+    }
+
+    // Support SSA Form
+    stmt_copy = ConvertSSA(stmt_copy);
+    return stmt_copy;
+  }
+
+  Stmt VisitStmt_(const ForNode* op) final {
+    if ((!is_updating_) && (target_for_ == op)) {
+      is_updating_ = true;
+      is_then_case_ = true;
+      Stmt then_case = StmtMutator::VisitStmt_(op);
+      is_then_case_ = false;
+      Stmt else_case = Stmt();
+      if (target_if_->else_case.defined()) {
+        else_case = StmtMutator::VisitStmt_(op);
+      }
+      is_updating_ = false;
+      return IfThenElse(target_if_->condition, then_case, else_case);
+    }
+    return StmtMutator::VisitStmt_(op);
+  }
+
+  Stmt VisitStmt_(const IfThenElseNode* op) final {
+    if (is_updating_ && (target_if_ == op)) {
+      if (is_then_case_) {
+        return StmtMutator::VisitStmt(op->then_case);
+      } else if (op->else_case.defined()) {
+        return StmtMutator::VisitStmt(op->else_case);
+      }
+    }
+    return StmtMutator::VisitStmt_(op);
+  }
+
+ private:
+  bool is_updating_{false};
+  bool is_then_case_{false};
+  HoistCandidateSelector hoist_selector_;
+  const ForNode* target_for_;
+  const IfThenElseNode* target_if_;
+};
+
+Stmt HoistIfThenElse(Stmt stmt) { return IfThenElseHoister().VisitAndMutate(stmt); }
+
+namespace transform {
+
+Pass HoistIfThenElse() {
+  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
+    auto* n = f.CopyOnWrite();
+    n->body = HoistIfThenElse(std::move(n->body));
+    return f;
+  };
+  return CreatePrimFuncPass(pass_func, 0, "tir.HoistIfThenElse", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.HoistIfThenElse").set_body_typed(HoistIfThenElse);
+
+}  // namespace transform
+
+}  // namespace tir
+}  // namespace tvm
index b1d7546..1fc2fcd 100644 (file)
@@ -49,7 +49,7 @@ def test_split_uneven_unique_likely():
     sch = te.create_schedule(c.op)
     xo, xi = sch[c].split(x, 5)
     stmt = tvm.lower(sch, [a, b, c])["main"].body
-    assert isinstance(stmt.body.body.body, tvm.tir.stmt.IfThenElse)
+    assert isinstance(stmt.body.body, tvm.tir.stmt.IfThenElse)
 
 
 if __name__ == "__main__":
diff --git a/tests/python/unittest/test_tir_transform_hoist_if.py b/tests/python/unittest/test_tir_transform_hoist_if.py
new file mode 100644 (file)
index 0000000..4ca952a
--- /dev/null
@@ -0,0 +1,268 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+from tvm import te
+
+var_list = []
+
+def verify_structure(stmt, expected_struct):
+    node_dict = {}
+    struct = {}
+    def _extract_vars(op):
+        global var_list
+        if isinstance(op, tvm.tir.Var):
+            var_list.append(op.name)
+
+    def _visit(op):
+        key = op
+        if isinstance(op, tvm.tir.IfThenElse):
+            global var_list
+            tvm.tir.stmt_functor.post_order_visit(op.condition, _extract_vars)
+            val = [(op.then_case, op.else_case), ("tir.IfThenElse", tuple(var_list))]
+            var_list.clear()
+        elif isinstance(op, tvm.tir.For):
+            val = [(op.body,), ("tir.For", op.loop_var.name)]
+        elif isinstance(op, tvm.tir.AttrStmt):
+            val = [(op.body,), ("tir.AttrStmt", op.attr_key, int(op.value))]
+        else:
+            return
+        node_dict[key] = val
+
+    tvm.tir.stmt_functor.post_order_visit(stmt, _visit)
+    for key, val in node_dict.items():
+        struct[val[1]] = tuple(node_dict[child][1] if child in node_dict
+                               else None for child in val[0])
+
+    assert struct == expected_struct, "Structure mismatch: expect %s but got %s" \
+                                      % (expected_struct, struct)
+    var_list.clear()
+
+def test_hoist_top_for():
+    ib = tvm.tir.ir_builder.create()
+    l = te.var('l')
+    m = te.var('m')
+    n = te.var('n')
+    data = ib.pointer("float32", name="data")
+
+    with ib.for_range(0, l, "i") as i:
+        with ib.for_range(0, m, "j") as j:
+            with ib.for_range(0, n, "k") as k:
+                with ib.if_scope(ib.likely(i < 2)):
+                    ib.emit(tvm.tir.Evaluate(m))
+                with ib.else_scope():
+                    ib.emit(tvm.tir.Evaluate(n))
+
+    stmt = ib.get()
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+    new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    expected_struct = {('tir.For', 'k'): (None,), ('tir.For', 'j'): (('tir.For', 'k'),),
+                       ('tir.IfThenElse', ('i',)): (('tir.For', 'j'), ('tir.For', 'j')),
+                       ('tir.For', 'i'): (('tir.IfThenElse', ('i',)),)}
+    verify_structure(new_stmt, expected_struct)
+
+def test_hoist_multi_var_if():
+    ib = tvm.tir.ir_builder.create()
+    l = te.var('l')
+    m = te.var('m')
+    n = te.var('n')
+    data = ib.pointer("float32", name="data")
+
+    with ib.for_range(0, l, "i") as i:
+        with ib.for_range(0, m, "j") as j:
+            with ib.for_range(0, n, "k") as k:
+                with ib.if_scope(ib.likely(i + j < 2)):
+                    ib.emit(tvm.tir.Evaluate(m))
+                with ib.else_scope():
+                    ib.emit(tvm.tir.Evaluate(n))
+
+    stmt = ib.get()
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+    new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    expected_struct = {('tir.For', 'k'): (None,),
+                       ('tir.IfThenElse', ('i', 'j')): (('tir.For', 'k'), ('tir.For', 'k')),
+                       ('tir.For', 'j'): (('tir.IfThenElse', ('i', 'j')),),
+                       ('tir.For', 'i'): (('tir.For', 'j'),)}
+    verify_structure(new_stmt, expected_struct)
+
+def test_hoist_no_match_for():
+    ib = tvm.tir.ir_builder.create()
+    l = te.var('l')
+    m = te.var('m')
+    n = te.var('n')
+    data = ib.pointer("float32", name="data")
+
+    with ib.for_range(0, l, "i") as i:
+        with ib.for_range(0, m, "j") as j:
+            data[i * 3 + j] = data[i * 3 + j] + 0.5
+            with ib.for_range(0, n, "k") as k:
+                with ib.if_scope(ib.likely(i < 2)):
+                    ib.emit(tvm.tir.Evaluate(m))
+                with ib.else_scope():
+                    ib.emit(tvm.tir.Evaluate(n))
+
+    stmt = ib.get()
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+    new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    expected_struct = {('tir.For', 'k'): (None,),
+                       ('tir.IfThenElse', ('i', )): (('tir.For', 'k'), ('tir.For', 'k')),
+                       ('tir.For', 'j'): (None,),
+                       ('tir.For', 'i'): (('tir.For', 'j'),)}
+    verify_structure(new_stmt, expected_struct)
+
+def test_no_else():
+    ib = tvm.tir.ir_builder.create()
+    l = te.var('l')
+    m = te.var('m')
+    n = te.var('n')
+
+    with ib.for_range(0, l, "i") as i:
+        with ib.for_range(0, m, "j") as j:
+            with ib.for_range(0, n, "k") as k:
+                with ib.if_scope(ib.likely(i < 2)):
+                    ib.emit(tvm.tir.Evaluate(m))
+
+    stmt = ib.get()
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+    new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    expected_struct = {('tir.For', 'k'): (None,), ('tir.For', 'j'): (('tir.For', 'k'),),
+                       ('tir.IfThenElse', ('i',)): (('tir.For', 'j'), None),
+                       ('tir.For', 'i'): (('tir.IfThenElse', ('i',)),)}
+    verify_structure(new_stmt, expected_struct)
+
+def test_attr_stmt():
+    ib = tvm.tir.ir_builder.create()
+    dshape = (32, 64)
+    data = ib.pointer("float32", name="data")
+    l = te.var('l')
+    m = te.var('m')
+    n = te.var('n')
+
+    tx = te.thread_axis("threadIdx.x")
+    bx = te.thread_axis("blockIdx.x")
+    ib.scope_attr(tx, "thread_extent", dshape[0])
+    ib.scope_attr(bx, "thread_extent", dshape[1])
+    with ib.for_range(0, l, "i") as i:
+        with ib.for_range(0, m, "j") as j:
+            with ib.for_range(0, n, "k") as k:
+                with ib.if_scope(tvm.tir.any(i < 4, j >= 8)):
+                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 0.5
+                with ib.else_scope():
+                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 1.0
+
+    stmt = ib.get()
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+    new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    expected_struct = {('tir.For', 'k'): (None,), ('tir.IfThenElse', ('i', 'j')): (('tir.For', 'k'), ('tir.For', 'k')),
+                       ('tir.For', 'j'): (('tir.IfThenElse', ('i', 'j')),), ('tir.For', 'i'): (('tir.For', 'j'),),
+                       ('tir.AttrStmt', 'thread_extent', 64): (('tir.For', 'i'),),
+                       ('tir.AttrStmt', 'thread_extent', 32): (('tir.AttrStmt', 'thread_extent', 64),)}
+    verify_structure(new_stmt, expected_struct)
+
+def test_nested_for():
+    ib = tvm.tir.ir_builder.create()
+    data = ib.pointer("float32", name="data")
+
+
+    with ib.for_range(0, 5, "i") as i:
+        with ib.for_range(0, 10, "j") as j:
+            with ib.if_scope(i >= 3):
+                data[i * 3 + j] = data[i * 3 + j] + 0.5
+                with ib.for_range(0, 15, "k") as k:
+                    with ib.for_range(0, 20, "l") as l:
+                        with ib.if_scope(tvm.tir.any(i < 4, j >= 8)):
+                            data[i * 3 + j + k + l] = data[i * 3 + j + k + l] * 2
+                        with ib.else_scope():
+                            data[i * 3 + j + k + l] = data[i * 3 + j + k + l] * 1.5
+
+    stmt = ib.get()
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+    new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    expected_struct = {('tir.For', 'l'): (None,), ('tir.For', 'k'): (('tir.For', 'l'),),
+                       ('tir.IfThenElse', ('i', 'j')): (('tir.For', 'k'), ('tir.For', 'k')),
+                       ('tir.For', 'j'): (None,),
+                       ('tir.IfThenElse', ('i',)): (('tir.For', 'j'), None),
+                       ('tir.For', 'i'): (('tir.IfThenElse', ('i',)),)}
+    verify_structure(new_stmt, expected_struct)
+
+def test_if_block():
+    ib = tvm.tir.ir_builder.create()
+    data = ib.pointer("float32", name="data")
+    n = te.var("n")
+
+
+    with ib.for_range(0, 5, "i") as i:
+        with ib.for_range(0, 10, "j") as j:
+            with ib.if_scope(i >= 3):
+                data[i * 3 + j] = data[i * 3 + j] + 0.5
+                with ib.for_range(0, 15, "k") as k:
+                    with ib.for_range(0, 20, "l") as l:
+                        with ib.if_scope(tvm.tir.any(i < 4, j >= 8)):
+                            data[i * 3 + j + k + l] = data[i * 3 + j + k + l] * 2
+                        with ib.else_scope():
+                            data[i * 3 + j + k + l] = data[i * 3 + j + k + l] * 1.5
+                        with ib.if_scope(j <5):
+                            data[i * 3 + j + k + l] = data[i * 3 + j + k + l] - 1
+
+
+    with ib.for_range(0, 5, "i") as i:
+        with ib.for_range(0, 10, "j") as j:
+                with ib.for_range(0, 15, "k") as k:
+                    with ib.if_scope(n >= 3):
+                        data[i * 3 + j + k] = data[i * 3 + j + k] + 0.6
+
+    stmt = ib.get()
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+    new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    expected_struct = {('tir.IfThenElse', ('i', 'j')): (None, None), ('tir.IfThenElse', ('j',)): (None, None),
+                       ('tir.For', 'l'): (None,), ('tir.For', 'k'): (None,), ('tir.For', 'j'): (('tir.For', 'j'),),
+                       ('tir.IfThenElse', ('i',)): (('tir.For', 'j'), None), ('tir.For', 'i'): (('tir.IfThenElse', ('i',)),),
+                       ('tir.IfThenElse', ('n',)): (('tir.For', 'j'), None)}
+    verify_structure(new_stmt, expected_struct)
+
+
+def test_multi_if():
+    ib = tvm.tir.ir_builder.create()
+    data = ib.pointer("float32", name="data")
+
+    with ib.for_range(0, 10, "i") as i:
+        with ib.for_range(0, 10, "j") as j:
+            with ib.for_range(0, 10, "k") as k:
+                with ib.if_scope(i >= 3):
+                    with ib.if_scope(j >= 3):
+                        data[i * 100 + j * 10 + k] = data[i * 100 + j * 10 + k] + 0.5
+
+    stmt = ib.get()
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+    new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    expected_struct = {('tir.For', 'k'): (None,),
+                       ('tir.IfThenElse', ('j',)): (('tir.For', 'k'), None),
+                       ('tir.For', 'j'): (('tir.IfThenElse', ('j',)),),
+                       ('tir.IfThenElse', ('i',)): (('tir.For', 'j'), None),
+                       ('tir.For', 'i'): (('tir.IfThenElse', ('i',)),)}
+    verify_structure(new_stmt, expected_struct)
+
+
+if __name__ == "__main__":
+    test_hoist_top_for()
+    test_hoist_multi_var_if()
+    test_hoist_no_match_for()
+    test_no_else()
+    test_attr_stmt()
+    test_nested_for()
+    test_if_block()
+    test_multi_if()
+