[TIR] Enforce buffer pointer var type to be consistent with dtype. (#6317)
authorTianqi Chen <tqchen@users.noreply.github.com>
Fri, 21 Aug 2020 17:53:39 +0000 (10:53 -0700)
committerGitHub <noreply@github.com>
Fri, 21 Aug 2020 17:53:39 +0000 (10:53 -0700)
Now that we have type_annotation in tir::Var.
We should make sure that the type annotation to be consistent with the dtype
in Buffer declaration and Allocation.

This change allows future passes to directly use the content type information via type_annotation.

This PR turns on the enforcement on Buffer and also fixed a few cases for Allocate.
A follow up PR need to fix a few more cases in the hybrid script parsing
before everything can be made consistent.

include/tvm/tir/op.h
python/tvm/tir/buffer.py
python/tvm/tir/ir_builder.py
src/driver/driver_api.cc
src/tir/ir/buffer.cc
src/tir/ir/stmt.cc
src/tir/transforms/bf16_legalize.cc
src/tir/transforms/storage_flatten.cc

index 68ca266..93a54b0 100644 (file)
@@ -617,6 +617,23 @@ TVM_DECLARE_INTRIN_BINARY(hypot);
 TVM_DECLARE_INTRIN_BINARY(ldexp);
 
 namespace tir {
+
+/*!
+ * \brief Check if type is a pointer to a runtime element type.
+ * \param type The type to be checked.
+ * \param element_type The corresponding element type.
+ * \return The check results
+ */
+inline bool IsPointerType(const Type& type, const DataType& element_type) {
+  if (!type.defined()) return false;
+  if (const auto* ptr_type = type.as<PointerTypeNode>()) {
+    if (const auto* prim_type = ptr_type->element_type.as<PrimTypeNode>()) {
+      return prim_type->dtype == element_type;
+    }
+  }
+  return false;
+}
+
 /*!
  * \brief Make a const value with certain data type.
  * \param t The target type.
index 11bfb4c..bd7672a 100644 (file)
@@ -20,7 +20,7 @@ import tvm._ffi
 
 from tvm._ffi.base import string_types
 from tvm.runtime import Object, convert
-from tvm.ir import PrimExpr
+from tvm.ir import PrimExpr, PointerType, PrimType
 from . import _ffi_api
 
 
@@ -241,7 +241,7 @@ def decl_buffer(shape,
         shape_dtype = shape[0].dtype if hasattr(shape[0], "dtype") else "int32"
         elem_offset = Var('%s_elem_offset' % name, shape_dtype)
     if data is None:
-        data = Var(name, "handle")
+        data = Var(name, PointerType(PrimType(dtype)))
     return _ffi_api.Buffer(
         data, dtype, shape, strides, elem_offset, name, scope,
         data_alignment, offset_factor, buffer_type)
index 20180d1..b313e58 100644 (file)
@@ -17,7 +17,7 @@
 """Developer API of IR node builder make function."""
 from tvm._ffi.base import string_types
 from tvm.runtime import ObjectGeneric, DataType, convert, const
-from tvm.ir import container as _container
+from tvm.ir import container as _container, PointerType, PrimType
 
 from . import stmt as _stmt
 from . import expr as _expr
@@ -325,7 +325,7 @@ class IRBuilder(object):
         buffer : BufferVar
             The buffer var representing the buffer.
         """
-        buffer_var = _expr.Var(name, dtype="handle")
+        buffer_var = _expr.Var(name, PointerType(PrimType(dtype)))
         if not isinstance(shape, (list, tuple, _container.Array)):
             shape = [shape]
         if scope:
index 142bdfc..14aa4fc 100644 (file)
@@ -69,7 +69,7 @@ Target DefaultTargetHost(Target target) {
 
 tir::Buffer BufferWithOffsetAlignment(Array<PrimExpr> shape, DataType dtype, std::string name,
                                       int data_alignment, int offset_factor, bool compact) {
-  auto data = tir::Var(name, DataType::Handle());
+  auto data = tir::Var(name, PointerType(PrimType(dtype)));
   bool has_any = false;
   if (!compact) {
     for (const auto& it : shape) {
index 00e3335..d33f2dd 100644 (file)
@@ -383,9 +383,14 @@ PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lane
 Buffer::Buffer(Var data, DataType dtype, Array<PrimExpr> shape, Array<PrimExpr> strides,
                PrimExpr elem_offset, String name, String scope, int data_alignment,
                int offset_factor, BufferType buffer_type) {
+  CHECK(IsPointerType(data->type_annotation, dtype))
+      << "Buffer data field expect to have the right pointer type annotation"
+      << " annotation=" << data->type_annotation << ", dtype=" << dtype;
+
   auto n = make_object<BufferNode>();
   n->data = std::move(data);
   n->dtype = dtype;
+
   n->shape = std::move(shape);
   n->strides = std::move(strides);
   n->name = std::move(name);
index 296f492..d9e1df4 100644 (file)
@@ -263,6 +263,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
 // Allocate
 Allocate::Allocate(Var buffer_var, DataType dtype, Array<PrimExpr> extents, PrimExpr condition,
                    Stmt body) {
+  // TODO(tvm-team): Add invariant check to make sure
+  // IsPointerPType(buffer_var->type_annotation, dtype)
+  // once we fix the allocate hybrid script printing.
   for (size_t i = 0; i < extents.size(); ++i) {
     CHECK(extents[i].defined());
     CHECK(extents[i].dtype().is_scalar());
index 4a44b85..97c96ed 100644 (file)
@@ -172,14 +172,11 @@ uint16_t RoundToNearestEven(float src) {
  * Lower cast between bf16 and fp32
  * Lower bf16 FloatImm to int16
  */
-class BF16LowerRewriter : StmtExprMutator {
+class BF16LowerRewriter : public StmtExprMutator {
  public:
   BF16LowerRewriter() {}
 
-  std::unordered_map<const BufferNode*, Buffer> buffer_remap;
-  std::unordered_map<const VarNode*, Var> var_remap;
-
-  Stmt operator()(Stmt s) { return VisitStmt(s); }
+  using StmtExprMutator::operator();
 
   PrimExpr VisitExpr_(const CastNode* op) final {
     auto op_val = StmtExprMutator::VisitExpr(op->value);
@@ -190,7 +187,6 @@ class BF16LowerRewriter : StmtExprMutator {
       auto uint32_v = Cast(uint32_dtype, op_val);
       // to be endian invariant.
       return Call(op->dtype, builtin::reinterpret(), {uint32_v << 16});
-
     } else if (op->dtype.is_bfloat16()) {
       // if is cast_to_bf16, check if op->value is fp32
       CHECK(op->value->dtype.is_float() && op->value->dtype.bits() == 32);
@@ -209,104 +205,104 @@ class BF16LowerRewriter : StmtExprMutator {
   }
 
   PrimExpr VisitExpr_(const VarNode* op) final {
-    auto itr = var_remap.find(op);
-    if (itr != var_remap.end()) {
+    Var var = GetRef<Var>(op);
+
+    auto itr = var_remap_.find(var);
+    if (itr != var_remap_.end()) {
       return itr->second;
+    } else {
+      return std::move(var);
     }
-    if (op->dtype.is_bfloat16()) {
-      CHECK(!op->type_annotation.defined());
-      auto ret = Var(op->name_hint, op->dtype);
-      var_remap[op] = ret;
-      return std::move(ret);
-    }
-    return StmtExprMutator::VisitExpr_(op);
   }
 
   Stmt VisitStmt_(const AllocateNode* op) final {
-    Stmt node_holder;
-    const AllocateNode* newop;
     if (op->dtype.is_bfloat16()) {
-      auto v = Allocate(op->buffer_var, DataType::UInt(16, op->dtype.lanes()), op->extents,
-                        op->condition, op->body);
-      node_holder = v;
-      newop = static_cast<const AllocateNode*>(v.operator->());
+      DataType dtype = DataType::UInt(16, op->dtype.lanes());
+      Var buffer_var = Var(op->buffer_var->name_hint, PointerType(PrimType(dtype)));
+      var_remap_[op->buffer_var] = buffer_var;
+      return VisitStmt(Allocate(buffer_var, dtype, op->extents, op->condition, op->body));
     } else {
-      newop = op;
+      return StmtExprMutator::VisitStmt_(op);
     }
-    return StmtExprMutator::VisitStmt_(newop);
   }
 
   Stmt VisitStmt_(const BufferStoreNode* op) final {
-    auto itr = buffer_remap.find(op->buffer.operator->());
-    const BufferStoreNode* newop;
-    BufferStore newop_holder;
-    if (itr != buffer_remap.end()) {
-      newop_holder = BufferStore(itr->second, op->value, op->indices);
-      newop = newop_holder.operator->();
+    Stmt ret = StmtExprMutator::VisitStmt_(op);
+    op = ret.as<BufferStoreNode>();
+
+    auto it = buffer_remap_.find(op->buffer);
+    if (it != buffer_remap_.end()) {
+      return BufferStore(it->second, op->value, op->indices);
     } else {
-      newop = op;
+      return ret;
     }
-    return StmtExprMutator::VisitStmt_(newop);
   }
 
   Stmt VisitStmt_(const AttrStmtNode* op) final {
-    const AttrStmtNode* newop = op;
-    Stmt newop_holder;
-    if (auto buffer = op->node.as<BufferNode>()) {
-      auto itr = buffer_remap.find(buffer);
-      if (itr != buffer_remap.end()) {
-        newop_holder = AttrStmt(itr->second, op->attr_key, op->value, op->body);
-        newop = newop_holder.as<AttrStmtNode>();
+    Stmt ret = StmtExprMutator::VisitStmt_(op);
+    op = ret.as<AttrStmtNode>();
+
+    if (auto* buffer = op->node.as<BufferNode>()) {
+      auto it = buffer_remap_.find(GetRef<Buffer>(buffer));
+      if (it != buffer_remap_.end()) {
+        return AttrStmt(it->second, op->attr_key, op->value, op->body);
       }
-    } else if (auto buffer = op->node.as<VarNode>()) {
-      auto itr = var_remap.find(buffer);
-      if (itr != var_remap.end()) {
-        newop_holder = AttrStmt(itr->second, op->attr_key, op->value, op->body);
-        newop = newop_holder.as<AttrStmtNode>();
+    } else if (auto* var = op->node.as<VarNode>()) {
+      auto it = var_remap_.find(GetRef<Var>(var));
+      if (it != var_remap_.end()) {
+        return AttrStmt(it->second, op->attr_key, op->value, op->body);
       }
     }
-    return StmtExprMutator::VisitStmt_(newop);
+    return ret;
   }
 
   Stmt VisitStmt_(const BufferRealizeNode* op) final {
-    auto itr = buffer_remap.find(op->buffer.operator->());
-    const BufferRealizeNode* newop;
-    Stmt newop_holder;
-    if (itr != buffer_remap.end()) {
-      auto v = BufferRealize(itr->second, op->bounds, op->condition, op->body);
-      newop_holder = v;
-      newop = v.operator->();
+    Stmt ret = StmtExprMutator::VisitStmt_(op);
+    op = ret.as<BufferRealizeNode>();
+
+    auto it = buffer_remap_.find(op->buffer);
+    if (it != buffer_remap_.end()) {
+      return BufferRealize(it->second, op->bounds, op->condition, op->body);
     } else {
-      newop = op;
+      return ret;
+    }
+  }
+
+  Stmt VisitStmt_(const StoreNode* op) final {
+    // NOTE: we do not explicit recursivly mutate op->buffer_var
+    Stmt ret = StmtExprMutator::VisitStmt_(op);
+    op = ret.as<StoreNode>();
+
+    auto it = var_remap_.find(op->buffer_var);
+    if (it != var_remap_.end()) {
+      return Store(it->second, op->value, op->index, op->predicate);
+    } else {
+      return ret;
     }
-    return StmtExprMutator::VisitStmt_(newop);
   }
 
   PrimExpr VisitExpr_(const BufferLoadNode* op) final {
-    auto itr = buffer_remap.find(op->buffer.operator->());
-    const BufferLoadNode* newop;
-    BufferLoad newop_holder;
-    if (itr != buffer_remap.end()) {
-      newop_holder = BufferLoad(itr->second, op->indices);
-      newop = newop_holder.operator->();
+    PrimExpr ret = StmtExprMutator::VisitExpr_(op);
+    op = ret.as<BufferLoadNode>();
+
+    auto it = buffer_remap_.find(op->buffer);
+    if (it != buffer_remap_.end()) {
+      return BufferLoad(it->second, op->indices);
     } else {
-      newop = op;
+      return ret;
     }
-    return StmtExprMutator::VisitExpr_(newop);
   }
 
   PrimExpr VisitExpr_(const LoadNode* op) final {
-    bool is_bf16 = false;
+    PrimExpr ret = StmtExprMutator::VisitExpr_(op);
+    op = ret.as<LoadNode>();
+
     if (op->dtype.is_bfloat16()) {
-      is_bf16 = true;
-    }
-    PrimExpr index = this->VisitExpr(op->index);
-    PrimExpr predicate = this->VisitExpr(op->predicate);
-    if (index.same_as(op->index) && predicate.same_as(op->predicate) && !is_bf16) {
-      return GetRef<PrimExpr>(op);
+      auto it = var_remap_.find(op->buffer_var);
+      CHECK(it != var_remap_.end()) << "bfloat* var needs to be remapped";
+      return Load(DataType::UInt(16, op->dtype.lanes()), it->second, op->index, op->predicate);
     } else {
-      return Load(is_bf16 ? DataType::UInt(16, op->dtype.lanes()) : op->dtype, op->buffer_var,
-                  index, predicate);
+      return ret;
     }
   }
 
@@ -320,20 +316,31 @@ class BF16LowerRewriter : StmtExprMutator {
 
   void AlterBuffers(PrimFuncNode* op) {
     std::vector<std::pair<Var, Buffer>> changes;
+
     for (auto& itr : op->buffer_map) {
       auto oldbuf = itr.second;
       if (oldbuf->dtype.is_bfloat16()) {
-        auto newbuf = Buffer(oldbuf->data, DataType::UInt(16, oldbuf->dtype.lanes()), oldbuf->shape,
-                             oldbuf->strides, oldbuf->elem_offset, oldbuf->name, oldbuf->scope,
-                             oldbuf->data_alignment, oldbuf->offset_factor, oldbuf->buffer_type);
-        buffer_remap[oldbuf.operator->()] = newbuf;
+        DataType dtype = DataType::UInt(16, oldbuf->dtype.lanes());
+        Var buffer_var = Var(oldbuf->data->name_hint, PointerType(PrimType(dtype)));
+        auto newbuf = Buffer(buffer_var, dtype, oldbuf->shape, oldbuf->strides, oldbuf->elem_offset,
+                             oldbuf->name, oldbuf->scope, oldbuf->data_alignment,
+                             oldbuf->offset_factor, oldbuf->buffer_type);
+        buffer_remap_[oldbuf] = newbuf;
+        var_remap_[oldbuf->data] = buffer_var;
         changes.emplace_back(itr.first, newbuf);
+      } else {
+        changes.emplace_back(itr);
       }
     }
-    if (buffer_remap.size() != 0) {
+
+    if (buffer_remap_.size() != 0) {
       op->buffer_map = Map<Var, Buffer>(changes.begin(), changes.end());
     }
   }
+
+ private:
+  std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_remap_;
+  std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual> var_remap_;
 };
 
 namespace transform {
index 8eb43f8..7475bf6 100644 (file)
@@ -200,9 +200,9 @@ class StorageFlattener : public StmtExprMutator {
         strides = Array<PrimExpr>(rstrides.rbegin(), rstrides.rend());
       }
 
-      e.buffer =
-          Buffer(Var(op->buffer->data->name_hint, DataType::Handle()), op->buffer->dtype, shape,
-                 strides, PrimExpr(), op->buffer->name, skey.to_string(), align, 0, kDefault);
+      e.buffer = Buffer(Var(op->buffer->data->name_hint, op->buffer->data->type_annotation),
+                        op->buffer->dtype, shape, strides, PrimExpr(), op->buffer->name,
+                        skey.to_string(), align, 0, kDefault);
 
       buf_map_[key] = e;
       Stmt body = this->VisitStmt(op->body);