namespace tvm {
namespace tir {
+struct HoistIfThenElseConfigNode : public tvm::AttrsNode<HoistIfThenElseConfigNode> {
+ bool support_block_scope_hosting;
+
+ TVM_DECLARE_ATTRS(HoistIfThenElseConfigNode, "tir.transform.HoistIfThenElseConfig") {
+ TVM_ATTR_FIELD(support_block_scope_hosting)
+ .describe("Hoist if cond with block scope variables")
+ .set_default(false);
+ }
+};
+
+class HoistIfThenElseConfig : public Attrs {
+ public:
+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(HoistIfThenElseConfig, Attrs,
+ HoistIfThenElseConfigNode);
+};
+
+TVM_REGISTER_NODE_TYPE(HoistIfThenElseConfigNode);
+TVM_REGISTER_PASS_CONFIG_OPTION("tir.HoistIfThenElse", HoistIfThenElseConfig);
+
using VarForMap = std::unordered_map<const VarNode*, const ForNode*>;
using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>;
* if (likely(j > 2))
* A[i+j+k] = B[i+j+k]
*
+ *
+ * This pass do hoisting for Block scope variables also.
+ * As below:
+ * Attr(IterVar: threadIdx.x)
+ * for (i = 0; i < 3; i++)
+ * for (j = 0; j < 4; j++)
+ * for (k = 0; k < 5; k++)
+ * if (likely(threadIdx.x < 3))
+ * A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Will be transformed to as below:
+ * Attr(IterVar: threadIdx.x)
+ * if (likely(threadIdx.x < 3))
+ * for (i = 0; i < 3; i++)
+ * for (j = 0; j < 4; j++)
+ * for (k = 0; k < 5; k++)
+ * A[3*i+2j+k] = B[7*i+3j+k]
+ *
*/
// Select potential candidate IRs that can be hoisted.
class HoistCandidateSelector final : public StmtExprVisitor {
public:
+ explicit HoistCandidateSelector(bool support_block_scope_hosting)
+ : support_block_scope_hosting_(support_block_scope_hosting) {
+ InitRecorder();
+ }
HoistCandidateSelector() { InitRecorder(); }
void VisitStmt_(const ForNode* op) final {
}
// Check if it is first for loop, then start the recorder
- StartOrAddRecord(op);
+ StartOrAddRecord(GetRef<ObjectRef>(op));
StmtExprVisitor::VisitStmt_(op);
- RemoveRecord(op);
+ RemoveRecord(GetRef<ObjectRef>(op));
}
void VisitStmt_(const SeqStmtNode* op) final {
// If SeqStmt is encountered in the middle of recording
// then need to purge all, as it can not be hoisted
if (IsRecordingOn()) {
- ResetRecorder();
+ ResetRecorderInternal();
}
StmtExprVisitor::VisitStmt_(op);
}
// Maintain list of all vars in AttrStmt
// To stop hoisting if any of the block variables are used.
//
- // NOTE: If in future
- // hoisting is required for any specific case,
- // then add exception to only those case
- // rather than allowing for all.
+ // In case we want to use hoisting in between certain passes
+ // which have interdependencies of the postioning of if nodes with scope var
+ // it is better to disable this section
+ if (support_block_scope_hosting_) {
+ if (IsRecordingOn()) {
+ StartOrAddRecord(GetRef<ObjectRef>(op));
+ StmtExprVisitor::VisitStmt_(op);
+ RemoveRecord(GetRef<ObjectRef>(op));
+ return;
+ } else {
+ return StmtExprVisitor::VisitStmt_(op);
+ }
+ }
UpdateAttrVarList(op);
StmtExprVisitor::VisitStmt_(op);
RemoveAttrVarList(op);
if (CheckValidIf()) {
// Check corresponding for loop
- bool match_found = false;
- size_t match_for_loop_pos = 0;
+ int match_for_loop_pos = -1;
for (auto var : if_var_list_) {
- for (size_t i = 0; i < ordered_for_list_.size() - 1; ++i) {
- if (ordered_for_list_[i] == var_for_map_[var]) {
+ for (int i = 0; i < static_cast<int>(ordered_list_.size()); ++i) {
+ if ((ordered_list_[i] == var_for_map_[var]) || (ordered_list_[i] == var)) {
if (match_for_loop_pos < i) {
match_for_loop_pos = i;
}
- match_found = true;
- break;
}
}
}
// If none of the for loop has the matching loop variable as if condition,
// then the if node need to be hoisted on top of all, provided no parent loop exists.
- int target_for_pos = match_found ? match_for_loop_pos + 1 : 0;
+ int target_for_pos = GetNextLoopPos(match_for_loop_pos);
- // Check if target for loop is not the parent of current if node
- if (!IsParentForLoop(target_for_pos)) {
- StopAndAddRecord(ordered_for_list_[target_for_pos], op);
+ // Check if valid position
+ if (target_for_pos >= 0) {
+ StopAndAddRecord(static_cast<const ForNode*>(ordered_list_[target_for_pos]), op);
if_var_list_.clear();
return;
}
HoistForIfTuple hoist_for_if_recorder;
void ResetRecorder() {
- if (is_recorder_on_) {
- CHECK_GT(ordered_for_list_.size(), 0);
- is_recorder_on_ = false;
- }
- ordered_for_list_.clear();
- var_for_map_.clear();
- hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr);
+ ResetRecorderInternal();
+
+ // Reset Block scope vars also here
+ attr_var_list_.clear();
}
bool RecordingComplete() { return std::get<0>(hoist_for_if_recorder); }
const IfThenElseNode* GetTargetIfNode() { return std::get<2>(hoist_for_if_recorder); }
private:
+ void ResetRecorderInternal() {
+ if (is_recorder_on_) {
+ CHECK_GT(ordered_list_.size(), 0);
+ is_recorder_on_ = false;
+ }
+ ordered_list_.clear();
+ var_for_map_.clear();
+ hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr);
+ }
bool CheckValidIf() {
// If no if var list is present, then all the condition vars are possibly from AttrStmt, so stop
// hoisting
return ((!if_var_list_.empty()) && (!CheckAttrVar()));
}
- bool IsParentForLoop(int loop_pos) {
- // Check if the loop position is higher than the parent loop position
- for (auto var : if_var_list_) {
- if (GetParentLoopPos(var_for_map_[var]) >= loop_pos) {
- return true;
- }
- }
- return false;
- }
-
- int GetParentLoopPos(const Object* node) {
- for (size_t i = 0; i < ordered_for_list_.size(); ++i) {
- if (ordered_for_list_[i] == node) {
+ int GetNextLoopPos(int cur_pos) {
+ for (size_t i = cur_pos + 1; i < ordered_list_.size(); ++i) {
+ if (ordered_list_[i]->IsInstance<ForNode>()) {
return i;
}
}
bool IsRecordingOn() { return is_recorder_on_; }
- void StartOrAddRecord(const ForNode* op) {
+ void StartOrAddRecord(const ObjectRef& op) {
is_recorder_on_ = true;
- if (!var_for_map_.count(op->loop_var.get())) {
- var_for_map_.insert({op->loop_var.get(), op});
+ if (const auto* node = op.as<ForNode>()) {
+ if (!var_for_map_.count(node->loop_var.get()))
+ var_for_map_.insert({node->loop_var.get(), node});
+ ordered_list_.emplace_back(op.get());
+ } else if (const auto* node = op.as<AttrStmtNode>()) {
+ if (const auto* iv = node->node.as<IterVarNode>()) {
+ ordered_list_.emplace_back(iv->var.get());
+ } else if (const auto* iv = node->node.as<VarNode>()) {
+ ordered_list_.emplace_back(iv);
+ }
}
- ordered_for_list_.emplace_back(op);
}
- void RemoveRecord(const ForNode* op) {
+ void RemoveRecord(const ObjectRef& op) {
StopRecording();
- var_for_map_.erase(op->loop_var.get());
- if (ordered_for_list_.size() > 0) ordered_for_list_.pop_back();
+ if (const auto* node = op.as<ForNode>()) var_for_map_.erase(node->loop_var.get());
+ if (ordered_list_.size() > 0) ordered_list_.pop_back();
}
void StopAndAddRecord(const ForNode* for_node, const IfThenElseNode* if_node) {
return false;
}
- std::vector<const ForNode*> ordered_for_list_;
+ // Ordered List maintains all ForNodes & AttrStmtNodes encountered in sequence
+ std::vector<const Object*> ordered_list_;
std::vector<const VarNode*> if_var_list_;
std::unordered_set<const VarNode*> attr_var_list_;
VarForMap var_for_map_;
bool is_if_cond_{false};
bool is_recorder_on_{false};
+ bool support_block_scope_hosting_{false};
};
class IfThenElseHoister : public StmtMutator {
public:
IfThenElseHoister() : hoist_selector_(HoistCandidateSelector()) {}
+ explicit IfThenElseHoister(bool support_block_scope_hosting)
+ : hoist_selector_(HoistCandidateSelector(support_block_scope_hosting)) {}
Stmt VisitAndMutate(Stmt stmt) {
hoist_selector_(stmt);
const IfThenElseNode* target_if_;
};
+Stmt HoistIfThenElse(Stmt stmt, bool support_block_scope_hosting) {
+ return IfThenElseHoister(support_block_scope_hosting).VisitAndMutate(stmt);
+}
Stmt HoistIfThenElse(Stmt stmt) { return IfThenElseHoister().VisitAndMutate(stmt); }
namespace transform {
Pass HoistIfThenElse() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
- n->body = HoistIfThenElse(std::move(n->body));
+ auto cfg = ctx->GetConfig<HoistIfThenElseConfig>("tir.HoistIfThenElse");
+
+ if (!cfg.defined()) {
+ cfg = AttrsWithDefaultValues<HoistIfThenElseConfig>();
+ }
+ n->body = HoistIfThenElse(std::move(n->body), cfg.value()->support_block_scope_hosting);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.HoistIfThenElse", {});
}
+Pass HoistIfThenElseBasic() {
+ auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
+ auto* n = f.CopyOnWrite();
+ n->body = HoistIfThenElse(std::move(n->body));
+ return f;
+ };
+ return CreatePrimFuncPass(pass_func, 0, "tir.HoistIfThenElseBasic", {});
+}
+
TVM_REGISTER_GLOBAL("tir.transform.HoistIfThenElse").set_body_typed(HoistIfThenElse);
+TVM_REGISTER_GLOBAL("tir.transform.HoistIfThenElseBasic").set_body_typed(HoistIfThenElseBasic);
+
} // namespace transform
} // namespace tir
# under the License.
import tvm
from tvm import te
+from tvm import relay
+import numpy as np
+import pytest
+from tvm.relay.testing import ctx_list
var_list = []
('tir.For', 'i'): (('tir.IfThenElse', ('i',)),)}
verify_structure(new_stmt, expected_struct)
+def test_no_hoisting_1():
+ ib = tvm.tir.ir_builder.create()
+ data = ib.pointer("float32", name="data")
+ n = te.var("n")
-if __name__ == "__main__":
- test_hoist_top_for()
- test_hoist_multi_var_if()
- test_hoist_no_match_for()
- test_no_else()
- test_attr_stmt()
- test_nested_for()
- test_if_block()
- test_multi_if()
+ with ib.for_range(0, 10, "i") as i:
+ with ib.for_range(0, 10, "j") as j:
+ with ib.for_range(0, 10, "k") as k:
+ with ib.if_scope(k >= 3):
+ data[i * 100 + j * 10 + k] = data[i * 100 + j * 10 + k] + 0.5
+ stmt = ib.get()
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+ with tvm.transform.PassContext(config={
+ "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+ }):
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+def test_no_hoisting_2():
+ ib = tvm.tir.ir_builder.create()
+ data = ib.pointer("float32", name="data")
+ n = te.var("n")
+ x = te.var("x")
+
+ with ib.for_range(0, 10, "i") as i:
+ with ib.for_range(0, 10, "j") as j:
+ with ib.for_range(0, 10, "k") as k:
+ with ib.if_scope(i >= 3):
+ data[i * 100 + j * 10 + k] = data[i * 100 + j * 10 + k] + 0.3
+ data[i * 100 + j * 10 + k] = data[i * 100 + j * 10 + k] + 0.5
+
+ stmt = ib.get()
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+ with tvm.transform.PassContext(config={
+ "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+ }):
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+def test_no_hoisting_3():
+ ib = tvm.tir.ir_builder.create()
+ dshape = (32, 64)
+ dshape_inner = (33, 63)
+ data = ib.pointer("float32", name="data")
+ l = te.var('l')
+ m = te.var('m')
+ n = te.var('n')
+
+ tx = te.thread_axis("threadIdx.x")
+ bx = te.thread_axis("blockIdx.x")
+ ib.scope_attr(tx, "thread_extent", dshape[0])
+ ib.scope_attr(bx, "thread_extent", dshape[1])
+ with ib.for_range(0, l, "i") as i:
+ with ib.for_range(0, m, "j") as j:
+ with ib.for_range(0, n, "k") as k:
+ ib.scope_attr(tx, "thread_extent", dshape_inner[0])
+ ib.scope_attr(bx, "thread_extent", dshape_inner[1])
+ with ib.if_scope(tx < 3):
+ data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3
+ with ib.else_scope():
+ data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3
+
+ stmt = ib.get()
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+ with tvm.transform.PassContext(config={
+ "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+ }):
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+def test_no_hoisting_4():
+ ib = tvm.tir.ir_builder.create()
+ dshape = (32, 64)
+ dshape_inner = (33, 63)
+ data = ib.pointer("float32", name="data")
+ l = te.var('l')
+ m = te.var('m')
+ n = te.var('n')
+
+ tx = te.thread_axis("threadIdx.x")
+ bx = te.thread_axis("blockIdx.x")
+ ib.scope_attr(bx, "thread_extent", dshape[1])
+ with ib.for_range(0, l, "i") as i:
+ with ib.for_range(0, m, "j") as j:
+ with ib.for_range(0, n, "k") as k:
+ ib.scope_attr(tx, "thread_extent", dshape_inner[0])
+ with ib.if_scope(tx < 3):
+ data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3
+ with ib.else_scope():
+ data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3
+
+ stmt = ib.get()
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+ with tvm.transform.PassContext(config={
+ "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+ }):
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+def test_no_hoisting_5():
+ ib = tvm.tir.ir_builder.create()
+ dshape = (32, 64)
+ dshape_inner = (33, 63)
+ data = ib.pointer("float32", name="data")
+ l = te.var('l')
+ m = te.var('m')
+ n = te.var('n')
+
+ tx = te.thread_axis("threadIdx.x")
+ bx = te.thread_axis("blockIdx.x")
+ ib.scope_attr(tx, "thread_extent", dshape[0])
+ ib.scope_attr(bx, "thread_extent", dshape[1])
+ with ib.for_range(0, l, "i") as i:
+ with ib.for_range(0, m, "j") as j:
+ ib.scope_attr(bx, "thread_extent", dshape_inner[1])
+ with ib.for_range(0, n, "k") as k:
+ ib.scope_attr(tx, "thread_extent", dshape_inner[0])
+ with ib.if_scope(tx < 3):
+ data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3
+ with ib.else_scope():
+ data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3
+
+ stmt = ib.get()
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+ with tvm.transform.PassContext(config={
+ "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+ }):
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+def test_no_hoisting_6():
+ ib = tvm.tir.ir_builder.create()
+ dshape = (32, 64)
+ data = ib.pointer("float32", name="data")
+ l = te.var('l')
+ m = te.var('m')
+ n = te.var('n')
+
+ tx = te.thread_axis("threadIdx.x")
+ bx = te.thread_axis("blockIdx.x")
+ ib.scope_attr(tx, "thread_extent", dshape[0])
+ ib.scope_attr(bx, "thread_extent", dshape[1])
+ with ib.for_range(0, l, "i") as i:
+ with ib.for_range(0, m, "j") as j:
+ with ib.for_range(0, n, "k") as k:
+ with ib.if_scope((tx + k) < 3):
+ data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3
+ with ib.else_scope():
+ data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3
+
+ stmt = ib.get()
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+ with tvm.transform.PassContext(config={
+ "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+ }):
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+def test_no_hoisting_7():
+ ib = tvm.tir.ir_builder.create()
+ dshape = (32, 64)
+ data = ib.pointer("float32", name="data")
+ l = te.var('l')
+ m = te.var('m')
+ n = te.var('n')
+
+ tx = te.thread_axis("threadIdx.x")
+ bx = te.thread_axis("blockIdx.x")
+ ib.scope_attr(tx, "thread_extent", dshape[0])
+ ib.scope_attr(bx, "thread_extent", dshape[1])
+ with ib.for_range(0, l, "i") as i:
+ with ib.for_range(0, m, "j") as j:
+ with ib.if_scope((tx + j) < 9):
+ with ib.for_range(0, n, "k") as k:
+ with ib.if_scope((tx + k) < 3):
+ data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3
+
+ stmt = ib.get()
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+ with tvm.transform.PassContext(config={
+ "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+ }):
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+def test_hoisting_block_scope_1():
+ n = te.size_var("n")
+ m = te.size_var("m")
+ A = te.placeholder((n, m), name='A')
+ k = te.reduce_axis((0, m), "k")
+ B = te.compute((n,), lambda i: te.sum(A[i, k], axis=k), name="B")
+ s = te.create_schedule(B.op)
+ ko, ki = s[B].split(B.op.reduce_axis[0], factor=16)
+ BF = s.rfactor(B, ki)
+ xo, xi = s[B].split(s[B].op.axis[0], factor=32)
+ s[B.op].bind(xo, te.thread_axis("blockIdx.x"))
+ s[B.op].bind(xi, te.thread_axis("threadIdx.y"))
+ s[B].bind(s[B].op.reduce_axis[0], te.thread_axis("threadIdx.x"))
+ s[BF].compute_at(s[B], s[B].op.reduce_axis[0])
+ func = tvm.driver.build_module.form_irmodule(
+ s, [A, B], "main", None)["main"]
+ stmt = func.body
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(tvm.IRModule.from_expr(func))["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+ with tvm.transform.PassContext(config={
+ "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+ }):
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(tvm.IRModule.from_expr(func))["main"].body
+ assert(not tvm.ir.structural_equal(new_stmt, stmt))
+
+def test_hoisting_block_scope_2():
+ ib = tvm.tir.ir_builder.create()
+ dshape = (32, 64)
+ dshape_inner = (33, 63)
+ data = ib.pointer("float32", name="data")
+ l = te.var('l')
+ m = te.var('m')
+ n = te.var('n')
+
+ tx = te.thread_axis("threadIdx.x")
+ bx = te.thread_axis("blockIdx.x")
+ ib.scope_attr(tx, "thread_extent", dshape[0])
+ #ib.scope_attr(bx, "thread_extent", dshape[1])
+ with ib.for_range(0, l, "i") as i:
+ with ib.for_range(0, m, "j") as j:
+ with ib.for_range(0, n, "k") as k:
+ ib.scope_attr(bx, "thread_extent", dshape[1])
+ with ib.if_scope(tx < 3):
+ data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3
+ with ib.else_scope():
+ data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3
+
+ stmt = ib.get()
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+ with tvm.transform.PassContext(config={
+ "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+ }):
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ #tvm.ir.assert_structural_equal(new_stmt, stmt)
+ assert(not tvm.ir.structural_equal(new_stmt, stmt))
+
+def test_hoisting_block_scope_3():
+ ib = tvm.tir.ir_builder.create()
+ dshape = (32, 64)
+ dshape_inner = (33, 63)
+ data = ib.pointer("float32", name="data")
+ l = te.var('l')
+ m = te.var('m')
+ n = te.var('n')
+
+ tx = te.thread_axis("threadIdx.x")
+ bx = te.thread_axis("blockIdx.x")
+ ib.scope_attr(tx, "thread_extent", dshape[0])
+ ib.scope_attr(bx, "thread_extent", dshape[1])
+ with ib.for_range(0, l, "i") as i:
+ with ib.for_range(0, m, "j") as j:
+ ib.scope_attr(tx, "thread_extent", dshape_inner[0])
+ ib.scope_attr(bx, "thread_extent", dshape_inner[1])
+ with ib.for_range(0, n, "k") as k:
+ with ib.if_scope(tx < 3):
+ data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3
+ with ib.else_scope():
+ data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3
+
+ stmt = ib.get()
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+ with tvm.transform.PassContext(config={
+ "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+ }):
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ #tvm.ir.assert_structural_equal(new_stmt, stmt)
+ assert(not tvm.ir.structural_equal(new_stmt, stmt))
+
+def test_hoisting_block_scope_4():
+ nn = 1024
+ n = tvm.runtime.convert(nn)
+ A = te.placeholder((n,), name='A')
+ B = te.placeholder((n,), name='B')
+ AA = te.compute((n,), lambda *i: A(*i), name='A')
+ BB = te.compute((n,), lambda *i: B(*i), name='B')
+ T = te.compute(A.shape, lambda *i: AA(*i) + BB(*i), name='T')
+ C = te.compute(A.shape, lambda *i: T(*i), name='C')
+ s = te.create_schedule(C.op)
+ xo, xi = s[C].split(C.op.axis[0], factor=4)
+ xo1, xo2 = s[C].split(xo, factor=13)
+ s[C].parallel(xo2)
+ s[C].pragma(xo1, "parallel_launch_point")
+ s[C].pragma(xo2, "parallel_stride_pattern")
+ s[C].pragma(xo2, "parallel_barrier_when_finish")
+ s[C].vectorize(xi)
+ func = tvm.driver.build_module.form_irmodule(
+ s, [A, B, C], "main", None)["main"]
+ stmt = func.body
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(tvm.IRModule.from_expr(func))["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+ with tvm.transform.PassContext(config={
+ "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+ }):
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(tvm.IRModule.from_expr(func))["main"].body
+ assert(not tvm.ir.structural_equal(new_stmt, stmt))
+
+def test_hoisting_block_scope_5():
+ ib = tvm.tir.ir_builder.create()
+ data = ib.pointer("float32", name="data")
+ l = te.var('l')
+ m = te.var('m')
+ n = te.var('n')
+ g = te.var('g')
+
+ ib.scope_attr(data, "storage_scope", "global")
+ with ib.for_range(0, l, "i") as i:
+ with ib.for_range(0, m, "j") as j:
+ with ib.for_range(0, n, "k") as k:
+ with ib.if_scope(data[g] < 3):
+ data[9 * j + 3 * j * k] = data[9 * j + 3 * j * k] + 0.3
+ with ib.else_scope():
+ data[9 * j + 3 * j * k] = data[9 * j + 3 * j * k] + 1.3
+
+ stmt = ib.get()
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ assert(not tvm.ir.structural_equal(new_stmt, stmt))
+
+ stmt = new_stmt
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+
+ with tvm.transform.PassContext(config={
+ "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+ }):
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+def test_hoisting_block_scope_6():
+ ib = tvm.tir.ir_builder.create()
+ dshape = (32, 64)
+ data = ib.pointer("float32", name="data")
+ l = te.var('l')
+ m = te.var('m')
+ n = te.var('n')
+
+ tx = te.thread_axis("threadIdx.x")
+ bx = te.thread_axis("blockIdx.x")
+ ib.scope_attr(tx, "thread_extent", dshape[0])
+ ib.scope_attr(bx, "thread_extent", dshape[1])
+ with ib.for_range(0, l, "i") as i:
+ with ib.for_range(0, m, "j") as j:
+ with ib.for_range(0, n, "k") as k:
+ with ib.if_scope((tx + n) < 3):
+ data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3
+ with ib.else_scope():
+ data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3
+
+ stmt = ib.get()
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+ with tvm.transform.PassContext(config={
+ "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+ }):
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ assert(not tvm.ir.structural_equal(new_stmt, stmt))
+
+def test_hoisting_block_scope_7():
+ ib = tvm.tir.ir_builder.create()
+ dshape = (32, 64)
+ data = ib.pointer("float32", name="data")
+ l = te.var('l')
+ m = te.var('m')
+ n = te.var('n')
+
+ tx = te.thread_axis("threadIdx.x")
+ bx = te.thread_axis("blockIdx.x")
+ ib.scope_attr(tx, "thread_extent", dshape[0])
+ ib.scope_attr(bx, "thread_extent", dshape[1])
+ with ib.for_range(0, l, "i") as i:
+ with ib.for_range(0, m, "j") as j:
+ with ib.for_range(0, n, "k") as k:
+ with ib.if_scope((tx + i) < 3):
+ data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3
+ with ib.else_scope():
+ data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3
+
+ stmt = ib.get()
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+ with tvm.transform.PassContext(config={
+ "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+ }):
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ assert(not tvm.ir.structural_equal(new_stmt, stmt))
+
+@pytest.mark.skip()
+def test_hoisting_op_conv():
+ dtype = "float32"
+ dshape = (1, 80, 73, 73)
+ kshape = (192, 80, 3, 3)
+ padding=(1, 1)
+ groups=1
+ dilation=(1, 1)
+ kernel_size=(3, 3)
+ channels=192
+ scale=1
+ x = relay.var("x", shape=dshape, dtype=dtype)
+ w = relay.var("w", shape=kshape, dtype=dtype)
+ y = relay.nn.conv2d(x, w, padding=padding,
+ dilation=dilation,
+ groups=groups,
+ channels=channels,
+ kernel_size=kernel_size)
+
+ func = relay.Function([x, w], y)
+ mod = tvm.IRModule()
+ mod['main'] = func
+ mod = relay.transform.InferType()(mod)
+
+ data = np.random.uniform(-scale, scale, size=dshape).astype(dtype)
+ kernel = np.random.uniform(-scale, scale, size=kshape).astype(dtype)
+
+ params = {'w': tvm.nd.array(kernel)}
+ for target, ctx in ctx_list():
+ with tvm.transform.PassContext(opt_level=3):
+ graph, lib, params = relay.build_module.build(mod, target=target, params=params)
+ m = tvm.contrib.graph_runtime.create(graph, lib, ctx)
+ x = np.random.uniform(size=dshape)
+ data_tvm = tvm.nd.array(data)
+ m.set_input('x', data_tvm)
+ m.set_input(**params)
+ m.run()
+ e = m.module.time_evaluator("run", ctx, number=300, repeat=3)
+ t1 = e(data_tvm).results
+ t1 = np.array(t1) * 1000
+ print('{} ms'.format(t1.mean()))
+
+ with tvm.transform.PassContext(opt_level=3, config={
+ "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+ }):
+ graph, lib, params = relay.build_module.build(mod, target=target, params=params)
+ m = tvm.contrib.graph_runtime.create(graph, lib, ctx)
+ x = np.random.uniform(size=dshape)
+ data_tvm = tvm.nd.array(data)
+ m.set_input('x', data_tvm)
+ m.set_input(**params)
+ m.run()
+ e = m.module.time_evaluator("run", ctx, number=300, repeat=3)
+ t2 = e(data_tvm).results
+ t2 = np.array(t2) * 1000
+
+ print('{} ms'.format(t2.mean()))
+ tvm.testing.assert_allclose(t1.mean(), t2.mean(), atol=1, rtol=1e-1)
+
+if __name__ == "__main__":
+ pytest.main([__file__])