[TIR][Transform]Block scope hoisting added (#6238)
authorANSHUMAN TRIPATHY <anshuman.t@huawei.com>
Mon, 31 Aug 2020 21:19:33 +0000 (02:49 +0530)
committerGitHub <noreply@github.com>
Mon, 31 Aug 2020 21:19:33 +0000 (14:19 -0700)
* Block scope hoisting added

* lowering flow added with 2 variants

* Fake commit to trigger ci with pass default enabled

* CI Failure resolved

* Optimize for if var list iteration

* More test case added

* Fake commit to disable failed test cases

* Pass default value restored

* [1] Review comment handled

* [2] Review comments handled

python/tvm/driver/build_module.py
python/tvm/tir/transform/transform.py
src/tir/transforms/hoist_if_then_else.cc
tests/python/unittest/test_tir_transform_hoist_if.py

index ff4b56b..9a3c473 100644 (file)
@@ -181,7 +181,6 @@ def lower(sch,
         tvm.tir.transform.BF16Legalize(),
         tvm.tir.transform.NarrowDataType(32),
         tvm.tir.transform.Simplify(),
-        tvm.tir.transform.HoistIfThenElse(),
     ]
     pass_list += lower_phase1
 
@@ -205,6 +204,7 @@ def lower(sch,
     ]
 
     pass_list += [tvm.tir.transform.RewriteUnsafeSelect()]
+    pass_list += [tvm.tir.transform.HoistIfThenElse()]
     pass_list += lower_phase3
 
     # Instrument BoundCheckers
index d2f5acd..3f7fb41 100644 (file)
@@ -500,11 +500,31 @@ def VerifyMemory():
     """
     return _ffi_api.VerifyMemory()
 
-def HoistIfThenElse():
+#pylint: disable=no-else-return,inconsistent-return-statements
+def HoistIfThenElse(variant=None):
     """Hoist loop-invariant IfThenElse nodes to outside the elligible loops.
+
+    Parameters
+    ----------
+    variant : Optional[String]
+        The variant of the pass.
+        variant can have any one of following values ["basic", None(Default)].
+
+        The basic variant supports basic hoisting scenarios where it exepects
+        the For & If Nodes are in place consecutively and does not involve
+        global scope variables or more advanced scenarios.
+
+        Default variant supports all hoisting scenarios,i.e., {"Basic" + "Advanced"}
+        supported with control with PassContext configs like below:
+
+            config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}}
+
     Returns
     -------
     fpass : tvm.transform.Pass
         The result pass
     """
-    return _ffi_api.HoistIfThenElse()
+    if variant == "basic":
+        return _ffi_api.HoistIfThenElseBasic()
+    elif variant is None:
+        return _ffi_api.HoistIfThenElse()
index f58eb96..4e7589c 100644 (file)
 namespace tvm {
 namespace tir {
 
+struct HoistIfThenElseConfigNode : public tvm::AttrsNode<HoistIfThenElseConfigNode> {
+  bool support_block_scope_hosting;
+
+  TVM_DECLARE_ATTRS(HoistIfThenElseConfigNode, "tir.transform.HoistIfThenElseConfig") {
+    TVM_ATTR_FIELD(support_block_scope_hosting)
+        .describe("Hoist if cond with block scope variables")
+        .set_default(false);
+  }
+};
+
+class HoistIfThenElseConfig : public Attrs {
+ public:
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(HoistIfThenElseConfig, Attrs,
+                                            HoistIfThenElseConfigNode);
+};
+
+TVM_REGISTER_NODE_TYPE(HoistIfThenElseConfigNode);
+TVM_REGISTER_PASS_CONFIG_OPTION("tir.HoistIfThenElse", HoistIfThenElseConfig);
+
 using VarForMap = std::unordered_map<const VarNode*, const ForNode*>;
 using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>;
 
@@ -93,11 +112,33 @@ using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>;
  *            if (likely(j > 2))
  *                A[i+j+k] = B[i+j+k]
  *
+ *
+ * This pass do hoisting for Block scope variables also.
+ * As below:
+ * Attr(IterVar: threadIdx.x)
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(threadIdx.x < 3))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Will be transformed to as below:
+ * Attr(IterVar: threadIdx.x)
+ * if (likely(threadIdx.x < 3))
+ *     for (i = 0; i < 3; i++)
+ *         for (j = 0; j < 4; j++)
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
  */
 
 // Select potential candidate IRs that can be hoisted.
 class HoistCandidateSelector final : public StmtExprVisitor {
  public:
+  explicit HoistCandidateSelector(bool support_block_scope_hosting)
+      : support_block_scope_hosting_(support_block_scope_hosting) {
+    InitRecorder();
+  }
   HoistCandidateSelector() { InitRecorder(); }
 
   void VisitStmt_(const ForNode* op) final {
@@ -108,16 +149,16 @@ class HoistCandidateSelector final : public StmtExprVisitor {
     }
 
     // Check if it is first for loop, then start the recorder
-    StartOrAddRecord(op);
+    StartOrAddRecord(GetRef<ObjectRef>(op));
     StmtExprVisitor::VisitStmt_(op);
-    RemoveRecord(op);
+    RemoveRecord(GetRef<ObjectRef>(op));
   }
 
   void VisitStmt_(const SeqStmtNode* op) final {
     // If SeqStmt is encountered in the middle of recording
     //  then need to purge all, as it can not be hoisted
     if (IsRecordingOn()) {
-      ResetRecorder();
+      ResetRecorderInternal();
     }
     StmtExprVisitor::VisitStmt_(op);
   }
@@ -126,10 +167,19 @@ class HoistCandidateSelector final : public StmtExprVisitor {
     // Maintain list of all vars in AttrStmt
     // To stop hoisting if any of the block variables are used.
     //
-    // NOTE: If in future
-    // hoisting is required for any specific case,
-    // then add exception to only those case
-    // rather than allowing for all.
+    // In case we want to use hoisting in between certain passes
+    // which have interdependencies of the postioning of if nodes with scope var
+    // it is better to disable this section
+    if (support_block_scope_hosting_) {
+      if (IsRecordingOn()) {
+        StartOrAddRecord(GetRef<ObjectRef>(op));
+        StmtExprVisitor::VisitStmt_(op);
+        RemoveRecord(GetRef<ObjectRef>(op));
+        return;
+      } else {
+        return StmtExprVisitor::VisitStmt_(op);
+      }
+    }
     UpdateAttrVarList(op);
     StmtExprVisitor::VisitStmt_(op);
     RemoveAttrVarList(op);
@@ -147,26 +197,23 @@ class HoistCandidateSelector final : public StmtExprVisitor {
 
     if (CheckValidIf()) {
       // Check corresponding for loop
-      bool match_found = false;
-      size_t match_for_loop_pos = 0;
+      int match_for_loop_pos = -1;
       for (auto var : if_var_list_) {
-        for (size_t i = 0; i < ordered_for_list_.size() - 1; ++i) {
-          if (ordered_for_list_[i] == var_for_map_[var]) {
+        for (int i = 0; i < static_cast<int>(ordered_list_.size()); ++i) {
+          if ((ordered_list_[i] == var_for_map_[var]) || (ordered_list_[i] == var)) {
             if (match_for_loop_pos < i) {
               match_for_loop_pos = i;
             }
-            match_found = true;
-            break;
           }
         }
       }
       // If none of the for loop has the matching loop variable as if condition,
       // then the if node need to be hoisted on top of all, provided no parent loop exists.
-      int target_for_pos = match_found ? match_for_loop_pos + 1 : 0;
+      int target_for_pos = GetNextLoopPos(match_for_loop_pos);
 
-      // Check if target for loop is not the parent of current if node
-      if (!IsParentForLoop(target_for_pos)) {
-        StopAndAddRecord(ordered_for_list_[target_for_pos], op);
+      // Check if valid position
+      if (target_for_pos >= 0) {
+        StopAndAddRecord(static_cast<const ForNode*>(ordered_list_[target_for_pos]), op);
         if_var_list_.clear();
         return;
       }
@@ -186,13 +233,10 @@ class HoistCandidateSelector final : public StmtExprVisitor {
   HoistForIfTuple hoist_for_if_recorder;
 
   void ResetRecorder() {
-    if (is_recorder_on_) {
-      CHECK_GT(ordered_for_list_.size(), 0);
-      is_recorder_on_ = false;
-    }
-    ordered_for_list_.clear();
-    var_for_map_.clear();
-    hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr);
+    ResetRecorderInternal();
+
+    // Reset Block scope vars also here
+    attr_var_list_.clear();
   }
 
   bool RecordingComplete() { return std::get<0>(hoist_for_if_recorder); }
@@ -202,25 +246,24 @@ class HoistCandidateSelector final : public StmtExprVisitor {
   const IfThenElseNode* GetTargetIfNode() { return std::get<2>(hoist_for_if_recorder); }
 
  private:
+  void ResetRecorderInternal() {
+    if (is_recorder_on_) {
+      CHECK_GT(ordered_list_.size(), 0);
+      is_recorder_on_ = false;
+    }
+    ordered_list_.clear();
+    var_for_map_.clear();
+    hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr);
+  }
   bool CheckValidIf() {
     // If no if var list is present, then all the condition vars are possibly from AttrStmt, so stop
     // hoisting
     return ((!if_var_list_.empty()) && (!CheckAttrVar()));
   }
 
-  bool IsParentForLoop(int loop_pos) {
-    // Check if the loop position is higher than the parent loop position
-    for (auto var : if_var_list_) {
-      if (GetParentLoopPos(var_for_map_[var]) >= loop_pos) {
-        return true;
-      }
-    }
-    return false;
-  }
-
-  int GetParentLoopPos(const Object* node) {
-    for (size_t i = 0; i < ordered_for_list_.size(); ++i) {
-      if (ordered_for_list_[i] == node) {
+  int GetNextLoopPos(int cur_pos) {
+    for (size_t i = cur_pos + 1; i < ordered_list_.size(); ++i) {
+      if (ordered_list_[i]->IsInstance<ForNode>()) {
         return i;
       }
     }
@@ -233,18 +276,25 @@ class HoistCandidateSelector final : public StmtExprVisitor {
 
   bool IsRecordingOn() { return is_recorder_on_; }
 
-  void StartOrAddRecord(const ForNode* op) {
+  void StartOrAddRecord(const ObjectRef& op) {
     is_recorder_on_ = true;
-    if (!var_for_map_.count(op->loop_var.get())) {
-      var_for_map_.insert({op->loop_var.get(), op});
+    if (const auto* node = op.as<ForNode>()) {
+      if (!var_for_map_.count(node->loop_var.get()))
+        var_for_map_.insert({node->loop_var.get(), node});
+      ordered_list_.emplace_back(op.get());
+    } else if (const auto* node = op.as<AttrStmtNode>()) {
+      if (const auto* iv = node->node.as<IterVarNode>()) {
+        ordered_list_.emplace_back(iv->var.get());
+      } else if (const auto* iv = node->node.as<VarNode>()) {
+        ordered_list_.emplace_back(iv);
+      }
     }
-    ordered_for_list_.emplace_back(op);
   }
 
-  void RemoveRecord(const ForNode* op) {
+  void RemoveRecord(const ObjectRef& op) {
     StopRecording();
-    var_for_map_.erase(op->loop_var.get());
-    if (ordered_for_list_.size() > 0) ordered_for_list_.pop_back();
+    if (const auto* node = op.as<ForNode>()) var_for_map_.erase(node->loop_var.get());
+    if (ordered_list_.size() > 0) ordered_list_.pop_back();
   }
 
   void StopAndAddRecord(const ForNode* for_node, const IfThenElseNode* if_node) {
@@ -277,18 +327,22 @@ class HoistCandidateSelector final : public StmtExprVisitor {
     return false;
   }
 
-  std::vector<const ForNode*> ordered_for_list_;
+  // Ordered List maintains all ForNodes & AttrStmtNodes encountered in sequence
+  std::vector<const Object*> ordered_list_;
   std::vector<const VarNode*> if_var_list_;
   std::unordered_set<const VarNode*> attr_var_list_;
   VarForMap var_for_map_;
 
   bool is_if_cond_{false};
   bool is_recorder_on_{false};
+  bool support_block_scope_hosting_{false};
 };
 
 class IfThenElseHoister : public StmtMutator {
  public:
   IfThenElseHoister() : hoist_selector_(HoistCandidateSelector()) {}
+  explicit IfThenElseHoister(bool support_block_scope_hosting)
+      : hoist_selector_(HoistCandidateSelector(support_block_scope_hosting)) {}
 
   Stmt VisitAndMutate(Stmt stmt) {
     hoist_selector_(stmt);
@@ -344,6 +398,9 @@ class IfThenElseHoister : public StmtMutator {
   const IfThenElseNode* target_if_;
 };
 
+Stmt HoistIfThenElse(Stmt stmt, bool support_block_scope_hosting) {
+  return IfThenElseHoister(support_block_scope_hosting).VisitAndMutate(stmt);
+}
 Stmt HoistIfThenElse(Stmt stmt) { return IfThenElseHoister().VisitAndMutate(stmt); }
 
 namespace transform {
@@ -351,14 +408,30 @@ namespace transform {
 Pass HoistIfThenElse() {
   auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
     auto* n = f.CopyOnWrite();
-    n->body = HoistIfThenElse(std::move(n->body));
+    auto cfg = ctx->GetConfig<HoistIfThenElseConfig>("tir.HoistIfThenElse");
+
+    if (!cfg.defined()) {
+      cfg = AttrsWithDefaultValues<HoistIfThenElseConfig>();
+    }
+    n->body = HoistIfThenElse(std::move(n->body), cfg.value()->support_block_scope_hosting);
     return f;
   };
   return CreatePrimFuncPass(pass_func, 0, "tir.HoistIfThenElse", {});
 }
 
+Pass HoistIfThenElseBasic() {
+  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
+    auto* n = f.CopyOnWrite();
+    n->body = HoistIfThenElse(std::move(n->body));
+    return f;
+  };
+  return CreatePrimFuncPass(pass_func, 0, "tir.HoistIfThenElseBasic", {});
+}
+
 TVM_REGISTER_GLOBAL("tir.transform.HoistIfThenElse").set_body_typed(HoistIfThenElse);
 
+TVM_REGISTER_GLOBAL("tir.transform.HoistIfThenElseBasic").set_body_typed(HoistIfThenElseBasic);
+
 }  // namespace transform
 
 }  // namespace tir
index 4ca952a..186a52d 100644 (file)
 # under the License.
 import tvm
 from tvm import te
+from tvm import relay
+import numpy as np
+import pytest
+from tvm.relay.testing import ctx_list
 
 var_list = []
 
@@ -255,14 +259,488 @@ def test_multi_if():
                        ('tir.For', 'i'): (('tir.IfThenElse', ('i',)),)}
     verify_structure(new_stmt, expected_struct)
 
+def test_no_hoisting_1():
+    ib = tvm.tir.ir_builder.create()
+    data = ib.pointer("float32", name="data")
+    n = te.var("n")
 
-if __name__ == "__main__":
-    test_hoist_top_for()
-    test_hoist_multi_var_if()
-    test_hoist_no_match_for()
-    test_no_else()
-    test_attr_stmt()
-    test_nested_for()
-    test_if_block()
-    test_multi_if()
+    with ib.for_range(0, 10, "i") as i:
+        with ib.for_range(0, 10, "j") as j:
+            with ib.for_range(0, 10, "k") as k:
+                with ib.if_scope(k >= 3):
+                    data[i * 100 + j * 10 + k] = data[i * 100 + j * 10 + k] + 0.5
 
+    stmt = ib.get()
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+    new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+    with tvm.transform.PassContext(config={
+        "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+    }):
+        new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+def test_no_hoisting_2():
+    ib = tvm.tir.ir_builder.create()
+    data = ib.pointer("float32", name="data")
+    n = te.var("n")
+    x = te.var("x")
+
+    with ib.for_range(0, 10, "i") as i:
+        with ib.for_range(0, 10, "j") as j:
+            with ib.for_range(0, 10, "k") as k:
+                with ib.if_scope(i >= 3):
+                    data[i * 100 + j * 10 + k] = data[i * 100 + j * 10 + k] + 0.3
+                data[i * 100 + j * 10 + k] = data[i * 100 + j * 10 + k] + 0.5
+
+    stmt = ib.get()
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+    new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+    with tvm.transform.PassContext(config={
+        "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+    }):
+        new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+def test_no_hoisting_3():
+    ib = tvm.tir.ir_builder.create()
+    dshape = (32, 64)
+    dshape_inner = (33, 63)
+    data = ib.pointer("float32", name="data")
+    l = te.var('l')
+    m = te.var('m')
+    n = te.var('n')
+
+    tx = te.thread_axis("threadIdx.x")
+    bx = te.thread_axis("blockIdx.x")
+    ib.scope_attr(tx, "thread_extent", dshape[0])
+    ib.scope_attr(bx, "thread_extent", dshape[1])
+    with ib.for_range(0, l, "i") as i:
+        with ib.for_range(0, m, "j") as j:
+            with ib.for_range(0, n, "k") as k:
+                ib.scope_attr(tx, "thread_extent", dshape_inner[0])
+                ib.scope_attr(bx, "thread_extent", dshape_inner[1])
+                with ib.if_scope(tx < 3):
+                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 0.3
+                with ib.else_scope():
+                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 1.3
+
+    stmt = ib.get()
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+    new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+    with tvm.transform.PassContext(config={
+        "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+    }):
+        new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+def test_no_hoisting_4():
+    ib = tvm.tir.ir_builder.create()
+    dshape = (32, 64)
+    dshape_inner = (33, 63)
+    data = ib.pointer("float32", name="data")
+    l = te.var('l')
+    m = te.var('m')
+    n = te.var('n')
+
+    tx = te.thread_axis("threadIdx.x")
+    bx = te.thread_axis("blockIdx.x")
+    ib.scope_attr(bx, "thread_extent", dshape[1])
+    with ib.for_range(0, l, "i") as i:
+        with ib.for_range(0, m, "j") as j:
+            with ib.for_range(0, n, "k") as k:
+                ib.scope_attr(tx, "thread_extent", dshape_inner[0])
+                with ib.if_scope(tx < 3):
+                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 0.3
+                with ib.else_scope():
+                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 1.3
+
+    stmt = ib.get()
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+    new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+    with tvm.transform.PassContext(config={
+        "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+    }):
+        new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+def test_no_hoisting_5():
+    ib = tvm.tir.ir_builder.create()
+    dshape = (32, 64)
+    dshape_inner = (33, 63)
+    data = ib.pointer("float32", name="data")
+    l = te.var('l')
+    m = te.var('m')
+    n = te.var('n')
+
+    tx = te.thread_axis("threadIdx.x")
+    bx = te.thread_axis("blockIdx.x")
+    ib.scope_attr(tx, "thread_extent", dshape[0])
+    ib.scope_attr(bx, "thread_extent", dshape[1])
+    with ib.for_range(0, l, "i") as i:
+        with ib.for_range(0, m, "j") as j:
+            ib.scope_attr(bx, "thread_extent", dshape_inner[1])
+            with ib.for_range(0, n, "k") as k:
+                ib.scope_attr(tx, "thread_extent", dshape_inner[0])
+                with ib.if_scope(tx < 3):
+                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 0.3
+                with ib.else_scope():
+                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 1.3
+
+    stmt = ib.get()
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+    new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+    with tvm.transform.PassContext(config={
+        "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+    }):
+        new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+def test_no_hoisting_6():
+    ib = tvm.tir.ir_builder.create()
+    dshape = (32, 64)
+    data = ib.pointer("float32", name="data")
+    l = te.var('l')
+    m = te.var('m')
+    n = te.var('n')
+
+    tx = te.thread_axis("threadIdx.x")
+    bx = te.thread_axis("blockIdx.x")
+    ib.scope_attr(tx, "thread_extent", dshape[0])
+    ib.scope_attr(bx, "thread_extent", dshape[1])
+    with ib.for_range(0, l, "i") as i:
+        with ib.for_range(0, m, "j") as j:
+            with ib.for_range(0, n, "k") as k:
+                with ib.if_scope((tx + k) < 3):
+                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 0.3
+                with ib.else_scope():
+                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 1.3
+
+    stmt = ib.get()
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+    new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+    with tvm.transform.PassContext(config={
+        "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+    }):
+        new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+def test_no_hoisting_7():
+    ib = tvm.tir.ir_builder.create()
+    dshape = (32, 64)
+    data = ib.pointer("float32", name="data")
+    l = te.var('l')
+    m = te.var('m')
+    n = te.var('n')
+
+    tx = te.thread_axis("threadIdx.x")
+    bx = te.thread_axis("blockIdx.x")
+    ib.scope_attr(tx, "thread_extent", dshape[0])
+    ib.scope_attr(bx, "thread_extent", dshape[1])
+    with ib.for_range(0, l, "i") as i:
+        with ib.for_range(0, m, "j") as j:
+            with ib.if_scope((tx + j) < 9):
+                with ib.for_range(0, n, "k") as k:
+                    with ib.if_scope((tx + k) < 3):
+                        data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 0.3
+
+    stmt = ib.get()
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+    new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+    with tvm.transform.PassContext(config={
+        "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+    }):
+        new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+def test_hoisting_block_scope_1():
+    n = te.size_var("n")
+    m = te.size_var("m")
+    A = te.placeholder((n, m), name='A')
+    k = te.reduce_axis((0, m), "k")
+    B = te.compute((n,), lambda i: te.sum(A[i, k], axis=k), name="B")
+    s = te.create_schedule(B.op)
+    ko, ki = s[B].split(B.op.reduce_axis[0], factor=16)
+    BF = s.rfactor(B, ki)
+    xo, xi = s[B].split(s[B].op.axis[0], factor=32)
+    s[B.op].bind(xo, te.thread_axis("blockIdx.x"))
+    s[B.op].bind(xi, te.thread_axis("threadIdx.y"))
+    s[B].bind(s[B].op.reduce_axis[0], te.thread_axis("threadIdx.x"))
+    s[BF].compute_at(s[B], s[B].op.reduce_axis[0])
+    func = tvm.driver.build_module.form_irmodule(
+            s, [A, B], "main", None)["main"]
+    stmt = func.body
+    new_stmt = tvm.tir.transform.HoistIfThenElse()(tvm.IRModule.from_expr(func))["main"].body
+    tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+    with tvm.transform.PassContext(config={
+        "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+    }):
+        new_stmt = tvm.tir.transform.HoistIfThenElse()(tvm.IRModule.from_expr(func))["main"].body
+    assert(not tvm.ir.structural_equal(new_stmt, stmt))
+
+def test_hoisting_block_scope_2():
+    ib = tvm.tir.ir_builder.create()
+    dshape = (32, 64)
+    dshape_inner = (33, 63)
+    data = ib.pointer("float32", name="data")
+    l = te.var('l')
+    m = te.var('m')
+    n = te.var('n')
+
+    tx = te.thread_axis("threadIdx.x")
+    bx = te.thread_axis("blockIdx.x")
+    ib.scope_attr(tx, "thread_extent", dshape[0])
+    #ib.scope_attr(bx, "thread_extent", dshape[1])
+    with ib.for_range(0, l, "i") as i:
+        with ib.for_range(0, m, "j") as j:
+            with ib.for_range(0, n, "k") as k:
+                ib.scope_attr(bx, "thread_extent", dshape[1])
+                with ib.if_scope(tx < 3):
+                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 0.3
+                with ib.else_scope():
+                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 1.3
+
+    stmt = ib.get()
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+    new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+    with tvm.transform.PassContext(config={
+        "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+    }):
+        new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    #tvm.ir.assert_structural_equal(new_stmt, stmt)
+    assert(not tvm.ir.structural_equal(new_stmt, stmt))
+
+def test_hoisting_block_scope_3():
+    ib = tvm.tir.ir_builder.create()
+    dshape = (32, 64)
+    dshape_inner = (33, 63)
+    data = ib.pointer("float32", name="data")
+    l = te.var('l')
+    m = te.var('m')
+    n = te.var('n')
+
+    tx = te.thread_axis("threadIdx.x")
+    bx = te.thread_axis("blockIdx.x")
+    ib.scope_attr(tx, "thread_extent", dshape[0])
+    ib.scope_attr(bx, "thread_extent", dshape[1])
+    with ib.for_range(0, l, "i") as i:
+        with ib.for_range(0, m, "j") as j:
+            ib.scope_attr(tx, "thread_extent", dshape_inner[0])
+            ib.scope_attr(bx, "thread_extent", dshape_inner[1])
+            with ib.for_range(0, n, "k") as k:
+                with ib.if_scope(tx < 3):
+                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 0.3
+                with ib.else_scope():
+                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 1.3
+
+    stmt = ib.get()
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+    new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+    with tvm.transform.PassContext(config={
+        "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+    }):
+        new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    #tvm.ir.assert_structural_equal(new_stmt, stmt)
+    assert(not tvm.ir.structural_equal(new_stmt, stmt))
+
+def test_hoisting_block_scope_4():
+    nn = 1024
+    n = tvm.runtime.convert(nn)
+    A = te.placeholder((n,), name='A')
+    B = te.placeholder((n,), name='B')
+    AA = te.compute((n,), lambda *i: A(*i), name='A')
+    BB = te.compute((n,), lambda *i: B(*i), name='B')
+    T = te.compute(A.shape, lambda *i: AA(*i) + BB(*i), name='T')
+    C = te.compute(A.shape, lambda *i: T(*i), name='C')
+    s = te.create_schedule(C.op)
+    xo, xi = s[C].split(C.op.axis[0], factor=4)
+    xo1, xo2 = s[C].split(xo, factor=13)
+    s[C].parallel(xo2)
+    s[C].pragma(xo1, "parallel_launch_point")
+    s[C].pragma(xo2, "parallel_stride_pattern")
+    s[C].pragma(xo2, "parallel_barrier_when_finish")
+    s[C].vectorize(xi)
+    func = tvm.driver.build_module.form_irmodule(
+            s, [A, B, C], "main", None)["main"]
+    stmt = func.body
+    new_stmt = tvm.tir.transform.HoistIfThenElse()(tvm.IRModule.from_expr(func))["main"].body
+    tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+    with tvm.transform.PassContext(config={
+        "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+    }):
+        new_stmt = tvm.tir.transform.HoistIfThenElse()(tvm.IRModule.from_expr(func))["main"].body
+    assert(not tvm.ir.structural_equal(new_stmt, stmt))
+
+def test_hoisting_block_scope_5():
+    ib = tvm.tir.ir_builder.create()
+    data = ib.pointer("float32", name="data")
+    l = te.var('l')
+    m = te.var('m')
+    n = te.var('n')
+    g = te.var('g')
+
+    ib.scope_attr(data, "storage_scope", "global")
+    with ib.for_range(0, l, "i") as i:
+        with ib.for_range(0, m, "j") as j:
+            with ib.for_range(0, n, "k") as k:
+                with ib.if_scope(data[g] < 3):
+                    data[9 * j + 3 * j * k] = data[9 * j + 3 * j * k]  + 0.3
+                with ib.else_scope():
+                    data[9 * j + 3 * j * k] = data[9 * j + 3 * j * k]  + 1.3
+
+    stmt = ib.get()
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+    new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    assert(not tvm.ir.structural_equal(new_stmt, stmt))
+
+    stmt = new_stmt
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+
+    with tvm.transform.PassContext(config={
+        "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+    }):
+        new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+def test_hoisting_block_scope_6():
+    ib = tvm.tir.ir_builder.create()
+    dshape = (32, 64)
+    data = ib.pointer("float32", name="data")
+    l = te.var('l')
+    m = te.var('m')
+    n = te.var('n')
+
+    tx = te.thread_axis("threadIdx.x")
+    bx = te.thread_axis("blockIdx.x")
+    ib.scope_attr(tx, "thread_extent", dshape[0])
+    ib.scope_attr(bx, "thread_extent", dshape[1])
+    with ib.for_range(0, l, "i") as i:
+        with ib.for_range(0, m, "j") as j:
+            with ib.for_range(0, n, "k") as k:
+                with ib.if_scope((tx + n) < 3):
+                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 0.3
+                with ib.else_scope():
+                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 1.3
+
+    stmt = ib.get()
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+    new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+    with tvm.transform.PassContext(config={
+        "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+    }):
+        new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    assert(not tvm.ir.structural_equal(new_stmt, stmt))
+
+def test_hoisting_block_scope_7():
+    ib = tvm.tir.ir_builder.create()
+    dshape = (32, 64)
+    data = ib.pointer("float32", name="data")
+    l = te.var('l')
+    m = te.var('m')
+    n = te.var('n')
+
+    tx = te.thread_axis("threadIdx.x")
+    bx = te.thread_axis("blockIdx.x")
+    ib.scope_attr(tx, "thread_extent", dshape[0])
+    ib.scope_attr(bx, "thread_extent", dshape[1])
+    with ib.for_range(0, l, "i") as i:
+        with ib.for_range(0, m, "j") as j:
+            with ib.for_range(0, n, "k") as k:
+                with ib.if_scope((tx + i) < 3):
+                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 0.3
+                with ib.else_scope():
+                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 1.3
+
+    stmt = ib.get()
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+    new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+    with tvm.transform.PassContext(config={
+        "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+    }):
+        new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    assert(not tvm.ir.structural_equal(new_stmt, stmt))
+
+@pytest.mark.skip()
+def test_hoisting_op_conv():
+    dtype = "float32"
+    dshape = (1, 80, 73, 73)
+    kshape = (192, 80, 3, 3)
+    padding=(1, 1)
+    groups=1
+    dilation=(1, 1)
+    kernel_size=(3, 3)
+    channels=192
+    scale=1
+    x = relay.var("x", shape=dshape, dtype=dtype)
+    w = relay.var("w", shape=kshape, dtype=dtype)
+    y = relay.nn.conv2d(x, w, padding=padding,
+                                dilation=dilation,
+                                groups=groups,
+                                channels=channels,
+                                kernel_size=kernel_size)
+
+    func = relay.Function([x, w], y)
+    mod = tvm.IRModule()
+    mod['main'] = func
+    mod = relay.transform.InferType()(mod)
+
+    data = np.random.uniform(-scale, scale, size=dshape).astype(dtype)
+    kernel = np.random.uniform(-scale, scale, size=kshape).astype(dtype)
+
+    params = {'w': tvm.nd.array(kernel)}
+    for target, ctx in ctx_list():
+        with tvm.transform.PassContext(opt_level=3):
+            graph, lib, params = relay.build_module.build(mod, target=target, params=params)
+            m = tvm.contrib.graph_runtime.create(graph, lib, ctx)
+            x = np.random.uniform(size=dshape)
+            data_tvm = tvm.nd.array(data)
+            m.set_input('x', data_tvm)
+            m.set_input(**params)
+            m.run()
+            e = m.module.time_evaluator("run", ctx, number=300, repeat=3)
+            t1 = e(data_tvm).results
+            t1 = np.array(t1) * 1000
+            print('{} ms'.format(t1.mean()))
+
+        with tvm.transform.PassContext(opt_level=3, config={
+        "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+        }):
+            graph, lib, params = relay.build_module.build(mod, target=target, params=params)
+            m = tvm.contrib.graph_runtime.create(graph, lib, ctx)
+            x = np.random.uniform(size=dshape)
+            data_tvm = tvm.nd.array(data)
+            m.set_input('x', data_tvm)
+            m.set_input(**params)
+            m.run()
+            e = m.module.time_evaluator("run", ctx, number=300, repeat=3)
+            t2 = e(data_tvm).results
+            t2 = np.array(t2) * 1000
+
+            print('{} ms'.format(t2.mean()))
+        tvm.testing.assert_allclose(t1.mean(), t2.mean(), atol=1, rtol=1e-1)
+
+if __name__ == "__main__":
+    pytest.main([__file__])