From a7c2102d988b2ae2214f1483d2b4066955b4dc98 Mon Sep 17 00:00:00 2001 From: "Balaji V. Iyer" Date: Fri, 7 Apr 2023 21:47:20 +0000 Subject: [PATCH] [mlir][math]Expand Fused math.fmaf to a multiply-add Fused multiply and add are being pushed directly to the libm. This is problematic for situations where libm is not available. This patch will break down a fused multiply and add into a multiply followed by an add. Reviewed By: rsuderman Differential Revision: https://reviews.llvm.org/D147811 --- mlir/include/mlir/Dialect/Math/Transforms/Passes.h | 1 + mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp | 16 ++++++++++++++++ mlir/test/Dialect/Math/expand-math.mlir | 12 ++++++++++++ mlir/test/lib/Dialect/Math/TestExpandMath.cpp | 1 + 4 files changed, 30 insertions(+) diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h index a1801dd..5976180 100644 --- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h @@ -16,6 +16,7 @@ class RewritePatternSet; void populateExpandCtlzPattern(RewritePatternSet &patterns); void populateExpandTanPattern(RewritePatternSet &patterns); void populateExpandTanhPattern(RewritePatternSet &patterns); +void populateExpandFmaFPattern(RewritePatternSet &patterns); void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns); diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp index 91aef84..f3e807e 100644 --- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp +++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp @@ -90,6 +90,18 @@ static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) { return success(); } +static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) { + ImplicitLocOpBuilder b(op->getLoc(), rewriter); + Value operandA = op.getOperand(0); + Value operandB = op.getOperand(1); + Value operandC = op.getOperand(2); + Type type = op.getType(); + Value mult = b.create(type, operandA, operandB); + Value add = b.create(type, mult, operandC); + rewriter.replaceOp(op, add); + return success(); +} + // Converts math.ctlz to scf and arith operations. This is done // by performing a binary search on the bits. static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op, @@ -145,3 +157,7 @@ void mlir::populateExpandTanPattern(RewritePatternSet &patterns) { void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) { patterns.add(convertTanhOp); } + +void mlir::populateExpandFmaFPattern(RewritePatternSet &patterns) { + patterns.add(convertFmaFOp); +} diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir index a66ea08..cc6c401 100644 --- a/mlir/test/Dialect/Math/expand-math.mlir +++ b/mlir/test/Dialect/Math/expand-math.mlir @@ -119,3 +119,15 @@ func.func @ctlz_vector(%arg: vector<4xi32>) -> vector<4xi32> { // CHECK-LABEL: @ctlz_vector // CHECK-NOT: math.ctlz + +// ----- + +// CHECK-LABEL: func @fmaf_func +// CHECK-SAME: ([[ARG0:%.+]]: f64, [[ARG1:%.+]]: f64, [[ARG2:%.+]]: f64) -> f64 +func.func @fmaf_func(%a: f64, %b: f64, %c: f64) -> f64 { + // CHECK-NEXT: [[MULF:%.+]] = arith.mulf [[ARG0]], [[ARG1]] + // CHECK-NEXT: [[ADDF:%.+]] = arith.addf [[MULF]], [[ARG2]] + // CHECK-NEXT: return [[ADDF]] + %ret = math.fma %a, %b, %c : f64 + return %ret : f64 +} diff --git a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp index 29b862e..12bc3af 100644 --- a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp +++ b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp @@ -39,6 +39,7 @@ void TestExpandMathPass::runOnOperation() { populateExpandCtlzPattern(patterns); populateExpandTanPattern(patterns); populateExpandTanhPattern(patterns); + populateExpandFmaFPattern(patterns); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } -- 2.7.4