void populateExpandFloorFPattern(RewritePatternSet &patterns);
void populateExpandCeilFPattern(RewritePatternSet &patterns);
void populateExpandExp2FPattern(RewritePatternSet &patterns);
+void populateExpandRoundFPattern(RewritePatternSet &patterns);
void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);
struct MathPolynomialApproximationOptions {
return success();
}
+static LogicalResult convertRoundOp(math::RoundOp op,
+ PatternRewriter &rewriter) {
+ ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+ Value operand = op.getOperand();
+ Type opType = operand.getType();
+
+ // Creating constants for later use.
+ Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
+ Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
+ Value negHalf = createFloatConst(op->getLoc(), opType, -0.5, rewriter);
+
+ Value posCheck =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, operand, zero);
+ Value incrValue =
+ b.create<arith::SelectOp>(op->getLoc(), posCheck, half, negHalf);
+ Value add = b.create<arith::AddFOp>(opType, operand, incrValue);
+
+ Value fpFixedConvert = createTruncatedFPValue(add, b);
+ rewriter.replaceOp(op, fpFixedConvert);
+ return success();
+}
+
// Converts math.ctlz to scf and arith operations. This is done
// by performing a binary search on the bits.
static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,
patterns.add(convertExp2fOp);
}
+void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) {
+ patterns.add(convertRoundOp);
+}
+
void mlir::populateExpandFloorFPattern(RewritePatternSet &patterns) {
patterns.add(convertFloorOp);
}
%ret = math.exp2 %a : tensor<1xf32>
return %ret : tensor<1xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @roundf_func
+// CHECK-SAME: ([[ARG0:%.+]]: f64) -> f64
+func.func @roundf_func(%a: f64) -> f64 {
+ // CHECK-DAG: [[CST:%.+]] = arith.constant 0.000
+ // CHECK-DAG: [[CST_0:%.+]] = arith.constant 5.000000e-01
+ // CHECK-DAG: [[CST_1:%.+]] = arith.constant -5.000000e-01
+ // CHECK-DAG: [[COMP:%.+]] = arith.cmpf oge, [[ARG0]], [[CST]]
+ // CHECK-DAG: [[SEL:%.+]] = arith.select [[COMP]], [[CST_0]], [[CST_1]]
+ // CHECK-DAG: [[ADDF:%.+]] = arith.addf [[ARG0]], [[SEL]]
+ // CHECK-DAG: [[CVTI:%.+]] = arith.fptosi [[ADDF]]
+ // CHECK-DAG: [[CVTF:%.+]] = arith.sitofp [[CVTI]]
+ // CHECK: return [[CVTF]]
+ %ret = math.round %a : f64
+ return %ret : f64
+}
populateExpandFmaFPattern(patterns);
populateExpandFloorFPattern(patterns);
populateExpandCeilFPattern(patterns);
+ populateExpandRoundFPattern(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
return
}
+// -------------------------------------------------------------------------- //
+// round.
+// -------------------------------------------------------------------------- //
+func.func @func_roundf(%a : f32) {
+ %r = math.round %a : f32
+ vector.print %r : f32
+ return
+}
+
+func.func @roundf() {
+ // CHECK: 4
+ %a = arith.constant 3.8 : f32
+ call @func_roundf(%a) : (f32) -> ()
+
+ // CHECK: -4
+ %b = arith.constant -3.8 : f32
+ call @func_roundf(%b) : (f32) -> ()
+
+ // CHECK: 0
+ %c = arith.constant 0.0 : f32
+ call @func_roundf(%c) : (f32) -> ()
+
+ // CHECK: -4
+ %d = arith.constant -4.2 : f32
+ call @func_roundf(%d) : (f32) -> ()
+
+ // CHECK: -495
+ %e = arith.constant -495.0 : f32
+ call @func_roundf(%e) : (f32) -> ()
+
+ // CHECK: 495
+ %f = arith.constant 495.0 : f32
+ call @func_roundf(%f) : (f32) -> ()
+
+ // CHECK: 9
+ %g = arith.constant 8.5 : f32
+ call @func_roundf(%g) : (f32) -> ()
+
+ // CHECK: -9
+ %h = arith.constant -8.5 : f32
+ call @func_roundf(%h) : (f32) -> ()
+
+ return
+}
+
+
func.func @main() {
call @exp2f() : () -> ()
+ call @roundf() : () -> ()
return
}