[mlir][math] Expand math.round to truncate, compare and increment.
authorBalaji V. Iyer <bviyer@gmail.com>
Thu, 13 Apr 2023 17:58:14 +0000 (17:58 +0000)
committerRobert Suderman <suderman@google.com>
Thu, 13 Apr 2023 18:02:10 +0000 (18:02 +0000)
Round functions are pushed directly to libm. This is problematic for
situations where libm is not available. This patch will decompose the
roundf function by adding 0.5 to positive number to input
(subtracting for negative) following by a truncate.

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D148026

mlir/include/mlir/Dialect/Math/Transforms/Passes.h
mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
mlir/test/Dialect/Math/expand-math.mlir
mlir/test/lib/Dialect/Math/TestExpandMath.cpp
mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir

index 3ac18c3..6cd5b0a 100644 (file)
@@ -20,6 +20,7 @@ void populateExpandFmaFPattern(RewritePatternSet &patterns);
 void populateExpandFloorFPattern(RewritePatternSet &patterns);
 void populateExpandCeilFPattern(RewritePatternSet &patterns);
 void populateExpandExp2FPattern(RewritePatternSet &patterns);
+void populateExpandRoundFPattern(RewritePatternSet &patterns);
 void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);
 
 struct MathPolynomialApproximationOptions {
index e9447dc..bc35263 100644 (file)
@@ -174,6 +174,28 @@ static LogicalResult convertExp2fOp(math::Exp2Op op,
   return success();
 }
 
+static LogicalResult convertRoundOp(math::RoundOp op,
+                                    PatternRewriter &rewriter) {
+  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+  Value operand = op.getOperand();
+  Type opType = operand.getType();
+
+  // Creating constants for later use.
+  Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
+  Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
+  Value negHalf = createFloatConst(op->getLoc(), opType, -0.5, rewriter);
+
+  Value posCheck =
+      b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, operand, zero);
+  Value incrValue =
+      b.create<arith::SelectOp>(op->getLoc(), posCheck, half, negHalf);
+  Value add = b.create<arith::AddFOp>(opType, operand, incrValue);
+
+  Value fpFixedConvert = createTruncatedFPValue(add, b);
+  rewriter.replaceOp(op, fpFixedConvert);
+  return success();
+}
+
 // Converts math.ctlz to scf and arith operations. This is done
 // by performing a binary search on the bits.
 static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,
@@ -242,6 +264,10 @@ void mlir::populateExpandExp2FPattern(RewritePatternSet &patterns) {
   patterns.add(convertExp2fOp);
 }
 
+void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) {
+  patterns.add(convertRoundOp);
+}
+
 void mlir::populateExpandFloorFPattern(RewritePatternSet &patterns) {
   patterns.add(convertFloorOp);
 }
index 5098696..b3a5668 100644 (file)
@@ -189,3 +189,21 @@ func.func @exp2f_func_tensor(%a: tensor<1xf32>) -> tensor<1xf32> {
   %ret = math.exp2 %a : tensor<1xf32>
   return %ret : tensor<1xf32>
 }
+
+// -----
+
+// CHECK-LABEL:      func @roundf_func
+// CHECK-SAME:      ([[ARG0:%.+]]: f64) -> f64
+func.func @roundf_func(%a: f64) -> f64 {
+  // CHECK-DAG:   [[CST:%.+]] = arith.constant 0.000
+  // CHECK-DAG:   [[CST_0:%.+]] = arith.constant 5.000000e-01
+  // CHECK-DAG:   [[CST_1:%.+]] = arith.constant -5.000000e-01
+  // CHECK-DAG:  [[COMP:%.+]] = arith.cmpf oge, [[ARG0]], [[CST]]
+  // CHECK-DAG:  [[SEL:%.+]] = arith.select [[COMP]], [[CST_0]], [[CST_1]]
+  // CHECK-DAG:  [[ADDF:%.+]] = arith.addf [[ARG0]], [[SEL]]
+  // CHECK-DAG:   [[CVTI:%.+]] = arith.fptosi [[ADDF]]
+  // CHECK-DAG:   [[CVTF:%.+]] = arith.sitofp [[CVTI]]
+  // CHECK:   return [[CVTF]]
+  %ret = math.round %a : f64
+  return %ret : f64
+}
index 29eff99..5692ecf 100644 (file)
@@ -43,6 +43,7 @@ void TestExpandMathPass::runOnOperation() {
   populateExpandFmaFPattern(patterns);
   populateExpandFloorFPattern(patterns);
   populateExpandCeilFPattern(patterns);
+  populateExpandRoundFPattern(patterns);
   (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
 }
 
index 3fb3b2b..f3c7a2c 100644 (file)
@@ -55,7 +55,54 @@ func.func @exp2f() {
   return
 }
 
+// -------------------------------------------------------------------------- //
+// round.
+// -------------------------------------------------------------------------- //
+func.func @func_roundf(%a : f32) {
+  %r = math.round %a : f32
+  vector.print %r : f32
+  return
+}
+
+func.func @roundf() {
+  // CHECK: 4
+  %a = arith.constant 3.8 : f32
+  call @func_roundf(%a) : (f32) -> ()
+
+  // CHECK: -4
+  %b = arith.constant -3.8 : f32
+  call @func_roundf(%b) : (f32) -> ()
+
+  // CHECK: 0
+  %c = arith.constant 0.0 : f32
+  call @func_roundf(%c) : (f32) -> ()
+
+  // CHECK: -4
+  %d = arith.constant -4.2 : f32
+  call @func_roundf(%d) : (f32) -> ()
+
+  // CHECK: -495
+  %e = arith.constant -495.0 : f32
+  call @func_roundf(%e) : (f32) -> ()
+
+  // CHECK: 495
+  %f = arith.constant 495.0 : f32
+  call @func_roundf(%f) : (f32) -> ()
+
+  // CHECK: 9
+  %g = arith.constant 8.5 : f32
+  call @func_roundf(%g) : (f32) -> ()
+
+  // CHECK: -9
+  %h = arith.constant -8.5 : f32
+  call @func_roundf(%h) : (f32) -> ()
+
+  return
+}
+
+
 func.func @main() {
   call @exp2f() : () -> ()
+  call @roundf() : () -> ()
   return
 }