Summary: For example, DenseElementsAttr currently does not properly round-trip unsigned integer values.
Differential Revision: https://reviews.llvm.org/D75374
return false;
auto type = op->getResult(0).getType();
- if (type.isSignlessIntOrIndex()) {
+ if (type.isa<IntegerType>() || type.isa<IndexType>())
return attr_value_binder<IntegerAttr>(bind_value).match(attr);
- }
if (type.isa<VectorType>() || type.isa<RankedTensorType>()) {
if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
return attr_value_binder<IntegerAttr>(bind_value)
def I32 : I<32>;
def I64 : I<64>;
+// Unsigned integer types.
+// Any unsigned integer type irrespective of its width.
+def AnyUnsignedInteger : Type<
+ CPred<"$_self.isUnsignedInteger()">, "unsigned integer">;
+
+// Unsigned integer type of a specific width.
+class UI<int width>
+ : Type<CPred<"$_self.isUnsignedInteger(" # width # ")">,
+ width # "-bit unsigned integer">,
+ BuildableType<"$_builder.getIntegerType(" # width #
+ ", /*isSigned=*/false)"> {
+ int bitwidth = width;
+}
+
+class UnsignedIntOfWidths<list<int> widths> :
+ AnyTypeOf<!foreach(w, widths, UI<w>),
+ StrJoinInt<widths, "/">.result # "-bit unsigned integer">;
+
+def UI1 : UI<1>;
+def UI8 : UI<8>;
+def UI16 : UI<16>;
+def UI32 : UI<32>;
+def UI64 : UI<64>;
+
// Floating point types.
// Any float type irrespective of its width.
// Note: Non standard/builtin types are allowed to exist within tensor
// types. Dialects are expected to verify that tensor types have a valid
// element type within that dialect.
- return type.isSignlessIntOrFloat() || type.isa<ComplexType>() ||
- type.isa<VectorType>() || type.isa<OpaqueType>() ||
+ return type.isa<ComplexType>() || type.isa<FloatType>() ||
+ type.isa<IntegerType>() || type.isa<OpaqueType>() ||
+ type.isa<VectorType>() ||
(type.getKind() > Type::Kind::LAST_STANDARD_TYPE);
}
/// Return true of this is a signless integer or a float type.
bool isSignlessIntOrFloat();
+ /// Return true of this is an integer(of any signedness) or a float type.
+ bool isIntOrFloat();
+
/// Print the current type.
void print(raw_ostream &os);
void dump();
auto elementType = memRefType.getElementType();
unsigned sizeInBits;
- if (elementType.isSignlessIntOrFloat()) {
+ if (elementType.isIntOrFloat()) {
sizeInBits = elementType.getIntOrFloatBitWidth();
} else {
auto vectorType = elementType.cast<VectorType>();
if (!memRefType.hasStaticShape())
return None;
auto elementType = memRefType.getElementType();
- if (!elementType.isSignlessIntOrFloat() && !elementType.isa<VectorType>())
+ if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>())
return None;
uint64_t sizeInBytes = getMemRefEltSizeInBytes(memRefType);
/// Print the integer element of the given DenseElementsAttr at 'index'.
static void printDenseIntElement(DenseElementsAttr attr, raw_ostream &os,
- unsigned index) {
+ unsigned index, bool isSigned) {
APInt value = *std::next(attr.int_value_begin(), index);
if (value.getBitWidth() == 1)
os << (value.getBoolValue() ? "true" : "false");
else
- value.print(os, /*isSigned=*/true);
+ value.print(os, isSigned);
}
/// Print the float element of the given DenseElementsAttr at 'index'.
static void printDenseFloatElement(DenseElementsAttr attr, raw_ostream &os,
- unsigned index) {
+ unsigned index, bool isSigned) {
+ assert(isSigned && "floating point values are always signed");
APFloat value = *std::next(attr.float_value_begin(), index);
printFloatValue(value, os);
}
auto type = attr.getType();
auto shape = type.getShape();
auto rank = type.getRank();
+ bool isSigned = !type.getElementType().isUnsignedInteger();
// The function used to print elements of this attribute.
auto printEltFn = type.getElementType().isa<IntegerType>()
// Special case for 0-d and splat tensors.
if (attr.isSplat()) {
- printEltFn(attr, os, 0);
+ printEltFn(attr, os, 0, isSigned);
return;
}
while (openBrackets++ < rank)
os << '[';
openBrackets = rank;
- printEltFn(attr, os, idx);
+ printEltFn(attr, os, idx, isSigned);
bumpCounter();
}
while (openBrackets-- > 0)
DenseElementsAttr DenseElementsAttr::get(ShapedType type,
ArrayRef<Attribute> values) {
- assert(type.getElementType().isSignlessIntOrFloat() &&
+ assert(type.getElementType().isIntOrFloat() &&
"expected int or float element type");
assert(hasSameElementsOrSplat(type, values));
return isSignlessInteger() || isa<FloatType>();
}
+bool Type::isIntOrFloat() { return isa<IntegerType>() || isa<FloatType>(); }
+
//===----------------------------------------------------------------------===//
// Integer Type
//===----------------------------------------------------------------------===//
}
unsigned Type::getIntOrFloatBitWidth() {
- assert(isSignlessIntOrFloat() && "only ints and floats have a bitwidth");
- if (auto intType = dyn_cast<IntegerType>()) {
+ assert(isIntOrFloat() && "only integers and floats have a bitwidth");
+ if (auto intType = dyn_cast<IntegerType>())
return intType.getWidth();
- }
-
- auto floatType = cast<FloatType>();
- return floatType.getWidth();
+ return cast<FloatType>().getWidth();
}
//===----------------------------------------------------------------------===//
"cannot get the bit size of an aggregate with a dynamic shape");
auto elementType = getElementType();
- if (elementType.isSignlessIntOrFloat())
+ if (elementType.isIntOrFloat())
return elementType.getIntOrFloatBitWidth() * getNumElements();
// Tensors can have vectors and other tensors as elements, other shaped types
auto *context = elementType.getContext();
// Check that memref is formed from allowed types.
- if (!elementType.isSignlessIntOrFloat() && !elementType.isa<VectorType>() &&
+ if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>() &&
!elementType.isa<ComplexType>())
return emitOptionalError(location, "invalid memref element type"),
MemRefType();
UnrankedMemRefType::verifyConstructionInvariants(Location loc, Type elementType,
unsigned memorySpace) {
// Check that memref is formed from allowed types.
- if (!elementType.isSignlessIntOrFloat() && !elementType.isa<VectorType>() &&
+ if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>() &&
!elementType.isa<ComplexType>())
return emitError(loc, "invalid memref element type");
return success();
return nullptr;
// Check that memref is formed from allowed types.
- if (!elementType.isSignlessIntOrFloat() && !elementType.isa<VectorType>() &&
+ if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>() &&
!elementType.isa<ComplexType>())
return emitError(typeLoc, "invalid memref element type"), nullptr;
auto elementType = memRefType.getElementType();
unsigned sizeInBits;
- if (elementType.isSignlessIntOrFloat()) {
+ if (elementType.isIntOrFloat()) {
sizeInBits = elementType.getIntOrFloatBitWidth();
} else {
auto vectorType = elementType.cast<VectorType>();
// CHECK: "splatBoolTensor"() {bar = dense<false> : tensor<i1>} : () -> ()
"splatBoolTensor"(){bar = dense<false> : tensor<i1>} : () -> ()
+ // CHECK: "splatUIntTensor"() {bar = dense<222> : tensor<2x1x4xui8>} : () -> ()
+ "splatUIntTensor"(){bar = dense<222> : tensor<2x1x4xui8>} : () -> ()
+
// CHECK: "splatIntTensor"() {bar = dense<5> : tensor<2x1x4xi32>} : () -> ()
"splatIntTensor"(){bar = dense<5> : tensor<2x1x4xi32>} : () -> ()