of supporting smaller integer types, e.g. `i8` or `i16`, for small indices
instead of (presumably larger) `index` type.
+### Bit width of a non-primitive types and `index` is undefined {#bit-width-of-a-compound-type}
+
+The bit width of a compound type is not defined by MLIR, it may be defined by a
+specific lowering pass. In MLIR, bit width is a property of certain primitive
+_type_, in particular integers and floats. It is equal to the number that
+appears in the type definition, e.g. the bit width of `i32` is `32`, so is the
+bit width of `f32`. The bit width is not _necessarily_ related to the amount of
+memory (in bytes) or the size of register (in bits) that is necessary to store
+the value of the given type. These quantities are target and ABI-specific and
+should be defined during the lowering process rather than imposed from above.
+For example, `vector<3xi57>` is likely to be lowered to a vector of four 64-bit
+integers, so that its storage requirement is `4 x 64 / 8 = 32` bytes, rather
+than `(3 x 57) ceildiv 8 = 22` bytes as can be naively computed from the
+bitwidth. Individual components of MLIR that allocate space for storing values
+may use the bit size as the baseline and query the target description when it is
+introduced.
+
+The bit width is not defined for dialect-specific types at MLIR level. Dialects
+are free to define their own quantities for type sizes.
+
### Splitting floating point vs integer operations {#splitting-floating-point-vs-integer-operations}
The MLIR operation set is likely to
/// Return true if this is an integer type with the specified width.
bool isInteger(unsigned width) const;
- /// Return the bitwidth of this type. For vector or tensor types, returns the
- /// element type's bitwidth.
- unsigned getBitWidth() const;
+ /// Return the bit width of an integer or a float type, assert failure on
+ /// other types.
+ unsigned getIntOrFloatBitWidth() const;
/// Return true if this is an integer or index type.
bool isIntOrIndex() const;
return kind >= Kind::FIRST_FLOATING_POINT_TYPE &&
kind <= Kind::LAST_FLOATING_POINT_TYPE;
}
+
+ /// Return the bitwidth of this float type.
+ unsigned getWidth() const;
};
inline FloatType Type::getBF16(MLIRContext *ctx) {
VectorOrTensorType() = default;
/* implicit */ VectorOrTensorType(Type::ImplType *ptr);
+ /// Return the element type.
Type getElementType() const;
+ /// If an element type is an integer or a float, return its width. Abort
+ /// otherwise.
+ unsigned getElementTypeBitWidth() const;
+
/// If this is ranked tensor or vector type, return the number of elements. If
/// it is an unranked tensor, abort.
unsigned getNumElements() const;
/// the getRank call method).
int getDimSize(unsigned i) const;
+ /// Get the total amount of bits occupied by a value of this type. This does
+ /// not take into account any memory layout or widening constraints, e.g. a
+ /// vector<3xi57> is reported to occupy 3x57=171 bit, even though in practice
+ /// it will likely be stored as in a 4xi64 vector register. Fail an assertion
+ /// if the size cannot be computed statically, i.e. if the tensor has a
+ /// dynamic shape or if its elemental type does not have a known bit width.
+ long getSizeInBits() const;
+
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(Kind kind) {
return kind == Kind::Vector || kind == Kind::RankedTensor ||
}
/// Returns the size of memref data in bytes if it's statically shaped, None
-/// otherwise.
+/// otherwise. If the element of the memref has vector type, takes into account
+/// size of the vector as well.
Optional<uint64_t> mlir::getMemRefSizeInBytes(MemRefType memRefType) {
if (memRefType.getNumDynamicDims() > 0)
return None;
- uint64_t sizeInBits = memRefType.getElementType().getBitWidth();
+ auto elementType = memRefType.getElementType();
+ if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>())
+ return None;
+
+ uint64_t sizeInBits;
+ if (elementType.isIntOrFloat()) {
+ sizeInBits = elementType.getIntOrFloatBitWidth();
+ } else {
+ auto vectorType = elementType.cast<VectorType>();
+ sizeInBits =
+ vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
+ }
for (unsigned i = 0, e = memRefType.getRank(); i < e; i++) {
sizeInBits = sizeInBits * memRefType.getDimSize(i);
}
case Attribute::Kind::Integer: {
auto intAttr = attr.cast<IntegerAttr>();
// Print all integer attributes as signed unless i1.
- bool isSigned =
- intAttr.getType().isIndex() || intAttr.getType().getBitWidth() != 1;
+ bool isSigned = intAttr.getType().isIndex() ||
+ intAttr.getType().getIntOrFloatBitWidth() != 1;
intAttr.getValue().print(os, isSigned);
break;
}
APInt getValue() const {
if (type.isIndex())
return APInt(64, {getTrailingObjects<uint64_t>(), numObjects});
- return APInt(type.getBitWidth(),
+ return APInt(type.getIntOrFloatBitWidth(),
{getTrailingObjects<uint64_t>(), numObjects});
}
};
IntegerAttr Builder::getIntegerAttr(Type type, int64_t value) {
if (type.isIndex())
return IntegerAttr::get(type, APInt(64, value));
- return IntegerAttr::get(type, APInt(type.getBitWidth(), value));
+ return IntegerAttr::get(type, APInt(type.getIntOrFloatBitWidth(), value));
}
IntegerAttr Builder::getIntegerAttr(Type type, const APInt &value) {
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
return false;
assert(lhs.first.isIndex() ||
- (lhs.first.getBitWidth() == lhs.second.getBitWidth()));
+ (lhs.first.isa<IntegerType>() &&
+ lhs.first.cast<IntegerType>().getWidth() ==
+ lhs.second.getBitWidth()) &&
+ "mismatching integer type and value bitwidth");
return lhs.first == rhs->type && lhs.second == rhs->getValue();
}
};
IntegerAttr IntegerAttr::get(Type type, int64_t value) {
// This uses 64 bit APInts by default for index type.
- auto width = type.isIndex() ? 64 : type.getBitWidth();
- return get(type, APInt(width, value));
+ if (type.isIndex())
+ return get(type, APInt(64, value));
+
+ auto intType = type.dyn_cast<IntegerType>();
+ assert(intType && "expected an integer type for an integer attribute");
+ return get(type, APInt(intType.getWidth(), value));
}
/// Returns the floating semantics for the given type.
DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type,
ArrayRef<char> data) {
- auto bitsRequired = (long)type.getBitWidth() * type.getNumElements();
+ auto bitsRequired = type.getSizeInBits();
(void)bitsRequired;
assert((bitsRequired <= data.size() * 8L) &&
"Input data bit size should be larger than that type requires");
MLIRContext *Type::getContext() const { return type->context; }
-unsigned Type::getBitWidth() const {
+unsigned Type::getSubclassData() const { return type->getSubclassData(); }
+void Type::setSubclassData(unsigned val) { type->setSubclassData(val); }
+
+IndexType::IndexType(Type::ImplType *ptr) : Type(ptr) {}
+
+IntegerType::IntegerType(Type::ImplType *ptr) : Type(ptr) {}
+
+unsigned IntegerType::getWidth() const {
+ return static_cast<ImplType *>(type)->width;
+}
+
+FloatType::FloatType(Type::ImplType *ptr) : Type(ptr) {}
+
+unsigned FloatType::getWidth() const {
switch (getKind()) {
- // TODO: Currently the IR uses host double type to store all the float
- // datatypes. This is completely incorrect for BF16 and other datatypes.
- // We have to fix this once APFloat is used in the IR.
case Type::Kind::BF16:
case Type::Kind::F16:
+ return 16;
case Type::Kind::F32:
+ return 32;
case Type::Kind::F64:
return 64;
- case Type::Kind::Integer:
- return cast<IntegerType>().getWidth();
- case Type::Kind::Vector:
- case Type::Kind::RankedTensor:
- case Type::Kind::UnrankedTensor:
- return cast<VectorOrTensorType>().getElementType().getBitWidth();
- // TODO: Handle more types.
default:
llvm_unreachable("unexpected type");
}
}
-unsigned Type::getSubclassData() const { return type->getSubclassData(); }
-void Type::setSubclassData(unsigned val) { type->setSubclassData(val); }
-
-IndexType::IndexType(Type::ImplType *ptr) : Type(ptr) {}
-
-IntegerType::IntegerType(Type::ImplType *ptr) : Type(ptr) {}
+unsigned Type::getIntOrFloatBitWidth() const {
+ assert(isIntOrFloat() && "only ints and floats have a bitwidth");
+ if (auto intType = dyn_cast<IntegerType>()) {
+ return intType.getWidth();
+ }
-unsigned IntegerType::getWidth() const {
- return static_cast<ImplType *>(type)->width;
+ auto floatType = cast<FloatType>();
+ return floatType.getWidth();
}
-FloatType::FloatType(Type::ImplType *ptr) : Type(ptr) {}
-
OtherType::OtherType(Type::ImplType *ptr) : Type(ptr) {}
FunctionType::FunctionType(Type::ImplType *ptr) : Type(ptr) {}
return static_cast<ImplType *>(type)->elementType;
}
+unsigned VectorOrTensorType::getElementTypeBitWidth() const {
+ return getElementType().getIntOrFloatBitWidth();
+}
+
unsigned VectorOrTensorType::getNumElements() const {
switch (getKind()) {
case Kind::Vector:
}
}
+// Get the number of number of bits require to store a value of the given vector
+// or tensor types. Compute the value recursively since tensors are allowed to
+// have vectors as elements.
+long VectorOrTensorType::getSizeInBits() const {
+ assert(hasStaticShape() &&
+ "cannot get the bit size of an aggregate with a dynamic shape");
+
+ auto elementType = getElementType();
+ if (elementType.isIntOrFloat())
+ return elementType.getIntOrFloatBitWidth() * getNumElements();
+
+ // Tensors can have vectors and other tensors as elements, vectors cannot.
+ assert(!isa<VectorType>() && "unsupported vector element type");
+ auto elementVectorOrTensorType = elementType.dyn_cast<VectorOrTensorType>();
+ assert(elementVectorOrTensorType && "unsupported tensor element type");
+ return getNumElements() * elementVectorOrTensorType.getSizeInBits();
+}
+
ArrayRef<int> VectorOrTensorType::getShape() const {
switch (getKind()) {
case Kind::Vector:
class TensorLiteralParser {
public:
TensorLiteralParser(Parser &p, Type eltTy)
- : p(p), eltTy(eltTy), currBitPos(0), bitsWidth(eltTy.getBitWidth()) {}
+ : p(p), eltTy(eltTy), currBitPos(0) {}
ParseResult parse() { return parseList(shape); }
ParseResult parseList(llvm::SmallVectorImpl<int> &dims);
void addToStorage(uint64_t value) {
- if (bitsWidth == 64)
+ // Only tensors of integers or floats are supported.
+ // TODO: we currently use 64 bit for all floating point constants for legacy
+ // reasoins. For f16 and f32, this is fixable by bitcasting APFloat value
+ // to APInt, but APFloat does not support bf16 semantics.
+ auto eltIntTy = eltTy.dyn_cast<IntegerType>();
+ size_t bitWidth = eltIntTy ? eltIntTy.getWidth() : 64;
+
+ if (bitWidth == 64)
storage.push_back(value);
- if (currBitPos + bitsWidth > storage.size() * 64)
+ if (currBitPos + bitWidth > storage.size() * 64)
storage.push_back(0L);
auto *rawData = reinterpret_cast<char *>(storage.data());
- DenseIntElementsAttr::writeBits(rawData, currBitPos, bitsWidth, value);
- currBitPos += bitsWidth;
+ DenseIntElementsAttr::writeBits(rawData, currBitPos, bitWidth, value);
+ currBitPos += bitWidth;
}
Parser &p;
Type eltTy;
size_t currBitPos;
- size_t bitsWidth;
SmallVector<int, 4> shape;
std::vector<uint64_t> storage;
};
if (!result.isa<IntegerAttr>())
return p.emitError("expected tensor literal element has integer type");
auto value = result.cast<IntegerAttr>().getValue();
- if (value.getMinSignedBits() > bitsWidth)
+ auto bitWidth = eltTy.getIntOrFloatBitWidth();
+ if (value.getMinSignedBits() > bitWidth)
return p.emitError("tensor literal element has more bits than that "
"specified in the type");
addToStorage(value.getSExtValue());
}
if (!type.isIntOrIndex())
return (emitError("integer value not valid for specified type"), nullptr);
- int width = type.isIndex() ? 64 : type.getBitWidth();
+ int width = type.isIndex() ? 64 : type.getIntOrFloatBitWidth();
APInt apInt(width, val.getValue());
if (apInt != *val)
return emitError("integer constant out of range for attribute"), nullptr;
}
if (!type.isIntOrIndex())
return (emitError("integer value not valid for type"), nullptr);
- int width = type.isIndex() ? 64 : type.getBitWidth();
+ int width = type.isIndex() ? 64 : type.getIntOrFloatBitWidth();
APInt apInt(width, *val, /*isSigned=*/true);
if (apInt != *val)
return (emitError("integer constant out of range for attribute"),
}
llvm::IntegerType *ModuleLowerer::convertIntegerType(IntegerType type) {
- return builder.getIntNTy(type.getBitWidth());
+ return builder.getIntNTy(type.getWidth());
}
llvm::Type *ModuleLowerer::convertFloatType(FloatType type) {