TVM_DLL bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const;
};
-
/*!
* \brief Find undefined vars in the statment.
* \param stmt The function to be checked.
bool VerifyCompactBuffer(Stmt stmt);
/*!
- * \brief Remove No Op from the Stmt.
- * \param stmt The stmt to be trasnformed
- * \return Transformed stmt.
- */
-Stmt RemoveNoOp(Stmt stmt);
-
-/*!
- * \brief unroll the constant loop marked by unroll.
- * This pass also automatically attach pragma unroll tag to loops which meets the standard.
- *
- * \param stmt The statment to be unrolled.
- * \param auto_max_step The maximum step before stop attach automatic unroll
- * \param auto_max_depth The maximum depth before stop attach automatic unroll
- * \param auto_max_extent The maximum extent of the loop we can unroll,
- * this is an legacy option that do not take the loop total steps into account.
- * \param explicit_unroll Whether explicitly unroll the loop, or leave unroll annotation to codegen.
- * \return Transformed stmt.
- */
-Stmt UnrollLoop(Stmt stmt,
- int auto_max_step,
- int auto_max_depth,
- int auto_max_extent,
- bool explicit_unroll);
-
-/*!
- * \brief vectorize the constant loops
- * \param stmt The statement to be vectorized.
- * \return Transformed stmt.
- */
-Stmt VectorizeLoop(Stmt stmt);
-
-/*!
- * \brief convert vectorized loops into serialized loops
- * \param stmt The statement to skip vectorization on.
- * \return Transformed stmt.
- */
-Stmt SkipVectorize(Stmt stmt);
-
-/*!
-* \brief instruments bound checkers.
-* \param stmt The statement to be instrumented.
-* \return Instrumented stmt.
-*/
-Stmt InstrumentBoundCheckers(Stmt stmt);
-
-/*!
- * \brief Inject virtual thread loops into stmt.
- * \param stmt The statement to be transformed.
- * \return Transformed stmt.
- */
-Stmt InjectVirtualThread(Stmt stmt);
-
-/*!
* \brief Inject prefetch instructions into stmt.
* \param stmt The statement to be transformed.
* \return Transformed stmt.
Stmt InjectPrefetch(Stmt stmt);
/*!
- * \brief Inject double buffer into stmt.
- * \param stmt The statement to be transformed.
- * \param split_loop Loop splitting factor.
- * \return Transformed stmt.
- */
-Stmt InjectDoubleBuffer(Stmt stmt, int split_loop);
-
-/*!
- * \brief Inject copy intrinsics with optional pad.
- *
- * \param stmt The statement to be transformed.
- * \param pragma_key The pragma key for hint of copy.
- * \param fintrin The function with signature
- *
- * Stmt fintrin(Buffer src,
- * Buffer dst,
- * Array<Expr> pad_before,
- * Array<Expr> pad_after,
- * Expr pad_value)
- * \return Transformed stmt.
- */
-Stmt InjectCopyIntrin(Stmt stmt,
- const std::string& pragma_key,
- const runtime::PackedFunc& fintrin);
-
-/*!
- * \brief Rewrite storage allocation pattern.
- * Moves the allocation to outer most possible scope.
- * Trying to share space between allocations to make
- * a static allocation plan when possible.
- *
- * \param stmt The stmt to be transformed
- * \return Transformed stmt.
- */
-Stmt StorageRewrite(Stmt stmt);
-
-/*!
- * \brief partition loops in the stmt
- * \param stmt The stmt to do loop partition
- * \param split_const_loop flag to enable partition for const loop
- * \return Transformed stmt.
- */
-Stmt LoopPartition(Stmt stmt, bool split_const_loop);
-
-/*!
- * \brief Detect and insert sync points to co-processor.
- *
- * \param stmt The stmt to be transformed
- * \return Transformed stmt.
- */
-Stmt CoProcSync(Stmt stmt);
-
-/*!
- * \brief Lift common attrs with attr_key to outer scope.
- *
- * \param stmt The stmt to be transformed
- * \param attr_key The attribute key to be checked.
- * \return Transformed stmt.
- */
-Stmt LiftAttrScope(Stmt stmt, std::string attr_key);
-
-/*!
- * \brief Detect and rewrite unsafe select that contains memory access.
- * \param stmt The statement to be rewritten.
- * \return Transformed stmt.
- */
-Stmt RewriteUnsafeSelect(Stmt stmt);
-
-/*!
- * \brief Lower attached storage access information.
- * Do this pass after all storage access analysis finish.
- *
- * \param stmt The stmt to be transformed
- * \return Transformed stmt.
- */
-Stmt LowerStorageAccessInfo(Stmt stmt);
-
-/*!
* \brief Decorate the stmt with a device scope, this is helpful for
* hardware accelerator without thread blocks.
*
Stmt HoistIfThenElse(Stmt stmt);
/*!
- * \brief Narrow down PrimExpr datatype in stmt to target_bits.
- * \note Run this pass after StorageFlatten.
- * \param stmt The stmt to do datatype rewrite
- * \param target_bits the bit of target datatype
- * \return Transformed stmt.
- */
-Stmt NarrowDataType(Stmt stmt, int target_bits);
-
-/*!
* \brief Rewrite the pointer content type of arguments,
* as well as Alloc internal to the function to use
* the most frequently accessed type for load/store
const tvm::Array<runtime::String>& required);
/*!
+ * \brief Inject copy intrinsics with optional pad.
+ *
+ * \param pragma_key The pragma key for hint of copy.
+ * \param fintrin The function with signature
+ *
+ * Stmt fintrin(Buffer src,
+ * Buffer dst,
+ * Array<Expr> pad_before,
+ * Array<Expr> pad_after,
+ * Expr pad_value)
+ * \return The pass.
+ */
+TVM_DLL Pass InjectCopyIntrin(std::string pragma_key,
+ runtime::PackedFunc fintrin);
+
+/*!
+ * \brief Detect and insert sync points to co-processor.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass CoProcSync();
+
+/*!
+ * \brief Lift common attrs with attr_key to outer scope.
+ *
+ * \param attr_key The attribute key to be checked.
+ * \return The pass.
+ */
+TVM_DLL Pass LiftAttrScope(std::string attr_key);
+
+/*!
+ * \brief partition loops in the stmt.
+ *
+ * \param split_const_loop flag to enable partition for const loop
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass LoopPartition(bool split_const_loop);
+
+/*!
+ * \brief Lower vectorization loops.
+ *
+ * \param enable_vectorize Whether vectorization is enabled.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass VectorizeLoop(bool enable_vectorize = true);
+
+/*!
+ * \brief Inject virtual thread loops.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass InjectVirtualThread();
+
+/*!
+ * \brief Inject double buffer statements.
+ *
+ * \param split_loop_factor Loop splitting factor.
+ * \return The pass.
+ */
+TVM_DLL Pass InjectDoubleBuffer(int split_loop_factor);
+
+/*!
+ * \brief Rewrite storage allocation pattern.
+ * Moves the allocation to outer most possible scope.
+ * Trying to share space between allocations to make
+ * a static allocation plan when possible.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass StorageRewrite();
+
+/*!
+ * \brief unroll the constant loop marked by unroll.
+ * This pass also automatically attach pragma unroll tag to loops which meets the standard.
+ *
+ * \param auto_max_step The maximum step before stop attach automatic unroll
+ * \param auto_max_depth The maximum depth before stop attach automatic unroll
+ * \param auto_max_extent The maximum extent of the loop we can unroll,
+ * this is an legacy option that do not take the loop total steps into account.
+ * \param explicit_unroll Whether explicitly unroll the loop, or leave unroll annotation to codegen.
+ * \return The pass.
+ */
+TVM_DLL Pass UnrollLoop(int auto_max_step,
+ int auto_max_depth,
+ int auto_max_extent,
+ bool explicit_unroll);
+
+/*!
+ * \brief Remove No Op from the Stmt.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass RemoveNoOp();
+
+/*!
+ * \brief Detect and rewrite unsafe select that contains memory access.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass RewriteUnsafeSelect();
+
+/*!
+* \brief Run arithmetic simplifications on the statements and expressions.
+*
+* \return The pass.
+*/
+TVM_DLL Pass Simplify();
+
+/*!
+* \brief Instruments bound checkers.
+*
+* \return The pass.
+*/
+TVM_DLL Pass InstrumentBoundCheckers();
+
+/*!
* \brief Transform the high-level PrimFunc to a low-level version
* that can be used as an API function.
*
cfg.auto_unroll_max_depth,
cfg.auto_unroll_max_extent,
cfg.unroll_explicit)
+
for f in lower_phase2:
stmt = f(stmt)
stmt = ir_pass.RemoveNoOp(stmt)
if not cfg.disable_select_rewriting:
stmt = ir_pass.RewriteUnsafeSelect(stmt)
+
for f in lower_phase3:
stmt = f(stmt)
+
# Instrument BoundCheckers
if cfg.instrument_bound_checkers:
stmt = ir_pass.InstrumentBoundCheckers(stmt)
+
if simple_mode:
return stmt
return _fpass.prim_func_pass(_transform, opt_level=0, name="Filter")
+def InjectCopyIntrin(pragma_key, fintrin):
+ """Inject virtual thread loops.
+
+ Parameters
+ ----------
+ pragma_key : str
+ The pragma key for hint of copy.
+
+ fintrin : function
+ The function with signature copyintrin(src, dst, pad_before, pad_after, pad_value)
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.InjectCopyIntrin(pragma_key, fintrin)
+
+
+def CoProcSync():
+ """Detect and insert sync points to co-processor.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.CoProcSync()
+
+
+def LiftAttrScope(attr_key):
+ """Lift common attrs with attr_key to outer scope.
+
+ Parameters
+ ----------
+ attr_key : str
+ The attribute key to be checked.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.LiftAttrScope(attr_key)
+
+
+def LoopPartition(split_const_loop):
+ """Inject virtual thread loops.
+
+ Parameters
+ ----------
+ split_const_loop : bool
+ Flag to enable partition for const loop.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.LoopPartition(split_const_loop)
+
+
+def VectorizeLoop(enable_vectorize=True):
+ """Lower vectorization loops.
+
+ Parameters
+ ----------
+ enable_vectorize : bool
+ Whether vectorization is enabled.
+ Will lower to scalar loop when it is turned off.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.VectorizeLoop(enable_vectorize)
+
+
+def InjectVirtualThread():
+ """Inject virtual thread loops.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.InjectVirtualThread()
+
+
+def InjectDoubleBuffer(split_loop_factor):
+ """Inject double buffer statements.
+
+ Parameters
+ ----------
+ split_loop_factor : int
+ Loop splitting factor.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.InjectDoubleBuffer(split_loop_factor)
+
+
+def StorageRewrite():
+ """Rewrite storage allocation pattern.
+
+ Moves the allocation to outer most possible scope.
+ Trying to share space between allocations to make
+ a static allocation plan when possible.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.StorageRewrite()
+
+
+def UnrollLoop(auto_max_step,
+ auto_max_depth,
+ auto_max_extent,
+ explicit_unroll):
+ """Unroll the constant loop marked by unroll.
+
+ This pass also automatically attach pragma unroll tag to loops which meets the standard.
+
+ Parameters
+ ----------
+ auto_max_step : int
+ The maximum step before stop attach automatic unroll
+
+ auto_max_depth : int
+ The maximum depth before stop attach automatic unroll
+
+ auto_max_extent : int
+ The maximum extent of the loop we can unroll.
+ This is an legacy option that do not take the loop total steps into account.
+
+ explicit_unroll : bool
+ Whether explicitly unroll the loop, or leave unroll annotation to codegen.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.UnrollLoop(
+ auto_max_step, auto_max_depth, auto_max_extent, explicit_unroll)
+
+
+def RemoveNoOp():
+ """Remove No Op from the Stmt.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.RemoveNoOp()
+
+
+def RewriteUnsafeSelect():
+ """Detect and rewrite unsafe select that contains memory access.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.RewriteUnsafeSelect()
+
+
+def Simplify():
+ """Run arithmetic simplifications on the statements and expressions.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.Simplify()
+
+
+def InstrumentBoundCheckers():
+ """Instruments bound checkers.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.InstrumentBoundCheckers()
+
+
def LowerCustomDatatypes():
"""Lower custom datatypes.
#define TVM_ARITH_COMPUTE_EXPR_H_
#include <tvm/tir/expr.h>
+#include <tvm/tir/op.h>
#include <limits>
#include <algorithm>
}
}
-/*!
-* \brief Build a Stmt given a schedule, args and binds. This function runs the IR passes.
-* \param sch The schedule to build.
-* \param args The arguments for the schedule.
-* \param binds Buffer assignments.
-* \param loop_partition True if the LoopPartition pass should be included.
-* \param out_arg_list Returns the arguments for the Stmt.
-* \param config The build configuration.
-* \return The built Stmt.
-*/
-tir::Stmt BuildStmt(te::Schedule sch,
- const Array<te::Tensor>& args,
- const std::unordered_map<te::Tensor, tir::Buffer>& binds,
- bool loop_partition,
- Array<ObjectRef> *out_arg_list,
- const BuildConfig& config) {
- sch = sch.normalize();
-
- // Phase 0
- auto bounds = te::InferBound(sch);
- auto stmt = te::ScheduleOps(sch, bounds, false);
- stmt = tir::InjectPrefetch(stmt);
-
- bool compact = tir::VerifyCompactBuffer(stmt);
- Map<te::Tensor, tir::Buffer> out_binds;
- GetBinds(args, compact, binds, &out_binds, out_arg_list, config);
-
- // Phase 1
- stmt = tir::StorageFlatten(stmt, out_binds, 64,
- config->instrument_bound_checkers);
- stmt = tir::CanonicalSimplify(stmt);
- if (loop_partition) {
- stmt = tir::LoopPartition(stmt, config->partition_const_loop);
- }
- if (config->disable_vectorize) {
- stmt = tir::SkipVectorize(stmt);
- } else {
- stmt = tir::VectorizeLoop(stmt);
- }
- stmt = tir::InjectVirtualThread(stmt);
- stmt = tir::InjectDoubleBuffer(stmt, config->double_buffer_split_loop);
- stmt = tir::StorageRewrite(stmt);
- stmt = tir::UnrollLoop(stmt, config->auto_unroll_max_step, config->auto_unroll_max_depth,
- config->auto_unroll_max_extent, config->unroll_explicit);
-
- // Phase 2
- stmt = tir::Simplify(stmt);
- stmt = tir::RemoveNoOp(stmt);
-
- if (!(config->disable_select_rewriting))
- stmt = tir::RewriteUnsafeSelect(stmt);
-
- if (config->instrument_bound_checkers)
- stmt = tir::InstrumentBoundCheckers(stmt);
-
- return stmt;
-}
-
transform::Pass BindTarget(Target target) {
auto fpass = [target](tir::PrimFunc f, IRModule m, transform::PassContext ctx) {
return WithAttr(std::move(f), tvm::attr::kTarget, target);
template<typename FCond>
-transform::Pass FilterBy(FCond fcond) {
+transform::Pass Filter(FCond fcond) {
auto fpass = [fcond](tir::PrimFunc f, IRModule m, transform::PassContext ctx) {
if (fcond(f)) {
return f;
return tir::PrimFunc(nullptr);
}
};
- return tir::transform::CreatePrimFuncPass(fpass, 0, "FilterBy", {});
+ return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
}
-IRModule lower(te::Schedule sch,
- const Array<te::Tensor>& args,
- const std::string& name,
- const std::unordered_map<te::Tensor, tir::Buffer>& binds,
- const BuildConfig& config) {
- Array<ObjectRef> out_arg_list;
- auto stmt = BuildStmt(sch, args, binds, true, &out_arg_list, config);
-
+IRModule BuildIRModule(const Array<ObjectRef>& out_arg_list,
+ tir::Stmt stmt,
+ const std::string& name,
+ const BuildConfig& config) {
Array<tir::Var> params;
Map<tir::Var, tir::Buffer> buffer_map;
if (config->restricted_func) {
f = WithAttr(std::move(f), "tir.noalias", Integer(1));
}
+
return IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
}
+IRModule lower(te::Schedule sch,
+ const Array<te::Tensor>& args,
+ const std::string& name,
+ const std::unordered_map<te::Tensor, tir::Buffer>& binds,
+ const BuildConfig& config) {
+ Array<ObjectRef> out_arg_list;
+
+ sch = sch.normalize();
+
+ // Phase 0
+ auto bounds = te::InferBound(sch);
+ auto stmt = te::ScheduleOps(sch, bounds, false);
+ stmt = tir::InjectPrefetch(stmt);
+
+ bool compact = tir::VerifyCompactBuffer(stmt);
+ Map<te::Tensor, tir::Buffer> out_binds;
+ GetBinds(args, compact, binds, &out_binds, &out_arg_list, config);
+
+ // Phase 1
+ stmt = tir::StorageFlatten(stmt, out_binds, 64,
+ config->instrument_bound_checkers);
+
+ // convert to IRModule.
+ auto mod = BuildIRModule(out_arg_list, stmt, name, config);
+ auto pass_list = Array<tvm::transform::Pass>();
+
+ pass_list.push_back(tir::transform::Simplify());
+ pass_list.push_back(tir::transform::LoopPartition(config->partition_const_loop));
+ pass_list.push_back(tir::transform::VectorizeLoop(!config->disable_vectorize));
+ pass_list.push_back(tir::transform::InjectVirtualThread());
+ pass_list.push_back(tir::transform::InjectDoubleBuffer(config->double_buffer_split_loop));
+ pass_list.push_back(tir::transform::StorageRewrite());
+ pass_list.push_back(
+ tir::transform::UnrollLoop(config->auto_unroll_max_step,
+ config->auto_unroll_max_depth,
+ config->auto_unroll_max_extent,
+ config->unroll_explicit));
+ // Phase 2
+ pass_list.push_back(tir::transform::Simplify());
+ pass_list.push_back(tir::transform::RemoveNoOp());
+ if (!(config->disable_select_rewriting)) {
+ pass_list.push_back(tir::transform::RewriteUnsafeSelect());
+ }
+ if (config->instrument_bound_checkers) {
+ pass_list.push_back(tir::transform::InstrumentBoundCheckers());
+ }
+ // run
+ auto optimize = transform::Sequential(pass_list);
+ mod = optimize(std::move(mod));
+ return mod;
+}
+
+
std::pair<IRModule, IRModule>
split_dev_host_funcs(IRModule mod_mixed,
const Target& target,
mod_mixed = opt_mixed(std::move(mod_mixed));
auto host_pass_list = {
- FilterBy([](const tir::PrimFunc& f) {
+ Filter([](const tir::PrimFunc& f) {
return f->GetAttr<Integer>(
tvm::attr::kCallingConv,
Integer(CallingConv::kDefault)) != CallingConv::kDeviceKernelLaunch;
// device pipeline
auto device_pass_list = {
- FilterBy([](const tir::PrimFunc& f) {
+ Filter([](const tir::PrimFunc& f) {
return f->GetAttr<Integer>(
tvm::attr::kCallingConv,
Integer(CallingConv::kDefault)) == CallingConv::kDeviceKernelLaunch;
REGISTER_PASS(ConvertSSA);
REGISTER_PASS(VerifySSA);
-REGISTER_PASS(RewriteUnsafeSelect);
REGISTER_PASS(Inline);
REGISTER_PASS(IRTransform);
-REGISTER_PASS(VectorizeLoop);
-REGISTER_PASS(SkipVectorize);
-REGISTER_PASS(UnrollLoop);
-REGISTER_PASS(InjectCopyIntrin);
-REGISTER_PASS(StorageRewrite);
-REGISTER_PASS(CoProcSync);
-REGISTER_PASS(LowerStorageAccessInfo);
-REGISTER_PASS(InjectVirtualThread);
REGISTER_PASS(InjectPrefetch);
-REGISTER_PASS(InjectDoubleBuffer);
-REGISTER_PASS(LoopPartition);
-REGISTER_PASS(RemoveNoOp);
-REGISTER_PASS(LiftAttrScope);
REGISTER_PASS(VerifyGPUCode);
REGISTER_PASS(DecorateDeviceScope);
-REGISTER_PASS(InstrumentBoundCheckers);
REGISTER_PASS(VerifyCompactBuffer);
REGISTER_PASS(HoistIfThenElse);
-REGISTER_PASS(NarrowDataType);
} // namespace tir
} // namespace tvm
*/
// Instrument checkers for out of the bounds access.
+#include <tvm/runtime/registry.h>
+#include <tvm/arith/analyzer.h>
#include <tvm/tir/expr.h>
-#include <tvm/tir/ir_pass.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/transform.h>
#include <tvm/tir/stmt_functor.h>
#include <vector>
#include <unordered_map>
}
// Try to simplify index and bound.
- index = tir::Simplify(index);
- upper_bound = tir::Simplify(upper_bound);
+ index = analyzer_.Simplify(index);
+ upper_bound = analyzer_.Simplify(upper_bound);
// Cast to the same type - signed, to be able to check lower bound.
index = CastNode::make(DataType::Int(64), index);
const char *const error_message_ = "OUT OF THE BOUNDS";
// Hashtable which maps buffer_var to shape.
std::unordered_map<const VarNode *, PrimExpr> mem_to_shape_;
+ // internal analyzer
+ arith::Analyzer analyzer_;
};
Stmt InstrumentBoundCheckers(Stmt stmt) {
bound_collector(stmt);
return BoundChecker(bound_collector.mem_to_shape)(std::move(stmt));
}
+
+
+TVM_REGISTER_GLOBAL("ir_pass.InstrumentBoundCheckers")
+.set_body_typed(InstrumentBoundCheckers);
+
+namespace transform {
+
+Pass InstrumentBoundCheckers() {
+ auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+ auto* n = f.CopyOnWrite();
+ BoundCollector bound_collector;
+ // At first walk recursively and collect bound attributes.
+ bound_collector(n->body);
+ n->body = BoundChecker(bound_collector.mem_to_shape)(std::move(n->body));
+ return f;
+ };
+ return CreatePrimFuncPass(pass_func, 0, "tir.InstrumentBoundCheckers", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.InstrumentBoundCheckers")
+.set_body_typed(InstrumentBoundCheckers);
+
+} // namespace transform
+
} // namespace tir
} // namespace tvm
/*!
* \file coproc_sync.cc
*/
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/transform.h>
#include <tvm/tir/expr.h>
-#include <tvm/tir/ir_pass.h>
#include <tvm/tir/stmt_functor.h>
#include <unordered_map>
#include <unordered_set>
-#include "ir_util.h"
-#include "storage_access.h"
+#include "../pass/ir_util.h"
+#include "../pass/storage_access.h"
namespace tvm {
namespace tir {
return CoProcSyncInserter().Insert(std::move(stmt));
}
+TVM_REGISTER_GLOBAL("ir_pass.CoProcSync")
+.set_body_typed(CoProcSync);
+
+namespace transform {
+
+Pass CoProcSync() {
+ auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+ auto* n = f.CopyOnWrite();
+ n->body = CoProcSyncInserter().Insert(std::move(n->body));
+ return f;
+ };
+ return CreatePrimFuncPass(pass_func, 0, "tir.CoProcSync", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.CoProcSync")
+.set_body_typed(CoProcSync);
+
+} // namespace transform
+
} // namespace tir
} // namespace tvm
* \brief Replace certain copy with copy intrinsics.
* \file copy_intrin_rewrite.cc
*/
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/transform.h>
#include <tvm/arith/pattern.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
-#include <tvm/tir/ir_pass.h>
#include "../../arith/pattern_match.h"
namespace tvm {
return CopyIntrinInjector(pragma_key, flower_copy_fromto)(std::move(stmt));
}
+TVM_REGISTER_GLOBAL("ir_pass.InjectCopyIntrin")
+.set_body_typed(InjectCopyIntrin);
+
+namespace transform {
+
+Pass InjectCopyIntrin(std::string pragma_key,
+ PackedFunc flower_copy_fromto) {
+ auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
+ auto* n = f.CopyOnWrite();
+ n->body = CopyIntrinInjector(
+ pragma_key, flower_copy_fromto)(std::move(n->body));
+ return f;
+ };
+ return CreatePrimFuncPass(pass_func, 0, "tir.InjectCopyIntrin", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.InjectCopyIntrin")
+.set_body_typed(InjectCopyIntrin);
+
+} // namespace transform
+
} // namespace tir
} // namespace tvm
* \brief Inject double buffering optimization for data fetch.
* \file inject_double_buffer.cc
*/
+#include <tvm/runtime/registry.h>
#include <tvm/tir/ir_pass.h>
+#include <tvm/tir/transform.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/op.h>
-#include "ir_util.h"
+#include "../pass/ir_util.h"
#include "../../arith/compute_expr.h"
namespace tvm {
Stmt InjectDoubleBuffer(Stmt stmt, int split_loop) {
return DoubleBufferInjector(split_loop).Inject(stmt);
}
+
+TVM_REGISTER_GLOBAL("ir_pass.InjectDoubleBuffer")
+.set_body_typed(InjectDoubleBuffer);
+
+
+namespace transform {
+
+Pass InjectDoubleBuffer(int split_loop) {
+ auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
+ auto* n = f.CopyOnWrite();
+ n->body = DoubleBufferInjector(split_loop).Inject(std::move(n->body));
+ return f;
+ };
+ return CreatePrimFuncPass(pass_func, 0, "tir.InjectDoubleBuffer", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.InjectDoubleBuffer")
+.set_body_typed(InjectDoubleBuffer);
+
+} // namespace transform
+
} // namespace tir
} // namespace tvm
/*!
* \file inject_virtual_thread.cc
*/
+#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
#include <tvm/tir/ir_pass.h>
#include <unordered_set>
#include "../../arith/compute_expr.h"
return ConvertSSA(std::move(stmt));
}
+TVM_REGISTER_GLOBAL("ir_pass.InjectVirtualThread")
+.set_body_typed(InjectVirtualThread);
+
+namespace transform {
+
+Pass InjectVirtualThread() {
+ auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+ auto* n = f.CopyOnWrite();
+ n->body = ConvertSSA(VirtualThreadInjector()(std::move(n->body)));
+ return f;
+ };
+ return CreatePrimFuncPass(pass_func, 0, "tir.InjectVirtualThread", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.InjectVirtualThread")
+.set_body_typed(InjectVirtualThread);
+
+} // namespace transform
+
} // namespace tir
} // namespace tvm
* the body contains the same scope.
* \file lift_attr_scope.cc
*/
-#include <tvm/tir/ir_pass.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/transform.h>
#include <tvm/tir/stmt_functor.h>
-#include "ir_util.h"
+#include "../pass/ir_util.h"
namespace tvm {
namespace tir {
return AttrScopeLifter(attr_key).Lift(std::move(stmt));
}
+TVM_REGISTER_GLOBAL("ir_pass.LiftAttrScope")
+.set_body_typed(LiftAttrScope);
+
+namespace transform {
+
+Pass LiftAttrScope(std::string attr_key) {
+ auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
+ auto* n = f.CopyOnWrite();
+ n->body = AttrScopeLifter(attr_key).Lift(std::move(n->body));
+ return f;
+ };
+ return CreatePrimFuncPass(pass_func, 0, "tir.LiftAttrScope", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.LiftAttrScope")
+.set_body_typed(LiftAttrScope);
+
+} // namespace transform
+
} // namespace tir
} // namespace tvm
/*!
* \file loop_partition.cc
*/
+#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
-#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/ir_pass.h>
+#include <tvm/tir/transform.h>
+#include <tvm/tir/stmt_functor.h>
#include <tvm/arith/analyzer.h>
#include <unordered_map>
#include <unordered_set>
Stmt pre_stmt;
bool pre_stmt_recurse = true;
if (middle_interval_i->HasLowerBound()) {
- body_begin = tir::Simplify(middle_interval.min());
+ body_begin = analyzer_.Simplify(middle_interval.min());
if (!analyzer_.CanProve(body_begin == min)) {
PrimExpr cond = (body_begin - min >= 0);
if (!analyzer_.CanProve(cond)) {
Stmt post_stmt;
bool post_stmt_recurse = true;
if (middle_interval_i->HasUpperBound()) {
- post_doubt_begin = tir::Simplify(middle_interval.max() + 1);
+ post_doubt_begin = analyzer_.Simplify(middle_interval.max() + 1);
if (!analyzer_.CanProve(middle_interval.max() == max)) {
// require the extent to be non-negative
PrimExpr cond = (max - post_doubt_begin + 1 >= 0);
return Substitute(body, {{Var{for_node->loop_var}, make_const(DataType::Int(32), 0)}});
} else {
return ForNode::make(for_node->loop_var, IntImm(for_node->min.dtype(), 0), extent,
- for_node->for_type, for_node->device_api, body);
+ for_node->for_type, for_node->device_api, body);
}
}
return stmt;
}
+
+TVM_REGISTER_GLOBAL("ir_pass.LoopPartition")
+.set_body_typed(LoopPartition);
+
+namespace transform {
+
+Pass LoopPartition(bool split_const_loop) {
+ auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
+ auto* n = f.CopyOnWrite();
+ n->body = LoopPartition(std::move(n->body), split_const_loop);
+ return f;
+ };
+ return CreatePrimFuncPass(pass_func, 0, "tir.LoopPartition", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.LoopPartition")
+.set_body_typed(LoopPartition);
+
+} // namespace transform
+
} // namespace tir
} // namespace tvm
return StorageAccessInfoLower()(std::move(stmt));
}
+TVM_REGISTER_GLOBAL("ir_pass.LowerStorageAccessInfo")
+.set_body_typed(LowerStorageAccessInfo);
namespace transform {
return DataTypeRewriter(target_bits)(stmt);
}
+TVM_REGISTER_GLOBAL("ir_pass.NarrowDataType")
+.set_body_typed(NarrowDataType);
+
+
namespace transform {
Pass NarrowDataType(int target_bits) {
* \file remove_no_op.cc
* \brief Remove no op from the stmt
*/
+#include <tvm/runtime/registry.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/ir_pass.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/transform.h>
#include <tvm/tir/stmt_functor.h>
#include <unordered_map>
Stmt RemoveNoOp(Stmt stmt) {
return NoOpRemover()(std::move(stmt));
}
+
+TVM_REGISTER_GLOBAL("ir_pass.RemoveNoOp")
+.set_body_typed(RemoveNoOp);
+
+namespace transform {
+
+Pass RemoveNoOp() {
+ auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+ auto* n = f.CopyOnWrite();
+ n->body = NoOpRemover()(std::move(n->body));
+ return f;
+ };
+ return CreatePrimFuncPass(pass_func, 0, "tir.RemoveNoOp", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.RemoveNoOp")
+.set_body_typed(RemoveNoOp);
+
+} // namespace transform
+
} // namespace tir
} // namespace tvm
* \file unsafe_select_rewrite.cc
* \brief Rewrite uinsafe select expression.
*/
+#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
-#include <tvm/tir/ir_pass.h>
+#include <tvm/tir/transform.h>
namespace tvm {
namespace tir {
return UnsafeSelectRewriter()(std::move(stmt));
}
+TVM_REGISTER_GLOBAL("ir_pass.RewriteUnsafeSelect")
+.set_body_typed(RewriteUnsafeSelect);
+
+namespace transform {
+
+Pass RewriteUnsafeSelect() {
+ auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+ auto* n = f.CopyOnWrite();
+ n->body = UnsafeSelectRewriter()(std::move(n->body));
+ return f;
+ };
+ return CreatePrimFuncPass(pass_func, 0, "tir.RewriteUnsafeSelect", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.RewriteUnsafeSelect")
+.set_body_typed(RewriteUnsafeSelect);
+
+} // namespace transform
+
} // namespace tir
} // namespace tvm
*/
/*!
- * \file stmt_simplify.cc
+ * \file simplify.cc
* \brief Statement simplifier based on analyzer
*/
+#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h>
+#include <tvm/tir/transform.h>
#include <tvm/tir/analysis.h>
#include <tvm/arith/analyzer.h>
#include <tvm/tir/op.h>
#include <tvm/arith/analyzer.h>
-#include "ir_mutator_with_analyzer.h"
+#include "../../arith/ir_mutator_with_analyzer.h"
namespace tvm {
namespace arith {
Stmt Simplify(Stmt stmt, Map<Var, Range> vrange) {
return CanonicalSimplify(std::move(stmt), vrange);
}
+
+namespace transform {
+
+Pass Simplify() {
+ auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+ auto* n = f.CopyOnWrite();
+ arith::Analyzer analyzer;
+ n->body = arith::StmtSimplifier(&analyzer).Simplify(std::move(n->body));
+ return f;
+ };
+ return CreatePrimFuncPass(pass_func, 0, "tir.Simplify", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.Simplify")
+.set_body_typed(Simplify);
+
+} // namespace transform
+
} // namespace tir
} // namespace tvm
* \brief Memory access pattern analysis and optimization.
* Re-write data access to enable memory sharing when possible.
*/
+#include <tvm/runtime/registry.h>
#include <tvm/arith/analyzer.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h>
+#include <tvm/tir/transform.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/target/target_info.h>
#include <map>
#include <unordered_set>
#include <unordered_map>
-#include "ir_util.h"
+#include "../pass/ir_util.h"
#include "../../arith/compute_expr.h"
#include "../../runtime/thread_storage_scope.h"
stmt = StoragePlanRewriter().Rewrite(std::move(stmt), true);
return VectorAllocRewriter()(std::move(stmt));
}
+
+TVM_REGISTER_GLOBAL("ir_pass.StorageRewrite")
+.set_body_typed(StorageRewrite);
+
+namespace transform {
+
+Pass StorageRewrite() {
+ auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+ auto* n = f.CopyOnWrite();
+ n->body = StoragePlanRewriter().Rewrite(std::move(n->body), true);
+ n->body = VectorAllocRewriter()(std::move(n->body));
+ return f;
+ };
+ return CreatePrimFuncPass(pass_func, 0, "tir.StorageRewrite", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.StorageRewrite")
+.set_body_typed(StorageRewrite);
+
+} // namespace transform
+
} // namespace tir
} // namespace tvm
* \file unroll_loop.cc
*/
// Unrolls the loop as in Halide pipeline.
+#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
+#include <tvm/tir/op.h>
#include <tvm/tir/ir_pass.h>
+#include <tvm/tir/transform.h>
#include <tvm/tir/stmt_functor.h>
#include <unordered_set>
#include <unordered_map>
}
}
-Stmt UnrollLoopExplicitly(Stmt stmt) {
- const ForNode* op = stmt.as<ForNode>();
- if (!op) {
- LOG(FATAL) << "attempted to unroll a non-loop statement";
- }
- return LoopUnroller(0, 0, 0, false).Unroll(op);
+TVM_REGISTER_GLOBAL("ir_pass.UnrollLoop")
+.set_body_typed(UnrollLoop);
+
+namespace transform {
+
+Pass UnrollLoop(int auto_max_step,
+ int auto_max_depth,
+ int auto_max_extent,
+ bool explicit_unroll) {
+ auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
+ auto* n = f.CopyOnWrite();
+ n->body = UnrollLoop(std::move(f->body),
+ auto_max_step,
+ auto_max_depth,
+ auto_max_extent,
+ explicit_unroll);
+ return f;
+ };
+ return CreatePrimFuncPass(pass_func, 0, "tir.UnrollLoop", {});
}
+TVM_REGISTER_GLOBAL("tir.transform.UnrollLoop")
+.set_body_typed(UnrollLoop);
+
+} // namespace transform
+
} // namespace tir
} // namespace tvm
* \file vectorize_loop.cc
*/
// Loop vectorizer as in Halide pipeline.
+#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
-#include <tvm/tir/ir_pass.h>
+#include <tvm/tir/transform.h>
#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/ir_pass.h>
#include <tvm/arith/analyzer.h>
#include <unordered_set>
#include <unordered_map>
Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<ForNode>();
if (op->for_type == ForType::Vectorized) {
- return ForNode::make(op->loop_var, op->min, op->extent, ForType::Serial, op->device_api,
- op->body);
+ return ForNode::make(op->loop_var, op->min, op->extent,
+ ForType::Serial, op->device_api,
+ op->body);
} else {
return stmt;
}
return VectorizeSkipper()(std::move(stmt));
}
+TVM_REGISTER_GLOBAL("ir_pass.VectorizeLoop")
+.set_body_typed(VectorizeLoop);
+
+TVM_REGISTER_GLOBAL("ir_pass.SkipVectorize")
+.set_body_typed(SkipVectorize);
+
+namespace transform {
+
+// TODO(tvm-team): Make it as a target property.
+Pass VectorizeLoop(bool enable_vectorize) {
+ auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
+ auto* n = f.CopyOnWrite();
+ if (enable_vectorize) {
+ n->body = LoopVectorizer()(std::move(n->body));
+ } else {
+ n->body = VectorizeSkipper()(std::move(n->body));
+ }
+ return f;
+ };
+ return CreatePrimFuncPass(pass_func, 0, "tir.VectorizeLoop", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.VectorizeLoop")
+.set_body_typed(VectorizeLoop);
+
+} // namespace transform
+
} // namespace tir
} // namespace tvm
+++ /dev/null
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-import tvm
-from tvm import te
-
-def test_virtual_thread():
- m = te.var('m')
- A = te.placeholder((m, ), name='A')
- A1 = te.compute((m,), lambda i: A[i], name='A1')
- A2 = te.compute((m,), lambda i: A1[i] + 3, name='A2')
-
- s = te.create_schedule(A2.op)
- vx = te.thread_axis("vthread", name="vx")
- xo, xi = s[A2].split(A2.op.axis[0], nparts=2)
- s[A2].bind(xo, vx)
- xo, xi = s[A2].split(xi, 8)
- s[A1].compute_at(s[A2], xo)
-
- bounds = tvm.te.schedule.InferBound(s)
- assert isinstance(bounds, tvm.container.Map)
- stmt = tvm.te.schedule.ScheduleOps(s, bounds)
-
- Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
- A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name='A2')
- stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64)
- stmt = tvm.tir.ir_pass.Simplify(stmt)
- stmt = tvm.tir.ir_pass.InjectVirtualThread(stmt)
- print(stmt)
-
-if __name__ == "__main__":
- test_virtual_thread()
ib.scope_attr(cp, "coproc_scope", 1)
A[j] = A[j + k * 10] + 2
stmt = ib.get()
- stmt = tvm.tir.ir_pass.CoProcSync(stmt)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], stmt))
+ stmt = tvm.tir.transform.CoProcSync()(mod)["main"].body
+
body = stmt.body.body.body
blist = tvm.tir.stmt_list(body)
assert(blist[1].value.name == "cop.coproc_read_barrier")
ib.scope_attr(cp, "coproc_scope", 2)
A[ty] = 1.0
stmt = ib.get()
- stmt = tvm.tir.ir_pass.CoProcSync(stmt)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], stmt))
+ stmt = tvm.tir.transform.CoProcSync()(mod)["main"].body
+
def test_coproc_sync3():
def __check_list(tvm_array, py_list):
A[0] = 0.0
stmt = ib.get()
- stmt = tvm.tir.ir_pass.CoProcSync(stmt)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], stmt))
+ stmt = tvm.tir.transform.CoProcSync()(mod)["main"].body
+
slist = tvm.tir.stmt_list(stmt[0].body.body)
push_st = slist[2]
slist = tvm.tir.stmt_list(slist[-1])
assert src.strides[0] == l
assert tuple(src.shape) == (m, l)
return tvm.tir.Evaluate(0)
- stmt = tvm.tir.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
+ stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body
+
def test_copy_pad():
m = te.var('m')
assert pad_after[1].value == 0
assert pad_value.value == 1.0
return tvm.tir.Evaluate(0)
- stmt = tvm.tir.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
+ stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body
+
def test_single_point_test():
A = te.placeholder((1,), name='A')
assert tvm.tir.ir_pass.Simplify(src.strides[0]).value == 1
assert tvm.tir.ir_pass.Simplify(dst.strides[0]).value == 1
return tvm.tir.Evaluate(0)
- stmt = tvm.tir.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
+ stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body
+
def assert_expr_equal(a, b):
assert tvm.tir.ir_pass.Simplify(a - b).value == 0
assert_expr_equal(pad_after[0], rpad_after)
assert_expr_equal(src.shape[0], 6 - rpad_before - rpad_after)
return tvm.tir.Evaluate(0)
- stmt = tvm.tir.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
+ stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body
+
+
if __name__ == "__main__":
C[j] = B[j] + 1
stmt = ib.get()
- stmt = tvm.tir.ir_pass.InjectDoubleBuffer(stmt, 2)
- stmt = tvm.tir.ir_pass.Simplify(stmt)
- assert isinstance(stmt.body.body, tvm.tir.Allocate)
- assert stmt.body.body.extents[0].value == 2
mod = tvm.IRModule({
"db" : tvm.tir.PrimFunc([A.asobject(), C.asobject()], stmt)
})
+
+ opt = tvm.transform.Sequential(
+ [tvm.tir.transform.InjectDoubleBuffer(2),
+ tvm.tir.transform.Simplify()])
+ mod = opt(mod)
+ stmt = mod["db"].body
+
+ assert isinstance(stmt.body.body, tvm.tir.Allocate)
+ assert stmt.body.body.extents[0].value == 2
+
f = tvm.tir.transform.ThreadSync("shared")(mod)["db"]
count = [0]
def count_sync(op):
C[i * nthread + tx] = B[i] + 1
return ib.get()
- stmt = tvm.tir.ir_pass.InjectVirtualThread(get_vthread("vthread"))
+ stmt = tvm.tir.transform.InjectVirtualThread()(tvm.IRModule.from_expr(
+ tvm.tir.PrimFunc([], get_vthread("vthread"))))["main"].body
+
assert stmt.body.body.extents[0].value == 2
- stmt = tvm.tir.ir_pass.InjectVirtualThread(get_vthread("cthread"))
+
+ stmt = tvm.tir.transform.InjectVirtualThread()(tvm.IRModule.from_expr(
+ tvm.tir.PrimFunc([], get_vthread("cthread"))))["main"].body
+
assert len(stmt.body.body.extents) == 3
A[tx] = tx + 1.0
B[ty] = ty + 1.0
ib.emit(tvm.tir.call_extern("int32", "Run",
- abuffer.access_ptr("r"),
- bbuffer.access_ptr("r"),
- cbuffer.access_ptr("rw")))
+ abuffer.access_ptr("r"),
+ bbuffer.access_ptr("r"),
+ cbuffer.access_ptr("rw")))
return ib.get()
- stmt = tvm.tir.ir_pass.InjectVirtualThread(get_vthread("vthread"))
+
+ stmt = tvm.tir.transform.InjectVirtualThread()(tvm.IRModule.from_expr(
+ tvm.tir.PrimFunc([], get_vthread("cthread"))))["main"].body
+
assert stmt.body.body.extents[0].value == 2
assert stmt.body.body.body.body.body.body.extents[0].value == 2
assert len(stmt.body.body.body.body.body.body.extents) == 3
+
def test_vthread_if_then_else():
nthread = 2
tx = te.thread_axis("vthread")
with ib.if_scope(i == 0):
B[i] = A[i * nthread + tx] + 2
stmt = ib.get()
- stmt = tvm.tir.ir_pass.InjectVirtualThread(stmt)
+
+ stmt = tvm.tir.transform.InjectVirtualThread()(tvm.IRModule.from_expr(
+ tvm.tir.PrimFunc([], stmt)))["main"].body
+
assert stmt.body.body.body[0].else_case != None
assert stmt.body.body.body[1].else_case == None
import tvm
from tvm import te
import numpy as np
+
def collect_visit(stmt, f):
ret = []
tvm.tir.ir_pass.PostOrderVisit(stmt, lambda x: ret.append(f(x)))
return ret
-def lower(sch, args):
- binds = {}
- arg_list = []
- for x in args:
- if isinstance(x, te.tensor.Tensor):
- buf = tvm.tir.decl_buffer(x.shape, dtype=x.dtype, name=x.name)
- assert x not in binds
- binds[x] = buf
- arg_list.append(buf)
- else:
- raise ValueError("args must be Tensor, Buffer or Var")
- sch = sch.normalize()
- bounds = tvm.te.schedule.InferBound(sch)
- stmt = tvm.te.schedule.ScheduleOps(sch, bounds)
- stmt = tvm.tir.ir_pass.LoopPartition(stmt, True)
- stmt = tvm.tir.ir_pass.RemoveNoOp(stmt)
- stmt = tvm.tir.ir_pass.StorageFlatten(stmt, binds, 64, True)
- stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt)
- stmt = tvm.tir.ir_pass.VectorizeLoop(stmt)
- stmt = tvm.tir.ir_pass.Simplify(stmt)
- return stmt
@pytest.mark.xfail
def test_out_of_bounds_llvm(index_a, index_b):
tgt = "llvm"
tgt_host = "llvm"
stmt = tvm.lower (s, [A, B, C], simple_mode=True)
- print (stmt)
fadd = tvm.build (s, [A, B, C], tgt, target_host=tgt_host, name="myadd")
ctx = tvm.context(tgt, 0)
a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx)
tgt = "llvm"
tgt_host = "llvm"
stmt = tvm.lower (s, [a, b, c], simple_mode=True)
- print (stmt)
f = tvm.build(s, [a, b, c], tgt, target_host=tgt_host, name="myaddvec")
ctx = tvm.cpu(0)
n = nn
s = te.create_schedule(T.op)
xo, xi = s[T].split(T.op.axis[0], factor=4)
- bounds = tvm.te.schedule.InferBound(s)
- stmt = lower (s, [A, B, T])
- # num_attributes = num_buffers * num_splits = 2 * 3
- # before instrumentation
- assert_bound_instrumentation(stmt, check_attr_stmt, 2 * 3)
- assert_bound_instrumentation(stmt, check_branch_stmt, 0)
- stmt = tvm.tir.ir_pass.InstrumentBoundCheckers(stmt)
+ with tvm.target.build_config(instrument_bound_checkers=True,
+ partition_const_loop=True):
+ mod = tvm.driver.lower(s, [A, B, T], name="main")
+
+ stmt = mod["main"].body
# after instrumentation
assert_bound_instrumentation(stmt, check_attr_stmt, 2 * 3)
assert_bound_instrumentation(stmt, check_branch_stmt, 2)
def test_in_bounds_const_loop_partition_llvm():
- with tvm.target.build_config(instrument_bound_checkers=True, partition_const_loop=True):
+ with tvm.target.build_config(instrument_bound_checkers=True,
+ partition_const_loop=True):
n = 21
A = te.placeholder((n, ), name='A')
B = te.placeholder((n, ), name='B')
A[j] = A[j] + 3
A[j] = A[j] + 3
body = ib.get()
- body = tvm.tir.ir_pass.LiftAttrScope(body, "coproc_uop_scope")
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body))
+ body = tvm.tir.transform.LiftAttrScope("coproc_uop_scope")(mod)["main"].body
+
assert body.body.body.node == cp
# only able to lift to the common pattern of the last two fors.
A[i] = A[i] + 2
body = ib.get()
- body = tvm.tir.ir_pass.LiftAttrScope(body, "coproc_uop_scope")
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body))
+ body = tvm.tir.transform.LiftAttrScope("coproc_uop_scope")(mod)["main"].body
+
assert body.body.body.body[1].node == cp
assert len(body.body.body.body) == 2
tvm.tir.ir_pass.PostOrderVisit(stmt, lambda x : ret.append(f(x)))
return ret
-def lower(sch, args):
- binds = {}
- arg_list = []
- for x in args:
- if isinstance(x, te.tensor.Tensor):
- buf = tvm.tir.decl_buffer(x.shape, dtype=x.dtype, name=x.name)
- assert x not in binds
- binds[x] = buf
- arg_list.append(buf)
- else:
- raise ValueError("args must be Tensor, Buffer or Var")
- sch = sch.normalize()
- bounds = tvm.te.schedule.InferBound(sch)
- stmt = tvm.te.schedule.ScheduleOps(sch, bounds)
- stmt = tvm.tir.ir_pass.LoopPartition(stmt, False)
- stmt = tvm.tir.ir_pass.StorageFlatten(stmt, binds, 64)
- stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt)
- stmt = tvm.tir.ir_pass.VectorizeLoop(stmt)
- stmt = tvm.tir.ir_pass.Simplify(stmt)
- return stmt
def test_basic():
n = te.size_var('n')
bounds = tvm.te.schedule.InferBound(s)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
- stmt = tvm.tir.ir_pass.LoopPartition(stmt, False)
- stmt = tvm.tir.ir_pass.Simplify(stmt)
- assert('if' not in str(stmt.body.body[0]))
- assert('if' in str(stmt.body.body[1]))
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], stmt))
+ mod = tvm.tir.transform.LoopPartition(False)(mod)
+ stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+
+ assert(not any(
+ collect_visit(stmt.body.body[0], lambda x: isinstance(x, tvm.tir.IfThenElse))))
+ assert(any(
+ collect_visit(stmt.body.body[1], lambda x: isinstance(x, tvm.tir.IfThenElse))))
+
def test_const_loop():
n = 21
bounds = tvm.te.schedule.InferBound(s)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
- stmt = tvm.tir.ir_pass.LoopPartition(stmt, True)
- stmt = tvm.tir.ir_pass.Simplify(stmt)
- assert('if' not in str(stmt.body.body[0]))
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ mod = tvm.tir.transform.LoopPartition(True)(mod)
+ stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+
+ assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))))
def test_multi_loop():
ib = tvm.tir.ir_builder.create()
with ib.else_scope():
ib.emit(tvm.tir.Evaluate(n))
stmt = ib.get()
- stmt = tvm.tir.ir_pass.LoopPartition(stmt, False)
- stmt = tvm.tir.ir_pass.Simplify(stmt)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n, m], stmt))
+ mod = tvm.tir.transform.LoopPartition(False)(mod)
+ stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+
assert(not any(collect_visit(stmt.body[0], lambda x: isinstance(x, tvm.tir.IfThenElse))))
def test_multi_if():
with ib.else_scope():
ib.emit(tvm.tir.Evaluate(n))
stmt = ib.get()
- stmt = tvm.tir.ir_pass.LoopPartition(stmt, False)
- stmt = tvm.tir.ir_pass.Simplify(stmt)
- assert('if' not in str(stmt.body[0]))
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ mod = tvm.tir.transform.LoopPartition(False)(mod)
+ stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+
+ assert(not any(
+ collect_visit(stmt.body[0], lambda x: isinstance(x, tvm.tir.IfThenElse))))
+
def test_thread_axis():
m = te.size_var('m')
bounds = tvm.te.schedule.InferBound(s)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
- stmt = tvm.tir.ir_pass.LoopPartition(stmt, False)
- stmt = tvm.tir.ir_pass.Simplify(stmt)
- assert('if' not in str(stmt.body.body[0]))
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ mod = tvm.tir.transform.LoopPartition(False)(mod)
+ stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+
+ assert(not any(
+ collect_visit(stmt.body. body[0], lambda x: isinstance(x, tvm.tir.IfThenElse))))
+
def test_vectorize():
n = te.size_var('n')
s[C].bind(bx, te.thread_axis("blockIdx.x"))
s[C].bind(tx, te.thread_axis("threadIdx.x"))
s[C].vectorize(x)
- stmt = lower(s, [A, B])
+ stmt = tvm.lower(s, [A, B], name="main")["main"].body
body = stmt.body.body.body.body
assert(x.var.name not in str(body.condition))
assert(any(collect_visit(body.then_case, lambda x: isinstance(x, tvm.tir.Ramp))))
+
def test_condition():
ib = tvm.tir.ir_builder.create()
m = te.size_var('m')
ib.emit(tvm.tir.Evaluate(
tvm.tir.Select(ib.likely(i*4+j<n), m, n)))
stmt = ib.get()
- stmt = tvm.tir.ir_pass.LoopPartition(stmt, False)
- stmt = tvm.tir.ir_pass.Simplify(stmt)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([m, n], stmt))
+ mod = tvm.tir.transform.LoopPartition(False)(mod)
+ stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+
assert(not any(collect_visit(stmt[0], lambda x: isinstance(x, tvm.tir.Select))))
+
def test_condition_EQ():
ib = tvm.tir.ir_builder.create()
m = te.size_var('m')
ib.emit(tvm.tir.Evaluate(
tvm.tir.Select(ib.likely(tvm.tir.EQ(i, 5)), m, n)))
stmt = ib.get()
- stmt = tvm.tir.ir_pass.LoopPartition(stmt, True)
- stmt = tvm.tir.ir_pass.Simplify(stmt)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([m, n], stmt))
+ mod = tvm.tir.transform.LoopPartition(True)(mod)
+ stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+
assert(not any(collect_visit(stmt[0], lambda x: isinstance(x, tvm.tir.Select))))
+
def test_thread_axis2():
n = tvm.runtime.convert(4096)
m = te.size_var('m')
_, x = s[C].split(x, factor=m)
s[C].bind(bx, te.thread_axis("blockIdx.x"))
s[C].bind(tx, te.thread_axis("threadIdx.x"))
- stmt = lower(s, [A, B])
+ stmt = tvm.lower(s, [A, B], name="main")["main"].body
for_body = stmt.body.body.body.body[0]
assert('threadIdx' not in str(for_body.extent))
# this guard will produce everything during deduction
ib.emit(tvm.tir.Evaluate(m))
stmt = ib.get()
- stmt = tvm.tir.ir_pass.LoopPartition(stmt, False)
- stmt = tvm.tir.ir_pass.Simplify(stmt)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([m, n], stmt))
+ mod = tvm.tir.transform.LoopPartition(False)(mod)
+ stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+
+
assert(isinstance(stmt.body.body, tvm.tir.IfThenElse))
def test_single_likely():
bounds = tvm.te.schedule.InferBound(s)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
- stmt = tvm.tir.ir_pass.LoopPartition(stmt, True)
- stmt = tvm.tir.ir_pass.Simplify(stmt)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ mod = tvm.tir.transform.LoopPartition(True)(mod)
+ stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+
assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))))
def test_multi_likely():
bounds = tvm.te.schedule.InferBound(s)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
- stmt = tvm.tir.ir_pass.LoopPartition(stmt, True)
- stmt = tvm.tir.ir_pass.Simplify(stmt)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ mod = tvm.tir.transform.LoopPartition(True)(mod)
+ stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+
assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))))
+
def test_oneD_pool():
m = te.size_var('m')
ib = tvm.tir.ir_builder.create()
out[ow] = tvm.te.max(out[ow], data[ow + kw - 1])
stmt = ib.get()
- stmt = tvm.tir.ir_pass.LoopPartition(stmt, True)
- stmt = tvm.tir.ir_pass.Simplify(stmt)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([m, data, out], stmt))
+ mod = tvm.tir.transform.LoopPartition(True)(mod)
+ stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+
assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))))
+
def test_cce_loop_1():
ib = tvm.tir.ir_builder.create()
dtype = 'float16'
with ib.if_scope(ib.likely(((i*160) + j) < 1600)):
A[(i+1)*m+j+1] = B[(i)*m+j+1] + B[(i+1)*m+j+1] + B[(i+2)*m+j+1]
stmt = ib.get()
- stmt = tvm.tir.ir_pass.LoopPartition(stmt, True)
- stmt = tvm.tir.ir_pass.Simplify(stmt)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
+ mod = tvm.tir.transform.LoopPartition(True)(mod)
+ stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+
assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))))
def test_cce_loop_2():
ib.emit(tvm.tir.call_extern('float32', "cce_intrisic", head, tail))
stmt = ib.get()
- stmt = tvm.tir.ir_pass.LoopPartition(stmt, True)
- stmt = tvm.tir.ir_pass.Simplify(stmt)
+
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ mod = tvm.tir.transform.LoopPartition(True)(mod)
+ stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+
assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))))
ib.emit(tvm.tir.call_extern('float16',"cce_intrisic",head1))
stmt = ib.get()
- stmt = tvm.tir.ir_pass.LoopPartition(stmt,True)
- stmt = tvm.tir.ir_pass.Simplify(stmt)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ mod = tvm.tir.transform.LoopPartition(True)(mod)
+ stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+
assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))))
+
def test_conv_tiling():
HSTR = WSTR = 1
in_channel = 128
oho, owo, ohi, owi = s[conv].tile(oh, ow, 16, 16)
bounds = tvm.te.schedule.InferBound(s)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
- stmt = tvm.tir.ir_pass.LoopPartition(stmt, True)
- stmt = tvm.tir.ir_pass.Simplify(stmt)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ mod = tvm.tir.transform.LoopPartition(True)(mod)
+ stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+
assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))))
s.normalize()
bounds = tvm.te.schedule.InferBound(s)
-
stmt1 = tvm.te.schedule.ScheduleOps(s, bounds)
- stmt1 = tvm.tir.ir_pass.Simplify(stmt1)
- stmt2 = tvm.tir.ir_pass.LoopPartition(stmt1, True)
- stmt2 = tvm.tir.ir_pass.Simplify(stmt2)
+ mod1 = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt1))
+ stmt1 = tvm.tir.transform.Simplify()(mod1)["main"].body
+
+ mod2 = tvm.tir.transform.LoopPartition(True)(mod1)
+ stmt2 = tvm.tir.transform.Simplify()(mod2)["main"].body
- #make sure loop partition actually did something
+ # make sure loop partition actually did something
assert not tvm.ir.structural_equal(stmt1.body, stmt2.body)
k, 0, m, 0, 0,
tvm.tir.IfThenElse(
(i*m+j+k < n), tvm.tir.Evaluate(m), tvm.tir.Evaluate(n)))))
- ret = tvm.tir.ir_pass.RemoveNoOp(stmt)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt))
+ ret = tvm.tir.transform.RemoveNoOp()(mod)["main"].body
+
assert(isinstance(ret, tvm.tir.Evaluate))
store = tvm.tir.Store(Ab.data,
tvm.tir.Load(dtype, Ab.data, i) + 1,
i + 1)
stmt2 = tvm.tir.SeqStmt([nop(), tvm.tir.SeqStmt([store, nop()])])
- assert(tvm.tir.ir_pass.RemoveNoOp(stmt2) == store)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt2))
+ ret = tvm.tir.transform.RemoveNoOp()(mod)["main"].body
+ assert(ret == store)
+
# remove zero extent loop
stmt3 = tvm.tir.For(i, 0, 0, 0, 0, store)
- ret = tvm.tir.ir_pass.RemoveNoOp(stmt3)
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt3))
+ ret = tvm.tir.transform.RemoveNoOp()(mod)["main"].body
assert(isinstance(ret, tvm.tir.Evaluate))
A = ib.allocate("float32", 100, name="A", scope="global")
i = te.var("i")
y = tvm.tir.Select(i > 1, A[i-1], 1.0)
- yy = tvm.tir.ir_pass.RewriteUnsafeSelect(tvm.tir.Evaluate(y)).value
+
+ mod = tvm.IRModule.from_expr(
+ tvm.tir.PrimFunc([i], tvm.tir.Evaluate(y)))
+ yy = tvm.tir.transform.RewriteUnsafeSelect()(mod)["main"].body.value
z = tvm.tir.Select(
tvm.tir.Select(i > 1, A[i-1], 1.0) > 0.0, A[i], 0.1)
- zz = tvm.tir.ir_pass.RewriteUnsafeSelect(tvm.tir.Evaluate(z)).value
+ mod = tvm.IRModule.from_expr(
+ tvm.tir.PrimFunc([i], tvm.tir.Evaluate(z)))
+ zz = tvm.tir.transform.RewriteUnsafeSelect()(mod)["main"].body.value
+
+ a = tvm.tir.Select(tvm.tir.floordiv(i, 4) > 10, y, z)
- a = tvm.tir.Select(tvm.te.floordiv(i, 4) > 10, y, z)
- aa = tvm.tir.ir_pass.RewriteUnsafeSelect(tvm.tir.Evaluate(a)).value
+ mod = tvm.IRModule.from_expr(
+ tvm.tir.PrimFunc([i], tvm.tir.Evaluate(a)))
+ aa = tvm.tir.transform.RewriteUnsafeSelect()(mod)["main"].body.value
assert yy.name == "tvm_if_then_else"
assert zz.name == "tvm_if_then_else"
assert isinstance(aa, tvm.tir.Select)
A[i] = C[i]
body = tvm.tir.LetStmt(n, 10, ib.get())
- body = tvm.tir.ir_pass.CanonicalSimplify(body)
+ mod = tvm.IRModule.from_expr(
+ tvm.tir.PrimFunc([A, C, n], body))
+ body = tvm.tir.transform.Simplify()(mod)["main"].body
assert isinstance(body.body, tvm.tir.Store)
with ib.if_scope(tx + ty < 12):
A[tx] = C[tx + ty]
body = tvm.tir.LetStmt(n, 10, ib.get())
- body = tvm.tir.ir_pass.CanonicalSimplify(body)
+ mod = tvm.IRModule.from_expr(
+ tvm.tir.PrimFunc([A, C, n], body))
+ body = tvm.tir.transform.Simplify()(mod)["main"].body
assert isinstance(body.body.body.body, tvm.tir.Store)
Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B')
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64)
- stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt)
- stmt = tvm.tir.ir_pass.Simplify(stmt)
- stmt = tvm.tir.ir_pass.StorageRewrite(stmt)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
+ mod = tvm.tir.transform.Simplify()(mod)
+ mod = tvm.tir.transform.StorageRewrite()(mod)
+ stmt = mod["main"].body
+
# verify only have one allocations.
# verify inplace folding works
num_alloc = [0]
A[j] = 1.3
body = ib.get()
- body = tvm.tir.ir_pass.StorageRewrite(body)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body))
+ body = tvm.tir.transform.StorageRewrite()(mod)["main"].body
+
num_alloc = [0]
def verify(n):
if isinstance(n, tvm.tir.Allocate):
body = stmt_generater(dtype_list, length)
offset = offset_generater(dtype_list, length)
- body = tvm.tir.ir_pass.StorageRewrite(body)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], body))
+ body = tvm.tir.transform.StorageRewrite()(mod)["main"].body
+
tvm.tir.ir_pass.PostOrderVisit(body, verify)
length = 1024
Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B')
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64)
- stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt)
- stmt = tvm.tir.ir_pass.Simplify(stmt)
- stmt = tvm.tir.ir_pass.StorageRewrite(stmt)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
+ mod = tvm.tir.transform.Simplify()(mod)
+ mod = tvm.tir.transform.StorageRewrite()(mod)
+ stmt = mod["main"].body
+
# verify only have one allocations.
# verify inplace folding works
num_alloc = [0]
Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B')
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64)
- stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt)
- stmt = tvm.tir.ir_pass.Simplify(stmt)
- stmt = tvm.tir.ir_pass.StorageRewrite(stmt)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
+ mod = tvm.tir.transform.Simplify()(mod)
+ mod = tvm.tir.transform.StorageRewrite()(mod)
+ stmt = mod["main"].body
+
num_alloc = [0]
def verify(n):
if isinstance(n, tvm.tir.Allocate):
Ab = tvm.tir.decl_buffer(A[0].shape, A[0].dtype, name='A')
Bb = tvm.tir.decl_buffer(A[0].shape, A[0].dtype, name='B')
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A[0]: Ab, A[-1]: Bb}, 64)
- stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt)
- stmt = tvm.tir.ir_pass.Simplify(stmt)
- stmt = tvm.tir.ir_pass.StorageRewrite(stmt)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
+ mod = tvm.tir.transform.Simplify()(mod)
+ mod = tvm.tir.transform.StorageRewrite()(mod)
+ stmt = mod["main"].body
+
alloc_stats = {"global": 0, "shared": 0}
def verify(n):
A[j] = A[j] + 2
body = ib.get()
- body = tvm.tir.ir_pass.StorageRewrite(body)
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body))
+ body = tvm.tir.transform.StorageRewrite()(mod)["main"].body
+
assert (isinstance(body.body.body, tvm.tir.Allocate))
ib = tvm.tir.ir_builder.create()
A = ib.allocate("float32", n, name="A", scope="global")
A[j] = A[j] + 2
body = ib.get()
- body = tvm.tir.ir_pass.StorageRewrite(body)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body))
+ body = tvm.tir.transform.StorageRewrite()(mod)["main"].body
assert(isinstance(body.body.body.body.body, tvm.tir.Allocate))
Cc = tvm.tir.decl_buffer(C.shape, B.dtype, name='C')
Dd = tvm.tir.decl_buffer(D.shape, B.dtype, name='D')
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb, C: Cc, D:Dd}, 64)
- stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt)
- stmt = tvm.tir.ir_pass.Simplify(stmt)
- stmt = tvm.tir.ir_pass.StorageRewrite(stmt)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb, Cc, Dd], stmt))
+ mod = tvm.tir.transform.Simplify()(mod)
+ mod = tvm.tir.transform.StorageRewrite()(mod)
+ stmt = mod["main"].body
+
# verify only have one allocations.
# verify inplace folding works
num_alloc = [0]
B5a = tvm.tir.decl_buffer(B5.shape, B5.dtype, name='B5')
Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B')
- stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {B0: B0a, B1: B1a, B2: B2a, B3: B2a, B4: B4a, B5: B5a, B: Bb}, 64)
- stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt)
- stmt = tvm.tir.ir_pass.Simplify(stmt)
- stmt = tvm.tir.ir_pass.StorageRewrite(stmt)
+ stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {B0: B0a, B1: B1a, B2: B2a, B3: B3a, B4: B4a, B5: B5a, B: Bb}, 64)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([B0a, B1a, B2a, B3a, B4a, B5a, Bb], stmt))
+ mod = tvm.tir.transform.Simplify()(mod)
+ mod = tvm.tir.transform.StorageRewrite()(mod)
+ stmt = mod["main"].body
+
# verify only have one allocations.
# verify inplace folding works
def verify(n):
A2[j] = A[j]
body = ib.get()
- body = tvm.tir.ir_pass.StorageRewrite(body)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body))
+ body = tvm.tir.transform.StorageRewrite()(mod)["main"].body
+
num_alloc = [0]
def verify(n):
if isinstance(n, tvm.tir.Allocate):
C[j] = 1.2
body = ib.get()
- body = tvm.tir.ir_pass.StorageRewrite(body)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body))
+ body = tvm.tir.transform.StorageRewrite()(mod)["main"].body
+
num_alloc = [0]
def verify(n):
if isinstance(n, tvm.tir.Allocate):
E[j] = C[j]
body = ib.get()
- body = tvm.tir.ir_pass.StorageRewrite(body)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body))
+ body = tvm.tir.transform.StorageRewrite()(mod)["main"].body
num_alloc = [0]
if __name__ == "__main__":
+ test_storage_share()
test_alloc_seq()
test_alloc_different_dtypes()
test_inplace_rule()
- test_storage_share()
test_parallel_alloc()
test_storage_combine()
test_storage_share_gpu()
test_inplace_rule2()
+
test_exceed_mem()
test_inplace_rule3()
test_alloc_seq_type()
wrapped = ib.get()
wrapped = tvm.tir.SeqStmt([wrapped, stmt])
assert isinstance(ret, tvm.tir.For)
- ret = tvm.tir.ir_pass.UnrollLoop(wrapped, 0, 8, 0, False)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], wrapped))
+ ret = tvm.tir.transform.UnrollLoop(0, 8, 0, False)(mod)["main"].body
+
+ # ret = tvm.tir.ir_pass.UnrollLoop(wrapped, 0, 8, 0, False)
assert isinstance(ret[0], tvm.tir.For)
assert ret[0].for_type == tvm.tir.For.Unrolled
assert isinstance(ret[1], tvm.tir.For)
Aptr[j + 1] = Aptr[i] + 1
stmt = ib.get()
- ret = tvm.tir.ir_pass.UnrollLoop(stmt, 8, 0, 1, True)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt))
+ ret = tvm.tir.transform.UnrollLoop(8, 0, 1, False)(mod)["main"].body
+
+ # ret = tvm.tir.ir_pass.UnrollLoop(stmt, 8, 0, 1, True)
assert isinstance(ret[0], tvm.tir.Store)
def test_unroll_single_count_loops():
stmt = tvm.te.schedule.ScheduleOps(s, dom_map)
# all parameters to UnrolLoops are default values except for
# auto_unroll_max_extent which has been set to 1 (default:0)
- after_unroll_stmt = tvm.tir.ir_pass.UnrollLoop(stmt, 0, 8, 1, True)
- assert after_unroll_stmt == stmt
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ ret = tvm.tir.transform.UnrollLoop(0, 8, 1, True)(mod)["main"].body
+
+ assert ret == stmt
if __name__ == "__main__":
test_unroll_loop()
stmt = ib.get()
assert isinstance(stmt.body, tvm.tir.For)
- stmt = tvm.tir.ir_pass.VectorizeLoop(stmt)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt))
+ stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body
+
assert isinstance(stmt, tvm.tir.For)
assert not isinstance(stmt.body, tvm.tir.For)
assert isinstance(stmt.body.index, tvm.tir.Ramp)
assert isinstance(stmt.body.value, tvm.tir.Broadcast)
+
def test_vectorize_vector():
dtype = 'int64'
n = te.var('n')
A[j] = tvm.tir.const(1, A.dtype)
stmt = ib.get()
assert isinstance(stmt.body, tvm.tir.For)
- stmt = tvm.tir.ir_pass.VectorizeLoop(stmt)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt))
+ stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body
+
assert isinstance(stmt, tvm.tir.For)
assert not isinstance(stmt.body, tvm.tir.For)
assert isinstance(stmt.body.index, tvm.tir.Ramp)
with ib.if_scope(i < n):
A[i] = 2.0
stmt = ib.get()
- stmt = tvm.tir.ir_pass.VectorizeLoop(stmt)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n, x], stmt))
+ stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body
+
assert isinstance(stmt, tvm.tir.IfThenElse)
assert isinstance(stmt.then_case.index, tvm.tir.Ramp)
assert isinstance(stmt.then_case.value, tvm.tir.Add)
assert stmt.then_case.value.dtype == "float32x4"
assert isinstance(stmt.else_case, tvm.tir.For)
+
def test_vectorize_with_le_cond():
n = te.var('n')
ib = tvm.tir.ir_builder.create()
with ib.if_scope(i <= n):
A[i] = A[i] + 1
stmt = ib.get()
- stmt = tvm.tir.ir_pass.VectorizeLoop(stmt)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt))
+ stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body
+
assert isinstance(stmt, tvm.tir.For)
+
def test_vectorize_with_ge_cond():
n = te.var('n')
ib = tvm.tir.ir_builder.create()
with ib.if_scope(i >= n):
A[i] = A[i] + 1
stmt = ib.get()
- stmt = tvm.tir.ir_pass.VectorizeLoop(stmt)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt))
+ stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body
+
assert isinstance(stmt, tvm.tir.For)
+
def test_vectorize_if_then_else():
n = te.var('n')
x = te.var('x')
i > 0,
A[i] + 1, A[i])
stmt = ib.get()
- stmt = tvm.tir.ir_pass.VectorizeLoop(stmt)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n, x], stmt))
+ stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body
+
assert isinstance(stmt, tvm.tir.For)
k > 0,
A[k * 4 + i], 0)
stmt = ib.get()
+
assert isinstance(stmt.body, tvm.tir.For)
- stmt = tvm.tir.ir_pass.VectorizeLoop(stmt)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt))
+ stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body
+
assert not isinstance(stmt.body, tvm.tir.For)
assert isinstance(stmt.body.value.args[2], tvm.tir.Broadcast)