From 657f60a07b2d382dd8580dd8a6111ea5c6e2d889 Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Thu, 2 Mar 2023 17:37:43 -0800 Subject: [PATCH] [mlir][vector] add support for printing f16 and bf16 Love or hate it, but the vector.print operation was the very first operation that actually made "end-to-end" CHECK integration testing possible for MLIR. This revision adds support for the -until recently- less common but important floating-point types f16 and bf16. This will become useful for accelerator specific testing (e.g. NVidia GPUs) Reviewed By: wrengr Differential Revision: https://reviews.llvm.org/D145207 --- .../mlir/Dialect/LLVMIR/FunctionCallUtils.h | 2 + mlir/include/mlir/ExecutionEngine/CRunnerUtils.h | 2 + .../VectorToLLVM/ConvertVectorToLLVM.cpp | 46 +++++++++++----------- mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp | 14 +++++++ mlir/lib/ExecutionEngine/Float16bits.cpp | 12 ++++++ .../Dialect/Vector/CPU/test-print-fp.mlir | 27 +++++++++++++ 6 files changed, 81 insertions(+), 22 deletions(-) create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/test-print-fp.mlir diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h index 39b35fe..17aa9a3 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h +++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h @@ -34,6 +34,8 @@ class LLVMFuncOp; /// of the libc). LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(ModuleOp moduleOp); LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(ModuleOp moduleOp); +LLVM::LLVMFuncOp lookupOrCreatePrintF16Fn(ModuleOp moduleOp); +LLVM::LLVMFuncOp lookupOrCreatePrintBF16Fn(ModuleOp moduleOp); LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(ModuleOp moduleOp); LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(ModuleOp moduleOp); LLVM::LLVMFuncOp lookupOrCreatePrintStrFn(ModuleOp moduleOp, diff --git a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h index e7798b2..7b7e894 100644 --- a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h +++ b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h @@ -469,6 +469,8 @@ extern "C" MLIR_CRUNNERUTILS_EXPORT void printOpen(); extern "C" MLIR_CRUNNERUTILS_EXPORT void printClose(); extern "C" MLIR_CRUNNERUTILS_EXPORT void printComma(); extern "C" MLIR_CRUNNERUTILS_EXPORT void printNewline(); +extern "C" MLIR_CRUNNERUTILS_EXPORT void printF16(uint16_t bits); // bits! +extern "C" MLIR_CRUNNERUTILS_EXPORT void printBF16(uint16_t bits); // bits! //===----------------------------------------------------------------------===// // Small runtime support library for timing execution and printing GFLOPS diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index d1b78bf..f705284 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1466,16 +1466,20 @@ public: PrintConversion conversion = PrintConversion::None; VectorType vectorType = printType.dyn_cast(); Type eltType = vectorType ? vectorType.getElementType() : printType; + auto parent = printOp->getParentOfType(); Operation *printer; if (eltType.isF32()) { - printer = - LLVM::lookupOrCreatePrintF32Fn(printOp->getParentOfType()); + printer = LLVM::lookupOrCreatePrintF32Fn(parent); } else if (eltType.isF64()) { - printer = - LLVM::lookupOrCreatePrintF64Fn(printOp->getParentOfType()); + printer = LLVM::lookupOrCreatePrintF64Fn(parent); + } else if (eltType.isF16()) { + conversion = PrintConversion::Bitcast16; // bits! + printer = LLVM::lookupOrCreatePrintF16Fn(parent); + } else if (eltType.isBF16()) { + conversion = PrintConversion::Bitcast16; // bits! + printer = LLVM::lookupOrCreatePrintBF16Fn(parent); } else if (eltType.isIndex()) { - printer = - LLVM::lookupOrCreatePrintU64Fn(printOp->getParentOfType()); + printer = LLVM::lookupOrCreatePrintU64Fn(parent); } else if (auto intTy = eltType.dyn_cast()) { // Integers need a zero or sign extension on the operand // (depending on the source type) as well as a signed or @@ -1485,8 +1489,7 @@ public: if (width <= 64) { if (width < 64) conversion = PrintConversion::ZeroExt64; - printer = LLVM::lookupOrCreatePrintU64Fn( - printOp->getParentOfType()); + printer = LLVM::lookupOrCreatePrintU64Fn(parent); } else { return failure(); } @@ -1499,8 +1502,7 @@ public: conversion = PrintConversion::ZeroExt64; else if (width < 64) conversion = PrintConversion::SignExt64; - printer = LLVM::lookupOrCreatePrintI64Fn( - printOp->getParentOfType()); + printer = LLVM::lookupOrCreatePrintI64Fn(parent); } else { return failure(); } @@ -1515,8 +1517,7 @@ public: emitRanks(rewriter, printOp, adaptor.getSource(), type, printer, rank, conversion); emitCall(rewriter, printOp->getLoc(), - LLVM::lookupOrCreatePrintNewlineFn( - printOp->getParentOfType())); + LLVM::lookupOrCreatePrintNewlineFn(parent)); rewriter.eraseOp(printOp); return success(); } @@ -1526,7 +1527,8 @@ private: // clang-format off None, ZeroExt64, - SignExt64 + SignExt64, + Bitcast16 // clang-format on }; @@ -1546,6 +1548,10 @@ private: value = rewriter.create( loc, IntegerType::get(rewriter.getContext(), 64), value); break; + case PrintConversion::Bitcast16: + value = rewriter.create( + loc, IntegerType::get(rewriter.getContext(), 16), value); + break; case PrintConversion::None: break; } @@ -1553,10 +1559,9 @@ private: return; } - emitCall(rewriter, loc, - LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType())); - Operation *printComma = - LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType()); + auto parent = op->getParentOfType(); + emitCall(rewriter, loc, LLVM::lookupOrCreatePrintOpenFn(parent)); + Operation *printComma = LLVM::lookupOrCreatePrintCommaFn(parent); if (rank <= 1) { auto reducedType = vectorType.getElementType(); @@ -1570,9 +1575,7 @@ private: if (d != dim - 1) emitCall(rewriter, loc, printComma); } - emitCall( - rewriter, loc, - LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType())); + emitCall(rewriter, loc, LLVM::lookupOrCreatePrintCloseFn(parent)); return; } @@ -1587,8 +1590,7 @@ private: if (d != dim - 1) emitCall(rewriter, loc, printComma); } - emitCall(rewriter, loc, - LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType())); + emitCall(rewriter, loc, LLVM::lookupOrCreatePrintCloseFn(parent)); } // Helper to emit a call. diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp index 237e576..aef3a5a 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp @@ -26,6 +26,8 @@ using namespace mlir::LLVM; /// part of the libc). static constexpr llvm::StringRef kPrintI64 = "printI64"; static constexpr llvm::StringRef kPrintU64 = "printU64"; +static constexpr llvm::StringRef kPrintF16 = "printF16"; +static constexpr llvm::StringRef kPrintBF16 = "printBF16"; static constexpr llvm::StringRef kPrintF32 = "printF32"; static constexpr llvm::StringRef kPrintF64 = "printF64"; static constexpr llvm::StringRef kPrintStr = "puts"; @@ -67,6 +69,18 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(ModuleOp moduleOp) { LLVM::LLVMVoidType::get(moduleOp->getContext())); } +LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF16Fn(ModuleOp moduleOp) { + return lookupOrCreateFn(moduleOp, kPrintF16, + IntegerType::get(moduleOp->getContext(), 16), // bits! + LLVM::LLVMVoidType::get(moduleOp->getContext())); +} + +LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintBF16Fn(ModuleOp moduleOp) { + return lookupOrCreateFn(moduleOp, kPrintBF16, + IntegerType::get(moduleOp->getContext(), 16), // bits! + LLVM::LLVMVoidType::get(moduleOp->getContext())); +} + LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(ModuleOp moduleOp) { return lookupOrCreateFn(moduleOp, kPrintF32, Float32Type::get(moduleOp->getContext()), diff --git a/mlir/lib/ExecutionEngine/Float16bits.cpp b/mlir/lib/ExecutionEngine/Float16bits.cpp index 189286b..38a05fe 100644 --- a/mlir/lib/ExecutionEngine/Float16bits.cpp +++ b/mlir/lib/ExecutionEngine/Float16bits.cpp @@ -192,4 +192,16 @@ extern "C" BF16ABIType ATTR_WEAK __truncdfbf2(double d) { return __truncsfbf2(static_cast(d)); } +// Provide these to the CRunner with the local float16 knowledge. +extern "C" void printF16(uint16_t bits) { + f16 f; + std::memcpy(&f, &bits, sizeof(f16)); + std::cout << f; +} +extern "C" void printBF16(uint16_t bits) { + bf16 f; + std::memcpy(&f, &bits, sizeof(bf16)); + std::cout << f; +} + #endif // MLIR_FLOAT16_DEFINE_FUNCTIONS diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-print-fp.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-print-fp.mlir new file mode 100644 index 0000000..eeee363 --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-print-fp.mlir @@ -0,0 +1,27 @@ +// RUN: mlir-opt %s -convert-scf-to-cf -convert-vector-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_c_runner_utils | \ +// RUN: FileCheck %s + +// +// Test various floating-point types. +// +func.func @entry() { + %0 = arith.constant dense<[-1000.0, -1.1, 0.0, 1.1, 1000.0]> : vector<5xf64> + vector.print %0 : vector<5xf64> + // CHECK: ( -1000, -1.1, 0, 1.1, 1000 ) + + %1 = arith.constant dense<[-1000.0, -1.1, 0.0, 1.1, 1000.0]> : vector<5xf32> + vector.print %1 : vector<5xf32> + // CHECK: ( -1000, -1.1, 0, 1.1, 1000 ) + + %2 = arith.constant dense<[-1000.0, -1.1, 0.0, 1.1, 1000.0]> : vector<5xf16> + vector.print %2 : vector<5xf16> + // CHECK: ( -1000, -1.09961, 0, 1.09961, 1000 ) + + %3 = arith.constant dense<[-1000.0, -1.1, 0.0, 1.1, 1000.0]> : vector<5xbf16> + vector.print %3 : vector<5xbf16> + // CHECK: ( -1000, -1.10156, 0, 1.10156, 1000 ) + + return +} -- 2.7.4