[IR][BFloat] Add BFloat IR type
authorTies Stuij <ties.stuij@arm.com>
Tue, 31 Mar 2020 22:49:38 +0000 (23:49 +0100)
committerTies Stuij <ties.stuij@arm.com>
Fri, 15 May 2020 13:43:43 +0000 (14:43 +0100)
Summary:
The BFloat IR type is introduced to provide support for, initially, the BFloat16
datatype introduced with the Armv8.6 architecture (optional from Armv8.2
onwards). It has an 8-bit exponent and a 7-bit mantissa and behaves like an IEEE
754 floating point IR type.

This is part of a patch series upstreaming Armv8.6 features. Subsequent patches
will upstream intrinsics support and C-lang support for BFloat.

Reviewers: SjoerdMeijer, rjmccall, rsmith, liutianle, RKSimon, craig.topper, jfb, LukeGeeson, sdesmalen, deadalnix, ctetreau

Subscribers: hiraditya, llvm-commits, danielkiss, arphaman, kristof.beyls, dexonsmith

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D78190

28 files changed:
clang/lib/Sema/SemaOpenMP.cpp
llvm/docs/BitCodeFormat.rst
llvm/docs/LangRef.rst
llvm/include/llvm-c/Core.h
llvm/include/llvm/ADT/APFloat.h
llvm/include/llvm/Bitcode/LLVMBitCodes.h
llvm/include/llvm/IR/Constants.h
llvm/include/llvm/IR/DataLayout.h
llvm/include/llvm/IR/IRBuilder.h
llvm/include/llvm/IR/Type.h
llvm/lib/AsmParser/LLLexer.cpp
llvm/lib/AsmParser/LLParser.cpp
llvm/lib/Bitcode/Reader/BitcodeReader.cpp
llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
llvm/lib/CodeGen/MIRParser/MILexer.cpp
llvm/lib/IR/AsmWriter.cpp
llvm/lib/IR/Constants.cpp
llvm/lib/IR/Core.cpp
llvm/lib/IR/DataLayout.cpp
llvm/lib/IR/Function.cpp
llvm/lib/IR/LLVMContextImpl.cpp
llvm/lib/IR/LLVMContextImpl.h
llvm/lib/IR/Type.cpp
llvm/lib/Support/APFloat.cpp
llvm/lib/Target/Hexagon/HexagonTargetObjectFile.cpp
llvm/lib/Target/X86/X86ISelLowering.cpp
llvm/test/Assembler/bfloat.ll [new file with mode: 0644]
llvm/tools/llvm-c-test/echo.cpp

index 544dc61..e03b926 100644 (file)
@@ -14936,9 +14936,9 @@ static bool actOnOMPReductionKindClause(
         if (auto *ComplexTy = OrigType->getAs<ComplexType>())
           Type = ComplexTy->getElementType();
         if (Type->isRealFloatingType()) {
-          llvm::APFloat InitValue =
-              llvm::APFloat::getAllOnesValue(Context.getTypeSize(Type),
-                                             /*isIEEE=*/true);
+          llvm::APFloat InitValue = llvm::APFloat::getAllOnesValue(
+              Context.getFloatTypeSemantics(Type),
+              Context.getTypeSize(Type));
           Init = FloatingLiteral::Create(Context, InitValue, /*isexact=*/true,
                                          Type, ELoc);
         } else if (Type->isScalarType()) {
index dce8462..4fdccc8 100644 (file)
@@ -1107,6 +1107,14 @@ TYPE_CODE_HALF Record
 The ``HALF`` record (code 10) adds a ``half`` (16-bit floating point) type to
 the type table.
 
+TYPE_CODE_BFLOAT Record
+^^^^^^^^^^^^^^^^^^^^^
+
+``[BFLOAT]``
+
+The ``BFLOAT`` record (code 23) adds a ``bfloat`` (16-bit brain floating point)
+type to the type table.
+
 TYPE_CODE_FLOAT Record
 ^^^^^^^^^^^^^^^^^^^^^^
 
index 240dbd6..07320de 100644 (file)
@@ -2963,6 +2963,12 @@ Floating-Point Types
    * - ``half``
      - 16-bit floating-point value
 
+   * - ``bfloat``
+     - 16-bit "brain" floating-point value (7-bit significand).  Provides the
+       same number of exponent bits as ``float``, so that it matches its dynamic
+       range, but with greatly reduced precision.  Used in Intel's AVX-512 BF16
+       extensions and Arm's ARMv8.6-A extensions, among others.
+
    * - ``float``
      - 32-bit floating-point value
 
@@ -2970,7 +2976,7 @@ Floating-Point Types
      - 64-bit floating-point value
 
    * - ``fp128``
-     - 128-bit floating-point value (112-bit mantissa)
+     - 128-bit floating-point value (112-bit significand)
 
    * - ``x86_fp80``
      -  80-bit floating-point value (X87)
@@ -3303,20 +3309,20 @@ number of digits. For example, NaN's, infinities, and other special
 values are represented in their IEEE hexadecimal format so that assembly
 and disassembly do not cause any bits to change in the constants.
 
-When using the hexadecimal form, constants of types half, float, and
-double are represented using the 16-digit form shown above (which
-matches the IEEE754 representation for double); half and float values
-must, however, be exactly representable as IEEE 754 half and single
-precision, respectively. Hexadecimal format is always used for long
-double, and there are three forms of long double. The 80-bit format used
-by x86 is represented as ``0xK`` followed by 20 hexadecimal digits. The
-128-bit format used by PowerPC (two adjacent doubles) is represented by
-``0xM`` followed by 32 hexadecimal digits. The IEEE 128-bit format is
-represented by ``0xL`` followed by 32 hexadecimal digits. Long doubles
-will only work if they match the long double format on your target.
-The IEEE 16-bit format (half precision) is represented by ``0xH``
-followed by 4 hexadecimal digits. All hexadecimal formats are big-endian
-(sign bit at the left).
+When using the hexadecimal form, constants of types bfloat, half, float, and
+double are represented using the 16-digit form shown above (which matches the
+IEEE754 representation for double); bfloat, half and float values must, however,
+be exactly representable as bfloat, IEEE 754 half, and IEEE 754 single
+precision respectively. Hexadecimal format is always used for long double, and
+there are three forms of long double. The 80-bit format used by x86 is
+represented as ``0xK`` followed by 20 hexadecimal digits. The 128-bit format
+used by PowerPC (two adjacent doubles) is represented by ``0xM`` followed by 32
+hexadecimal digits. The IEEE 128-bit format is represented by ``0xL`` followed
+by 32 hexadecimal digits. Long doubles will only work if they match the long
+double format on your target.  The IEEE 16-bit format (half precision) is
+represented by ``0xH`` followed by 4 hexadecimal digits. The bfloat 16-bit
+format is represented by ``0xR`` followed by 4 hexadecimal digits. All
+hexadecimal formats are big-endian (sign bit at the left).
 
 There are no constants of type x86_mmx.
 
index 25802ed..1991dd9 100644 (file)
@@ -146,6 +146,7 @@ typedef enum {
 typedef enum {
   LLVMVoidTypeKind,        /**< type with no size */
   LLVMHalfTypeKind,        /**< 16 bit floating point type */
+  LLVMBFloatTypeKind,      /**< 16 bit brain floating point type */
   LLVMFloatTypeKind,       /**< 32 bit floating point type */
   LLVMDoubleTypeKind,      /**< 64 bit floating point type */
   LLVMX86_FP80TypeKind,    /**< 80 bit floating point type (X87) */
@@ -1164,6 +1165,11 @@ unsigned LLVMGetIntTypeWidth(LLVMTypeRef IntegerTy);
 LLVMTypeRef LLVMHalfTypeInContext(LLVMContextRef C);
 
 /**
+ * Obtain a 16-bit brain floating point type from a context.
+ */
+LLVMTypeRef LLVMBFloatTypeInContext(LLVMContextRef C);
+
+/**
  * Obtain a 32-bit floating point type from a context.
  */
 LLVMTypeRef LLVMFloatTypeInContext(LLVMContextRef C);
@@ -1195,6 +1201,7 @@ LLVMTypeRef LLVMPPCFP128TypeInContext(LLVMContextRef C);
  * These map to the functions in this group of the same name.
  */
 LLVMTypeRef LLVMHalfType(void);
+LLVMTypeRef LLVMBFloatType(void);
 LLVMTypeRef LLVMFloatType(void);
 LLVMTypeRef LLVMDoubleType(void);
 LLVMTypeRef LLVMX86FP80Type(void);
index 1c17f10..44857f7 100644 (file)
@@ -151,6 +151,7 @@ struct APFloatBase {
   /// @{
   enum Semantics {
     S_IEEEhalf,
+    S_BFloat,
     S_IEEEsingle,
     S_IEEEdouble,
     S_x87DoubleExtended,
@@ -162,6 +163,7 @@ struct APFloatBase {
   static Semantics SemanticsToEnum(const llvm::fltSemantics &Sem);
 
   static const fltSemantics &IEEEhalf() LLVM_READNONE;
+  static const fltSemantics &BFloat() LLVM_READNONE;
   static const fltSemantics &IEEEsingle() LLVM_READNONE;
   static const fltSemantics &IEEEdouble() LLVM_READNONE;
   static const fltSemantics &IEEEquad() LLVM_READNONE;
@@ -541,6 +543,7 @@ private:
   /// @}
 
   APInt convertHalfAPFloatToAPInt() const;
+  APInt convertBFloatAPFloatToAPInt() const;
   APInt convertFloatAPFloatToAPInt() const;
   APInt convertDoubleAPFloatToAPInt() const;
   APInt convertQuadrupleAPFloatToAPInt() const;
@@ -548,6 +551,7 @@ private:
   APInt convertPPCDoubleDoubleAPFloatToAPInt() const;
   void initFromAPInt(const fltSemantics *Sem, const APInt &api);
   void initFromHalfAPInt(const APInt &api);
+  void initFromBFloatAPInt(const APInt &api);
   void initFromFloatAPInt(const APInt &api);
   void initFromDoubleAPInt(const APInt &api);
   void initFromQuadrupleAPInt(const APInt &api);
@@ -954,9 +958,10 @@ public:
 
   /// Returns a float which is bitcasted from an all one value int.
   ///
+  /// \param Semantics - type float semantics
   /// \param BitWidth - Select float type
-  /// \param isIEEE   - If 128 bit number, select between PPC and IEEE
-  static APFloat getAllOnesValue(unsigned BitWidth, bool isIEEE = false);
+  static APFloat getAllOnesValue(const fltSemantics &Semantics,
+                                 unsigned BitWidth);
 
   /// Used to insert APFloat objects, or objects that contain APFloat objects,
   /// into FoldingSets.
index e614337..2f09ad3 100644 (file)
@@ -166,7 +166,9 @@ enum TypeCodes {
 
   TYPE_CODE_FUNCTION = 21, // FUNCTION: [vararg, retty, paramty x N]
 
-  TYPE_CODE_TOKEN = 22 // TOKEN
+  TYPE_CODE_TOKEN = 22, // TOKEN
+
+  TYPE_CODE_BFLOAT = 23 // BRAIN FLOATING POINT
 };
 
 enum OperandBundleTagCode {
index b31bcb7..25d8f6a 100644 (file)
@@ -721,14 +721,15 @@ public:
     return getImpl(Data, Ty);
   }
 
-  /// getFP() constructors - Return a constant with array type with an element
-  /// count and element type of float with precision matching the number of
-  /// bits in the ArrayRef passed in. (i.e. half for 16bits, float for 32bits,
-  /// double for 64bits) Note that this can return a ConstantAggregateZero
-  /// object.
-  static Constant *getFP(LLVMContext &Context, ArrayRef<uint16_t> Elts);
-  static Constant *getFP(LLVMContext &Context, ArrayRef<uint32_t> Elts);
-  static Constant *getFP(LLVMContext &Context, ArrayRef<uint64_t> Elts);
+  /// getFP() constructors - Return a constant of array type with a float
+  /// element type taken from argument `ElementType', and count taken from
+  /// argument `Elts'.  The amount of bits of the contained type must match the
+  /// number of bits of the type contained in the passed in ArrayRef.
+  /// (i.e. half or bfloat for 16bits, float for 32bits, double for 64bits) Note
+  /// that this can return a ConstantAggregateZero object.
+  static Constant *getFP(Type *ElementType, ArrayRef<uint16_t> Elts);
+  static Constant *getFP(Type *ElementType, ArrayRef<uint32_t> Elts);
+  static Constant *getFP(Type *ElementType, ArrayRef<uint64_t> Elts);
 
   /// This method constructs a CDS and initializes it with a text string.
   /// The default behavior (AddNull==true) causes a null terminator to
@@ -780,14 +781,15 @@ public:
   static Constant *get(LLVMContext &Context, ArrayRef<float> Elts);
   static Constant *get(LLVMContext &Context, ArrayRef<double> Elts);
 
-  /// getFP() constructors - Return a constant with vector type with an element
-  /// count and element type of float with the precision matching the number of
-  /// bits in the ArrayRef passed in.  (i.e. half for 16bits, float for 32bits,
-  /// double for 64bits) Note that this can return a ConstantAggregateZero
-  /// object.
-  static Constant *getFP(LLVMContext &Context, ArrayRef<uint16_t> Elts);
-  static Constant *getFP(LLVMContext &Context, ArrayRef<uint32_t> Elts);
-  static Constant *getFP(LLVMContext &Context, ArrayRef<uint64_t> Elts);
+  /// getFP() constructors - Return a constant of vector type with a float
+  /// element type taken from argument `ElementType', and count taken from
+  /// argument `Elts'.  The amount of bits of the contained type must match the
+  /// number of bits of the type contained in the passed in ArrayRef.
+  /// (i.e. half or bfloat for 16bits, float for 32bits, double for 64bits) Note
+  /// that this can return a ConstantAggregateZero object.
+  static Constant *getFP(Type *ElementType, ArrayRef<uint16_t> Elts);
+  static Constant *getFP(Type *ElementType, ArrayRef<uint32_t> Elts);
+  static Constant *getFP(Type *ElementType, ArrayRef<uint64_t> Elts);
 
   /// Return a ConstantVector with the specified constant in each element.
   /// The specified constant has to be a of a compatible type (i8/i16/
index 010469c..e8fab02 100644 (file)
@@ -651,6 +651,7 @@ inline TypeSize DataLayout::getTypeSizeInBits(Type *Ty) const {
   case Type::IntegerTyID:
     return TypeSize::Fixed(Ty->getIntegerBitWidth());
   case Type::HalfTyID:
+  case Type::BFloatTyID:
     return TypeSize::Fixed(16);
   case Type::FloatTyID:
     return TypeSize::Fixed(32);
index 6e431bc..b6dca11 100644 (file)
@@ -477,6 +477,11 @@ public:
     return Type::getHalfTy(Context);
   }
 
+  /// Fetch the type representing a 16-bit brain floating point value.
+  Type *getBFloatTy() {
+    return Type::getBFloatTy(Context);
+  }
+
   /// Fetch the type representing a 32-bit floating point value.
   Type *getFloatTy() {
     return Type::getFloatTy(Context);
index 618eee0..5d6c0c6 100644 (file)
@@ -54,27 +54,28 @@ public:
   ///
   enum TypeID {
     // PrimitiveTypes - make sure LastPrimitiveTyID stays up to date.
-    VoidTyID = 0,  ///<  0: type with no size
-    HalfTyID,      ///<  1: 16-bit floating point type
-    FloatTyID,     ///<  2: 32-bit floating point type
-    DoubleTyID,    ///<  3: 64-bit floating point type
-    X86_FP80TyID,  ///<  4: 80-bit floating point type (X87)
-    FP128TyID,     ///<  5: 128-bit floating point type (112-bit mantissa)
-    PPC_FP128TyID, ///<  6: 128-bit floating point type (two 64-bits, PowerPC)
-    LabelTyID,     ///<  7: Labels
-    MetadataTyID,  ///<  8: Metadata
-    X86_MMXTyID,   ///<  9: MMX vectors (64 bits, X86 specific)
-    TokenTyID,     ///< 10: Tokens
+    VoidTyID = 0,    ///<  0: type with no size
+    HalfTyID,        ///<  1: 16-bit floating point type
+    BFloatTyID,      ///<  2: 16-bit floating point type (7-bit significand)
+    FloatTyID,       ///<  3: 32-bit floating point type
+    DoubleTyID,      ///<  4: 64-bit floating point type
+    X86_FP80TyID,    ///<  5: 80-bit floating point type (X87)
+    FP128TyID,       ///<  6: 128-bit floating point type (112-bit significand)
+    PPC_FP128TyID,   ///<  7: 128-bit floating point type (two 64-bits, PowerPC)
+    LabelTyID,       ///<  8: Labels
+    MetadataTyID,    ///<  9: Metadata
+    X86_MMXTyID,     ///< 10: MMX vectors (64 bits, X86 specific)
+    TokenTyID,       ///< 11: Tokens
 
     // Derived types... see DerivedTypes.h file.
     // Make sure FirstDerivedTyID stays up to date!
-    IntegerTyID,       ///< 11: Arbitrary bit width integers
-    FunctionTyID,      ///< 12: Functions
-    StructTyID,        ///< 13: Structures
-    ArrayTyID,         ///< 14: Arrays
-    PointerTyID,       ///< 15: Pointers
-    FixedVectorTyID,   ///< 16: Fixed width SIMD vector type
-    ScalableVectorTyID ///< 17: Scalable SIMD vector type
+    IntegerTyID,       ///< 12: Arbitrary bit width integers
+    FunctionTyID,      ///< 13: Functions
+    StructTyID,        ///< 14: Structures
+    ArrayTyID,         ///< 15: Arrays
+    PointerTyID,       ///< 16: Pointers
+    FixedVectorTyID,   ///< 17: Fixed width SIMD vector type
+    ScalableVectorTyID ///< 18: Scalable SIMD vector type
   };
 
 private:
@@ -140,6 +141,9 @@ public:
   /// Return true if this is 'half', a 16-bit IEEE fp type.
   bool isHalfTy() const { return getTypeID() == HalfTyID; }
 
+  /// Return true if this is 'bfloat', a 16-bit bfloat type.
+  bool isBFloatTy() const { return getTypeID() == BFloatTyID; }
+
   /// Return true if this is 'float', a 32-bit IEEE fp type.
   bool isFloatTy() const { return getTypeID() == FloatTyID; }
 
@@ -157,8 +161,8 @@ public:
 
   /// Return true if this is one of the six floating-point types
   bool isFloatingPointTy() const {
-    return getTypeID() == HalfTyID || getTypeID() == FloatTyID ||
-           getTypeID() == DoubleTyID ||
+    return getTypeID() == HalfTyID || getTypeID() == BFloatTyID ||
+           getTypeID() == FloatTyID || getTypeID() == DoubleTyID ||
            getTypeID() == X86_FP80TyID || getTypeID() == FP128TyID ||
            getTypeID() == PPC_FP128TyID;
   }
@@ -166,6 +170,7 @@ public:
   const fltSemantics &getFltSemantics() const {
     switch (getTypeID()) {
     case HalfTyID: return APFloat::IEEEhalf();
+    case BFloatTyID: return APFloat::BFloat();
     case FloatTyID: return APFloat::IEEEsingle();
     case DoubleTyID: return APFloat::IEEEdouble();
     case X86_FP80TyID: return APFloat::x87DoubleExtended();
@@ -387,6 +392,7 @@ public:
   static Type *getVoidTy(LLVMContext &C);
   static Type *getLabelTy(LLVMContext &C);
   static Type *getHalfTy(LLVMContext &C);
+  static Type *getBFloatTy(LLVMContext &C);
   static Type *getFloatTy(LLVMContext &C);
   static Type *getDoubleTy(LLVMContext &C);
   static Type *getMetadataTy(LLVMContext &C);
@@ -422,6 +428,7 @@ public:
   // types as pointee.
   //
   static PointerType *getHalfPtrTy(LLVMContext &C, unsigned AS = 0);
+  static PointerType *getBFloatPtrTy(LLVMContext &C, unsigned AS = 0);
   static PointerType *getFloatPtrTy(LLVMContext &C, unsigned AS = 0);
   static PointerType *getDoublePtrTy(LLVMContext &C, unsigned AS = 0);
   static PointerType *getX86_FP80PtrTy(LLVMContext &C, unsigned AS = 0);
index 06631fc..eb85ef7 100644 (file)
@@ -820,6 +820,7 @@ lltok::Kind LLLexer::LexIdentifier() {
 
   TYPEKEYWORD("void",      Type::getVoidTy(Context));
   TYPEKEYWORD("half",      Type::getHalfTy(Context));
+  TYPEKEYWORD("bfloat",    Type::getBFloatTy(Context));
   TYPEKEYWORD("float",     Type::getFloatTy(Context));
   TYPEKEYWORD("double",    Type::getDoubleTy(Context));
   TYPEKEYWORD("x86_fp80",  Type::getX86_FP80Ty(Context));
@@ -985,11 +986,13 @@ lltok::Kind LLLexer::LexIdentifier() {
 ///    HexFP128Constant  0xL[0-9A-Fa-f]+
 ///    HexPPC128Constant 0xM[0-9A-Fa-f]+
 ///    HexHalfConstant   0xH[0-9A-Fa-f]+
+///    HexBFloatConstant 0xR[0-9A-Fa-f]+
 lltok::Kind LLLexer::Lex0x() {
   CurPtr = TokStart + 2;
 
   char Kind;
-  if ((CurPtr[0] >= 'K' && CurPtr[0] <= 'M') || CurPtr[0] == 'H') {
+  if ((CurPtr[0] >= 'K' && CurPtr[0] <= 'M') || CurPtr[0] == 'H' ||
+      CurPtr[0] == 'R') {
     Kind = *CurPtr++;
   } else {
     Kind = 'J';
@@ -1007,7 +1010,7 @@ lltok::Kind LLLexer::Lex0x() {
   if (Kind == 'J') {
     // HexFPConstant - Floating point constant represented in IEEE format as a
     // hexadecimal number for when exponential notation is not precise enough.
-    // Half, Float, and double only.
+    // Half, BFloat, Float, and double only.
     APFloatVal = APFloat(APFloat::IEEEdouble(),
                          APInt(64, HexIntToVal(TokStart + 2, CurPtr)));
     return lltok::APFloat;
@@ -1035,6 +1038,11 @@ lltok::Kind LLLexer::Lex0x() {
     APFloatVal = APFloat(APFloat::IEEEhalf(),
                          APInt(16,HexIntToVal(TokStart+3, CurPtr)));
     return lltok::APFloat;
+  case 'R':
+    // Brain floating point
+    APFloatVal = APFloat(APFloat::BFloat(),
+                         APInt(16, HexIntToVal(TokStart + 3, CurPtr)));
+    return lltok::APFloat;
   }
 }
 
index ce1e9d2..d045bcd 100644 (file)
@@ -5247,13 +5247,16 @@ bool LLParser::ConvertValIDToValue(Type *Ty, ValID &ID, Value *&V,
         !ConstantFP::isValueValidForType(Ty, ID.APFloatVal))
       return Error(ID.Loc, "floating point constant invalid for type");
 
-    // The lexer has no type info, so builds all half, float, and double FP
-    // constants as double.  Fix this here.  Long double does not need this.
+    // The lexer has no type info, so builds all half, bfloat, float, and double
+    // FP constants as double.  Fix this here.  Long double does not need this.
     if (&ID.APFloatVal.getSemantics() == &APFloat::IEEEdouble()) {
       bool Ignored;
       if (Ty->isHalfTy())
         ID.APFloatVal.convert(APFloat::IEEEhalf(), APFloat::rmNearestTiesToEven,
                               &Ignored);
+      else if (Ty->isBFloatTy())
+        ID.APFloatVal.convert(APFloat::BFloat(), APFloat::rmNearestTiesToEven,
+                              &Ignored);
       else if (Ty->isFloatTy())
         ID.APFloatVal.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven,
                               &Ignored);
index bdc0fa7..21759c5 100644 (file)
@@ -1720,6 +1720,9 @@ Error BitcodeReader::parseTypeTableBody() {
     case bitc::TYPE_CODE_HALF:     // HALF
       ResultTy = Type::getHalfTy(Context);
       break;
+    case bitc::TYPE_CODE_BFLOAT:    // BFLOAT
+      ResultTy = Type::getBFloatTy(Context);
+      break;
     case bitc::TYPE_CODE_FLOAT:     // FLOAT
       ResultTy = Type::getFloatTy(Context);
       break;
@@ -2429,6 +2432,9 @@ Error BitcodeReader::parseConstants() {
       if (CurTy->isHalfTy())
         V = ConstantFP::get(Context, APFloat(APFloat::IEEEhalf(),
                                              APInt(16, (uint16_t)Record[0])));
+      else if (CurTy->isBFloatTy())
+        V = ConstantFP::get(Context, APFloat(APFloat::BFloat(),
+                                             APInt(16, (uint32_t)Record[0])));
       else if (CurTy->isFloatTy())
         V = ConstantFP::get(Context, APFloat(APFloat::IEEEsingle(),
                                              APInt(32, (uint32_t)Record[0])));
@@ -2526,21 +2532,27 @@ Error BitcodeReader::parseConstants() {
       } else if (EltTy->isHalfTy()) {
         SmallVector<uint16_t, 16> Elts(Record.begin(), Record.end());
         if (isa<VectorType>(CurTy))
-          V = ConstantDataVector::getFP(Context, Elts);
+          V = ConstantDataVector::getFP(EltTy, Elts);
+        else
+          V = ConstantDataArray::getFP(EltTy, Elts);
+      } else if (EltTy->isBFloatTy()) {
+        SmallVector<uint16_t, 16> Elts(Record.begin(), Record.end());
+        if (isa<VectorType>(CurTy))
+          V = ConstantDataVector::getFP(EltTy, Elts);
         else
-          V = ConstantDataArray::getFP(Context, Elts);
+          V = ConstantDataArray::getFP(EltTy, Elts);
       } else if (EltTy->isFloatTy()) {
         SmallVector<uint32_t, 16> Elts(Record.begin(), Record.end());
         if (isa<VectorType>(CurTy))
-          V = ConstantDataVector::getFP(Context, Elts);
+          V = ConstantDataVector::getFP(EltTy, Elts);
         else
-          V = ConstantDataArray::getFP(Context, Elts);
+          V = ConstantDataArray::getFP(EltTy, Elts);
       } else if (EltTy->isDoubleTy()) {
         SmallVector<uint64_t, 16> Elts(Record.begin(), Record.end());
         if (isa<VectorType>(CurTy))
-          V = ConstantDataVector::getFP(Context, Elts);
+          V = ConstantDataVector::getFP(EltTy, Elts);
         else
-          V = ConstantDataArray::getFP(Context, Elts);
+          V = ConstantDataArray::getFP(EltTy, Elts);
       } else {
         return error("Invalid type for value");
       }
index 9e389fc..5b62a47 100644 (file)
@@ -881,6 +881,7 @@ void ModuleBitcodeWriter::writeTypeTable() {
     switch (T->getTypeID()) {
     case Type::VoidTyID:      Code = bitc::TYPE_CODE_VOID;      break;
     case Type::HalfTyID:      Code = bitc::TYPE_CODE_HALF;      break;
+    case Type::BFloatTyID:    Code = bitc::TYPE_CODE_BFLOAT;    break;
     case Type::FloatTyID:     Code = bitc::TYPE_CODE_FLOAT;     break;
     case Type::DoubleTyID:    Code = bitc::TYPE_CODE_DOUBLE;    break;
     case Type::X86_FP80TyID:  Code = bitc::TYPE_CODE_X86_FP80;  break;
@@ -2387,7 +2388,8 @@ void ModuleBitcodeWriter::writeConstants(unsigned FirstVal, unsigned LastVal,
     } else if (const ConstantFP *CFP = dyn_cast<ConstantFP>(C)) {
       Code = bitc::CST_CODE_FLOAT;
       Type *Ty = CFP->getType();
-      if (Ty->isHalfTy() || Ty->isFloatTy() || Ty->isDoubleTy()) {
+      if (Ty->isHalfTy() || Ty->isBFloatTy() || Ty->isFloatTy() ||
+          Ty->isDoubleTy()) {
         Record.push_back(CFP->getValueAPF().bitcastToAPInt().getZExtValue());
       } else if (Ty->isX86_FP80Ty()) {
         // api needed to prevent premature destruction
index e4852d3..0fbedc4 100644 (file)
@@ -534,7 +534,7 @@ static Cursor maybeLexMCSymbol(Cursor C, MIToken &Token,
 }
 
 static bool isValidHexFloatingPointPrefix(char C) {
-  return C == 'H' || C == 'K' || C == 'L' || C == 'M';
+  return C == 'H' || C == 'K' || C == 'L' || C == 'M' || C == 'R';
 }
 
 static Cursor lexFloatingPointLiteral(Cursor Range, Cursor C, MIToken &Token) {
index 6f451a1..72da461 100644 (file)
@@ -588,6 +588,7 @@ void TypePrinting::print(Type *Ty, raw_ostream &OS) {
   switch (Ty->getTypeID()) {
   case Type::VoidTyID:      OS << "void"; return;
   case Type::HalfTyID:      OS << "half"; return;
+  case Type::BFloatTyID:    OS << "bfloat"; return;
   case Type::FloatTyID:     OS << "float"; return;
   case Type::DoubleTyID:    OS << "double"; return;
   case Type::X86_FP80TyID:  OS << "x86_fp80"; return;
@@ -1379,7 +1380,7 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
       return;
     }
 
-    // Either half, or some form of long double.
+    // Either half, bfloat or some form of long double.
     // These appear as a magic letter identifying the type, then a
     // fixed number of hex digits.
     Out << "0x";
@@ -1407,6 +1408,10 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
       Out << 'H';
       Out << format_hex_no_prefix(API.getZExtValue(), 4,
                                   /*Upper=*/true);
+    } else if (&APF.getSemantics() == &APFloat::BFloat()) {
+      Out << 'R';
+      Out << format_hex_no_prefix(API.getZExtValue(), 4,
+                                  /*Upper=*/true);
     } else
       llvm_unreachable("Unsupported floating point type");
     return;
index 5a3c6a4..88971d8 100644 (file)
@@ -332,6 +332,9 @@ Constant *Constant::getNullValue(Type *Ty) {
   case Type::HalfTyID:
     return ConstantFP::get(Ty->getContext(),
                            APFloat::getZero(APFloat::IEEEhalf()));
+  case Type::BFloatTyID:
+    return ConstantFP::get(Ty->getContext(),
+                           APFloat::getZero(APFloat::BFloat()));
   case Type::FloatTyID:
     return ConstantFP::get(Ty->getContext(),
                            APFloat::getZero(APFloat::IEEEsingle()));
@@ -386,8 +389,8 @@ Constant *Constant::getAllOnesValue(Type *Ty) {
                             APInt::getAllOnesValue(ITy->getBitWidth()));
 
   if (Ty->isFloatingPointTy()) {
-    APFloat FL = APFloat::getAllOnesValue(Ty->getPrimitiveSizeInBits(),
-                                          !Ty->isPPC_FP128Ty());
+    APFloat FL = APFloat::getAllOnesValue(Ty->getFltSemantics(),
+                                          Ty->getPrimitiveSizeInBits());
     return ConstantFP::get(Ty->getContext(), FL);
   }
 
@@ -763,6 +766,8 @@ void ConstantInt::destroyConstantImpl() {
 static const fltSemantics *TypeToFloatSemantics(Type *Ty) {
   if (Ty->isHalfTy())
     return &APFloat::IEEEhalf();
+  if (Ty->isBFloatTy())
+    return &APFloat::BFloat();
   if (Ty->isFloatTy())
     return &APFloat::IEEEsingle();
   if (Ty->isDoubleTy())
@@ -880,6 +885,8 @@ ConstantFP* ConstantFP::get(LLVMContext &Context, const APFloat& V) {
     Type *Ty;
     if (&V.getSemantics() == &APFloat::IEEEhalf())
       Ty = Type::getHalfTy(Context);
+    else if (&V.getSemantics() == &APFloat::BFloat())
+      Ty = Type::getBFloatTy(Context);
     else if (&V.getSemantics() == &APFloat::IEEEsingle())
       Ty = Type::getFloatTy(Context);
     else if (&V.getSemantics() == &APFloat::IEEEdouble())
@@ -1029,7 +1036,7 @@ static Constant *getFPSequenceIfElementsMatch(ArrayRef<Constant *> V) {
       Elts.push_back(CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
     else
       return nullptr;
-  return SequentialTy::getFP(V[0]->getContext(), Elts);
+  return SequentialTy::getFP(V[0]->getType(), Elts);
 }
 
 template <typename SequenceTy>
@@ -1048,7 +1055,7 @@ static Constant *getSequenceIfElementsMatch(Constant *C,
     else if (CI->getType()->isIntegerTy(64))
       return getIntSequenceIfElementsMatch<SequenceTy, uint64_t>(V);
   } else if (ConstantFP *CFP = dyn_cast<ConstantFP>(C)) {
-    if (CFP->getType()->isHalfTy())
+    if (CFP->getType()->isHalfTy() || CFP->getType()->isBFloatTy())
       return getFPSequenceIfElementsMatch<SequenceTy, uint16_t>(V);
     else if (CFP->getType()->isFloatTy())
       return getFPSequenceIfElementsMatch<SequenceTy, uint32_t>(V);
@@ -1421,6 +1428,12 @@ bool ConstantFP::isValueValidForType(Type *Ty, const APFloat& Val) {
     Val2.convert(APFloat::IEEEhalf(), APFloat::rmNearestTiesToEven, &losesInfo);
     return !losesInfo;
   }
+  case Type::BFloatTyID: {
+    if (&Val2.getSemantics() == &APFloat::BFloat())
+      return true;
+    Val2.convert(APFloat::BFloat(), APFloat::rmNearestTiesToEven, &losesInfo);
+    return !losesInfo;
+  }
   case Type::FloatTyID: {
     if (&Val2.getSemantics() == &APFloat::IEEEsingle())
       return true;
@@ -1429,6 +1442,7 @@ bool ConstantFP::isValueValidForType(Type *Ty, const APFloat& Val) {
   }
   case Type::DoubleTyID: {
     if (&Val2.getSemantics() == &APFloat::IEEEhalf() ||
+        &Val2.getSemantics() == &APFloat::BFloat() ||
         &Val2.getSemantics() == &APFloat::IEEEsingle() ||
         &Val2.getSemantics() == &APFloat::IEEEdouble())
       return true;
@@ -1437,16 +1451,19 @@ bool ConstantFP::isValueValidForType(Type *Ty, const APFloat& Val) {
   }
   case Type::X86_FP80TyID:
     return &Val2.getSemantics() == &APFloat::IEEEhalf() ||
+           &Val2.getSemantics() == &APFloat::BFloat() ||
            &Val2.getSemantics() == &APFloat::IEEEsingle() ||
            &Val2.getSemantics() == &APFloat::IEEEdouble() ||
            &Val2.getSemantics() == &APFloat::x87DoubleExtended();
   case Type::FP128TyID:
     return &Val2.getSemantics() == &APFloat::IEEEhalf() ||
+           &Val2.getSemantics() == &APFloat::BFloat() ||
            &Val2.getSemantics() == &APFloat::IEEEsingle() ||
            &Val2.getSemantics() == &APFloat::IEEEdouble() ||
            &Val2.getSemantics() == &APFloat::IEEEquad();
   case Type::PPC_FP128TyID:
     return &Val2.getSemantics() == &APFloat::IEEEhalf() ||
+           &Val2.getSemantics() == &APFloat::BFloat() ||
            &Val2.getSemantics() == &APFloat::IEEEsingle() ||
            &Val2.getSemantics() == &APFloat::IEEEdouble() ||
            &Val2.getSemantics() == &APFloat::PPCDoubleDouble();
@@ -2562,7 +2579,8 @@ StringRef ConstantDataSequential::getRawDataValues() const {
 }
 
 bool ConstantDataSequential::isElementTypeCompatible(Type *Ty) {
-  if (Ty->isHalfTy() || Ty->isFloatTy() || Ty->isDoubleTy()) return true;
+  if (Ty->isHalfTy() || Ty->isBFloatTy() || Ty->isFloatTy() || Ty->isDoubleTy())
+    return true;
   if (auto *IT = dyn_cast<IntegerType>(Ty)) {
     switch (IT->getBitWidth()) {
     case 8:
@@ -2680,26 +2698,29 @@ void ConstantDataSequential::destroyConstantImpl() {
   Next = nullptr;
 }
 
-/// getFP() constructors - Return a constant with array type with an element
-/// count and element type of float with precision matching the number of
-/// bits in the ArrayRef passed in. (i.e. half for 16bits, float for 32bits,
-/// double for 64bits) Note that this can return a ConstantAggregateZero
-/// object.
-Constant *ConstantDataArray::getFP(LLVMContext &Context,
-                                   ArrayRef<uint16_t> Elts) {
-  Type *Ty = ArrayType::get(Type::getHalfTy(Context), Elts.size());
+/// getFP() constructors - Return a constant of array type with a float
+/// element type taken from argument `ElementType', and count taken from
+/// argument `Elts'.  The amount of bits of the contained type must match the
+/// number of bits of the type contained in the passed in ArrayRef.
+/// (i.e. half or bfloat for 16bits, float for 32bits, double for 64bits) Note
+/// that this can return a ConstantAggregateZero object.
+Constant *ConstantDataArray::getFP(Type *ElementType, ArrayRef<uint16_t> Elts) {
+  assert((ElementType->isHalfTy() || ElementType->isBFloatTy()) &&
+         "Element type is not a 16-bit float type");
+  Type *Ty = ArrayType::get(ElementType, Elts.size());
   const char *Data = reinterpret_cast<const char *>(Elts.data());
   return getImpl(StringRef(Data, Elts.size() * 2), Ty);
 }
-Constant *ConstantDataArray::getFP(LLVMContext &Context,
-                                   ArrayRef<uint32_t> Elts) {
-  Type *Ty = ArrayType::get(Type::getFloatTy(Context), Elts.size());
+Constant *ConstantDataArray::getFP(Type *ElementType, ArrayRef<uint32_t> Elts) {
+  assert(ElementType->isFloatTy() && "Element type is not a 32-bit float type");
+  Type *Ty = ArrayType::get(ElementType, Elts.size());
   const char *Data = reinterpret_cast<const char *>(Elts.data());
   return getImpl(StringRef(Data, Elts.size() * 4), Ty);
 }
-Constant *ConstantDataArray::getFP(LLVMContext &Context,
-                                   ArrayRef<uint64_t> Elts) {
-  Type *Ty = ArrayType::get(Type::getDoubleTy(Context), Elts.size());
+Constant *ConstantDataArray::getFP(Type *ElementType, ArrayRef<uint64_t> Elts) {
+  assert(ElementType->isDoubleTy() &&
+         "Element type is not a 64-bit float type");
+  Type *Ty = ArrayType::get(ElementType, Elts.size());
   const char *Data = reinterpret_cast<const char *>(Elts.data());
   return getImpl(StringRef(Data, Elts.size() * 8), Ty);
 }
@@ -2751,26 +2772,32 @@ Constant *ConstantDataVector::get(LLVMContext &Context, ArrayRef<double> Elts) {
   return getImpl(StringRef(Data, Elts.size() * 8), Ty);
 }
 
-/// getFP() constructors - Return a constant with vector type with an element
-/// count and element type of float with the precision matching the number of
-/// bits in the ArrayRef passed in.  (i.e. half for 16bits, float for 32bits,
-/// double for 64bits) Note that this can return a ConstantAggregateZero
-/// object.
-Constant *ConstantDataVector::getFP(LLVMContext &Context,
+/// getFP() constructors - Return a constant of vector type with a float
+/// element type taken from argument `ElementType', and count taken from
+/// argument `Elts'.  The amount of bits of the contained type must match the
+/// number of bits of the type contained in the passed in ArrayRef.
+/// (i.e. half or bfloat for 16bits, float for 32bits, double for 64bits) Note
+/// that this can return a ConstantAggregateZero object.
+Constant *ConstantDataVector::getFP(Type *ElementType,
                                     ArrayRef<uint16_t> Elts) {
-  Type *Ty = VectorType::get(Type::getHalfTy(Context), Elts.size());
+  assert((ElementType->isHalfTy() || ElementType->isBFloatTy()) &&
+         "Element type is not a 16-bit float type");
+  Type *Ty = VectorType::get(ElementType, Elts.size());
   const char *Data = reinterpret_cast<const char *>(Elts.data());
   return getImpl(StringRef(Data, Elts.size() * 2), Ty);
 }
-Constant *ConstantDataVector::getFP(LLVMContext &Context,
+Constant *ConstantDataVector::getFP(Type *ElementType,
                                     ArrayRef<uint32_t> Elts) {
-  Type *Ty = VectorType::get(Type::getFloatTy(Context), Elts.size());
+  assert(ElementType->isFloatTy() && "Element type is not a 32-bit float type");
+  Type *Ty = VectorType::get(ElementType, Elts.size());
   const char *Data = reinterpret_cast<const char *>(Elts.data());
   return getImpl(StringRef(Data, Elts.size() * 4), Ty);
 }
-Constant *ConstantDataVector::getFP(LLVMContext &Context,
+Constant *ConstantDataVector::getFP(Type *ElementType,
                                     ArrayRef<uint64_t> Elts) {
-  Type *Ty = VectorType::get(Type::getDoubleTy(Context), Elts.size());
+  assert(ElementType->isDoubleTy() &&
+         "Element type is not a 64-bit float type");
+  Type *Ty = VectorType::get(ElementType, Elts.size());
   const char *Data = reinterpret_cast<const char *>(Elts.data());
   return getImpl(StringRef(Data, Elts.size() * 8), Ty);
 }
@@ -2800,17 +2827,22 @@ Constant *ConstantDataVector::getSplat(unsigned NumElts, Constant *V) {
     if (CFP->getType()->isHalfTy()) {
       SmallVector<uint16_t, 16> Elts(
           NumElts, CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
-      return getFP(V->getContext(), Elts);
+      return getFP(V->getType(), Elts);
+    }
+    if (CFP->getType()->isBFloatTy()) {
+      SmallVector<uint16_t, 16> Elts(
+          NumElts, CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
+      return getFP(V->getType(), Elts);
     }
     if (CFP->getType()->isFloatTy()) {
       SmallVector<uint32_t, 16> Elts(
           NumElts, CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
-      return getFP(V->getContext(), Elts);
+      return getFP(V->getType(), Elts);
     }
     if (CFP->getType()->isDoubleTy()) {
       SmallVector<uint64_t, 16> Elts(
           NumElts, CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
-      return getFP(V->getContext(), Elts);
+      return getFP(V->getType(), Elts);
     }
   }
   return ConstantVector::getSplat({NumElts, false}, V);
@@ -2875,6 +2907,10 @@ APFloat ConstantDataSequential::getElementAsAPFloat(unsigned Elt) const {
     auto EltVal = *reinterpret_cast<const uint16_t *>(EltPtr);
     return APFloat(APFloat::IEEEhalf(), APInt(16, EltVal));
   }
+  case Type::BFloatTyID: {
+    auto EltVal = *reinterpret_cast<const uint16_t *>(EltPtr);
+    return APFloat(APFloat::BFloat(), APInt(16, EltVal));
+  }
   case Type::FloatTyID: {
     auto EltVal = *reinterpret_cast<const uint32_t *>(EltPtr);
     return APFloat(APFloat::IEEEsingle(), APInt(32, EltVal));
@@ -2899,8 +2935,8 @@ double ConstantDataSequential::getElementAsDouble(unsigned Elt) const {
 }
 
 Constant *ConstantDataSequential::getElementAsConstant(unsigned Elt) const {
-  if (getElementType()->isHalfTy() || getElementType()->isFloatTy() ||
-      getElementType()->isDoubleTy())
+  if (getElementType()->isHalfTy() || getElementType()->isBFloatTy() ||
+      getElementType()->isFloatTy() || getElementType()->isDoubleTy())
     return ConstantFP::get(getContext(), getElementAsAPFloat(Elt));
 
   return ConstantInt::get(getElementType(), getElementAsInteger(Elt));
index 696c25f..3bb1937 100644 (file)
@@ -477,6 +477,8 @@ LLVMTypeKind LLVMGetTypeKind(LLVMTypeRef Ty) {
     return LLVMVoidTypeKind;
   case Type::HalfTyID:
     return LLVMHalfTypeKind;
+  case Type::BFloatTyID:
+    return LLVMBFloatTypeKind;
   case Type::FloatTyID:
     return LLVMFloatTypeKind;
   case Type::DoubleTyID:
@@ -595,6 +597,9 @@ unsigned LLVMGetIntTypeWidth(LLVMTypeRef IntegerTy) {
 LLVMTypeRef LLVMHalfTypeInContext(LLVMContextRef C) {
   return (LLVMTypeRef) Type::getHalfTy(*unwrap(C));
 }
+LLVMTypeRef LLVMBFloatTypeInContext(LLVMContextRef C) {
+  return (LLVMTypeRef) Type::getBFloatTy(*unwrap(C));
+}
 LLVMTypeRef LLVMFloatTypeInContext(LLVMContextRef C) {
   return (LLVMTypeRef) Type::getFloatTy(*unwrap(C));
 }
@@ -617,6 +622,9 @@ LLVMTypeRef LLVMX86MMXTypeInContext(LLVMContextRef C) {
 LLVMTypeRef LLVMHalfType(void) {
   return LLVMHalfTypeInContext(LLVMGetGlobalContext());
 }
+LLVMTypeRef LLVMBFloatType(void) {
+  return LLVMBFloatTypeInContext(LLVMGetGlobalContext());
+}
 LLVMTypeRef LLVMFloatType(void) {
   return LLVMFloatTypeInContext(LLVMGetGlobalContext());
 }
index 0a25f1c..87563d9 100644 (file)
@@ -162,7 +162,7 @@ static const LayoutAlignElem DefaultAlignments[] = {
     {INTEGER_ALIGN, 16, Align(2), Align(2)},   // i16
     {INTEGER_ALIGN, 32, Align(4), Align(4)},   // i32
     {INTEGER_ALIGN, 64, Align(4), Align(8)},   // i64
-    {FLOAT_ALIGN, 16, Align(2), Align(2)},     // half
+    {FLOAT_ALIGN, 16, Align(2), Align(2)},     // half, bfloat
     {FLOAT_ALIGN, 32, Align(4), Align(4)},     // float
     {FLOAT_ALIGN, 64, Align(8), Align(8)},     // double
     {FLOAT_ALIGN, 128, Align(16), Align(16)},  // ppcf128, quad, ...
@@ -732,6 +732,7 @@ Align DataLayout::getAlignment(Type *Ty, bool abi_or_pref) const {
     AlignType = INTEGER_ALIGN;
     break;
   case Type::HalfTyID:
+  case Type::BFloatTyID:
   case Type::FloatTyID:
   case Type::DoubleTyID:
   // PPC_FP128TyID and FP128TyID have different data contents, but the
index dab1c33..7bf3ab5 100644 (file)
@@ -655,6 +655,7 @@ static std::string getMangledTypeStr(Type* Ty) {
     case Type::VoidTyID:      Result += "isVoid";   break;
     case Type::MetadataTyID:  Result += "Metadata"; break;
     case Type::HalfTyID:      Result += "f16";      break;
+    case Type::BFloatTyID:    Result += "bf16";     break;
     case Type::FloatTyID:     Result += "f32";      break;
     case Type::DoubleTyID:    Result += "f64";      break;
     case Type::X86_FP80TyID:  Result += "f80";      break;
index 68b8f8a..f197b3e 100644 (file)
@@ -26,6 +26,7 @@ LLVMContextImpl::LLVMContextImpl(LLVMContext &C)
     VoidTy(C, Type::VoidTyID),
     LabelTy(C, Type::LabelTyID),
     HalfTy(C, Type::HalfTyID),
+    BFloatTy(C, Type::BFloatTyID),
     FloatTy(C, Type::FloatTyID),
     DoubleTy(C, Type::DoubleTyID),
     MetadataTy(C, Type::MetadataTyID),
index a019f1e..9912808 100644 (file)
@@ -1342,7 +1342,8 @@ public:
   std::unique_ptr<ConstantTokenNone> TheNoneToken;
 
   // Basic type instances.
-  Type VoidTy, LabelTy, HalfTy, FloatTy, DoubleTy, MetadataTy, TokenTy;
+  Type VoidTy, LabelTy, HalfTy, BFloatTy, FloatTy, DoubleTy, MetadataTy,
+      TokenTy;
   Type X86_FP80Ty, FP128Ty, PPC_FP128Ty, X86_MMXTy;
   IntegerType Int1Ty, Int8Ty, Int16Ty, Int32Ty, Int64Ty, Int128Ty;
 
index 8e5b032..bb077c1 100644 (file)
@@ -40,6 +40,7 @@ Type *Type::getPrimitiveType(LLVMContext &C, TypeID IDNumber) {
   switch (IDNumber) {
   case VoidTyID      : return getVoidTy(C);
   case HalfTyID      : return getHalfTy(C);
+  case BFloatTyID    : return getBFloatTy(C);
   case FloatTyID     : return getFloatTy(C);
   case DoubleTyID    : return getDoubleTy(C);
   case X86_FP80TyID  : return getX86_FP80Ty(C);
@@ -112,6 +113,7 @@ bool Type::isEmptyTy() const {
 TypeSize Type::getPrimitiveSizeInBits() const {
   switch (getTypeID()) {
   case Type::HalfTyID: return TypeSize::Fixed(16);
+  case Type::BFloatTyID: return TypeSize::Fixed(16);
   case Type::FloatTyID: return TypeSize::Fixed(32);
   case Type::DoubleTyID: return TypeSize::Fixed(64);
   case Type::X86_FP80TyID: return TypeSize::Fixed(80);
@@ -142,6 +144,7 @@ int Type::getFPMantissaWidth() const {
     return VTy->getElementType()->getFPMantissaWidth();
   assert(isFloatingPointTy() && "Not a floating point type!");
   if (getTypeID() == HalfTyID) return 11;
+  if (getTypeID() == BFloatTyID) return 8;
   if (getTypeID() == FloatTyID) return 24;
   if (getTypeID() == DoubleTyID) return 53;
   if (getTypeID() == X86_FP80TyID) return 64;
@@ -167,6 +170,7 @@ bool Type::isSizedDerivedType(SmallPtrSetImpl<Type*> *Visited) const {
 Type *Type::getVoidTy(LLVMContext &C) { return &C.pImpl->VoidTy; }
 Type *Type::getLabelTy(LLVMContext &C) { return &C.pImpl->LabelTy; }
 Type *Type::getHalfTy(LLVMContext &C) { return &C.pImpl->HalfTy; }
+Type *Type::getBFloatTy(LLVMContext &C) { return &C.pImpl->BFloatTy; }
 Type *Type::getFloatTy(LLVMContext &C) { return &C.pImpl->FloatTy; }
 Type *Type::getDoubleTy(LLVMContext &C) { return &C.pImpl->DoubleTy; }
 Type *Type::getMetadataTy(LLVMContext &C) { return &C.pImpl->MetadataTy; }
@@ -191,6 +195,10 @@ PointerType *Type::getHalfPtrTy(LLVMContext &C, unsigned AS) {
   return getHalfTy(C)->getPointerTo(AS);
 }
 
+PointerType *Type::getBFloatPtrTy(LLVMContext &C, unsigned AS) {
+  return getBFloatTy(C)->getPointerTo(AS);
+}
+
 PointerType *Type::getFloatPtrTy(LLVMContext &C, unsigned AS) {
   return getFloatTy(C)->getPointerTo(AS);
 }
index 63114fa..78f44c5 100644 (file)
@@ -69,6 +69,7 @@ namespace llvm {
   };
 
   static const fltSemantics semIEEEhalf = {15, -14, 11, 16};
+  static const fltSemantics semBFloat = {127, -126, 8, 16};
   static const fltSemantics semIEEEsingle = {127, -126, 24, 32};
   static const fltSemantics semIEEEdouble = {1023, -1022, 53, 64};
   static const fltSemantics semIEEEquad = {16383, -16382, 113, 128};
@@ -117,6 +118,8 @@ namespace llvm {
     switch (S) {
     case S_IEEEhalf:
       return IEEEhalf();
+    case S_BFloat:
+      return BFloat();
     case S_IEEEsingle:
       return IEEEsingle();
     case S_IEEEdouble:
@@ -135,6 +138,8 @@ namespace llvm {
   APFloatBase::SemanticsToEnum(const llvm::fltSemantics &Sem) {
     if (&Sem == &llvm::APFloat::IEEEhalf())
       return S_IEEEhalf;
+    else if (&Sem == &llvm::APFloat::BFloat())
+      return S_BFloat;
     else if (&Sem == &llvm::APFloat::IEEEsingle())
       return S_IEEEsingle;
     else if (&Sem == &llvm::APFloat::IEEEdouble())
@@ -152,6 +157,9 @@ namespace llvm {
   const fltSemantics &APFloatBase::IEEEhalf() {
     return semIEEEhalf;
   }
+  const fltSemantics &APFloatBase::BFloat() {
+    return semBFloat;
+  }
   const fltSemantics &APFloatBase::IEEEsingle() {
     return semIEEEsingle;
   }
@@ -3255,6 +3263,33 @@ APInt IEEEFloat::convertFloatAPFloatToAPInt() const {
                     (mysignificand & 0x7fffff)));
 }
 
+APInt IEEEFloat::convertBFloatAPFloatToAPInt() const {
+  assert(semantics == (const llvm::fltSemantics *)&semBFloat);
+  assert(partCount() == 1);
+
+  uint32_t myexponent, mysignificand;
+
+  if (isFiniteNonZero()) {
+    myexponent = exponent + 127; // bias
+    mysignificand = (uint32_t)*significandParts();
+    if (myexponent == 1 && !(mysignificand & 0x80))
+      myexponent = 0; // denormal
+  } else if (category == fcZero) {
+    myexponent = 0;
+    mysignificand = 0;
+  } else if (category == fcInfinity) {
+    myexponent = 0x1f;
+    mysignificand = 0;
+  } else {
+    assert(category == fcNaN && "Unknown category!");
+    myexponent = 0x1f;
+    mysignificand = (uint32_t)*significandParts();
+  }
+
+  return APInt(16, (((sign & 1) << 15) | ((myexponent & 0xff) << 7) |
+                    (mysignificand & 0x7f)));
+}
+
 APInt IEEEFloat::convertHalfAPFloatToAPInt() const {
   assert(semantics == (const llvm::fltSemantics*)&semIEEEhalf);
   assert(partCount()==1);
@@ -3290,6 +3325,9 @@ APInt IEEEFloat::bitcastToAPInt() const {
   if (semantics == (const llvm::fltSemantics*)&semIEEEhalf)
     return convertHalfAPFloatToAPInt();
 
+  if (semantics == (const llvm::fltSemantics *)&semBFloat)
+    return convertBFloatAPFloatToAPInt();
+
   if (semantics == (const llvm::fltSemantics*)&semIEEEsingle)
     return convertFloatAPFloatToAPInt();
 
@@ -3486,6 +3524,37 @@ void IEEEFloat::initFromFloatAPInt(const APInt &api) {
   }
 }
 
+void IEEEFloat::initFromBFloatAPInt(const APInt &api) {
+  assert(api.getBitWidth() == 16);
+  uint32_t i = (uint32_t)*api.getRawData();
+  uint32_t myexponent = (i >> 7) & 0xff;
+  uint32_t mysignificand = i & 0x7f;
+
+  initialize(&semBFloat);
+  assert(partCount() == 1);
+
+  sign = i >> 15;
+  if (myexponent == 0 && mysignificand == 0) {
+    // exponent, significand meaningless
+    category = fcZero;
+  } else if (myexponent == 0xff && mysignificand == 0) {
+    // exponent, significand meaningless
+    category = fcInfinity;
+  } else if (myexponent == 0xff && mysignificand != 0) {
+    // sign, exponent, significand meaningless
+    category = fcNaN;
+    *significandParts() = mysignificand;
+  } else {
+    category = fcNormal;
+    exponent = myexponent - 127; // bias
+    *significandParts() = mysignificand;
+    if (myexponent == 0) // denormal
+      exponent = -126;
+    else
+      *significandParts() |= 0x80; // integer bit
+  }
+}
+
 void IEEEFloat::initFromHalfAPInt(const APInt &api) {
   assert(api.getBitWidth()==16);
   uint32_t i = (uint32_t)*api.getRawData();
@@ -3524,6 +3593,8 @@ void IEEEFloat::initFromHalfAPInt(const APInt &api) {
 void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) {
   if (Sem == &semIEEEhalf)
     return initFromHalfAPInt(api);
+  if (Sem == &semBFloat)
+    return initFromBFloatAPInt(api);
   if (Sem == &semIEEEsingle)
     return initFromFloatAPInt(api);
   if (Sem == &semIEEEdouble)
@@ -4763,26 +4834,9 @@ APFloat::opStatus APFloat::convert(const fltSemantics &ToSemantics,
   llvm_unreachable("Unexpected semantics");
 }
 
-APFloat APFloat::getAllOnesValue(unsigned BitWidth, bool isIEEE) {
-  if (isIEEE) {
-    switch (BitWidth) {
-    case 16:
-      return APFloat(semIEEEhalf, APInt::getAllOnesValue(BitWidth));
-    case 32:
-      return APFloat(semIEEEsingle, APInt::getAllOnesValue(BitWidth));
-    case 64:
-      return APFloat(semIEEEdouble, APInt::getAllOnesValue(BitWidth));
-    case 80:
-      return APFloat(semX87DoubleExtended, APInt::getAllOnesValue(BitWidth));
-    case 128:
-      return APFloat(semIEEEquad, APInt::getAllOnesValue(BitWidth));
-    default:
-      llvm_unreachable("Unknown floating bit width");
-    }
-  } else {
-    assert(BitWidth == 128);
-    return APFloat(semPPCDoubleDouble, APInt::getAllOnesValue(BitWidth));
-  }
+APFloat APFloat::getAllOnesValue(const fltSemantics &Semantics,
+                                 unsigned BitWidth) {
+  return APFloat(Semantics, APInt::getAllOnesValue(BitWidth));
 }
 
 void APFloat::print(raw_ostream &OS) const {
index 97aee3a..0af2625 100644 (file)
@@ -323,6 +323,7 @@ unsigned HexagonTargetObjectFile::getSmallestAddressableSize(const Type *Ty,
   }
   case Type::FunctionTyID:
   case Type::VoidTyID:
+  case Type::BFloatTyID:
   case Type::X86_FP80TyID:
   case Type::FP128TyID:
   case Type::PPC_FP128TyID:
index 29ced39..249f7b2 100644 (file)
@@ -11526,8 +11526,9 @@ static SDValue lowerShuffleAsBitMask(const SDLoc &DL, MVT VT, SDValue V1,
   MVT LogicVT = VT;
   if (EltVT == MVT::f32 || EltVT == MVT::f64) {
     Zero = DAG.getConstantFP(0.0, DL, EltVT);
-    AllOnes = DAG.getConstantFP(
-        APFloat::getAllOnesValue(EltVT.getSizeInBits(), true), DL, EltVT);
+    APFloat AllOnesValue = APFloat::getAllOnesValue(
+        SelectionDAG::EVTToAPFloatSemantics(EltVT), EltVT.getSizeInBits());
+    AllOnes = DAG.getConstantFP(AllOnesValue, DL, EltVT);
     LogicVT =
         MVT::getVectorVT(EltVT == MVT::f64 ? MVT::i64 : MVT::i32, Mask.size());
   } else {
diff --git a/llvm/test/Assembler/bfloat.ll b/llvm/test/Assembler/bfloat.ll
new file mode 100644 (file)
index 0000000..c9c7b6d
--- /dev/null
@@ -0,0 +1,38 @@
+; RUN: llvm-as < %s | llvm-dis | FileCheck %s --check-prefix=ASSEM-DISASS
+; RUN: opt < %s -O3 -S | FileCheck %s --check-prefix=OPT
+; RUN: verify-uselistorder %s
+; Basic smoke tests for bfloat type.
+
+define bfloat @check_bfloat(bfloat %A) {
+; ASSEM-DISASS: ret bfloat %A
+    ret bfloat %A
+}
+
+define bfloat @check_bfloat_literal() {
+; ASSEM-DISASS: ret bfloat 0xR3149
+    ret bfloat 0xR3149
+}
+
+define <4 x bfloat> @check_fixed_vector() {
+; ASSEM-DISASS: ret <4 x bfloat> %tmp
+  %tmp = fadd <4 x bfloat> undef, undef
+  ret <4 x bfloat> %tmp
+}
+
+define <vscale x 4 x bfloat> @check_vector() {
+; ASSEM-DISASS: ret <vscale x 4 x bfloat> %tmp
+  %tmp = fadd <vscale x 4 x bfloat> undef, undef
+  ret <vscale x 4 x bfloat> %tmp
+}
+
+define bfloat @check_bfloat_constprop() {
+  %tmp = fadd bfloat 0xR40C0, 0xR40C0
+; OPT: 0xR4140
+  ret bfloat %tmp
+}
+
+define float @check_bfloat_convert() {
+  %tmp = fpext bfloat 0xR4C8D to float
+; OPT: 0x4191A00000000000
+  ret float %tmp
+}
index bf284da..49b9f74 100644 (file)
@@ -72,6 +72,8 @@ struct TypeCloner {
         return LLVMVoidTypeInContext(Ctx);
       case LLVMHalfTypeKind:
         return LLVMHalfTypeInContext(Ctx);
+      case LLVMBFloatTypeKind:
+        return LLVMHalfTypeInContext(Ctx);
       case LLVMFloatTypeKind:
         return LLVMFloatTypeInContext(Ctx);
       case LLVMDoubleTypeKind: