[mlir] Add math.roundeven and llvm.intr.roundeven
authorTres Popp <tpopp@google.com>
Mon, 22 Aug 2022 13:47:44 +0000 (15:47 +0200)
committerTres Popp <tpopp@google.com>
Thu, 25 Aug 2022 11:39:01 +0000 (13:39 +0200)
This is similar to math.round, but rounds to even instead of rounding away from
zero in the case of halfway values. This CL also adds lowerings to libm and
to the LLVM intrinsic.

Differential Revision: https://reviews.llvm.org/D132375

mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
mlir/include/mlir/Dialect/Math/IR/MathOps.td
mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
mlir/test/Dialect/Math/ops.mlir

index ca989bf..219efd5 100644 (file)
@@ -61,6 +61,7 @@ def LLVM_Prefetch : LLVM_ZeroResultIntrOp<"prefetch", [0]> {
                    LLVM_Type:$cache);
 }
 def LLVM_SinOp : LLVM_UnaryIntrinsicOp<"sin">;
+def LLVM_RoundEvenOp : LLVM_UnaryIntrinsicOp<"roundeven">;
 def LLVM_RoundOp : LLVM_UnaryIntrinsicOp<"round">;
 def LLVM_SqrtOp : LLVM_UnaryIntrinsicOp<"sqrt">;
 def LLVM_PowOp : LLVM_BinarySameArgsIntrinsicOp<"pow">;
index 706e2ad..21b5db8 100644 (file)
@@ -742,6 +742,35 @@ def Math_TanhOp : Math_FloatUnaryOp<"tanh"> {
 }
 
 //===----------------------------------------------------------------------===//
+// RoundEvenOp
+//===----------------------------------------------------------------------===//
+
+def Math_RoundEvenOp : Math_FloatUnaryOp<"roundeven"> {
+  let summary = "round of the specified value with halfway cases to even";
+  let description = [{
+    Syntax:
+
+    ```
+    operation ::= ssa-id `=` `math.roundeven` ssa-use `:` type
+    ```
+
+    The `roundeven` operation returns the operand rounded to the nearest integer
+    value in floating-point format. It takes one operand of floating point type
+    (i.e., scalar, tensor or vector) and produces one result of the same type.  The
+    operation rounds the argument to the nearest integer value in floating-point
+    format, rounding halfway cases to even, regardless of the current
+    rounding direction.
+
+    Example:
+
+    ```mlir
+    // Scalar round operation.
+    %a = math.roundeven %b : f64
+    ```
+  }];
+}
+
+//===----------------------------------------------------------------------===//
 // RoundOp
 //===----------------------------------------------------------------------===//
 
index cb34982..8161cc5 100644 (file)
@@ -35,6 +35,8 @@ using Log10OpLowering =
 using Log2OpLowering = VectorConvertToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
 using LogOpLowering = VectorConvertToLLVMPattern<math::LogOp, LLVM::LogOp>;
 using PowFOpLowering = VectorConvertToLLVMPattern<math::PowFOp, LLVM::PowOp>;
+using RoundEvenOpLowering =
+    VectorConvertToLLVMPattern<math::RoundEvenOp, LLVM::RoundEvenOp>;
 using RoundOpLowering =
     VectorConvertToLLVMPattern<math::RoundOp, LLVM::RoundOp>;
 using SinOpLowering = VectorConvertToLLVMPattern<math::SinOp, LLVM::SinOp>;
@@ -285,6 +287,7 @@ void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter,
     Log2OpLowering,
     LogOpLowering,
     PowFOpLowering,
+    RoundEvenOpLowering,
     RoundOpLowering,
     RsqrtOpLowering,
     SinOpLowering,
index 6e3bd2a..5071d60 100644 (file)
@@ -141,16 +141,19 @@ ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
 void mlir::populateMathToLibmConversionPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit,
     llvm::Optional<PatternBenefit> log1pBenefit) {
-  patterns.add<VecOpToScalarOp<math::Atan2Op>, VecOpToScalarOp<math::ExpM1Op>,
-               VecOpToScalarOp<math::TanhOp>, VecOpToScalarOp<math::CosOp>,
-               VecOpToScalarOp<math::SinOp>, VecOpToScalarOp<math::ErfOp>,
-               VecOpToScalarOp<math::RoundOp>, VecOpToScalarOp<math::AtanOp>,
-               VecOpToScalarOp<math::TanOp>>(patterns.getContext(), benefit);
+  patterns
+      .add<VecOpToScalarOp<math::Atan2Op>, VecOpToScalarOp<math::ExpM1Op>,
+           VecOpToScalarOp<math::TanhOp>, VecOpToScalarOp<math::CosOp>,
+           VecOpToScalarOp<math::SinOp>, VecOpToScalarOp<math::ErfOp>,
+           VecOpToScalarOp<math::RoundEvenOp>, VecOpToScalarOp<math::RoundOp>,
+           VecOpToScalarOp<math::AtanOp>, VecOpToScalarOp<math::TanOp>>(
+          patterns.getContext(), benefit);
   patterns.add<PromoteOpToF32<math::Atan2Op>, PromoteOpToF32<math::ExpM1Op>,
                PromoteOpToF32<math::TanhOp>, PromoteOpToF32<math::CosOp>,
                PromoteOpToF32<math::SinOp>, PromoteOpToF32<math::ErfOp>,
-               PromoteOpToF32<math::RoundOp>, PromoteOpToF32<math::AtanOp>,
-               PromoteOpToF32<math::TanOp>>(patterns.getContext(), benefit);
+               PromoteOpToF32<math::RoundEvenOp>, PromoteOpToF32<math::RoundOp>,
+               PromoteOpToF32<math::AtanOp>, PromoteOpToF32<math::TanOp>>(
+      patterns.getContext(), benefit);
   patterns.add<ScalarOpToLibmCall<math::AtanOp>>(patterns.getContext(), "atanf",
                                                  "atan", benefit);
   patterns.add<ScalarOpToLibmCall<math::Atan2Op>>(patterns.getContext(),
@@ -163,6 +166,8 @@ void mlir::populateMathToLibmConversionPatterns(
                                                 "tan", benefit);
   patterns.add<ScalarOpToLibmCall<math::TanhOp>>(patterns.getContext(), "tanhf",
                                                  "tanh", benefit);
+  patterns.add<ScalarOpToLibmCall<math::RoundEvenOp>>(
+      patterns.getContext(), "roundevenf", "roundeven", benefit);
   patterns.add<ScalarOpToLibmCall<math::RoundOp>>(patterns.getContext(),
                                                   "roundf", "round", benefit);
   patterns.add<ScalarOpToLibmCall<math::CosOp>>(patterns.getContext(), "cosf",
index 64e2018..b87ddba 100644 (file)
@@ -190,3 +190,13 @@ func.func @round(%arg0 : f32) {
   %0 = math.round %arg0 : f32
   func.return
 }
+
+// -----
+
+// CHECK-LABEL: func @roundeven(
+// CHECK-SAME: f32
+func.func @roundeven(%arg0 : f32) {
+  // CHECK: "llvm.intr.roundeven"(%arg0) : (f32) -> f32
+  %0 = math.roundeven %arg0 : f32
+  func.return
+}
index b7e9dfc..641dd56 100644 (file)
@@ -14,6 +14,8 @@
 // CHECK-DAG: @tanhf(f32) -> f32
 // CHECK-DAG: @round(f64) -> f64
 // CHECK-DAG: @roundf(f32) -> f32
+// CHECK-DAG: @roundeven(f64) -> f64
+// CHECK-DAG: @roundevenf(f32) -> f32
 // CHECK-DAG: @cos(f64) -> f64
 // CHECK-DAG: @cosf(f32) -> f32
 // CHECK-DAG: @sin(f64) -> f64
@@ -213,6 +215,19 @@ func.func @round_caller(%float: f32, %double: f64) -> (f32, f64) {
   return %float_result, %double_result : f32, f64
 }
 
+// CHECK-LABEL: func @roundeven_caller
+// CHECK-SAME: %[[FLOAT:.*]]: f32
+// CHECK-SAME: %[[DOUBLE:.*]]: f64
+func.func @roundeven_caller(%float: f32, %double: f64) -> (f32, f64) {
+  // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @roundevenf(%[[FLOAT]]) : (f32) -> f32
+  %float_result = math.roundeven %float : f32
+  // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @roundeven(%[[DOUBLE]]) : (f64) -> f64
+  %double_result = math.roundeven %double : f64
+  // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
+  return %float_result, %double_result : f32, f64
+}
+
+
 // CHECK-LABEL: func @cos_caller
 // CHECK-SAME: %[[FLOAT:.*]]: f32
 // CHECK-SAME: %[[DOUBLE:.*]]: f64
@@ -261,6 +276,31 @@ func.func @round_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (v
   return %float_result, %double_result : vector<2xf32>, vector<2xf64>
 }
 
+// CHECK-LABEL:   func @roundeven_vec_caller(
+// CHECK-SAME:                           %[[VAL_0:.*]]: vector<2xf32>,
+// CHECK-SAME:                           %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
+func.func @roundeven_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
+  // CHECK-DAG:       %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
+  // CHECK-DAG:       %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
+  // CHECK:           %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : vector<2xf32>
+  // CHECK:           %[[OUT0_F32:.*]] = call @roundevenf(%[[IN0_F32]]) : (f32) -> f32
+  // CHECK:           %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
+  // CHECK:           %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : vector<2xf32>
+  // CHECK:           %[[OUT1_F32:.*]] = call @roundevenf(%[[IN1_F32]]) : (f32) -> f32
+  // CHECK:           %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
+  %float_result = math.roundeven %float : vector<2xf32>
+  // CHECK:           %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : vector<2xf64>
+  // CHECK:           %[[OUT0_F64:.*]] = call @roundeven(%[[IN0_F64]]) : (f64) -> f64
+  // CHECK:           %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
+  // CHECK:           %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : vector<2xf64>
+  // CHECK:           %[[OUT1_F64:.*]] = call @roundeven(%[[IN1_F64]]) : (f64) -> f64
+  // CHECK:           %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
+  %double_result = math.roundeven %double : vector<2xf64>
+  // CHECK:           return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
+  return %float_result, %double_result : vector<2xf32>, vector<2xf64>
+}
+
+
 // CHECK-LABEL: func @tan_caller
 // CHECK-SAME: %[[FLOAT:.*]]: f32
 // CHECK-SAME: %[[DOUBLE:.*]]: f64
index c25d097..1af096c 100644 (file)
@@ -233,6 +233,19 @@ func.func @round(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {
   return
 }
 
+// CHECK-LABEL: func @roundeven(
+// CHECK-SAME:             %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>)
+func.func @roundeven(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {
+  // CHECK: %{{.*}} = math.roundeven %[[F]] : f32
+  %0 = math.roundeven %f : f32
+  // CHECK: %{{.*}} = math.roundeven %[[V]] : vector<4xf32>
+  %1 = math.roundeven %v : vector<4xf32>
+  // CHECK: %{{.*}} = math.roundeven %[[T]] : tensor<4x4x?xf32>
+  %2 = math.roundeven %t : tensor<4x4x?xf32>
+  return
+}
+
+
 // CHECK-LABEL: func @ipowi(
 // CHECK-SAME:             %[[I:.*]]: i32, %[[V:.*]]: vector<4xi32>, %[[T:.*]]: tensor<4x4x?xi32>)
 func.func @ipowi(%i: i32, %v: vector<4xi32>, %t: tensor<4x4x?xi32>) {