[mlir][math] Expand math.ceilf to truncate, compares and increments
authorBalaji V. Iyer <bviyer@gmail.com>
Tue, 11 Apr 2023 13:43:15 +0000 (13:43 +0000)
committerRobert Suderman <suderman@google.com>
Tue, 11 Apr 2023 13:52:45 +0000 (13:52 +0000)
Ceilf are pushed directly to libm. This is problematic for
situations where libm is not available. This patch will break down
a ceilf function to truncate followed by an increment if the
truncated value is smaller than the input value.

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D147974

mlir/include/mlir/Dialect/Math/Transforms/Passes.h
mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
mlir/test/Dialect/Math/expand-math.mlir
mlir/test/lib/Dialect/Math/TestExpandMath.cpp
mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir

index f933748..1b32de2 100644 (file)
@@ -18,6 +18,7 @@ void populateExpandTanPattern(RewritePatternSet &patterns);
 void populateExpandTanhPattern(RewritePatternSet &patterns);
 void populateExpandFmaFPattern(RewritePatternSet &patterns);
 void populateExpandFloorFPattern(RewritePatternSet &patterns);
+void populateExpandCeilFPattern(RewritePatternSet &patterns);
 void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);
 
 struct MathPolynomialApproximationOptions {
index 2dab48d..b70ac4e 100644 (file)
@@ -46,6 +46,13 @@ static Value createIntConst(Location loc, Type type, int64_t value,
   return b.create<arith::ConstantOp>(loc, attr);
 }
 
+static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b) {
+  Type opType = operand.getType();
+  Value fixedConvert = b.create<arith::FPToSIOp>(b.getI64Type(), operand);
+  Value fpFixedConvert = b.create<arith::SIToFPOp>(opType, fixedConvert);
+  return fpFixedConvert;
+}
+
 /// Expands tanh op into
 ///   1) 1-exp^{-2x} / 1+exp^{-2x}, if x => 0
 ///   2) exp^{2x}-1 / exp^{2x}+1  , if x < 0
@@ -112,8 +119,7 @@ static LogicalResult convertFloorOp(math::FloorOp op,
   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
   Value operand = op.getOperand();
   Type opType = operand.getType();
-  Value fixedConvert = b.create<arith::FPToSIOp>(b.getI64Type(), operand);
-  Value fpFixedConvert = b.create<arith::SIToFPOp>(opType, fixedConvert);
+  Value fpFixedConvert = createTruncatedFPValue(operand, b);
 
   // Creating constants for later use.
   Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
@@ -128,6 +134,30 @@ static LogicalResult convertFloorOp(math::FloorOp op,
   return success();
 }
 
+// Converts a ceilf() function to the following:
+// ceilf(float x) ->
+//      y = (float)(int) x
+//      if (x > y) then incr = 1 else incr = 0
+//      y = y + incr   <= replace this op with the ceilf op.
+static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
+  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+  Value operand = op.getOperand();
+  Type opType = operand.getType();
+  Value fpFixedConvert = createTruncatedFPValue(operand, b);
+
+  // Creating constants for later use.
+  Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
+  Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
+
+  Value gtCheck = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, operand,
+                                          fpFixedConvert);
+  Value incrValue = b.create<arith::SelectOp>(op->getLoc(), gtCheck, one, zero);
+
+  Value ret = b.create<arith::AddFOp>(opType, fpFixedConvert, incrValue);
+  rewriter.replaceOp(op, ret);
+  return success();
+}
+
 // Converts math.ctlz to scf and arith operations. This is done
 // by performing a binary search on the bits.
 static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,
@@ -187,6 +217,11 @@ void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) {
 void mlir::populateExpandFmaFPattern(RewritePatternSet &patterns) {
   patterns.add(convertFmaFOp);
 }
+
+void mlir::populateExpandCeilFPattern(RewritePatternSet &patterns) {
+  patterns.add(convertCeilOp);
+}
+
 void mlir::populateExpandFloorFPattern(RewritePatternSet &patterns) {
   patterns.add(convertFloorOp);
 }
index d67193c..8ab6449 100644 (file)
@@ -148,3 +148,20 @@ func.func @floorf_func(%a: f64) -> f64 {
   %ret = math.floor %a : f64
   return %ret : f64
 }
+
+// -----
+
+// CHECK-LABEL:     func @ceilf_func
+// CHECK-SAME:      ([[ARG0:%.+]]: f64) -> f64
+func.func @ceilf_func(%a: f64) -> f64 {
+  // CHECK-DAG:   [[CST:%.+]] = arith.constant 0.000
+  // CHECK-DAG:   [[CST_0:%.+]] = arith.constant 1.000
+  // CHECK-NEXT:   [[CVTI:%.+]] = arith.fptosi [[ARG0]]
+  // CHECK-NEXT:   [[CVTF:%.+]] = arith.sitofp [[CVTI]]
+  // CHECK-NEXT:   [[COMP:%.+]] = arith.cmpf ogt, [[ARG0]], [[CVTF]]
+  // CHECK-NEXT:   [[INCR:%.+]] = arith.select [[COMP]], [[CST_0]], [[CST]]
+  // CHECK-NEXT:   [[ADDF:%.+]] = arith.addf [[CVTF]], [[INCR]]
+  // CHECK-NEXT:   return [[ADDF]]
+  %ret = math.ceil %a : f64
+  return %ret : f64
+}
index e6a4489..c670617 100644 (file)
@@ -41,6 +41,7 @@ void TestExpandMathPass::runOnOperation() {
   populateExpandTanhPattern(patterns);
   populateExpandFmaFPattern(patterns);
   populateExpandFloorFPattern(patterns);
+  populateExpandCeilFPattern(patterns);
   (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
 }
 
index 0fff84d..130147b 100644 (file)
@@ -647,6 +647,43 @@ func.func @floorf() {
   return
 }
 
+// -------------------------------------------------------------------------- //
+// ceil.
+// -------------------------------------------------------------------------- //
+func.func @func_ceilf32(%a : f32) {
+  %r = math.ceil %a : f32
+  vector.print %r : f32
+  return
+}
+
+func.func @ceilf() {
+  // CHECK: 4
+  %a = arith.constant 3.8 : f32
+  call @func_ceilf32(%a) : (f32) -> ()
+
+  // CHECK: -3
+  %b = arith.constant -3.8 : f32
+  call @func_ceilf32(%b) : (f32) -> ()
+
+  // CHECK: 0
+  %c = arith.constant 0.0 : f32
+  call @func_ceilf32(%c) : (f32) -> ()
+
+  // CHECK: -4
+  %d = arith.constant -4.2 : f32
+  call @func_ceilf32(%d) : (f32) -> ()
+
+  // CHECK: -495
+  %e = arith.constant -495.0 : f32
+  call @func_ceilf32(%e) : (f32) -> ()
+
+  // CHECK: 495
+  %f = arith.constant 495.0 : f32
+  call @func_ceilf32(%f) : (f32) -> ()
+
+  return
+}
+
 func.func @main() {
   call @tanh(): () -> ()
   call @log(): () -> ()
@@ -661,6 +698,7 @@ func.func @main() {
   call @atan2() : () -> ()
   call @cbrt() : () -> ()
   call @floorf() : () -> ()
+  call @ceilf() : () -> ()
   return
 }