Fix the MLIR integer attribute parser to be correct in the face of large integer...
authorChris Lattner <clattner@nondot.org>
Mon, 13 Apr 2020 23:37:00 +0000 (16:37 -0700)
committerChris Lattner <clattner@nondot.org>
Tue, 14 Apr 2020 04:50:36 +0000 (21:50 -0700)
Reviewers: rriddle!

Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, grosul1, frgossen, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D78065

mlir/lib/Parser/Parser.cpp
mlir/lib/Parser/Token.cpp
mlir/lib/Parser/Token.h
mlir/test/IR/attribute.mlir
mlir/test/IR/invalid.mlir

index e339588..0832e4b 100644 (file)
@@ -1790,37 +1790,45 @@ static Optional<APFloat> buildHexadecimalFloatLiteral(Parser *p, FloatType type,
 /// Construct an APint from a parsed value, a known attribute type and
 /// sign.
 static Optional<APInt> buildAttributeAPInt(Type type, bool isNegative,
-                                           uint64_t value) {
-  // We have the integer literal as an uint64_t in val, now convert it into an
-  // APInt and check that we don't overflow.
-  int width = type.isIndex() ? 64 : type.getIntOrFloatBitWidth();
-  APInt apInt(width, value, isNegative);
-  if (apInt != value)
+                                           StringRef spelling) {
+  // Parse the integer value into an APInt that is big enough to hold the value.
+  APInt result;
+  bool isHex = spelling.size() > 1 && spelling[1] == 'x';
+  if (spelling.getAsInteger(isHex ? 0 : 10, result))
     return llvm::None;
 
+  // Extend or truncate the bitwidth to the right size.
+  unsigned width = type.isIndex() ? 64 : type.getIntOrFloatBitWidth();
+  if (width > result.getBitWidth()) {
+    result = result.zext(width);
+  } else if (width < result.getBitWidth()) {
+    // The parser can return an unnecessarily wide result with leading zeros.
+    // This isn't a problem, but truncating off bits is bad.
+    if (result.countLeadingZeros() < result.getBitWidth() - width)
+      return llvm::None;
+
+    result = result.trunc(width);
+  }
+
   if (isNegative) {
     // The value is negative, we have an overflow if the sign bit is not set
     // in the negated apInt.
-    apInt.negate();
-    if (!apInt.isSignBitSet())
+    result.negate();
+    if (!result.isSignBitSet())
       return llvm::None;
   } else if ((type.isSignedInteger() || type.isIndex()) &&
-             apInt.isSignBitSet()) {
+             result.isSignBitSet()) {
     // The value is a positive signed integer or index,
     // we have an overflow if the sign bit is set.
     return llvm::None;
   }
 
-  return apInt;
+  return result;
 }
 
 /// Parse a decimal or a hexadecimal literal, which can be either an integer
 /// or a float attribute.
 Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
-  auto val = getToken().getUInt64IntegerValue();
-  if (!val.hasValue())
-    return (emitError("integer constant out of range for attribute"), nullptr);
-
   // Remember if the literal is hexadecimal.
   StringRef spelling = getToken().getSpelling();
   auto loc = state.curToken.getLoc();
@@ -1848,6 +1856,10 @@ Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
       return nullptr;
     }
 
+    auto val = Token::getUInt64IntegerValue(spelling);
+    if (!val.hasValue())
+      return emitError("integer constant out of range for attribute"), nullptr;
+
     // Construct a float attribute bitwise equivalent to the integer literal.
     Optional<APFloat> apVal =
         buildHexadecimalFloatLiteral(this, floatType, *val);
@@ -1864,8 +1876,7 @@ Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
     return nullptr;
   }
 
-  Optional<APInt> apInt = buildAttributeAPInt(type, isNegative, *val);
-
+  Optional<APInt> apInt = buildAttributeAPInt(type, isNegative, spelling);
   if (!apInt)
     return emitError(loc, "integer constant out of range for attribute"),
            nullptr;
@@ -2085,12 +2096,8 @@ DenseElementsAttr TensorLiteralParser::getIntAttr(llvm::SMLoc loc,
     }
 
     // Create APInt values for each element with the correct bitwidth.
-    auto val = token.getUInt64IntegerValue();
-    if (!val.hasValue()) {
-      p.emitError(tokenLoc, "integer constant out of range for attribute");
-      return nullptr;
-    }
-    Optional<APInt> apInt = buildAttributeAPInt(eltTy, isNegative, *val);
+    Optional<APInt> apInt =
+        buildAttributeAPInt(eltTy, isNegative, token.getSpelling());
     if (!apInt)
       return (p.emitError(tokenLoc, "integer constant out of range for type"),
               nullptr);
index b619af0..7b5fb9a 100644 (file)
@@ -37,7 +37,7 @@ Optional<unsigned> Token::getUnsignedIntegerValue() const {
 
 /// For an integer token, return its value as a uint64_t.  If it doesn't fit,
 /// return None.
-Optional<uint64_t> Token::getUInt64IntegerValue() const {
+Optional<uint64_t> Token::getUInt64IntegerValue(StringRef spelling) {
   bool isHex = spelling.size() > 1 && spelling[1] == 'x';
 
   uint64_t result = 0;
index 7952aca..4f9c098 100644 (file)
@@ -65,7 +65,10 @@ public:
 
   /// For an integer token, return its value as an uint64_t.  If it doesn't fit,
   /// return None.
-  Optional<uint64_t> getUInt64IntegerValue() const;
+  static Optional<uint64_t> getUInt64IntegerValue(StringRef spelling);
+  Optional<uint64_t> getUInt64IntegerValue() const {
+    return getUInt64IntegerValue(getSpelling());
+  }
 
   /// For a floatliteral token, return its value as a double. Returns None in
   /// the case of underflow or overflow.
index c13c12c..31804b2 100644 (file)
@@ -83,8 +83,11 @@ func @int_attrs_pass() {
     // CHECK-SAME: attr_20 = 1 : ui1
     attr_20 = 1 : ui1,
     // CHECK-SAME: attr_21 = -1 : si1
-    attr_21 = -1 : si1
-
+    attr_21 = -1 : si1,
+    // CHECK-SAME: attr_22 = 79228162514264337593543950335 : ui96
+    attr_22 = 79228162514264337593543950335 : ui96,
+    // CHECK-SAME: attr_23 = -39614081257132168796771975168 : si96
+    attr_23 = -39614081257132168796771975168 : si96
   } : () -> ()
 
   return
index 255f6cc..cf7f7eb 100644 (file)
@@ -1459,3 +1459,33 @@ func @large_bound() {
   } : () -> ()
   return
 }
+
+// -----
+
+func @really_large_bound() {
+  "test.out_of_range_attribute"() {
+    // expected-error @+1 {{integer constant out of range for attribute}}
+    attr = 79228162514264337593543950336 : ui96
+  } : () -> ()
+  return
+}
+
+// -----
+
+func @really_large_bound() {
+  "test.out_of_range_attribute"() {
+    // expected-error @+1 {{integer constant out of range for attribute}}
+    attr = 79228162514264337593543950336 : i96
+  } : () -> ()
+  return
+}
+
+// -----
+
+func @really_large_bound() {
+  "test.out_of_range_attribute"() {
+    // expected-error @+1 {{integer constant out of range for attribute}}
+    attr = 39614081257132168796771975168 : si96
+  } : () -> ()
+  return
+}