ConstIntBound operator()(const PrimExpr& expr);
/*!
+ * \brief analyze the expr with the intermediate memorized to avoid redundant computation
+ * \param expr The expression of interest.
+ * \param bound The lookup table to store the intermediate results
+ * \return the result of the analysis.
+ */
+ ConstIntBound operator()(const PrimExpr& expr,
+ std::unordered_map<const PrimExprNode*, ConstIntBound>* bound);
+
+ /*!
* \brief Update constant int bound information of var.
*
* \param var The variable of interest.
Stmt HoistIfThenElse(Stmt stmt);
/*!
+ * \brief Narrow down PrimExpr datatype in stmt to target_bits.
+ * \note Run this pass after StorageFlatten.
+ * \param stmt The stmt to do datatype rewrite
+ * \param target_bits the bit of target datatype
+ * \return Transformed stmt.
+ */
+Stmt NarrowDataType(Stmt stmt, int target_bits);
+
+/*!
* \brief Make an user callable API LoweredFunc.
*
* The main task of this function is to create code to :
*/
TVM_DLL Pass LowerWarpMemory();
+
+/*!
+ * \brief Narrow down PrimExpr datatype in stmt to target_bits.
+ *
+ * \note Run this pass after StorageFlatten.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass NarrowDataType();
+
} // namespace transform
} // namespace tir
} // namespace tvm
# Phase 1
stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds)
stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers)
+ stmt = ir_pass.NarrowDataType(stmt, 32)
stmt = ir_pass.CanonicalSimplify(stmt)
for f in lower_phase1:
stmt = f(stmt)
raise TypeError("dom need to be Range")
name = var if var is not None else "iter"
- var = Var(name, dtype="int32") if not isinstance(var, Var) else var
+ dtype = "int32" if dom is None else dom.extent.dtype
+ var = Var(name, dtype=dtype) if not isinstance(var, Var) else var
self.__init_handle_by_constructor__(
_ffi_api.IterVar, dom, var, iter_type, thread_tag)
def __getitem__(self, index):
t = DataType(self._content_type)
if t.lanes > 1:
- index = _expr.Ramp(index * t.lanes, 1, t.lanes)
+ base = index * t.lanes
+ index = _expr.Ramp(base, const(1, base.dtype), t.lanes)
return _expr.Load(self._content_type, self._buffer_var, index)
def __setitem__(self, index, value):
value.dtype, self._content_type))
t = DataType(self._content_type)
if t.lanes > 1:
- index = _expr.Ramp(index * t.lanes, 1, t.lanes)
+ base = index * t.lanes
+ index = _expr.Ramp(base, const(1, base.dtype), t.lanes)
self._builder.emit(_stmt.Store(self._buffer_var, value, index))
The result pass
"""
return _ffi_api.LowerWarpMemory()
+
+
+def NarrowDataType():
+ """Narrow down PrimExpr datatype in stmt to target_bits.
+
+ Returns
+ -------
+ fpass : tvm.ir.transform.Pass
+ The result pass
+
+ Note
+ ----
+ Run this pass after StorageFlatten.
+ """
+ return _ffi_api.NarrowDataType()
res = Intersect(res, info.bound);
}
}
+ if (bound_) {
+ const PrimExprNode* op = expr.as<PrimExprNode>();
+ auto val = bound_->find(op);
+ if (val != bound_->end()) {
+ CHECK(val->second->min_value == res.min_value &&
+ val->second->max_value == res.max_value)
+ << "Detected bound for " << expr
+ << "conflicts with memorization";
+ }
+ (*bound_)[op] = ConstIntBound(res.min_value, res.max_value);
+ }
return res;
}
+ Entry VisitExpr_(const RampNode* op) final {
+ // op = {base + i * stride | 0 <= i < lanes}
+ // Entry(op) = Union(Entry(base + i * stride) | 0 <= i < lanes)
+ // Note that `base + i * stride` is linear w.r.t. `i`
+ // Entry(op) = Union(Entry(base + i * stride) | i = 0, i = lanes-1)
+ Entry a = VisitExpr(op->base);
+ Entry b = VisitExpr(op->base + (op->lanes - 1) * op->stride);
+ return Union(a, b);
+ }
+
Entry VisitExpr_(const CastNode* op) final {
Entry a = VisitExpr(op->value);
Entry b = Everything(op->dtype);
}
private:
+ friend class ConstIntBoundAnalyzer;
// internal variable map
std::unordered_map<Var, Entry, ObjectHash, ObjectEqual> var_map_;
// additional bound info
std::vector<BoundInfo> additional_info_;
+ // look up table for memorization
+ std::unordered_map<const PrimExprNode*, ConstIntBound>* bound_{nullptr};
// constants: the limit value means umlimited
// NOTE: kNegInf/kPosInf are used to represent infinity.
static const constexpr int64_t kNegInf = ConstIntBound::kNegInf;
return ConstIntBound(ret.min_value, ret.max_value);
}
+ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr,
+ std::unordered_map<const PrimExprNode*, ConstIntBound>* bound) {
+ impl_->bound_ = bound;
+ Entry ret = impl_->VisitExpr(expr);
+ impl_->bound_ = nullptr;
+ return ConstIntBound(ret.min_value, ret.max_value);
+}
+
void ConstIntBoundAnalyzer::Update(const Var& var,
const ConstIntBound& info,
bool override) {
PrimExpr end = MinNode::make((task_id + make_const(t, 1)) * step, op->extent);
CreateSerialFor(MakeValue(begin),
MakeValue(end),
- ConstInt32(1),
+ llvm::ConstantInt::getSigned(GetLLVMType(end), 1),
op->loop_var,
op->body);
}
CHECK(op->for_type == ForType::Serial);
}
CreateSerialFor(MakeValue(op->min), MakeValue(op->extent),
- ConstInt32(1), op->loop_var, op->body);
+ llvm::ConstantInt::getSigned(GetLLVMType(op->extent), 1),
+ op->loop_var, op->body);
}
n->buffer_type = buffer_type;
if (n->buffer_type == kAutoBroadcast && n->shape.size() > 0 && n->strides.empty()) {
for (size_t i = 0; i < n->shape.size(); ++i) {
- n->strides.push_back(Var("stride"));
+ n->strides.push_back(Var("stride", n->shape[i].dtype()));
}
}
return Buffer(n);
REGISTER_PASS(VerifyCompactBuffer);
REGISTER_PASS(HoistIfThenElse);
REGISTER_PASS(InferFragment)
+REGISTER_PASS(NarrowDataType);
} // namespace tir
} // namespace tvm
// If the loop extent is 1, do not create the loop anymore
return Substitute(body, {{Var{for_node->loop_var}, make_const(DataType::Int(32), 0)}});
} else {
- return ForNode::make(for_node->loop_var, 0, extent,
+ return ForNode::make(for_node->loop_var, IntImm(for_node->min.dtype(), 0), extent,
for_node->for_type, for_node->device_api, body);
}
}
PrimExpr extent = tir::Simplify(op->extent);
const IntImmNode *v1 = extent.as<IntImmNode>();
int value = -1;
- if (v1 != nullptr) {
+ // integers that do not fit in int32_t are treated as symbolic,
+ // as it's impossible to unroll such large loops
+ if (v1 != nullptr && v1->value <= std::numeric_limits<int>::max()) {
value = static_cast<int>(v1->value);
}
return value;
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file narrow_datatype.cc
+ * \brief narrow the datatype of indexing vars
+ */
+
+#include <tvm/tir/ir_pass.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/transform.h>
+#include <tvm/runtime/registry.h>
+#include "../../arith/ir_mutator_with_analyzer.h"
+#include "../../arith/ir_visitor_with_analyzer.h"
+
+namespace tvm {
+namespace tir {
+
+// This pass narrows indexing expressions (like StoreNode::Index)
+// that trivially fit into i32/i16 (denoted by `target_bits_`) to
+// i32/i16. Considering that i32/i16 indices may be more
+// efficient on some backends (while i64 may be more efficient
+// on others, like llvm), we may want this pass when i32/i16
+// indices are more efficient.
+//
+// For Var v, we determine its dtype by examining all the PrimExpr
+// that contains v, denoted by E = {e_0 = v, e_1, e_2, ..., e_k}.
+// If all expressions in E fit into i32/i16, then we think v can be narrowed
+// to i32/i16.
+//
+// To make an indexing expression i32/i16, we must make sure that every
+// component of that expression is of dtype i32/i16. So besides Var, we
+// rewrite the following inside an indexing expression
+// - Var
+// - IntImm
+// - Cast
+//
+// Algorithm:
+// - Use DataTypeVisitor to determine whether a Var can be narrowed or not.
+// - Use DataTypeRewritter to rewrite the components of an indexing expression.
+
+using arith::Analyzer;
+using arith::IRMutatorWithAnalyzer;
+using arith::ConstIntBound;
+
+// Determine the result dtype for Var, IntImm and Cast,
+// which will be stored in `vmap` eventually.
+//
+// Algorithm:
+// We propogate the dtypes of all the Exprs that contain Var `var` into `vmap[var]`.
+// To be more specific, if for each Expr `e` which contains `var`
+// (`var` is a child node of `e` in AST), `e` fits into `target_bits_`,
+// then we narrow `var` into `target_bits_`. That is,
+// `vmap[var] = min(target_bits_, var.dtype.bits())`
+// Otherwise, `var` is not narrowed, that is, `vmap[var] = var.dtype.bits()`
+class DataTypeVisitor final : public StmtExprVisitor {
+ public:
+ explicit DataTypeVisitor(int target_bits)
+ : bits_(target_bits), target_bits_(target_bits) {}
+
+ void VisitExpr(const PrimExpr& e) {
+ if (e.dtype().is_int()) {
+ int bits = max_bits_;
+ const PrimExprNode* op = e.as<PrimExprNode>();
+ if (bound_.find(op) == bound_.end()) {
+ analyzer_.const_int_bound(e, &bound_);
+ }
+ ConstIntBound bound = bound_[op];
+ int64_t ubound = Downcast<IntImm>(max_value(DataType::Int(target_bits_)))->value;
+ int64_t lbound = Downcast<IntImm>(min_value(DataType::Int(target_bits_)))->value;
+ if (e.dtype().bits() <= target_bits_ ||
+ (bound->max_value <= ubound && bound->min_value >= lbound)) {
+ bits = target_bits_;
+ }
+ int tmp = bits > bits_ ? bits : bits_;
+ std::swap(bits_, tmp);
+ StmtExprVisitor::VisitExpr(e);
+ std::swap(bits_, tmp);
+ } else {
+ StmtExprVisitor::VisitExpr(e);
+ }
+ }
+
+ void VisitStmt_(const ForNode* op) {
+ analyzer_.Bind(op->loop_var,
+ Range::make_by_min_extent(op->min, op->extent));
+ vextent_[op->loop_var.as<VarNode>()] = op->extent.dtype();
+ return StmtExprVisitor::VisitStmt_(op);
+ }
+
+ void VisitStmt_(const AttrStmtNode* op) {
+ if (op->attr_key == attr::thread_extent ||
+ op->attr_key == attr::virtual_thread) {
+ IterVar iv = Downcast<IterVar>(op->node);
+ CHECK_NE(iv->thread_tag.length(), 0U);
+ analyzer_.Bind(iv->var,
+ Range::make_by_min_extent(0, op->value));
+ vextent_[iv->var.as<VarNode>()] = op->value.dtype();
+ StmtExprVisitor::VisitStmt_(op);
+ } else {
+ StmtExprVisitor::VisitStmt_(op);
+ }
+ }
+
+ void VisitExpr_(const ReduceNode* op) {
+ // Setup the domain information before simplification.
+ for (const IterVar& iv : op->axis) {
+ analyzer_.Bind(iv->var, iv->dom);
+ vextent_[iv->var.as<VarNode>()] = iv->dom->extent.dtype();
+ }
+ // Recursively call simplification when necessary.
+ StmtExprVisitor::VisitExpr_(op);
+ }
+
+ void VisitExpr_(const VarNode* op) {
+ if (vextent_.find(op) != vextent_.end()) {
+ // We only narrow and never promote, so the result dtype
+ // is upperbounded by its original dtype before rewrite.
+ int bits = std::min(vextent_[op].bits(), bits_);
+ if (vmap.find(op) == vmap.end()) {
+ vmap[op] = op->dtype.with_bits(bits);
+ } else {
+ // We take maximum bits for all the possible Expr where a var occurs
+ vmap[op] = op->dtype.with_bits(std::max(vmap[op].bits(), bits));
+ }
+ }
+ StmtExprVisitor::VisitExpr_(op);
+ }
+
+ void VisitExpr_(const IntImmNode* op) {
+ if (op->dtype.is_int()) {
+ // We only narrow and never promote, so the result dtype
+ // is upperbounded by its original dtype before rewrite.
+ int bits = std::min(op->dtype.bits(), bits_);
+ if (vmap.find(op) == vmap.end()) {
+ vmap[op] = op->dtype.with_bits(bits);
+ } else {
+ vmap[op] = op->dtype.with_bits(std::max(vmap[op].bits(), bits));
+ }
+ }
+ StmtExprVisitor::VisitExpr_(op);
+ }
+
+ void VisitExpr_(const CastNode* op) {
+ if (op->dtype.is_int()) {
+ // We only narrow and never promote, so the result dtype
+ // is upperbounded by its original dtype before rewrite.
+ int bits = std::min(op->dtype.bits(), bits_);
+ if (vmap.find(op) == vmap.end()) {
+ vmap[op] = op->dtype.with_bits(bits);
+ } else {
+ vmap[op] = op->dtype.with_bits(std::max(vmap[op].bits(), bits));
+ }
+ }
+ StmtExprVisitor::VisitExpr_(op);
+ }
+
+ // the narrowed datatype of Var and IntImm
+ std::unordered_map<const PrimExprNode*, DataType> vmap;
+
+ protected:
+ // internal analyzer
+ arith::Analyzer analyzer_;
+
+ private:
+ // the maximum possible bits, which serves as an init value
+ static constexpr const int max_bits_ = 64;
+ // the maximum possible bit of the current expression's return dtype
+ int bits_;
+ // the target bits
+ int target_bits_;
+ // the extent of vars to be rewritten
+ std::unordered_map<const VarNode*, DataType> vextent_;
+ // the memorized bound generated by ConstIntBoundAnalyzer
+ std::unordered_map<const PrimExprNode*, ConstIntBound> bound_;
+};
+
+class DataTypeRewriter : public StmtExprMutator {
+ public:
+ explicit DataTypeRewriter(int target_bits): visitor_(target_bits) {}
+
+ Stmt operator()(Stmt s) {
+ visitor_(s);
+ for (auto i = visitor_.vmap.begin(), last = visitor_.vmap.end(); i != last;) {
+ PrimExpr e = GetRef<PrimExpr>(i->first);
+ if (e.dtype() == i->second) {
+ i = visitor_.vmap.erase(i);
+ } else {
+ ++i;
+ }
+ }
+ return VisitStmt(s);
+ }
+
+ Stmt VisitStmt_(const StoreNode* op) final {
+ PrimExpr value = this->VisitExpr(op->value);
+ is_index_ = true;
+ PrimExpr index = this->VisitExpr(op->index);
+ is_index_ = false;
+ Stmt s = StoreNode::make(op->buffer_var,
+ op->value,
+ index,
+ op->predicate);
+ return StmtExprMutator::VisitStmt_(s.as<StoreNode>());
+ }
+
+ Stmt VisitStmt_(const ForNode* op) final {
+ Stmt s = StmtExprMutator::VisitStmt_(op);
+ op = s.as<ForNode>();
+ CHECK(op != nullptr)
+ << "Expected type to be ForNode"
+ << ", but get " << s->GetTypeKey();
+ PrimExpr e = VisitExpr(op->loop_var);
+ Var var = Downcast<Var>(e);
+ return ForNode::make(var, cast(var.dtype(), op->min), cast(var.dtype(), op->extent),
+ op->for_type, op->device_api, op->body);
+ }
+
+ Stmt VisitStmt_(const AttrStmtNode* op) final {
+ if (op->attr_key == attr::thread_extent ||
+ op->attr_key == attr::virtual_thread) {
+ Stmt s = StmtExprMutator::VisitStmt_(op);
+ op = s.as<AttrStmtNode>();
+ CHECK(op != nullptr)
+ << "Expected type to be AttrStmtNode"
+ << ", but get " << s->GetTypeKey();
+ const IterVarNode* iv = op->node.as<IterVarNode>();
+ CHECK(iv != nullptr)
+ << "Expected type to be IterVarNode"
+ << ", but get " << op->node->GetTypeKey();
+ PrimExpr e = VisitExpr(iv->var);
+ Var var = Downcast<Var>(e);
+ if (ivmap_.find(iv) == ivmap_.end()) {
+ ivmap_[iv] = IterVarNode::make(iv->dom, var, iv->iter_type, iv->thread_tag);
+ }
+ return AttrStmtNode::make(
+ ivmap_[iv],
+ op->attr_key,
+ cast(var.dtype(), op->value),
+ op->body);
+ }
+ return StmtExprMutator::VisitStmt_(op);
+ }
+
+ PrimExpr VisitExpr_(const VarNode* op) final {
+ if (visitor_.vmap.find(op) != visitor_.vmap.end()) {
+ if (vmap_.find(op) == vmap_.end()) {
+ vmap_[op] = Var(op->name_hint, visitor_.vmap[op]);
+ }
+ return vmap_[op];
+ }
+ return StmtExprMutator::VisitExpr_(op);
+ }
+
+ PrimExpr VisitExpr_(const SizeVarNode* op) final {
+ if (visitor_.vmap.find(op) != visitor_.vmap.end()) {
+ if (vmap_.find(op) == vmap_.end()) {
+ vmap_[op] = SizeVar(op->name_hint, visitor_.vmap[op]);
+ }
+ return vmap_[op];
+ }
+ return StmtExprMutator::VisitExpr_(op);
+ }
+
+ PrimExpr VisitExpr_(const LoadNode* op) final {
+ is_index_ = true;
+ PrimExpr index = this->VisitExpr(op->index);
+ is_index_ = false;
+ PrimExpr e = LoadNode::make(op->dtype, op->buffer_var, index, op->predicate);
+ return StmtExprMutator::VisitExpr_(e.as<LoadNode>());
+ }
+
+ PrimExpr VisitExpr_(const IntImmNode* op) final {
+ if (is_index_) {
+ if (visitor_.vmap.find(op) != visitor_.vmap.end()) {
+ return IntImm(visitor_.vmap[op], op->value);
+ }
+ }
+ return StmtExprMutator::VisitExpr_(op);
+ }
+
+ PrimExpr VisitExpr_(const CastNode* op) final {
+ if (is_index_ && visitor_.vmap.find(op) != visitor_.vmap.end()) {
+ PrimExpr e = StmtExprMutator::VisitExpr_(op);
+ const CastNode* new_op = e.as<CastNode>();
+ CHECK(new_op != nullptr)
+ << "Expected type to be CastNode"
+ << ", but get " << e->GetTypeKey();
+ return CastNode::make(visitor_.vmap[op], new_op->value);
+ }
+ return StmtExprMutator::VisitExpr_(op);
+ }
+
+ PrimExpr VisitExpr_(const AddNode* op) final;
+ PrimExpr VisitExpr_(const SubNode* op) final;
+ PrimExpr VisitExpr_(const MulNode* op) final;
+ PrimExpr VisitExpr_(const DivNode* op) final;
+ PrimExpr VisitExpr_(const ModNode* op) final;
+ PrimExpr VisitExpr_(const FloorDivNode* op) final;
+ PrimExpr VisitExpr_(const FloorModNode* op) final;
+ PrimExpr VisitExpr_(const MinNode* op) final;
+ PrimExpr VisitExpr_(const MaxNode* op) final;
+ PrimExpr VisitExpr_(const EQNode* op) final;
+ PrimExpr VisitExpr_(const NENode* op) final;
+ PrimExpr VisitExpr_(const LTNode* op) final;
+ PrimExpr VisitExpr_(const LENode* op) final;
+ PrimExpr VisitExpr_(const GTNode* op) final;
+ PrimExpr VisitExpr_(const GENode* op) final;
+ PrimExpr VisitExpr_(const CallNode* op) final;
+
+ private:
+ // the internal visitor to deduce the narrowed dtype
+ DataTypeVisitor visitor_;
+ // a map from Var before rewrite to that after rewrite,
+ // ensures one old Var maps to exactly one new Var
+ std::unordered_map<const VarNode*, Var> vmap_;
+ // a map from IterVar before rewrite to that after rewrite,
+ // ensures one old IterVar maps to exactly one new IterVar
+ std::unordered_map<const IterVarNode*, IterVar> ivmap_;
+ // indicator of LoadNode::index and StoreNode::index
+ bool is_index_{false};
+};
+
+#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \
+ PrimExpr DataTypeRewriter::VisitExpr_(const OP* op) { \
+ PrimExpr a = this->VisitExpr(op->a); \
+ PrimExpr b = this->VisitExpr(op->b); \
+ if (a.same_as(op->a) && \
+ b.same_as(op->b)) { \
+ return GetRef<PrimExpr>(op); \
+ } else { \
+ return FUNC(a, b); \
+ } \
+ }
+
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(AddNode, operator+)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(SubNode, operator-)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MulNode, operator*)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(DivNode, div)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(ModNode, truncmod)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(FloorDivNode, floordiv)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(FloorModNode, floormod)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MinNode, min)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MaxNode, max)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(EQNode, operator==)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(NENode, operator!=)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LTNode, operator <)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LENode, operator<=)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GTNode, operator >)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode, operator>=)
+
+PrimExpr DataTypeRewriter::VisitExpr_(const CallNode* op) {
+ PrimExpr e = StmtExprMutator::VisitExpr_(op);
+ op = e.as<CallNode>();
+ CHECK(op != nullptr)
+ << "Expected type to be CallNode"
+ << ", but get " << e->GetTypeKey();
+ if (op->call_type == CallNode::PureIntrinsic) {
+ if (op->name == intrinsic::tvm_if_then_else) {
+ return if_then_else(op->args[0], op->args[1], op->args[2]);
+ } else if (op->name == CallNode::shift_right) {
+ return op->args[0] >> op->args[1];
+ } else if (op->name == CallNode::shift_left) {
+ return op->args[0] << op->args[1];
+ } else if (op->name == CallNode::bitwise_and) {
+ return op->args[0] & op->args[1];
+ } else if (op->name == CallNode::bitwise_or) {
+ return op->args[0] | op->args[1];
+ } else if (op->name == CallNode::bitwise_xor) {
+ return op->args[0] ^ op->args[1];
+ } else if (op->name == "pow") {
+ return pow(op->args[0], op->args[1]);
+ }
+ }
+ return e;
+}
+
+Stmt NarrowDataType(Stmt stmt, int target_bits) {
+ return DataTypeRewriter(target_bits)(stmt);
+}
+
+namespace transform {
+
+Pass NarrowDataType() {
+ auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+ auto* n = f.CopyOnWrite();
+ IntImm target_bits = f->GetAttr<IntImm>("target_bits");
+ CHECK(target_bits.defined())
+ << "NarrowDataType: Require the target_bits";
+ n->body = DataTypeRewriter(target_bits->value)(std::move(n->body));
+ return f;
+ };
+ return CreatePrimFuncPass(
+ pass_func, 0, "tir.LowerDeviceStorageAccessInfo", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.NarrowDataType")
+.set_body_typed(NarrowDataType);
+
+} // namespace transform
+} // namespace tir
+} // namespace tvm
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+from tvm import te
+from tvm.tir import const
+
+
+def lower_stmt(params, stmt, target_bits):
+ func = tvm.tir.PrimFunc(params, stmt).with_attr(
+ "target_bits", target_bits)
+ func = tvm.tir.transform.NarrowDataType()(tvm.IRModule.from_expr(func))["main"]
+ stmt = func.body
+ return stmt
+
+
+def lower_sch(sch, args, target_bits):
+ binds = {}
+ arg_list = []
+ for x in args:
+ if isinstance(x, te.tensor.Tensor):
+ buf = tvm.tir.decl_buffer(x.shape, dtype=x.dtype, name=x.name)
+ assert x not in binds
+ binds[x] = buf
+ arg_list.append(buf)
+ else:
+ raise ValueError("args must be Tensor, Buffer or Var")
+ bounds = te.schedule.InferBound(sch)
+ stmt = te.schedule.ScheduleOps(sch, bounds)
+ stmt = tvm.tir.ir_pass.StorageFlatten(stmt, binds, 64, False)
+ return lower_stmt(arg_list, stmt, target_bits)
+
+
+def test_basic():
+ def check(m, n, target_bits, target_dtype):
+ ib = tvm.tir.ir_builder.create()
+ Ab = tvm.tir.decl_buffer((m, n), name='A')
+ A = ib.buffer_ptr(Ab)
+ Bb = tvm.tir.decl_buffer((m, n), name='B')
+ B = ib.buffer_ptr(Bb)
+ with ib.for_range(0, m, name='i') as i:
+ with ib.for_range(0, n, name='j') as j:
+ B[i * n + j] = A[i * n + j] + 1
+ stmt = ib.get()
+ stmt = lower_stmt([Ab, Bb], stmt, target_bits)
+ assert stmt.loop_var.dtype == target_dtype
+ assert stmt.body.loop_var.dtype == target_dtype
+
+ # const shape
+ # i32 -> i32
+ check(2, 2, 32, "int32")
+ check(2**16, 2**16, 32, "int32") # i32 + i32 is not promoted to i64 even if overflow
+ # i64 -> i32
+ check(const(2, dtype='int64'), const(2, dtype='int64'), 32, "int32")
+ check(const(2**16, dtype='int64'), const(2**16, dtype='int64'), 32, "int64")
+ # i32 -> i16
+ check(2, 2, 16, "int16")
+ check(2**10, 2**10, 16, "int32")
+
+ # symbolic shape
+ check(te.size_var(name='m', dtype='int32'), te.size_var(name='n', dtype='int32'), 32, "int32")
+ check(te.size_var(name='m', dtype='int64'), te.size_var(name='n', dtype='int64'), 32, "int64")
+
+
+def test_thread_axis():
+ def check(m, n, target_bits, target_dtype):
+ ib = tvm.tir.ir_builder.create()
+ Ab = tvm.tir.decl_buffer((m, n), name='A')
+ A = ib.buffer_ptr(Ab)
+ Bb = tvm.tir.decl_buffer((m, n), name='B')
+ B = ib.buffer_ptr(Bb)
+ bx = te.thread_axis("blockIdx.x")
+ tx = te.thread_axis("threadIdx.x")
+ ib.scope_attr(bx, "thread_extent", m)
+ ib.scope_attr(tx, "thread_extent", n)
+ B[bx * n + tx] = A[bx * n + tx] + 1
+ stmt = ib.get()
+ stmt = lower_stmt([Ab, Bb], stmt, target_bits)
+ assert stmt.node.var.dtype == target_dtype
+ assert stmt.body.node.var.dtype == target_dtype
+
+ # i32 -> i32
+ check(2, 32,
+ target_bits=32, target_dtype='int32')
+ check(2**30, 32, # i32 + i32 is not promoted to i64 even in the case of overflow
+ target_bits=32, target_dtype='int32')
+ # i64 -> i32
+ check(const(2, dtype='int64'),
+ const(32, dtype='int64'),
+ target_bits=32, target_dtype='int32')
+ check(const(2**30, dtype='int64'),
+ const(32, dtype='int64'),
+ target_bits=32, target_dtype='int64')
+ # i32 -> i16
+ check(2, 32,
+ target_bits=16, target_dtype='int16')
+ check(2**14, 32,
+ target_bits=16, target_dtype='int32')
+
+
+def test_multilanes():
+ def check(m, lanes, target_bits, target_dtype):
+ ib = tvm.tir.ir_builder.create()
+ Ab = tvm.tir.decl_buffer((m,), dtype='float32x{}'.format(lanes), name='A')
+ A = ib.buffer_ptr(Ab)
+ Bb = tvm.tir.decl_buffer((m,), dtype='float32x{}'.format(lanes), name='B')
+ B = ib.buffer_ptr(Bb)
+ with ib.for_range(0, m, name='i', dtype=m.dtype) as i:
+ B[i] = A[i] + 1
+ stmt = ib.get()
+ stmt = lower_stmt([Ab, Bb], stmt, target_bits)
+ assert stmt.loop_var.dtype == target_dtype
+
+ # i32 -> i32
+ check(const(2 ** 10, dtype='int32'), 2,
+ target_bits=32, target_dtype='int32')
+ check(const(2 ** 32, dtype='int32'), 2,
+ target_bits=32, target_dtype='int32')
+ # i64 -> i32
+ check(const(2 ** 10, dtype='int64'), 2,
+ target_bits=32, target_dtype='int32')
+ check(const(2 ** 32, dtype='int64'), 2,
+ target_bits=32, target_dtype='int64')
+ # i32 -> i16
+ check(const(2 ** 10, dtype='int32'), 2,
+ target_bits=16, target_dtype='int16')
+ check(const(2 ** 16, dtype='int32'), 2,
+ target_bits=16, target_dtype='int32')
+
+
+def test_reduce():
+ def check(m, target_bits, target_dtype):
+ A = te.placeholder((m,), name='A', dtype='float32')
+ k = te.reduce_axis((0, m), "k")
+ B = te.compute((), lambda *idx: te.sum(A[k], axis=k), name='B')
+ s = te.create_schedule(B.op)
+ stmt = lower_sch(s, [A, B], target_bits)
+ assert stmt.body[1].loop_var.dtype == target_dtype
+
+ # i32 -> i32
+ check(const(64, dtype='int32'), 32, 'int32')
+ # i64 -> i32
+ check(const(64, dtype='int64'), 32, 'int32')
+ # i32 -> i16
+ check(const(64, dtype='int32'), 16, 'int16')
+ check(const(2**16, dtype='int32'), 16, 'int32')
+ # symbolic
+ check(te.var('n', dtype='int32'), 32, 'int32')
+ check(te.var('n', dtype='int64'), 32, 'int64')
+
+
+def test_slice():
+ def check(m, n, target_bits, target_dtype):
+ # The index may overflow in B, while not in A
+ ib = tvm.tir.ir_builder.create()
+ Ab = tvm.tir.decl_buffer((m, n), name='A')
+ A = ib.buffer_ptr(Ab)
+ Bb = tvm.tir.decl_buffer((m, n * 2), name='B')
+ B = ib.buffer_ptr(Bb)
+ with ib.for_range(0, m, name='i') as i:
+ with ib.for_range(0, n, name='j') as j:
+ A[i * n + j] = B[i * 2 * n + 2 * j] + 1
+ stmt = ib.get()
+ stmt = lower_stmt([Ab, Bb], stmt, target_bits)
+ assert stmt.loop_var.dtype == target_dtype
+ assert stmt.body.loop_var.dtype == target_dtype
+
+ # The maximum index is (2**15 * 2**15 - 1) * 2 <= 2**31 - 1
+ check(const(2**15, 'int64'), const(2**15, 'int64'),
+ target_bits=32, target_dtype='int32')
+ # The maximum index is (2**15 * 2**15 - 1 + 2**15) * 2 > 2**31 - 1
+ check(const(2**15, 'int64'), const((2**15 + 1), 'int64'),
+ target_bits=32, target_dtype='int64')
+
+
+if __name__ == "__main__":
+ test_basic()
+ test_thread_axis()
+ test_multilanes()
+ test_reduce()
+ test_slice()