[TIR][REFACTOR] Migrate low-level passes in tvm.lower to the Unified IR pass manager...
authorTianqi Chen <tqchen@users.noreply.github.com>
Sat, 18 Apr 2020 19:33:58 +0000 (12:33 -0700)
committerGitHub <noreply@github.com>
Sat, 18 Apr 2020 19:33:58 +0000 (12:33 -0700)
- Migrate BoundCheckers and Simplify
- Migrate RewriteUnsafeSelect and RemoveNoOp
- Migrate UnrollLoop and StorageRewrite
- Migrate InjectDoubleBuffer and InjectVirtualThread
- Migrate LoopPartition and Vectorize
- Migrate CoProcSync, LiftAttrScope, InjectCopyIntrin

We still keep ir_pass registerations for now.
Need a separate PR to refactor the parts before the StorageFlatten.

37 files changed:
include/tvm/tir/analysis.h
include/tvm/tir/ir_pass.h
include/tvm/tir/transform.h
python/tvm/driver/build_module.py
python/tvm/tir/transform/transform.py
src/arith/compute_expr.h
src/driver/driver_api.cc
src/tir/pass/ffi_api.cc
src/tir/transforms/bound_checker.cc [moved from src/tir/pass/bound_checker.cc with 88% similarity]
src/tir/transforms/coproc_sync.cc [moved from src/tir/pass/coproc_sync.cc with 97% similarity]
src/tir/transforms/inject_copy_intrin.cc [moved from src/tir/pass/inject_copy_intrin.cc with 91% similarity]
src/tir/transforms/inject_double_buffer.cc [moved from src/tir/pass/inject_double_buffer.cc with 93% similarity]
src/tir/transforms/inject_virtual_thread.cc [moved from src/tir/pass/inject_virtual_thread.cc with 96% similarity]
src/tir/transforms/lift_attr_scope.cc [moved from src/tir/pass/lift_attr_scope.cc with 90% similarity]
src/tir/transforms/loop_partition.cc [moved from src/tir/pass/loop_partition.cc with 96% similarity]
src/tir/transforms/lower_device_storage_access_info.cc
src/tir/transforms/narrow_datatype.cc
src/tir/transforms/remove_no_op.cc [moved from src/tir/pass/remove_no_op.cc with 89% similarity]
src/tir/transforms/rewrite_unsafe_select.cc [moved from src/tir/pass/rewrite_unsafe_select.cc with 89% similarity]
src/tir/transforms/simplify.cc [moved from src/arith/stmt_simplify.cc with 86% similarity]
src/tir/transforms/storage_rewrite.cc [moved from src/tir/pass/storage_rewrite.cc with 98% similarity]
src/tir/transforms/unroll_loop.cc [moved from src/tir/pass/unroll_loop.cc with 88% similarity]
src/tir/transforms/vectorize_loop.cc [moved from src/tir/pass/vectorize_loop.cc with 95% similarity]
tests/python/unittest/test_tir_pass_virtual_thread.py [deleted file]
tests/python/unittest/test_tir_transform_coproc_sync.py [moved from tests/python/unittest/test_tir_pass_coproc_sync.py with 91% similarity]
tests/python/unittest/test_tir_transform_inject_copy_intrin.py [moved from tests/python/unittest/test_tir_pass_inject_copy_intrin.py with 89% similarity]
tests/python/unittest/test_tir_transform_inject_double_buffer.py [moved from tests/python/unittest/test_tir_pass_inject_double_buffer.py with 91% similarity]
tests/python/unittest/test_tir_transform_inject_virtual_thread.py [moved from tests/python/unittest/test_tir_pass_inject_vthread.py with 83% similarity]
tests/python/unittest/test_tir_transform_instrument_bound_checkers.py [moved from tests/python/unittest/test_tir_pass_bound_checkers.py with 94% similarity]
tests/python/unittest/test_tir_transform_lift_attr_scope.py [moved from tests/python/unittest/test_tir_pass_lift_attr_scope.py with 88% similarity]
tests/python/unittest/test_tir_transform_loop_partition.py [moved from tests/python/unittest/test_tir_pass_loop_partition.py with 79% similarity]
tests/python/unittest/test_tir_transform_remove_no_op.py [moved from tests/python/unittest/test_tir_pass_remove_no_op.py with 81% similarity]
tests/python/unittest/test_tir_transform_rewrite_unsafe_select.py [moved from tests/python/unittest/test_tir_pass_rewrite_unsafe_select.py with 70% similarity]
tests/python/unittest/test_tir_transform_simplify.py [moved from tests/python/unittest/test_arith_stmt_simplify.py with 93% similarity]
tests/python/unittest/test_tir_transform_storage_rewrite.py [moved from tests/python/unittest/test_tir_pass_storage_rewrite.py with 89% similarity]
tests/python/unittest/test_tir_transform_unroll_loop.py [moved from tests/python/unittest/test_tir_pass_unroll.py with 84% similarity]
tests/python/unittest/test_tir_transform_vectorize.py [moved from tests/python/unittest/test_tir_pass_vectorize.py with 82% similarity]

index 6af9958..5c4990a 100644 (file)
@@ -53,7 +53,6 @@ struct ExprDeepEqual {
   TVM_DLL bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const;
 };
 
-
 /*!
  * \brief Find undefined vars in the statment.
  * \param stmt The function to be checked.
index e228ce3..f3d447e 100644 (file)
@@ -203,59 +203,6 @@ Stmt RewriteForTensorCore(Stmt stmt,
 bool VerifyCompactBuffer(Stmt stmt);
 
 /*!
- * \brief Remove No Op from the Stmt.
- * \param stmt The stmt to be trasnformed
- * \return Transformed stmt.
- */
-Stmt RemoveNoOp(Stmt stmt);
-
-/*!
- * \brief unroll the constant loop marked by unroll.
- * This pass also automatically attach pragma unroll tag to loops which meets the standard.
- *
- * \param stmt The statment to be unrolled.
- * \param auto_max_step The maximum step before stop attach automatic unroll
- * \param auto_max_depth The maximum depth before stop attach automatic unroll
- * \param auto_max_extent The maximum extent of the loop we can unroll,
- *                     this is an legacy option that do not take the loop total steps into account.
- * \param explicit_unroll Whether explicitly unroll the loop, or leave unroll annotation to codegen.
- * \return Transformed stmt.
- */
-Stmt UnrollLoop(Stmt stmt,
-                int auto_max_step,
-                int auto_max_depth,
-                int auto_max_extent,
-                bool explicit_unroll);
-
-/*!
- * \brief vectorize the constant loops
- * \param stmt The statement to be vectorized.
- * \return Transformed stmt.
- */
-Stmt VectorizeLoop(Stmt stmt);
-
-/*!
- * \brief convert vectorized loops into serialized loops
- * \param stmt The statement to skip vectorization on.
- * \return Transformed stmt.
- */
-Stmt SkipVectorize(Stmt stmt);
-
-/*!
-* \brief instruments bound checkers.
-* \param stmt The statement to be instrumented.
-* \return Instrumented stmt.
-*/
-Stmt InstrumentBoundCheckers(Stmt stmt);
-
-/*!
- * \brief Inject virtual thread loops into stmt.
- * \param stmt The statement to be transformed.
- * \return Transformed stmt.
- */
-Stmt InjectVirtualThread(Stmt stmt);
-
-/*!
  * \brief Inject prefetch instructions into stmt.
  * \param stmt The statement to be transformed.
  * \return Transformed stmt.
@@ -263,84 +210,6 @@ Stmt InjectVirtualThread(Stmt stmt);
 Stmt InjectPrefetch(Stmt stmt);
 
 /*!
- * \brief Inject double buffer into stmt.
- * \param stmt The statement to be transformed.
- * \param split_loop Loop splitting factor.
- * \return Transformed stmt.
- */
-Stmt InjectDoubleBuffer(Stmt stmt, int split_loop);
-
-/*!
- * \brief Inject copy intrinsics with optional pad.
- *
- * \param stmt The statement to be transformed.
- * \param pragma_key The pragma key for hint of copy.
- * \param fintrin The function with signature
- *
- *   Stmt fintrin(Buffer src,
- *                Buffer dst,
- *                Array<Expr> pad_before,
- *                Array<Expr> pad_after,
- *                Expr pad_value)
- * \return Transformed stmt.
- */
-Stmt InjectCopyIntrin(Stmt stmt,
-                      const std::string& pragma_key,
-                      const runtime::PackedFunc& fintrin);
-
-/*!
- * \brief Rewrite storage allocation pattern.
- *  Moves the allocation to outer most possible scope.
- *  Trying to share space between allocations to make
- *  a static allocation plan when possible.
- *
- * \param stmt The stmt to be transformed
- * \return Transformed stmt.
- */
-Stmt StorageRewrite(Stmt stmt);
-
-/*!
- * \brief partition loops in the stmt
- * \param stmt The stmt to do loop partition
- * \param split_const_loop flag to enable partition for const loop
- * \return Transformed stmt.
- */
-Stmt LoopPartition(Stmt stmt, bool split_const_loop);
-
-/*!
- * \brief Detect and insert sync points to co-processor.
- *
- * \param stmt The stmt to be transformed
- * \return Transformed stmt.
- */
-Stmt CoProcSync(Stmt stmt);
-
-/*!
- * \brief Lift common attrs with attr_key to outer scope.
- *
- * \param stmt The stmt to be transformed
- * \param attr_key The attribute key to be checked.
- * \return Transformed stmt.
- */
-Stmt LiftAttrScope(Stmt stmt, std::string attr_key);
-
-/*!
- * \brief Detect and rewrite unsafe select that contains memory access.
- * \param stmt The statement to be rewritten.
- * \return Transformed stmt.
- */
-Stmt RewriteUnsafeSelect(Stmt stmt);
-
-/*!
- * \brief Lower attached storage access information.
- * Do this pass after all storage access analysis finish.
- *
- * \param stmt The stmt to be transformed
- * \return Transformed stmt.
- */
-Stmt LowerStorageAccessInfo(Stmt stmt);
-
-/*!
  * \brief Decorate the stmt with a device scope, this is helpful for
  * hardware accelerator without thread blocks.
  *
@@ -357,15 +226,6 @@ Stmt DecorateDeviceScope(Stmt stmt);
 Stmt HoistIfThenElse(Stmt stmt);
 
 /*!
- * \brief Narrow down PrimExpr datatype in stmt to target_bits.
- * \note  Run this pass after StorageFlatten.
- * \param stmt The stmt to do datatype rewrite
- * \param target_bits the bit of target datatype
- * \return Transformed stmt.
- */
-Stmt NarrowDataType(Stmt stmt, int target_bits);
-
-/*!
  * \brief Rewrite the pointer content type of arguments,
  *  as well as Alloc internal to the function to use
  *  the most frequently accessed type for load/store
index 23c1955..e593e1b 100644 (file)
@@ -59,6 +59,124 @@ TVM_DLL Pass CreatePrimFuncPass(const runtime::TypedPackedFunc<
                                 const tvm::Array<runtime::String>& required);
 
 /*!
+ * \brief Inject copy intrinsics with optional pad.
+ *
+ * \param pragma_key The pragma key for hint of copy.
+ * \param fintrin The function with signature
+ *
+ *   Stmt fintrin(Buffer src,
+ *                Buffer dst,
+ *                Array<Expr> pad_before,
+ *                Array<Expr> pad_after,
+ *                Expr pad_value)
+ * \return The pass.
+ */
+TVM_DLL Pass InjectCopyIntrin(std::string pragma_key,
+                              runtime::PackedFunc fintrin);
+
+/*!
+ * \brief Detect and insert sync points to co-processor.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass CoProcSync();
+
+/*!
+ * \brief Lift common attrs with attr_key to outer scope.
+ *
+ * \param attr_key The attribute key to be checked.
+ * \return The pass.
+ */
+TVM_DLL Pass LiftAttrScope(std::string attr_key);
+
+/*!
+ * \brief partition loops in the stmt.
+ *
+ * \param split_const_loop flag to enable partition for const loop
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass LoopPartition(bool split_const_loop);
+
+/*!
+ * \brief Lower vectorization loops.
+ *
+ * \param enable_vectorize Whether vectorization is enabled.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass VectorizeLoop(bool enable_vectorize = true);
+
+/*!
+ * \brief Inject virtual thread loops.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass InjectVirtualThread();
+
+/*!
+ * \brief Inject double buffer statements.
+ *
+ * \param split_loop_factor Loop splitting factor.
+ * \return The pass.
+ */
+TVM_DLL Pass InjectDoubleBuffer(int split_loop_factor);
+
+/*!
+ * \brief Rewrite storage allocation pattern.
+ *  Moves the allocation to outer most possible scope.
+ *  Trying to share space between allocations to make
+ *  a static allocation plan when possible.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass StorageRewrite();
+
+/*!
+ * \brief unroll the constant loop marked by unroll.
+ * This pass also automatically attach pragma unroll tag to loops which meets the standard.
+ *
+ * \param auto_max_step The maximum step before stop attach automatic unroll
+ * \param auto_max_depth The maximum depth before stop attach automatic unroll
+ * \param auto_max_extent The maximum extent of the loop we can unroll,
+ *        this is an legacy option that do not take the loop total steps into account.
+ * \param explicit_unroll Whether explicitly unroll the loop, or leave unroll annotation to codegen.
+ * \return The pass.
+ */
+TVM_DLL Pass UnrollLoop(int auto_max_step,
+                        int auto_max_depth,
+                        int auto_max_extent,
+                        bool explicit_unroll);
+
+/*!
+ * \brief Remove No Op from the Stmt.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass RemoveNoOp();
+
+/*!
+ * \brief Detect and rewrite unsafe select that contains memory access.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass RewriteUnsafeSelect();
+
+/*!
+* \brief Run arithmetic simplifications on the statements and expressions.
+*
+* \return The pass.
+*/
+TVM_DLL Pass Simplify();
+
+/*!
+* \brief Instruments bound checkers.
+*
+* \return The pass.
+*/
+TVM_DLL Pass InstrumentBoundCheckers();
+
+/*!
  * \brief Transform the high-level PrimFunc to a low-level version
  *        that can be used as an API function.
  *
index a429d07..18a8a47 100644 (file)
@@ -179,6 +179,7 @@ def lower(sch,
         cfg.auto_unroll_max_depth,
         cfg.auto_unroll_max_extent,
         cfg.unroll_explicit)
+
     for f in lower_phase2:
         stmt = f(stmt)
 
@@ -187,11 +188,14 @@ def lower(sch,
     stmt = ir_pass.RemoveNoOp(stmt)
     if not cfg.disable_select_rewriting:
         stmt = ir_pass.RewriteUnsafeSelect(stmt)
+
     for f in lower_phase3:
         stmt = f(stmt)
+
     # Instrument BoundCheckers
     if cfg.instrument_bound_checkers:
         stmt = ir_pass.InstrumentBoundCheckers(stmt)
+
     if simple_mode:
         return stmt
 
index 9f64a93..f83bb11 100644 (file)
@@ -60,6 +60,203 @@ def Filter(fcond):
     return _fpass.prim_func_pass(_transform, opt_level=0, name="Filter")
 
 
+def InjectCopyIntrin(pragma_key, fintrin):
+    """Inject virtual thread loops.
+
+    Parameters
+    ----------
+    pragma_key : str
+        The pragma key for hint of copy.
+
+    fintrin : function
+        The function with signature copyintrin(src, dst, pad_before, pad_after, pad_value)
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The result pass
+    """
+    return _ffi_api.InjectCopyIntrin(pragma_key, fintrin)
+
+
+def CoProcSync():
+    """Detect and insert sync points to co-processor.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The result pass
+    """
+    return _ffi_api.CoProcSync()
+
+
+def LiftAttrScope(attr_key):
+    """Lift common attrs with attr_key to outer scope.
+
+    Parameters
+    ----------
+    attr_key : str
+        The attribute key to be checked.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The result pass
+    """
+    return _ffi_api.LiftAttrScope(attr_key)
+
+
+def LoopPartition(split_const_loop):
+    """Inject virtual thread loops.
+
+    Parameters
+    ----------
+    split_const_loop : bool
+        Flag to enable partition for const loop.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The result pass
+    """
+    return _ffi_api.LoopPartition(split_const_loop)
+
+
+def VectorizeLoop(enable_vectorize=True):
+    """Lower vectorization loops.
+
+    Parameters
+    ----------
+    enable_vectorize : bool
+        Whether vectorization is enabled.
+        Will lower to scalar loop when it is turned off.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The result pass
+    """
+    return _ffi_api.VectorizeLoop(enable_vectorize)
+
+
+def InjectVirtualThread():
+    """Inject virtual thread loops.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The result pass
+    """
+    return _ffi_api.InjectVirtualThread()
+
+
+def InjectDoubleBuffer(split_loop_factor):
+    """Inject double buffer statements.
+
+    Parameters
+    ----------
+    split_loop_factor : int
+        Loop splitting factor.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The result pass
+    """
+    return _ffi_api.InjectDoubleBuffer(split_loop_factor)
+
+
+def StorageRewrite():
+    """Rewrite storage allocation pattern.
+
+    Moves the allocation to outer most possible scope.
+    Trying to share space between allocations to make
+    a static allocation plan when possible.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The result pass
+    """
+    return _ffi_api.StorageRewrite()
+
+
+def UnrollLoop(auto_max_step,
+               auto_max_depth,
+               auto_max_extent,
+               explicit_unroll):
+    """Unroll the constant loop marked by unroll.
+
+    This pass also automatically attach pragma unroll tag to loops which meets the standard.
+
+    Parameters
+    ----------
+    auto_max_step : int
+        The maximum step before stop attach automatic unroll
+
+    auto_max_depth : int
+        The maximum depth before stop attach automatic unroll
+
+     auto_max_extent : int
+        The maximum extent of the loop we can unroll.
+        This is an legacy option that do not take the loop total steps into account.
+
+    explicit_unroll : bool
+        Whether explicitly unroll the loop, or leave unroll annotation to codegen.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The result pass
+    """
+    return _ffi_api.UnrollLoop(
+        auto_max_step, auto_max_depth, auto_max_extent, explicit_unroll)
+
+
+def RemoveNoOp():
+    """Remove No Op from the Stmt.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The result pass
+    """
+    return _ffi_api.RemoveNoOp()
+
+
+def RewriteUnsafeSelect():
+    """Detect and rewrite unsafe select that contains memory access.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The result pass
+    """
+    return _ffi_api.RewriteUnsafeSelect()
+
+
+def Simplify():
+    """Run arithmetic simplifications on the statements and expressions.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The result pass
+    """
+    return _ffi_api.Simplify()
+
+
+def InstrumentBoundCheckers():
+    """Instruments bound checkers.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The result pass
+    """
+    return _ffi_api.InstrumentBoundCheckers()
+
+
 def LowerCustomDatatypes():
     """Lower custom datatypes.
 
index adb4f30..f842780 100644 (file)
@@ -25,6 +25,7 @@
 #define TVM_ARITH_COMPUTE_EXPR_H_
 
 #include <tvm/tir/expr.h>
+#include <tvm/tir/op.h>
 #include <limits>
 #include <algorithm>
 
index f576c84..e38179e 100644 (file)
@@ -109,64 +109,6 @@ void GetBinds(const Array<te::Tensor>& args,
   }
 }
 
-/*!
-* \brief Build a Stmt given a schedule, args and binds. This function runs the IR passes.
-* \param sch The schedule to build.
-* \param args The arguments for the schedule.
-* \param binds Buffer assignments.
-* \param loop_partition True if the LoopPartition pass should be included.
-* \param out_arg_list Returns the arguments for the Stmt.
-* \param config The build configuration.
-* \return The built Stmt.
-*/
-tir::Stmt BuildStmt(te::Schedule sch,
-                    const Array<te::Tensor>& args,
-                    const std::unordered_map<te::Tensor, tir::Buffer>& binds,
-                    bool loop_partition,
-                    Array<ObjectRef> *out_arg_list,
-                    const BuildConfig& config) {
-  sch = sch.normalize();
-
-  // Phase 0
-  auto bounds = te::InferBound(sch);
-  auto stmt = te::ScheduleOps(sch, bounds, false);
-  stmt = tir::InjectPrefetch(stmt);
-
-  bool compact = tir::VerifyCompactBuffer(stmt);
-  Map<te::Tensor, tir::Buffer> out_binds;
-  GetBinds(args, compact, binds, &out_binds, out_arg_list, config);
-
-  // Phase 1
-  stmt = tir::StorageFlatten(stmt, out_binds, 64,
-                            config->instrument_bound_checkers);
-  stmt = tir::CanonicalSimplify(stmt);
-  if (loop_partition) {
-    stmt = tir::LoopPartition(stmt, config->partition_const_loop);
-  }
-  if (config->disable_vectorize) {
-    stmt = tir::SkipVectorize(stmt);
-  } else {
-    stmt = tir::VectorizeLoop(stmt);
-  }
-  stmt = tir::InjectVirtualThread(stmt);
-  stmt = tir::InjectDoubleBuffer(stmt, config->double_buffer_split_loop);
-  stmt = tir::StorageRewrite(stmt);
-  stmt = tir::UnrollLoop(stmt, config->auto_unroll_max_step, config->auto_unroll_max_depth,
-    config->auto_unroll_max_extent, config->unroll_explicit);
-
-  // Phase 2
-  stmt = tir::Simplify(stmt);
-  stmt = tir::RemoveNoOp(stmt);
-
-  if (!(config->disable_select_rewriting))
-    stmt = tir::RewriteUnsafeSelect(stmt);
-
-  if (config->instrument_bound_checkers)
-    stmt = tir::InstrumentBoundCheckers(stmt);
-
-  return stmt;
-}
-
 transform::Pass BindTarget(Target target) {
   auto fpass = [target](tir::PrimFunc f, IRModule m, transform::PassContext ctx) {
     return WithAttr(std::move(f), tvm::attr::kTarget, target);
@@ -176,7 +118,7 @@ transform::Pass BindTarget(Target target) {
 
 
 template<typename FCond>
-transform::Pass FilterBy(FCond fcond) {
+transform::Pass Filter(FCond fcond) {
   auto fpass = [fcond](tir::PrimFunc f, IRModule m, transform::PassContext ctx) {
     if (fcond(f)) {
       return f;
@@ -184,18 +126,14 @@ transform::Pass FilterBy(FCond fcond) {
       return tir::PrimFunc(nullptr);
     }
   };
-  return tir::transform::CreatePrimFuncPass(fpass, 0, "FilterBy", {});
+  return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
 }
 
 
-IRModule lower(te::Schedule sch,
-               const Array<te::Tensor>& args,
-               const std::string& name,
-               const std::unordered_map<te::Tensor, tir::Buffer>& binds,
-               const BuildConfig& config) {
-  Array<ObjectRef> out_arg_list;
-  auto stmt = BuildStmt(sch, args, binds, true, &out_arg_list, config);
-
+IRModule BuildIRModule(const Array<ObjectRef>& out_arg_list,
+                       tir::Stmt stmt,
+                       const std::string& name,
+                       const BuildConfig& config) {
   Array<tir::Var> params;
   Map<tir::Var, tir::Buffer> buffer_map;
 
@@ -216,10 +154,64 @@ IRModule lower(te::Schedule sch,
   if (config->restricted_func) {
     f = WithAttr(std::move(f), "tir.noalias", Integer(1));
   }
+
   return IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
 }
 
 
+IRModule lower(te::Schedule sch,
+               const Array<te::Tensor>& args,
+               const std::string& name,
+               const std::unordered_map<te::Tensor, tir::Buffer>& binds,
+               const BuildConfig& config) {
+  Array<ObjectRef> out_arg_list;
+
+  sch = sch.normalize();
+
+  // Phase 0
+  auto bounds = te::InferBound(sch);
+  auto stmt = te::ScheduleOps(sch, bounds, false);
+  stmt = tir::InjectPrefetch(stmt);
+
+  bool compact = tir::VerifyCompactBuffer(stmt);
+  Map<te::Tensor, tir::Buffer> out_binds;
+  GetBinds(args, compact, binds, &out_binds, &out_arg_list, config);
+
+  // Phase 1
+  stmt = tir::StorageFlatten(stmt, out_binds, 64,
+                             config->instrument_bound_checkers);
+
+  // convert to IRModule.
+  auto mod = BuildIRModule(out_arg_list, stmt, name, config);
+  auto pass_list = Array<tvm::transform::Pass>();
+
+  pass_list.push_back(tir::transform::Simplify());
+  pass_list.push_back(tir::transform::LoopPartition(config->partition_const_loop));
+  pass_list.push_back(tir::transform::VectorizeLoop(!config->disable_vectorize));
+  pass_list.push_back(tir::transform::InjectVirtualThread());
+  pass_list.push_back(tir::transform::InjectDoubleBuffer(config->double_buffer_split_loop));
+  pass_list.push_back(tir::transform::StorageRewrite());
+  pass_list.push_back(
+      tir::transform::UnrollLoop(config->auto_unroll_max_step,
+                                 config->auto_unroll_max_depth,
+                                 config->auto_unroll_max_extent,
+                                 config->unroll_explicit));
+  // Phase 2
+  pass_list.push_back(tir::transform::Simplify());
+  pass_list.push_back(tir::transform::RemoveNoOp());
+  if (!(config->disable_select_rewriting)) {
+    pass_list.push_back(tir::transform::RewriteUnsafeSelect());
+  }
+  if (config->instrument_bound_checkers) {
+    pass_list.push_back(tir::transform::InstrumentBoundCheckers());
+  }
+  // run
+  auto optimize = transform::Sequential(pass_list);
+  mod = optimize(std::move(mod));
+  return mod;
+}
+
+
 std::pair<IRModule, IRModule>
 split_dev_host_funcs(IRModule mod_mixed,
                      const Target& target,
@@ -242,7 +234,7 @@ split_dev_host_funcs(IRModule mod_mixed,
   mod_mixed = opt_mixed(std::move(mod_mixed));
 
   auto host_pass_list = {
-    FilterBy([](const tir::PrimFunc& f) {
+    Filter([](const tir::PrimFunc& f) {
       return f->GetAttr<Integer>(
           tvm::attr::kCallingConv,
           Integer(CallingConv::kDefault)) != CallingConv::kDeviceKernelLaunch;
@@ -258,7 +250,7 @@ split_dev_host_funcs(IRModule mod_mixed,
 
   // device pipeline
   auto device_pass_list = {
-    FilterBy([](const tir::PrimFunc& f) {
+    Filter([](const tir::PrimFunc& f) {
       return f->GetAttr<Integer>(
           tvm::attr::kCallingConv,
           Integer(CallingConv::kDefault)) == CallingConv::kDeviceKernelLaunch;
index 3083b68..65981b9 100644 (file)
@@ -114,27 +114,12 @@ TVM_REGISTER_GLOBAL("ir_pass.PostOrderVisit")
 
 REGISTER_PASS(ConvertSSA);
 REGISTER_PASS(VerifySSA);
-REGISTER_PASS(RewriteUnsafeSelect);
 REGISTER_PASS(Inline);
 REGISTER_PASS(IRTransform);
-REGISTER_PASS(VectorizeLoop);
-REGISTER_PASS(SkipVectorize);
-REGISTER_PASS(UnrollLoop);
-REGISTER_PASS(InjectCopyIntrin);
-REGISTER_PASS(StorageRewrite);
-REGISTER_PASS(CoProcSync);
-REGISTER_PASS(LowerStorageAccessInfo);
-REGISTER_PASS(InjectVirtualThread);
 REGISTER_PASS(InjectPrefetch);
-REGISTER_PASS(InjectDoubleBuffer);
-REGISTER_PASS(LoopPartition);
-REGISTER_PASS(RemoveNoOp);
-REGISTER_PASS(LiftAttrScope);
 REGISTER_PASS(VerifyGPUCode);
 REGISTER_PASS(DecorateDeviceScope);
-REGISTER_PASS(InstrumentBoundCheckers);
 REGISTER_PASS(VerifyCompactBuffer);
 REGISTER_PASS(HoistIfThenElse);
-REGISTER_PASS(NarrowDataType);
 }  // namespace tir
 }  // namespace tvm
similarity index 88%
rename from src/tir/pass/bound_checker.cc
rename to src/tir/transforms/bound_checker.cc
index ee24d0f..f770bc7 100644 (file)
  */
 // Instrument checkers for out of the bounds access.
 
+#include <tvm/runtime/registry.h>
+#include <tvm/arith/analyzer.h>
 #include <tvm/tir/expr.h>
-#include <tvm/tir/ir_pass.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/transform.h>
 #include <tvm/tir/stmt_functor.h>
 #include <vector>
 #include <unordered_map>
@@ -173,8 +176,8 @@ class BoundChecker : public StmtExprMutator {
       }
 
       // Try to simplify index and bound.
-      index = tir::Simplify(index);
-      upper_bound = tir::Simplify(upper_bound);
+      index = analyzer_.Simplify(index);
+      upper_bound = analyzer_.Simplify(upper_bound);
 
       // Cast to the same type - signed, to be able to check lower bound.
       index = CastNode::make(DataType::Int(64), index);
@@ -201,6 +204,8 @@ class BoundChecker : public StmtExprMutator {
   const char *const error_message_ = "OUT OF THE BOUNDS";
   // Hashtable which maps buffer_var to shape.
   std::unordered_map<const VarNode *, PrimExpr> mem_to_shape_;
+  // internal analyzer
+  arith::Analyzer analyzer_;
 };
 
 Stmt InstrumentBoundCheckers(Stmt stmt) {
@@ -209,5 +214,29 @@ Stmt InstrumentBoundCheckers(Stmt stmt) {
   bound_collector(stmt);
   return BoundChecker(bound_collector.mem_to_shape)(std::move(stmt));
 }
+
+
+TVM_REGISTER_GLOBAL("ir_pass.InstrumentBoundCheckers")
+.set_body_typed(InstrumentBoundCheckers);
+
+namespace transform {
+
+Pass InstrumentBoundCheckers() {
+  auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+    auto* n = f.CopyOnWrite();
+    BoundCollector bound_collector;
+    // At first walk recursively and collect bound attributes.
+    bound_collector(n->body);
+    n->body = BoundChecker(bound_collector.mem_to_shape)(std::move(n->body));
+    return f;
+  };
+  return CreatePrimFuncPass(pass_func, 0, "tir.InstrumentBoundCheckers", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.InstrumentBoundCheckers")
+.set_body_typed(InstrumentBoundCheckers);
+
+}  // namespace transform
+
 }  // namespace tir
 }  // namespace tvm
similarity index 97%
rename from src/tir/pass/coproc_sync.cc
rename to src/tir/transforms/coproc_sync.cc
index 38b7798..fc20285 100644 (file)
 /*!
  * \file coproc_sync.cc
  */
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/transform.h>
 #include <tvm/tir/expr.h>
-#include <tvm/tir/ir_pass.h>
 #include <tvm/tir/stmt_functor.h>
 #include <unordered_map>
 #include <unordered_set>
-#include "ir_util.h"
-#include "storage_access.h"
+#include "../pass/ir_util.h"
+#include "../pass/storage_access.h"
 
 namespace tvm {
 namespace tir {
@@ -677,5 +678,24 @@ Stmt CoProcSync(Stmt stmt) {
   return CoProcSyncInserter().Insert(std::move(stmt));
 }
 
+TVM_REGISTER_GLOBAL("ir_pass.CoProcSync")
+.set_body_typed(CoProcSync);
+
+namespace transform {
+
+Pass CoProcSync() {
+  auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+    auto* n = f.CopyOnWrite();
+    n->body = CoProcSyncInserter().Insert(std::move(n->body));
+    return f;
+  };
+  return CreatePrimFuncPass(pass_func, 0, "tir.CoProcSync", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.CoProcSync")
+.set_body_typed(CoProcSync);
+
+}  // namespace transform
+
 }  // namespace tir
 }  // namespace tvm
similarity index 91%
rename from src/tir/pass/inject_copy_intrin.cc
rename to src/tir/transforms/inject_copy_intrin.cc
index 4805caf..5e40eb2 100644 (file)
  * \brief Replace certain copy with copy intrinsics.
  * \file copy_intrin_rewrite.cc
  */
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/transform.h>
 #include <tvm/arith/pattern.h>
 #include <tvm/tir/expr.h>
 #include <tvm/tir/stmt_functor.h>
-#include <tvm/tir/ir_pass.h>
 #include "../../arith/pattern_match.h"
 
 namespace tvm {
@@ -196,5 +197,26 @@ Stmt InjectCopyIntrin(Stmt stmt,
   return CopyIntrinInjector(pragma_key, flower_copy_fromto)(std::move(stmt));
 }
 
+TVM_REGISTER_GLOBAL("ir_pass.InjectCopyIntrin")
+.set_body_typed(InjectCopyIntrin);
+
+namespace transform {
+
+Pass InjectCopyIntrin(std::string pragma_key,
+                      PackedFunc flower_copy_fromto) {
+  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
+    auto* n = f.CopyOnWrite();
+    n->body = CopyIntrinInjector(
+        pragma_key, flower_copy_fromto)(std::move(n->body));
+    return f;
+  };
+  return CreatePrimFuncPass(pass_func, 0, "tir.InjectCopyIntrin", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.InjectCopyIntrin")
+.set_body_typed(InjectCopyIntrin);
+
+}  // namespace transform
+
 }  // namespace tir
 }  // namespace tvm
similarity index 93%
rename from src/tir/pass/inject_double_buffer.cc
rename to src/tir/transforms/inject_double_buffer.cc
index b9aa5a9..e9422fa 100644 (file)
  * \brief Inject double buffering optimization for data fetch.
  * \file inject_double_buffer.cc
  */
+#include <tvm/runtime/registry.h>
 #include <tvm/tir/ir_pass.h>
+#include <tvm/tir/transform.h>
 #include <tvm/tir/stmt_functor.h>
 #include <tvm/tir/op.h>
-#include "ir_util.h"
+#include "../pass/ir_util.h"
 #include "../../arith/compute_expr.h"
 
 namespace tvm {
@@ -273,5 +275,26 @@ class DoubleBufferInjector : public StmtExprMutator {
 Stmt InjectDoubleBuffer(Stmt stmt, int split_loop) {
   return DoubleBufferInjector(split_loop).Inject(stmt);
 }
+
+TVM_REGISTER_GLOBAL("ir_pass.InjectDoubleBuffer")
+.set_body_typed(InjectDoubleBuffer);
+
+
+namespace transform {
+
+Pass InjectDoubleBuffer(int split_loop) {
+  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
+    auto* n = f.CopyOnWrite();
+    n->body = DoubleBufferInjector(split_loop).Inject(std::move(n->body));
+    return f;
+  };
+  return CreatePrimFuncPass(pass_func, 0, "tir.InjectDoubleBuffer", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.InjectDoubleBuffer")
+.set_body_typed(InjectDoubleBuffer);
+
+}  // namespace transform
+
 }  // namespace tir
 }  // namespace tvm
similarity index 96%
rename from src/tir/pass/inject_virtual_thread.cc
rename to src/tir/transforms/inject_virtual_thread.cc
index e9c403c..c70962d 100644 (file)
 /*!
  * \file inject_virtual_thread.cc
  */
+#include <tvm/runtime/registry.h>
 #include <tvm/tir/expr.h>
 #include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
 #include <tvm/tir/ir_pass.h>
 #include <unordered_set>
 #include "../../arith/compute_expr.h"
@@ -500,5 +502,24 @@ Stmt InjectVirtualThread(Stmt stmt) {
   return ConvertSSA(std::move(stmt));
 }
 
+TVM_REGISTER_GLOBAL("ir_pass.InjectVirtualThread")
+.set_body_typed(InjectVirtualThread);
+
+namespace transform {
+
+Pass InjectVirtualThread() {
+  auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+    auto* n = f.CopyOnWrite();
+    n->body = ConvertSSA(VirtualThreadInjector()(std::move(n->body)));
+    return f;
+  };
+  return CreatePrimFuncPass(pass_func, 0, "tir.InjectVirtualThread", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.InjectVirtualThread")
+.set_body_typed(InjectVirtualThread);
+
+}  // namespace transform
+
 }  // namespace tir
 }  // namespace tvm
similarity index 90%
rename from src/tir/pass/lift_attr_scope.cc
rename to src/tir/transforms/lift_attr_scope.cc
index 9aa037f..a1d9223 100644 (file)
  *   the body contains the same scope.
  * \file lift_attr_scope.cc
  */
-#include <tvm/tir/ir_pass.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/transform.h>
 #include <tvm/tir/stmt_functor.h>
-#include "ir_util.h"
+#include "../pass/ir_util.h"
 
 namespace tvm {
 namespace tir {
@@ -191,5 +192,24 @@ Stmt LiftAttrScope(Stmt stmt, std::string attr_key) {
   return AttrScopeLifter(attr_key).Lift(std::move(stmt));
 }
 
+TVM_REGISTER_GLOBAL("ir_pass.LiftAttrScope")
+.set_body_typed(LiftAttrScope);
+
+namespace transform {
+
+Pass LiftAttrScope(std::string attr_key) {
+  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
+    auto* n = f.CopyOnWrite();
+    n->body = AttrScopeLifter(attr_key).Lift(std::move(n->body));
+    return f;
+  };
+  return CreatePrimFuncPass(pass_func, 0, "tir.LiftAttrScope", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.LiftAttrScope")
+.set_body_typed(LiftAttrScope);
+
+}  // namespace transform
+
 }  // namespace tir
 }  // namespace tvm
similarity index 96%
rename from src/tir/pass/loop_partition.cc
rename to src/tir/transforms/loop_partition.cc
index e9157e7..dbed5f2 100644 (file)
 /*!
  * \file loop_partition.cc
  */
+#include <tvm/runtime/registry.h>
 #include <tvm/tir/expr.h>
-#include <tvm/tir/stmt_functor.h>
 #include <tvm/tir/ir_pass.h>
+#include <tvm/tir/transform.h>
+#include <tvm/tir/stmt_functor.h>
 #include <tvm/arith/analyzer.h>
 #include <unordered_map>
 #include <unordered_set>
@@ -500,7 +502,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node,
   Stmt pre_stmt;
   bool pre_stmt_recurse = true;
   if (middle_interval_i->HasLowerBound()) {
-    body_begin = tir::Simplify(middle_interval.min());
+    body_begin = analyzer_.Simplify(middle_interval.min());
     if (!analyzer_.CanProve(body_begin == min)) {
       PrimExpr cond = (body_begin - min >= 0);
       if (!analyzer_.CanProve(cond)) {
@@ -525,7 +527,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node,
   Stmt post_stmt;
   bool post_stmt_recurse = true;
   if (middle_interval_i->HasUpperBound()) {
-    post_doubt_begin = tir::Simplify(middle_interval.max() + 1);
+    post_doubt_begin = analyzer_.Simplify(middle_interval.max() + 1);
     if (!analyzer_.CanProve(middle_interval.max() == max)) {
       // require the extent to be non-negative
       PrimExpr cond = (max - post_doubt_begin + 1 >= 0);
@@ -588,7 +590,7 @@ inline Stmt LoopPartitioner::MakeFor(const Object *node, PrimExpr extent, Stmt b
     return Substitute(body, {{Var{for_node->loop_var}, make_const(DataType::Int(32), 0)}});
   } else {
     return ForNode::make(for_node->loop_var, IntImm(for_node->min.dtype(), 0), extent,
-                     for_node->for_type, for_node->device_api, body);
+                         for_node->for_type, for_node->device_api, body);
   }
 }
 
@@ -610,5 +612,25 @@ Stmt LoopPartition(Stmt stmt, bool split_const_loop) {
   return stmt;
 }
 
+
+TVM_REGISTER_GLOBAL("ir_pass.LoopPartition")
+.set_body_typed(LoopPartition);
+
+namespace transform {
+
+Pass LoopPartition(bool split_const_loop) {
+  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
+    auto* n = f.CopyOnWrite();
+    n->body = LoopPartition(std::move(n->body), split_const_loop);
+    return f;
+  };
+  return CreatePrimFuncPass(pass_func, 0, "tir.LoopPartition", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.LoopPartition")
+.set_body_typed(LoopPartition);
+
+}  // namespace transform
+
 }  // namespace tir
 }  // namespace tvm
index e7f81ed..9fa7230 100644 (file)
@@ -143,6 +143,8 @@ Stmt LowerStorageAccessInfo(Stmt stmt) {
   return StorageAccessInfoLower()(std::move(stmt));
 }
 
+TVM_REGISTER_GLOBAL("ir_pass.LowerStorageAccessInfo")
+.set_body_typed(LowerStorageAccessInfo);
 
 namespace transform {
 
index 1f9d976..4aeaafd 100644 (file)
@@ -395,6 +395,10 @@ Stmt NarrowDataType(Stmt stmt, int target_bits) {
   return DataTypeRewriter(target_bits)(stmt);
 }
 
+TVM_REGISTER_GLOBAL("ir_pass.NarrowDataType")
+.set_body_typed(NarrowDataType);
+
+
 namespace transform {
 
 Pass NarrowDataType(int target_bits) {
similarity index 89%
rename from src/tir/pass/remove_no_op.cc
rename to src/tir/transforms/remove_no_op.cc
index 181a8c4..44c974f 100644 (file)
  * \file remove_no_op.cc
  * \brief Remove no op from the stmt
  */
+#include <tvm/runtime/registry.h>
 #include <tvm/tir/stmt.h>
 #include <tvm/tir/ir_pass.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/transform.h>
 #include <tvm/tir/stmt_functor.h>
 #include <unordered_map>
 
@@ -147,5 +150,25 @@ class NoOpRemover : public StmtMutator {
 Stmt RemoveNoOp(Stmt stmt) {
   return NoOpRemover()(std::move(stmt));
 }
+
+TVM_REGISTER_GLOBAL("ir_pass.RemoveNoOp")
+.set_body_typed(RemoveNoOp);
+
+namespace transform {
+
+Pass RemoveNoOp() {
+  auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+    auto* n = f.CopyOnWrite();
+    n->body = NoOpRemover()(std::move(n->body));
+    return f;
+  };
+  return CreatePrimFuncPass(pass_func, 0, "tir.RemoveNoOp", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.RemoveNoOp")
+.set_body_typed(RemoveNoOp);
+
+}  // namespace transform
+
 }  // namespace tir
 }  // namespace tvm
similarity index 89%
rename from src/tir/pass/rewrite_unsafe_select.cc
rename to src/tir/transforms/rewrite_unsafe_select.cc
index 5016492..386b4cc 100644 (file)
  * \file unsafe_select_rewrite.cc
  * \brief Rewrite uinsafe select expression.
  */
+#include <tvm/runtime/registry.h>
 #include <tvm/tir/expr.h>
 #include <tvm/tir/stmt_functor.h>
-#include <tvm/tir/ir_pass.h>
+#include <tvm/tir/transform.h>
 
 namespace tvm {
 namespace tir {
@@ -132,5 +133,24 @@ Stmt RewriteUnsafeSelect(Stmt stmt) {
   return UnsafeSelectRewriter()(std::move(stmt));
 }
 
+TVM_REGISTER_GLOBAL("ir_pass.RewriteUnsafeSelect")
+.set_body_typed(RewriteUnsafeSelect);
+
+namespace transform {
+
+Pass RewriteUnsafeSelect() {
+  auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+    auto* n = f.CopyOnWrite();
+    n->body = UnsafeSelectRewriter()(std::move(n->body));
+    return f;
+  };
+  return CreatePrimFuncPass(pass_func, 0, "tir.RewriteUnsafeSelect", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.RewriteUnsafeSelect")
+.set_body_typed(RewriteUnsafeSelect);
+
+}  // namespace transform
+
 }  // namespace tir
 }  // namespace tvm
similarity index 86%
rename from src/arith/stmt_simplify.cc
rename to src/tir/transforms/simplify.cc
index 6c3dd02..ecfa25e 100644 (file)
  */
 
 /*!
- * \file stmt_simplify.cc
+ * \file simplify.cc
  * \brief Statement simplifier based on analyzer
  */
+#include <tvm/runtime/registry.h>
 #include <tvm/tir/expr.h>
 #include <tvm/tir/ir_pass.h>
+#include <tvm/tir/transform.h>
 #include <tvm/tir/analysis.h>
 #include <tvm/arith/analyzer.h>
 
 #include <tvm/tir/op.h>
 #include <tvm/arith/analyzer.h>
-#include "ir_mutator_with_analyzer.h"
+#include "../../arith/ir_mutator_with_analyzer.h"
 
 namespace tvm {
 namespace arith {
@@ -125,5 +127,23 @@ PrimExpr Simplify(PrimExpr expr, Map<Var, Range> vrange) {
 Stmt Simplify(Stmt stmt, Map<Var, Range> vrange) {
   return CanonicalSimplify(std::move(stmt), vrange);
 }
+
+namespace transform {
+
+Pass Simplify() {
+  auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+    auto* n = f.CopyOnWrite();
+    arith::Analyzer analyzer;
+    n->body = arith::StmtSimplifier(&analyzer).Simplify(std::move(n->body));
+    return f;
+  };
+  return CreatePrimFuncPass(pass_func, 0, "tir.Simplify", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.Simplify")
+.set_body_typed(Simplify);
+
+}  // namespace transform
+
 }  // namespace tir
 }  // namespace tvm
similarity index 98%
rename from src/tir/pass/storage_rewrite.cc
rename to src/tir/transforms/storage_rewrite.cc
index f3604b6..c13879c 100644 (file)
  * \brief Memory access pattern analysis and optimization.
  *  Re-write data access to enable memory sharing when possible.
  */
+#include <tvm/runtime/registry.h>
 #include <tvm/arith/analyzer.h>
 #include <tvm/tir/expr.h>
 #include <tvm/tir/ir_pass.h>
+#include <tvm/tir/transform.h>
 #include <tvm/tir/analysis.h>
 #include <tvm/tir/stmt_functor.h>
 #include <tvm/target/target_info.h>
 #include <map>
 #include <unordered_set>
 #include <unordered_map>
-#include "ir_util.h"
+#include "../pass/ir_util.h"
 #include "../../arith/compute_expr.h"
 #include "../../runtime/thread_storage_scope.h"
 
@@ -1039,5 +1041,26 @@ Stmt StorageRewrite(Stmt stmt) {
   stmt = StoragePlanRewriter().Rewrite(std::move(stmt), true);
   return VectorAllocRewriter()(std::move(stmt));
 }
+
+TVM_REGISTER_GLOBAL("ir_pass.StorageRewrite")
+.set_body_typed(StorageRewrite);
+
+namespace transform {
+
+Pass StorageRewrite() {
+  auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+    auto* n = f.CopyOnWrite();
+    n->body = StoragePlanRewriter().Rewrite(std::move(n->body), true);
+    n->body = VectorAllocRewriter()(std::move(n->body));
+    return f;
+  };
+  return CreatePrimFuncPass(pass_func, 0, "tir.StorageRewrite", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.StorageRewrite")
+.set_body_typed(StorageRewrite);
+
+}  // namespace transform
+
 }  // namespace tir
 }  // namespace tvm
similarity index 88%
rename from src/tir/pass/unroll_loop.cc
rename to src/tir/transforms/unroll_loop.cc
index 0167dbc..27c39d4 100644 (file)
  * \file unroll_loop.cc
  */
 // Unrolls the loop as in Halide pipeline.
+#include <tvm/runtime/registry.h>
 #include <tvm/tir/expr.h>
+#include <tvm/tir/op.h>
 #include <tvm/tir/ir_pass.h>
+#include <tvm/tir/transform.h>
 #include <tvm/tir/stmt_functor.h>
 #include <unordered_set>
 #include <unordered_map>
@@ -201,13 +204,31 @@ Stmt UnrollLoop(Stmt stmt,
   }
 }
 
-Stmt UnrollLoopExplicitly(Stmt stmt) {
-  const ForNode* op = stmt.as<ForNode>();
-  if (!op) {
-    LOG(FATAL) << "attempted to unroll a non-loop statement";
-  }
-  return LoopUnroller(0, 0, 0, false).Unroll(op);
+TVM_REGISTER_GLOBAL("ir_pass.UnrollLoop")
+.set_body_typed(UnrollLoop);
+
+namespace transform {
+
+Pass UnrollLoop(int auto_max_step,
+                int auto_max_depth,
+                int auto_max_extent,
+                bool explicit_unroll) {
+  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
+    auto* n = f.CopyOnWrite();
+    n->body = UnrollLoop(std::move(f->body),
+                         auto_max_step,
+                         auto_max_depth,
+                         auto_max_extent,
+                         explicit_unroll);
+    return f;
+  };
+  return CreatePrimFuncPass(pass_func, 0, "tir.UnrollLoop", {});
 }
 
+TVM_REGISTER_GLOBAL("tir.transform.UnrollLoop")
+.set_body_typed(UnrollLoop);
+
+}  // namespace transform
+
 }  // namespace tir
 }  // namespace tvm
similarity index 95%
rename from src/tir/pass/vectorize_loop.cc
rename to src/tir/transforms/vectorize_loop.cc
index b73587d..cc4361d 100644 (file)
  * \file vectorize_loop.cc
  */
 // Loop vectorizer as in Halide pipeline.
+#include <tvm/runtime/registry.h>
 #include <tvm/tir/expr.h>
-#include <tvm/tir/ir_pass.h>
+#include <tvm/tir/transform.h>
 #include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/ir_pass.h>
 #include <tvm/arith/analyzer.h>
 #include <unordered_set>
 #include <unordered_map>
@@ -539,8 +541,9 @@ class VectorizeSkipper : public StmtMutator {
     Stmt stmt = StmtMutator::VisitStmt_(op);
     op = stmt.as<ForNode>();
     if (op->for_type == ForType::Vectorized) {
-      return ForNode::make(op->loop_var, op->min, op->extent, ForType::Serial, op->device_api,
-                       op->body);
+      return ForNode::make(op->loop_var, op->min, op->extent,
+                           ForType::Serial, op->device_api,
+                           op->body);
     } else {
        return stmt;
     }
@@ -551,5 +554,32 @@ Stmt SkipVectorize(Stmt stmt) {
   return VectorizeSkipper()(std::move(stmt));
 }
 
+TVM_REGISTER_GLOBAL("ir_pass.VectorizeLoop")
+.set_body_typed(VectorizeLoop);
+
+TVM_REGISTER_GLOBAL("ir_pass.SkipVectorize")
+.set_body_typed(SkipVectorize);
+
+namespace transform {
+
+// TODO(tvm-team): Make it as a target property.
+Pass VectorizeLoop(bool enable_vectorize) {
+  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
+    auto* n = f.CopyOnWrite();
+    if (enable_vectorize) {
+      n->body = LoopVectorizer()(std::move(n->body));
+    } else {
+      n->body = VectorizeSkipper()(std::move(n->body));
+    }
+    return f;
+  };
+  return CreatePrimFuncPass(pass_func, 0, "tir.VectorizeLoop", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.VectorizeLoop")
+.set_body_typed(VectorizeLoop);
+
+}  // namespace transform
+
 }  // namespace tir
 }  // namespace tvm
diff --git a/tests/python/unittest/test_tir_pass_virtual_thread.py b/tests/python/unittest/test_tir_pass_virtual_thread.py
deleted file mode 100644 (file)
index 2d96696..0000000
+++ /dev/null
@@ -1,45 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-import tvm
-from tvm import te
-
-def test_virtual_thread():
-    m = te.var('m')
-    A = te.placeholder((m, ), name='A')
-    A1 = te.compute((m,), lambda i: A[i], name='A1')
-    A2 = te.compute((m,), lambda i: A1[i] + 3, name='A2')
-
-    s = te.create_schedule(A2.op)
-    vx = te.thread_axis("vthread", name="vx")
-    xo, xi = s[A2].split(A2.op.axis[0], nparts=2)
-    s[A2].bind(xo, vx)
-    xo, xi = s[A2].split(xi, 8)
-    s[A1].compute_at(s[A2], xo)
-
-    bounds = tvm.te.schedule.InferBound(s)
-    assert isinstance(bounds, tvm.container.Map)
-    stmt = tvm.te.schedule.ScheduleOps(s, bounds)
-
-    Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
-    A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name='A2')
-    stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64)
-    stmt = tvm.tir.ir_pass.Simplify(stmt)
-    stmt = tvm.tir.ir_pass.InjectVirtualThread(stmt)
-    print(stmt)
-
-if __name__ == "__main__":
-    test_virtual_thread()
@@ -37,7 +37,10 @@ def test_coproc_sync():
                 ib.scope_attr(cp, "coproc_scope", 1)
                 A[j] = A[j + k * 10] + 2
     stmt = ib.get()
-    stmt = tvm.tir.ir_pass.CoProcSync(stmt)
+
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], stmt))
+    stmt = tvm.tir.transform.CoProcSync()(mod)["main"].body
+
     body = stmt.body.body.body
     blist = tvm.tir.stmt_list(body)
     assert(blist[1].value.name == "cop.coproc_read_barrier")
@@ -65,7 +68,10 @@ def test_coproc_sync2():
             ib.scope_attr(cp, "coproc_scope", 2)
             A[ty] = 1.0
     stmt = ib.get()
-    stmt = tvm.tir.ir_pass.CoProcSync(stmt)
+
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], stmt))
+    stmt = tvm.tir.transform.CoProcSync()(mod)["main"].body
+
 
 def test_coproc_sync3():
     def __check_list(tvm_array, py_list):
@@ -91,7 +97,10 @@ def test_coproc_sync3():
         A[0] = 0.0
 
     stmt = ib.get()
-    stmt = tvm.tir.ir_pass.CoProcSync(stmt)
+
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], stmt))
+    stmt = tvm.tir.transform.CoProcSync()(mod)["main"].body
+
     slist = tvm.tir.stmt_list(stmt[0].body.body)
     push_st = slist[2]
     slist = tvm.tir.stmt_list(slist[-1])
@@ -35,7 +35,10 @@ def test_copy2d():
         assert src.strides[0] == l
         assert tuple(src.shape) == (m, l)
         return tvm.tir.Evaluate(0)
-    stmt = tvm.tir.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)
+
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
+    stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body
+
 
 def test_copy_pad():
     m = te.var('m')
@@ -59,7 +62,10 @@ def test_copy_pad():
         assert pad_after[1].value == 0
         assert pad_value.value == 1.0
         return tvm.tir.Evaluate(0)
-    stmt = tvm.tir.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)
+
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
+    stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body
+
 
 def test_single_point_test():
     A = te.placeholder((1,), name='A')
@@ -78,7 +84,10 @@ def test_single_point_test():
         assert tvm.tir.ir_pass.Simplify(src.strides[0]).value == 1
         assert tvm.tir.ir_pass.Simplify(dst.strides[0]).value == 1
         return tvm.tir.Evaluate(0)
-    stmt = tvm.tir.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)
+
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
+    stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body
+
 
 def assert_expr_equal(a, b):
     assert tvm.tir.ir_pass.Simplify(a - b).value == 0
@@ -111,7 +120,11 @@ def test_copy_pad_split():
         assert_expr_equal(pad_after[0], rpad_after)
         assert_expr_equal(src.shape[0], 6 - rpad_before - rpad_after)
         return tvm.tir.Evaluate(0)
-    stmt = tvm.tir.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)
+
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
+    stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body
+
+
 
 
 if __name__ == "__main__":
@@ -36,13 +36,19 @@ def test_double_buffer():
             C[j] = B[j] + 1
 
     stmt = ib.get()
-    stmt = tvm.tir.ir_pass.InjectDoubleBuffer(stmt, 2)
-    stmt = tvm.tir.ir_pass.Simplify(stmt)
-    assert isinstance(stmt.body.body, tvm.tir.Allocate)
-    assert stmt.body.body.extents[0].value == 2
     mod = tvm.IRModule({
         "db" : tvm.tir.PrimFunc([A.asobject(), C.asobject()], stmt)
     })
+
+    opt = tvm.transform.Sequential(
+        [tvm.tir.transform.InjectDoubleBuffer(2),
+         tvm.tir.transform.Simplify()])
+    mod = opt(mod)
+    stmt = mod["db"].body
+
+    assert isinstance(stmt.body.body, tvm.tir.Allocate)
+    assert stmt.body.body.extents[0].value == 2
+
     f = tvm.tir.transform.ThreadSync("shared")(mod)["db"]
     count = [0]
     def count_sync(op):
@@ -40,9 +40,14 @@ def test_vthread():
             C[i * nthread + tx] = B[i] + 1
         return ib.get()
 
-    stmt = tvm.tir.ir_pass.InjectVirtualThread(get_vthread("vthread"))
+    stmt = tvm.tir.transform.InjectVirtualThread()(tvm.IRModule.from_expr(
+        tvm.tir.PrimFunc([], get_vthread("vthread"))))["main"].body
+
     assert stmt.body.body.extents[0].value == 2
-    stmt = tvm.tir.ir_pass.InjectVirtualThread(get_vthread("cthread"))
+
+    stmt = tvm.tir.transform.InjectVirtualThread()(tvm.IRModule.from_expr(
+        tvm.tir.PrimFunc([], get_vthread("cthread"))))["main"].body
+
     assert len(stmt.body.body.extents) == 3
 
 
@@ -67,16 +72,20 @@ def test_vthread_extern():
             A[tx] = tx + 1.0
             B[ty] = ty + 1.0
             ib.emit(tvm.tir.call_extern("int32", "Run",
-                                    abuffer.access_ptr("r"),
-                                    bbuffer.access_ptr("r"),
-                                    cbuffer.access_ptr("rw")))
+                                        abuffer.access_ptr("r"),
+                                        bbuffer.access_ptr("r"),
+                                        cbuffer.access_ptr("rw")))
         return ib.get()
 
-    stmt = tvm.tir.ir_pass.InjectVirtualThread(get_vthread("vthread"))
+
+    stmt = tvm.tir.transform.InjectVirtualThread()(tvm.IRModule.from_expr(
+        tvm.tir.PrimFunc([], get_vthread("cthread"))))["main"].body
+
     assert stmt.body.body.extents[0].value == 2
     assert stmt.body.body.body.body.body.body.extents[0].value == 2
     assert len(stmt.body.body.body.body.body.body.extents) == 3
 
+
 def test_vthread_if_then_else():
     nthread = 2
     tx = te.thread_axis("vthread")
@@ -92,7 +101,10 @@ def test_vthread_if_then_else():
         with ib.if_scope(i == 0):
             B[i] = A[i * nthread + tx] + 2
     stmt = ib.get()
-    stmt = tvm.tir.ir_pass.InjectVirtualThread(stmt)
+
+    stmt = tvm.tir.transform.InjectVirtualThread()(tvm.IRModule.from_expr(
+        tvm.tir.PrimFunc([], stmt)))["main"].body
+
     assert stmt.body.body.body[0].else_case != None
     assert stmt.body.body.body[1].else_case == None
 
@@ -18,32 +18,12 @@ import pytest
 import tvm
 from tvm import te
 import numpy as np
+
 def collect_visit(stmt, f):
     ret = []
     tvm.tir.ir_pass.PostOrderVisit(stmt, lambda x: ret.append(f(x)))
     return ret
 
-def lower(sch, args):
-    binds = {}
-    arg_list = []
-    for x in args:
-        if isinstance(x, te.tensor.Tensor):
-            buf = tvm.tir.decl_buffer(x.shape, dtype=x.dtype, name=x.name)
-            assert x not in binds
-            binds[x] = buf
-            arg_list.append(buf)
-        else:
-            raise ValueError("args must be Tensor, Buffer or Var")
-    sch = sch.normalize()
-    bounds = tvm.te.schedule.InferBound(sch)
-    stmt = tvm.te.schedule.ScheduleOps(sch, bounds)
-    stmt = tvm.tir.ir_pass.LoopPartition(stmt, True)
-    stmt = tvm.tir.ir_pass.RemoveNoOp(stmt)
-    stmt = tvm.tir.ir_pass.StorageFlatten(stmt, binds, 64, True)
-    stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt)
-    stmt = tvm.tir.ir_pass.VectorizeLoop(stmt)
-    stmt = tvm.tir.ir_pass.Simplify(stmt)
-    return stmt
 
 @pytest.mark.xfail
 def test_out_of_bounds_llvm(index_a, index_b):
@@ -72,7 +52,6 @@ def test_in_bounds_llvm():
     tgt = "llvm"
     tgt_host = "llvm"
     stmt = tvm.lower (s, [A, B, C], simple_mode=True)
-    print (stmt)
     fadd = tvm.build (s, [A, B, C], tgt, target_host=tgt_host, name="myadd")
     ctx = tvm.context(tgt, 0)
     a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx)
@@ -93,7 +72,6 @@ def test_out_of_bounds_vectorize_llvm(nn, index_a, index_b):
     tgt = "llvm"
     tgt_host = "llvm"
     stmt = tvm.lower (s, [a, b, c], simple_mode=True)
-    print (stmt)
     f = tvm.build(s, [a, b, c], tgt, target_host=tgt_host, name="myaddvec")
     ctx = tvm.cpu(0)
     n = nn
@@ -192,13 +170,11 @@ def test_in_bounds_const_loop_partition_ir():
     s = te.create_schedule(T.op)
     xo, xi = s[T].split(T.op.axis[0], factor=4)
 
-    bounds = tvm.te.schedule.InferBound(s)
-    stmt = lower (s, [A, B, T])
-    # num_attributes = num_buffers * num_splits = 2 * 3
-    # before instrumentation
-    assert_bound_instrumentation(stmt, check_attr_stmt, 2 * 3)
-    assert_bound_instrumentation(stmt, check_branch_stmt, 0)
-    stmt = tvm.tir.ir_pass.InstrumentBoundCheckers(stmt)
+    with tvm.target.build_config(instrument_bound_checkers=True,
+                                 partition_const_loop=True):
+        mod = tvm.driver.lower(s, [A, B, T], name="main")
+
+    stmt = mod["main"].body
     # after instrumentation
     assert_bound_instrumentation(stmt, check_attr_stmt, 2 * 3)
     assert_bound_instrumentation(stmt, check_branch_stmt, 2)
@@ -209,7 +185,8 @@ def test_in_bounds_const_loop_partition_ir():
 
 
 def test_in_bounds_const_loop_partition_llvm():
-    with tvm.target.build_config(instrument_bound_checkers=True, partition_const_loop=True):
+    with tvm.target.build_config(instrument_bound_checkers=True,
+                                 partition_const_loop=True):
         n = 21
         A = te.placeholder((n, ), name='A')
         B = te.placeholder((n, ), name='B')
@@ -35,7 +35,10 @@ def test_coproc_lift():
                 A[j] = A[j] + 3
                 A[j] = A[j] + 3
     body = ib.get()
-    body = tvm.tir.ir_pass.LiftAttrScope(body, "coproc_uop_scope")
+
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body))
+    body = tvm.tir.transform.LiftAttrScope("coproc_uop_scope")(mod)["main"].body
+
     assert body.body.body.node == cp
 
     # only able to lift to the common pattern of the last two fors.
@@ -52,7 +55,10 @@ def test_coproc_lift():
             A[i] = A[i] + 2
 
     body = ib.get()
-    body = tvm.tir.ir_pass.LiftAttrScope(body, "coproc_uop_scope")
+
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body))
+    body = tvm.tir.transform.LiftAttrScope("coproc_uop_scope")(mod)["main"].body
+
     assert body.body.body.body[1].node == cp
     assert len(body.body.body.body) == 2
 
@@ -23,26 +23,6 @@ def collect_visit(stmt, f):
     tvm.tir.ir_pass.PostOrderVisit(stmt, lambda x : ret.append(f(x)))
     return ret
 
-def lower(sch, args):
-    binds = {}
-    arg_list = []
-    for x in args:
-        if isinstance(x, te.tensor.Tensor):
-            buf = tvm.tir.decl_buffer(x.shape, dtype=x.dtype, name=x.name)
-            assert x not in binds
-            binds[x] = buf
-            arg_list.append(buf)
-        else:
-            raise ValueError("args must be Tensor, Buffer or Var")
-    sch = sch.normalize()
-    bounds = tvm.te.schedule.InferBound(sch)
-    stmt = tvm.te.schedule.ScheduleOps(sch, bounds)
-    stmt = tvm.tir.ir_pass.LoopPartition(stmt, False)
-    stmt = tvm.tir.ir_pass.StorageFlatten(stmt, binds, 64)
-    stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt)
-    stmt = tvm.tir.ir_pass.VectorizeLoop(stmt)
-    stmt = tvm.tir.ir_pass.Simplify(stmt)
-    return stmt
 
 def test_basic():
     n = te.size_var('n')
@@ -55,10 +35,16 @@ def test_basic():
 
     bounds = tvm.te.schedule.InferBound(s)
     stmt = tvm.te.schedule.ScheduleOps(s, bounds)
-    stmt = tvm.tir.ir_pass.LoopPartition(stmt, False)
-    stmt = tvm.tir.ir_pass.Simplify(stmt)
-    assert('if' not in str(stmt.body.body[0]))
-    assert('if' in str(stmt.body.body[1]))
+
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], stmt))
+    mod = tvm.tir.transform.LoopPartition(False)(mod)
+    stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+
+    assert(not any(
+        collect_visit(stmt.body.body[0], lambda x: isinstance(x, tvm.tir.IfThenElse))))
+    assert(any(
+        collect_visit(stmt.body.body[1], lambda x: isinstance(x, tvm.tir.IfThenElse))))
+
 
 def test_const_loop():
     n = 21
@@ -71,9 +57,12 @@ def test_const_loop():
 
     bounds = tvm.te.schedule.InferBound(s)
     stmt = tvm.te.schedule.ScheduleOps(s, bounds)
-    stmt = tvm.tir.ir_pass.LoopPartition(stmt, True)
-    stmt = tvm.tir.ir_pass.Simplify(stmt)
-    assert('if' not in str(stmt.body.body[0]))
+
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+    mod = tvm.tir.transform.LoopPartition(True)(mod)
+    stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+
+    assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))))
 
 def test_multi_loop():
     ib = tvm.tir.ir_builder.create()
@@ -87,8 +76,11 @@ def test_multi_loop():
                 with ib.else_scope():
                     ib.emit(tvm.tir.Evaluate(n))
     stmt = ib.get()
-    stmt = tvm.tir.ir_pass.LoopPartition(stmt, False)
-    stmt = tvm.tir.ir_pass.Simplify(stmt)
+
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n, m], stmt))
+    mod = tvm.tir.transform.LoopPartition(False)(mod)
+    stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+
     assert(not any(collect_visit(stmt.body[0], lambda x: isinstance(x, tvm.tir.IfThenElse))))
 
 def test_multi_if():
@@ -107,9 +99,14 @@ def test_multi_if():
                 with ib.else_scope():
                     ib.emit(tvm.tir.Evaluate(n))
     stmt = ib.get()
-    stmt = tvm.tir.ir_pass.LoopPartition(stmt, False)
-    stmt = tvm.tir.ir_pass.Simplify(stmt)
-    assert('if' not in str(stmt.body[0]))
+
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+    mod = tvm.tir.transform.LoopPartition(False)(mod)
+    stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+
+    assert(not any(
+        collect_visit(stmt.body[0], lambda x: isinstance(x, tvm.tir.IfThenElse))))
+
 
 def test_thread_axis():
     m = te.size_var('m')
@@ -126,9 +123,14 @@ def test_thread_axis():
 
     bounds = tvm.te.schedule.InferBound(s)
     stmt = tvm.te.schedule.ScheduleOps(s, bounds)
-    stmt = tvm.tir.ir_pass.LoopPartition(stmt, False)
-    stmt = tvm.tir.ir_pass.Simplify(stmt)
-    assert('if' not in str(stmt.body.body[0]))
+
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+    mod = tvm.tir.transform.LoopPartition(False)(mod)
+    stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+
+    assert(not any(
+        collect_visit(stmt.body. body[0], lambda x: isinstance(x, tvm.tir.IfThenElse))))
+
 
 def test_vectorize():
     n = te.size_var('n')
@@ -147,11 +149,12 @@ def test_vectorize():
     s[C].bind(bx, te.thread_axis("blockIdx.x"))
     s[C].bind(tx, te.thread_axis("threadIdx.x"))
     s[C].vectorize(x)
-    stmt = lower(s, [A, B])
+    stmt = tvm.lower(s, [A, B], name="main")["main"].body
     body = stmt.body.body.body.body
     assert(x.var.name not in str(body.condition))
     assert(any(collect_visit(body.then_case, lambda x: isinstance(x, tvm.tir.Ramp))))
 
+
 def test_condition():
     ib = tvm.tir.ir_builder.create()
     m = te.size_var('m')
@@ -161,10 +164,14 @@ def test_condition():
         ib.emit(tvm.tir.Evaluate(
           tvm.tir.Select(ib.likely(i*4+j<n), m, n)))
     stmt = ib.get()
-    stmt = tvm.tir.ir_pass.LoopPartition(stmt, False)
-    stmt = tvm.tir.ir_pass.Simplify(stmt)
+
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([m, n], stmt))
+    mod = tvm.tir.transform.LoopPartition(False)(mod)
+    stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+
     assert(not any(collect_visit(stmt[0], lambda x: isinstance(x, tvm.tir.Select))))
 
+
 def test_condition_EQ():
     ib = tvm.tir.ir_builder.create()
     m = te.size_var('m')
@@ -173,10 +180,14 @@ def test_condition_EQ():
             ib.emit(tvm.tir.Evaluate(
                 tvm.tir.Select(ib.likely(tvm.tir.EQ(i, 5)), m, n)))
     stmt = ib.get()
-    stmt = tvm.tir.ir_pass.LoopPartition(stmt, True)
-    stmt = tvm.tir.ir_pass.Simplify(stmt)
+
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([m, n], stmt))
+    mod = tvm.tir.transform.LoopPartition(True)(mod)
+    stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+
     assert(not any(collect_visit(stmt[0], lambda x: isinstance(x, tvm.tir.Select))))
 
+
 def test_thread_axis2():
     n = tvm.runtime.convert(4096)
     m = te.size_var('m')
@@ -190,7 +201,7 @@ def test_thread_axis2():
     _,  x = s[C].split(x, factor=m)
     s[C].bind(bx, te.thread_axis("blockIdx.x"))
     s[C].bind(tx, te.thread_axis("threadIdx.x"))
-    stmt = lower(s, [A, B])
+    stmt = tvm.lower(s, [A, B], name="main")["main"].body
     for_body = stmt.body.body.body.body[0]
     assert('threadIdx' not in str(for_body.extent))
 
@@ -204,8 +215,12 @@ def test_everything_during_deduction():
                 # this guard will produce everything during deduction
                 ib.emit(tvm.tir.Evaluate(m))
     stmt = ib.get()
-    stmt = tvm.tir.ir_pass.LoopPartition(stmt, False)
-    stmt = tvm.tir.ir_pass.Simplify(stmt)
+
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([m, n], stmt))
+    mod = tvm.tir.transform.LoopPartition(False)(mod)
+    stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+
+
     assert(isinstance(stmt.body.body, tvm.tir.IfThenElse))
 
 def test_single_likely():
@@ -220,8 +235,11 @@ def test_single_likely():
 
     bounds = tvm.te.schedule.InferBound(s)
     stmt = tvm.te.schedule.ScheduleOps(s, bounds)
-    stmt = tvm.tir.ir_pass.LoopPartition(stmt, True)
-    stmt = tvm.tir.ir_pass.Simplify(stmt)
+
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+    mod = tvm.tir.transform.LoopPartition(True)(mod)
+    stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+
     assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))))
 
 def test_multi_likely():
@@ -241,10 +259,14 @@ def test_multi_likely():
 
     bounds = tvm.te.schedule.InferBound(s)
     stmt = tvm.te.schedule.ScheduleOps(s, bounds)
-    stmt = tvm.tir.ir_pass.LoopPartition(stmt, True)
-    stmt = tvm.tir.ir_pass.Simplify(stmt)
+
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+    mod = tvm.tir.transform.LoopPartition(True)(mod)
+    stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+
     assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))))
 
+
 def test_oneD_pool():
     m = te.size_var('m')
     ib = tvm.tir.ir_builder.create()
@@ -268,10 +290,14 @@ def test_oneD_pool():
                     out[ow] = tvm.te.max(out[ow], data[ow + kw - 1])
 
     stmt = ib.get()
-    stmt = tvm.tir.ir_pass.LoopPartition(stmt, True)
-    stmt = tvm.tir.ir_pass.Simplify(stmt)
+
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([m, data, out], stmt))
+    mod = tvm.tir.transform.LoopPartition(True)(mod)
+    stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+
     assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))))
 
+
 def test_cce_loop_1():
   ib = tvm.tir.ir_builder.create()
   dtype = 'float16'
@@ -289,8 +315,11 @@ def test_cce_loop_1():
           with ib.if_scope(ib.likely(((i*160) + j) < 1600)):
                A[(i+1)*m+j+1] = B[(i)*m+j+1] + B[(i+1)*m+j+1] + B[(i+2)*m+j+1]
   stmt = ib.get()
-  stmt = tvm.tir.ir_pass.LoopPartition(stmt, True)
-  stmt = tvm.tir.ir_pass.Simplify(stmt)
+
+  mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
+  mod = tvm.tir.transform.LoopPartition(True)(mod)
+  stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+
   assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))))
 
 def test_cce_loop_2():
@@ -308,8 +337,12 @@ def test_cce_loop_2():
       ib.emit(tvm.tir.call_extern('float32', "cce_intrisic", head, tail))
 
   stmt = ib.get()
-  stmt = tvm.tir.ir_pass.LoopPartition(stmt, True)
-  stmt = tvm.tir.ir_pass.Simplify(stmt)
+
+
+  mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+  mod = tvm.tir.transform.LoopPartition(True)(mod)
+  stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+
   assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))))
 
 
@@ -326,10 +359,14 @@ def test_cce_loop_3():
                 ib.emit(tvm.tir.call_extern('float16',"cce_intrisic",head1))
 
     stmt = ib.get()
-    stmt = tvm.tir.ir_pass.LoopPartition(stmt,True)
-    stmt = tvm.tir.ir_pass.Simplify(stmt)
+
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+    mod = tvm.tir.transform.LoopPartition(True)(mod)
+    stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+
     assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))))
 
+
 def test_conv_tiling():
     HSTR = WSTR = 1
     in_channel = 128
@@ -355,8 +392,11 @@ def test_conv_tiling():
     oho, owo, ohi, owi = s[conv].tile(oh, ow, 16, 16)
     bounds = tvm.te.schedule.InferBound(s)
     stmt = tvm.te.schedule.ScheduleOps(s, bounds)
-    stmt = tvm.tir.ir_pass.LoopPartition(stmt, True)
-    stmt = tvm.tir.ir_pass.Simplify(stmt)
+
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+    mod = tvm.tir.transform.LoopPartition(True)(mod)
+    stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+
     assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))))
 
 
@@ -426,14 +466,15 @@ def test_simple_rfactor():
 
     s.normalize()
     bounds = tvm.te.schedule.InferBound(s)
-
     stmt1 = tvm.te.schedule.ScheduleOps(s, bounds)
-    stmt1 = tvm.tir.ir_pass.Simplify(stmt1)
 
-    stmt2 = tvm.tir.ir_pass.LoopPartition(stmt1, True)
-    stmt2 = tvm.tir.ir_pass.Simplify(stmt2)
+    mod1 = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt1))
+    stmt1 = tvm.tir.transform.Simplify()(mod1)["main"].body
+
+    mod2 = tvm.tir.transform.LoopPartition(True)(mod1)
+    stmt2 = tvm.tir.transform.Simplify()(mod2)["main"].body
 
-    #make sure loop partition actually did something
+    # make sure loop partition actually did something
     assert not tvm.ir.structural_equal(stmt1.body, stmt2.body)
 
 
@@ -36,16 +36,24 @@ def test_remove_no_op():
                 k, 0, m, 0, 0,
                 tvm.tir.IfThenElse(
                     (i*m+j+k < n), tvm.tir.Evaluate(m), tvm.tir.Evaluate(n)))))
-    ret = tvm.tir.ir_pass.RemoveNoOp(stmt)
+
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt))
+    ret = tvm.tir.transform.RemoveNoOp()(mod)["main"].body
+
     assert(isinstance(ret, tvm.tir.Evaluate))
     store = tvm.tir.Store(Ab.data,
                            tvm.tir.Load(dtype, Ab.data, i) + 1,
                            i + 1)
     stmt2 = tvm.tir.SeqStmt([nop(), tvm.tir.SeqStmt([store, nop()])])
-    assert(tvm.tir.ir_pass.RemoveNoOp(stmt2) == store)
+
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt2))
+    ret = tvm.tir.transform.RemoveNoOp()(mod)["main"].body
+    assert(ret == store)
+
     # remove zero extent loop
     stmt3 = tvm.tir.For(i, 0, 0, 0, 0, store)
-    ret = tvm.tir.ir_pass.RemoveNoOp(stmt3)
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt3))
+    ret = tvm.tir.transform.RemoveNoOp()(mod)["main"].body
     assert(isinstance(ret, tvm.tir.Evaluate))
 
 
@@ -23,14 +23,22 @@ def test_rewrite_Select():
     A = ib.allocate("float32", 100, name="A", scope="global")
     i = te.var("i")
     y = tvm.tir.Select(i > 1, A[i-1], 1.0)
-    yy = tvm.tir.ir_pass.RewriteUnsafeSelect(tvm.tir.Evaluate(y)).value
+
+    mod = tvm.IRModule.from_expr(
+        tvm.tir.PrimFunc([i], tvm.tir.Evaluate(y)))
+    yy = tvm.tir.transform.RewriteUnsafeSelect()(mod)["main"].body.value
 
     z = tvm.tir.Select(
         tvm.tir.Select(i > 1, A[i-1], 1.0) > 0.0, A[i], 0.1)
-    zz = tvm.tir.ir_pass.RewriteUnsafeSelect(tvm.tir.Evaluate(z)).value
+    mod = tvm.IRModule.from_expr(
+        tvm.tir.PrimFunc([i], tvm.tir.Evaluate(z)))
+    zz = tvm.tir.transform.RewriteUnsafeSelect()(mod)["main"].body.value
+
+    a = tvm.tir.Select(tvm.tir.floordiv(i, 4) > 10, y, z)
 
-    a = tvm.tir.Select(tvm.te.floordiv(i, 4) > 10, y, z)
-    aa = tvm.tir.ir_pass.RewriteUnsafeSelect(tvm.tir.Evaluate(a)).value
+    mod = tvm.IRModule.from_expr(
+        tvm.tir.PrimFunc([i], tvm.tir.Evaluate(a)))
+    aa = tvm.tir.transform.RewriteUnsafeSelect()(mod)["main"].body.value
     assert yy.name == "tvm_if_then_else"
     assert zz.name == "tvm_if_then_else"
     assert isinstance(aa, tvm.tir.Select)
@@ -27,7 +27,9 @@ def test_stmt_simplify():
             A[i] = C[i]
 
     body = tvm.tir.LetStmt(n, 10, ib.get())
-    body = tvm.tir.ir_pass.CanonicalSimplify(body)
+    mod = tvm.IRModule.from_expr(
+        tvm.tir.PrimFunc([A, C, n], body))
+    body = tvm.tir.transform.Simplify()(mod)["main"].body
     assert isinstance(body.body, tvm.tir.Store)
 
 
@@ -44,7 +46,9 @@ def test_thread_extent_simplify():
     with ib.if_scope(tx + ty < 12):
         A[tx] = C[tx + ty]
     body = tvm.tir.LetStmt(n, 10, ib.get())
-    body = tvm.tir.ir_pass.CanonicalSimplify(body)
+    mod = tvm.IRModule.from_expr(
+        tvm.tir.PrimFunc([A, C, n], body))
+    body = tvm.tir.transform.Simplify()(mod)["main"].body
     assert isinstance(body.body.body.body, tvm.tir.Store)
 
 
@@ -33,9 +33,12 @@ def test_storage_share():
     Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
     Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B')
     stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64)
-    stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt)
-    stmt = tvm.tir.ir_pass.Simplify(stmt)
-    stmt = tvm.tir.ir_pass.StorageRewrite(stmt)
+
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
+    mod = tvm.tir.transform.Simplify()(mod)
+    mod = tvm.tir.transform.StorageRewrite()(mod)
+    stmt = mod["main"].body
+
     # verify only have one allocations.
     # verify inplace folding works
     num_alloc = [0]
@@ -72,7 +75,10 @@ def test_alloc_seq():
             A[j] = 1.3
 
     body = ib.get()
-    body = tvm.tir.ir_pass.StorageRewrite(body)
+
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body))
+    body = tvm.tir.transform.StorageRewrite()(mod)["main"].body
+
     num_alloc = [0]
     def verify(n):
         if isinstance(n, tvm.tir.Allocate):
@@ -129,7 +135,10 @@ def test_alloc_different_dtypes():
 
         body = stmt_generater(dtype_list, length)
         offset = offset_generater(dtype_list, length)
-        body = tvm.tir.ir_pass.StorageRewrite(body)
+
+        mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], body))
+        body = tvm.tir.transform.StorageRewrite()(mod)["main"].body
+
         tvm.tir.ir_pass.PostOrderVisit(body, verify)
 
     length = 1024
@@ -160,9 +169,12 @@ def test_inplace_rule():
     Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
     Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B')
     stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64)
-    stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt)
-    stmt = tvm.tir.ir_pass.Simplify(stmt)
-    stmt = tvm.tir.ir_pass.StorageRewrite(stmt)
+
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
+    mod = tvm.tir.transform.Simplify()(mod)
+    mod = tvm.tir.transform.StorageRewrite()(mod)
+    stmt = mod["main"].body
+
     # verify only have one allocations.
     # verify inplace folding works
     num_alloc = [0]
@@ -192,9 +204,12 @@ def test_storage_combine():
     Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
     Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B')
     stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64)
-    stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt)
-    stmt = tvm.tir.ir_pass.Simplify(stmt)
-    stmt = tvm.tir.ir_pass.StorageRewrite(stmt)
+
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
+    mod = tvm.tir.transform.Simplify()(mod)
+    mod = tvm.tir.transform.StorageRewrite()(mod)
+    stmt = mod["main"].body
+
     num_alloc = [0]
     def verify(n):
         if isinstance(n, tvm.tir.Allocate):
@@ -226,9 +241,12 @@ def test_storage_share_gpu():
     Ab = tvm.tir.decl_buffer(A[0].shape, A[0].dtype, name='A')
     Bb = tvm.tir.decl_buffer(A[0].shape, A[0].dtype, name='B')
     stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A[0]: Ab, A[-1]: Bb}, 64)
-    stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt)
-    stmt = tvm.tir.ir_pass.Simplify(stmt)
-    stmt = tvm.tir.ir_pass.StorageRewrite(stmt)
+
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
+    mod = tvm.tir.transform.Simplify()(mod)
+    mod = tvm.tir.transform.StorageRewrite()(mod)
+    stmt = mod["main"].body
+
     alloc_stats = {"global": 0, "shared": 0}
 
     def verify(n):
@@ -248,7 +266,9 @@ def test_parallel_alloc():
             A[j] = A[j] + 2
 
     body = ib.get()
-    body = tvm.tir.ir_pass.StorageRewrite(body)
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body))
+    body = tvm.tir.transform.StorageRewrite()(mod)["main"].body
+
     assert (isinstance(body.body.body, tvm.tir.Allocate))
 
     ib = tvm.tir.ir_builder.create()
@@ -262,7 +282,9 @@ def test_parallel_alloc():
                 A = ib.allocate("float32", n, name="A", scope="global")
                 A[j] = A[j] + 2
     body = ib.get()
-    body = tvm.tir.ir_pass.StorageRewrite(body)
+
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body))
+    body = tvm.tir.transform.StorageRewrite()(mod)["main"].body
 
     assert(isinstance(body.body.body.body.body, tvm.tir.Allocate))
 
@@ -289,9 +311,12 @@ def test_inplace_rule2(scope_tb = "local_TB2", max_bits = 1024 * 1024 * 1024):
     Cc = tvm.tir.decl_buffer(C.shape, B.dtype, name='C')
     Dd = tvm.tir.decl_buffer(D.shape, B.dtype, name='D')
     stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb, C: Cc, D:Dd}, 64)
-    stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt)
-    stmt = tvm.tir.ir_pass.Simplify(stmt)
-    stmt = tvm.tir.ir_pass.StorageRewrite(stmt)
+
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb, Cc, Dd], stmt))
+    mod = tvm.tir.transform.Simplify()(mod)
+    mod = tvm.tir.transform.StorageRewrite()(mod)
+    stmt = mod["main"].body
+
     # verify only have one allocations.
     # verify inplace folding works
     num_alloc = [0]
@@ -381,10 +406,13 @@ def test_inplace_rule3():
     B5a = tvm.tir.decl_buffer(B5.shape, B5.dtype, name='B5')
 
     Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B')
-    stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {B0: B0a, B1: B1a, B2: B2a, B3: B2a, B4: B4a, B5: B5a, B: Bb}, 64)
-    stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt)
-    stmt = tvm.tir.ir_pass.Simplify(stmt)
-    stmt = tvm.tir.ir_pass.StorageRewrite(stmt)
+    stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {B0: B0a, B1: B1a, B2: B2a, B3: B3a, B4: B4a, B5: B5a, B: Bb}, 64)
+
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([B0a, B1a, B2a, B3a, B4a, B5a, Bb], stmt))
+    mod = tvm.tir.transform.Simplify()(mod)
+    mod = tvm.tir.transform.StorageRewrite()(mod)
+    stmt = mod["main"].body
+
     # verify only have one allocations.
     # verify inplace folding works
     def verify(n):
@@ -411,7 +439,10 @@ def test_alloc_seq_type():
             A2[j] = A[j]
 
     body = ib.get()
-    body = tvm.tir.ir_pass.StorageRewrite(body)
+
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body))
+    body = tvm.tir.transform.StorageRewrite()(mod)["main"].body
+
     num_alloc = [0]
     def verify(n):
         if isinstance(n, tvm.tir.Allocate):
@@ -440,7 +471,10 @@ def test_alloc_seq_type2():
             C[j] = 1.2
 
     body = ib.get()
-    body = tvm.tir.ir_pass.StorageRewrite(body)
+
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body))
+    body = tvm.tir.transform.StorageRewrite()(mod)["main"].body
+
     num_alloc = [0]
     def verify(n):
         if isinstance(n, tvm.tir.Allocate):
@@ -469,7 +503,9 @@ def test_reuse_small_buffer():
             E[j] = C[j]
 
     body = ib.get()
-    body = tvm.tir.ir_pass.StorageRewrite(body)
+
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body))
+    body = tvm.tir.transform.StorageRewrite()(mod)["main"].body
 
     num_alloc = [0]
 
@@ -519,14 +555,15 @@ def test_large_input():
 
 
 if __name__ == "__main__":
+    test_storage_share()
     test_alloc_seq()
     test_alloc_different_dtypes()
     test_inplace_rule()
-    test_storage_share()
     test_parallel_alloc()
     test_storage_combine()
     test_storage_share_gpu()
     test_inplace_rule2()
+
     test_exceed_mem()
     test_inplace_rule3()
     test_alloc_seq_type()
@@ -46,7 +46,11 @@ def test_unroll_loop():
     wrapped = ib.get()
     wrapped = tvm.tir.SeqStmt([wrapped, stmt])
     assert isinstance(ret, tvm.tir.For)
-    ret = tvm.tir.ir_pass.UnrollLoop(wrapped, 0, 8, 0, False)
+
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], wrapped))
+    ret = tvm.tir.transform.UnrollLoop(0, 8, 0, False)(mod)["main"].body
+
+    # ret = tvm.tir.ir_pass.UnrollLoop(wrapped, 0, 8, 0, False)
     assert isinstance(ret[0], tvm.tir.For)
     assert ret[0].for_type == tvm.tir.For.Unrolled
     assert isinstance(ret[1], tvm.tir.For)
@@ -65,7 +69,11 @@ def test_unroll_fake_loop():
             Aptr[j + 1] = Aptr[i] + 1
 
     stmt = ib.get()
-    ret = tvm.tir.ir_pass.UnrollLoop(stmt, 8, 0, 1, True)
+
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt))
+    ret = tvm.tir.transform.UnrollLoop(8, 0, 1, False)(mod)["main"].body
+
+    # ret = tvm.tir.ir_pass.UnrollLoop(stmt, 8, 0, 1, True)
     assert isinstance(ret[0], tvm.tir.Store)
 
 def test_unroll_single_count_loops():
@@ -78,8 +86,10 @@ def test_unroll_single_count_loops():
     stmt = tvm.te.schedule.ScheduleOps(s, dom_map)
     # all parameters to UnrolLoops are default values except for
     # auto_unroll_max_extent which has been set to 1 (default:0)
-    after_unroll_stmt = tvm.tir.ir_pass.UnrollLoop(stmt, 0, 8, 1, True)
-    assert after_unroll_stmt == stmt
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+    ret = tvm.tir.transform.UnrollLoop(0, 8, 1, True)(mod)["main"].body
+
+    assert ret == stmt
 
 if __name__ == "__main__":
     test_unroll_loop()
@@ -28,12 +28,16 @@ def test_vectorize_loop():
     stmt = ib.get()
 
     assert isinstance(stmt.body, tvm.tir.For)
-    stmt = tvm.tir.ir_pass.VectorizeLoop(stmt)
+
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt))
+    stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body
+
     assert isinstance(stmt, tvm.tir.For)
     assert not isinstance(stmt.body, tvm.tir.For)
     assert isinstance(stmt.body.index, tvm.tir.Ramp)
     assert isinstance(stmt.body.value, tvm.tir.Broadcast)
 
+
 def test_vectorize_vector():
     dtype = 'int64'
     n = te.var('n')
@@ -44,7 +48,10 @@ def test_vectorize_vector():
             A[j] = tvm.tir.const(1, A.dtype)
     stmt = ib.get()
     assert isinstance(stmt.body, tvm.tir.For)
-    stmt = tvm.tir.ir_pass.VectorizeLoop(stmt)
+
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt))
+    stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body
+
     assert isinstance(stmt, tvm.tir.For)
     assert not isinstance(stmt.body, tvm.tir.For)
     assert isinstance(stmt.body.index, tvm.tir.Ramp)
@@ -63,13 +70,17 @@ def test_vectorize_with_if():
             with ib.if_scope(i < n):
                 A[i] = 2.0
     stmt = ib.get()
-    stmt = tvm.tir.ir_pass.VectorizeLoop(stmt)
+
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n, x], stmt))
+    stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body
+
     assert isinstance(stmt, tvm.tir.IfThenElse)
     assert isinstance(stmt.then_case.index, tvm.tir.Ramp)
     assert isinstance(stmt.then_case.value, tvm.tir.Add)
     assert stmt.then_case.value.dtype == "float32x4"
     assert isinstance(stmt.else_case, tvm.tir.For)
 
+
 def test_vectorize_with_le_cond():
     n = te.var('n')
     ib = tvm.tir.ir_builder.create()
@@ -78,9 +89,13 @@ def test_vectorize_with_le_cond():
         with ib.if_scope(i <= n):
             A[i] = A[i] + 1
     stmt = ib.get()
-    stmt = tvm.tir.ir_pass.VectorizeLoop(stmt)
+
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt))
+    stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body
+
     assert isinstance(stmt, tvm.tir.For)
 
+
 def test_vectorize_with_ge_cond():
     n = te.var('n')
     ib = tvm.tir.ir_builder.create()
@@ -89,9 +104,13 @@ def test_vectorize_with_ge_cond():
         with ib.if_scope(i >= n):
             A[i] = A[i] + 1
     stmt = ib.get()
-    stmt = tvm.tir.ir_pass.VectorizeLoop(stmt)
+
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt))
+    stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body
+
     assert isinstance(stmt, tvm.tir.For)
 
+
 def test_vectorize_if_then_else():
     n = te.var('n')
     x = te.var('x')
@@ -102,7 +121,10 @@ def test_vectorize_if_then_else():
                                i > 0,
                                A[i] + 1, A[i])
     stmt = ib.get()
-    stmt = tvm.tir.ir_pass.VectorizeLoop(stmt)
+
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n, x], stmt))
+    stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body
+
     assert isinstance(stmt, tvm.tir.For)
 
 
@@ -114,8 +136,12 @@ def test_vectorize_if_then_else():
                                            k > 0,
                                            A[k * 4 + i], 0)
     stmt = ib.get()
+
     assert isinstance(stmt.body, tvm.tir.For)
-    stmt = tvm.tir.ir_pass.VectorizeLoop(stmt)
+
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt))
+    stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body
+
     assert not isinstance(stmt.body, tvm.tir.For)
     assert isinstance(stmt.body.value.args[2], tvm.tir.Broadcast)