From: Haozheng Fan Date: Thu, 2 Apr 2020 16:04:33 +0000 (+0800) Subject: [TIR][PASS] dtype rewrite for indexing variables (#5092) X-Git-Tag: upstream/0.7.0~1000 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=4e5c5843e7c078e39d89c4ee2163f1fd40aef952;p=platform%2Fupstream%2Ftvm.git [TIR][PASS] dtype rewrite for indexing variables (#5092) --- diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index e7f5ede..1889e16 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -115,6 +115,15 @@ class ConstIntBoundAnalyzer { ConstIntBound operator()(const PrimExpr& expr); /*! + * \brief analyze the expr with the intermediate memorized to avoid redundant computation + * \param expr The expression of interest. + * \param bound The lookup table to store the intermediate results + * \return the result of the analysis. + */ + ConstIntBound operator()(const PrimExpr& expr, + std::unordered_map* bound); + + /*! * \brief Update constant int bound information of var. * * \param var The variable of interest. diff --git a/include/tvm/tir/ir_pass.h b/include/tvm/tir/ir_pass.h index d54e094..f056d9f 100644 --- a/include/tvm/tir/ir_pass.h +++ b/include/tvm/tir/ir_pass.h @@ -358,6 +358,15 @@ Stmt DecorateDeviceScope(Stmt stmt); Stmt HoistIfThenElse(Stmt stmt); /*! + * \brief Narrow down PrimExpr datatype in stmt to target_bits. + * \note Run this pass after StorageFlatten. + * \param stmt The stmt to do datatype rewrite + * \param target_bits the bit of target datatype + * \return Transformed stmt. + */ +Stmt NarrowDataType(Stmt stmt, int target_bits); + +/*! * \brief Make an user callable API LoweredFunc. * * The main task of this function is to create code to : diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 9d55db5..a414bcc 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -87,6 +87,16 @@ TVM_DLL Pass LowerDeviceStorageAccessInfo(); */ TVM_DLL Pass LowerWarpMemory(); + +/*! + * \brief Narrow down PrimExpr datatype in stmt to target_bits. + * + * \note Run this pass after StorageFlatten. + * + * \return The pass. + */ +TVM_DLL Pass NarrowDataType(); + } // namespace transform } // namespace tir } // namespace tvm diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 67eb224..88231aa 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -159,6 +159,7 @@ def lower(sch, # Phase 1 stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds) stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers) + stmt = ir_pass.NarrowDataType(stmt, 32) stmt = ir_pass.CanonicalSimplify(stmt) for f in lower_phase1: stmt = f(stmt) diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index deb8d34..a192fce 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -370,7 +370,8 @@ class IterVar(Object, ExprOp): raise TypeError("dom need to be Range") name = var if var is not None else "iter" - var = Var(name, dtype="int32") if not isinstance(var, Var) else var + dtype = "int32" if dom is None else dom.extent.dtype + var = Var(name, dtype=dtype) if not isinstance(var, Var) else var self.__init_handle_by_constructor__( _ffi_api.IterVar, dom, var, iter_type, thread_tag) diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 885b847..0c4c368 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -76,7 +76,8 @@ class BufferVar(ObjectGeneric): def __getitem__(self, index): t = DataType(self._content_type) if t.lanes > 1: - index = _expr.Ramp(index * t.lanes, 1, t.lanes) + base = index * t.lanes + index = _expr.Ramp(base, const(1, base.dtype), t.lanes) return _expr.Load(self._content_type, self._buffer_var, index) def __setitem__(self, index, value): @@ -87,7 +88,8 @@ class BufferVar(ObjectGeneric): value.dtype, self._content_type)) t = DataType(self._content_type) if t.lanes > 1: - index = _expr.Ramp(index * t.lanes, 1, t.lanes) + base = index * t.lanes + index = _expr.Ramp(base, const(1, base.dtype), t.lanes) self._builder.emit(_stmt.Store(self._buffer_var, value, index)) diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 2b50387..7c2b3c8 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -66,3 +66,18 @@ def LowerWarpMemory(): The result pass """ return _ffi_api.LowerWarpMemory() + + +def NarrowDataType(): + """Narrow down PrimExpr datatype in stmt to target_bits. + + Returns + ------- + fpass : tvm.ir.transform.Pass + The result pass + + Note + ---- + Run this pass after StorageFlatten. + """ + return _ffi_api.NarrowDataType() diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index 702e775..57dfc15 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -146,9 +146,30 @@ class ConstIntBoundAnalyzer::Impl : res = Intersect(res, info.bound); } } + if (bound_) { + const PrimExprNode* op = expr.as(); + auto val = bound_->find(op); + if (val != bound_->end()) { + CHECK(val->second->min_value == res.min_value && + val->second->max_value == res.max_value) + << "Detected bound for " << expr + << "conflicts with memorization"; + } + (*bound_)[op] = ConstIntBound(res.min_value, res.max_value); + } return res; } + Entry VisitExpr_(const RampNode* op) final { + // op = {base + i * stride | 0 <= i < lanes} + // Entry(op) = Union(Entry(base + i * stride) | 0 <= i < lanes) + // Note that `base + i * stride` is linear w.r.t. `i` + // Entry(op) = Union(Entry(base + i * stride) | i = 0, i = lanes-1) + Entry a = VisitExpr(op->base); + Entry b = VisitExpr(op->base + (op->lanes - 1) * op->stride); + return Union(a, b); + } + Entry VisitExpr_(const CastNode* op) final { Entry a = VisitExpr(op->value); Entry b = Everything(op->dtype); @@ -340,10 +361,13 @@ class ConstIntBoundAnalyzer::Impl : } private: + friend class ConstIntBoundAnalyzer; // internal variable map std::unordered_map var_map_; // additional bound info std::vector additional_info_; + // look up table for memorization + std::unordered_map* bound_{nullptr}; // constants: the limit value means umlimited // NOTE: kNegInf/kPosInf are used to represent infinity. static const constexpr int64_t kNegInf = ConstIntBound::kNegInf; @@ -536,6 +560,14 @@ ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr) { return ConstIntBound(ret.min_value, ret.max_value); } +ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr, + std::unordered_map* bound) { + impl_->bound_ = bound; + Entry ret = impl_->VisitExpr(expr); + impl_->bound_ = nullptr; + return ConstIntBound(ret.min_value, ret.max_value); +} + void ConstIntBoundAnalyzer::Update(const Var& var, const ConstIntBound& info, bool override) { diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index ba3dee7..70bcfe8 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -943,7 +943,7 @@ void CodeGenCPU::VisitStmt_(const ForNode* op) { PrimExpr end = MinNode::make((task_id + make_const(t, 1)) * step, op->extent); CreateSerialFor(MakeValue(begin), MakeValue(end), - ConstInt32(1), + llvm::ConstantInt::getSigned(GetLLVMType(end), 1), op->loop_var, op->body); } diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 68d004c..31465cd 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -1121,7 +1121,8 @@ void CodeGenLLVM::VisitStmt_(const ForNode* op) { CHECK(op->for_type == ForType::Serial); } CreateSerialFor(MakeValue(op->min), MakeValue(op->extent), - ConstInt32(1), op->loop_var, op->body); + llvm::ConstantInt::getSigned(GetLLVMType(op->extent), 1), + op->loop_var, op->body); } diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index d663c30..eec7c10 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -452,7 +452,7 @@ Buffer BufferNode::make(Var data, n->buffer_type = buffer_type; if (n->buffer_type == kAutoBroadcast && n->shape.size() > 0 && n->strides.empty()) { for (size_t i = 0; i < n->shape.size(); ++i) { - n->strides.push_back(Var("stride")); + n->strides.push_back(Var("stride", n->shape[i].dtype())); } } return Buffer(n); diff --git a/src/tir/pass/ffi_api.cc b/src/tir/pass/ffi_api.cc index f4d8193..d13461e 100644 --- a/src/tir/pass/ffi_api.cc +++ b/src/tir/pass/ffi_api.cc @@ -156,5 +156,6 @@ REGISTER_PASS(InstrumentBoundCheckers); REGISTER_PASS(VerifyCompactBuffer); REGISTER_PASS(HoistIfThenElse); REGISTER_PASS(InferFragment) +REGISTER_PASS(NarrowDataType); } // namespace tir } // namespace tvm diff --git a/src/tir/pass/loop_partition.cc b/src/tir/pass/loop_partition.cc index d1fa46e..e9157e7 100644 --- a/src/tir/pass/loop_partition.cc +++ b/src/tir/pass/loop_partition.cc @@ -587,7 +587,7 @@ inline Stmt LoopPartitioner::MakeFor(const Object *node, PrimExpr extent, Stmt b // If the loop extent is 1, do not create the loop anymore return Substitute(body, {{Var{for_node->loop_var}, make_const(DataType::Int(32), 0)}}); } else { - return ForNode::make(for_node->loop_var, 0, extent, + return ForNode::make(for_node->loop_var, IntImm(for_node->min.dtype(), 0), extent, for_node->for_type, for_node->device_api, body); } } diff --git a/src/tir/pass/unroll_loop.cc b/src/tir/pass/unroll_loop.cc index 3d669d0..0167dbc 100644 --- a/src/tir/pass/unroll_loop.cc +++ b/src/tir/pass/unroll_loop.cc @@ -160,7 +160,9 @@ class LoopUnroller : public StmtExprMutator { PrimExpr extent = tir::Simplify(op->extent); const IntImmNode *v1 = extent.as(); int value = -1; - if (v1 != nullptr) { + // integers that do not fit in int32_t are treated as symbolic, + // as it's impossible to unroll such large loops + if (v1 != nullptr && v1->value <= std::numeric_limits::max()) { value = static_cast(v1->value); } return value; diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc new file mode 100644 index 0000000..00bc45a --- /dev/null +++ b/src/tir/transforms/narrow_datatype.cc @@ -0,0 +1,418 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file narrow_datatype.cc + * \brief narrow the datatype of indexing vars + */ + +#include +#include +#include +#include +#include "../../arith/ir_mutator_with_analyzer.h" +#include "../../arith/ir_visitor_with_analyzer.h" + +namespace tvm { +namespace tir { + +// This pass narrows indexing expressions (like StoreNode::Index) +// that trivially fit into i32/i16 (denoted by `target_bits_`) to +// i32/i16. Considering that i32/i16 indices may be more +// efficient on some backends (while i64 may be more efficient +// on others, like llvm), we may want this pass when i32/i16 +// indices are more efficient. +// +// For Var v, we determine its dtype by examining all the PrimExpr +// that contains v, denoted by E = {e_0 = v, e_1, e_2, ..., e_k}. +// If all expressions in E fit into i32/i16, then we think v can be narrowed +// to i32/i16. +// +// To make an indexing expression i32/i16, we must make sure that every +// component of that expression is of dtype i32/i16. So besides Var, we +// rewrite the following inside an indexing expression +// - Var +// - IntImm +// - Cast +// +// Algorithm: +// - Use DataTypeVisitor to determine whether a Var can be narrowed or not. +// - Use DataTypeRewritter to rewrite the components of an indexing expression. + +using arith::Analyzer; +using arith::IRMutatorWithAnalyzer; +using arith::ConstIntBound; + +// Determine the result dtype for Var, IntImm and Cast, +// which will be stored in `vmap` eventually. +// +// Algorithm: +// We propogate the dtypes of all the Exprs that contain Var `var` into `vmap[var]`. +// To be more specific, if for each Expr `e` which contains `var` +// (`var` is a child node of `e` in AST), `e` fits into `target_bits_`, +// then we narrow `var` into `target_bits_`. That is, +// `vmap[var] = min(target_bits_, var.dtype.bits())` +// Otherwise, `var` is not narrowed, that is, `vmap[var] = var.dtype.bits()` +class DataTypeVisitor final : public StmtExprVisitor { + public: + explicit DataTypeVisitor(int target_bits) + : bits_(target_bits), target_bits_(target_bits) {} + + void VisitExpr(const PrimExpr& e) { + if (e.dtype().is_int()) { + int bits = max_bits_; + const PrimExprNode* op = e.as(); + if (bound_.find(op) == bound_.end()) { + analyzer_.const_int_bound(e, &bound_); + } + ConstIntBound bound = bound_[op]; + int64_t ubound = Downcast(max_value(DataType::Int(target_bits_)))->value; + int64_t lbound = Downcast(min_value(DataType::Int(target_bits_)))->value; + if (e.dtype().bits() <= target_bits_ || + (bound->max_value <= ubound && bound->min_value >= lbound)) { + bits = target_bits_; + } + int tmp = bits > bits_ ? bits : bits_; + std::swap(bits_, tmp); + StmtExprVisitor::VisitExpr(e); + std::swap(bits_, tmp); + } else { + StmtExprVisitor::VisitExpr(e); + } + } + + void VisitStmt_(const ForNode* op) { + analyzer_.Bind(op->loop_var, + Range::make_by_min_extent(op->min, op->extent)); + vextent_[op->loop_var.as()] = op->extent.dtype(); + return StmtExprVisitor::VisitStmt_(op); + } + + void VisitStmt_(const AttrStmtNode* op) { + if (op->attr_key == attr::thread_extent || + op->attr_key == attr::virtual_thread) { + IterVar iv = Downcast(op->node); + CHECK_NE(iv->thread_tag.length(), 0U); + analyzer_.Bind(iv->var, + Range::make_by_min_extent(0, op->value)); + vextent_[iv->var.as()] = op->value.dtype(); + StmtExprVisitor::VisitStmt_(op); + } else { + StmtExprVisitor::VisitStmt_(op); + } + } + + void VisitExpr_(const ReduceNode* op) { + // Setup the domain information before simplification. + for (const IterVar& iv : op->axis) { + analyzer_.Bind(iv->var, iv->dom); + vextent_[iv->var.as()] = iv->dom->extent.dtype(); + } + // Recursively call simplification when necessary. + StmtExprVisitor::VisitExpr_(op); + } + + void VisitExpr_(const VarNode* op) { + if (vextent_.find(op) != vextent_.end()) { + // We only narrow and never promote, so the result dtype + // is upperbounded by its original dtype before rewrite. + int bits = std::min(vextent_[op].bits(), bits_); + if (vmap.find(op) == vmap.end()) { + vmap[op] = op->dtype.with_bits(bits); + } else { + // We take maximum bits for all the possible Expr where a var occurs + vmap[op] = op->dtype.with_bits(std::max(vmap[op].bits(), bits)); + } + } + StmtExprVisitor::VisitExpr_(op); + } + + void VisitExpr_(const IntImmNode* op) { + if (op->dtype.is_int()) { + // We only narrow and never promote, so the result dtype + // is upperbounded by its original dtype before rewrite. + int bits = std::min(op->dtype.bits(), bits_); + if (vmap.find(op) == vmap.end()) { + vmap[op] = op->dtype.with_bits(bits); + } else { + vmap[op] = op->dtype.with_bits(std::max(vmap[op].bits(), bits)); + } + } + StmtExprVisitor::VisitExpr_(op); + } + + void VisitExpr_(const CastNode* op) { + if (op->dtype.is_int()) { + // We only narrow and never promote, so the result dtype + // is upperbounded by its original dtype before rewrite. + int bits = std::min(op->dtype.bits(), bits_); + if (vmap.find(op) == vmap.end()) { + vmap[op] = op->dtype.with_bits(bits); + } else { + vmap[op] = op->dtype.with_bits(std::max(vmap[op].bits(), bits)); + } + } + StmtExprVisitor::VisitExpr_(op); + } + + // the narrowed datatype of Var and IntImm + std::unordered_map vmap; + + protected: + // internal analyzer + arith::Analyzer analyzer_; + + private: + // the maximum possible bits, which serves as an init value + static constexpr const int max_bits_ = 64; + // the maximum possible bit of the current expression's return dtype + int bits_; + // the target bits + int target_bits_; + // the extent of vars to be rewritten + std::unordered_map vextent_; + // the memorized bound generated by ConstIntBoundAnalyzer + std::unordered_map bound_; +}; + +class DataTypeRewriter : public StmtExprMutator { + public: + explicit DataTypeRewriter(int target_bits): visitor_(target_bits) {} + + Stmt operator()(Stmt s) { + visitor_(s); + for (auto i = visitor_.vmap.begin(), last = visitor_.vmap.end(); i != last;) { + PrimExpr e = GetRef(i->first); + if (e.dtype() == i->second) { + i = visitor_.vmap.erase(i); + } else { + ++i; + } + } + return VisitStmt(s); + } + + Stmt VisitStmt_(const StoreNode* op) final { + PrimExpr value = this->VisitExpr(op->value); + is_index_ = true; + PrimExpr index = this->VisitExpr(op->index); + is_index_ = false; + Stmt s = StoreNode::make(op->buffer_var, + op->value, + index, + op->predicate); + return StmtExprMutator::VisitStmt_(s.as()); + } + + Stmt VisitStmt_(const ForNode* op) final { + Stmt s = StmtExprMutator::VisitStmt_(op); + op = s.as(); + CHECK(op != nullptr) + << "Expected type to be ForNode" + << ", but get " << s->GetTypeKey(); + PrimExpr e = VisitExpr(op->loop_var); + Var var = Downcast(e); + return ForNode::make(var, cast(var.dtype(), op->min), cast(var.dtype(), op->extent), + op->for_type, op->device_api, op->body); + } + + Stmt VisitStmt_(const AttrStmtNode* op) final { + if (op->attr_key == attr::thread_extent || + op->attr_key == attr::virtual_thread) { + Stmt s = StmtExprMutator::VisitStmt_(op); + op = s.as(); + CHECK(op != nullptr) + << "Expected type to be AttrStmtNode" + << ", but get " << s->GetTypeKey(); + const IterVarNode* iv = op->node.as(); + CHECK(iv != nullptr) + << "Expected type to be IterVarNode" + << ", but get " << op->node->GetTypeKey(); + PrimExpr e = VisitExpr(iv->var); + Var var = Downcast(e); + if (ivmap_.find(iv) == ivmap_.end()) { + ivmap_[iv] = IterVarNode::make(iv->dom, var, iv->iter_type, iv->thread_tag); + } + return AttrStmtNode::make( + ivmap_[iv], + op->attr_key, + cast(var.dtype(), op->value), + op->body); + } + return StmtExprMutator::VisitStmt_(op); + } + + PrimExpr VisitExpr_(const VarNode* op) final { + if (visitor_.vmap.find(op) != visitor_.vmap.end()) { + if (vmap_.find(op) == vmap_.end()) { + vmap_[op] = Var(op->name_hint, visitor_.vmap[op]); + } + return vmap_[op]; + } + return StmtExprMutator::VisitExpr_(op); + } + + PrimExpr VisitExpr_(const SizeVarNode* op) final { + if (visitor_.vmap.find(op) != visitor_.vmap.end()) { + if (vmap_.find(op) == vmap_.end()) { + vmap_[op] = SizeVar(op->name_hint, visitor_.vmap[op]); + } + return vmap_[op]; + } + return StmtExprMutator::VisitExpr_(op); + } + + PrimExpr VisitExpr_(const LoadNode* op) final { + is_index_ = true; + PrimExpr index = this->VisitExpr(op->index); + is_index_ = false; + PrimExpr e = LoadNode::make(op->dtype, op->buffer_var, index, op->predicate); + return StmtExprMutator::VisitExpr_(e.as()); + } + + PrimExpr VisitExpr_(const IntImmNode* op) final { + if (is_index_) { + if (visitor_.vmap.find(op) != visitor_.vmap.end()) { + return IntImm(visitor_.vmap[op], op->value); + } + } + return StmtExprMutator::VisitExpr_(op); + } + + PrimExpr VisitExpr_(const CastNode* op) final { + if (is_index_ && visitor_.vmap.find(op) != visitor_.vmap.end()) { + PrimExpr e = StmtExprMutator::VisitExpr_(op); + const CastNode* new_op = e.as(); + CHECK(new_op != nullptr) + << "Expected type to be CastNode" + << ", but get " << e->GetTypeKey(); + return CastNode::make(visitor_.vmap[op], new_op->value); + } + return StmtExprMutator::VisitExpr_(op); + } + + PrimExpr VisitExpr_(const AddNode* op) final; + PrimExpr VisitExpr_(const SubNode* op) final; + PrimExpr VisitExpr_(const MulNode* op) final; + PrimExpr VisitExpr_(const DivNode* op) final; + PrimExpr VisitExpr_(const ModNode* op) final; + PrimExpr VisitExpr_(const FloorDivNode* op) final; + PrimExpr VisitExpr_(const FloorModNode* op) final; + PrimExpr VisitExpr_(const MinNode* op) final; + PrimExpr VisitExpr_(const MaxNode* op) final; + PrimExpr VisitExpr_(const EQNode* op) final; + PrimExpr VisitExpr_(const NENode* op) final; + PrimExpr VisitExpr_(const LTNode* op) final; + PrimExpr VisitExpr_(const LENode* op) final; + PrimExpr VisitExpr_(const GTNode* op) final; + PrimExpr VisitExpr_(const GENode* op) final; + PrimExpr VisitExpr_(const CallNode* op) final; + + private: + // the internal visitor to deduce the narrowed dtype + DataTypeVisitor visitor_; + // a map from Var before rewrite to that after rewrite, + // ensures one old Var maps to exactly one new Var + std::unordered_map vmap_; + // a map from IterVar before rewrite to that after rewrite, + // ensures one old IterVar maps to exactly one new IterVar + std::unordered_map ivmap_; + // indicator of LoadNode::index and StoreNode::index + bool is_index_{false}; +}; + +#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \ + PrimExpr DataTypeRewriter::VisitExpr_(const OP* op) { \ + PrimExpr a = this->VisitExpr(op->a); \ + PrimExpr b = this->VisitExpr(op->b); \ + if (a.same_as(op->a) && \ + b.same_as(op->b)) { \ + return GetRef(op); \ + } else { \ + return FUNC(a, b); \ + } \ + } + +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(AddNode, operator+) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(SubNode, operator-) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MulNode, operator*) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(DivNode, div) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(ModNode, truncmod) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(FloorDivNode, floordiv) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(FloorModNode, floormod) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MinNode, min) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MaxNode, max) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(EQNode, operator==) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(NENode, operator!=) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LTNode, operator <) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LENode, operator<=) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GTNode, operator >) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode, operator>=) + +PrimExpr DataTypeRewriter::VisitExpr_(const CallNode* op) { + PrimExpr e = StmtExprMutator::VisitExpr_(op); + op = e.as(); + CHECK(op != nullptr) + << "Expected type to be CallNode" + << ", but get " << e->GetTypeKey(); + if (op->call_type == CallNode::PureIntrinsic) { + if (op->name == intrinsic::tvm_if_then_else) { + return if_then_else(op->args[0], op->args[1], op->args[2]); + } else if (op->name == CallNode::shift_right) { + return op->args[0] >> op->args[1]; + } else if (op->name == CallNode::shift_left) { + return op->args[0] << op->args[1]; + } else if (op->name == CallNode::bitwise_and) { + return op->args[0] & op->args[1]; + } else if (op->name == CallNode::bitwise_or) { + return op->args[0] | op->args[1]; + } else if (op->name == CallNode::bitwise_xor) { + return op->args[0] ^ op->args[1]; + } else if (op->name == "pow") { + return pow(op->args[0], op->args[1]); + } + } + return e; +} + +Stmt NarrowDataType(Stmt stmt, int target_bits) { + return DataTypeRewriter(target_bits)(stmt); +} + +namespace transform { + +Pass NarrowDataType() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + IntImm target_bits = f->GetAttr("target_bits"); + CHECK(target_bits.defined()) + << "NarrowDataType: Require the target_bits"; + n->body = DataTypeRewriter(target_bits->value)(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass( + pass_func, 0, "tir.LowerDeviceStorageAccessInfo", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.NarrowDataType") +.set_body_typed(NarrowDataType); + +} // namespace transform +} // namespace tir +} // namespace tvm diff --git a/tests/python/unittest/test_tir_transform_narrow_datatype.py b/tests/python/unittest/test_tir_transform_narrow_datatype.py new file mode 100644 index 0000000..49df1c2 --- /dev/null +++ b/tests/python/unittest/test_tir_transform_narrow_datatype.py @@ -0,0 +1,194 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import te +from tvm.tir import const + + +def lower_stmt(params, stmt, target_bits): + func = tvm.tir.PrimFunc(params, stmt).with_attr( + "target_bits", target_bits) + func = tvm.tir.transform.NarrowDataType()(tvm.IRModule.from_expr(func))["main"] + stmt = func.body + return stmt + + +def lower_sch(sch, args, target_bits): + binds = {} + arg_list = [] + for x in args: + if isinstance(x, te.tensor.Tensor): + buf = tvm.tir.decl_buffer(x.shape, dtype=x.dtype, name=x.name) + assert x not in binds + binds[x] = buf + arg_list.append(buf) + else: + raise ValueError("args must be Tensor, Buffer or Var") + bounds = te.schedule.InferBound(sch) + stmt = te.schedule.ScheduleOps(sch, bounds) + stmt = tvm.tir.ir_pass.StorageFlatten(stmt, binds, 64, False) + return lower_stmt(arg_list, stmt, target_bits) + + +def test_basic(): + def check(m, n, target_bits, target_dtype): + ib = tvm.tir.ir_builder.create() + Ab = tvm.tir.decl_buffer((m, n), name='A') + A = ib.buffer_ptr(Ab) + Bb = tvm.tir.decl_buffer((m, n), name='B') + B = ib.buffer_ptr(Bb) + with ib.for_range(0, m, name='i') as i: + with ib.for_range(0, n, name='j') as j: + B[i * n + j] = A[i * n + j] + 1 + stmt = ib.get() + stmt = lower_stmt([Ab, Bb], stmt, target_bits) + assert stmt.loop_var.dtype == target_dtype + assert stmt.body.loop_var.dtype == target_dtype + + # const shape + # i32 -> i32 + check(2, 2, 32, "int32") + check(2**16, 2**16, 32, "int32") # i32 + i32 is not promoted to i64 even if overflow + # i64 -> i32 + check(const(2, dtype='int64'), const(2, dtype='int64'), 32, "int32") + check(const(2**16, dtype='int64'), const(2**16, dtype='int64'), 32, "int64") + # i32 -> i16 + check(2, 2, 16, "int16") + check(2**10, 2**10, 16, "int32") + + # symbolic shape + check(te.size_var(name='m', dtype='int32'), te.size_var(name='n', dtype='int32'), 32, "int32") + check(te.size_var(name='m', dtype='int64'), te.size_var(name='n', dtype='int64'), 32, "int64") + + +def test_thread_axis(): + def check(m, n, target_bits, target_dtype): + ib = tvm.tir.ir_builder.create() + Ab = tvm.tir.decl_buffer((m, n), name='A') + A = ib.buffer_ptr(Ab) + Bb = tvm.tir.decl_buffer((m, n), name='B') + B = ib.buffer_ptr(Bb) + bx = te.thread_axis("blockIdx.x") + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(bx, "thread_extent", m) + ib.scope_attr(tx, "thread_extent", n) + B[bx * n + tx] = A[bx * n + tx] + 1 + stmt = ib.get() + stmt = lower_stmt([Ab, Bb], stmt, target_bits) + assert stmt.node.var.dtype == target_dtype + assert stmt.body.node.var.dtype == target_dtype + + # i32 -> i32 + check(2, 32, + target_bits=32, target_dtype='int32') + check(2**30, 32, # i32 + i32 is not promoted to i64 even in the case of overflow + target_bits=32, target_dtype='int32') + # i64 -> i32 + check(const(2, dtype='int64'), + const(32, dtype='int64'), + target_bits=32, target_dtype='int32') + check(const(2**30, dtype='int64'), + const(32, dtype='int64'), + target_bits=32, target_dtype='int64') + # i32 -> i16 + check(2, 32, + target_bits=16, target_dtype='int16') + check(2**14, 32, + target_bits=16, target_dtype='int32') + + +def test_multilanes(): + def check(m, lanes, target_bits, target_dtype): + ib = tvm.tir.ir_builder.create() + Ab = tvm.tir.decl_buffer((m,), dtype='float32x{}'.format(lanes), name='A') + A = ib.buffer_ptr(Ab) + Bb = tvm.tir.decl_buffer((m,), dtype='float32x{}'.format(lanes), name='B') + B = ib.buffer_ptr(Bb) + with ib.for_range(0, m, name='i', dtype=m.dtype) as i: + B[i] = A[i] + 1 + stmt = ib.get() + stmt = lower_stmt([Ab, Bb], stmt, target_bits) + assert stmt.loop_var.dtype == target_dtype + + # i32 -> i32 + check(const(2 ** 10, dtype='int32'), 2, + target_bits=32, target_dtype='int32') + check(const(2 ** 32, dtype='int32'), 2, + target_bits=32, target_dtype='int32') + # i64 -> i32 + check(const(2 ** 10, dtype='int64'), 2, + target_bits=32, target_dtype='int32') + check(const(2 ** 32, dtype='int64'), 2, + target_bits=32, target_dtype='int64') + # i32 -> i16 + check(const(2 ** 10, dtype='int32'), 2, + target_bits=16, target_dtype='int16') + check(const(2 ** 16, dtype='int32'), 2, + target_bits=16, target_dtype='int32') + + +def test_reduce(): + def check(m, target_bits, target_dtype): + A = te.placeholder((m,), name='A', dtype='float32') + k = te.reduce_axis((0, m), "k") + B = te.compute((), lambda *idx: te.sum(A[k], axis=k), name='B') + s = te.create_schedule(B.op) + stmt = lower_sch(s, [A, B], target_bits) + assert stmt.body[1].loop_var.dtype == target_dtype + + # i32 -> i32 + check(const(64, dtype='int32'), 32, 'int32') + # i64 -> i32 + check(const(64, dtype='int64'), 32, 'int32') + # i32 -> i16 + check(const(64, dtype='int32'), 16, 'int16') + check(const(2**16, dtype='int32'), 16, 'int32') + # symbolic + check(te.var('n', dtype='int32'), 32, 'int32') + check(te.var('n', dtype='int64'), 32, 'int64') + + +def test_slice(): + def check(m, n, target_bits, target_dtype): + # The index may overflow in B, while not in A + ib = tvm.tir.ir_builder.create() + Ab = tvm.tir.decl_buffer((m, n), name='A') + A = ib.buffer_ptr(Ab) + Bb = tvm.tir.decl_buffer((m, n * 2), name='B') + B = ib.buffer_ptr(Bb) + with ib.for_range(0, m, name='i') as i: + with ib.for_range(0, n, name='j') as j: + A[i * n + j] = B[i * 2 * n + 2 * j] + 1 + stmt = ib.get() + stmt = lower_stmt([Ab, Bb], stmt, target_bits) + assert stmt.loop_var.dtype == target_dtype + assert stmt.body.loop_var.dtype == target_dtype + + # The maximum index is (2**15 * 2**15 - 1) * 2 <= 2**31 - 1 + check(const(2**15, 'int64'), const(2**15, 'int64'), + target_bits=32, target_dtype='int32') + # The maximum index is (2**15 * 2**15 - 1 + 2**15) * 2 > 2**31 - 1 + check(const(2**15, 'int64'), const((2**15 + 1), 'int64'), + target_bits=32, target_dtype='int64') + + +if __name__ == "__main__": + test_basic() + test_thread_axis() + test_multilanes() + test_reduce() + test_slice()