[REFACTOR][TIR] Migrate Low-level Passes to Pass Manager (#5198)
authorTianqi Chen <tqchen@users.noreply.github.com>
Wed, 1 Apr 2020 02:44:41 +0000 (19:44 -0700)
committerGitHub <noreply@github.com>
Wed, 1 Apr 2020 02:44:41 +0000 (19:44 -0700)
* [TIR][TRANSFORM] Migrate LowerIntrin

* LowerDeviceStorageAccessInfo

* Migrate LowerWarpMemory

14 files changed:
include/tvm/ir/module.h
include/tvm/tir/transform.h
python/tvm/tir/transform/transform.py
src/ir/module.cc
src/ir/transform.cc
src/target/codegen.cc
src/tir/pass/storage_access.cc
src/tir/transforms/combine_context_call.cc
src/tir/transforms/lower_device_storage_access_info.cc [new file with mode: 0644]
src/tir/transforms/lower_intrin.cc [moved from src/tir/pass/lower_intrin.cc with 90% similarity]
src/tir/transforms/lower_warp_memory.cc [moved from src/tir/pass/lower_warp_memory.cc with 94% similarity]
tests/python/unittest/test_tir_transform_combine_context_call.py
tests/python/unittest/test_tir_transform_lower_intrin.py [moved from tests/python/unittest/test_tir_pass_lower_intrin.py with 77% similarity]
tests/python/unittest/test_tir_transform_lower_warp_memory.py [moved from tests/python/unittest/test_tir_pass_lower_warp_memory.py with 72% similarity]

index eae936d..f6ea918 100644 (file)
@@ -321,6 +321,9 @@ class IRModule : public ObjectRef {
    * \return A Relay module.
    */
   TVM_DLL static IRModule FromText(const std::string& text, const std::string& source_path);
+
+  /*! \brief Declare the container type. */
+  using ContainerType = IRModuleNode;
 };
 
 /*!
index 5149677..9d55db5 100644 (file)
@@ -59,11 +59,33 @@ TVM_DLL Pass CreatePrimFuncPass(const runtime::TypedPackedFunc<
                                 const tvm::Array<tvm::PrimExpr>& required);
 
 /*!
- * \brief Create PrimFuncPass to combine context calls in the host function.
+ * \brief Combine context calls in the host function.
  *
  * \return The pass.
  */
-Pass CombineContextCall();
+TVM_DLL Pass CombineContextCall();
+
+/*!
+ * \brief Lower the target specific function intrinsics in each of the function.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass LowerIntrin();
+
+/*!
+ * \brief Lower attached storage access information on device.
+ *
+ * \note Run this pass after all storage access analysis finish.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass LowerDeviceStorageAccessInfo();
+
+/*!
+ * \brief Lower warp memory access to low-level device related function calls.
+ * \return The pass.
+ */
+TVM_DLL Pass LowerWarpMemory();
 
 }  // namespace transform
 }  // namespace tir
index 1eec94e..2b50387 100644 (file)
@@ -29,3 +29,40 @@ def CombineContextCall():
         The result pass
     """
     return _ffi_api.CombineContextCall()
+
+
+def LowerIntrin():
+    """Lower target specific intrinsic calls.
+
+    Returns
+    -------
+    fpass : tvm.ir.transform.Pass
+        The result pass
+    """
+    return _ffi_api.LowerIntrin()
+
+
+def LowerDeviceStorageAccessInfo():
+    """Lower attached storage access information on device.
+
+    Returns
+    -------
+    fpass : tvm.ir.transform.Pass
+        The result pass
+
+    Note
+    ----
+    Run this pass after all storage access analysis finish.
+    """
+    return _ffi_api.LowerDeviceStorageAccessInfo()
+
+
+def LowerWarpMemory():
+    """Lower warp memory access to low-level device related function calls.
+
+    Returns
+    -------
+    fpass : tvm.ir.transform.Pass
+        The result pass
+    """
+    return _ffi_api.LowerWarpMemory()
index c7474de..ea74f4c 100644 (file)
@@ -408,8 +408,8 @@ TVM_REGISTER_GLOBAL("ir.Module_Add")
   bool update = args[3];
   CHECK(val->IsInstance<RelayExprNode>());
 
-  if (val->IsInstance<relay::FunctionNode>()) {
-    mod->Add(var, Downcast<relay::Function>(val), update);
+  if (val->IsInstance<BaseFuncNode>()) {
+    mod->Add(var, Downcast<BaseFunc>(val), update);
   } else if (val->IsInstance<GlobalVarNode>()) {
     GlobalVar gv = Downcast<GlobalVar>(val);
     auto mod_copy = IRModule(make_object<IRModuleNode>(*mod.operator->()));
index 6878abc..61c1fc2 100644 (file)
@@ -382,6 +382,7 @@ TVM_REGISTER_GLOBAL("transform.RunPass")
 .set_body([](TVMArgs args, TVMRetValue* ret) {
   Pass pass = args[0];
   IRModule mod = args[1];
+  ObjectRef ref = args[1];
   *ret = pass(mod);
 });
 
index e9ff234..1981c21 100644 (file)
@@ -54,8 +54,6 @@ runtime::Module BuildForIRModule(const IRModule& module,
   return (*bf)(module, target->str());
 }
 
-
-
 // convert legacy LoweredFunc to PrimFunc.
 tir::PrimFunc ToPrimFunc(tir::LoweredFunc from) {
   // remap args to attach type annotations.
index aaee582..f6bba48 100644 (file)
@@ -235,116 +235,5 @@ StorageScope StorageAccessVisitor::GetScope(const VarNode* buf) const {
   return it->second;
 }
 
-
-class StorageAccessInfoLower : public StmtExprMutator {
- public:
-  Stmt VisitStmt_(const AllocateNode* op) final {
-    // Lower allocate to device allocate when needed.
-    Stmt stmt = StmtExprMutator::VisitStmt_(op);
-    op = stmt.as<AllocateNode>();
-    // For special memory, remove allocate, or use head expr
-    auto it = storage_info_.find(op->buffer_var.get());
-    if (it != storage_info_.end() && it->second.info.defined()) {
-      const MemoryInfo& info = it->second.info;
-      ++it->second.alloc_count;
-      CHECK_LE(it->second.alloc_count, 1)
-          << "Double allocation of " << it->second.scope.to_string();
-      if (info->head_address.defined()) {
-        return AllocateNode::make(
-            op->buffer_var, op->dtype, op->extents, op->condition,
-            op->body, info->head_address, "nop");
-      }
-      return op->body;
-    } else {
-      return stmt;
-    }
-  }
-  Stmt VisitStmt_(const AttrStmtNode* op) final {
-    if (op->attr_key == attr::storage_scope) {
-      const VarNode* buf = op->node.as<VarNode>();
-      StorageScope scope = StorageScope::make(op->value.as<StringImmNode>()->value);
-      StorageEntry e;
-      e.scope = scope;
-      if (scope.tag.length() != 0) {
-        e.info = GetMemoryInfo(op->value.as<StringImmNode>()->value);
-        CHECK(e.info.defined()) << "Cannot find memory info of " << scope.to_string();
-      }
-      storage_info_[buf] = e;
-      return StmtExprMutator::VisitStmt_(op);
-
-    } else {
-      return StmtExprMutator::VisitStmt_(op);
-    }
-  }
-
-  PrimExpr VisitExpr_(const CallNode* op) final {
-    if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
-      return MakeAccessPtr(op);
-    } else {
-      return StmtExprMutator::VisitExpr_(op);
-    }
-  }
-
- private:
-  // tvm_access_ptr
-  PrimExpr MakeAccessPtr(const CallNode* op) {
-    // Specially handle the buffer packed intrinsic
-    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
-    op = expr.as<CallNode>();
-    CHECK_EQ(op->args.size(), 5U);
-    DataType dtype = op->args[0].dtype();
-    const VarNode* buffer = op->args[1].as<VarNode>();
-    Var buffer_var = Downcast<Var>(op->args[1]);
-    PrimExpr offset = op->args[2];
-    auto it = storage_info_.find(buffer);
-    if (it != storage_info_.end() && it->second.info.defined()) {
-      return MakeTaggedAccessPtr(
-          op->dtype, buffer_var, dtype, offset,
-          it->second.info);
-    }
-    CHECK(op->dtype.is_handle());
-    // Change to address_of
-    return AddressOffset(buffer_var, dtype, offset);
-  }
-
-  PrimExpr MakeTaggedAccessPtr(DataType ptr_type,
-                           Var buffer_var,
-                           DataType dtype,
-                           PrimExpr offset,
-                           const MemoryInfo& info) {
-    if (ptr_type.is_handle()) {
-      CHECK(info->head_address.defined())
-          << buffer_var << " is not adddressable.";
-      return AddressOffset(buffer_var, dtype, offset);
-    }
-    int dtype_bits = dtype.bits() * dtype.lanes();
-    CHECK_EQ(info->unit_bits % dtype_bits, 0);
-    return cast(ptr_type,
-                   tir::Simplify(offset / make_const(
-                       offset.dtype(), info->unit_bits / dtype_bits)));
-  }
-  // The storage entry.
-  struct StorageEntry {
-    // Whether it is tagged memory.
-    StorageScope scope;
-    // The memory info if any.
-    MemoryInfo info;
-    // Allocation counter
-    int alloc_count{0};
-  };
-  // The storage scope of each buffer
-  std::unordered_map<const VarNode*, StorageEntry> storage_info_;
-};
-
-Stmt LowerStorageAccessInfo(Stmt stmt) {
-  return StorageAccessInfoLower()(std::move(stmt));
-}
-
-LoweredFunc LowerDeviceStorageAccessInfo(LoweredFunc f) {
-  auto n = make_object<LoweredFuncNode>(*f.operator->());
-  n->body = LowerStorageAccessInfo(f->body);
-  return LoweredFunc(n);
-}
-
 }  // namespace tir
 }  // namespace tvm
index ed352c1..069de57 100644 (file)
@@ -126,7 +126,7 @@ Pass CombineContextCall() {
     n->body = ContextCallCombiner().Combine(n->body);
     return f;
   };
-  return CreatePrimFuncPass(pass_func, 0, "CombineContextCall", {});
+  return CreatePrimFuncPass(pass_func, 0, "tir.CombineContextCall", {});
 }
 
 TVM_REGISTER_GLOBAL("tir.transform.CombineContextCall")
diff --git a/src/tir/transforms/lower_device_storage_access_info.cc b/src/tir/transforms/lower_device_storage_access_info.cc
new file mode 100644 (file)
index 0000000..5797665
--- /dev/null
@@ -0,0 +1,168 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file lower_device_storage_access.cc
+ * \brief Lower the special device storage access.
+ */
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+#include <tvm/tir/buffer.h>
+#include <tvm/target/target_info.h>
+#include <tvm/runtime/registry.h>
+
+#include <tvm/tir/ir_pass.h>
+
+#include "../pass/ir_util.h"
+#include "../../runtime/thread_storage_scope.h"
+
+namespace tvm {
+namespace tir {
+
+using runtime::StorageScope;
+using runtime::StorageRank;
+
+class StorageAccessInfoLower : public StmtExprMutator {
+ public:
+  Stmt VisitStmt_(const AllocateNode* op) final {
+    // Lower allocate to device allocate when needed.
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<AllocateNode>();
+    // For special memory, remove allocate, or use head expr
+    auto it = storage_info_.find(op->buffer_var.get());
+    if (it != storage_info_.end() && it->second.info.defined()) {
+      const MemoryInfo& info = it->second.info;
+      ++it->second.alloc_count;
+      CHECK_LE(it->second.alloc_count, 1)
+          << "Double allocation of " << it->second.scope.to_string();
+      if (info->head_address.defined()) {
+        return AllocateNode::make(
+            op->buffer_var, op->dtype, op->extents, op->condition,
+            op->body, info->head_address, "nop");
+      }
+      return op->body;
+    } else {
+      return stmt;
+    }
+  }
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    if (op->attr_key == attr::storage_scope) {
+      const VarNode* buf = op->node.as<VarNode>();
+      StorageScope scope = StorageScope::make(op->value.as<StringImmNode>()->value);
+      StorageEntry e;
+      e.scope = scope;
+      if (scope.tag.length() != 0) {
+        e.info = GetMemoryInfo(op->value.as<StringImmNode>()->value);
+        CHECK(e.info.defined()) << "Cannot find memory info of " << scope.to_string();
+      }
+      storage_info_[buf] = e;
+      return StmtExprMutator::VisitStmt_(op);
+
+    } else {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+  }
+
+  PrimExpr VisitExpr_(const CallNode* op) final {
+    if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
+      return MakeAccessPtr(op);
+    } else {
+      return StmtExprMutator::VisitExpr_(op);
+    }
+  }
+
+ private:
+  // tvm_access_ptr
+  PrimExpr MakeAccessPtr(const CallNode* op) {
+    // Specially handle the buffer packed intrinsic
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<CallNode>();
+    CHECK_EQ(op->args.size(), 5U);
+    DataType dtype = op->args[0].dtype();
+    const VarNode* buffer = op->args[1].as<VarNode>();
+    Var buffer_var = Downcast<Var>(op->args[1]);
+    PrimExpr offset = op->args[2];
+    auto it = storage_info_.find(buffer);
+    if (it != storage_info_.end() && it->second.info.defined()) {
+      return MakeTaggedAccessPtr(
+          op->dtype, buffer_var, dtype, offset,
+          it->second.info);
+    }
+    CHECK(op->dtype.is_handle());
+    // Change to address_of
+    return AddressOffset(buffer_var, dtype, offset);
+  }
+
+  PrimExpr MakeTaggedAccessPtr(DataType ptr_type,
+                           Var buffer_var,
+                           DataType dtype,
+                           PrimExpr offset,
+                           const MemoryInfo& info) {
+    if (ptr_type.is_handle()) {
+      CHECK(info->head_address.defined())
+          << buffer_var << " is not adddressable.";
+      return AddressOffset(buffer_var, dtype, offset);
+    }
+    int dtype_bits = dtype.bits() * dtype.lanes();
+    CHECK_EQ(info->unit_bits % dtype_bits, 0);
+    return cast(ptr_type,
+                   tir::Simplify(offset / make_const(
+                       offset.dtype(), info->unit_bits / dtype_bits)));
+  }
+  // The storage entry.
+  struct StorageEntry {
+    // Whether it is tagged memory.
+    StorageScope scope;
+    // The memory info if any.
+    MemoryInfo info;
+    // Allocation counter
+    int alloc_count{0};
+  };
+  // The storage scope of each buffer
+  std::unordered_map<const VarNode*, StorageEntry> storage_info_;
+};
+
+Stmt LowerStorageAccessInfo(Stmt stmt) {
+  return StorageAccessInfoLower()(std::move(stmt));
+}
+
+LoweredFunc LowerDeviceStorageAccessInfo(LoweredFunc f) {
+  auto n = make_object<LoweredFuncNode>(*f.operator->());
+  n->body = LowerStorageAccessInfo(f->body);
+  return LoweredFunc(n);
+}
+
+namespace transform {
+
+Pass LowerDeviceStorageAccessInfo() {
+  auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+    auto* n = f.CopyOnWrite();
+    n->body = StorageAccessInfoLower()(std::move(n->body));
+    return f;
+  };
+  return CreatePrimFuncPass(
+      pass_func, 0, "tir.LowerDeviceStorageAccessInfo", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.LowerDeviceStorageAccessInfo")
+.set_body_typed(LowerDeviceStorageAccessInfo);
+
+}  // namespace transform
+}  // namespace tir
+}  // namespace tvm
similarity index 90%
rename from src/tir/pass/lower_intrin.cc
rename to src/tir/transforms/lower_intrin.cc
index d39624d..6d4863d 100644 (file)
  */
 #include <tvm/tir/expr.h>
 #include <tvm/tir/ir_pass.h>
+#include <tvm/tir/transform.h>
 #include <tvm/runtime/registry.h>
 
 #include <tvm/tir/op.h>
+#include <tvm/target/target.h>
 #include <unordered_set>
-#include "ir_util.h"
 #include "../../arith/pattern_match.h"
 #include "../../arith/ir_mutator_with_analyzer.h"
 
@@ -39,15 +40,12 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
   using IRMutatorWithAnalyzer::VisitStmt_;
   using IRMutatorWithAnalyzer::VisitExpr_;
 
-  IntrinInjecter(arith::Analyzer* analyzer, std::string target)
+  IntrinInjecter(arith::Analyzer* analyzer, std::string target_name)
       : IRMutatorWithAnalyzer(analyzer) {
-    std::istringstream is(target);
-    std::string starget;
-    is >> starget;
-    patterns_.push_back("tvm.intrin.rule." + starget + ".");
+    patterns_.push_back("tvm.intrin.rule." + target_name + ".");
     patterns_.push_back("tvm.intrin.rule.default.");
     fma_ = runtime::Registry::Get(patterns_[0] + "fma");
-    if (target == "stackvm") {
+    if (target_name == "stackvm") {
       support_bitwise_op_ = false;
     }
   }
@@ -280,21 +278,41 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
   bool support_bitwise_op_{true};
 };
 
-Stmt LowerIntrinStmt(Stmt stmt, const std::string& target) {
+Stmt LowerIntrinStmt(Stmt stmt, const std::string target_name) {
   arith::Analyzer analyzer;
-  return IntrinInjecter(&analyzer, target)(std::move(stmt));
+  return IntrinInjecter(&analyzer, target_name)(std::move(stmt));
 }
 
 LoweredFunc
 LowerIntrin(LoweredFunc f, const std::string& target) {
   auto n = make_object<LoweredFuncNode>(*f.operator->());
-  n->body = LowerIntrinStmt(n->body, target);
+  std::istringstream is(target);
+  std::string target_name;
+  is >> target_name;
+  n->body = LowerIntrinStmt(n->body, target_name);
   return LoweredFunc(n);
 }
 
-// Register the api only for test purposes
-TVM_REGISTER_GLOBAL("ir_pass._LowerIntrinStmt")
-.set_body_typed(LowerIntrinStmt);
+namespace transform {
+
+Pass LowerIntrin() {
+  auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+    auto* n = f.CopyOnWrite();
+    auto target = f->GetAttr<Target>(tvm::attr::kTarget);
+    CHECK(target.defined())
+        << "LowerIntrin: Require the target attribute";
+    arith::Analyzer analyzer;
+    n->body =
+        IntrinInjecter(&analyzer, target->target_name)(std::move(n->body));
+    return f;
+  };
+  return CreatePrimFuncPass(pass_func, 0, "tir.LowerIntrin", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.LowerIntrin")
+.set_body_typed(LowerIntrin);
+
+}  // namespace transform
 
 }  // namespace tir
 }  // namespace tvm
similarity index 94%
rename from src/tir/pass/lower_warp_memory.cc
rename to src/tir/transforms/lower_warp_memory.cc
index 385a5b4..808b081 100644 (file)
 
 #include <tvm/tir/expr.h>
 #include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+#include <tvm/target/target.h>
+#include <tvm/runtime/registry.h>
 #include <tvm/tir/ir_pass.h>
+
 #include <unordered_set>
-#include "ir_util.h"
+
+#include "../pass/ir_util.h"
 #include "../../arith/compute_expr.h"
 #include "../../runtime/thread_storage_scope.h"
 
@@ -388,5 +393,24 @@ LowerWarpMemory(LoweredFunc f, int warp_size) {
   return LoweredFunc(n);
 }
 
+namespace transform {
+
+Pass LowerWarpMemory() {
+  auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+    auto* n = f.CopyOnWrite();
+    auto target = f->GetAttr<Target>(tvm::attr::kTarget);
+    CHECK(target.defined())
+        << "LowerWarpMemory: Require the target attribute";
+    n->body = WarpMemoryRewriter(target->thread_warp_size).Rewrite(n->body);
+    return f;
+  };
+  return CreatePrimFuncPass(pass_func, 0, "tir.LowerWarpMemory", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.LowerWarpMemory")
+.set_body_typed(LowerWarpMemory);
+
+}  // namespace transform
+
 }  // namespace tir
 }  // namespace tvm
index e76fb33..8140ddb 100644 (file)
@@ -39,7 +39,7 @@ def test_for():
     f = tvm.tir.ir_pass.MakeAPI(body, "func", [dev_type, n], 2, True)
 
     # temp adapter to convert loweredFunc to IRModule
-    # to test passes in the new style.
+    # to test passes in the new style.x
     mod = tvm.testing.LoweredFuncsToIRModule([f])
 
     mod = tvm.tir.transform.CombineContextCall()(mod)
@@ -18,12 +18,15 @@ import tvm
 from tvm import te
 import numpy as np
 
-def lower_intrin(stmt):
+def lower_intrin(params, stmt):
     """wrapper to call transformation in stmt"""
     lower_expr = isinstance(stmt, tvm.tir.PrimExpr)
     stmt = tvm.tir.Evaluate(stmt) if lower_expr else stmt
     stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt)
-    stmt  = tvm.tir.ir_pass._LowerIntrinStmt(stmt, "llvm")
+    func = tvm.tir.PrimFunc(params, stmt).with_attr(
+        "target", tvm.target.create("llvm"))
+    func = tvm.tir.transform.LowerIntrin()(tvm.IRModule.from_expr(func))["main"]
+    stmt = func.body
     return stmt.value if lower_expr else stmt.body
 
 
@@ -70,19 +73,19 @@ def test_lower_floordiv():
         y = te.var("y", dtype=dtype)
         zero = tvm.tir.const(0, dtype)
         # no constraints
-        res = lower_intrin(tvm.te.floordiv(x, y))
+        res = lower_intrin([x, y], tvm.te.floordiv(x, y))
         check_value(res, x, y, data, lambda a, b: a // b)
         # rhs >= 0
-        res = lower_intrin(tvm.tir.Select(y >= 0, tvm.te.floordiv(x, y), zero))
+        res = lower_intrin([x, y], tvm.tir.Select(y >= 0, tvm.te.floordiv(x, y), zero))
         check_value(res, x, y, data, lambda a, b: a // b if b > 0 else 0)
         # involves max
-        res = lower_intrin(tvm.tir.Select(y >= 0, tvm.te.max(tvm.te.floordiv(x, y), zero), zero))
+        res = lower_intrin([x, y], tvm.tir.Select(y >= 0, tvm.te.max(tvm.te.floordiv(x, y), zero), zero))
         check_value(res, x, y, data, lambda a, b: max(a // b, 0) if b > 0 else 0)
         # lhs >= 0
-        res = lower_intrin(tvm.tir.Select(tvm.tir.all(y >= 0, x >= 0), tvm.te.floordiv(x, y), zero))
+        res = lower_intrin([x, y], tvm.tir.Select(tvm.tir.all(y >= 0, x >= 0), tvm.te.floordiv(x, y), zero))
         check_value(res, x, y, data, lambda a, b: a // b if b > 0 and a >= 0 else 0)
         # const power of two
-        res = lower_intrin(tvm.te.floordiv(x, tvm.tir.const(8, dtype=dtype)))
+        res = lower_intrin([x, y], tvm.te.floordiv(x, tvm.tir.const(8, dtype=dtype)))
         check_value(res, x, y, [(a, b) for a, b in data if b == 8], lambda a, b: a // b)
 
 
@@ -93,16 +96,16 @@ def test_lower_floormod():
         y = te.var("y", dtype=dtype)
         zero = tvm.tir.const(0, dtype)
         # no constraints
-        res = lower_intrin(tvm.te.floormod(x, y))
+        res = lower_intrin([x, y], tvm.te.floormod(x, y))
         check_value(res, x, y, data, lambda a, b: a % b)
         # rhs >= 0
-        res = lower_intrin(tvm.tir.Select(y >= 0, tvm.te.floormod(x, y), zero))
+        res = lower_intrin([x, y], tvm.tir.Select(y >= 0, tvm.te.floormod(x, y), zero))
         check_value(res, x, y, data, lambda a, b: a % b if b > 0 else 0)
         # lhs >= 0
-        res = lower_intrin(tvm.tir.Select(tvm.tir.all(y >= 0, x >= 0), tvm.te.floormod(x, y), zero))
+        res = lower_intrin([x, y], tvm.tir.Select(tvm.tir.all(y >= 0, x >= 0), tvm.te.floormod(x, y), zero))
         check_value(res, x, y, data, lambda a, b: a % b if b > 0 and a >= 0 else 0)
         # const power of two
-        res = lower_intrin(tvm.te.floormod(x, tvm.tir.const(8, dtype=dtype)))
+        res = lower_intrin([x, y], tvm.te.floormod(x, tvm.tir.const(8, dtype=dtype)))
         check_value(res, x, y, [(a, b) for a, b in data if b == 8], lambda a, b: a % b)
 
 
@@ -24,18 +24,26 @@ def test_lower_warp_mem():
 
     s = te.create_schedule(B.op)
     AA = s.cache_read(A, "warp", [B])
-    xo, xi = s[B].split(B.op.axis[0], 32)
-    xi0, xi1 = s[B].split(xi, factor=16)
+    xo, xi = s[B].split(B.op.axis[0], 64)
+    xi0, xi1 = s[B].split(xi, factor=32)
     tx = te.thread_axis("threadIdx.x")
     s[B].bind(xi1, tx)
     s[B].bind(xo, te.thread_axis("blockIdx.x"))
     s[AA].compute_at(s[B], xo)
-    xo, xi = s[AA].split(s[AA].op.axis[0], 16)
+    xo, xi = s[AA].split(s[AA].op.axis[0], 32)
     s[AA].bind(xi, tx)
 
     f = tvm.lower(s, [A, B])
     fhost, fdevice = tvm.tir.ir_pass.SplitHostDevice(f)
-    fdevice = tvm.tir.ir_pass.LowerWarpMemory(fdevice, 16)
+
+    # temp adapter to convert loweredFunc to IRModule
+    # to test passes in the new style.
+    fname = fdevice.name
+    mod = tvm.testing.LoweredFuncsToIRModule([fdevice])
+    cuda_target = tvm.target.create("cuda")
+    assert cuda_target.thread_warp_size == 32
+    mod = tvm.IRModule.from_expr(mod[fname].with_attr("target", cuda_target))
+    fdevice = tvm.tir.transform.LowerWarpMemory()(mod)["main"]
     assert(fdevice.body.body.value.value == "local")
     assert(fdevice.body.body.body.extents[0].value == 2)