[TIR][PASS] dtype rewrite for indexing variables (#5092)
authorHaozheng Fan <fanhaozh@amazon.com>
Thu, 2 Apr 2020 16:04:33 +0000 (00:04 +0800)
committerGitHub <noreply@github.com>
Thu, 2 Apr 2020 16:04:33 +0000 (09:04 -0700)
16 files changed:
include/tvm/arith/analyzer.h
include/tvm/tir/ir_pass.h
include/tvm/tir/transform.h
python/tvm/driver/build_module.py
python/tvm/tir/expr.py
python/tvm/tir/ir_builder.py
python/tvm/tir/transform/transform.py
src/arith/const_int_bound.cc
src/target/llvm/codegen_cpu.cc
src/target/llvm/codegen_llvm.cc
src/tir/ir/buffer.cc
src/tir/pass/ffi_api.cc
src/tir/pass/loop_partition.cc
src/tir/pass/unroll_loop.cc
src/tir/transforms/narrow_datatype.cc [new file with mode: 0644]
tests/python/unittest/test_tir_transform_narrow_datatype.py [new file with mode: 0644]

index e7f5ede..1889e16 100644 (file)
@@ -115,6 +115,15 @@ class ConstIntBoundAnalyzer {
   ConstIntBound operator()(const PrimExpr& expr);
 
   /*!
+   * \brief analyze the expr with the intermediate memorized to avoid redundant computation
+   * \param expr The expression of interest.
+   * \param bound The lookup table to store the intermediate results
+   * \return the result of the analysis.
+   */
+  ConstIntBound operator()(const PrimExpr& expr,
+                           std::unordered_map<const PrimExprNode*, ConstIntBound>* bound);
+
+  /*!
    * \brief Update constant int bound information of var.
    *
    * \param var The variable of interest.
index d54e094..f056d9f 100644 (file)
@@ -358,6 +358,15 @@ Stmt DecorateDeviceScope(Stmt stmt);
 Stmt HoistIfThenElse(Stmt stmt);
 
 /*!
+ * \brief Narrow down PrimExpr datatype in stmt to target_bits.
+ * \note  Run this pass after StorageFlatten.
+ * \param stmt The stmt to do datatype rewrite
+ * \param target_bits the bit of target datatype
+ * \return Transformed stmt.
+ */
+Stmt NarrowDataType(Stmt stmt, int target_bits);
+
+/*!
  * \brief Make an user callable API LoweredFunc.
  *
  *  The main task of this function is to create code to :
index 9d55db5..a414bcc 100644 (file)
@@ -87,6 +87,16 @@ TVM_DLL Pass LowerDeviceStorageAccessInfo();
  */
 TVM_DLL Pass LowerWarpMemory();
 
+
+/*!
+ * \brief Narrow down PrimExpr datatype in stmt to target_bits.
+ *
+ * \note Run this pass after StorageFlatten.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass NarrowDataType();
+
 }  // namespace transform
 }  // namespace tir
 }  // namespace tvm
index 67eb224..88231aa 100644 (file)
@@ -159,6 +159,7 @@ def lower(sch,
     # Phase 1
     stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds)
     stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers)
+    stmt = ir_pass.NarrowDataType(stmt, 32)
     stmt = ir_pass.CanonicalSimplify(stmt)
     for f in lower_phase1:
         stmt = f(stmt)
index deb8d34..a192fce 100644 (file)
@@ -370,7 +370,8 @@ class IterVar(Object, ExprOp):
                 raise TypeError("dom need to be Range")
 
         name = var if var is not None else "iter"
-        var = Var(name, dtype="int32") if not isinstance(var, Var) else var
+        dtype = "int32" if dom is None else dom.extent.dtype
+        var = Var(name, dtype=dtype) if not isinstance(var, Var) else var
         self.__init_handle_by_constructor__(
             _ffi_api.IterVar, dom, var, iter_type, thread_tag)
 
index 885b847..0c4c368 100644 (file)
@@ -76,7 +76,8 @@ class BufferVar(ObjectGeneric):
     def __getitem__(self, index):
         t = DataType(self._content_type)
         if t.lanes > 1:
-            index = _expr.Ramp(index * t.lanes, 1, t.lanes)
+            base = index * t.lanes
+            index = _expr.Ramp(base, const(1, base.dtype), t.lanes)
         return _expr.Load(self._content_type, self._buffer_var, index)
 
     def __setitem__(self, index, value):
@@ -87,7 +88,8 @@ class BufferVar(ObjectGeneric):
                     value.dtype, self._content_type))
         t = DataType(self._content_type)
         if t.lanes > 1:
-            index = _expr.Ramp(index * t.lanes, 1, t.lanes)
+            base = index * t.lanes
+            index = _expr.Ramp(base, const(1, base.dtype), t.lanes)
         self._builder.emit(_stmt.Store(self._buffer_var, value, index))
 
 
index 2b50387..7c2b3c8 100644 (file)
@@ -66,3 +66,18 @@ def LowerWarpMemory():
         The result pass
     """
     return _ffi_api.LowerWarpMemory()
+
+
+def NarrowDataType():
+    """Narrow down PrimExpr datatype in stmt to target_bits.
+
+    Returns
+    -------
+    fpass : tvm.ir.transform.Pass
+        The result pass
+
+    Note
+    ----
+    Run this pass after StorageFlatten.
+    """
+    return _ffi_api.NarrowDataType()
index 702e775..57dfc15 100644 (file)
@@ -146,9 +146,30 @@ class ConstIntBoundAnalyzer::Impl :
         res = Intersect(res, info.bound);
       }
     }
+    if (bound_) {
+      const PrimExprNode* op = expr.as<PrimExprNode>();
+      auto val = bound_->find(op);
+      if (val != bound_->end()) {
+        CHECK(val->second->min_value == res.min_value &&
+              val->second->max_value == res.max_value)
+          << "Detected bound for " << expr
+          << "conflicts with memorization";
+      }
+      (*bound_)[op] = ConstIntBound(res.min_value, res.max_value);
+    }
     return res;
   }
 
+  Entry VisitExpr_(const RampNode* op) final {
+    // op = {base + i * stride | 0 <= i < lanes}
+    // Entry(op) = Union(Entry(base + i * stride) | 0 <= i < lanes)
+    // Note that `base + i * stride` is linear w.r.t. `i`
+    // Entry(op) = Union(Entry(base + i * stride) | i = 0, i = lanes-1)
+    Entry a = VisitExpr(op->base);
+    Entry b = VisitExpr(op->base + (op->lanes - 1) * op->stride);
+    return Union(a, b);
+  }
+
   Entry VisitExpr_(const CastNode* op) final {
     Entry a = VisitExpr(op->value);
     Entry b = Everything(op->dtype);
@@ -340,10 +361,13 @@ class ConstIntBoundAnalyzer::Impl :
   }
 
  private:
+  friend class ConstIntBoundAnalyzer;
   // internal variable map
   std::unordered_map<Var, Entry, ObjectHash, ObjectEqual> var_map_;
   // additional bound info
   std::vector<BoundInfo> additional_info_;
+  // look up table for memorization
+  std::unordered_map<const PrimExprNode*, ConstIntBound>* bound_{nullptr};
   // constants: the limit value means umlimited
   // NOTE: kNegInf/kPosInf are used to represent infinity.
   static const constexpr int64_t kNegInf = ConstIntBound::kNegInf;
@@ -536,6 +560,14 @@ ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr) {
   return ConstIntBound(ret.min_value, ret.max_value);
 }
 
+ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr,
+  std::unordered_map<const PrimExprNode*, ConstIntBound>* bound) {
+  impl_->bound_ = bound;
+  Entry ret = impl_->VisitExpr(expr);
+  impl_->bound_ = nullptr;
+  return ConstIntBound(ret.min_value, ret.max_value);
+}
+
 void ConstIntBoundAnalyzer::Update(const Var& var,
                                    const ConstIntBound& info,
                                    bool override) {
index ba3dee7..70bcfe8 100644 (file)
@@ -943,7 +943,7 @@ void CodeGenCPU::VisitStmt_(const ForNode* op) {
         PrimExpr end = MinNode::make((task_id + make_const(t, 1)) * step, op->extent);
         CreateSerialFor(MakeValue(begin),
                         MakeValue(end),
-                        ConstInt32(1),
+                        llvm::ConstantInt::getSigned(GetLLVMType(end), 1),
                         op->loop_var,
                         op->body);
       }
index 68d004c..31465cd 100644 (file)
@@ -1121,7 +1121,8 @@ void CodeGenLLVM::VisitStmt_(const ForNode* op) {
     CHECK(op->for_type == ForType::Serial);
   }
   CreateSerialFor(MakeValue(op->min), MakeValue(op->extent),
-                  ConstInt32(1), op->loop_var, op->body);
+                  llvm::ConstantInt::getSigned(GetLLVMType(op->extent), 1),
+                  op->loop_var, op->body);
 }
 
 
index d663c30..eec7c10 100644 (file)
@@ -452,7 +452,7 @@ Buffer BufferNode::make(Var data,
   n->buffer_type = buffer_type;
   if (n->buffer_type == kAutoBroadcast && n->shape.size() > 0 && n->strides.empty()) {
     for (size_t i = 0; i < n->shape.size(); ++i) {
-      n->strides.push_back(Var("stride"));
+      n->strides.push_back(Var("stride", n->shape[i].dtype()));
     }
   }
   return Buffer(n);
index f4d8193..d13461e 100644 (file)
@@ -156,5 +156,6 @@ REGISTER_PASS(InstrumentBoundCheckers);
 REGISTER_PASS(VerifyCompactBuffer);
 REGISTER_PASS(HoistIfThenElse);
 REGISTER_PASS(InferFragment)
+REGISTER_PASS(NarrowDataType);
 }  // namespace tir
 }  // namespace tvm
index d1fa46e..e9157e7 100644 (file)
@@ -587,7 +587,7 @@ inline Stmt LoopPartitioner::MakeFor(const Object *node, PrimExpr extent, Stmt b
     // If the loop extent is 1, do not create the loop anymore
     return Substitute(body, {{Var{for_node->loop_var}, make_const(DataType::Int(32), 0)}});
   } else {
-    return ForNode::make(for_node->loop_var, 0, extent,
+    return ForNode::make(for_node->loop_var, IntImm(for_node->min.dtype(), 0), extent,
                      for_node->for_type, for_node->device_api, body);
   }
 }
index 3d669d0..0167dbc 100644 (file)
@@ -160,7 +160,9 @@ class LoopUnroller : public StmtExprMutator {
     PrimExpr extent = tir::Simplify(op->extent);
     const IntImmNode  *v1 = extent.as<IntImmNode>();
     int value = -1;
-    if (v1 != nullptr) {
+    // integers that do not fit in int32_t are treated as symbolic,
+    // as it's impossible to unroll such large loops
+    if (v1 != nullptr && v1->value <= std::numeric_limits<int>::max()) {
       value = static_cast<int>(v1->value);
     }
     return value;
diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc
new file mode 100644 (file)
index 0000000..00bc45a
--- /dev/null
@@ -0,0 +1,418 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file narrow_datatype.cc
+ * \brief narrow the datatype of indexing vars
+ */
+
+#include <tvm/tir/ir_pass.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/transform.h>
+#include <tvm/runtime/registry.h>
+#include "../../arith/ir_mutator_with_analyzer.h"
+#include "../../arith/ir_visitor_with_analyzer.h"
+
+namespace tvm {
+namespace tir {
+
+// This pass narrows indexing expressions (like StoreNode::Index)
+// that trivially fit into i32/i16 (denoted by `target_bits_`) to
+// i32/i16. Considering that i32/i16 indices may be more
+// efficient on some backends (while i64 may be more efficient
+// on others, like llvm), we may want this pass when i32/i16
+// indices are more efficient.
+//
+// For Var v, we determine its dtype by examining all the PrimExpr
+// that contains v, denoted by E = {e_0 = v, e_1, e_2, ..., e_k}.
+// If all expressions in E fit into i32/i16, then we think v can be narrowed
+// to i32/i16.
+//
+// To make an indexing expression i32/i16, we must make sure that every
+// component of that expression is of dtype i32/i16. So besides Var, we
+// rewrite the following inside an indexing expression
+// - Var
+// - IntImm
+// - Cast
+//
+// Algorithm:
+// - Use DataTypeVisitor to determine whether a Var can be narrowed or not.
+// - Use DataTypeRewritter to rewrite the components of an indexing expression.
+
+using arith::Analyzer;
+using arith::IRMutatorWithAnalyzer;
+using arith::ConstIntBound;
+
+// Determine the result dtype for Var, IntImm and Cast,
+// which will be stored in `vmap` eventually.
+//
+// Algorithm:
+// We propogate the dtypes of all the Exprs that contain Var `var` into `vmap[var]`.
+// To be more specific, if for each Expr `e` which contains `var`
+// (`var` is a child node of `e` in AST), `e` fits into `target_bits_`,
+// then we narrow `var` into `target_bits_`. That is,
+// `vmap[var] = min(target_bits_, var.dtype.bits())`
+// Otherwise, `var` is not narrowed, that is, `vmap[var] = var.dtype.bits()`
+class DataTypeVisitor final : public StmtExprVisitor {
+ public:
+  explicit DataTypeVisitor(int target_bits)
+    : bits_(target_bits), target_bits_(target_bits) {}
+
+  void VisitExpr(const PrimExpr& e) {
+    if (e.dtype().is_int()) {
+      int bits = max_bits_;
+      const PrimExprNode* op = e.as<PrimExprNode>();
+      if (bound_.find(op) == bound_.end()) {
+        analyzer_.const_int_bound(e, &bound_);
+      }
+      ConstIntBound bound = bound_[op];
+      int64_t ubound = Downcast<IntImm>(max_value(DataType::Int(target_bits_)))->value;
+      int64_t lbound = Downcast<IntImm>(min_value(DataType::Int(target_bits_)))->value;
+      if (e.dtype().bits() <= target_bits_ ||
+          (bound->max_value <= ubound && bound->min_value >= lbound)) {
+        bits = target_bits_;
+      }
+      int tmp = bits > bits_ ? bits :  bits_;
+      std::swap(bits_, tmp);
+      StmtExprVisitor::VisitExpr(e);
+      std::swap(bits_, tmp);
+    } else {
+      StmtExprVisitor::VisitExpr(e);
+    }
+  }
+
+  void VisitStmt_(const ForNode* op) {
+    analyzer_.Bind(op->loop_var,
+                   Range::make_by_min_extent(op->min, op->extent));
+    vextent_[op->loop_var.as<VarNode>()] = op->extent.dtype();
+    return StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const AttrStmtNode* op) {
+    if (op->attr_key == attr::thread_extent ||
+        op->attr_key == attr::virtual_thread) {
+      IterVar iv = Downcast<IterVar>(op->node);
+      CHECK_NE(iv->thread_tag.length(), 0U);
+      analyzer_.Bind(iv->var,
+                      Range::make_by_min_extent(0, op->value));
+      vextent_[iv->var.as<VarNode>()] = op->value.dtype();
+      StmtExprVisitor::VisitStmt_(op);
+    } else {
+      StmtExprVisitor::VisitStmt_(op);
+    }
+  }
+
+  void VisitExpr_(const ReduceNode* op) {
+    // Setup the domain information before simplification.
+    for (const IterVar& iv : op->axis) {
+      analyzer_.Bind(iv->var, iv->dom);
+      vextent_[iv->var.as<VarNode>()] = iv->dom->extent.dtype();
+    }
+    // Recursively call simplification when necessary.
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitExpr_(const VarNode* op) {
+    if (vextent_.find(op) != vextent_.end()) {
+      // We only narrow and never promote, so the result dtype
+      // is upperbounded by its original dtype before rewrite.
+      int bits = std::min(vextent_[op].bits(), bits_);
+      if (vmap.find(op) == vmap.end()) {
+        vmap[op] = op->dtype.with_bits(bits);
+      } else {
+        // We take maximum bits for all the possible Expr where a var occurs
+        vmap[op] = op->dtype.with_bits(std::max(vmap[op].bits(), bits));
+      }
+    }
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitExpr_(const IntImmNode* op) {
+    if (op->dtype.is_int()) {
+      // We only narrow and never promote, so the result dtype
+      // is upperbounded by its original dtype before rewrite.
+      int bits = std::min(op->dtype.bits(), bits_);
+      if (vmap.find(op) == vmap.end()) {
+        vmap[op] = op->dtype.with_bits(bits);
+      } else {
+        vmap[op] = op->dtype.with_bits(std::max(vmap[op].bits(), bits));
+      }
+    }
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitExpr_(const CastNode* op) {
+    if (op->dtype.is_int()) {
+      // We only narrow and never promote, so the result dtype
+      // is upperbounded by its original dtype before rewrite.
+      int bits = std::min(op->dtype.bits(), bits_);
+      if (vmap.find(op) == vmap.end()) {
+        vmap[op] = op->dtype.with_bits(bits);
+      } else {
+        vmap[op] = op->dtype.with_bits(std::max(vmap[op].bits(), bits));
+      }
+    }
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  // the narrowed datatype of Var and IntImm
+  std::unordered_map<const PrimExprNode*, DataType> vmap;
+
+ protected:
+  // internal analyzer
+  arith::Analyzer analyzer_;
+
+ private:
+  // the maximum possible bits, which serves as an init value
+  static constexpr const int max_bits_ = 64;
+  // the maximum possible bit of the current expression's return dtype
+  int bits_;
+  // the target bits
+  int target_bits_;
+  // the extent of vars to be rewritten
+  std::unordered_map<const VarNode*, DataType> vextent_;
+  // the memorized bound generated by ConstIntBoundAnalyzer
+  std::unordered_map<const PrimExprNode*, ConstIntBound> bound_;
+};
+
+class DataTypeRewriter : public StmtExprMutator {
+ public:
+  explicit DataTypeRewriter(int target_bits): visitor_(target_bits) {}
+
+  Stmt operator()(Stmt s) {
+    visitor_(s);
+    for (auto i = visitor_.vmap.begin(), last = visitor_.vmap.end(); i != last;) {
+      PrimExpr e = GetRef<PrimExpr>(i->first);
+      if (e.dtype() == i->second) {
+        i = visitor_.vmap.erase(i);
+      } else {
+        ++i;
+      }
+    }
+    return VisitStmt(s);
+  }
+
+  Stmt VisitStmt_(const StoreNode* op) final {
+    PrimExpr value = this->VisitExpr(op->value);
+    is_index_ = true;
+    PrimExpr index = this->VisitExpr(op->index);
+    is_index_ = false;
+    Stmt s = StoreNode::make(op->buffer_var,
+                             op->value,
+                             index,
+                             op->predicate);
+    return StmtExprMutator::VisitStmt_(s.as<StoreNode>());
+  }
+
+  Stmt VisitStmt_(const ForNode* op) final {
+    Stmt s = StmtExprMutator::VisitStmt_(op);
+    op = s.as<ForNode>();
+    CHECK(op != nullptr)
+      << "Expected type to be ForNode"
+      << ", but get " << s->GetTypeKey();
+    PrimExpr e = VisitExpr(op->loop_var);
+    Var var = Downcast<Var>(e);
+    return ForNode::make(var, cast(var.dtype(), op->min), cast(var.dtype(), op->extent),
+                         op->for_type, op->device_api, op->body);
+  }
+
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    if (op->attr_key == attr::thread_extent ||
+        op->attr_key == attr::virtual_thread) {
+      Stmt s = StmtExprMutator::VisitStmt_(op);
+      op = s.as<AttrStmtNode>();
+      CHECK(op != nullptr)
+        << "Expected type to be AttrStmtNode"
+        << ", but get " << s->GetTypeKey();
+      const IterVarNode* iv = op->node.as<IterVarNode>();
+      CHECK(iv != nullptr)
+        << "Expected type to be IterVarNode"
+        << ", but get " << op->node->GetTypeKey();
+      PrimExpr e = VisitExpr(iv->var);
+      Var var = Downcast<Var>(e);
+      if (ivmap_.find(iv) == ivmap_.end()) {
+        ivmap_[iv] = IterVarNode::make(iv->dom, var, iv->iter_type, iv->thread_tag);
+      }
+      return AttrStmtNode::make(
+        ivmap_[iv],
+        op->attr_key,
+        cast(var.dtype(), op->value),
+        op->body);
+    }
+    return StmtExprMutator::VisitStmt_(op);
+  }
+
+  PrimExpr VisitExpr_(const VarNode* op) final {
+    if (visitor_.vmap.find(op) != visitor_.vmap.end()) {
+      if (vmap_.find(op) == vmap_.end()) {
+        vmap_[op] = Var(op->name_hint, visitor_.vmap[op]);
+      }
+      return vmap_[op];
+    }
+    return StmtExprMutator::VisitExpr_(op);
+  }
+
+  PrimExpr VisitExpr_(const SizeVarNode* op) final {
+    if (visitor_.vmap.find(op) != visitor_.vmap.end()) {
+      if (vmap_.find(op) == vmap_.end()) {
+        vmap_[op] = SizeVar(op->name_hint, visitor_.vmap[op]);
+      }
+      return vmap_[op];
+    }
+    return StmtExprMutator::VisitExpr_(op);
+  }
+
+  PrimExpr VisitExpr_(const LoadNode* op) final {
+    is_index_ = true;
+    PrimExpr index = this->VisitExpr(op->index);
+    is_index_ = false;
+    PrimExpr e = LoadNode::make(op->dtype, op->buffer_var, index, op->predicate);
+    return StmtExprMutator::VisitExpr_(e.as<LoadNode>());
+  }
+
+  PrimExpr VisitExpr_(const IntImmNode* op) final {
+    if (is_index_) {
+      if (visitor_.vmap.find(op) != visitor_.vmap.end()) {
+        return IntImm(visitor_.vmap[op], op->value);
+      }
+    }
+    return StmtExprMutator::VisitExpr_(op);
+  }
+
+  PrimExpr VisitExpr_(const CastNode* op) final {
+    if (is_index_ && visitor_.vmap.find(op) != visitor_.vmap.end()) {
+      PrimExpr e = StmtExprMutator::VisitExpr_(op);
+      const CastNode* new_op = e.as<CastNode>();
+      CHECK(new_op != nullptr)
+        << "Expected type to be CastNode"
+        << ", but get " << e->GetTypeKey();
+      return CastNode::make(visitor_.vmap[op], new_op->value);
+    }
+    return StmtExprMutator::VisitExpr_(op);
+  }
+
+  PrimExpr VisitExpr_(const AddNode* op) final;
+  PrimExpr VisitExpr_(const SubNode* op) final;
+  PrimExpr VisitExpr_(const MulNode* op) final;
+  PrimExpr VisitExpr_(const DivNode* op) final;
+  PrimExpr VisitExpr_(const ModNode* op) final;
+  PrimExpr VisitExpr_(const FloorDivNode* op) final;
+  PrimExpr VisitExpr_(const FloorModNode* op) final;
+  PrimExpr VisitExpr_(const MinNode* op) final;
+  PrimExpr VisitExpr_(const MaxNode* op) final;
+  PrimExpr VisitExpr_(const EQNode* op) final;
+  PrimExpr VisitExpr_(const NENode* op) final;
+  PrimExpr VisitExpr_(const LTNode* op) final;
+  PrimExpr VisitExpr_(const LENode* op) final;
+  PrimExpr VisitExpr_(const GTNode* op) final;
+  PrimExpr VisitExpr_(const GENode* op) final;
+  PrimExpr VisitExpr_(const CallNode* op) final;
+
+ private:
+  // the internal visitor to deduce the narrowed dtype
+  DataTypeVisitor visitor_;
+  // a map from Var before rewrite to that after rewrite,
+  // ensures one old Var maps to exactly one new Var
+  std::unordered_map<const VarNode*, Var> vmap_;
+  // a map from IterVar before rewrite to that after rewrite,
+  // ensures one old IterVar maps to exactly one new IterVar
+  std::unordered_map<const IterVarNode*, IterVar> ivmap_;
+  // indicator of LoadNode::index and StoreNode::index
+  bool is_index_{false};
+};
+
+#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC)               \
+  PrimExpr DataTypeRewriter::VisitExpr_(const OP* op) {                 \
+    PrimExpr a = this->VisitExpr(op->a);                                \
+    PrimExpr b = this->VisitExpr(op->b);                                \
+    if (a.same_as(op->a) &&                                             \
+        b.same_as(op->b)) {                                             \
+      return GetRef<PrimExpr>(op);                                      \
+    } else {                                                            \
+      return FUNC(a, b);                                                \
+    }                                                                   \
+  }
+
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(AddNode, operator+)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(SubNode, operator-)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MulNode, operator*)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(DivNode, div)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(ModNode, truncmod)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(FloorDivNode, floordiv)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(FloorModNode, floormod)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MinNode, min)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MaxNode, max)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(EQNode, operator==)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(NENode, operator!=)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LTNode, operator <)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LENode, operator<=)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GTNode, operator >)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode, operator>=)
+
+PrimExpr DataTypeRewriter::VisitExpr_(const CallNode* op) {
+  PrimExpr e = StmtExprMutator::VisitExpr_(op);
+  op = e.as<CallNode>();
+  CHECK(op != nullptr)
+    << "Expected type to be CallNode"
+    << ", but get " << e->GetTypeKey();
+  if (op->call_type == CallNode::PureIntrinsic) {
+    if (op->name == intrinsic::tvm_if_then_else) {
+      return if_then_else(op->args[0], op->args[1], op->args[2]);
+    } else if (op->name == CallNode::shift_right) {
+      return op->args[0] >> op->args[1];
+    } else if (op->name == CallNode::shift_left) {
+      return op->args[0] << op->args[1];
+    } else if (op->name == CallNode::bitwise_and) {
+      return op->args[0] & op->args[1];
+    } else if (op->name == CallNode::bitwise_or) {
+      return op->args[0] | op->args[1];
+    } else if (op->name == CallNode::bitwise_xor) {
+      return op->args[0] ^ op->args[1];
+    } else if (op->name == "pow") {
+      return pow(op->args[0], op->args[1]);
+    }
+  }
+  return e;
+}
+
+Stmt NarrowDataType(Stmt stmt, int target_bits) {
+  return DataTypeRewriter(target_bits)(stmt);
+}
+
+namespace transform {
+
+Pass NarrowDataType() {
+  auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+    auto* n = f.CopyOnWrite();
+    IntImm target_bits = f->GetAttr<IntImm>("target_bits");
+    CHECK(target_bits.defined())
+      << "NarrowDataType: Require the target_bits";
+    n->body = DataTypeRewriter(target_bits->value)(std::move(n->body));
+    return f;
+  };
+  return CreatePrimFuncPass(
+      pass_func, 0, "tir.LowerDeviceStorageAccessInfo", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.NarrowDataType")
+.set_body_typed(NarrowDataType);
+
+}  // namespace transform
+}  // namespace tir
+}  // namespace tvm
diff --git a/tests/python/unittest/test_tir_transform_narrow_datatype.py b/tests/python/unittest/test_tir_transform_narrow_datatype.py
new file mode 100644 (file)
index 0000000..49df1c2
--- /dev/null
@@ -0,0 +1,194 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+from tvm import te
+from tvm.tir import const
+
+
+def lower_stmt(params, stmt, target_bits):
+    func = tvm.tir.PrimFunc(params, stmt).with_attr(
+        "target_bits", target_bits)
+    func = tvm.tir.transform.NarrowDataType()(tvm.IRModule.from_expr(func))["main"]
+    stmt = func.body
+    return stmt
+
+
+def lower_sch(sch, args, target_bits):
+    binds = {}
+    arg_list = []
+    for x in args:
+        if isinstance(x, te.tensor.Tensor):
+            buf = tvm.tir.decl_buffer(x.shape, dtype=x.dtype, name=x.name)
+            assert x not in binds
+            binds[x] = buf
+            arg_list.append(buf)
+        else:
+            raise ValueError("args must be Tensor, Buffer or Var")
+    bounds = te.schedule.InferBound(sch)
+    stmt = te.schedule.ScheduleOps(sch, bounds)
+    stmt = tvm.tir.ir_pass.StorageFlatten(stmt, binds, 64, False)
+    return lower_stmt(arg_list, stmt, target_bits)
+
+
+def test_basic():
+    def check(m, n, target_bits, target_dtype):
+        ib = tvm.tir.ir_builder.create()
+        Ab = tvm.tir.decl_buffer((m, n), name='A')
+        A = ib.buffer_ptr(Ab)
+        Bb = tvm.tir.decl_buffer((m, n), name='B')
+        B = ib.buffer_ptr(Bb)
+        with ib.for_range(0, m, name='i') as i:
+            with ib.for_range(0, n, name='j') as j:
+                B[i * n + j] = A[i * n + j] + 1
+        stmt = ib.get()
+        stmt = lower_stmt([Ab, Bb], stmt, target_bits)
+        assert stmt.loop_var.dtype == target_dtype
+        assert stmt.body.loop_var.dtype == target_dtype
+
+    # const shape
+    # i32 -> i32
+    check(2, 2, 32, "int32")
+    check(2**16, 2**16, 32, "int32")  # i32 + i32 is not promoted to i64 even if overflow
+    # i64 -> i32
+    check(const(2, dtype='int64'), const(2, dtype='int64'), 32, "int32")
+    check(const(2**16, dtype='int64'), const(2**16, dtype='int64'), 32, "int64")
+    # i32 -> i16
+    check(2, 2, 16, "int16")
+    check(2**10, 2**10, 16, "int32")
+
+    # symbolic shape
+    check(te.size_var(name='m', dtype='int32'), te.size_var(name='n', dtype='int32'), 32, "int32")
+    check(te.size_var(name='m', dtype='int64'), te.size_var(name='n', dtype='int64'), 32, "int64")
+
+
+def test_thread_axis():
+    def check(m, n, target_bits, target_dtype):
+        ib = tvm.tir.ir_builder.create()
+        Ab = tvm.tir.decl_buffer((m, n), name='A')
+        A = ib.buffer_ptr(Ab)
+        Bb = tvm.tir.decl_buffer((m, n), name='B')
+        B = ib.buffer_ptr(Bb)
+        bx = te.thread_axis("blockIdx.x")
+        tx = te.thread_axis("threadIdx.x")
+        ib.scope_attr(bx, "thread_extent", m)
+        ib.scope_attr(tx, "thread_extent", n)
+        B[bx * n + tx] = A[bx * n + tx] + 1
+        stmt = ib.get()
+        stmt = lower_stmt([Ab, Bb], stmt, target_bits)
+        assert stmt.node.var.dtype == target_dtype
+        assert stmt.body.node.var.dtype == target_dtype
+
+    # i32 -> i32
+    check(2, 32,
+          target_bits=32, target_dtype='int32')
+    check(2**30, 32,  # i32 + i32 is not promoted to i64 even in the case of overflow
+          target_bits=32, target_dtype='int32')
+    # i64 -> i32
+    check(const(2, dtype='int64'),
+          const(32, dtype='int64'),
+          target_bits=32, target_dtype='int32')
+    check(const(2**30, dtype='int64'),
+          const(32, dtype='int64'),
+          target_bits=32, target_dtype='int64')
+    # i32 -> i16
+    check(2, 32,
+          target_bits=16, target_dtype='int16')
+    check(2**14, 32,
+          target_bits=16, target_dtype='int32')
+
+
+def test_multilanes():
+    def check(m, lanes, target_bits, target_dtype):
+        ib = tvm.tir.ir_builder.create()
+        Ab = tvm.tir.decl_buffer((m,), dtype='float32x{}'.format(lanes), name='A')
+        A = ib.buffer_ptr(Ab)
+        Bb = tvm.tir.decl_buffer((m,), dtype='float32x{}'.format(lanes), name='B')
+        B = ib.buffer_ptr(Bb)
+        with ib.for_range(0, m, name='i', dtype=m.dtype) as i:
+            B[i] = A[i] + 1
+        stmt = ib.get()
+        stmt = lower_stmt([Ab, Bb], stmt, target_bits)
+        assert stmt.loop_var.dtype == target_dtype
+
+    # i32 -> i32
+    check(const(2 ** 10, dtype='int32'), 2,
+          target_bits=32, target_dtype='int32')
+    check(const(2 ** 32, dtype='int32'), 2,
+          target_bits=32, target_dtype='int32')
+    # i64 -> i32
+    check(const(2 ** 10, dtype='int64'), 2,
+          target_bits=32, target_dtype='int32')
+    check(const(2 ** 32, dtype='int64'), 2,
+          target_bits=32, target_dtype='int64')
+    # i32 -> i16
+    check(const(2 ** 10, dtype='int32'), 2,
+          target_bits=16, target_dtype='int16')
+    check(const(2 ** 16, dtype='int32'), 2,
+          target_bits=16, target_dtype='int32')
+
+
+def test_reduce():
+    def check(m, target_bits, target_dtype):
+        A = te.placeholder((m,), name='A', dtype='float32')
+        k = te.reduce_axis((0, m), "k")
+        B = te.compute((), lambda *idx: te.sum(A[k], axis=k), name='B')
+        s = te.create_schedule(B.op)
+        stmt = lower_sch(s, [A, B], target_bits)
+        assert stmt.body[1].loop_var.dtype == target_dtype
+
+    # i32 -> i32
+    check(const(64, dtype='int32'), 32, 'int32')
+    # i64 -> i32
+    check(const(64, dtype='int64'), 32, 'int32')
+    # i32 -> i16
+    check(const(64, dtype='int32'), 16, 'int16')
+    check(const(2**16, dtype='int32'), 16, 'int32')
+    # symbolic
+    check(te.var('n', dtype='int32'), 32, 'int32')
+    check(te.var('n', dtype='int64'), 32, 'int64')
+
+
+def test_slice():
+    def check(m, n, target_bits, target_dtype):
+        # The index may overflow in B, while not in A
+        ib = tvm.tir.ir_builder.create()
+        Ab = tvm.tir.decl_buffer((m, n), name='A')
+        A = ib.buffer_ptr(Ab)
+        Bb = tvm.tir.decl_buffer((m, n * 2), name='B')
+        B = ib.buffer_ptr(Bb)
+        with ib.for_range(0, m, name='i') as i:
+            with ib.for_range(0, n, name='j') as j:
+                A[i * n + j] = B[i * 2 * n + 2 * j] + 1
+        stmt = ib.get()
+        stmt = lower_stmt([Ab, Bb], stmt, target_bits)
+        assert stmt.loop_var.dtype == target_dtype
+        assert stmt.body.loop_var.dtype == target_dtype
+
+    # The maximum index is (2**15 * 2**15 - 1) * 2 <= 2**31 - 1
+    check(const(2**15, 'int64'), const(2**15, 'int64'),
+          target_bits=32, target_dtype='int32')
+    # The maximum index is (2**15 * 2**15 - 1 + 2**15) * 2 > 2**31 - 1
+    check(const(2**15, 'int64'), const((2**15 + 1), 'int64'),
+          target_bits=32, target_dtype='int64')
+
+
+if __name__ == "__main__":
+    test_basic()
+    test_thread_axis()
+    test_multilanes()
+    test_reduce()
+    test_slice()