if (auto *ComplexTy = OrigType->getAs<ComplexType>())
Type = ComplexTy->getElementType();
if (Type->isRealFloatingType()) {
- llvm::APFloat InitValue =
- llvm::APFloat::getAllOnesValue(Context.getTypeSize(Type),
- /*isIEEE=*/true);
+ llvm::APFloat InitValue = llvm::APFloat::getAllOnesValue(
+ Context.getFloatTypeSemantics(Type),
+ Context.getTypeSize(Type));
Init = FloatingLiteral::Create(Context, InitValue, /*isexact=*/true,
Type, ELoc);
} else if (Type->isScalarType()) {
The ``HALF`` record (code 10) adds a ``half`` (16-bit floating point) type to
the type table.
+TYPE_CODE_BFLOAT Record
+^^^^^^^^^^^^^^^^^^^^^
+
+``[BFLOAT]``
+
+The ``BFLOAT`` record (code 23) adds a ``bfloat`` (16-bit brain floating point)
+type to the type table.
+
TYPE_CODE_FLOAT Record
^^^^^^^^^^^^^^^^^^^^^^
* - ``half``
- 16-bit floating-point value
+ * - ``bfloat``
+ - 16-bit "brain" floating-point value (7-bit significand). Provides the
+ same number of exponent bits as ``float``, so that it matches its dynamic
+ range, but with greatly reduced precision. Used in Intel's AVX-512 BF16
+ extensions and Arm's ARMv8.6-A extensions, among others.
+
* - ``float``
- 32-bit floating-point value
- 64-bit floating-point value
* - ``fp128``
- - 128-bit floating-point value (112-bit mantissa)
+ - 128-bit floating-point value (112-bit significand)
* - ``x86_fp80``
- 80-bit floating-point value (X87)
values are represented in their IEEE hexadecimal format so that assembly
and disassembly do not cause any bits to change in the constants.
-When using the hexadecimal form, constants of types half, float, and
-double are represented using the 16-digit form shown above (which
-matches the IEEE754 representation for double); half and float values
-must, however, be exactly representable as IEEE 754 half and single
-precision, respectively. Hexadecimal format is always used for long
-double, and there are three forms of long double. The 80-bit format used
-by x86 is represented as ``0xK`` followed by 20 hexadecimal digits. The
-128-bit format used by PowerPC (two adjacent doubles) is represented by
-``0xM`` followed by 32 hexadecimal digits. The IEEE 128-bit format is
-represented by ``0xL`` followed by 32 hexadecimal digits. Long doubles
-will only work if they match the long double format on your target.
-The IEEE 16-bit format (half precision) is represented by ``0xH``
-followed by 4 hexadecimal digits. All hexadecimal formats are big-endian
-(sign bit at the left).
+When using the hexadecimal form, constants of types bfloat, half, float, and
+double are represented using the 16-digit form shown above (which matches the
+IEEE754 representation for double); bfloat, half and float values must, however,
+be exactly representable as bfloat, IEEE 754 half, and IEEE 754 single
+precision respectively. Hexadecimal format is always used for long double, and
+there are three forms of long double. The 80-bit format used by x86 is
+represented as ``0xK`` followed by 20 hexadecimal digits. The 128-bit format
+used by PowerPC (two adjacent doubles) is represented by ``0xM`` followed by 32
+hexadecimal digits. The IEEE 128-bit format is represented by ``0xL`` followed
+by 32 hexadecimal digits. Long doubles will only work if they match the long
+double format on your target. The IEEE 16-bit format (half precision) is
+represented by ``0xH`` followed by 4 hexadecimal digits. The bfloat 16-bit
+format is represented by ``0xR`` followed by 4 hexadecimal digits. All
+hexadecimal formats are big-endian (sign bit at the left).
There are no constants of type x86_mmx.
typedef enum {
LLVMVoidTypeKind, /**< type with no size */
LLVMHalfTypeKind, /**< 16 bit floating point type */
+ LLVMBFloatTypeKind, /**< 16 bit brain floating point type */
LLVMFloatTypeKind, /**< 32 bit floating point type */
LLVMDoubleTypeKind, /**< 64 bit floating point type */
LLVMX86_FP80TypeKind, /**< 80 bit floating point type (X87) */
LLVMTypeRef LLVMHalfTypeInContext(LLVMContextRef C);
/**
+ * Obtain a 16-bit brain floating point type from a context.
+ */
+LLVMTypeRef LLVMBFloatTypeInContext(LLVMContextRef C);
+
+/**
* Obtain a 32-bit floating point type from a context.
*/
LLVMTypeRef LLVMFloatTypeInContext(LLVMContextRef C);
* These map to the functions in this group of the same name.
*/
LLVMTypeRef LLVMHalfType(void);
+LLVMTypeRef LLVMBFloatType(void);
LLVMTypeRef LLVMFloatType(void);
LLVMTypeRef LLVMDoubleType(void);
LLVMTypeRef LLVMX86FP80Type(void);
/// @{
enum Semantics {
S_IEEEhalf,
+ S_BFloat,
S_IEEEsingle,
S_IEEEdouble,
S_x87DoubleExtended,
static Semantics SemanticsToEnum(const llvm::fltSemantics &Sem);
static const fltSemantics &IEEEhalf() LLVM_READNONE;
+ static const fltSemantics &BFloat() LLVM_READNONE;
static const fltSemantics &IEEEsingle() LLVM_READNONE;
static const fltSemantics &IEEEdouble() LLVM_READNONE;
static const fltSemantics &IEEEquad() LLVM_READNONE;
/// @}
APInt convertHalfAPFloatToAPInt() const;
+ APInt convertBFloatAPFloatToAPInt() const;
APInt convertFloatAPFloatToAPInt() const;
APInt convertDoubleAPFloatToAPInt() const;
APInt convertQuadrupleAPFloatToAPInt() const;
APInt convertPPCDoubleDoubleAPFloatToAPInt() const;
void initFromAPInt(const fltSemantics *Sem, const APInt &api);
void initFromHalfAPInt(const APInt &api);
+ void initFromBFloatAPInt(const APInt &api);
void initFromFloatAPInt(const APInt &api);
void initFromDoubleAPInt(const APInt &api);
void initFromQuadrupleAPInt(const APInt &api);
/// Returns a float which is bitcasted from an all one value int.
///
+ /// \param Semantics - type float semantics
/// \param BitWidth - Select float type
- /// \param isIEEE - If 128 bit number, select between PPC and IEEE
- static APFloat getAllOnesValue(unsigned BitWidth, bool isIEEE = false);
+ static APFloat getAllOnesValue(const fltSemantics &Semantics,
+ unsigned BitWidth);
/// Used to insert APFloat objects, or objects that contain APFloat objects,
/// into FoldingSets.
TYPE_CODE_FUNCTION = 21, // FUNCTION: [vararg, retty, paramty x N]
- TYPE_CODE_TOKEN = 22 // TOKEN
+ TYPE_CODE_TOKEN = 22, // TOKEN
+
+ TYPE_CODE_BFLOAT = 23 // BRAIN FLOATING POINT
};
enum OperandBundleTagCode {
return getImpl(Data, Ty);
}
- /// getFP() constructors - Return a constant with array type with an element
- /// count and element type of float with precision matching the number of
- /// bits in the ArrayRef passed in. (i.e. half for 16bits, float for 32bits,
- /// double for 64bits) Note that this can return a ConstantAggregateZero
- /// object.
- static Constant *getFP(LLVMContext &Context, ArrayRef<uint16_t> Elts);
- static Constant *getFP(LLVMContext &Context, ArrayRef<uint32_t> Elts);
- static Constant *getFP(LLVMContext &Context, ArrayRef<uint64_t> Elts);
+ /// getFP() constructors - Return a constant of array type with a float
+ /// element type taken from argument `ElementType', and count taken from
+ /// argument `Elts'. The amount of bits of the contained type must match the
+ /// number of bits of the type contained in the passed in ArrayRef.
+ /// (i.e. half or bfloat for 16bits, float for 32bits, double for 64bits) Note
+ /// that this can return a ConstantAggregateZero object.
+ static Constant *getFP(Type *ElementType, ArrayRef<uint16_t> Elts);
+ static Constant *getFP(Type *ElementType, ArrayRef<uint32_t> Elts);
+ static Constant *getFP(Type *ElementType, ArrayRef<uint64_t> Elts);
/// This method constructs a CDS and initializes it with a text string.
/// The default behavior (AddNull==true) causes a null terminator to
static Constant *get(LLVMContext &Context, ArrayRef<float> Elts);
static Constant *get(LLVMContext &Context, ArrayRef<double> Elts);
- /// getFP() constructors - Return a constant with vector type with an element
- /// count and element type of float with the precision matching the number of
- /// bits in the ArrayRef passed in. (i.e. half for 16bits, float for 32bits,
- /// double for 64bits) Note that this can return a ConstantAggregateZero
- /// object.
- static Constant *getFP(LLVMContext &Context, ArrayRef<uint16_t> Elts);
- static Constant *getFP(LLVMContext &Context, ArrayRef<uint32_t> Elts);
- static Constant *getFP(LLVMContext &Context, ArrayRef<uint64_t> Elts);
+ /// getFP() constructors - Return a constant of vector type with a float
+ /// element type taken from argument `ElementType', and count taken from
+ /// argument `Elts'. The amount of bits of the contained type must match the
+ /// number of bits of the type contained in the passed in ArrayRef.
+ /// (i.e. half or bfloat for 16bits, float for 32bits, double for 64bits) Note
+ /// that this can return a ConstantAggregateZero object.
+ static Constant *getFP(Type *ElementType, ArrayRef<uint16_t> Elts);
+ static Constant *getFP(Type *ElementType, ArrayRef<uint32_t> Elts);
+ static Constant *getFP(Type *ElementType, ArrayRef<uint64_t> Elts);
/// Return a ConstantVector with the specified constant in each element.
/// The specified constant has to be a of a compatible type (i8/i16/
case Type::IntegerTyID:
return TypeSize::Fixed(Ty->getIntegerBitWidth());
case Type::HalfTyID:
+ case Type::BFloatTyID:
return TypeSize::Fixed(16);
case Type::FloatTyID:
return TypeSize::Fixed(32);
return Type::getHalfTy(Context);
}
+ /// Fetch the type representing a 16-bit brain floating point value.
+ Type *getBFloatTy() {
+ return Type::getBFloatTy(Context);
+ }
+
/// Fetch the type representing a 32-bit floating point value.
Type *getFloatTy() {
return Type::getFloatTy(Context);
///
enum TypeID {
// PrimitiveTypes - make sure LastPrimitiveTyID stays up to date.
- VoidTyID = 0, ///< 0: type with no size
- HalfTyID, ///< 1: 16-bit floating point type
- FloatTyID, ///< 2: 32-bit floating point type
- DoubleTyID, ///< 3: 64-bit floating point type
- X86_FP80TyID, ///< 4: 80-bit floating point type (X87)
- FP128TyID, ///< 5: 128-bit floating point type (112-bit mantissa)
- PPC_FP128TyID, ///< 6: 128-bit floating point type (two 64-bits, PowerPC)
- LabelTyID, ///< 7: Labels
- MetadataTyID, ///< 8: Metadata
- X86_MMXTyID, ///< 9: MMX vectors (64 bits, X86 specific)
- TokenTyID, ///< 10: Tokens
+ VoidTyID = 0, ///< 0: type with no size
+ HalfTyID, ///< 1: 16-bit floating point type
+ BFloatTyID, ///< 2: 16-bit floating point type (7-bit significand)
+ FloatTyID, ///< 3: 32-bit floating point type
+ DoubleTyID, ///< 4: 64-bit floating point type
+ X86_FP80TyID, ///< 5: 80-bit floating point type (X87)
+ FP128TyID, ///< 6: 128-bit floating point type (112-bit significand)
+ PPC_FP128TyID, ///< 7: 128-bit floating point type (two 64-bits, PowerPC)
+ LabelTyID, ///< 8: Labels
+ MetadataTyID, ///< 9: Metadata
+ X86_MMXTyID, ///< 10: MMX vectors (64 bits, X86 specific)
+ TokenTyID, ///< 11: Tokens
// Derived types... see DerivedTypes.h file.
// Make sure FirstDerivedTyID stays up to date!
- IntegerTyID, ///< 11: Arbitrary bit width integers
- FunctionTyID, ///< 12: Functions
- StructTyID, ///< 13: Structures
- ArrayTyID, ///< 14: Arrays
- PointerTyID, ///< 15: Pointers
- FixedVectorTyID, ///< 16: Fixed width SIMD vector type
- ScalableVectorTyID ///< 17: Scalable SIMD vector type
+ IntegerTyID, ///< 12: Arbitrary bit width integers
+ FunctionTyID, ///< 13: Functions
+ StructTyID, ///< 14: Structures
+ ArrayTyID, ///< 15: Arrays
+ PointerTyID, ///< 16: Pointers
+ FixedVectorTyID, ///< 17: Fixed width SIMD vector type
+ ScalableVectorTyID ///< 18: Scalable SIMD vector type
};
private:
/// Return true if this is 'half', a 16-bit IEEE fp type.
bool isHalfTy() const { return getTypeID() == HalfTyID; }
+ /// Return true if this is 'bfloat', a 16-bit bfloat type.
+ bool isBFloatTy() const { return getTypeID() == BFloatTyID; }
+
/// Return true if this is 'float', a 32-bit IEEE fp type.
bool isFloatTy() const { return getTypeID() == FloatTyID; }
/// Return true if this is one of the six floating-point types
bool isFloatingPointTy() const {
- return getTypeID() == HalfTyID || getTypeID() == FloatTyID ||
- getTypeID() == DoubleTyID ||
+ return getTypeID() == HalfTyID || getTypeID() == BFloatTyID ||
+ getTypeID() == FloatTyID || getTypeID() == DoubleTyID ||
getTypeID() == X86_FP80TyID || getTypeID() == FP128TyID ||
getTypeID() == PPC_FP128TyID;
}
const fltSemantics &getFltSemantics() const {
switch (getTypeID()) {
case HalfTyID: return APFloat::IEEEhalf();
+ case BFloatTyID: return APFloat::BFloat();
case FloatTyID: return APFloat::IEEEsingle();
case DoubleTyID: return APFloat::IEEEdouble();
case X86_FP80TyID: return APFloat::x87DoubleExtended();
static Type *getVoidTy(LLVMContext &C);
static Type *getLabelTy(LLVMContext &C);
static Type *getHalfTy(LLVMContext &C);
+ static Type *getBFloatTy(LLVMContext &C);
static Type *getFloatTy(LLVMContext &C);
static Type *getDoubleTy(LLVMContext &C);
static Type *getMetadataTy(LLVMContext &C);
// types as pointee.
//
static PointerType *getHalfPtrTy(LLVMContext &C, unsigned AS = 0);
+ static PointerType *getBFloatPtrTy(LLVMContext &C, unsigned AS = 0);
static PointerType *getFloatPtrTy(LLVMContext &C, unsigned AS = 0);
static PointerType *getDoublePtrTy(LLVMContext &C, unsigned AS = 0);
static PointerType *getX86_FP80PtrTy(LLVMContext &C, unsigned AS = 0);
TYPEKEYWORD("void", Type::getVoidTy(Context));
TYPEKEYWORD("half", Type::getHalfTy(Context));
+ TYPEKEYWORD("bfloat", Type::getBFloatTy(Context));
TYPEKEYWORD("float", Type::getFloatTy(Context));
TYPEKEYWORD("double", Type::getDoubleTy(Context));
TYPEKEYWORD("x86_fp80", Type::getX86_FP80Ty(Context));
/// HexFP128Constant 0xL[0-9A-Fa-f]+
/// HexPPC128Constant 0xM[0-9A-Fa-f]+
/// HexHalfConstant 0xH[0-9A-Fa-f]+
+/// HexBFloatConstant 0xR[0-9A-Fa-f]+
lltok::Kind LLLexer::Lex0x() {
CurPtr = TokStart + 2;
char Kind;
- if ((CurPtr[0] >= 'K' && CurPtr[0] <= 'M') || CurPtr[0] == 'H') {
+ if ((CurPtr[0] >= 'K' && CurPtr[0] <= 'M') || CurPtr[0] == 'H' ||
+ CurPtr[0] == 'R') {
Kind = *CurPtr++;
} else {
Kind = 'J';
if (Kind == 'J') {
// HexFPConstant - Floating point constant represented in IEEE format as a
// hexadecimal number for when exponential notation is not precise enough.
- // Half, Float, and double only.
+ // Half, BFloat, Float, and double only.
APFloatVal = APFloat(APFloat::IEEEdouble(),
APInt(64, HexIntToVal(TokStart + 2, CurPtr)));
return lltok::APFloat;
APFloatVal = APFloat(APFloat::IEEEhalf(),
APInt(16,HexIntToVal(TokStart+3, CurPtr)));
return lltok::APFloat;
+ case 'R':
+ // Brain floating point
+ APFloatVal = APFloat(APFloat::BFloat(),
+ APInt(16, HexIntToVal(TokStart + 3, CurPtr)));
+ return lltok::APFloat;
}
}
!ConstantFP::isValueValidForType(Ty, ID.APFloatVal))
return Error(ID.Loc, "floating point constant invalid for type");
- // The lexer has no type info, so builds all half, float, and double FP
- // constants as double. Fix this here. Long double does not need this.
+ // The lexer has no type info, so builds all half, bfloat, float, and double
+ // FP constants as double. Fix this here. Long double does not need this.
if (&ID.APFloatVal.getSemantics() == &APFloat::IEEEdouble()) {
bool Ignored;
if (Ty->isHalfTy())
ID.APFloatVal.convert(APFloat::IEEEhalf(), APFloat::rmNearestTiesToEven,
&Ignored);
+ else if (Ty->isBFloatTy())
+ ID.APFloatVal.convert(APFloat::BFloat(), APFloat::rmNearestTiesToEven,
+ &Ignored);
else if (Ty->isFloatTy())
ID.APFloatVal.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven,
&Ignored);
case bitc::TYPE_CODE_HALF: // HALF
ResultTy = Type::getHalfTy(Context);
break;
+ case bitc::TYPE_CODE_BFLOAT: // BFLOAT
+ ResultTy = Type::getBFloatTy(Context);
+ break;
case bitc::TYPE_CODE_FLOAT: // FLOAT
ResultTy = Type::getFloatTy(Context);
break;
if (CurTy->isHalfTy())
V = ConstantFP::get(Context, APFloat(APFloat::IEEEhalf(),
APInt(16, (uint16_t)Record[0])));
+ else if (CurTy->isBFloatTy())
+ V = ConstantFP::get(Context, APFloat(APFloat::BFloat(),
+ APInt(16, (uint32_t)Record[0])));
else if (CurTy->isFloatTy())
V = ConstantFP::get(Context, APFloat(APFloat::IEEEsingle(),
APInt(32, (uint32_t)Record[0])));
} else if (EltTy->isHalfTy()) {
SmallVector<uint16_t, 16> Elts(Record.begin(), Record.end());
if (isa<VectorType>(CurTy))
- V = ConstantDataVector::getFP(Context, Elts);
+ V = ConstantDataVector::getFP(EltTy, Elts);
+ else
+ V = ConstantDataArray::getFP(EltTy, Elts);
+ } else if (EltTy->isBFloatTy()) {
+ SmallVector<uint16_t, 16> Elts(Record.begin(), Record.end());
+ if (isa<VectorType>(CurTy))
+ V = ConstantDataVector::getFP(EltTy, Elts);
else
- V = ConstantDataArray::getFP(Context, Elts);
+ V = ConstantDataArray::getFP(EltTy, Elts);
} else if (EltTy->isFloatTy()) {
SmallVector<uint32_t, 16> Elts(Record.begin(), Record.end());
if (isa<VectorType>(CurTy))
- V = ConstantDataVector::getFP(Context, Elts);
+ V = ConstantDataVector::getFP(EltTy, Elts);
else
- V = ConstantDataArray::getFP(Context, Elts);
+ V = ConstantDataArray::getFP(EltTy, Elts);
} else if (EltTy->isDoubleTy()) {
SmallVector<uint64_t, 16> Elts(Record.begin(), Record.end());
if (isa<VectorType>(CurTy))
- V = ConstantDataVector::getFP(Context, Elts);
+ V = ConstantDataVector::getFP(EltTy, Elts);
else
- V = ConstantDataArray::getFP(Context, Elts);
+ V = ConstantDataArray::getFP(EltTy, Elts);
} else {
return error("Invalid type for value");
}
switch (T->getTypeID()) {
case Type::VoidTyID: Code = bitc::TYPE_CODE_VOID; break;
case Type::HalfTyID: Code = bitc::TYPE_CODE_HALF; break;
+ case Type::BFloatTyID: Code = bitc::TYPE_CODE_BFLOAT; break;
case Type::FloatTyID: Code = bitc::TYPE_CODE_FLOAT; break;
case Type::DoubleTyID: Code = bitc::TYPE_CODE_DOUBLE; break;
case Type::X86_FP80TyID: Code = bitc::TYPE_CODE_X86_FP80; break;
} else if (const ConstantFP *CFP = dyn_cast<ConstantFP>(C)) {
Code = bitc::CST_CODE_FLOAT;
Type *Ty = CFP->getType();
- if (Ty->isHalfTy() || Ty->isFloatTy() || Ty->isDoubleTy()) {
+ if (Ty->isHalfTy() || Ty->isBFloatTy() || Ty->isFloatTy() ||
+ Ty->isDoubleTy()) {
Record.push_back(CFP->getValueAPF().bitcastToAPInt().getZExtValue());
} else if (Ty->isX86_FP80Ty()) {
// api needed to prevent premature destruction
}
static bool isValidHexFloatingPointPrefix(char C) {
- return C == 'H' || C == 'K' || C == 'L' || C == 'M';
+ return C == 'H' || C == 'K' || C == 'L' || C == 'M' || C == 'R';
}
static Cursor lexFloatingPointLiteral(Cursor Range, Cursor C, MIToken &Token) {
switch (Ty->getTypeID()) {
case Type::VoidTyID: OS << "void"; return;
case Type::HalfTyID: OS << "half"; return;
+ case Type::BFloatTyID: OS << "bfloat"; return;
case Type::FloatTyID: OS << "float"; return;
case Type::DoubleTyID: OS << "double"; return;
case Type::X86_FP80TyID: OS << "x86_fp80"; return;
return;
}
- // Either half, or some form of long double.
+ // Either half, bfloat or some form of long double.
// These appear as a magic letter identifying the type, then a
// fixed number of hex digits.
Out << "0x";
Out << 'H';
Out << format_hex_no_prefix(API.getZExtValue(), 4,
/*Upper=*/true);
+ } else if (&APF.getSemantics() == &APFloat::BFloat()) {
+ Out << 'R';
+ Out << format_hex_no_prefix(API.getZExtValue(), 4,
+ /*Upper=*/true);
} else
llvm_unreachable("Unsupported floating point type");
return;
case Type::HalfTyID:
return ConstantFP::get(Ty->getContext(),
APFloat::getZero(APFloat::IEEEhalf()));
+ case Type::BFloatTyID:
+ return ConstantFP::get(Ty->getContext(),
+ APFloat::getZero(APFloat::BFloat()));
case Type::FloatTyID:
return ConstantFP::get(Ty->getContext(),
APFloat::getZero(APFloat::IEEEsingle()));
APInt::getAllOnesValue(ITy->getBitWidth()));
if (Ty->isFloatingPointTy()) {
- APFloat FL = APFloat::getAllOnesValue(Ty->getPrimitiveSizeInBits(),
- !Ty->isPPC_FP128Ty());
+ APFloat FL = APFloat::getAllOnesValue(Ty->getFltSemantics(),
+ Ty->getPrimitiveSizeInBits());
return ConstantFP::get(Ty->getContext(), FL);
}
static const fltSemantics *TypeToFloatSemantics(Type *Ty) {
if (Ty->isHalfTy())
return &APFloat::IEEEhalf();
+ if (Ty->isBFloatTy())
+ return &APFloat::BFloat();
if (Ty->isFloatTy())
return &APFloat::IEEEsingle();
if (Ty->isDoubleTy())
Type *Ty;
if (&V.getSemantics() == &APFloat::IEEEhalf())
Ty = Type::getHalfTy(Context);
+ else if (&V.getSemantics() == &APFloat::BFloat())
+ Ty = Type::getBFloatTy(Context);
else if (&V.getSemantics() == &APFloat::IEEEsingle())
Ty = Type::getFloatTy(Context);
else if (&V.getSemantics() == &APFloat::IEEEdouble())
Elts.push_back(CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
else
return nullptr;
- return SequentialTy::getFP(V[0]->getContext(), Elts);
+ return SequentialTy::getFP(V[0]->getType(), Elts);
}
template <typename SequenceTy>
else if (CI->getType()->isIntegerTy(64))
return getIntSequenceIfElementsMatch<SequenceTy, uint64_t>(V);
} else if (ConstantFP *CFP = dyn_cast<ConstantFP>(C)) {
- if (CFP->getType()->isHalfTy())
+ if (CFP->getType()->isHalfTy() || CFP->getType()->isBFloatTy())
return getFPSequenceIfElementsMatch<SequenceTy, uint16_t>(V);
else if (CFP->getType()->isFloatTy())
return getFPSequenceIfElementsMatch<SequenceTy, uint32_t>(V);
Val2.convert(APFloat::IEEEhalf(), APFloat::rmNearestTiesToEven, &losesInfo);
return !losesInfo;
}
+ case Type::BFloatTyID: {
+ if (&Val2.getSemantics() == &APFloat::BFloat())
+ return true;
+ Val2.convert(APFloat::BFloat(), APFloat::rmNearestTiesToEven, &losesInfo);
+ return !losesInfo;
+ }
case Type::FloatTyID: {
if (&Val2.getSemantics() == &APFloat::IEEEsingle())
return true;
}
case Type::DoubleTyID: {
if (&Val2.getSemantics() == &APFloat::IEEEhalf() ||
+ &Val2.getSemantics() == &APFloat::BFloat() ||
&Val2.getSemantics() == &APFloat::IEEEsingle() ||
&Val2.getSemantics() == &APFloat::IEEEdouble())
return true;
}
case Type::X86_FP80TyID:
return &Val2.getSemantics() == &APFloat::IEEEhalf() ||
+ &Val2.getSemantics() == &APFloat::BFloat() ||
&Val2.getSemantics() == &APFloat::IEEEsingle() ||
&Val2.getSemantics() == &APFloat::IEEEdouble() ||
&Val2.getSemantics() == &APFloat::x87DoubleExtended();
case Type::FP128TyID:
return &Val2.getSemantics() == &APFloat::IEEEhalf() ||
+ &Val2.getSemantics() == &APFloat::BFloat() ||
&Val2.getSemantics() == &APFloat::IEEEsingle() ||
&Val2.getSemantics() == &APFloat::IEEEdouble() ||
&Val2.getSemantics() == &APFloat::IEEEquad();
case Type::PPC_FP128TyID:
return &Val2.getSemantics() == &APFloat::IEEEhalf() ||
+ &Val2.getSemantics() == &APFloat::BFloat() ||
&Val2.getSemantics() == &APFloat::IEEEsingle() ||
&Val2.getSemantics() == &APFloat::IEEEdouble() ||
&Val2.getSemantics() == &APFloat::PPCDoubleDouble();
}
bool ConstantDataSequential::isElementTypeCompatible(Type *Ty) {
- if (Ty->isHalfTy() || Ty->isFloatTy() || Ty->isDoubleTy()) return true;
+ if (Ty->isHalfTy() || Ty->isBFloatTy() || Ty->isFloatTy() || Ty->isDoubleTy())
+ return true;
if (auto *IT = dyn_cast<IntegerType>(Ty)) {
switch (IT->getBitWidth()) {
case 8:
Next = nullptr;
}
-/// getFP() constructors - Return a constant with array type with an element
-/// count and element type of float with precision matching the number of
-/// bits in the ArrayRef passed in. (i.e. half for 16bits, float for 32bits,
-/// double for 64bits) Note that this can return a ConstantAggregateZero
-/// object.
-Constant *ConstantDataArray::getFP(LLVMContext &Context,
- ArrayRef<uint16_t> Elts) {
- Type *Ty = ArrayType::get(Type::getHalfTy(Context), Elts.size());
+/// getFP() constructors - Return a constant of array type with a float
+/// element type taken from argument `ElementType', and count taken from
+/// argument `Elts'. The amount of bits of the contained type must match the
+/// number of bits of the type contained in the passed in ArrayRef.
+/// (i.e. half or bfloat for 16bits, float for 32bits, double for 64bits) Note
+/// that this can return a ConstantAggregateZero object.
+Constant *ConstantDataArray::getFP(Type *ElementType, ArrayRef<uint16_t> Elts) {
+ assert((ElementType->isHalfTy() || ElementType->isBFloatTy()) &&
+ "Element type is not a 16-bit float type");
+ Type *Ty = ArrayType::get(ElementType, Elts.size());
const char *Data = reinterpret_cast<const char *>(Elts.data());
return getImpl(StringRef(Data, Elts.size() * 2), Ty);
}
-Constant *ConstantDataArray::getFP(LLVMContext &Context,
- ArrayRef<uint32_t> Elts) {
- Type *Ty = ArrayType::get(Type::getFloatTy(Context), Elts.size());
+Constant *ConstantDataArray::getFP(Type *ElementType, ArrayRef<uint32_t> Elts) {
+ assert(ElementType->isFloatTy() && "Element type is not a 32-bit float type");
+ Type *Ty = ArrayType::get(ElementType, Elts.size());
const char *Data = reinterpret_cast<const char *>(Elts.data());
return getImpl(StringRef(Data, Elts.size() * 4), Ty);
}
-Constant *ConstantDataArray::getFP(LLVMContext &Context,
- ArrayRef<uint64_t> Elts) {
- Type *Ty = ArrayType::get(Type::getDoubleTy(Context), Elts.size());
+Constant *ConstantDataArray::getFP(Type *ElementType, ArrayRef<uint64_t> Elts) {
+ assert(ElementType->isDoubleTy() &&
+ "Element type is not a 64-bit float type");
+ Type *Ty = ArrayType::get(ElementType, Elts.size());
const char *Data = reinterpret_cast<const char *>(Elts.data());
return getImpl(StringRef(Data, Elts.size() * 8), Ty);
}
return getImpl(StringRef(Data, Elts.size() * 8), Ty);
}
-/// getFP() constructors - Return a constant with vector type with an element
-/// count and element type of float with the precision matching the number of
-/// bits in the ArrayRef passed in. (i.e. half for 16bits, float for 32bits,
-/// double for 64bits) Note that this can return a ConstantAggregateZero
-/// object.
-Constant *ConstantDataVector::getFP(LLVMContext &Context,
+/// getFP() constructors - Return a constant of vector type with a float
+/// element type taken from argument `ElementType', and count taken from
+/// argument `Elts'. The amount of bits of the contained type must match the
+/// number of bits of the type contained in the passed in ArrayRef.
+/// (i.e. half or bfloat for 16bits, float for 32bits, double for 64bits) Note
+/// that this can return a ConstantAggregateZero object.
+Constant *ConstantDataVector::getFP(Type *ElementType,
ArrayRef<uint16_t> Elts) {
- Type *Ty = VectorType::get(Type::getHalfTy(Context), Elts.size());
+ assert((ElementType->isHalfTy() || ElementType->isBFloatTy()) &&
+ "Element type is not a 16-bit float type");
+ Type *Ty = VectorType::get(ElementType, Elts.size());
const char *Data = reinterpret_cast<const char *>(Elts.data());
return getImpl(StringRef(Data, Elts.size() * 2), Ty);
}
-Constant *ConstantDataVector::getFP(LLVMContext &Context,
+Constant *ConstantDataVector::getFP(Type *ElementType,
ArrayRef<uint32_t> Elts) {
- Type *Ty = VectorType::get(Type::getFloatTy(Context), Elts.size());
+ assert(ElementType->isFloatTy() && "Element type is not a 32-bit float type");
+ Type *Ty = VectorType::get(ElementType, Elts.size());
const char *Data = reinterpret_cast<const char *>(Elts.data());
return getImpl(StringRef(Data, Elts.size() * 4), Ty);
}
-Constant *ConstantDataVector::getFP(LLVMContext &Context,
+Constant *ConstantDataVector::getFP(Type *ElementType,
ArrayRef<uint64_t> Elts) {
- Type *Ty = VectorType::get(Type::getDoubleTy(Context), Elts.size());
+ assert(ElementType->isDoubleTy() &&
+ "Element type is not a 64-bit float type");
+ Type *Ty = VectorType::get(ElementType, Elts.size());
const char *Data = reinterpret_cast<const char *>(Elts.data());
return getImpl(StringRef(Data, Elts.size() * 8), Ty);
}
if (CFP->getType()->isHalfTy()) {
SmallVector<uint16_t, 16> Elts(
NumElts, CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
- return getFP(V->getContext(), Elts);
+ return getFP(V->getType(), Elts);
+ }
+ if (CFP->getType()->isBFloatTy()) {
+ SmallVector<uint16_t, 16> Elts(
+ NumElts, CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
+ return getFP(V->getType(), Elts);
}
if (CFP->getType()->isFloatTy()) {
SmallVector<uint32_t, 16> Elts(
NumElts, CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
- return getFP(V->getContext(), Elts);
+ return getFP(V->getType(), Elts);
}
if (CFP->getType()->isDoubleTy()) {
SmallVector<uint64_t, 16> Elts(
NumElts, CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
- return getFP(V->getContext(), Elts);
+ return getFP(V->getType(), Elts);
}
}
return ConstantVector::getSplat({NumElts, false}, V);
auto EltVal = *reinterpret_cast<const uint16_t *>(EltPtr);
return APFloat(APFloat::IEEEhalf(), APInt(16, EltVal));
}
+ case Type::BFloatTyID: {
+ auto EltVal = *reinterpret_cast<const uint16_t *>(EltPtr);
+ return APFloat(APFloat::BFloat(), APInt(16, EltVal));
+ }
case Type::FloatTyID: {
auto EltVal = *reinterpret_cast<const uint32_t *>(EltPtr);
return APFloat(APFloat::IEEEsingle(), APInt(32, EltVal));
}
Constant *ConstantDataSequential::getElementAsConstant(unsigned Elt) const {
- if (getElementType()->isHalfTy() || getElementType()->isFloatTy() ||
- getElementType()->isDoubleTy())
+ if (getElementType()->isHalfTy() || getElementType()->isBFloatTy() ||
+ getElementType()->isFloatTy() || getElementType()->isDoubleTy())
return ConstantFP::get(getContext(), getElementAsAPFloat(Elt));
return ConstantInt::get(getElementType(), getElementAsInteger(Elt));
return LLVMVoidTypeKind;
case Type::HalfTyID:
return LLVMHalfTypeKind;
+ case Type::BFloatTyID:
+ return LLVMBFloatTypeKind;
case Type::FloatTyID:
return LLVMFloatTypeKind;
case Type::DoubleTyID:
LLVMTypeRef LLVMHalfTypeInContext(LLVMContextRef C) {
return (LLVMTypeRef) Type::getHalfTy(*unwrap(C));
}
+LLVMTypeRef LLVMBFloatTypeInContext(LLVMContextRef C) {
+ return (LLVMTypeRef) Type::getBFloatTy(*unwrap(C));
+}
LLVMTypeRef LLVMFloatTypeInContext(LLVMContextRef C) {
return (LLVMTypeRef) Type::getFloatTy(*unwrap(C));
}
LLVMTypeRef LLVMHalfType(void) {
return LLVMHalfTypeInContext(LLVMGetGlobalContext());
}
+LLVMTypeRef LLVMBFloatType(void) {
+ return LLVMBFloatTypeInContext(LLVMGetGlobalContext());
+}
LLVMTypeRef LLVMFloatType(void) {
return LLVMFloatTypeInContext(LLVMGetGlobalContext());
}
{INTEGER_ALIGN, 16, Align(2), Align(2)}, // i16
{INTEGER_ALIGN, 32, Align(4), Align(4)}, // i32
{INTEGER_ALIGN, 64, Align(4), Align(8)}, // i64
- {FLOAT_ALIGN, 16, Align(2), Align(2)}, // half
+ {FLOAT_ALIGN, 16, Align(2), Align(2)}, // half, bfloat
{FLOAT_ALIGN, 32, Align(4), Align(4)}, // float
{FLOAT_ALIGN, 64, Align(8), Align(8)}, // double
{FLOAT_ALIGN, 128, Align(16), Align(16)}, // ppcf128, quad, ...
AlignType = INTEGER_ALIGN;
break;
case Type::HalfTyID:
+ case Type::BFloatTyID:
case Type::FloatTyID:
case Type::DoubleTyID:
// PPC_FP128TyID and FP128TyID have different data contents, but the
case Type::VoidTyID: Result += "isVoid"; break;
case Type::MetadataTyID: Result += "Metadata"; break;
case Type::HalfTyID: Result += "f16"; break;
+ case Type::BFloatTyID: Result += "bf16"; break;
case Type::FloatTyID: Result += "f32"; break;
case Type::DoubleTyID: Result += "f64"; break;
case Type::X86_FP80TyID: Result += "f80"; break;
VoidTy(C, Type::VoidTyID),
LabelTy(C, Type::LabelTyID),
HalfTy(C, Type::HalfTyID),
+ BFloatTy(C, Type::BFloatTyID),
FloatTy(C, Type::FloatTyID),
DoubleTy(C, Type::DoubleTyID),
MetadataTy(C, Type::MetadataTyID),
std::unique_ptr<ConstantTokenNone> TheNoneToken;
// Basic type instances.
- Type VoidTy, LabelTy, HalfTy, FloatTy, DoubleTy, MetadataTy, TokenTy;
+ Type VoidTy, LabelTy, HalfTy, BFloatTy, FloatTy, DoubleTy, MetadataTy,
+ TokenTy;
Type X86_FP80Ty, FP128Ty, PPC_FP128Ty, X86_MMXTy;
IntegerType Int1Ty, Int8Ty, Int16Ty, Int32Ty, Int64Ty, Int128Ty;
switch (IDNumber) {
case VoidTyID : return getVoidTy(C);
case HalfTyID : return getHalfTy(C);
+ case BFloatTyID : return getBFloatTy(C);
case FloatTyID : return getFloatTy(C);
case DoubleTyID : return getDoubleTy(C);
case X86_FP80TyID : return getX86_FP80Ty(C);
TypeSize Type::getPrimitiveSizeInBits() const {
switch (getTypeID()) {
case Type::HalfTyID: return TypeSize::Fixed(16);
+ case Type::BFloatTyID: return TypeSize::Fixed(16);
case Type::FloatTyID: return TypeSize::Fixed(32);
case Type::DoubleTyID: return TypeSize::Fixed(64);
case Type::X86_FP80TyID: return TypeSize::Fixed(80);
return VTy->getElementType()->getFPMantissaWidth();
assert(isFloatingPointTy() && "Not a floating point type!");
if (getTypeID() == HalfTyID) return 11;
+ if (getTypeID() == BFloatTyID) return 8;
if (getTypeID() == FloatTyID) return 24;
if (getTypeID() == DoubleTyID) return 53;
if (getTypeID() == X86_FP80TyID) return 64;
Type *Type::getVoidTy(LLVMContext &C) { return &C.pImpl->VoidTy; }
Type *Type::getLabelTy(LLVMContext &C) { return &C.pImpl->LabelTy; }
Type *Type::getHalfTy(LLVMContext &C) { return &C.pImpl->HalfTy; }
+Type *Type::getBFloatTy(LLVMContext &C) { return &C.pImpl->BFloatTy; }
Type *Type::getFloatTy(LLVMContext &C) { return &C.pImpl->FloatTy; }
Type *Type::getDoubleTy(LLVMContext &C) { return &C.pImpl->DoubleTy; }
Type *Type::getMetadataTy(LLVMContext &C) { return &C.pImpl->MetadataTy; }
return getHalfTy(C)->getPointerTo(AS);
}
+PointerType *Type::getBFloatPtrTy(LLVMContext &C, unsigned AS) {
+ return getBFloatTy(C)->getPointerTo(AS);
+}
+
PointerType *Type::getFloatPtrTy(LLVMContext &C, unsigned AS) {
return getFloatTy(C)->getPointerTo(AS);
}
};
static const fltSemantics semIEEEhalf = {15, -14, 11, 16};
+ static const fltSemantics semBFloat = {127, -126, 8, 16};
static const fltSemantics semIEEEsingle = {127, -126, 24, 32};
static const fltSemantics semIEEEdouble = {1023, -1022, 53, 64};
static const fltSemantics semIEEEquad = {16383, -16382, 113, 128};
switch (S) {
case S_IEEEhalf:
return IEEEhalf();
+ case S_BFloat:
+ return BFloat();
case S_IEEEsingle:
return IEEEsingle();
case S_IEEEdouble:
APFloatBase::SemanticsToEnum(const llvm::fltSemantics &Sem) {
if (&Sem == &llvm::APFloat::IEEEhalf())
return S_IEEEhalf;
+ else if (&Sem == &llvm::APFloat::BFloat())
+ return S_BFloat;
else if (&Sem == &llvm::APFloat::IEEEsingle())
return S_IEEEsingle;
else if (&Sem == &llvm::APFloat::IEEEdouble())
const fltSemantics &APFloatBase::IEEEhalf() {
return semIEEEhalf;
}
+ const fltSemantics &APFloatBase::BFloat() {
+ return semBFloat;
+ }
const fltSemantics &APFloatBase::IEEEsingle() {
return semIEEEsingle;
}
(mysignificand & 0x7fffff)));
}
+APInt IEEEFloat::convertBFloatAPFloatToAPInt() const {
+ assert(semantics == (const llvm::fltSemantics *)&semBFloat);
+ assert(partCount() == 1);
+
+ uint32_t myexponent, mysignificand;
+
+ if (isFiniteNonZero()) {
+ myexponent = exponent + 127; // bias
+ mysignificand = (uint32_t)*significandParts();
+ if (myexponent == 1 && !(mysignificand & 0x80))
+ myexponent = 0; // denormal
+ } else if (category == fcZero) {
+ myexponent = 0;
+ mysignificand = 0;
+ } else if (category == fcInfinity) {
+ myexponent = 0x1f;
+ mysignificand = 0;
+ } else {
+ assert(category == fcNaN && "Unknown category!");
+ myexponent = 0x1f;
+ mysignificand = (uint32_t)*significandParts();
+ }
+
+ return APInt(16, (((sign & 1) << 15) | ((myexponent & 0xff) << 7) |
+ (mysignificand & 0x7f)));
+}
+
APInt IEEEFloat::convertHalfAPFloatToAPInt() const {
assert(semantics == (const llvm::fltSemantics*)&semIEEEhalf);
assert(partCount()==1);
if (semantics == (const llvm::fltSemantics*)&semIEEEhalf)
return convertHalfAPFloatToAPInt();
+ if (semantics == (const llvm::fltSemantics *)&semBFloat)
+ return convertBFloatAPFloatToAPInt();
+
if (semantics == (const llvm::fltSemantics*)&semIEEEsingle)
return convertFloatAPFloatToAPInt();
}
}
+void IEEEFloat::initFromBFloatAPInt(const APInt &api) {
+ assert(api.getBitWidth() == 16);
+ uint32_t i = (uint32_t)*api.getRawData();
+ uint32_t myexponent = (i >> 7) & 0xff;
+ uint32_t mysignificand = i & 0x7f;
+
+ initialize(&semBFloat);
+ assert(partCount() == 1);
+
+ sign = i >> 15;
+ if (myexponent == 0 && mysignificand == 0) {
+ // exponent, significand meaningless
+ category = fcZero;
+ } else if (myexponent == 0xff && mysignificand == 0) {
+ // exponent, significand meaningless
+ category = fcInfinity;
+ } else if (myexponent == 0xff && mysignificand != 0) {
+ // sign, exponent, significand meaningless
+ category = fcNaN;
+ *significandParts() = mysignificand;
+ } else {
+ category = fcNormal;
+ exponent = myexponent - 127; // bias
+ *significandParts() = mysignificand;
+ if (myexponent == 0) // denormal
+ exponent = -126;
+ else
+ *significandParts() |= 0x80; // integer bit
+ }
+}
+
void IEEEFloat::initFromHalfAPInt(const APInt &api) {
assert(api.getBitWidth()==16);
uint32_t i = (uint32_t)*api.getRawData();
void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) {
if (Sem == &semIEEEhalf)
return initFromHalfAPInt(api);
+ if (Sem == &semBFloat)
+ return initFromBFloatAPInt(api);
if (Sem == &semIEEEsingle)
return initFromFloatAPInt(api);
if (Sem == &semIEEEdouble)
llvm_unreachable("Unexpected semantics");
}
-APFloat APFloat::getAllOnesValue(unsigned BitWidth, bool isIEEE) {
- if (isIEEE) {
- switch (BitWidth) {
- case 16:
- return APFloat(semIEEEhalf, APInt::getAllOnesValue(BitWidth));
- case 32:
- return APFloat(semIEEEsingle, APInt::getAllOnesValue(BitWidth));
- case 64:
- return APFloat(semIEEEdouble, APInt::getAllOnesValue(BitWidth));
- case 80:
- return APFloat(semX87DoubleExtended, APInt::getAllOnesValue(BitWidth));
- case 128:
- return APFloat(semIEEEquad, APInt::getAllOnesValue(BitWidth));
- default:
- llvm_unreachable("Unknown floating bit width");
- }
- } else {
- assert(BitWidth == 128);
- return APFloat(semPPCDoubleDouble, APInt::getAllOnesValue(BitWidth));
- }
+APFloat APFloat::getAllOnesValue(const fltSemantics &Semantics,
+ unsigned BitWidth) {
+ return APFloat(Semantics, APInt::getAllOnesValue(BitWidth));
}
void APFloat::print(raw_ostream &OS) const {
}
case Type::FunctionTyID:
case Type::VoidTyID:
+ case Type::BFloatTyID:
case Type::X86_FP80TyID:
case Type::FP128TyID:
case Type::PPC_FP128TyID:
MVT LogicVT = VT;
if (EltVT == MVT::f32 || EltVT == MVT::f64) {
Zero = DAG.getConstantFP(0.0, DL, EltVT);
- AllOnes = DAG.getConstantFP(
- APFloat::getAllOnesValue(EltVT.getSizeInBits(), true), DL, EltVT);
+ APFloat AllOnesValue = APFloat::getAllOnesValue(
+ SelectionDAG::EVTToAPFloatSemantics(EltVT), EltVT.getSizeInBits());
+ AllOnes = DAG.getConstantFP(AllOnesValue, DL, EltVT);
LogicVT =
MVT::getVectorVT(EltVT == MVT::f64 ? MVT::i64 : MVT::i32, Mask.size());
} else {
--- /dev/null
+; RUN: llvm-as < %s | llvm-dis | FileCheck %s --check-prefix=ASSEM-DISASS
+; RUN: opt < %s -O3 -S | FileCheck %s --check-prefix=OPT
+; RUN: verify-uselistorder %s
+; Basic smoke tests for bfloat type.
+
+define bfloat @check_bfloat(bfloat %A) {
+; ASSEM-DISASS: ret bfloat %A
+ ret bfloat %A
+}
+
+define bfloat @check_bfloat_literal() {
+; ASSEM-DISASS: ret bfloat 0xR3149
+ ret bfloat 0xR3149
+}
+
+define <4 x bfloat> @check_fixed_vector() {
+; ASSEM-DISASS: ret <4 x bfloat> %tmp
+ %tmp = fadd <4 x bfloat> undef, undef
+ ret <4 x bfloat> %tmp
+}
+
+define <vscale x 4 x bfloat> @check_vector() {
+; ASSEM-DISASS: ret <vscale x 4 x bfloat> %tmp
+ %tmp = fadd <vscale x 4 x bfloat> undef, undef
+ ret <vscale x 4 x bfloat> %tmp
+}
+
+define bfloat @check_bfloat_constprop() {
+ %tmp = fadd bfloat 0xR40C0, 0xR40C0
+; OPT: 0xR4140
+ ret bfloat %tmp
+}
+
+define float @check_bfloat_convert() {
+ %tmp = fpext bfloat 0xR4C8D to float
+; OPT: 0x4191A00000000000
+ ret float %tmp
+}
return LLVMVoidTypeInContext(Ctx);
case LLVMHalfTypeKind:
return LLVMHalfTypeInContext(Ctx);
+ case LLVMBFloatTypeKind:
+ return LLVMHalfTypeInContext(Ctx);
case LLVMFloatTypeKind:
return LLVMFloatTypeInContext(Ctx);
case LLVMDoubleTypeKind: