From 5c990d6994559225466cb256146f6440431b229e Mon Sep 17 00:00:00 2001 From: Diego Caballero Date: Thu, 4 Jun 2020 14:08:49 -0700 Subject: [PATCH] [mlir] Add support for bf16 to StandardToLLVM conversion Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D81127 --- mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h | 2 ++ mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp | 4 +--- mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 18 ++++++++++++------ .../Conversion/StandardToLLVM/convert-to-llvmir.mlir | 9 +++++++++ 4 files changed, 24 insertions(+), 9 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h index ccf2185..078cb1c 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -62,6 +62,7 @@ public: llvm::Type *getUnderlyingType() const; /// Utilities to identify types. + bool isBFloatTy() { return getUnderlyingType()->isBFloatTy(); } bool isHalfTy() { return getUnderlyingType()->isHalfTy(); } bool isFloatTy() { return getUnderlyingType()->isFloatTy(); } bool isDoubleTy() { return getUnderlyingType()->isDoubleTy(); } @@ -99,6 +100,7 @@ public: /// Utilities used to generate floating point types. static LLVMType getDoubleTy(LLVMDialect *dialect); static LLVMType getFloatTy(LLVMDialect *dialect); + static LLVMType getBFloatTy(LLVMDialect *dialect); static LLVMType getHalfTy(LLVMDialect *dialect); static LLVMType getFP128Ty(LLVMDialect *dialect); static LLVMType getX86_FP80Ty(LLVMDialect *dialect); diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp index 4294e00..5d3984d 100644 --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -201,9 +201,7 @@ Type LLVMTypeConverter::convertFloatType(FloatType type) { case mlir::StandardTypes::F16: return LLVM::LLVMType::getHalfTy(llvmDialect); case mlir::StandardTypes::BF16: { - auto *mlirContext = llvmDialect->getContext(); - return emitError(UnknownLoc::get(mlirContext), "unsupported type: BF16"), - Type(); + return LLVM::LLVMType::getBFloatTy(llvmDialect); } default: llvm_unreachable("non-float type in convertFloatType"); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 2c6478d..9fd8bfe 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -939,8 +939,9 @@ static LogicalResult verify(DialectCastOp op) { if (auto llvmType = type.dyn_cast()) { if (llvmType.isVectorTy()) llvmType = llvmType.getVectorElementType(); - if (llvmType.isIntegerTy() || llvmType.isHalfTy() || - llvmType.isFloatTy() || llvmType.isDoubleTy()) { + if (llvmType.isIntegerTy() || llvmType.isBFloatTy() || + llvmType.isHalfTy() || llvmType.isFloatTy() || + llvmType.isDoubleTy()) { return success(); } return op.emitOpError("type must be non-index integer types, float " @@ -1500,7 +1501,8 @@ static LogicalResult verify(AtomicRMWOp op) { } else if (op.bin_op() == AtomicBinOp::xchg) { if (!valType.isIntegerTy(8) && !valType.isIntegerTy(16) && !valType.isIntegerTy(32) && !valType.isIntegerTy(64) && - !valType.isHalfTy() && !valType.isFloatTy() && !valType.isDoubleTy()) + !valType.isBFloatTy() && !valType.isHalfTy() && !valType.isFloatTy() && + !valType.isDoubleTy()) return op.emitOpError("unexpected LLVM IR type for 'xchg' bin_op"); } else { if (!valType.isIntegerTy(8) && !valType.isIntegerTy(16) && @@ -1561,8 +1563,8 @@ static LogicalResult verify(AtomicCmpXchgOp op) { "match type for all other operands"); if (!valType.isPointerTy() && !valType.isIntegerTy(8) && !valType.isIntegerTy(16) && !valType.isIntegerTy(32) && - !valType.isIntegerTy(64) && !valType.isHalfTy() && !valType.isFloatTy() && - !valType.isDoubleTy()) + !valType.isIntegerTy(64) && !valType.isBFloatTy() && + !valType.isHalfTy() && !valType.isFloatTy() && !valType.isDoubleTy()) return op.emitOpError("unexpected LLVM IR type"); if (op.success_ordering() < AtomicOrdering::monotonic || op.failure_ordering() < AtomicOrdering::monotonic) @@ -1630,7 +1632,7 @@ struct LLVMDialectImpl { /// A set of LLVMTypes that are cached on construction to avoid any lookups or /// locking. LLVMType int1Ty, int8Ty, int16Ty, int32Ty, int64Ty, int128Ty; - LLVMType doubleTy, floatTy, halfTy, fp128Ty, x86_fp80Ty; + LLVMType doubleTy, floatTy, bfloatTy, halfTy, fp128Ty, x86_fp80Ty; LLVMType voidTy; /// A smart mutex to lock access to the llvm context. Unlike MLIR, LLVM is not @@ -1665,6 +1667,7 @@ LLVMDialect::LLVMDialect(MLIRContext *context) /// Float Types. impl->doubleTy = LLVMType::get(context, llvm::Type::getDoubleTy(llvmContext)); impl->floatTy = LLVMType::get(context, llvm::Type::getFloatTy(llvmContext)); + impl->bfloatTy = LLVMType::get(context, llvm::Type::getBFloatTy(llvmContext)); impl->halfTy = LLVMType::get(context, llvm::Type::getHalfTy(llvmContext)); impl->fp128Ty = LLVMType::get(context, llvm::Type::getFP128Ty(llvmContext)); impl->x86_fp80Ty = @@ -1827,6 +1830,9 @@ LLVMType LLVMType::getDoubleTy(LLVMDialect *dialect) { LLVMType LLVMType::getFloatTy(LLVMDialect *dialect) { return dialect->impl->floatTy; } +LLVMType LLVMType::getBFloatTy(LLVMDialect *dialect) { + return dialect->impl->bfloatTy; +} LLVMType LLVMType::getHalfTy(LLVMDialect *dialect) { return dialect->impl->halfTy; } diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir index e2c3238..ea21a6d 100644 --- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir @@ -1228,3 +1228,12 @@ func @mlir_cast_from_llvm(%0 : !llvm.half) -> f16 { // CHECK-NEXT: llvm.return %[[ARG]] return %1 : f16 } + +// ----- + +// CHECK-LABEL: func @bfloat +// CHECK-SAME: !llvm.bfloat) -> !llvm.bfloat +func @bfloat(%arg0: bf16) -> bf16 { + return %arg0 : bf16 +} +// CHECK-NEXT: return %{{.*}} : !llvm.bfloat -- 2.7.4