From 4c0a53dc5bef49797ccbb5f05b0c61ca832d7c85 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sun, 19 Apr 2020 19:57:25 -0700 Subject: [PATCH] [TIR][REFACTOR] RewriteForTensorCore -> te/schedule (#5379) * [TIR][REFACTIR] RewriteForTensorCore -> te/schedule RewriteForTensor depends on the schedule information, which makes it differ from a typical pass(which should get all the information from the input TIR). As a result, we refactor it as a SchedulePostProc step for now. We should revisit it later as we introduce more support for tensor core patterns in the TIR. * Fix VTA to fit the new IR Pattern --- include/tvm/te/schedule_pass.h | 49 ++++++++++++++-------- include/tvm/tir/ir_pass.h | 13 ------ python/tvm/driver/build_module.py | 45 ++++++++++++-------- python/tvm/te/hybrid/__init__.py | 6 ++- .../schedule_postproc_rewrite_for_tensor_core.cc} | 39 ++++++++++------- src/tir/pass/ffi_api.cc | 8 ---- vta/python/vta/ir_pass.py | 20 ++++----- 7 files changed, 98 insertions(+), 82 deletions(-) rename src/{tir/pass/tensor_core.cc => te/schedule/schedule_postproc_rewrite_for_tensor_core.cc} (97%) diff --git a/include/tvm/te/schedule_pass.h b/include/tvm/te/schedule_pass.h index e64ea21..c06ab50 100644 --- a/include/tvm/te/schedule_pass.h +++ b/include/tvm/te/schedule_pass.h @@ -35,6 +35,23 @@ namespace tvm { namespace te { /*! + * \brief To automatically inline the element-wise operations. + * + * \param sch The schedule to be inlined. + */ +void AutoInlineElemWise(Schedule sch); + +/*! + * \brief To automatically inline operations with injective writes + * (i.e. writes without reduction or sequential loops). Note + * that in this case, guarantees about contiguity, transpose, stride, + * alignemnt and memory footprint in general do not hold. + * + * \param sch The schedule to be inlined. + */ +TVM_DLL void AutoInlineInjective(Schedule sch); + +/*! * \brief Infer the bound of all iteration variables relates to the schedule. * * \param sch The root schedule to infer all the bounds. @@ -55,6 +72,21 @@ Map InferBound(const Schedule& sch); */ Stmt ScheduleOps(Schedule s, Map dom_map, bool debug_keep_trivial_loop); + +/*! + * \brief Try to modify the AST generated by ScheduleOps to support TensorCore. + * + * \param stmt The stmt to be trasnformed. + * \param schedule The original schedule. + * \param extern_buffer Map specifies external + * buffer assignment of input and outputs. + * \return Transformed stmt. + */ +Stmt SchedulePostProcRewriteForTensorCore( + Stmt stmt, + Schedule schedule, + Map extern_buffer); + /*! * \brief Postprocessing the Stmt generated by ScheduleOps to create * a PrimFunc that can then be used for further TIR optimizations. @@ -75,23 +107,6 @@ PrimFunc SchedulePostProcToPrimFunc(Array arg_list, Stmt body, Optional> bindings); -/*! - * \brief To automatically inline the element-wise operations. - * - * \param sch The schedule to be inlined. - */ -void AutoInlineElemWise(Schedule sch); - -/*! - * \brief To automatically inline operations with injective writes - * (i.e. writes without reduction or sequential loops). Note - * that in this case, guarantees about contiguity, transpose, stride, - * alignemnt and memory footprint in general do not hold. - * - * \param sch The schedule to be inlined. - */ -TVM_DLL void AutoInlineInjective(Schedule sch); - } // namespace te } // namespace tvm #endif // TVM_TE_SCHEDULE_PASS_H_ diff --git a/include/tvm/tir/ir_pass.h b/include/tvm/tir/ir_pass.h index e6e2de6..592a79f 100644 --- a/include/tvm/tir/ir_pass.h +++ b/include/tvm/tir/ir_pass.h @@ -165,19 +165,6 @@ Stmt Inline(Stmt stmt, PrimExpr body); /*! - * \brief Try to modify the AST to support TensorCore - * - * \param stmt The stmt to be trasnformed. - * \param schedule The original schedule. - * \param extern_buffer Map specifies external - * buffer assignment of input and outputs. - * \return Transformed stmt. - */ -Stmt RewriteForTensorCore(Stmt stmt, - te::Schedule schedule, - Map extern_buffer); - -/*! * \brief Verify if there is any argument bound to compact buffer. * * \param stmt The stmt to be verified. diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index eea3727..5c92965 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -84,23 +84,43 @@ def get_binds(args, compact=False, binds=None): return binds, arg_list -def form_body(sch): +def form_irmodule(sch, args, name, binds): """According to the given schedule, form a function. Parameters ---------- sch : tvm.te.schedule.Schedule - The given scheduler to form the raw body + The given scheduler to form the raw body + + args : list of Buffer or Tensor or Var + The argument lists to the function. + + name : str + The name of result function. + + binds : dict of :any:`Tensor` to :any:`Buffer`, optional + The binds information Returns ------- The body formed according to the given schedule """ # normalize schedule first + cfg = BuildConfig.current() sch = sch.normalize() bounds = schedule.InferBound(sch) stmt = schedule.ScheduleOps(sch, bounds) - return stmt + + compact = ir_pass.VerifyCompactBuffer(stmt) + binds, arg_list = get_binds(args, compact, binds) + + stmt = schedule.SchedulePostProcRewriteForTensorCore(stmt, sch, binds) + func = schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds) + + func = func.with_attr("global_symbol", name) + if cfg.restricted_func: + func = func.with_attr("tir.noalias", True) + return tvm.IRModule({name: func}) def _wrap_as_prim_func_pass(flist, name): @@ -166,24 +186,13 @@ def lower(sch, # Phase 0 if isinstance(sch, schedule.Schedule): - stmt = form_body(sch) - - for f in lower_phase0: - stmt = f(stmt) - - compact = ir_pass.VerifyCompactBuffer(stmt) - binds, arg_list = get_binds(args, compact, binds) - stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds) - - # Start the new style pass manager. - func = schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds) - func = func.with_attr("global_symbol", name) - if cfg.restricted_func: - func = func.with_attr("tir.noalias", True) - mod = tvm.IRModule({name: func}) + mod = form_irmodule(sch, args, name, binds) + else: + mod = sch # Phase 1 pass_list = [ + _wrap_as_prim_func_pass(lower_phase0, "Custom-Phase0"), tvm.tir.transform.InjectPrefetch(), tvm.tir.transform.StorageFlatten(64, cfg.instrument_bound_checkers), tvm.tir.transform.NarrowDataType(32), diff --git a/python/tvm/te/hybrid/__init__.py b/python/tvm/te/hybrid/__init__.py index 31acaeb..42bcc86 100644 --- a/python/tvm/te/hybrid/__init__.py +++ b/python/tvm/te/hybrid/__init__.py @@ -30,7 +30,7 @@ HalideIR. # 2. Support multi-level HalideIR import inspect import tvm._ffi -from tvm.driver.build_module import form_body +import tvm.te.schedule from tvm._ffi.base import decorate from .module import HybridModule @@ -87,8 +87,10 @@ def build(sch, inputs, outputs, name="hybrid_func"): The built results is wrapped in a HybridModule. The usage of HybridModule is roughly the same as normal TVM-built modules. """ + sch = sch.normalize() + bounds = tvm.te.schedule.InferBound(sch) + stmt = tvm.te.schedule.ScheduleOps(sch, bounds) - stmt = form_body(sch) src = _Dump(stmt, inputs, outputs, name) return HybridModule(src, name) diff --git a/src/tir/pass/tensor_core.cc b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc similarity index 97% rename from src/tir/pass/tensor_core.cc rename to src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc index dc2df98..5623559 100644 --- a/src/tir/pass/tensor_core.cc +++ b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc @@ -18,9 +18,12 @@ */ /*! - * \file tensor_core.cc + * \file schedule_postproc_rewrite_for_tensor_core.cc + * + * \brief Rewrite the Stmt generated by ScheduleOps + * to accomondate tensorcore. */ -// IR Passes for TensorCore CodeGen +#include #include #include #include @@ -32,12 +35,11 @@ #include #include #include -#include "ir_util.h" #include "../../arith/compute_expr.h" #include "../../runtime/thread_storage_scope.h" namespace tvm { -namespace tir { +namespace te { using namespace te; using runtime::StorageRank; @@ -86,10 +88,10 @@ class MMAMatcher: public StmtVisitor { } void VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::pragma_tensor_core) { + if (op->attr_key == tir::attr::pragma_tensor_core) { tensor_core_on_ = true; StmtVisitor::VisitStmt_(op); - } else if (op->attr_key == attr::realize_scope) { + } else if (op->attr_key == tir::attr::realize_scope) { storage_scope_[op->node.get()] = op->value.as()->value; this->VisitStmt(op->body); } else { @@ -414,7 +416,7 @@ class BufferAnalyser : public StmtExprVisitor { } void VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::thread_extent) { + if (op->attr_key == tir::attr::thread_extent) { if (const IntImmNode* value = op->value.as()) { thread_extent_.insert( std::make_pair( @@ -422,10 +424,10 @@ class BufferAnalyser : public StmtExprVisitor { value->value)); } StmtExprVisitor::VisitStmt_(op); - } else if (op->attr_key == attr::realize_scope) { + } else if (op->attr_key == tir::attr::realize_scope) { storage_scope_[op->node.get()] = op->value.as()->value; this->VisitStmt(op->body); - } else if (op->attr_key == attr::buffer_dim_align) { + } else if (op->attr_key == tir::attr::buffer_dim_align) { te::Tensor tensor = Downcast(op->node); const CallNode* tuple = op->value.as(); CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple)); @@ -850,7 +852,7 @@ class TensorCoreIRMutator : public StmtExprMutator { Stmt VisitStmt_(const AttrStmtNode* op) final { Stmt stmt = StmtExprMutator::VisitStmt_(op); - if (op->attr_key == attr::realize_scope) { + if (op->attr_key == tir::attr::realize_scope) { auto node = op->node.as(); if (node != nullptr) { if (!frag_reg_.count(node->name)) { @@ -1186,9 +1188,10 @@ class TensorCoreIRMutator : public StmtExprMutator { int warp_threads_y_{-1}; }; -Stmt RewriteForTensorCore(Stmt stmt, - Schedule schedule, - Map extern_buffer) { +Stmt SchedulePostProcRewriteForTensorCore( + Stmt stmt, + Schedule schedule, + Map extern_buffer) { // Check if current lower target is CUDA auto target = tvm::Target::Current(true); if (target.defined() && target->target_name != "cuda") { @@ -1223,5 +1226,13 @@ Stmt RewriteForTensorCore(Stmt stmt, return TensorCoreIRMutator(schedule_analyser, buffer_analyser)(std::move(stmt)); } -} // namespace tir +TVM_REGISTER_GLOBAL("schedule.SchedulePostProcRewriteForTensorCore") +.set_body_typed([](Stmt stmt, + Schedule schedule, + Map extern_buffer) { + return SchedulePostProcRewriteForTensorCore( + stmt, schedule, extern_buffer); +}); + +} // namespace te } // namespace tvm diff --git a/src/tir/pass/ffi_api.cc b/src/tir/pass/ffi_api.cc index 4d7ed5d..60b5bd9 100644 --- a/src/tir/pass/ffi_api.cc +++ b/src/tir/pass/ffi_api.cc @@ -75,14 +75,6 @@ TVM_REGISTER_GLOBAL("ir_pass.Substitute") } }); -TVM_REGISTER_GLOBAL("ir_pass.RewriteForTensorCore") -.set_body_typed - ([](const Stmt& stmt, - const te::Schedule& schedule, - const Map& extern_buffer) { - return RewriteForTensorCore(stmt, schedule, extern_buffer); - }); - TVM_REGISTER_GLOBAL("ir_pass.ExprUseVar") .set_body([](TVMArgs args, TVMRetValue *ret) { *ret = ExprUseVar(args[0].operator PrimExpr(), args[1].operator Var()); diff --git a/vta/python/vta/ir_pass.py b/vta/python/vta/ir_pass.py index 8a7798a..c2684d1 100644 --- a/vta/python/vta/ir_pass.py +++ b/vta/python/vta/ir_pass.py @@ -638,7 +638,7 @@ def inject_conv2d_transpose_skip(stmt_in): selects = [] def _find_basics(op): - if isinstance(op, tvm.tir.Call): + if isinstance(op, tvm.tir.BufferLoad): calls.append(op) elif isinstance(op, tvm.tir.Select): selects.append(op) @@ -664,18 +664,18 @@ def inject_conv2d_transpose_skip(stmt_in): body = op.body.body while isinstance(body, tvm.tir.IfThenElse): body = body.then_case - args = body.args - res_tensor = body.func.output(0) + args = body.indices + res_buffer = body.buffer tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], 1, 0, 1, 0, env.BLOCK_OUT) inner = tvm.tir.AttrStmt( - [dout, res_tensor], 'buffer_bind_scope', + [dout, res_buffer], 'buffer_bind_scope', tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner) return inner else: conv_call, data_call, kernel_call = calls[-3:] - pad_data_tensor = data_call.func.output(0) - kernel_tensor = kernel_call.func.output(0) - res_tensor = conv_call.func.output(0) + pad_data_tensor = data_call.buffer + kernel_tensor = kernel_call.buffer + res_tensor = conv_call.buffer if selects: condition = selects[0].condition @@ -696,19 +696,19 @@ def inject_conv2d_transpose_skip(stmt_in): 0, 0, 0)) inner = irb.get() - args = conv_call.args + args = conv_call.indices tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], 1, 0, 1, 0, env.BLOCK_OUT) inner = tvm.tir.AttrStmt( [dout, res_tensor], 'buffer_bind_scope', tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner) - args = kernel_call.args + args = kernel_call.indices tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], 1, 0, env.BLOCK_OUT, 0, env.BLOCK_IN) inner = tvm.tir.AttrStmt( [dwgt, kernel_tensor], 'buffer_bind_scope', tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner) - args = data_call.args + args = data_call.indices tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], 1, 0, 1, 0, env.BLOCK_IN) inner = tvm.tir.AttrStmt( -- 2.7.4