From: Tianqi Chen Date: Wed, 1 Apr 2020 02:44:41 +0000 (-0700) Subject: [REFACTOR][TIR] Migrate Low-level Passes to Pass Manager (#5198) X-Git-Tag: upstream/0.7.0~1014 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=2b6d69c62c07acc102c6ca42ee5c4edcc3de41f1;p=platform%2Fupstream%2Ftvm.git [REFACTOR][TIR] Migrate Low-level Passes to Pass Manager (#5198) * [TIR][TRANSFORM] Migrate LowerIntrin * LowerDeviceStorageAccessInfo * Migrate LowerWarpMemory --- diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index eae936d..f6ea918 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -321,6 +321,9 @@ class IRModule : public ObjectRef { * \return A Relay module. */ TVM_DLL static IRModule FromText(const std::string& text, const std::string& source_path); + + /*! \brief Declare the container type. */ + using ContainerType = IRModuleNode; }; /*! diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 5149677..9d55db5 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -59,11 +59,33 @@ TVM_DLL Pass CreatePrimFuncPass(const runtime::TypedPackedFunc< const tvm::Array& required); /*! - * \brief Create PrimFuncPass to combine context calls in the host function. + * \brief Combine context calls in the host function. * * \return The pass. */ -Pass CombineContextCall(); +TVM_DLL Pass CombineContextCall(); + +/*! + * \brief Lower the target specific function intrinsics in each of the function. + * + * \return The pass. + */ +TVM_DLL Pass LowerIntrin(); + +/*! + * \brief Lower attached storage access information on device. + * + * \note Run this pass after all storage access analysis finish. + * + * \return The pass. + */ +TVM_DLL Pass LowerDeviceStorageAccessInfo(); + +/*! + * \brief Lower warp memory access to low-level device related function calls. + * \return The pass. + */ +TVM_DLL Pass LowerWarpMemory(); } // namespace transform } // namespace tir diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 1eec94e..2b50387 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -29,3 +29,40 @@ def CombineContextCall(): The result pass """ return _ffi_api.CombineContextCall() + + +def LowerIntrin(): + """Lower target specific intrinsic calls. + + Returns + ------- + fpass : tvm.ir.transform.Pass + The result pass + """ + return _ffi_api.LowerIntrin() + + +def LowerDeviceStorageAccessInfo(): + """Lower attached storage access information on device. + + Returns + ------- + fpass : tvm.ir.transform.Pass + The result pass + + Note + ---- + Run this pass after all storage access analysis finish. + """ + return _ffi_api.LowerDeviceStorageAccessInfo() + + +def LowerWarpMemory(): + """Lower warp memory access to low-level device related function calls. + + Returns + ------- + fpass : tvm.ir.transform.Pass + The result pass + """ + return _ffi_api.LowerWarpMemory() diff --git a/src/ir/module.cc b/src/ir/module.cc index c7474de..ea74f4c 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -408,8 +408,8 @@ TVM_REGISTER_GLOBAL("ir.Module_Add") bool update = args[3]; CHECK(val->IsInstance()); - if (val->IsInstance()) { - mod->Add(var, Downcast(val), update); + if (val->IsInstance()) { + mod->Add(var, Downcast(val), update); } else if (val->IsInstance()) { GlobalVar gv = Downcast(val); auto mod_copy = IRModule(make_object(*mod.operator->())); diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 6878abc..61c1fc2 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -382,6 +382,7 @@ TVM_REGISTER_GLOBAL("transform.RunPass") .set_body([](TVMArgs args, TVMRetValue* ret) { Pass pass = args[0]; IRModule mod = args[1]; + ObjectRef ref = args[1]; *ret = pass(mod); }); diff --git a/src/target/codegen.cc b/src/target/codegen.cc index e9ff234..1981c21 100644 --- a/src/target/codegen.cc +++ b/src/target/codegen.cc @@ -54,8 +54,6 @@ runtime::Module BuildForIRModule(const IRModule& module, return (*bf)(module, target->str()); } - - // convert legacy LoweredFunc to PrimFunc. tir::PrimFunc ToPrimFunc(tir::LoweredFunc from) { // remap args to attach type annotations. diff --git a/src/tir/pass/storage_access.cc b/src/tir/pass/storage_access.cc index aaee582..f6bba48 100644 --- a/src/tir/pass/storage_access.cc +++ b/src/tir/pass/storage_access.cc @@ -235,116 +235,5 @@ StorageScope StorageAccessVisitor::GetScope(const VarNode* buf) const { return it->second; } - -class StorageAccessInfoLower : public StmtExprMutator { - public: - Stmt VisitStmt_(const AllocateNode* op) final { - // Lower allocate to device allocate when needed. - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - // For special memory, remove allocate, or use head expr - auto it = storage_info_.find(op->buffer_var.get()); - if (it != storage_info_.end() && it->second.info.defined()) { - const MemoryInfo& info = it->second.info; - ++it->second.alloc_count; - CHECK_LE(it->second.alloc_count, 1) - << "Double allocation of " << it->second.scope.to_string(); - if (info->head_address.defined()) { - return AllocateNode::make( - op->buffer_var, op->dtype, op->extents, op->condition, - op->body, info->head_address, "nop"); - } - return op->body; - } else { - return stmt; - } - } - Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::storage_scope) { - const VarNode* buf = op->node.as(); - StorageScope scope = StorageScope::make(op->value.as()->value); - StorageEntry e; - e.scope = scope; - if (scope.tag.length() != 0) { - e.info = GetMemoryInfo(op->value.as()->value); - CHECK(e.info.defined()) << "Cannot find memory info of " << scope.to_string(); - } - storage_info_[buf] = e; - return StmtExprMutator::VisitStmt_(op); - - } else { - return StmtExprMutator::VisitStmt_(op); - } - } - - PrimExpr VisitExpr_(const CallNode* op) final { - if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { - return MakeAccessPtr(op); - } else { - return StmtExprMutator::VisitExpr_(op); - } - } - - private: - // tvm_access_ptr - PrimExpr MakeAccessPtr(const CallNode* op) { - // Specially handle the buffer packed intrinsic - PrimExpr expr = StmtExprMutator::VisitExpr_(op); - op = expr.as(); - CHECK_EQ(op->args.size(), 5U); - DataType dtype = op->args[0].dtype(); - const VarNode* buffer = op->args[1].as(); - Var buffer_var = Downcast(op->args[1]); - PrimExpr offset = op->args[2]; - auto it = storage_info_.find(buffer); - if (it != storage_info_.end() && it->second.info.defined()) { - return MakeTaggedAccessPtr( - op->dtype, buffer_var, dtype, offset, - it->second.info); - } - CHECK(op->dtype.is_handle()); - // Change to address_of - return AddressOffset(buffer_var, dtype, offset); - } - - PrimExpr MakeTaggedAccessPtr(DataType ptr_type, - Var buffer_var, - DataType dtype, - PrimExpr offset, - const MemoryInfo& info) { - if (ptr_type.is_handle()) { - CHECK(info->head_address.defined()) - << buffer_var << " is not adddressable."; - return AddressOffset(buffer_var, dtype, offset); - } - int dtype_bits = dtype.bits() * dtype.lanes(); - CHECK_EQ(info->unit_bits % dtype_bits, 0); - return cast(ptr_type, - tir::Simplify(offset / make_const( - offset.dtype(), info->unit_bits / dtype_bits))); - } - // The storage entry. - struct StorageEntry { - // Whether it is tagged memory. - StorageScope scope; - // The memory info if any. - MemoryInfo info; - // Allocation counter - int alloc_count{0}; - }; - // The storage scope of each buffer - std::unordered_map storage_info_; -}; - -Stmt LowerStorageAccessInfo(Stmt stmt) { - return StorageAccessInfoLower()(std::move(stmt)); -} - -LoweredFunc LowerDeviceStorageAccessInfo(LoweredFunc f) { - auto n = make_object(*f.operator->()); - n->body = LowerStorageAccessInfo(f->body); - return LoweredFunc(n); -} - } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/combine_context_call.cc b/src/tir/transforms/combine_context_call.cc index ed352c1..069de57 100644 --- a/src/tir/transforms/combine_context_call.cc +++ b/src/tir/transforms/combine_context_call.cc @@ -126,7 +126,7 @@ Pass CombineContextCall() { n->body = ContextCallCombiner().Combine(n->body); return f; }; - return CreatePrimFuncPass(pass_func, 0, "CombineContextCall", {}); + return CreatePrimFuncPass(pass_func, 0, "tir.CombineContextCall", {}); } TVM_REGISTER_GLOBAL("tir.transform.CombineContextCall") diff --git a/src/tir/transforms/lower_device_storage_access_info.cc b/src/tir/transforms/lower_device_storage_access_info.cc new file mode 100644 index 0000000..5797665 --- /dev/null +++ b/src/tir/transforms/lower_device_storage_access_info.cc @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file lower_device_storage_access.cc + * \brief Lower the special device storage access. + */ +#include +#include +#include +#include +#include + +#include + +#include "../pass/ir_util.h" +#include "../../runtime/thread_storage_scope.h" + +namespace tvm { +namespace tir { + +using runtime::StorageScope; +using runtime::StorageRank; + +class StorageAccessInfoLower : public StmtExprMutator { + public: + Stmt VisitStmt_(const AllocateNode* op) final { + // Lower allocate to device allocate when needed. + Stmt stmt = StmtExprMutator::VisitStmt_(op); + op = stmt.as(); + // For special memory, remove allocate, or use head expr + auto it = storage_info_.find(op->buffer_var.get()); + if (it != storage_info_.end() && it->second.info.defined()) { + const MemoryInfo& info = it->second.info; + ++it->second.alloc_count; + CHECK_LE(it->second.alloc_count, 1) + << "Double allocation of " << it->second.scope.to_string(); + if (info->head_address.defined()) { + return AllocateNode::make( + op->buffer_var, op->dtype, op->extents, op->condition, + op->body, info->head_address, "nop"); + } + return op->body; + } else { + return stmt; + } + } + Stmt VisitStmt_(const AttrStmtNode* op) final { + if (op->attr_key == attr::storage_scope) { + const VarNode* buf = op->node.as(); + StorageScope scope = StorageScope::make(op->value.as()->value); + StorageEntry e; + e.scope = scope; + if (scope.tag.length() != 0) { + e.info = GetMemoryInfo(op->value.as()->value); + CHECK(e.info.defined()) << "Cannot find memory info of " << scope.to_string(); + } + storage_info_[buf] = e; + return StmtExprMutator::VisitStmt_(op); + + } else { + return StmtExprMutator::VisitStmt_(op); + } + } + + PrimExpr VisitExpr_(const CallNode* op) final { + if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { + return MakeAccessPtr(op); + } else { + return StmtExprMutator::VisitExpr_(op); + } + } + + private: + // tvm_access_ptr + PrimExpr MakeAccessPtr(const CallNode* op) { + // Specially handle the buffer packed intrinsic + PrimExpr expr = StmtExprMutator::VisitExpr_(op); + op = expr.as(); + CHECK_EQ(op->args.size(), 5U); + DataType dtype = op->args[0].dtype(); + const VarNode* buffer = op->args[1].as(); + Var buffer_var = Downcast(op->args[1]); + PrimExpr offset = op->args[2]; + auto it = storage_info_.find(buffer); + if (it != storage_info_.end() && it->second.info.defined()) { + return MakeTaggedAccessPtr( + op->dtype, buffer_var, dtype, offset, + it->second.info); + } + CHECK(op->dtype.is_handle()); + // Change to address_of + return AddressOffset(buffer_var, dtype, offset); + } + + PrimExpr MakeTaggedAccessPtr(DataType ptr_type, + Var buffer_var, + DataType dtype, + PrimExpr offset, + const MemoryInfo& info) { + if (ptr_type.is_handle()) { + CHECK(info->head_address.defined()) + << buffer_var << " is not adddressable."; + return AddressOffset(buffer_var, dtype, offset); + } + int dtype_bits = dtype.bits() * dtype.lanes(); + CHECK_EQ(info->unit_bits % dtype_bits, 0); + return cast(ptr_type, + tir::Simplify(offset / make_const( + offset.dtype(), info->unit_bits / dtype_bits))); + } + // The storage entry. + struct StorageEntry { + // Whether it is tagged memory. + StorageScope scope; + // The memory info if any. + MemoryInfo info; + // Allocation counter + int alloc_count{0}; + }; + // The storage scope of each buffer + std::unordered_map storage_info_; +}; + +Stmt LowerStorageAccessInfo(Stmt stmt) { + return StorageAccessInfoLower()(std::move(stmt)); +} + +LoweredFunc LowerDeviceStorageAccessInfo(LoweredFunc f) { + auto n = make_object(*f.operator->()); + n->body = LowerStorageAccessInfo(f->body); + return LoweredFunc(n); +} + +namespace transform { + +Pass LowerDeviceStorageAccessInfo() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + n->body = StorageAccessInfoLower()(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass( + pass_func, 0, "tir.LowerDeviceStorageAccessInfo", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.LowerDeviceStorageAccessInfo") +.set_body_typed(LowerDeviceStorageAccessInfo); + +} // namespace transform +} // namespace tir +} // namespace tvm diff --git a/src/tir/pass/lower_intrin.cc b/src/tir/transforms/lower_intrin.cc similarity index 90% rename from src/tir/pass/lower_intrin.cc rename to src/tir/transforms/lower_intrin.cc index d39624d..6d4863d 100644 --- a/src/tir/pass/lower_intrin.cc +++ b/src/tir/transforms/lower_intrin.cc @@ -23,11 +23,12 @@ */ #include #include +#include #include #include +#include #include -#include "ir_util.h" #include "../../arith/pattern_match.h" #include "../../arith/ir_mutator_with_analyzer.h" @@ -39,15 +40,12 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { using IRMutatorWithAnalyzer::VisitStmt_; using IRMutatorWithAnalyzer::VisitExpr_; - IntrinInjecter(arith::Analyzer* analyzer, std::string target) + IntrinInjecter(arith::Analyzer* analyzer, std::string target_name) : IRMutatorWithAnalyzer(analyzer) { - std::istringstream is(target); - std::string starget; - is >> starget; - patterns_.push_back("tvm.intrin.rule." + starget + "."); + patterns_.push_back("tvm.intrin.rule." + target_name + "."); patterns_.push_back("tvm.intrin.rule.default."); fma_ = runtime::Registry::Get(patterns_[0] + "fma"); - if (target == "stackvm") { + if (target_name == "stackvm") { support_bitwise_op_ = false; } } @@ -280,21 +278,41 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { bool support_bitwise_op_{true}; }; -Stmt LowerIntrinStmt(Stmt stmt, const std::string& target) { +Stmt LowerIntrinStmt(Stmt stmt, const std::string target_name) { arith::Analyzer analyzer; - return IntrinInjecter(&analyzer, target)(std::move(stmt)); + return IntrinInjecter(&analyzer, target_name)(std::move(stmt)); } LoweredFunc LowerIntrin(LoweredFunc f, const std::string& target) { auto n = make_object(*f.operator->()); - n->body = LowerIntrinStmt(n->body, target); + std::istringstream is(target); + std::string target_name; + is >> target_name; + n->body = LowerIntrinStmt(n->body, target_name); return LoweredFunc(n); } -// Register the api only for test purposes -TVM_REGISTER_GLOBAL("ir_pass._LowerIntrinStmt") -.set_body_typed(LowerIntrinStmt); +namespace transform { + +Pass LowerIntrin() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + auto target = f->GetAttr(tvm::attr::kTarget); + CHECK(target.defined()) + << "LowerIntrin: Require the target attribute"; + arith::Analyzer analyzer; + n->body = + IntrinInjecter(&analyzer, target->target_name)(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.LowerIntrin", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.LowerIntrin") +.set_body_typed(LowerIntrin); + +} // namespace transform } // namespace tir } // namespace tvm diff --git a/src/tir/pass/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc similarity index 94% rename from src/tir/pass/lower_warp_memory.cc rename to src/tir/transforms/lower_warp_memory.cc index 385a5b4..808b081 100644 --- a/src/tir/pass/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -30,9 +30,14 @@ #include #include +#include +#include +#include #include + #include -#include "ir_util.h" + +#include "../pass/ir_util.h" #include "../../arith/compute_expr.h" #include "../../runtime/thread_storage_scope.h" @@ -388,5 +393,24 @@ LowerWarpMemory(LoweredFunc f, int warp_size) { return LoweredFunc(n); } +namespace transform { + +Pass LowerWarpMemory() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + auto target = f->GetAttr(tvm::attr::kTarget); + CHECK(target.defined()) + << "LowerWarpMemory: Require the target attribute"; + n->body = WarpMemoryRewriter(target->thread_warp_size).Rewrite(n->body); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.LowerWarpMemory", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.LowerWarpMemory") +.set_body_typed(LowerWarpMemory); + +} // namespace transform + } // namespace tir } // namespace tvm diff --git a/tests/python/unittest/test_tir_transform_combine_context_call.py b/tests/python/unittest/test_tir_transform_combine_context_call.py index e76fb33..8140ddb 100644 --- a/tests/python/unittest/test_tir_transform_combine_context_call.py +++ b/tests/python/unittest/test_tir_transform_combine_context_call.py @@ -39,7 +39,7 @@ def test_for(): f = tvm.tir.ir_pass.MakeAPI(body, "func", [dev_type, n], 2, True) # temp adapter to convert loweredFunc to IRModule - # to test passes in the new style. + # to test passes in the new style.x mod = tvm.testing.LoweredFuncsToIRModule([f]) mod = tvm.tir.transform.CombineContextCall()(mod) diff --git a/tests/python/unittest/test_tir_pass_lower_intrin.py b/tests/python/unittest/test_tir_transform_lower_intrin.py similarity index 77% rename from tests/python/unittest/test_tir_pass_lower_intrin.py rename to tests/python/unittest/test_tir_transform_lower_intrin.py index f36b4a5..b2e984a 100644 --- a/tests/python/unittest/test_tir_pass_lower_intrin.py +++ b/tests/python/unittest/test_tir_transform_lower_intrin.py @@ -18,12 +18,15 @@ import tvm from tvm import te import numpy as np -def lower_intrin(stmt): +def lower_intrin(params, stmt): """wrapper to call transformation in stmt""" lower_expr = isinstance(stmt, tvm.tir.PrimExpr) stmt = tvm.tir.Evaluate(stmt) if lower_expr else stmt stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt) - stmt = tvm.tir.ir_pass._LowerIntrinStmt(stmt, "llvm") + func = tvm.tir.PrimFunc(params, stmt).with_attr( + "target", tvm.target.create("llvm")) + func = tvm.tir.transform.LowerIntrin()(tvm.IRModule.from_expr(func))["main"] + stmt = func.body return stmt.value if lower_expr else stmt.body @@ -70,19 +73,19 @@ def test_lower_floordiv(): y = te.var("y", dtype=dtype) zero = tvm.tir.const(0, dtype) # no constraints - res = lower_intrin(tvm.te.floordiv(x, y)) + res = lower_intrin([x, y], tvm.te.floordiv(x, y)) check_value(res, x, y, data, lambda a, b: a // b) # rhs >= 0 - res = lower_intrin(tvm.tir.Select(y >= 0, tvm.te.floordiv(x, y), zero)) + res = lower_intrin([x, y], tvm.tir.Select(y >= 0, tvm.te.floordiv(x, y), zero)) check_value(res, x, y, data, lambda a, b: a // b if b > 0 else 0) # involves max - res = lower_intrin(tvm.tir.Select(y >= 0, tvm.te.max(tvm.te.floordiv(x, y), zero), zero)) + res = lower_intrin([x, y], tvm.tir.Select(y >= 0, tvm.te.max(tvm.te.floordiv(x, y), zero), zero)) check_value(res, x, y, data, lambda a, b: max(a // b, 0) if b > 0 else 0) # lhs >= 0 - res = lower_intrin(tvm.tir.Select(tvm.tir.all(y >= 0, x >= 0), tvm.te.floordiv(x, y), zero)) + res = lower_intrin([x, y], tvm.tir.Select(tvm.tir.all(y >= 0, x >= 0), tvm.te.floordiv(x, y), zero)) check_value(res, x, y, data, lambda a, b: a // b if b > 0 and a >= 0 else 0) # const power of two - res = lower_intrin(tvm.te.floordiv(x, tvm.tir.const(8, dtype=dtype))) + res = lower_intrin([x, y], tvm.te.floordiv(x, tvm.tir.const(8, dtype=dtype))) check_value(res, x, y, [(a, b) for a, b in data if b == 8], lambda a, b: a // b) @@ -93,16 +96,16 @@ def test_lower_floormod(): y = te.var("y", dtype=dtype) zero = tvm.tir.const(0, dtype) # no constraints - res = lower_intrin(tvm.te.floormod(x, y)) + res = lower_intrin([x, y], tvm.te.floormod(x, y)) check_value(res, x, y, data, lambda a, b: a % b) # rhs >= 0 - res = lower_intrin(tvm.tir.Select(y >= 0, tvm.te.floormod(x, y), zero)) + res = lower_intrin([x, y], tvm.tir.Select(y >= 0, tvm.te.floormod(x, y), zero)) check_value(res, x, y, data, lambda a, b: a % b if b > 0 else 0) # lhs >= 0 - res = lower_intrin(tvm.tir.Select(tvm.tir.all(y >= 0, x >= 0), tvm.te.floormod(x, y), zero)) + res = lower_intrin([x, y], tvm.tir.Select(tvm.tir.all(y >= 0, x >= 0), tvm.te.floormod(x, y), zero)) check_value(res, x, y, data, lambda a, b: a % b if b > 0 and a >= 0 else 0) # const power of two - res = lower_intrin(tvm.te.floormod(x, tvm.tir.const(8, dtype=dtype))) + res = lower_intrin([x, y], tvm.te.floormod(x, tvm.tir.const(8, dtype=dtype))) check_value(res, x, y, [(a, b) for a, b in data if b == 8], lambda a, b: a % b) diff --git a/tests/python/unittest/test_tir_pass_lower_warp_memory.py b/tests/python/unittest/test_tir_transform_lower_warp_memory.py similarity index 72% rename from tests/python/unittest/test_tir_pass_lower_warp_memory.py rename to tests/python/unittest/test_tir_transform_lower_warp_memory.py index 266ca7e..66d3cfb 100644 --- a/tests/python/unittest/test_tir_pass_lower_warp_memory.py +++ b/tests/python/unittest/test_tir_transform_lower_warp_memory.py @@ -24,18 +24,26 @@ def test_lower_warp_mem(): s = te.create_schedule(B.op) AA = s.cache_read(A, "warp", [B]) - xo, xi = s[B].split(B.op.axis[0], 32) - xi0, xi1 = s[B].split(xi, factor=16) + xo, xi = s[B].split(B.op.axis[0], 64) + xi0, xi1 = s[B].split(xi, factor=32) tx = te.thread_axis("threadIdx.x") s[B].bind(xi1, tx) s[B].bind(xo, te.thread_axis("blockIdx.x")) s[AA].compute_at(s[B], xo) - xo, xi = s[AA].split(s[AA].op.axis[0], 16) + xo, xi = s[AA].split(s[AA].op.axis[0], 32) s[AA].bind(xi, tx) f = tvm.lower(s, [A, B]) fhost, fdevice = tvm.tir.ir_pass.SplitHostDevice(f) - fdevice = tvm.tir.ir_pass.LowerWarpMemory(fdevice, 16) + + # temp adapter to convert loweredFunc to IRModule + # to test passes in the new style. + fname = fdevice.name + mod = tvm.testing.LoweredFuncsToIRModule([fdevice]) + cuda_target = tvm.target.create("cuda") + assert cuda_target.thread_warp_size == 32 + mod = tvm.IRModule.from_expr(mod[fname].with_attr("target", cuda_target)) + fdevice = tvm.tir.transform.LowerWarpMemory()(mod)["main"] assert(fdevice.body.body.value.value == "local") assert(fdevice.body.body.body.extents[0].value == 2)