Type system: replace Type::getBitWidth with getIntOrFloatBitWidth
authorAlex Zinenko <zinenko@google.com>
Mon, 17 Dec 2018 18:05:56 +0000 (10:05 -0800)
committerjpienaar <jpienaar@google.com>
Fri, 29 Mar 2019 21:30:43 +0000 (14:30 -0700)
As MLIR moves towards dialect-specific types, a generic Type::getBitWidth does
not make sense for all of them.  Even with the current type system, the bit
width is not defined (and causes the method in question to abort) for all
TensorFlow types.

This commit restricts the bit width definition to primitive standard types that
have a number of bits appearing verbatim in their type, i.e., integers and
floats.  As a side effect, it delegates the decision on the bit width of the
`index` to the backends.  Existing backends currently hardcode it to 64 bits.

The Type::getBitWidth method is replaced by Type::getIntOrFloatBitWidth that
only applies to integers and floats.  The call sites are updated to use the new
method, where applicable, or rewritten so as not rely on it.  Incidentally,
this fixes a utility method that did not account for memrefs being allowed to
have vectors as element types in the size computation.

As an observation, several places in the code use Type in places where a more
specific type could be used instead.  Some of those are fixed by this commit.

PiperOrigin-RevId: 225844792

mlir/g3doc/Rationale.md
mlir/include/mlir/IR/Types.h
mlir/lib/Analysis/Utils.cpp
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/AttributeDetail.h
mlir/lib/IR/Builders.cpp
mlir/lib/IR/MLIRContext.cpp
mlir/lib/IR/Types.cpp
mlir/lib/Parser/Parser.cpp
mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp

index 09aa88d5a339c39a2739523e06671f57d7d154fd..a2e38466dfcfc77ca7e9bdd71310bec6f1637197 100644 (file)
@@ -259,6 +259,26 @@ fixed-width integer types, at the SSA value level. It has an additional benefit
 of supporting smaller integer types, e.g. `i8` or `i16`, for small indices
 instead of (presumably larger) `index` type.
 
+### Bit width of a non-primitive types and `index` is undefined {#bit-width-of-a-compound-type}
+
+The bit width of a compound type is not defined by MLIR, it may be defined by a
+specific lowering pass. In MLIR, bit width is a property of certain primitive
+_type_, in particular integers and floats. It is equal to the number that
+appears in the type definition, e.g. the bit width of `i32` is `32`, so is the
+bit width of `f32`. The bit width is not _necessarily_ related to the amount of
+memory (in bytes) or the size of register (in bits) that is necessary to store
+the value of the given type. These quantities are target and ABI-specific and
+should be defined during the lowering process rather than imposed from above.
+For example, `vector<3xi57>` is likely to be lowered to a vector of four 64-bit
+integers, so that its storage requirement is `4 x 64 / 8 = 32` bytes, rather
+than `(3 x 57) ceildiv 8 = 22` bytes as can be naively computed from the
+bitwidth. Individual components of MLIR that allocate space for storing values
+may use the bit size as the baseline and query the target description when it is
+introduced.
+
+The bit width is not defined for dialect-specific types at MLIR level. Dialects
+are free to define their own quantities for type sizes.
+
 ### Splitting floating point vs integer operations {#splitting-floating-point-vs-integer-operations}
 
 The MLIR operation set is likely to
index 552fd3713dc788c628d66a3d4367485a0570b38c..28653af019ad7e07407dba15dcd13b371dc12ced 100644 (file)
@@ -135,9 +135,9 @@ public:
   /// Return true if this is an integer type with the specified width.
   bool isInteger(unsigned width) const;
 
-  /// Return the bitwidth of this type. For vector or tensor types, returns the
-  /// element type's bitwidth.
-  unsigned getBitWidth() const;
+  /// Return the bit width of an integer or a float type, assert failure on
+  /// other types.
+  unsigned getIntOrFloatBitWidth() const;
 
   /// Return true if this is an integer or index type.
   bool isIntOrIndex() const;
@@ -251,6 +251,9 @@ public:
     return kind >= Kind::FIRST_FLOATING_POINT_TYPE &&
            kind <= Kind::LAST_FLOATING_POINT_TYPE;
   }
+
+  /// Return the bitwidth of this float type.
+  unsigned getWidth() const;
 };
 
 inline FloatType Type::getBF16(MLIRContext *ctx) {
@@ -357,8 +360,13 @@ public:
   VectorOrTensorType() = default;
   /* implicit */ VectorOrTensorType(Type::ImplType *ptr);
 
+  /// Return the element type.
   Type getElementType() const;
 
+  /// If an element type is an integer or a float, return its width.  Abort
+  /// otherwise.
+  unsigned getElementTypeBitWidth() const;
+
   /// If this is ranked tensor or vector type, return the number of elements. If
   /// it is an unranked tensor, abort.
   unsigned getNumElements() const;
@@ -381,6 +389,14 @@ public:
   /// the getRank call method).
   int getDimSize(unsigned i) const;
 
+  /// Get the total amount of bits occupied by a value of this type.  This does
+  /// not take into account any memory layout or widening constraints, e.g. a
+  /// vector<3xi57> is reported to occupy 3x57=171 bit, even though in practice
+  /// it will likely be stored as in a 4xi64 vector register.  Fail an assertion
+  /// if the size cannot be computed statically, i.e. if the tensor has a
+  /// dynamic shape or if its elemental type does not have a known bit width.
+  long getSizeInBits() const;
+
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
   static bool kindof(Kind kind) {
     return kind == Kind::Vector || kind == Kind::RankedTensor ||
index 3fe22e9dd9a72aa0cbecd3058e25367dbc84dd16..5fb2320308571e612580f47dc7188db97f1109a0 100644 (file)
@@ -241,11 +241,23 @@ bool mlir::getMemRefRegion(OperationStmt *opStmt, unsigned loopDepth,
 }
 
 /// Returns the size of memref data in bytes if it's statically shaped, None
-/// otherwise.
+/// otherwise.  If the element of the memref has vector type, takes into account
+/// size of the vector as well.
 Optional<uint64_t> mlir::getMemRefSizeInBytes(MemRefType memRefType) {
   if (memRefType.getNumDynamicDims() > 0)
     return None;
-  uint64_t sizeInBits = memRefType.getElementType().getBitWidth();
+  auto elementType = memRefType.getElementType();
+  if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>())
+    return None;
+
+  uint64_t sizeInBits;
+  if (elementType.isIntOrFloat()) {
+    sizeInBits = elementType.getIntOrFloatBitWidth();
+  } else {
+    auto vectorType = elementType.cast<VectorType>();
+    sizeInBits =
+        vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
+  }
   for (unsigned i = 0, e = memRefType.getRank(); i < e; i++) {
     sizeInBits = sizeInBits * memRefType.getDimSize(i);
   }
index 7445302edad3a932a9bc445b28611f94275710cf..b798e3890a028e31f495cc1c84ac425df436e507 100644 (file)
@@ -420,8 +420,8 @@ void ModulePrinter::printAttribute(Attribute attr) {
   case Attribute::Kind::Integer: {
     auto intAttr = attr.cast<IntegerAttr>();
     // Print all integer attributes as signed unless i1.
-    bool isSigned =
-        intAttr.getType().isIndex() || intAttr.getType().getBitWidth() != 1;
+    bool isSigned = intAttr.getType().isIndex() ||
+                    intAttr.getType().getIntOrFloatBitWidth() != 1;
     intAttr.getValue().print(os, isSigned);
     break;
   }
index 1b8ded8a27de660f9f40d856a88f0d3235900ab2..635d5940a16373d9824675a660bf924e14e7d9b1 100644 (file)
@@ -62,7 +62,7 @@ struct IntegerAttributeStorage final
   APInt getValue() const {
     if (type.isIndex())
       return APInt(64, {getTrailingObjects<uint64_t>(), numObjects});
-    return APInt(type.getBitWidth(),
+    return APInt(type.getIntOrFloatBitWidth(),
                  {getTrailingObjects<uint64_t>(), numObjects});
   }
 };
index 41cc6cf2d5c2c7024a066c1ae787dd4ba8022396..a6391777ea7ae337f38f9a9c88d044ee727e5ad7 100644 (file)
@@ -128,7 +128,7 @@ IntegerAttr Builder::getIntegerAttr(int64_t value) {
 IntegerAttr Builder::getIntegerAttr(Type type, int64_t value) {
   if (type.isIndex())
     return IntegerAttr::get(type, APInt(64, value));
-  return IntegerAttr::get(type, APInt(type.getBitWidth(), value));
+  return IntegerAttr::get(type, APInt(type.getIntOrFloatBitWidth(), value));
 }
 
 IntegerAttr Builder::getIntegerAttr(Type type, const APInt &value) {
index b14b598457f1905b6e18faf109aab229c4138dd6..1346431bd014a027bff6ce40e57cb7a1bf84bb32 100644 (file)
@@ -230,7 +230,10 @@ struct IntegerAttrKeyInfo : DenseMapInfo<IntegerAttributeStorage *> {
     if (rhs == getEmptyKey() || rhs == getTombstoneKey())
       return false;
     assert(lhs.first.isIndex() ||
-           (lhs.first.getBitWidth() == lhs.second.getBitWidth()));
+           (lhs.first.isa<IntegerType>() &&
+            lhs.first.cast<IntegerType>().getWidth() ==
+                lhs.second.getBitWidth()) &&
+               "mismatching integer type and value bitwidth");
     return lhs.first == rhs->type && lhs.second == rhs->getValue();
   }
 };
@@ -1167,8 +1170,12 @@ IntegerAttr IntegerAttr::get(Type type, const APInt &value) {
 
 IntegerAttr IntegerAttr::get(Type type, int64_t value) {
   // This uses 64 bit APInts by default for index type.
-  auto width = type.isIndex() ? 64 : type.getBitWidth();
-  return get(type, APInt(width, value));
+  if (type.isIndex())
+    return get(type, APInt(64, value));
+
+  auto intType = type.dyn_cast<IntegerType>();
+  assert(intType && "expected an integer type for an integer attribute");
+  return get(type, APInt(intType.getWidth(), value));
 }
 
 /// Returns the floating semantics for the given type.
@@ -1461,7 +1468,7 @@ SplatElementsAttr SplatElementsAttr::get(VectorOrTensorType type,
 
 DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type,
                                          ArrayRef<char> data) {
-  auto bitsRequired = (long)type.getBitWidth() * type.getNumElements();
+  auto bitsRequired = type.getSizeInBits();
   (void)bitsRequired;
   assert((bitsRequired <= data.size() * 8L) &&
          "Input data bit size should be larger than that type requires");
index 685f140692289d8565a62fe8cc22a063e604ba15..2b4beabb3f2f62b71960c9b0c440e0c07f67ae2d 100644 (file)
@@ -28,41 +28,43 @@ Type::Kind Type::getKind() const { return type->kind; }
 
 MLIRContext *Type::getContext() const { return type->context; }
 
-unsigned Type::getBitWidth() const {
+unsigned Type::getSubclassData() const { return type->getSubclassData(); }
+void Type::setSubclassData(unsigned val) { type->setSubclassData(val); }
+
+IndexType::IndexType(Type::ImplType *ptr) : Type(ptr) {}
+
+IntegerType::IntegerType(Type::ImplType *ptr) : Type(ptr) {}
+
+unsigned IntegerType::getWidth() const {
+  return static_cast<ImplType *>(type)->width;
+}
+
+FloatType::FloatType(Type::ImplType *ptr) : Type(ptr) {}
+
+unsigned FloatType::getWidth() const {
   switch (getKind()) {
-    // TODO: Currently the IR uses host double type to store all the float
-    // datatypes. This is completely incorrect for BF16 and other datatypes.
-    // We have to fix this once APFloat is used in the IR.
   case Type::Kind::BF16:
   case Type::Kind::F16:
+    return 16;
   case Type::Kind::F32:
+    return 32;
   case Type::Kind::F64:
     return 64;
-  case Type::Kind::Integer:
-    return cast<IntegerType>().getWidth();
-  case Type::Kind::Vector:
-  case Type::Kind::RankedTensor:
-  case Type::Kind::UnrankedTensor:
-    return cast<VectorOrTensorType>().getElementType().getBitWidth();
-    // TODO: Handle more types.
   default:
     llvm_unreachable("unexpected type");
   }
 }
 
-unsigned Type::getSubclassData() const { return type->getSubclassData(); }
-void Type::setSubclassData(unsigned val) { type->setSubclassData(val); }
-
-IndexType::IndexType(Type::ImplType *ptr) : Type(ptr) {}
-
-IntegerType::IntegerType(Type::ImplType *ptr) : Type(ptr) {}
+unsigned Type::getIntOrFloatBitWidth() const {
+  assert(isIntOrFloat() && "only ints and floats have a bitwidth");
+  if (auto intType = dyn_cast<IntegerType>()) {
+    return intType.getWidth();
+  }
 
-unsigned IntegerType::getWidth() const {
-  return static_cast<ImplType *>(type)->width;
+  auto floatType = cast<FloatType>();
+  return floatType.getWidth();
 }
 
-FloatType::FloatType(Type::ImplType *ptr) : Type(ptr) {}
-
 OtherType::OtherType(Type::ImplType *ptr) : Type(ptr) {}
 
 FunctionType::FunctionType(Type::ImplType *ptr) : Type(ptr) {}
@@ -85,6 +87,10 @@ Type VectorOrTensorType::getElementType() const {
   return static_cast<ImplType *>(type)->elementType;
 }
 
+unsigned VectorOrTensorType::getElementTypeBitWidth() const {
+  return getElementType().getIntOrFloatBitWidth();
+}
+
 unsigned VectorOrTensorType::getNumElements() const {
   switch (getKind()) {
   case Kind::Vector:
@@ -124,6 +130,24 @@ int VectorOrTensorType::getDimSize(unsigned i) const {
   }
 }
 
+// Get the number of number of bits require to store a value of the given vector
+// or tensor types.  Compute the value recursively since tensors are allowed to
+// have vectors as elements.
+long VectorOrTensorType::getSizeInBits() const {
+  assert(hasStaticShape() &&
+         "cannot get the bit size of an aggregate with a dynamic shape");
+
+  auto elementType = getElementType();
+  if (elementType.isIntOrFloat())
+    return elementType.getIntOrFloatBitWidth() * getNumElements();
+
+  // Tensors can have vectors and other tensors as elements, vectors cannot.
+  assert(!isa<VectorType>() && "unsupported vector element type");
+  auto elementVectorOrTensorType = elementType.dyn_cast<VectorOrTensorType>();
+  assert(elementVectorOrTensorType && "unsupported tensor element type");
+  return getNumElements() * elementVectorOrTensorType.getSizeInBits();
+}
+
 ArrayRef<int> VectorOrTensorType::getShape() const {
   switch (getKind()) {
   case Kind::Vector:
index a4a70a083d359300def72f8f18e7135f04427bc3..0eabc67e9658df6089a3fe475f53686b61dc1246 100644 (file)
@@ -624,7 +624,7 @@ namespace {
 class TensorLiteralParser {
 public:
   TensorLiteralParser(Parser &p, Type eltTy)
-      : p(p), eltTy(eltTy), currBitPos(0), bitsWidth(eltTy.getBitWidth()) {}
+      : p(p), eltTy(eltTy), currBitPos(0) {}
 
   ParseResult parse() { return parseList(shape); }
 
@@ -648,21 +648,27 @@ private:
   ParseResult parseList(llvm::SmallVectorImpl<int> &dims);
 
   void addToStorage(uint64_t value) {
-    if (bitsWidth == 64)
+    // Only tensors of integers or floats are supported.
+    // TODO: we currently use 64 bit for all floating point constants for legacy
+    // reasoins.  For f16 and f32, this is fixable by bitcasting APFloat value
+    // to APInt, but APFloat does not support bf16 semantics.
+    auto eltIntTy = eltTy.dyn_cast<IntegerType>();
+    size_t bitWidth = eltIntTy ? eltIntTy.getWidth() : 64;
+
+    if (bitWidth == 64)
       storage.push_back(value);
 
-    if (currBitPos + bitsWidth > storage.size() * 64)
+    if (currBitPos + bitWidth > storage.size() * 64)
       storage.push_back(0L);
 
     auto *rawData = reinterpret_cast<char *>(storage.data());
-    DenseIntElementsAttr::writeBits(rawData, currBitPos, bitsWidth, value);
-    currBitPos += bitsWidth;
+    DenseIntElementsAttr::writeBits(rawData, currBitPos, bitWidth, value);
+    currBitPos += bitWidth;
   }
 
   Parser &p;
   Type eltTy;
   size_t currBitPos;
-  size_t bitsWidth;
   SmallVector<int, 4> shape;
   std::vector<uint64_t> storage;
 };
@@ -705,7 +711,8 @@ TensorLiteralParser::parseElementOrList(llvm::SmallVectorImpl<int> &dims) {
       if (!result.isa<IntegerAttr>())
         return p.emitError("expected tensor literal element has integer type");
       auto value = result.cast<IntegerAttr>().getValue();
-      if (value.getMinSignedBits() > bitsWidth)
+      auto bitWidth = eltTy.getIntOrFloatBitWidth();
+      if (value.getMinSignedBits() > bitWidth)
         return p.emitError("tensor literal element has more bits than that "
                            "specified in the type");
       addToStorage(value.getSExtValue());
@@ -849,7 +856,7 @@ Attribute Parser::parseAttribute(Type type) {
     }
     if (!type.isIntOrIndex())
       return (emitError("integer value not valid for specified type"), nullptr);
-    int width = type.isIndex() ? 64 : type.getBitWidth();
+    int width = type.isIndex() ? 64 : type.getIntOrFloatBitWidth();
     APInt apInt(width, val.getValue());
     if (apInt != *val)
       return emitError("integer constant out of range for attribute"), nullptr;
@@ -875,7 +882,7 @@ Attribute Parser::parseAttribute(Type type) {
       }
       if (!type.isIntOrIndex())
         return (emitError("integer value not valid for type"), nullptr);
-      int width = type.isIndex() ? 64 : type.getBitWidth();
+      int width = type.isIndex() ? 64 : type.getIntOrFloatBitWidth();
       APInt apInt(width, *val, /*isSigned=*/true);
       if (apInt != *val)
         return (emitError("integer constant out of range for attribute"),
index 57c1d94d6159feeb2df254a875f7e1a978cfb349..1fc7e48942663d1c6b3418611bd469cb0cfa0813 100644 (file)
@@ -157,7 +157,7 @@ llvm::IntegerType *ModuleLowerer::convertIndexType(IndexType type) {
 }
 
 llvm::IntegerType *ModuleLowerer::convertIntegerType(IntegerType type) {
-  return builder.getIntNTy(type.getBitWidth());
+  return builder.getIntNTy(type.getWidth());
 }
 
 llvm::Type *ModuleLowerer::convertFloatType(FloatType type) {