[mlir][llvm] Add rounding intrinsics
authorLukas Sommer <lukas.sommer@codeplay.com>
Mon, 29 May 2023 15:58:50 +0000 (17:58 +0200)
committerMarkus Böck <markus.boeck02@gmail.com>
Mon, 29 May 2023 16:13:08 +0000 (18:13 +0200)
Add some of the missing libm rounding intrinsics to the LLVM dialect:
* `llvm.rint`
* `llvm.nearbyint`
* `llvm.lround`
* `llvm.llround`
* `llvm.lrint`
* `llvm.llrint`

Differential Revision: https://reviews.llvm.org/D151558

mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
mlir/test/Target/LLVMIR/Import/intrinsic.ll
mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir

index eb815b3..a409223 100644 (file)
@@ -130,6 +130,18 @@ def LLVM_PowIOp : LLVM_OneResultIntrOp<"powi", [], [0,1],
   let assemblyFormat = "`(` operands `)` custom<LLVMOpAttrs>(attr-dict) `:` "
       "functional-type(operands, results)";
 }
+def LLVM_RintOp : LLVM_UnaryIntrOpF<"rint">;
+def LLVM_NearbyintOp : LLVM_UnaryIntrOpF<"nearbyint">;
+class LLVM_IntRoundIntrOpBase<string func> :
+        LLVM_OneResultIntrOp<func, [0], [0], [Pure]> {
+  let arguments = (ins LLVM_AnyFloat:$val);
+  let assemblyFormat = "`(` operands `)` custom<LLVMOpAttrs>(attr-dict) `:` "
+      "functional-type(operands, results)";
+}
+def LLVM_LroundOp : LLVM_IntRoundIntrOpBase<"lround">;
+def LLVM_LlroundOp : LLVM_IntRoundIntrOpBase<"llround">;
+def LLVM_LrintOp : LLVM_IntRoundIntrOpBase<"lrint">;
+def LLVM_LlrintOp : LLVM_IntRoundIntrOpBase<"llrint">;
 def LLVM_BitReverseOp : LLVM_UnaryIntrOpI<"bitreverse">;
 def LLVM_ByteSwapOp : LLVM_UnaryIntrOpI<"bswap">;
 def LLVM_CountLeadingZerosOp : LLVM_CountZerosIntrOp<"ctlz">;
index 811dc44..e9b3615 100644 (file)
@@ -117,6 +117,72 @@ define void @pow_test(float %0, float %1, <8 x float> %2, <8 x float> %3) {
   %6 = call <8 x float> @llvm.pow.v8f32(<8 x float> %2, <8 x float> %3)
   ret void
 }
+
+; CHECK-LABEL: llvm.func @rint_test
+define void @rint_test(float %0, double %1, <8 x float> %2, <8 x double> %3) {
+  ; CHECK: llvm.intr.rint(%{{.*}}) : (f32) -> f32
+  %5 = call float @llvm.rint.f32(float %0)
+  ; CHECK: llvm.intr.rint(%{{.*}}) : (f64) -> f64
+  %6 = call double @llvm.rint.f64(double %1)
+  ; CHECK: llvm.intr.rint(%{{.*}}) : (vector<8xf32>) -> vector<8xf32>
+  %7 = call <8 x float> @llvm.rint.v8f32(<8 x float> %2)
+  ; CHECK: llvm.intr.rint(%{{.*}}) : (vector<8xf64>) -> vector<8xf64>
+  %8 = call <8 x double> @llvm.rint.v8f64(<8 x double> %3)
+  ret void
+}
+; CHECK-LABEL: llvm.func @nearbyint_test
+define void @nearbyint_test(float %0, double %1, <8 x float> %2, <8 x double> %3) {
+  ; CHECK: llvm.intr.nearbyint(%{{.*}}) : (f32) -> f32
+  %5 = call float @llvm.nearbyint.f32(float %0)
+  ; CHECK: llvm.intr.nearbyint(%{{.*}}) : (f64) -> f64
+  %6 = call double @llvm.nearbyint.f64(double %1)
+  ; CHECK: llvm.intr.nearbyint(%{{.*}}) : (vector<8xf32>) -> vector<8xf32>
+  %7 = call <8 x float> @llvm.nearbyint.v8f32(<8 x float> %2)
+  ; CHECK: llvm.intr.nearbyint(%{{.*}}) : (vector<8xf64>) -> vector<8xf64>
+  %8 = call <8 x double> @llvm.nearbyint.v8f64(<8 x double> %3)
+  ret void
+}
+; CHECK-LABEL: llvm.func @lround_test
+define void @lround_test(float %0, double %1) {
+  ; CHECK: llvm.intr.lround(%{{.*}}) : (f32) -> i32
+  %3 = call i32 @llvm.lround.i32.f32(float %0)
+  ; CHECK: llvm.intr.lround(%{{.*}}) : (f32) -> i64
+  %4 = call i64 @llvm.lround.i64.f32(float %0)
+  ; CHECK: llvm.intr.lround(%{{.*}}) : (f64) -> i32
+  %5 = call i32 @llvm.lround.i32.f64(double %1)
+  ; CHECK: llvm.intr.lround(%{{.*}}) : (f64) -> i64
+  %6 = call i64 @llvm.lround.i64.f64(double %1)
+  ret void
+}
+; CHECK-LABEL: llvm.func @llround_test
+define void @llround_test(float %0, double %1) {
+  ; CHECK: llvm.intr.llround(%{{.*}}) : (f32) -> i64
+  %3 = call i64 @llvm.llround.i64.f32(float %0)
+  ; CHECK: llvm.intr.llround(%{{.*}}) : (f64) -> i64
+  %4 = call i64 @llvm.llround.i64.f64(double %1)
+  ret void
+}
+; CHECK-LABEL: llvm.func @lrint_test
+define void @lrint_test(float %0, double %1) {
+  ; CHECK: llvm.intr.lrint(%{{.*}}) : (f32) -> i32
+  %3 = call i32 @llvm.lrint.i32.f32(float %0)
+  ; CHECK: llvm.intr.lrint(%{{.*}}) : (f32) -> i64
+  %4 = call i64 @llvm.lrint.i64.f32(float %0)
+  ; CHECK: llvm.intr.lrint(%{{.*}}) : (f64) -> i32
+  %5 = call i32 @llvm.lrint.i32.f64(double %1)
+  ; CHECK: llvm.intr.lrint(%{{.*}}) : (f64) -> i64
+  %6 = call i64 @llvm.lrint.i64.f64(double %1)
+  ret void
+}
+; CHECK-LABEL: llvm.func @llrint_test
+define void @llrint_test(float %0, double %1) {
+  ; CHECK: llvm.intr.llrint(%{{.*}}) : (f32) -> i64
+  %3 = call i64 @llvm.llrint.i64.f32(float %0)
+  ; CHECK: llvm.intr.llrint(%{{.*}}) : (f64) -> i64
+  %4 = call i64 @llvm.llrint.i64.f64(double %1)
+  ret void
+}
+
 ; CHECK-LABEL:  llvm.func @bitreverse_test
 define void @bitreverse_test(i32 %0, <8 x i32> %1) {
   ; CHECK:   llvm.intr.bitreverse(%{{.*}}) : (i32) -> i32
@@ -781,6 +847,26 @@ declare float @llvm.copysign.f32(float, float)
 declare <8 x float> @llvm.copysign.v8f32(<8 x float>, <8 x float>)
 declare float @llvm.pow.f32(float, float)
 declare <8 x float> @llvm.pow.v8f32(<8 x float>, <8 x float>)
+declare float @llvm.rint.f32(float)
+declare double @llvm.rint.f64(double)
+declare <8 x float> @llvm.rint.v8f32(<8 x float>)
+declare <8 x double> @llvm.rint.v8f64(<8 x double>)
+declare float @llvm.nearbyint.f32(float)
+declare double @llvm.nearbyint.f64(double)
+declare <8 x float> @llvm.nearbyint.v8f32(<8 x float>)
+declare <8 x double> @llvm.nearbyint.v8f64(<8 x double>)
+declare i32 @llvm.lround.i32.f32(float)
+declare i64 @llvm.lround.i64.f32(float)
+declare i32 @llvm.lround.i32.f64(double)
+declare i64 @llvm.lround.i64.f64(double)
+declare i64 @llvm.llround.i64.f32(float)
+declare i64 @llvm.llround.i64.f64(double)
+declare i32 @llvm.lrint.i32.f32(float)
+declare i64 @llvm.lrint.i64.f32(float)
+declare i32 @llvm.lrint.i32.f64(double)
+declare i64 @llvm.lrint.i64.f64(double)
+declare i64 @llvm.llrint.i64.f32(float)
+declare i64 @llvm.llrint.i64.f64(double)
 declare i32 @llvm.bitreverse.i32(i32)
 declare <8 x i32> @llvm.bitreverse.v8i32(<8 x i32>)
 declare i32 @llvm.bswap.i32(i32)
index c6a3c7f..ec619b9 100644 (file)
@@ -134,6 +134,76 @@ llvm.func @pow_test(%arg0: f32, %arg1: f32, %arg2: vector<8xf32>, %arg3: vector<
   llvm.return
 }
 
+// CHECK-LABEL: @rint_test
+llvm.func @rint_test(%arg0 : f32, %arg1 : f64, %arg2 : vector<8xf32>, %arg3 : vector<8xf64>) {
+  // CHECK: call float @llvm.rint.f32
+  "llvm.intr.rint"(%arg0) : (f32) -> f32
+  // CHECK: call double @llvm.rint.f64
+  "llvm.intr.rint"(%arg1) : (f64) -> f64
+  // CHECK: call <8 x float> @llvm.rint.v8f32
+  "llvm.intr.rint"(%arg2) : (vector<8xf32>) -> vector<8xf32>
+  // CHECK: call <8 x double> @llvm.rint.v8f64
+  "llvm.intr.rint"(%arg3) : (vector<8xf64>) -> vector<8xf64>
+  llvm.return
+}
+
+// CHECK-LABEL: @nearbyint_test
+llvm.func @nearbyint_test(%arg0 : f32, %arg1 : f64, %arg2 : vector<8xf32>, %arg3 : vector<8xf64>) {
+  // CHECK: call float @llvm.nearbyint.f32
+  "llvm.intr.nearbyint"(%arg0) : (f32) -> f32
+  // CHECK: call double @llvm.nearbyint.f64
+  "llvm.intr.nearbyint"(%arg1) : (f64) -> f64
+  // CHECK: call <8 x float> @llvm.nearbyint.v8f32
+  "llvm.intr.nearbyint"(%arg2) : (vector<8xf32>) -> vector<8xf32>
+  // CHECK: call <8 x double> @llvm.nearbyint.v8f64
+  "llvm.intr.nearbyint"(%arg3) : (vector<8xf64>) -> vector<8xf64>
+  llvm.return
+}
+
+// CHECK-LABEL: @lround_test
+llvm.func @lround_test(%arg0 : f32, %arg1 : f64) {
+  // CHECK: call i32 @llvm.lround.i32.f32
+  "llvm.intr.lround"(%arg0) : (f32) -> i32
+  // CHECK: call i64 @llvm.lround.i64.f32
+  "llvm.intr.lround"(%arg0) : (f32) -> i64
+  // CHECK: call i32 @llvm.lround.i32.f64
+  "llvm.intr.lround"(%arg1) : (f64) -> i32
+  // CHECK: call i64 @llvm.lround.i64.f64
+  "llvm.intr.lround"(%arg1) : (f64) -> i64
+  llvm.return
+}
+
+// CHECK-LABEL: @llround_test
+llvm.func @llround_test(%arg0 : f32, %arg1 : f64) {
+  // CHECK: call i64 @llvm.llround.i64.f32
+  "llvm.intr.llround"(%arg0) : (f32) -> i64
+  // CHECK: call i64 @llvm.llround.i64.f64
+  "llvm.intr.llround"(%arg1) : (f64) -> i64
+  llvm.return
+}
+
+// CHECK-LABEL: @lrint_test
+llvm.func @lrint_test(%arg0 : f32, %arg1 : f64) {
+  // CHECK: call i32 @llvm.lrint.i32.f32
+  "llvm.intr.lrint"(%arg0) : (f32) -> i32
+  // CHECK: call i64 @llvm.lrint.i64.f32
+  "llvm.intr.lrint"(%arg0) : (f32) -> i64
+  // CHECK: call i32 @llvm.lrint.i32.f64
+  "llvm.intr.lrint"(%arg1) : (f64) -> i32
+  // CHECK: call i64 @llvm.lrint.i64.f64
+  "llvm.intr.lrint"(%arg1) : (f64) -> i64
+  llvm.return
+}
+
+// CHECK-LABEL: @llrint_test
+llvm.func @llrint_test(%arg0 : f32, %arg1 : f64) {
+  // CHECK: call i64 @llvm.llrint.i64.f32
+  "llvm.intr.llrint"(%arg0) : (f32) -> i64
+  // CHECK: call i64 @llvm.llrint.i64.f64
+  "llvm.intr.llrint"(%arg1) : (f64) -> i64
+  llvm.return
+}
+
 // CHECK-LABEL: @bitreverse_test
 llvm.func @bitreverse_test(%arg0: i32, %arg1: vector<8xi32>) {
   // CHECK: call i32 @llvm.bitreverse.i32
@@ -865,6 +935,26 @@ llvm.func @lifetime(%p: !llvm.ptr) {
 // CHECK-DAG: declare float @llvm.cos.f32(float)
 // CHECK-DAG: declare <8 x float> @llvm.cos.v8f32(<8 x float>) #0
 // CHECK-DAG: declare float @llvm.copysign.f32(float, float)
+// CHECK-DAG: declare float @llvm.rint.f32(float)
+// CHECK-DAG: declare double @llvm.rint.f64(double)
+// CHECK-DAG: declare <8 x float> @llvm.rint.v8f32(<8 x float>)
+// CHECK-DAG: declare <8 x double> @llvm.rint.v8f64(<8 x double>)
+// CHECK-DAG: declare float @llvm.nearbyint.f32(float)
+// CHECK-DAG: declare double @llvm.nearbyint.f64(double)
+// CHECK-DAG: declare <8 x float> @llvm.nearbyint.v8f32(<8 x float>)
+// CHECK-DAG: declare <8 x double> @llvm.nearbyint.v8f64(<8 x double>)
+// CHECK-DAG: declare i32 @llvm.lround.i32.f32(float)
+// CHECK-DAG: declare i64 @llvm.lround.i64.f32(float)
+// CHECK-DAG: declare i32 @llvm.lround.i32.f64(double)
+// CHECK-DAG: declare i64 @llvm.lround.i64.f64(double)
+// CHECK-DAG: declare i64 @llvm.llround.i64.f32(float)
+// CHECK-DAG: declare i64 @llvm.llround.i64.f64(double)
+// CHECK-DAG: declare i32 @llvm.lrint.i32.f32(float)
+// CHECK-DAG: declare i64 @llvm.lrint.i64.f32(float)
+// CHECK-DAG: declare i32 @llvm.lrint.i32.f64(double)
+// CHECK-DAG: declare i64 @llvm.lrint.i64.f64(double)
+// CHECK-DAG: declare i64 @llvm.llrint.i64.f32(float)
+// CHECK-DAG: declare i64 @llvm.llrint.i64.f64(double)
 // CHECK-DAG: declare <12 x float> @llvm.matrix.multiply.v12f32.v64f32.v48f32(<64 x float>, <48 x float>, i32 immarg, i32 immarg, i32 immarg)
 // CHECK-DAG: declare <48 x float> @llvm.matrix.transpose.v48f32(<48 x float>, i32 immarg, i32 immarg)
 // CHECK-DAG: declare <48 x float> @llvm.matrix.column.major.load.v48f32.i64(ptr nocapture, i64, i1 immarg, i32 immarg, i32 immarg)