[mlir] Update several usages of IntegerType to properly handled unsignedness.
authorRiver Riddle <riddleriver@gmail.com>
Mon, 2 Mar 2020 17:18:45 +0000 (09:18 -0800)
committerRiver Riddle <riddleriver@gmail.com>
Mon, 2 Mar 2020 17:19:26 +0000 (09:19 -0800)
Summary: For example, DenseElementsAttr currently does not properly round-trip unsigned integer values.

Differential Revision: https://reviews.llvm.org/D75374

mlir/include/mlir/IR/Matchers.h
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/IR/StandardTypes.h
mlir/include/mlir/IR/Types.h
mlir/lib/Analysis/Utils.cpp
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/Attributes.cpp
mlir/lib/IR/StandardTypes.cpp
mlir/lib/Parser/Parser.cpp
mlir/lib/Transforms/LoopFusion.cpp
mlir/test/IR/parser.mlir

index 6321e88..d9979b8 100644 (file)
@@ -93,9 +93,8 @@ struct constant_int_op_binder {
       return false;
     auto type = op->getResult(0).getType();
 
-    if (type.isSignlessIntOrIndex()) {
+    if (type.isa<IntegerType>() || type.isa<IndexType>())
       return attr_value_binder<IntegerAttr>(bind_value).match(attr);
-    }
     if (type.isa<VectorType>() || type.isa<RankedTensorType>()) {
       if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
         return attr_value_binder<IntegerAttr>(bind_value)
index 25c0238..d431d4e 100644 (file)
@@ -339,6 +339,30 @@ def I16 : I<16>;
 def I32 : I<32>;
 def I64 : I<64>;
 
+// Unsigned integer types.
+// Any unsigned integer type irrespective of its width.
+def AnyUnsignedInteger : Type<
+  CPred<"$_self.isUnsignedInteger()">, "unsigned integer">;
+
+// Unsigned integer type of a specific width.
+class UI<int width>
+    : Type<CPred<"$_self.isUnsignedInteger(" # width # ")">,
+                  width # "-bit unsigned integer">,
+      BuildableType<"$_builder.getIntegerType(" # width #
+                    ", /*isSigned=*/false)"> {
+  int bitwidth = width;
+}
+
+class UnsignedIntOfWidths<list<int> widths> :
+    AnyTypeOf<!foreach(w, widths, UI<w>),
+              StrJoinInt<widths, "/">.result # "-bit unsigned integer">;
+
+def UI1  : UI<1>;
+def UI8  : UI<8>;
+def UI16 : UI<16>;
+def UI32 : UI<32>;
+def UI64 : UI<64>;
+
 // Floating point types.
 
 // Any float type irrespective of its width.
index 9bb9a8c..cd5ba07 100644 (file)
@@ -328,8 +328,9 @@ public:
     // Note: Non standard/builtin types are allowed to exist within tensor
     // types. Dialects are expected to verify that tensor types have a valid
     // element type within that dialect.
-    return type.isSignlessIntOrFloat() || type.isa<ComplexType>() ||
-           type.isa<VectorType>() || type.isa<OpaqueType>() ||
+    return type.isa<ComplexType>() || type.isa<FloatType>() ||
+           type.isa<IntegerType>() || type.isa<OpaqueType>() ||
+           type.isa<VectorType>() ||
            (type.getKind() > Type::Kind::LAST_STANDARD_TYPE);
   }
 
index 40f1d48..eccc90c 100644 (file)
@@ -169,6 +169,9 @@ public:
   /// Return true of this is a signless integer or a float type.
   bool isSignlessIntOrFloat();
 
+  /// Return true of this is an integer(of any signedness) or a float type.
+  bool isIntOrFloat();
+
   /// Print the current type.
   void print(raw_ostream &os);
   void dump();
index b76c0c0..14635a1 100644 (file)
@@ -314,7 +314,7 @@ static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
   auto elementType = memRefType.getElementType();
 
   unsigned sizeInBits;
-  if (elementType.isSignlessIntOrFloat()) {
+  if (elementType.isIntOrFloat()) {
     sizeInBits = elementType.getIntOrFloatBitWidth();
   } else {
     auto vectorType = elementType.cast<VectorType>();
@@ -358,7 +358,7 @@ Optional<uint64_t> mlir::getMemRefSizeInBytes(MemRefType memRefType) {
   if (!memRefType.hasStaticShape())
     return None;
   auto elementType = memRefType.getElementType();
-  if (!elementType.isSignlessIntOrFloat() && !elementType.isa<VectorType>())
+  if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>())
     return None;
 
   uint64_t sizeInBytes = getMemRefEltSizeInBytes(memRefType);
index 140f533..ac26488 100644 (file)
@@ -1372,17 +1372,18 @@ void ModulePrinter::printAttribute(Attribute attr,
 
 /// Print the integer element of the given DenseElementsAttr at 'index'.
 static void printDenseIntElement(DenseElementsAttr attr, raw_ostream &os,
-                                 unsigned index) {
+                                 unsigned index, bool isSigned) {
   APInt value = *std::next(attr.int_value_begin(), index);
   if (value.getBitWidth() == 1)
     os << (value.getBoolValue() ? "true" : "false");
   else
-    value.print(os, /*isSigned=*/true);
+    value.print(os, isSigned);
 }
 
 /// Print the float element of the given DenseElementsAttr at 'index'.
 static void printDenseFloatElement(DenseElementsAttr attr, raw_ostream &os,
-                                   unsigned index) {
+                                   unsigned index, bool isSigned) {
+  assert(isSigned && "floating point values are always signed");
   APFloat value = *std::next(attr.float_value_begin(), index);
   printFloatValue(value, os);
 }
@@ -1392,6 +1393,7 @@ void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr,
   auto type = attr.getType();
   auto shape = type.getShape();
   auto rank = type.getRank();
+  bool isSigned = !type.getElementType().isUnsignedInteger();
 
   // The function used to print elements of this attribute.
   auto printEltFn = type.getElementType().isa<IntegerType>()
@@ -1400,7 +1402,7 @@ void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr,
 
   // Special case for 0-d and splat tensors.
   if (attr.isSplat()) {
-    printEltFn(attr, os, 0);
+    printEltFn(attr, os, 0, isSigned);
     return;
   }
 
@@ -1452,7 +1454,7 @@ void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr,
     while (openBrackets++ < rank)
       os << '[';
     openBrackets = rank;
-    printEltFn(attr, os, idx);
+    printEltFn(attr, os, idx, isSigned);
     bumpCounter();
   }
   while (openBrackets-- > 0)
index 5beb12a..4526d7d 100644 (file)
@@ -608,7 +608,7 @@ DenseElementsAttr::FloatElementIterator::FloatElementIterator(
 
 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
                                          ArrayRef<Attribute> values) {
-  assert(type.getElementType().isSignlessIntOrFloat() &&
+  assert(type.getElementType().isIntOrFloat() &&
          "expected int or float element type");
   assert(hasSameElementsOrSplat(type, values));
 
index 30d5bbc..774f80a 100644 (file)
@@ -84,6 +84,8 @@ bool Type::isSignlessIntOrFloat() {
   return isSignlessInteger() || isa<FloatType>();
 }
 
+bool Type::isIntOrFloat() { return isa<IntegerType>() || isa<FloatType>(); }
+
 //===----------------------------------------------------------------------===//
 // Integer Type
 //===----------------------------------------------------------------------===//
@@ -147,13 +149,10 @@ const llvm::fltSemantics &FloatType::getFloatSemantics() {
 }
 
 unsigned Type::getIntOrFloatBitWidth() {
-  assert(isSignlessIntOrFloat() && "only ints and floats have a bitwidth");
-  if (auto intType = dyn_cast<IntegerType>()) {
+  assert(isIntOrFloat() && "only integers and floats have a bitwidth");
+  if (auto intType = dyn_cast<IntegerType>())
     return intType.getWidth();
-  }
-
-  auto floatType = cast<FloatType>();
-  return floatType.getWidth();
+  return cast<FloatType>().getWidth();
 }
 
 //===----------------------------------------------------------------------===//
@@ -202,7 +201,7 @@ int64_t ShapedType::getSizeInBits() const {
          "cannot get the bit size of an aggregate with a dynamic shape");
 
   auto elementType = getElementType();
-  if (elementType.isSignlessIntOrFloat())
+  if (elementType.isIntOrFloat())
     return elementType.getIntOrFloatBitWidth() * getNumElements();
 
   // Tensors can have vectors and other tensors as elements, other shaped types
@@ -373,7 +372,7 @@ MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType,
   auto *context = elementType.getContext();
 
   // Check that memref is formed from allowed types.
-  if (!elementType.isSignlessIntOrFloat() && !elementType.isa<VectorType>() &&
+  if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>() &&
       !elementType.isa<ComplexType>())
     return emitOptionalError(location, "invalid memref element type"),
            MemRefType();
@@ -451,7 +450,7 @@ LogicalResult
 UnrankedMemRefType::verifyConstructionInvariants(Location loc, Type elementType,
                                                  unsigned memorySpace) {
   // Check that memref is formed from allowed types.
-  if (!elementType.isSignlessIntOrFloat() && !elementType.isa<VectorType>() &&
+  if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>() &&
       !elementType.isa<ComplexType>())
     return emitError(loc, "invalid memref element type");
   return success();
index 668fb69..661bddf 100644 (file)
@@ -1102,7 +1102,7 @@ Type Parser::parseMemRefType() {
     return nullptr;
 
   // Check that memref is formed from allowed types.
-  if (!elementType.isSignlessIntOrFloat() && !elementType.isa<VectorType>() &&
+  if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>() &&
       !elementType.isa<ComplexType>())
     return emitError(typeLoc, "invalid memref element type"), nullptr;
 
index ef1af5d..bcb0c16 100644 (file)
@@ -869,7 +869,7 @@ static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
   auto elementType = memRefType.getElementType();
 
   unsigned sizeInBits;
-  if (elementType.isSignlessIntOrFloat()) {
+  if (elementType.isIntOrFloat()) {
     sizeInBits = elementType.getIntOrFloatBitWidth();
   } else {
     auto vectorType = elementType.cast<VectorType>();
index bec1fbd..3baf064 100644 (file)
@@ -616,6 +616,9 @@ func @splattensorattr() -> () {
   // CHECK: "splatBoolTensor"() {bar = dense<false> : tensor<i1>} : () -> ()
   "splatBoolTensor"(){bar = dense<false> : tensor<i1>} : () -> ()
 
+  // CHECK: "splatUIntTensor"() {bar = dense<222> : tensor<2x1x4xui8>} : () -> ()
+  "splatUIntTensor"(){bar = dense<222> : tensor<2x1x4xui8>} : () -> ()
+
   // CHECK: "splatIntTensor"() {bar = dense<5> : tensor<2x1x4xi32>} : () -> ()
   "splatIntTensor"(){bar = dense<5> : tensor<2x1x4xi32>} : () -> ()