Add support for hexadecimal float literals
authorAlex Zinenko <zinenko@google.com>
Thu, 25 Jul 2019 21:15:33 +0000 (14:15 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 25 Jul 2019 21:16:02 +0000 (14:16 -0700)
MLIR does not have support for parsing special floating point values such as
infinities and NaNs.  If programmatically constructed, these values are printed
as NaN and (+-)Inf and cannot be parsed back.  Add parser support for
hexadecimal literals in float attributes, following LLVM IR.  The literal
corresponds to the in-memory representation of the floating point value.
IEEE 754 defines a range of possible values for NaNs, storing the bitwise
representation allows MLIR to properly roundtrip NaNs with different bit values
of significands.

PiperOrigin-RevId: 260018802

mlir/g3doc/LangRef.md
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/Parser/Parser.cpp
mlir/test/IR/invalid.mlir
mlir/test/IR/parser.mlir

index 9f28f93..f259b6c 100644 (file)
@@ -840,10 +840,26 @@ Syntax:
 
 ``` {.ebnf}
 float-attribute ::= float-literal (`:` float-type)?
+                  | hexadecimal-literal `:` float-type
 ```
 
 A float attribute is a literal attribute that represents a floating point value
-of the specified [float type](#floating-point-types).
+of the specified [float type](#floating-point-types). It can be represented in
+the hexadecimal form where the hexadecimal value is interpreted as bits of the
+underlying binary representation. This form is useful for representing infinity
+and NaN floating point values. To avoid confusion with integer attributes,
+hexadecimal literals _must_ be followed by a float type to define a float
+attribute.
+
+Examples:
+
+``` {.mlir}
+42.0         // float attribute defaults to f64 type
+42.0 : f32   // float attribute of f32 type
+0x7C00 : f16 // positive infinity
+0x7CFF : f16 // NaN (one of possible values)
+42 : f32     // Error: expected integer type
+```
 
 #### String Attribute
 
index bb17b3e..b8d23ed 100644 (file)
@@ -478,8 +478,12 @@ static void printFloatValue(const APFloat &apValue, raw_ostream &os) {
     }
   }
 
+  // Print special values in hexadecimal format.  The sign bit should be
+  // included in the literal.
   SmallVector<char, 16> str;
-  apValue.toString(str);
+  APInt apInt = apValue.bitcastToAPInt();
+  apInt.toString(str, /*Radix=*/16, /*Signed=*/false,
+                 /*formatAsCLiteral=*/true);
   os << str;
 }
 
index cee3a53..a5dd981 100644 (file)
@@ -227,8 +227,9 @@ public:
   /// Parse a float attribute.
   Attribute parseFloatAttr(Type type, bool isNegative);
 
-  /// Parse an integer attribute.
-  Attribute parseIntegerAttr(Type type, bool isSigned);
+  /// Parse a decimal or a hexadecimal literal, which can be either an integer
+  /// or a float attribute.
+  Attribute parseDecOrHexAttr(Type type, bool isNegative);
 
   /// Parse an opaque elements attribute.
   Attribute parseOpaqueElementsAttr();
@@ -998,11 +999,11 @@ Attribute Parser::parseAttribute(Type type) {
   case Token::floatliteral:
     return parseFloatAttr(type, /*isNegative=*/false);
   case Token::integer:
-    return parseIntegerAttr(type, /*isSigned=*/false);
+    return parseDecOrHexAttr(type, /*isNegative=*/false);
   case Token::minus: {
     consumeToken(Token::minus);
     if (getToken().is(Token::integer))
-      return parseIntegerAttr(type, /*isSigned=*/true);
+      return parseDecOrHexAttr(type, /*isNegative=*/true);
     if (getToken().is(Token::floatliteral))
       return parseFloatAttr(type, /*isNegative=*/true);
 
@@ -1151,12 +1152,17 @@ Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
   return FloatAttr::get(type, isNegative ? -val.getValue() : val.getValue());
 }
 
-/// Parse an integer attribute.
-Attribute Parser::parseIntegerAttr(Type type, bool isSigned) {
+/// Parse a decimal or a hexadecimal literal, which can be either an integer
+/// or a float attribute.
+Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
   auto val = getToken().getUInt64IntegerValue();
-  if (!val.hasValue() ||
-      (isSigned ? (int64_t)-val.getValue() >= 0 : (int64_t)val.getValue() < 0))
+  if (!val.hasValue())
     return (emitError("integer constant out of range for attribute"), nullptr);
+
+  // Remember if the literal is hexadecimal.
+  StringRef spelling = getToken().getSpelling();
+  bool isHex = spelling.size() > 1 && spelling[1] == 'x';
+
   consumeToken(Token::integer);
   if (!type) {
     // Default to i64 if not type is specified.
@@ -1165,14 +1171,47 @@ Attribute Parser::parseIntegerAttr(Type type, bool isSigned) {
     else if (!(type = parseType()))
       return nullptr;
   }
+
+  // Hexadecimal representation of float literals is not supported for bfloat16.
+  // When supported, the literal should be unsigned.
+  auto floatType = type.dyn_cast<FloatType>();
+  if (floatType && !type.isBF16()) {
+    if (isNegative) {
+      emitError("hexadecimal float literal should not have a leading minus");
+      return nullptr;
+    }
+    if (!isHex) {
+      emitError("unexpected decimal integer literal for a float attribute")
+              .attachNote()
+          << "add a trailing dot to make the literal a float";
+      return nullptr;
+    }
+
+    // Construct a float attribute bitwise equivalent to the integer literal.
+    int width = type.getIntOrFloatBitWidth();
+    APInt apInt(width, *val, isNegative);
+    if (apInt != *val) {
+      emitError("hexadecimal float constant out of range for attribute");
+      return nullptr;
+    }
+    APFloat apFloat(floatType.getFloatSemantics(), apInt);
+    return builder.getFloatAttr(type, apFloat);
+  }
+
   if (!type.isIntOrIndex())
-    return (emitError("integer value not valid for specified type"), nullptr);
+    return (emitError("integer literal not valid for specified type"), nullptr);
 
+  // Parse the integer literal.
   int width = type.isIndex() ? 64 : type.getIntOrFloatBitWidth();
-  APInt apInt(width, *val, isSigned);
+  APInt apInt(width, *val, isNegative);
   if (apInt != *val)
     return (emitError("integer constant out of range for attribute"), nullptr);
-  return builder.getIntegerAttr(type, isSigned ? -apInt : apInt);
+
+  // Otherwise construct an integer attribute.
+  if (isNegative ? (int64_t)-val.getValue() >= 0 : (int64_t)val.getValue() < 0)
+    return (emitError("integer constant out of range for attribute"), nullptr);
+
+  return builder.getIntegerAttr(type, isNegative ? -apInt : apInt);
 }
 
 /// Parse an opaque elements attribute.
index 4250ab4..d6032a9 100644 (file)
@@ -1044,3 +1044,32 @@ func @invalid_region_dominance() {
   }) : () -> ()
   return
 }
+
+// -----
+
+func @hexadecimal_bf16() {
+  // expected-error @+1 {{integer literal not valid for specified type}}
+  "foo"() {value = 0xffff : bf16} : () -> ()
+}
+
+// -----
+
+func @hexadecimal_float_leading_minus() {
+  // expected-error @+1 {{hexadecimal float literal should not have a leading minus}}
+  "foo"() {value = -0x7fff : f16} : () -> ()
+}
+
+// -----
+
+func @hexadecimal_float_literal_overflow() {
+  // expected-error @+1 {{hexadecimal float constant out of range for attribute}}
+  "foo"() {value = 0xffffffff : f16} : () -> ()
+}
+
+// -----
+
+func @decimal_float_literal() {
+  // expected-error @+2 {{unexpected decimal integer literal for a float attribute}}
+  // expected-note @+1 {{add a trailing dot to make the literal a float}}
+  "foo"() {value = 42 : f32} : () -> ()
+}
index 6f0ad06..985b517 100644 (file)
@@ -945,3 +945,71 @@ func @dialect_attribute_with_type() {
   // CHECK-NEXT: foo = #foo.attr : i32
   "foo.unknown_op"() {foo = #foo.attr : i32} : () -> ()
 }
+
+// CHECK-LABEL: @f16_special_values
+func @f16_special_values() {
+  // F16 NaNs.
+  // CHECK: constant 0x7C01 : f16
+  %0 = constant 0x7C01 : f16
+  // CHECK: constant 0x7FFF : f16
+  %1 = constant 0x7FFF : f16
+  // CHECK: constant 0xFFFF : f16
+  %2 = constant 0xFFFF : f16
+
+  // F16 positive infinity.
+  // CHECK: constant 0x7C00 : f16
+  %3 = constant 0x7C00 : f16
+  // F16 negative inifinity.
+  // CHECK: constant 0xFC00 : f16
+  %4 = constant 0xFC00 : f16
+
+  return
+}
+
+// CHECK-LABEL: @f32_special_values
+func @f32_special_values() {
+  // F32 signaling NaNs.
+  // CHECK: constant 0x7F800001 : f32
+  %0 = constant 0x7F800001 : f32
+  // CHECK: constant 0x7FBFFFFF : f32
+  %1 = constant 0x7FBFFFFF : f32
+
+  // F32 quiet NaNs.
+  // CHECK: constant 0x7FC00000 : f32
+  %2 = constant 0x7FC00000 : f32
+  // CHECK: constant 0xFFFFFFFF : f32
+  %3 = constant 0xFFFFFFFF : f32
+
+  // F32 positive infinity.
+  // CHECK: constant 0x7F800000 : f32
+  %4 = constant 0x7F800000 : f32
+  // F32 negative infinity.
+  // CHECK: constant 0xFF800000 : f32
+  %5 = constant 0xFF800000 : f32
+
+  return
+}
+
+// CHECK-LABEL: @f64_special_values
+func @f64_special_values() {
+  // F64 signaling NaNs.
+  // CHECK: constant 0x7FF0000000000001 : f64
+  %0 = constant 0x7FF0000000000001 : f64
+  // CHECK: constant 0x7FF8000000000000 : f64
+  %1 = constant 0x7FF8000000000000 : f64
+
+  // F64 quiet NaNs.
+  // CHECK: constant 0x7FF0000001000000 : f64
+  %2 = constant 0x7FF0000001000000 : f64
+  // CHECK: constant 0xFFF0000001000000 : f64
+  %3 = constant 0xFFF0000001000000 : f64
+
+  // F64 positive inifinity.
+  // CHECK: constant 0x7FF0000000000000 : f64
+  %4 = constant 0x7FF0000000000000 : f64
+  // F64 negative infinity.
+  // CHECK: constant 0xFFF0000000000000 : f64
+  %5 = constant 0xFFF0000000000000 : f64
+
+  return
+}