rewriter.replaceOp(op, ret);
return success();
}
+// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
+static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
+ ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+ Value operandA = op.getOperand(0);
+ Value operandB = op.getOperand(1);
+ Type opType = operandA.getType();
+
+ Value logA = b.create<math::LogOp>(opType, operandA);
+ Value mult = b.create<arith::MulFOp>(opType, logA, operandB);
+ Value expResult = b.create<math::ExpOp>(opType, mult);
+ rewriter.replaceOp(op, expResult);
+ return success();
+}
// exp2f(float x) -> exp(x * ln(2))
// Proof: Let's say 2^x = y
patterns.add(convertExp2fOp);
}
+void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) {
+ patterns.add(convertPowfOp);
+}
+
void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) {
patterns.add(convertRoundOp);
}
%ret = math.round %a : f64
return %ret : f64
}
+
+// -----
+
+// CHECK-LABEL: func @powf_func
+// CHECK-SAME: ([[ARG0:%.+]]: f64, [[ARG1:%.+]]: f64)
+func.func @powf_func(%a: f64, %b: f64) ->f64 {
+ // CHECK-DAG: [[LOG:%.+]] = math.log [[ARG0]]
+ // CHECK-DAG: [[MULT:%.+]] = arith.mulf [[LOG]], [[ARG1]]
+ // CHECK-DAG: [[EXPR:%.+]] = math.exp [[MULT]]
+ // CHECK: return [[EXPR]]
+ %ret = math.powf %a, %b : f64
+ return %ret : f64
+}
return
}
+// -------------------------------------------------------------------------- //
+// pow.
+// -------------------------------------------------------------------------- //
+func.func @func_powff64(%a : f64, %b : f64) {
+ %r = math.powf %a, %b : f64
+ vector.print %r : f64
+ return
+}
+
+func.func @powf() {
+ // CHECK: 16
+ %a = arith.constant 4.0 : f64
+ %a_p = arith.constant 2.0 : f64
+ call @func_powff64(%a, %a_p) : (f64, f64) -> ()
+
+ // CHECK: -nan
+ %b = arith.constant -3.0 : f64
+ %b_p = arith.constant 3.0 : f64
+ call @func_powff64(%b, %b_p) : (f64, f64) -> ()
+
+ // CHECK: 2.343
+ %c = arith.constant 2.343 : f64
+ %c_p = arith.constant 1.000 : f64
+ call @func_powff64(%c, %c_p) : (f64, f64) -> ()
+
+ // CHECK: 0.176171
+ %d = arith.constant 4.25 : f64
+ %d_p = arith.constant -1.2 : f64
+ call @func_powff64(%d, %d_p) : (f64, f64) -> ()
+
+ // CHECK: 1
+ %e = arith.constant 4.385 : f64
+ %e_p = arith.constant 0.00 : f64
+ call @func_powff64(%e, %e_p) : (f64, f64) -> ()
+
+ // CHECK: 6.62637
+ %f = arith.constant 4.835 : f64
+ %f_p = arith.constant 1.2 : f64
+ call @func_powff64(%f, %f_p) : (f64, f64) -> ()
+
+ // CHECK: -nan
+ %g = arith.constant 0xff80000000000000 : f64
+ call @func_powff64(%g, %g) : (f64, f64) -> ()
+
+ // CHECK: nan
+ %h = arith.constant 0x7fffffffffffffff : f64
+ call @func_powff64(%h, %h) : (f64, f64) -> ()
+
+ // CHECK: nan
+ %i = arith.constant 1.0 : f64
+ call @func_powff64(%i, %h) : (f64, f64) -> ()
+
+ // CHECK: inf
+ %j = arith.constant 29385.0 : f64
+ %j_p = arith.constant 23598.0 : f64
+ call @func_powff64(%j, %j_p) : (f64, f64) -> ()
+ return
+}
func.func @main() {
call @exp2f() : () -> ()
call @roundf() : () -> ()
+ call @powf() : () -> ()
return
}