From 1beffb92d17c8ba54a6c4b3d4e96867403717a3f Mon Sep 17 00:00:00 2001 From: Chris Lattner Date: Mon, 13 Apr 2020 16:37:00 -0700 Subject: [PATCH] Fix the MLIR integer attribute parser to be correct in the face of large integer attributes, it was previously artificially limited to 64 bits. Reviewers: rriddle! Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, grosul1, frgossen, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D78065 --- mlir/lib/Parser/Parser.cpp | 51 ++++++++++++++++++++++++++------------------- mlir/lib/Parser/Token.cpp | 2 +- mlir/lib/Parser/Token.h | 5 ++++- mlir/test/IR/attribute.mlir | 7 +++++-- mlir/test/IR/invalid.mlir | 30 ++++++++++++++++++++++++++ 5 files changed, 69 insertions(+), 26 deletions(-) diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index e339588..0832e4b 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -1790,37 +1790,45 @@ static Optional buildHexadecimalFloatLiteral(Parser *p, FloatType type, /// Construct an APint from a parsed value, a known attribute type and /// sign. static Optional buildAttributeAPInt(Type type, bool isNegative, - uint64_t value) { - // We have the integer literal as an uint64_t in val, now convert it into an - // APInt and check that we don't overflow. - int width = type.isIndex() ? 64 : type.getIntOrFloatBitWidth(); - APInt apInt(width, value, isNegative); - if (apInt != value) + StringRef spelling) { + // Parse the integer value into an APInt that is big enough to hold the value. + APInt result; + bool isHex = spelling.size() > 1 && spelling[1] == 'x'; + if (spelling.getAsInteger(isHex ? 0 : 10, result)) return llvm::None; + // Extend or truncate the bitwidth to the right size. + unsigned width = type.isIndex() ? 64 : type.getIntOrFloatBitWidth(); + if (width > result.getBitWidth()) { + result = result.zext(width); + } else if (width < result.getBitWidth()) { + // The parser can return an unnecessarily wide result with leading zeros. + // This isn't a problem, but truncating off bits is bad. + if (result.countLeadingZeros() < result.getBitWidth() - width) + return llvm::None; + + result = result.trunc(width); + } + if (isNegative) { // The value is negative, we have an overflow if the sign bit is not set // in the negated apInt. - apInt.negate(); - if (!apInt.isSignBitSet()) + result.negate(); + if (!result.isSignBitSet()) return llvm::None; } else if ((type.isSignedInteger() || type.isIndex()) && - apInt.isSignBitSet()) { + result.isSignBitSet()) { // The value is a positive signed integer or index, // we have an overflow if the sign bit is set. return llvm::None; } - return apInt; + return result; } /// Parse a decimal or a hexadecimal literal, which can be either an integer /// or a float attribute. Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) { - auto val = getToken().getUInt64IntegerValue(); - if (!val.hasValue()) - return (emitError("integer constant out of range for attribute"), nullptr); - // Remember if the literal is hexadecimal. StringRef spelling = getToken().getSpelling(); auto loc = state.curToken.getLoc(); @@ -1848,6 +1856,10 @@ Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) { return nullptr; } + auto val = Token::getUInt64IntegerValue(spelling); + if (!val.hasValue()) + return emitError("integer constant out of range for attribute"), nullptr; + // Construct a float attribute bitwise equivalent to the integer literal. Optional apVal = buildHexadecimalFloatLiteral(this, floatType, *val); @@ -1864,8 +1876,7 @@ Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) { return nullptr; } - Optional apInt = buildAttributeAPInt(type, isNegative, *val); - + Optional apInt = buildAttributeAPInt(type, isNegative, spelling); if (!apInt) return emitError(loc, "integer constant out of range for attribute"), nullptr; @@ -2085,12 +2096,8 @@ DenseElementsAttr TensorLiteralParser::getIntAttr(llvm::SMLoc loc, } // Create APInt values for each element with the correct bitwidth. - auto val = token.getUInt64IntegerValue(); - if (!val.hasValue()) { - p.emitError(tokenLoc, "integer constant out of range for attribute"); - return nullptr; - } - Optional apInt = buildAttributeAPInt(eltTy, isNegative, *val); + Optional apInt = + buildAttributeAPInt(eltTy, isNegative, token.getSpelling()); if (!apInt) return (p.emitError(tokenLoc, "integer constant out of range for type"), nullptr); diff --git a/mlir/lib/Parser/Token.cpp b/mlir/lib/Parser/Token.cpp index b619af0..7b5fb9a 100644 --- a/mlir/lib/Parser/Token.cpp +++ b/mlir/lib/Parser/Token.cpp @@ -37,7 +37,7 @@ Optional Token::getUnsignedIntegerValue() const { /// For an integer token, return its value as a uint64_t. If it doesn't fit, /// return None. -Optional Token::getUInt64IntegerValue() const { +Optional Token::getUInt64IntegerValue(StringRef spelling) { bool isHex = spelling.size() > 1 && spelling[1] == 'x'; uint64_t result = 0; diff --git a/mlir/lib/Parser/Token.h b/mlir/lib/Parser/Token.h index 7952aca..4f9c098 100644 --- a/mlir/lib/Parser/Token.h +++ b/mlir/lib/Parser/Token.h @@ -65,7 +65,10 @@ public: /// For an integer token, return its value as an uint64_t. If it doesn't fit, /// return None. - Optional getUInt64IntegerValue() const; + static Optional getUInt64IntegerValue(StringRef spelling); + Optional getUInt64IntegerValue() const { + return getUInt64IntegerValue(getSpelling()); + } /// For a floatliteral token, return its value as a double. Returns None in /// the case of underflow or overflow. diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir index c13c12c..31804b2 100644 --- a/mlir/test/IR/attribute.mlir +++ b/mlir/test/IR/attribute.mlir @@ -83,8 +83,11 @@ func @int_attrs_pass() { // CHECK-SAME: attr_20 = 1 : ui1 attr_20 = 1 : ui1, // CHECK-SAME: attr_21 = -1 : si1 - attr_21 = -1 : si1 - + attr_21 = -1 : si1, + // CHECK-SAME: attr_22 = 79228162514264337593543950335 : ui96 + attr_22 = 79228162514264337593543950335 : ui96, + // CHECK-SAME: attr_23 = -39614081257132168796771975168 : si96 + attr_23 = -39614081257132168796771975168 : si96 } : () -> () return diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir index 255f6cc..cf7f7eb 100644 --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -1459,3 +1459,33 @@ func @large_bound() { } : () -> () return } + +// ----- + +func @really_large_bound() { + "test.out_of_range_attribute"() { + // expected-error @+1 {{integer constant out of range for attribute}} + attr = 79228162514264337593543950336 : ui96 + } : () -> () + return +} + +// ----- + +func @really_large_bound() { + "test.out_of_range_attribute"() { + // expected-error @+1 {{integer constant out of range for attribute}} + attr = 79228162514264337593543950336 : i96 + } : () -> () + return +} + +// ----- + +func @really_large_bound() { + "test.out_of_range_attribute"() { + // expected-error @+1 {{integer constant out of range for attribute}} + attr = 39614081257132168796771975168 : si96 + } : () -> () + return +} -- 2.7.4