* \return A Relay module.
*/
TVM_DLL static IRModule FromText(const std::string& text, const std::string& source_path);
+
+ /*! \brief Declare the container type. */
+ using ContainerType = IRModuleNode;
};
/*!
const tvm::Array<tvm::PrimExpr>& required);
/*!
- * \brief Create PrimFuncPass to combine context calls in the host function.
+ * \brief Combine context calls in the host function.
*
* \return The pass.
*/
-Pass CombineContextCall();
+TVM_DLL Pass CombineContextCall();
+
+/*!
+ * \brief Lower the target specific function intrinsics in each of the function.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass LowerIntrin();
+
+/*!
+ * \brief Lower attached storage access information on device.
+ *
+ * \note Run this pass after all storage access analysis finish.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass LowerDeviceStorageAccessInfo();
+
+/*!
+ * \brief Lower warp memory access to low-level device related function calls.
+ * \return The pass.
+ */
+TVM_DLL Pass LowerWarpMemory();
} // namespace transform
} // namespace tir
The result pass
"""
return _ffi_api.CombineContextCall()
+
+
+def LowerIntrin():
+ """Lower target specific intrinsic calls.
+
+ Returns
+ -------
+ fpass : tvm.ir.transform.Pass
+ The result pass
+ """
+ return _ffi_api.LowerIntrin()
+
+
+def LowerDeviceStorageAccessInfo():
+ """Lower attached storage access information on device.
+
+ Returns
+ -------
+ fpass : tvm.ir.transform.Pass
+ The result pass
+
+ Note
+ ----
+ Run this pass after all storage access analysis finish.
+ """
+ return _ffi_api.LowerDeviceStorageAccessInfo()
+
+
+def LowerWarpMemory():
+ """Lower warp memory access to low-level device related function calls.
+
+ Returns
+ -------
+ fpass : tvm.ir.transform.Pass
+ The result pass
+ """
+ return _ffi_api.LowerWarpMemory()
bool update = args[3];
CHECK(val->IsInstance<RelayExprNode>());
- if (val->IsInstance<relay::FunctionNode>()) {
- mod->Add(var, Downcast<relay::Function>(val), update);
+ if (val->IsInstance<BaseFuncNode>()) {
+ mod->Add(var, Downcast<BaseFunc>(val), update);
} else if (val->IsInstance<GlobalVarNode>()) {
GlobalVar gv = Downcast<GlobalVar>(val);
auto mod_copy = IRModule(make_object<IRModuleNode>(*mod.operator->()));
.set_body([](TVMArgs args, TVMRetValue* ret) {
Pass pass = args[0];
IRModule mod = args[1];
+ ObjectRef ref = args[1];
*ret = pass(mod);
});
return (*bf)(module, target->str());
}
-
-
// convert legacy LoweredFunc to PrimFunc.
tir::PrimFunc ToPrimFunc(tir::LoweredFunc from) {
// remap args to attach type annotations.
return it->second;
}
-
-class StorageAccessInfoLower : public StmtExprMutator {
- public:
- Stmt VisitStmt_(const AllocateNode* op) final {
- // Lower allocate to device allocate when needed.
- Stmt stmt = StmtExprMutator::VisitStmt_(op);
- op = stmt.as<AllocateNode>();
- // For special memory, remove allocate, or use head expr
- auto it = storage_info_.find(op->buffer_var.get());
- if (it != storage_info_.end() && it->second.info.defined()) {
- const MemoryInfo& info = it->second.info;
- ++it->second.alloc_count;
- CHECK_LE(it->second.alloc_count, 1)
- << "Double allocation of " << it->second.scope.to_string();
- if (info->head_address.defined()) {
- return AllocateNode::make(
- op->buffer_var, op->dtype, op->extents, op->condition,
- op->body, info->head_address, "nop");
- }
- return op->body;
- } else {
- return stmt;
- }
- }
- Stmt VisitStmt_(const AttrStmtNode* op) final {
- if (op->attr_key == attr::storage_scope) {
- const VarNode* buf = op->node.as<VarNode>();
- StorageScope scope = StorageScope::make(op->value.as<StringImmNode>()->value);
- StorageEntry e;
- e.scope = scope;
- if (scope.tag.length() != 0) {
- e.info = GetMemoryInfo(op->value.as<StringImmNode>()->value);
- CHECK(e.info.defined()) << "Cannot find memory info of " << scope.to_string();
- }
- storage_info_[buf] = e;
- return StmtExprMutator::VisitStmt_(op);
-
- } else {
- return StmtExprMutator::VisitStmt_(op);
- }
- }
-
- PrimExpr VisitExpr_(const CallNode* op) final {
- if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
- return MakeAccessPtr(op);
- } else {
- return StmtExprMutator::VisitExpr_(op);
- }
- }
-
- private:
- // tvm_access_ptr
- PrimExpr MakeAccessPtr(const CallNode* op) {
- // Specially handle the buffer packed intrinsic
- PrimExpr expr = StmtExprMutator::VisitExpr_(op);
- op = expr.as<CallNode>();
- CHECK_EQ(op->args.size(), 5U);
- DataType dtype = op->args[0].dtype();
- const VarNode* buffer = op->args[1].as<VarNode>();
- Var buffer_var = Downcast<Var>(op->args[1]);
- PrimExpr offset = op->args[2];
- auto it = storage_info_.find(buffer);
- if (it != storage_info_.end() && it->second.info.defined()) {
- return MakeTaggedAccessPtr(
- op->dtype, buffer_var, dtype, offset,
- it->second.info);
- }
- CHECK(op->dtype.is_handle());
- // Change to address_of
- return AddressOffset(buffer_var, dtype, offset);
- }
-
- PrimExpr MakeTaggedAccessPtr(DataType ptr_type,
- Var buffer_var,
- DataType dtype,
- PrimExpr offset,
- const MemoryInfo& info) {
- if (ptr_type.is_handle()) {
- CHECK(info->head_address.defined())
- << buffer_var << " is not adddressable.";
- return AddressOffset(buffer_var, dtype, offset);
- }
- int dtype_bits = dtype.bits() * dtype.lanes();
- CHECK_EQ(info->unit_bits % dtype_bits, 0);
- return cast(ptr_type,
- tir::Simplify(offset / make_const(
- offset.dtype(), info->unit_bits / dtype_bits)));
- }
- // The storage entry.
- struct StorageEntry {
- // Whether it is tagged memory.
- StorageScope scope;
- // The memory info if any.
- MemoryInfo info;
- // Allocation counter
- int alloc_count{0};
- };
- // The storage scope of each buffer
- std::unordered_map<const VarNode*, StorageEntry> storage_info_;
-};
-
-Stmt LowerStorageAccessInfo(Stmt stmt) {
- return StorageAccessInfoLower()(std::move(stmt));
-}
-
-LoweredFunc LowerDeviceStorageAccessInfo(LoweredFunc f) {
- auto n = make_object<LoweredFuncNode>(*f.operator->());
- n->body = LowerStorageAccessInfo(f->body);
- return LoweredFunc(n);
-}
-
} // namespace tir
} // namespace tvm
n->body = ContextCallCombiner().Combine(n->body);
return f;
};
- return CreatePrimFuncPass(pass_func, 0, "CombineContextCall", {});
+ return CreatePrimFuncPass(pass_func, 0, "tir.CombineContextCall", {});
}
TVM_REGISTER_GLOBAL("tir.transform.CombineContextCall")
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file lower_device_storage_access.cc
+ * \brief Lower the special device storage access.
+ */
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+#include <tvm/tir/buffer.h>
+#include <tvm/target/target_info.h>
+#include <tvm/runtime/registry.h>
+
+#include <tvm/tir/ir_pass.h>
+
+#include "../pass/ir_util.h"
+#include "../../runtime/thread_storage_scope.h"
+
+namespace tvm {
+namespace tir {
+
+using runtime::StorageScope;
+using runtime::StorageRank;
+
+class StorageAccessInfoLower : public StmtExprMutator {
+ public:
+ Stmt VisitStmt_(const AllocateNode* op) final {
+ // Lower allocate to device allocate when needed.
+ Stmt stmt = StmtExprMutator::VisitStmt_(op);
+ op = stmt.as<AllocateNode>();
+ // For special memory, remove allocate, or use head expr
+ auto it = storage_info_.find(op->buffer_var.get());
+ if (it != storage_info_.end() && it->second.info.defined()) {
+ const MemoryInfo& info = it->second.info;
+ ++it->second.alloc_count;
+ CHECK_LE(it->second.alloc_count, 1)
+ << "Double allocation of " << it->second.scope.to_string();
+ if (info->head_address.defined()) {
+ return AllocateNode::make(
+ op->buffer_var, op->dtype, op->extents, op->condition,
+ op->body, info->head_address, "nop");
+ }
+ return op->body;
+ } else {
+ return stmt;
+ }
+ }
+ Stmt VisitStmt_(const AttrStmtNode* op) final {
+ if (op->attr_key == attr::storage_scope) {
+ const VarNode* buf = op->node.as<VarNode>();
+ StorageScope scope = StorageScope::make(op->value.as<StringImmNode>()->value);
+ StorageEntry e;
+ e.scope = scope;
+ if (scope.tag.length() != 0) {
+ e.info = GetMemoryInfo(op->value.as<StringImmNode>()->value);
+ CHECK(e.info.defined()) << "Cannot find memory info of " << scope.to_string();
+ }
+ storage_info_[buf] = e;
+ return StmtExprMutator::VisitStmt_(op);
+
+ } else {
+ return StmtExprMutator::VisitStmt_(op);
+ }
+ }
+
+ PrimExpr VisitExpr_(const CallNode* op) final {
+ if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
+ return MakeAccessPtr(op);
+ } else {
+ return StmtExprMutator::VisitExpr_(op);
+ }
+ }
+
+ private:
+ // tvm_access_ptr
+ PrimExpr MakeAccessPtr(const CallNode* op) {
+ // Specially handle the buffer packed intrinsic
+ PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+ op = expr.as<CallNode>();
+ CHECK_EQ(op->args.size(), 5U);
+ DataType dtype = op->args[0].dtype();
+ const VarNode* buffer = op->args[1].as<VarNode>();
+ Var buffer_var = Downcast<Var>(op->args[1]);
+ PrimExpr offset = op->args[2];
+ auto it = storage_info_.find(buffer);
+ if (it != storage_info_.end() && it->second.info.defined()) {
+ return MakeTaggedAccessPtr(
+ op->dtype, buffer_var, dtype, offset,
+ it->second.info);
+ }
+ CHECK(op->dtype.is_handle());
+ // Change to address_of
+ return AddressOffset(buffer_var, dtype, offset);
+ }
+
+ PrimExpr MakeTaggedAccessPtr(DataType ptr_type,
+ Var buffer_var,
+ DataType dtype,
+ PrimExpr offset,
+ const MemoryInfo& info) {
+ if (ptr_type.is_handle()) {
+ CHECK(info->head_address.defined())
+ << buffer_var << " is not adddressable.";
+ return AddressOffset(buffer_var, dtype, offset);
+ }
+ int dtype_bits = dtype.bits() * dtype.lanes();
+ CHECK_EQ(info->unit_bits % dtype_bits, 0);
+ return cast(ptr_type,
+ tir::Simplify(offset / make_const(
+ offset.dtype(), info->unit_bits / dtype_bits)));
+ }
+ // The storage entry.
+ struct StorageEntry {
+ // Whether it is tagged memory.
+ StorageScope scope;
+ // The memory info if any.
+ MemoryInfo info;
+ // Allocation counter
+ int alloc_count{0};
+ };
+ // The storage scope of each buffer
+ std::unordered_map<const VarNode*, StorageEntry> storage_info_;
+};
+
+Stmt LowerStorageAccessInfo(Stmt stmt) {
+ return StorageAccessInfoLower()(std::move(stmt));
+}
+
+LoweredFunc LowerDeviceStorageAccessInfo(LoweredFunc f) {
+ auto n = make_object<LoweredFuncNode>(*f.operator->());
+ n->body = LowerStorageAccessInfo(f->body);
+ return LoweredFunc(n);
+}
+
+namespace transform {
+
+Pass LowerDeviceStorageAccessInfo() {
+ auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+ auto* n = f.CopyOnWrite();
+ n->body = StorageAccessInfoLower()(std::move(n->body));
+ return f;
+ };
+ return CreatePrimFuncPass(
+ pass_func, 0, "tir.LowerDeviceStorageAccessInfo", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.LowerDeviceStorageAccessInfo")
+.set_body_typed(LowerDeviceStorageAccessInfo);
+
+} // namespace transform
+} // namespace tir
+} // namespace tvm
*/
#include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h>
+#include <tvm/tir/transform.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/op.h>
+#include <tvm/target/target.h>
#include <unordered_set>
-#include "ir_util.h"
#include "../../arith/pattern_match.h"
#include "../../arith/ir_mutator_with_analyzer.h"
using IRMutatorWithAnalyzer::VisitStmt_;
using IRMutatorWithAnalyzer::VisitExpr_;
- IntrinInjecter(arith::Analyzer* analyzer, std::string target)
+ IntrinInjecter(arith::Analyzer* analyzer, std::string target_name)
: IRMutatorWithAnalyzer(analyzer) {
- std::istringstream is(target);
- std::string starget;
- is >> starget;
- patterns_.push_back("tvm.intrin.rule." + starget + ".");
+ patterns_.push_back("tvm.intrin.rule." + target_name + ".");
patterns_.push_back("tvm.intrin.rule.default.");
fma_ = runtime::Registry::Get(patterns_[0] + "fma");
- if (target == "stackvm") {
+ if (target_name == "stackvm") {
support_bitwise_op_ = false;
}
}
bool support_bitwise_op_{true};
};
-Stmt LowerIntrinStmt(Stmt stmt, const std::string& target) {
+Stmt LowerIntrinStmt(Stmt stmt, const std::string target_name) {
arith::Analyzer analyzer;
- return IntrinInjecter(&analyzer, target)(std::move(stmt));
+ return IntrinInjecter(&analyzer, target_name)(std::move(stmt));
}
LoweredFunc
LowerIntrin(LoweredFunc f, const std::string& target) {
auto n = make_object<LoweredFuncNode>(*f.operator->());
- n->body = LowerIntrinStmt(n->body, target);
+ std::istringstream is(target);
+ std::string target_name;
+ is >> target_name;
+ n->body = LowerIntrinStmt(n->body, target_name);
return LoweredFunc(n);
}
-// Register the api only for test purposes
-TVM_REGISTER_GLOBAL("ir_pass._LowerIntrinStmt")
-.set_body_typed(LowerIntrinStmt);
+namespace transform {
+
+Pass LowerIntrin() {
+ auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+ auto* n = f.CopyOnWrite();
+ auto target = f->GetAttr<Target>(tvm::attr::kTarget);
+ CHECK(target.defined())
+ << "LowerIntrin: Require the target attribute";
+ arith::Analyzer analyzer;
+ n->body =
+ IntrinInjecter(&analyzer, target->target_name)(std::move(n->body));
+ return f;
+ };
+ return CreatePrimFuncPass(pass_func, 0, "tir.LowerIntrin", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.LowerIntrin")
+.set_body_typed(LowerIntrin);
+
+} // namespace transform
} // namespace tir
} // namespace tvm
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+#include <tvm/target/target.h>
+#include <tvm/runtime/registry.h>
#include <tvm/tir/ir_pass.h>
+
#include <unordered_set>
-#include "ir_util.h"
+
+#include "../pass/ir_util.h"
#include "../../arith/compute_expr.h"
#include "../../runtime/thread_storage_scope.h"
return LoweredFunc(n);
}
+namespace transform {
+
+Pass LowerWarpMemory() {
+ auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+ auto* n = f.CopyOnWrite();
+ auto target = f->GetAttr<Target>(tvm::attr::kTarget);
+ CHECK(target.defined())
+ << "LowerWarpMemory: Require the target attribute";
+ n->body = WarpMemoryRewriter(target->thread_warp_size).Rewrite(n->body);
+ return f;
+ };
+ return CreatePrimFuncPass(pass_func, 0, "tir.LowerWarpMemory", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.LowerWarpMemory")
+.set_body_typed(LowerWarpMemory);
+
+} // namespace transform
+
} // namespace tir
} // namespace tvm
f = tvm.tir.ir_pass.MakeAPI(body, "func", [dev_type, n], 2, True)
# temp adapter to convert loweredFunc to IRModule
- # to test passes in the new style.
+ # to test passes in the new style.x
mod = tvm.testing.LoweredFuncsToIRModule([f])
mod = tvm.tir.transform.CombineContextCall()(mod)
from tvm import te
import numpy as np
-def lower_intrin(stmt):
+def lower_intrin(params, stmt):
"""wrapper to call transformation in stmt"""
lower_expr = isinstance(stmt, tvm.tir.PrimExpr)
stmt = tvm.tir.Evaluate(stmt) if lower_expr else stmt
stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt)
- stmt = tvm.tir.ir_pass._LowerIntrinStmt(stmt, "llvm")
+ func = tvm.tir.PrimFunc(params, stmt).with_attr(
+ "target", tvm.target.create("llvm"))
+ func = tvm.tir.transform.LowerIntrin()(tvm.IRModule.from_expr(func))["main"]
+ stmt = func.body
return stmt.value if lower_expr else stmt.body
y = te.var("y", dtype=dtype)
zero = tvm.tir.const(0, dtype)
# no constraints
- res = lower_intrin(tvm.te.floordiv(x, y))
+ res = lower_intrin([x, y], tvm.te.floordiv(x, y))
check_value(res, x, y, data, lambda a, b: a // b)
# rhs >= 0
- res = lower_intrin(tvm.tir.Select(y >= 0, tvm.te.floordiv(x, y), zero))
+ res = lower_intrin([x, y], tvm.tir.Select(y >= 0, tvm.te.floordiv(x, y), zero))
check_value(res, x, y, data, lambda a, b: a // b if b > 0 else 0)
# involves max
- res = lower_intrin(tvm.tir.Select(y >= 0, tvm.te.max(tvm.te.floordiv(x, y), zero), zero))
+ res = lower_intrin([x, y], tvm.tir.Select(y >= 0, tvm.te.max(tvm.te.floordiv(x, y), zero), zero))
check_value(res, x, y, data, lambda a, b: max(a // b, 0) if b > 0 else 0)
# lhs >= 0
- res = lower_intrin(tvm.tir.Select(tvm.tir.all(y >= 0, x >= 0), tvm.te.floordiv(x, y), zero))
+ res = lower_intrin([x, y], tvm.tir.Select(tvm.tir.all(y >= 0, x >= 0), tvm.te.floordiv(x, y), zero))
check_value(res, x, y, data, lambda a, b: a // b if b > 0 and a >= 0 else 0)
# const power of two
- res = lower_intrin(tvm.te.floordiv(x, tvm.tir.const(8, dtype=dtype)))
+ res = lower_intrin([x, y], tvm.te.floordiv(x, tvm.tir.const(8, dtype=dtype)))
check_value(res, x, y, [(a, b) for a, b in data if b == 8], lambda a, b: a // b)
y = te.var("y", dtype=dtype)
zero = tvm.tir.const(0, dtype)
# no constraints
- res = lower_intrin(tvm.te.floormod(x, y))
+ res = lower_intrin([x, y], tvm.te.floormod(x, y))
check_value(res, x, y, data, lambda a, b: a % b)
# rhs >= 0
- res = lower_intrin(tvm.tir.Select(y >= 0, tvm.te.floormod(x, y), zero))
+ res = lower_intrin([x, y], tvm.tir.Select(y >= 0, tvm.te.floormod(x, y), zero))
check_value(res, x, y, data, lambda a, b: a % b if b > 0 else 0)
# lhs >= 0
- res = lower_intrin(tvm.tir.Select(tvm.tir.all(y >= 0, x >= 0), tvm.te.floormod(x, y), zero))
+ res = lower_intrin([x, y], tvm.tir.Select(tvm.tir.all(y >= 0, x >= 0), tvm.te.floormod(x, y), zero))
check_value(res, x, y, data, lambda a, b: a % b if b > 0 and a >= 0 else 0)
# const power of two
- res = lower_intrin(tvm.te.floormod(x, tvm.tir.const(8, dtype=dtype)))
+ res = lower_intrin([x, y], tvm.te.floormod(x, tvm.tir.const(8, dtype=dtype)))
check_value(res, x, y, [(a, b) for a, b in data if b == 8], lambda a, b: a % b)
s = te.create_schedule(B.op)
AA = s.cache_read(A, "warp", [B])
- xo, xi = s[B].split(B.op.axis[0], 32)
- xi0, xi1 = s[B].split(xi, factor=16)
+ xo, xi = s[B].split(B.op.axis[0], 64)
+ xi0, xi1 = s[B].split(xi, factor=32)
tx = te.thread_axis("threadIdx.x")
s[B].bind(xi1, tx)
s[B].bind(xo, te.thread_axis("blockIdx.x"))
s[AA].compute_at(s[B], xo)
- xo, xi = s[AA].split(s[AA].op.axis[0], 16)
+ xo, xi = s[AA].split(s[AA].op.axis[0], 32)
s[AA].bind(xi, tx)
f = tvm.lower(s, [A, B])
fhost, fdevice = tvm.tir.ir_pass.SplitHostDevice(f)
- fdevice = tvm.tir.ir_pass.LowerWarpMemory(fdevice, 16)
+
+ # temp adapter to convert loweredFunc to IRModule
+ # to test passes in the new style.
+ fname = fdevice.name
+ mod = tvm.testing.LoweredFuncsToIRModule([fdevice])
+ cuda_target = tvm.target.create("cuda")
+ assert cuda_target.thread_warp_size == 32
+ mod = tvm.IRModule.from_expr(mod[fname].with_attr("target", cuda_target))
+ fdevice = tvm.tir.transform.LowerWarpMemory()(mod)["main"]
assert(fdevice.body.body.value.value == "local")
assert(fdevice.body.body.body.extents[0].value == 2)