[mlir] Add support for bf16 to StandardToLLVM conversion
authorDiego Caballero <diego.caballero@intel.com>
Thu, 4 Jun 2020 21:08:49 +0000 (14:08 -0700)
committerDiego Caballero <diego.caballero@intel.com>
Thu, 4 Jun 2020 21:36:36 +0000 (14:36 -0700)
Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D81127

mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir

index ccf2185..078cb1c 100644 (file)
@@ -62,6 +62,7 @@ public:
   llvm::Type *getUnderlyingType() const;
 
   /// Utilities to identify types.
+  bool isBFloatTy() { return getUnderlyingType()->isBFloatTy(); }
   bool isHalfTy() { return getUnderlyingType()->isHalfTy(); }
   bool isFloatTy() { return getUnderlyingType()->isFloatTy(); }
   bool isDoubleTy() { return getUnderlyingType()->isDoubleTy(); }
@@ -99,6 +100,7 @@ public:
   /// Utilities used to generate floating point types.
   static LLVMType getDoubleTy(LLVMDialect *dialect);
   static LLVMType getFloatTy(LLVMDialect *dialect);
+  static LLVMType getBFloatTy(LLVMDialect *dialect);
   static LLVMType getHalfTy(LLVMDialect *dialect);
   static LLVMType getFP128Ty(LLVMDialect *dialect);
   static LLVMType getX86_FP80Ty(LLVMDialect *dialect);
index 4294e00..5d3984d 100644 (file)
@@ -201,9 +201,7 @@ Type LLVMTypeConverter::convertFloatType(FloatType type) {
   case mlir::StandardTypes::F16:
     return LLVM::LLVMType::getHalfTy(llvmDialect);
   case mlir::StandardTypes::BF16: {
-    auto *mlirContext = llvmDialect->getContext();
-    return emitError(UnknownLoc::get(mlirContext), "unsupported type: BF16"),
-           Type();
+    return LLVM::LLVMType::getBFloatTy(llvmDialect);
   }
   default:
     llvm_unreachable("non-float type in convertFloatType");
index 2c6478d..9fd8bfe 100644 (file)
@@ -939,8 +939,9 @@ static LogicalResult verify(DialectCastOp op) {
     if (auto llvmType = type.dyn_cast<LLVM::LLVMType>()) {
       if (llvmType.isVectorTy())
         llvmType = llvmType.getVectorElementType();
-      if (llvmType.isIntegerTy() || llvmType.isHalfTy() ||
-          llvmType.isFloatTy() || llvmType.isDoubleTy()) {
+      if (llvmType.isIntegerTy() || llvmType.isBFloatTy() ||
+          llvmType.isHalfTy() || llvmType.isFloatTy() ||
+          llvmType.isDoubleTy()) {
         return success();
       }
       return op.emitOpError("type must be non-index integer types, float "
@@ -1500,7 +1501,8 @@ static LogicalResult verify(AtomicRMWOp op) {
   } else if (op.bin_op() == AtomicBinOp::xchg) {
     if (!valType.isIntegerTy(8) && !valType.isIntegerTy(16) &&
         !valType.isIntegerTy(32) && !valType.isIntegerTy(64) &&
-        !valType.isHalfTy() && !valType.isFloatTy() && !valType.isDoubleTy())
+        !valType.isBFloatTy() && !valType.isHalfTy() && !valType.isFloatTy() &&
+        !valType.isDoubleTy())
       return op.emitOpError("unexpected LLVM IR type for 'xchg' bin_op");
   } else {
     if (!valType.isIntegerTy(8) && !valType.isIntegerTy(16) &&
@@ -1561,8 +1563,8 @@ static LogicalResult verify(AtomicCmpXchgOp op) {
                           "match type for all other operands");
   if (!valType.isPointerTy() && !valType.isIntegerTy(8) &&
       !valType.isIntegerTy(16) && !valType.isIntegerTy(32) &&
-      !valType.isIntegerTy(64) && !valType.isHalfTy() && !valType.isFloatTy() &&
-      !valType.isDoubleTy())
+      !valType.isIntegerTy(64) && !valType.isBFloatTy() &&
+      !valType.isHalfTy() && !valType.isFloatTy() && !valType.isDoubleTy())
     return op.emitOpError("unexpected LLVM IR type");
   if (op.success_ordering() < AtomicOrdering::monotonic ||
       op.failure_ordering() < AtomicOrdering::monotonic)
@@ -1630,7 +1632,7 @@ struct LLVMDialectImpl {
   /// A set of LLVMTypes that are cached on construction to avoid any lookups or
   /// locking.
   LLVMType int1Ty, int8Ty, int16Ty, int32Ty, int64Ty, int128Ty;
-  LLVMType doubleTy, floatTy, halfTy, fp128Ty, x86_fp80Ty;
+  LLVMType doubleTy, floatTy, bfloatTy, halfTy, fp128Ty, x86_fp80Ty;
   LLVMType voidTy;
 
   /// A smart mutex to lock access to the llvm context. Unlike MLIR, LLVM is not
@@ -1665,6 +1667,7 @@ LLVMDialect::LLVMDialect(MLIRContext *context)
   /// Float Types.
   impl->doubleTy = LLVMType::get(context, llvm::Type::getDoubleTy(llvmContext));
   impl->floatTy = LLVMType::get(context, llvm::Type::getFloatTy(llvmContext));
+  impl->bfloatTy = LLVMType::get(context, llvm::Type::getBFloatTy(llvmContext));
   impl->halfTy = LLVMType::get(context, llvm::Type::getHalfTy(llvmContext));
   impl->fp128Ty = LLVMType::get(context, llvm::Type::getFP128Ty(llvmContext));
   impl->x86_fp80Ty =
@@ -1827,6 +1830,9 @@ LLVMType LLVMType::getDoubleTy(LLVMDialect *dialect) {
 LLVMType LLVMType::getFloatTy(LLVMDialect *dialect) {
   return dialect->impl->floatTy;
 }
+LLVMType LLVMType::getBFloatTy(LLVMDialect *dialect) {
+  return dialect->impl->bfloatTy;
+}
 LLVMType LLVMType::getHalfTy(LLVMDialect *dialect) {
   return dialect->impl->halfTy;
 }
index e2c3238..ea21a6d 100644 (file)
@@ -1228,3 +1228,12 @@ func @mlir_cast_from_llvm(%0 : !llvm.half) -> f16 {
   // CHECK-NEXT: llvm.return %[[ARG]]
   return %1 : f16
 }
+
+// -----
+
+// CHECK-LABEL: func @bfloat
+// CHECK-SAME: !llvm.bfloat) -> !llvm.bfloat
+func @bfloat(%arg0: bf16) -> bf16 {
+  return %arg0 : bf16
+}
+// CHECK-NEXT: return %{{.*}} : !llvm.bfloat