Add support for hexadecimal float literals
authorAlex Zinenko <zinenko@google.com>
Tue, 30 Jul 2019 21:05:49 +0000 (14:05 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 30 Jul 2019 21:06:26 +0000 (14:06 -0700)
MLIR does not have support for parsing special floating point values such as
infinities and NaNs.  If programmatically constructed, these values are printed
as NaN and (+-)Inf and cannot be parsed back.  Add parser support for
hexadecimal literals in float attributes, following LLVM IR.  The literal
corresponds to the in-memory representation of the floating point value.
IEEE 754 defines a range of possible values for NaNs, storing the bitwise
representation allows MLIR to properly roundtrip NaNs with different bit values
of significands.

The initial version of this commit was missing support for float literals that
used to be printed in decimal notation as a fallback, but ended up being
printed in hexadecimal format which became the fallback for special values.
The decimal fallback behavior was not exercised by tests.  It is currently
reinstated and tested by the newly added test @f32_potential_precision_loss in
parser.mlir.

PiperOrigin-RevId: 260790900

mlir/g3doc/LangRef.md
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/Parser/Parser.cpp
mlir/test/IR/invalid.mlir
mlir/test/IR/parser.mlir

index 9f28f93..f259b6c 100644 (file)
@@ -840,10 +840,26 @@ Syntax:
 
 ``` {.ebnf}
 float-attribute ::= float-literal (`:` float-type)?
+                  | hexadecimal-literal `:` float-type
 ```
 
 A float attribute is a literal attribute that represents a floating point value
-of the specified [float type](#floating-point-types).
+of the specified [float type](#floating-point-types). It can be represented in
+the hexadecimal form where the hexadecimal value is interpreted as bits of the
+underlying binary representation. This form is useful for representing infinity
+and NaN floating point values. To avoid confusion with integer attributes,
+hexadecimal literals _must_ be followed by a float type to define a float
+attribute.
+
+Examples:
+
+``` {.mlir}
+42.0         // float attribute defaults to f64 type
+42.0 : f32   // float attribute of f32 type
+0x7C00 : f16 // positive infinity
+0x7CFF : f16 // NaN (one of possible values)
+42 : f32     // Error: expected integer type
+```
 
 #### String Attribute
 
index bb17b3e..31d45bd 100644 (file)
@@ -471,15 +471,24 @@ static void printFloatValue(const APFloat &apValue, raw_ostream &os) {
             ((strValue[0] == '-' || strValue[0] == '+') &&
              (strValue[1] >= '0' && strValue[1] <= '9'))) &&
            "[-+]?[0-9] regex does not match!");
-    // Reparse stringized version!
-    if (APFloat(apValue.getSemantics(), strValue).bitwiseIsEqual(apValue)) {
-      os << strValue;
-      return;
+
+    // Parse back the stringized version and check that the value is equal
+    // (i.e., there is no precision loss). If it is not, use the default format
+    // of APFloat instead of the exponential notation.
+    if (!APFloat(apValue.getSemantics(), strValue).bitwiseIsEqual(apValue)) {
+      strValue.clear();
+      apValue.toString(strValue);
     }
+    os << strValue;
+    return;
   }
 
+  // Print special values in hexadecimal format.  The sign bit should be
+  // included in the literal.
   SmallVector<char, 16> str;
-  apValue.toString(str);
+  APInt apInt = apValue.bitcastToAPInt();
+  apInt.toString(str, /*Radix=*/16, /*Signed=*/false,
+                 /*formatAsCLiteral=*/true);
   os << str;
 }
 
index cee3a53..a5dd981 100644 (file)
@@ -227,8 +227,9 @@ public:
   /// Parse a float attribute.
   Attribute parseFloatAttr(Type type, bool isNegative);
 
-  /// Parse an integer attribute.
-  Attribute parseIntegerAttr(Type type, bool isSigned);
+  /// Parse a decimal or a hexadecimal literal, which can be either an integer
+  /// or a float attribute.
+  Attribute parseDecOrHexAttr(Type type, bool isNegative);
 
   /// Parse an opaque elements attribute.
   Attribute parseOpaqueElementsAttr();
@@ -998,11 +999,11 @@ Attribute Parser::parseAttribute(Type type) {
   case Token::floatliteral:
     return parseFloatAttr(type, /*isNegative=*/false);
   case Token::integer:
-    return parseIntegerAttr(type, /*isSigned=*/false);
+    return parseDecOrHexAttr(type, /*isNegative=*/false);
   case Token::minus: {
     consumeToken(Token::minus);
     if (getToken().is(Token::integer))
-      return parseIntegerAttr(type, /*isSigned=*/true);
+      return parseDecOrHexAttr(type, /*isNegative=*/true);
     if (getToken().is(Token::floatliteral))
       return parseFloatAttr(type, /*isNegative=*/true);
 
@@ -1151,12 +1152,17 @@ Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
   return FloatAttr::get(type, isNegative ? -val.getValue() : val.getValue());
 }
 
-/// Parse an integer attribute.
-Attribute Parser::parseIntegerAttr(Type type, bool isSigned) {
+/// Parse a decimal or a hexadecimal literal, which can be either an integer
+/// or a float attribute.
+Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
   auto val = getToken().getUInt64IntegerValue();
-  if (!val.hasValue() ||
-      (isSigned ? (int64_t)-val.getValue() >= 0 : (int64_t)val.getValue() < 0))
+  if (!val.hasValue())
     return (emitError("integer constant out of range for attribute"), nullptr);
+
+  // Remember if the literal is hexadecimal.
+  StringRef spelling = getToken().getSpelling();
+  bool isHex = spelling.size() > 1 && spelling[1] == 'x';
+
   consumeToken(Token::integer);
   if (!type) {
     // Default to i64 if not type is specified.
@@ -1165,14 +1171,47 @@ Attribute Parser::parseIntegerAttr(Type type, bool isSigned) {
     else if (!(type = parseType()))
       return nullptr;
   }
+
+  // Hexadecimal representation of float literals is not supported for bfloat16.
+  // When supported, the literal should be unsigned.
+  auto floatType = type.dyn_cast<FloatType>();
+  if (floatType && !type.isBF16()) {
+    if (isNegative) {
+      emitError("hexadecimal float literal should not have a leading minus");
+      return nullptr;
+    }
+    if (!isHex) {
+      emitError("unexpected decimal integer literal for a float attribute")
+              .attachNote()
+          << "add a trailing dot to make the literal a float";
+      return nullptr;
+    }
+
+    // Construct a float attribute bitwise equivalent to the integer literal.
+    int width = type.getIntOrFloatBitWidth();
+    APInt apInt(width, *val, isNegative);
+    if (apInt != *val) {
+      emitError("hexadecimal float constant out of range for attribute");
+      return nullptr;
+    }
+    APFloat apFloat(floatType.getFloatSemantics(), apInt);
+    return builder.getFloatAttr(type, apFloat);
+  }
+
   if (!type.isIntOrIndex())
-    return (emitError("integer value not valid for specified type"), nullptr);
+    return (emitError("integer literal not valid for specified type"), nullptr);
 
+  // Parse the integer literal.
   int width = type.isIndex() ? 64 : type.getIntOrFloatBitWidth();
-  APInt apInt(width, *val, isSigned);
+  APInt apInt(width, *val, isNegative);
   if (apInt != *val)
     return (emitError("integer constant out of range for attribute"), nullptr);
-  return builder.getIntegerAttr(type, isSigned ? -apInt : apInt);
+
+  // Otherwise construct an integer attribute.
+  if (isNegative ? (int64_t)-val.getValue() >= 0 : (int64_t)val.getValue() < 0)
+    return (emitError("integer constant out of range for attribute"), nullptr);
+
+  return builder.getIntegerAttr(type, isNegative ? -apInt : apInt);
 }
 
 /// Parse an opaque elements attribute.
index 4250ab4..d6032a9 100644 (file)
@@ -1044,3 +1044,32 @@ func @invalid_region_dominance() {
   }) : () -> ()
   return
 }
+
+// -----
+
+func @hexadecimal_bf16() {
+  // expected-error @+1 {{integer literal not valid for specified type}}
+  "foo"() {value = 0xffff : bf16} : () -> ()
+}
+
+// -----
+
+func @hexadecimal_float_leading_minus() {
+  // expected-error @+1 {{hexadecimal float literal should not have a leading minus}}
+  "foo"() {value = -0x7fff : f16} : () -> ()
+}
+
+// -----
+
+func @hexadecimal_float_literal_overflow() {
+  // expected-error @+1 {{hexadecimal float constant out of range for attribute}}
+  "foo"() {value = 0xffffffff : f16} : () -> ()
+}
+
+// -----
+
+func @decimal_float_literal() {
+  // expected-error @+2 {{unexpected decimal integer literal for a float attribute}}
+  // expected-note @+1 {{add a trailing dot to make the literal a float}}
+  "foo"() {value = 42 : f32} : () -> ()
+}
index 6f0ad06..104af5e 100644 (file)
@@ -945,3 +945,81 @@ func @dialect_attribute_with_type() {
   // CHECK-NEXT: foo = #foo.attr : i32
   "foo.unknown_op"() {foo = #foo.attr : i32} : () -> ()
 }
+
+// CHECK-LABEL: @f16_special_values
+func @f16_special_values() {
+  // F16 NaNs.
+  // CHECK: constant 0x7C01 : f16
+  %0 = constant 0x7C01 : f16
+  // CHECK: constant 0x7FFF : f16
+  %1 = constant 0x7FFF : f16
+  // CHECK: constant 0xFFFF : f16
+  %2 = constant 0xFFFF : f16
+
+  // F16 positive infinity.
+  // CHECK: constant 0x7C00 : f16
+  %3 = constant 0x7C00 : f16
+  // F16 negative inifinity.
+  // CHECK: constant 0xFC00 : f16
+  %4 = constant 0xFC00 : f16
+
+  return
+}
+
+// CHECK-LABEL: @f32_special_values
+func @f32_special_values() {
+  // F32 signaling NaNs.
+  // CHECK: constant 0x7F800001 : f32
+  %0 = constant 0x7F800001 : f32
+  // CHECK: constant 0x7FBFFFFF : f32
+  %1 = constant 0x7FBFFFFF : f32
+
+  // F32 quiet NaNs.
+  // CHECK: constant 0x7FC00000 : f32
+  %2 = constant 0x7FC00000 : f32
+  // CHECK: constant 0xFFFFFFFF : f32
+  %3 = constant 0xFFFFFFFF : f32
+
+  // F32 positive infinity.
+  // CHECK: constant 0x7F800000 : f32
+  %4 = constant 0x7F800000 : f32
+  // F32 negative infinity.
+  // CHECK: constant 0xFF800000 : f32
+  %5 = constant 0xFF800000 : f32
+
+  return
+}
+
+// CHECK-LABEL: @f64_special_values
+func @f64_special_values() {
+  // F64 signaling NaNs.
+  // CHECK: constant 0x7FF0000000000001 : f64
+  %0 = constant 0x7FF0000000000001 : f64
+  // CHECK: constant 0x7FF8000000000000 : f64
+  %1 = constant 0x7FF8000000000000 : f64
+
+  // F64 quiet NaNs.
+  // CHECK: constant 0x7FF0000001000000 : f64
+  %2 = constant 0x7FF0000001000000 : f64
+  // CHECK: constant 0xFFF0000001000000 : f64
+  %3 = constant 0xFFF0000001000000 : f64
+
+  // F64 positive inifinity.
+  // CHECK: constant 0x7FF0000000000000 : f64
+  %4 = constant 0x7FF0000000000000 : f64
+  // F64 negative infinity.
+  // CHECK: constant 0xFFF0000000000000 : f64
+  %5 = constant 0xFFF0000000000000 : f64
+
+  return
+}
+
+// We want to print floats in exponential notation with 6 significant digits,
+// but it may lead to precision loss when parsing back, in which case we print
+// the decimal form instead.
+// CHECK-LABEL: @f32_potential_precision_loss()
+func @f32_potential_precision_loss() {
+  // CHECK: constant -1.23697901 : f32
+  %0 = constant -1.23697901 : f32
+  return
+}