Use APFloat for FloatAttribute
authorFeng Liu <fengliuai@google.com>
Sun, 21 Oct 2018 01:31:49 +0000 (18:31 -0700)
committerjpienaar <jpienaar@google.com>
Fri, 29 Mar 2019 20:34:09 +0000 (13:34 -0700)
We should be able to represent arbitrary precision Float-point values inside
the IR, so compiler optimizations, such as constant folding can be done
independently on the compiling platform.

This CL also added a new field, AttrValueGetter, to the Attr class definition
for TableGen. This field is used to customize which mlir::Attr getter method to
get the defined PrimitiveType.

PiperOrigin-RevId: 218034983

mlir/include/mlir/IR/Attributes.h
mlir/include/mlir/IR/Builders.h
mlir/include/mlir/IR/BuiltinOps.h
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/Builders.cpp
mlir/lib/IR/BuiltinOps.cpp
mlir/lib/IR/MLIRContext.cpp
mlir/lib/Parser/Parser.cpp

index a676bf276f53281dfb435ddb61d680d31db03b62..28a3939fb1e1c58c6b751772da3217a654aec95d 100644 (file)
@@ -20,7 +20,8 @@
 
 #include "mlir/IR/AffineMap.h"
 #include "mlir/Support/LLVM.h"
-#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/APFloat.h"
+#include "llvm/Support/TrailingObjects.h"
 
 namespace mlir {
 class Function;
@@ -121,15 +122,15 @@ private:
   int64_t value;
 };
 
-class FloatAttr : public Attribute {
+class FloatAttr final : public Attribute,
+                        public llvm::TrailingObjects<FloatAttr, uint64_t> {
 public:
   static FloatAttr *get(double value, MLIRContext *context);
+  static FloatAttr *get(const APFloat &value, MLIRContext *context);
 
-  // TODO: This should really be implemented in terms of APFloat for
-  // correctness, otherwise constant folding will be done with host math.  This
-  // is completely incorrect for BF16 and other datatypes, and subtly wrong
-  // for float32.
-  double getValue() const { return value; }
+  APFloat getValue() const;
+
+  double getDouble() const { return getValue().convertToDouble(); }
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
   static bool classof(const Attribute *attr) {
@@ -137,10 +138,18 @@ public:
   }
 
 private:
-  FloatAttr(double value)
-      : Attribute(Kind::Float, /*isOrContainsFunction=*/false), value(value) {}
+  FloatAttr(const llvm::fltSemantics &semantics, size_t numObjects)
+      : Attribute(Kind::Float, /*isOrContainsFunction=*/false),
+        semantics(semantics), numObjects(numObjects) {}
+  FloatAttr(const FloatAttr &value) = delete;
   ~FloatAttr() = delete;
-  double value;
+
+  size_t numTrailingObjects(OverloadToken<uint64_t>) const {
+    return numObjects;
+  }
+
+  const llvm::fltSemantics &semantics;
+  size_t numObjects;
 };
 
 class StringAttr : public Attribute {
index 4e44211c34442d754278697b7f986c00bfd34def..0415637f3f07e681ee0a7846b3220680865d7d1e 100644 (file)
@@ -96,6 +96,7 @@ public:
   BoolAttr *getBoolAttr(bool value);
   IntegerAttr *getIntegerAttr(int64_t value);
   FloatAttr *getFloatAttr(double value);
+  FloatAttr *getFloatAttr(const APFloat &value);
   StringAttr *getStringAttr(StringRef bytes);
   ArrayAttr *getArrayAttr(ArrayRef<Attribute *> value);
   AffineMapAttr *getAffineMapAttr(AffineMap map);
index fb2c2acdfd8ea2b734d11a2b578757fbaa81b515..9030f97572b998b7fdbe36d1098afb1e3b869922 100644 (file)
@@ -116,10 +116,10 @@ protected:
 class ConstantFloatOp : public ConstantOp {
 public:
   /// Builds a constant float op producing a float of the specified type.
-  static void build(Builder *builder, OperationState *result, double value,
-                    FloatType *type);
+  static void build(Builder *builder, OperationState *result,
+                    const APFloat &value, FloatType *type);
 
-  double getValue() const {
+  APFloat getValue() const {
     return getAttrOfType<FloatAttr>("value")->getValue();
   }
 
index e70dcae29360b28a376137f9cd6e2041c2f56608..3c71f9a6b48cd9886d5e6f4fb50e34c39d52c8b2 100644 (file)
@@ -373,9 +373,7 @@ void ModulePrinter::print(const Module *module) {
 
 /// Print a floating point value in a way that the parser will be able to
 /// round-trip losslessly.
-static void printFloatValue(double value, raw_ostream &os) {
-  APFloat apValue(value);
-
+static void printFloatValue(const APFloat &apValue, raw_ostream &os) {
   // We would like to output the FP constant value in exponential notation,
   // but we cannot do this if doing so will lose precision.  Check here to
   // make sure that we only output it in exponential format if we can parse
@@ -394,25 +392,15 @@ static void printFloatValue(double value, raw_ostream &os) {
              (strValue[1] >= '0' && strValue[1] <= '9'))) &&
            "[-+]?[0-9] regex does not match!");
     // Reparse stringized version!
-    if (APFloat(APFloat::IEEEdouble(), strValue).convertToDouble() == value) {
+    if (APFloat(APFloat::IEEEdouble(), strValue).bitwiseIsEqual(apValue)) {
       os << strValue;
       return;
     }
   }
 
-  // Otherwise, print it in a hexadecimal form.  Convert it to an integer so we
-  // can print it out using integer math.
-  union {
-    double doubleValue;
-    uint64_t integerValue;
-  };
-  doubleValue = value;
-  os << "0x";
-  // Print out 16 nibbles worth of hex digit.
-  for (unsigned i = 0; i != 16; ++i) {
-    os << llvm::hexdigit(integerValue >> 60);
-    integerValue <<= 4;
-  }
+  SmallVector<char, 16> str;
+  apValue.toString(str);
+  os << str;
 }
 
 void ModulePrinter::printFunctionReference(const Function *func) {
index 3e22c852ee42f3c6fae1adaec3fad83223983ebf..66192f0a867a81203cfd04509a084834268c3bb9 100644 (file)
@@ -121,6 +121,10 @@ IntegerAttr *Builder::getIntegerAttr(int64_t value) {
 }
 
 FloatAttr *Builder::getFloatAttr(double value) {
+  return FloatAttr::get(APFloat(value), context);
+}
+
+FloatAttr *Builder::getFloatAttr(const APFloat &value) {
   return FloatAttr::get(value, context);
 }
 
index 2acc26d73af30d6f2955301b4477859f1851629a..fe943026ecbc46fea9c64916b1a49bcbd40bb86c 100644 (file)
@@ -238,7 +238,7 @@ Attribute *ConstantOp::constantFold(ArrayRef<Attribute *> operands,
 }
 
 void ConstantFloatOp::build(Builder *builder, OperationState *result,
-                            double value, FloatType *type) {
+                            const APFloat &value, FloatType *type) {
   ConstantOp::build(builder, result, builder->getFloatAttr(value), type);
 }
 
index 9983428892627a2c90b977176219a12495d49e8c..1b9fe0cde63c13c8cd0b7a0df00b4963f043c1fd 100644 (file)
@@ -32,7 +32,6 @@
 #include "mlir/IR/Types.h"
 #include "mlir/Support/MathExtras.h"
 #include "mlir/Support/STLExtras.h"
-#include "llvm/ADT/APInt.h"
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/StringMap.h"
 #include "llvm/Support/Allocator.h"
@@ -147,6 +146,21 @@ struct MemRefTypeKeyInfo : DenseMapInfo<MemRefType *> {
   }
 };
 
+struct FloatAttrKeyInfo : DenseMapInfo<FloatAttr *> {
+  // Float attributes are uniqued based on wrapped APFloat.
+  using KeyTy = APFloat;
+  using DenseMapInfo<FloatAttr *>::getHashValue;
+  using DenseMapInfo<FloatAttr *>::isEqual;
+
+  static unsigned getHashValue(KeyTy key) { return llvm::hash_value(key); }
+
+  static bool isEqual(const KeyTy &lhs, const FloatAttr *rhs) {
+    if (rhs == getEmptyKey() || rhs == getTombstoneKey())
+      return false;
+    return lhs.bitwiseIsEqual(rhs->getValue());
+  }
+};
+
 struct ArrayAttrKeyInfo : DenseMapInfo<ArrayAttr *> {
   // Array attributes are uniqued based on their elements.
   using KeyTy = ArrayRef<Attribute *>;
@@ -282,7 +296,7 @@ public:
   // Attribute uniquing.
   BoolAttr *boolAttrs[2] = {nullptr};
   DenseMap<int64_t, IntegerAttr *> integerAttrs;
-  DenseMap<int64_t, FloatAttr *> floatAttrs;
+  DenseSet<FloatAttr *, FloatAttrKeyInfo> floatAttrs;
   StringMap<StringAttr *> stringAttrs;
   using ArrayAttrSet = DenseSet<ArrayAttr *, ArrayAttrKeyInfo>;
   ArrayAttrSet arrayAttrs;
@@ -638,21 +652,36 @@ IntegerAttr *IntegerAttr::get(int64_t value, MLIRContext *context) {
 }
 
 FloatAttr *FloatAttr::get(double value, MLIRContext *context) {
-  // We hash based on the bit representation of the double to ensure we don't
-  // merge things like -0.0 and 0.0 in the hash comparison.
-  union {
-    double floatValue;
-    int64_t intValue;
-  };
-  floatValue = value;
-
-  auto *&result = context->getImpl().floatAttrs[intValue];
-  if (result)
-    return result;
+  return get(APFloat(value), context);
+}
 
-  result = context->getImpl().allocator.Allocate<FloatAttr>();
-  new (result) FloatAttr(value);
-  return result;
+FloatAttr *FloatAttr::get(const APFloat &value, MLIRContext *context) {
+  auto &impl = context->getImpl();
+
+  // Look to see if the float attribute has been created already.
+  auto existing = impl.floatAttrs.insert_as(nullptr, value);
+
+  // If it has been created, return it.
+  if (!existing.second)
+    return *existing.first;
+
+  // If it doesn't, create one, unique it and return it.
+  const auto &apint = value.bitcastToAPInt();
+  // Here one word's bitwidth equals to that of uint64_t.
+  auto elements = ArrayRef<uint64_t>(apint.getRawData(), apint.getNumWords());
+
+  auto byteSize = FloatAttr::totalSizeToAlloc<uint64_t>(elements.size());
+  auto rawMem = impl.allocator.Allocate(byteSize, alignof(FloatAttr));
+  auto result = ::new (rawMem) FloatAttr(value.getSemantics(), elements.size());
+  std::uninitialized_copy(elements.begin(), elements.end(),
+                          result->getTrailingObjects<uint64_t>());
+  return *existing.first = result;
+}
+
+APFloat FloatAttr::getValue() const {
+  auto val = APInt(APFloat::getSizeInBits(semantics),
+                   {getTrailingObjects<uint64_t>(), numObjects});
+  return APFloat(semantics, val);
 }
 
 StringAttr *StringAttr::get(StringRef bytes, MLIRContext *context) {
index cbd3d8b7bd726a95612e7d2e1ed4c2d86c9014ec..448f974b348ec477cc409cf8fb1f1c1aa964b944 100644 (file)
@@ -699,7 +699,7 @@ TensorLiteralParser::parseElementOrList(llvm::SmallVectorImpl<int> &dims) {
     case Type::Kind::F64: {
       if (!isa<FloatAttr>(result))
         return p.emitError("expected tensor literal element has float type");
-      double value = cast<FloatAttr>(result)->getValue();
+      double value = cast<FloatAttr>(result)->getDouble();
       addToStorage(*(uint64_t *)(&value));
       break;
     }
@@ -823,7 +823,7 @@ Attribute *Parser::parseAttribute() {
       return (emitError("floating point value too large for attribute"),
               nullptr);
     consumeToken(Token::floatliteral);
-    return builder.getFloatAttr(val.getValue());
+    return builder.getFloatAttr(APFloat(val.getValue()));
   }
   case Token::integer: {
     auto val = getToken().getUInt64IntegerValue();
@@ -848,7 +848,7 @@ Attribute *Parser::parseAttribute() {
         return (emitError("floating point value too large for attribute"),
                 nullptr);
       consumeToken(Token::floatliteral);
-      return builder.getFloatAttr(-val.getValue());
+      return builder.getFloatAttr(APFloat(-val.getValue()));
     }
 
     return (emitError("expected constant integer or floating point value"),