[MLIR][Math] Add round operation
authorlorenzo chelini <l.chelini@icloud.com>
Thu, 2 Jun 2022 14:49:23 +0000 (16:49 +0200)
committerlorenzo chelini <l.chelini@icloud.com>
Wed, 8 Jun 2022 11:07:39 +0000 (13:07 +0200)
Introduce RoundOp in the math dialect. The operation rounds the operand to the
nearest integer value in floating-point format. RoundOp lowers to LLVM
intrinsics 'llvm.intr.round' or as a function call to libm (round or roundf).

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D127286

mlir/include/mlir/Dialect/Math/IR/MathOps.td
mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
mlir/test/Dialect/Math/ops.mlir

index 1378135..58cf55f 100644 (file)
@@ -652,4 +652,30 @@ def Math_TanhOp : Math_FloatUnaryOp<"tanh"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// RoundOp
+//===----------------------------------------------------------------------===//
+
+def Math_RoundOp : Math_FloatUnaryOp<"round"> {
+  let summary = "round of the specified value";
+  let description = [{
+    Syntax:
+
+    ```
+    operation ::= ssa-id `=` `math.round` ssa-use `:` type
+    ```
+
+    The `round` operation returns the operand rounded to the nearest integer
+    value in floating-point format. It takes one operand of floating point type
+    (i.e., scalar, tensor or vector) and produces one result of the same type.
+
+    Example:
+
+    ```mlir
+    // Scalar round operation.
+    %a = math.round %b : f64
+    ```
+  }];
+}
+
 #endif // MATH_OPS
index 189680c..510540d 100644 (file)
@@ -37,6 +37,8 @@ using LogOpLowering = VectorConvertToLLVMPattern<math::LogOp, LLVM::LogOp>;
 using PowFOpLowering = VectorConvertToLLVMPattern<math::PowFOp, LLVM::PowOp>;
 using SinOpLowering = VectorConvertToLLVMPattern<math::SinOp, LLVM::SinOp>;
 using SqrtOpLowering = VectorConvertToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
+using RoundOpLowering =
+    VectorConvertToLLVMPattern<math::RoundOp, LLVM::RoundOp>;
 
 // A `CtLz/CtTz(a)` is converted into `CtLz/CtTz(a, false)`.
 template <typename MathOp, typename LLVMOp>
@@ -285,7 +287,8 @@ void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter,
     PowFOpLowering,
     RsqrtOpLowering,
     SinOpLowering,
-    SqrtOpLowering
+    SqrtOpLowering,
+    RoundOpLowering
   >(converter);
   // clang-format on
 }
index 6c9d02c..78835e1 100644 (file)
@@ -152,6 +152,8 @@ void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns,
                                                   "expm1f", "expm1", benefit);
   patterns.add<ScalarOpToLibmCall<math::TanhOp>>(patterns.getContext(), "tanhf",
                                                  "tanh", benefit);
+  patterns.add<ScalarOpToLibmCall<math::RoundOp>>(patterns.getContext(),
+                                                  "roundf", "round", benefit);
 }
 
 namespace {
index af32271..6378ea6 100644 (file)
@@ -172,3 +172,12 @@ func.func @powf(%arg0 : f64) {
   func.return
 }
 
+// -----
+
+// CHECK-LABEL: func @round(
+// CHECK-SAME: f32
+func.func @round(%arg0 : f32) {
+  // CHECK: "llvm.intr.round"(%arg0) : (f32) -> f32
+  %0 = math.round %arg0 : f32
+  func.return
+}
index 7cdb56e..cb09988 100644 (file)
@@ -8,6 +8,8 @@
 // CHECK-DAG: @atan2f(f32, f32) -> f32
 // CHECK-DAG: @tanh(f64) -> f64
 // CHECK-DAG: @tanhf(f32) -> f32
+// CHECK-DAG: @round(f64) -> f64
+// CHECK-DAG: @roundf(f32) -> f32
 
 // CHECK-LABEL: func @tanh_caller
 // CHECK-SAME: %[[FLOAT:.*]]: f32
@@ -21,7 +23,6 @@ func.func @tanh_caller(%float: f32, %double: f64) -> (f32, f64)  {
   return %float_result, %double_result : f32, f64
 }
 
-
 // CHECK-LABEL: func @atan2_caller
 // CHECK-SAME: %[[FLOAT:.*]]: f32
 // CHECK-SAME: %[[DOUBLE:.*]]: f64
@@ -116,3 +117,15 @@ func.func @expm1_multidim_vec_caller(%float: vector<2x2xf32>) -> (vector<2x2xf32
 // CHECK:           %[[VAL_4:.*]] = vector.insert %[[OUT1_1_F32]], %[[VAL_3]] [1, 1] : f32 into vector<2x2xf32>
 // CHECK:           return %[[VAL_4]] : vector<2x2xf32>
 // CHECK:         }
+
+// CHECK-LABEL: func @round_caller
+// CHECK-SAME: %[[FLOAT:.*]]: f32
+// CHECK-SAME: %[[DOUBLE:.*]]: f64
+func.func @round_caller(%float: f32, %double: f64) -> (f32, f64) {
+  // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @roundf(%[[FLOAT]]) : (f32) -> f32
+  %float_result = math.round %float : f32
+  // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @round(%[[DOUBLE]]) : (f64) -> f64
+  %double_result = math.round %double : f64
+  // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
+  return %float_result, %double_result : f32, f64
+}
index a1bb9af..7acd893 100644 (file)
@@ -194,3 +194,15 @@ func.func @tanh(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {
   %2 = math.tanh %t : tensor<4x4x?xf32>
   return
 }
+
+// CHECK-LABEL: func @round(
+// CHECK-SAME:             %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>)
+func.func @round(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {
+  // CHECK: %{{.*}} = math.round %[[F]] : f32
+  %0 = math.round %f : f32
+  // CHECK: %{{.*}} = math.round %[[V]] : vector<4xf32>
+  %1 = math.round %v : vector<4xf32>
+  // CHECK: %{{.*}} = math.round %[[T]] : tensor<4x4x?xf32>
+  %2 = math.round %t : tensor<4x4x?xf32>
+  return
+}