#include "mlir/IR/AffineMap.h"
#include "mlir/Support/LLVM.h"
-#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/APFloat.h"
+#include "llvm/Support/TrailingObjects.h"
namespace mlir {
class Function;
int64_t value;
};
-class FloatAttr : public Attribute {
+class FloatAttr final : public Attribute,
+ public llvm::TrailingObjects<FloatAttr, uint64_t> {
public:
static FloatAttr *get(double value, MLIRContext *context);
+ static FloatAttr *get(const APFloat &value, MLIRContext *context);
- // TODO: This should really be implemented in terms of APFloat for
- // correctness, otherwise constant folding will be done with host math. This
- // is completely incorrect for BF16 and other datatypes, and subtly wrong
- // for float32.
- double getValue() const { return value; }
+ APFloat getValue() const;
+
+ double getDouble() const { return getValue().convertToDouble(); }
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Attribute *attr) {
}
private:
- FloatAttr(double value)
- : Attribute(Kind::Float, /*isOrContainsFunction=*/false), value(value) {}
+ FloatAttr(const llvm::fltSemantics &semantics, size_t numObjects)
+ : Attribute(Kind::Float, /*isOrContainsFunction=*/false),
+ semantics(semantics), numObjects(numObjects) {}
+ FloatAttr(const FloatAttr &value) = delete;
~FloatAttr() = delete;
- double value;
+
+ size_t numTrailingObjects(OverloadToken<uint64_t>) const {
+ return numObjects;
+ }
+
+ const llvm::fltSemantics &semantics;
+ size_t numObjects;
};
class StringAttr : public Attribute {
BoolAttr *getBoolAttr(bool value);
IntegerAttr *getIntegerAttr(int64_t value);
FloatAttr *getFloatAttr(double value);
+ FloatAttr *getFloatAttr(const APFloat &value);
StringAttr *getStringAttr(StringRef bytes);
ArrayAttr *getArrayAttr(ArrayRef<Attribute *> value);
AffineMapAttr *getAffineMapAttr(AffineMap map);
class ConstantFloatOp : public ConstantOp {
public:
/// Builds a constant float op producing a float of the specified type.
- static void build(Builder *builder, OperationState *result, double value,
- FloatType *type);
+ static void build(Builder *builder, OperationState *result,
+ const APFloat &value, FloatType *type);
- double getValue() const {
+ APFloat getValue() const {
return getAttrOfType<FloatAttr>("value")->getValue();
}
/// Print a floating point value in a way that the parser will be able to
/// round-trip losslessly.
-static void printFloatValue(double value, raw_ostream &os) {
- APFloat apValue(value);
-
+static void printFloatValue(const APFloat &apValue, raw_ostream &os) {
// We would like to output the FP constant value in exponential notation,
// but we cannot do this if doing so will lose precision. Check here to
// make sure that we only output it in exponential format if we can parse
(strValue[1] >= '0' && strValue[1] <= '9'))) &&
"[-+]?[0-9] regex does not match!");
// Reparse stringized version!
- if (APFloat(APFloat::IEEEdouble(), strValue).convertToDouble() == value) {
+ if (APFloat(APFloat::IEEEdouble(), strValue).bitwiseIsEqual(apValue)) {
os << strValue;
return;
}
}
- // Otherwise, print it in a hexadecimal form. Convert it to an integer so we
- // can print it out using integer math.
- union {
- double doubleValue;
- uint64_t integerValue;
- };
- doubleValue = value;
- os << "0x";
- // Print out 16 nibbles worth of hex digit.
- for (unsigned i = 0; i != 16; ++i) {
- os << llvm::hexdigit(integerValue >> 60);
- integerValue <<= 4;
- }
+ SmallVector<char, 16> str;
+ apValue.toString(str);
+ os << str;
}
void ModulePrinter::printFunctionReference(const Function *func) {
}
FloatAttr *Builder::getFloatAttr(double value) {
+ return FloatAttr::get(APFloat(value), context);
+}
+
+FloatAttr *Builder::getFloatAttr(const APFloat &value) {
return FloatAttr::get(value, context);
}
}
void ConstantFloatOp::build(Builder *builder, OperationState *result,
- double value, FloatType *type) {
+ const APFloat &value, FloatType *type) {
ConstantOp::build(builder, result, builder->getFloatAttr(value), type);
}
#include "mlir/IR/Types.h"
#include "mlir/Support/MathExtras.h"
#include "mlir/Support/STLExtras.h"
-#include "llvm/ADT/APInt.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/Support/Allocator.h"
}
};
+struct FloatAttrKeyInfo : DenseMapInfo<FloatAttr *> {
+ // Float attributes are uniqued based on wrapped APFloat.
+ using KeyTy = APFloat;
+ using DenseMapInfo<FloatAttr *>::getHashValue;
+ using DenseMapInfo<FloatAttr *>::isEqual;
+
+ static unsigned getHashValue(KeyTy key) { return llvm::hash_value(key); }
+
+ static bool isEqual(const KeyTy &lhs, const FloatAttr *rhs) {
+ if (rhs == getEmptyKey() || rhs == getTombstoneKey())
+ return false;
+ return lhs.bitwiseIsEqual(rhs->getValue());
+ }
+};
+
struct ArrayAttrKeyInfo : DenseMapInfo<ArrayAttr *> {
// Array attributes are uniqued based on their elements.
using KeyTy = ArrayRef<Attribute *>;
// Attribute uniquing.
BoolAttr *boolAttrs[2] = {nullptr};
DenseMap<int64_t, IntegerAttr *> integerAttrs;
- DenseMap<int64_t, FloatAttr *> floatAttrs;
+ DenseSet<FloatAttr *, FloatAttrKeyInfo> floatAttrs;
StringMap<StringAttr *> stringAttrs;
using ArrayAttrSet = DenseSet<ArrayAttr *, ArrayAttrKeyInfo>;
ArrayAttrSet arrayAttrs;
}
FloatAttr *FloatAttr::get(double value, MLIRContext *context) {
- // We hash based on the bit representation of the double to ensure we don't
- // merge things like -0.0 and 0.0 in the hash comparison.
- union {
- double floatValue;
- int64_t intValue;
- };
- floatValue = value;
-
- auto *&result = context->getImpl().floatAttrs[intValue];
- if (result)
- return result;
+ return get(APFloat(value), context);
+}
- result = context->getImpl().allocator.Allocate<FloatAttr>();
- new (result) FloatAttr(value);
- return result;
+FloatAttr *FloatAttr::get(const APFloat &value, MLIRContext *context) {
+ auto &impl = context->getImpl();
+
+ // Look to see if the float attribute has been created already.
+ auto existing = impl.floatAttrs.insert_as(nullptr, value);
+
+ // If it has been created, return it.
+ if (!existing.second)
+ return *existing.first;
+
+ // If it doesn't, create one, unique it and return it.
+ const auto &apint = value.bitcastToAPInt();
+ // Here one word's bitwidth equals to that of uint64_t.
+ auto elements = ArrayRef<uint64_t>(apint.getRawData(), apint.getNumWords());
+
+ auto byteSize = FloatAttr::totalSizeToAlloc<uint64_t>(elements.size());
+ auto rawMem = impl.allocator.Allocate(byteSize, alignof(FloatAttr));
+ auto result = ::new (rawMem) FloatAttr(value.getSemantics(), elements.size());
+ std::uninitialized_copy(elements.begin(), elements.end(),
+ result->getTrailingObjects<uint64_t>());
+ return *existing.first = result;
+}
+
+APFloat FloatAttr::getValue() const {
+ auto val = APInt(APFloat::getSizeInBits(semantics),
+ {getTrailingObjects<uint64_t>(), numObjects});
+ return APFloat(semantics, val);
}
StringAttr *StringAttr::get(StringRef bytes, MLIRContext *context) {
case Type::Kind::F64: {
if (!isa<FloatAttr>(result))
return p.emitError("expected tensor literal element has float type");
- double value = cast<FloatAttr>(result)->getValue();
+ double value = cast<FloatAttr>(result)->getDouble();
addToStorage(*(uint64_t *)(&value));
break;
}
return (emitError("floating point value too large for attribute"),
nullptr);
consumeToken(Token::floatliteral);
- return builder.getFloatAttr(val.getValue());
+ return builder.getFloatAttr(APFloat(val.getValue()));
}
case Token::integer: {
auto val = getToken().getUInt64IntegerValue();
return (emitError("floating point value too large for attribute"),
nullptr);
consumeToken(Token::floatliteral);
- return builder.getFloatAttr(-val.getValue());
+ return builder.getFloatAttr(APFloat(-val.getValue()));
}
return (emitError("expected constant integer or floating point value"),