From 196d89740c5e8bf238200b7f95e6173b231aa5d2 Mon Sep 17 00:00:00 2001 From: Lukas Sommer Date: Mon, 29 May 2023 17:58:50 +0200 Subject: [PATCH] [mlir][llvm] Add rounding intrinsics Add some of the missing libm rounding intrinsics to the LLVM dialect: * `llvm.rint` * `llvm.nearbyint` * `llvm.lround` * `llvm.llround` * `llvm.lrint` * `llvm.llrint` Differential Revision: https://reviews.llvm.org/D151558 --- .../mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td | 12 +++ mlir/test/Target/LLVMIR/Import/intrinsic.ll | 86 +++++++++++++++++++++ mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir | 90 ++++++++++++++++++++++ 3 files changed, 188 insertions(+) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td index eb815b3..a409223 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td @@ -130,6 +130,18 @@ def LLVM_PowIOp : LLVM_OneResultIntrOp<"powi", [], [0,1], let assemblyFormat = "`(` operands `)` custom(attr-dict) `:` " "functional-type(operands, results)"; } +def LLVM_RintOp : LLVM_UnaryIntrOpF<"rint">; +def LLVM_NearbyintOp : LLVM_UnaryIntrOpF<"nearbyint">; +class LLVM_IntRoundIntrOpBase : + LLVM_OneResultIntrOp { + let arguments = (ins LLVM_AnyFloat:$val); + let assemblyFormat = "`(` operands `)` custom(attr-dict) `:` " + "functional-type(operands, results)"; +} +def LLVM_LroundOp : LLVM_IntRoundIntrOpBase<"lround">; +def LLVM_LlroundOp : LLVM_IntRoundIntrOpBase<"llround">; +def LLVM_LrintOp : LLVM_IntRoundIntrOpBase<"lrint">; +def LLVM_LlrintOp : LLVM_IntRoundIntrOpBase<"llrint">; def LLVM_BitReverseOp : LLVM_UnaryIntrOpI<"bitreverse">; def LLVM_ByteSwapOp : LLVM_UnaryIntrOpI<"bswap">; def LLVM_CountLeadingZerosOp : LLVM_CountZerosIntrOp<"ctlz">; diff --git a/mlir/test/Target/LLVMIR/Import/intrinsic.ll b/mlir/test/Target/LLVMIR/Import/intrinsic.ll index 811dc44..e9b3615 100644 --- a/mlir/test/Target/LLVMIR/Import/intrinsic.ll +++ b/mlir/test/Target/LLVMIR/Import/intrinsic.ll @@ -117,6 +117,72 @@ define void @pow_test(float %0, float %1, <8 x float> %2, <8 x float> %3) { %6 = call <8 x float> @llvm.pow.v8f32(<8 x float> %2, <8 x float> %3) ret void } + +; CHECK-LABEL: llvm.func @rint_test +define void @rint_test(float %0, double %1, <8 x float> %2, <8 x double> %3) { + ; CHECK: llvm.intr.rint(%{{.*}}) : (f32) -> f32 + %5 = call float @llvm.rint.f32(float %0) + ; CHECK: llvm.intr.rint(%{{.*}}) : (f64) -> f64 + %6 = call double @llvm.rint.f64(double %1) + ; CHECK: llvm.intr.rint(%{{.*}}) : (vector<8xf32>) -> vector<8xf32> + %7 = call <8 x float> @llvm.rint.v8f32(<8 x float> %2) + ; CHECK: llvm.intr.rint(%{{.*}}) : (vector<8xf64>) -> vector<8xf64> + %8 = call <8 x double> @llvm.rint.v8f64(<8 x double> %3) + ret void +} +; CHECK-LABEL: llvm.func @nearbyint_test +define void @nearbyint_test(float %0, double %1, <8 x float> %2, <8 x double> %3) { + ; CHECK: llvm.intr.nearbyint(%{{.*}}) : (f32) -> f32 + %5 = call float @llvm.nearbyint.f32(float %0) + ; CHECK: llvm.intr.nearbyint(%{{.*}}) : (f64) -> f64 + %6 = call double @llvm.nearbyint.f64(double %1) + ; CHECK: llvm.intr.nearbyint(%{{.*}}) : (vector<8xf32>) -> vector<8xf32> + %7 = call <8 x float> @llvm.nearbyint.v8f32(<8 x float> %2) + ; CHECK: llvm.intr.nearbyint(%{{.*}}) : (vector<8xf64>) -> vector<8xf64> + %8 = call <8 x double> @llvm.nearbyint.v8f64(<8 x double> %3) + ret void +} +; CHECK-LABEL: llvm.func @lround_test +define void @lround_test(float %0, double %1) { + ; CHECK: llvm.intr.lround(%{{.*}}) : (f32) -> i32 + %3 = call i32 @llvm.lround.i32.f32(float %0) + ; CHECK: llvm.intr.lround(%{{.*}}) : (f32) -> i64 + %4 = call i64 @llvm.lround.i64.f32(float %0) + ; CHECK: llvm.intr.lround(%{{.*}}) : (f64) -> i32 + %5 = call i32 @llvm.lround.i32.f64(double %1) + ; CHECK: llvm.intr.lround(%{{.*}}) : (f64) -> i64 + %6 = call i64 @llvm.lround.i64.f64(double %1) + ret void +} +; CHECK-LABEL: llvm.func @llround_test +define void @llround_test(float %0, double %1) { + ; CHECK: llvm.intr.llround(%{{.*}}) : (f32) -> i64 + %3 = call i64 @llvm.llround.i64.f32(float %0) + ; CHECK: llvm.intr.llround(%{{.*}}) : (f64) -> i64 + %4 = call i64 @llvm.llround.i64.f64(double %1) + ret void +} +; CHECK-LABEL: llvm.func @lrint_test +define void @lrint_test(float %0, double %1) { + ; CHECK: llvm.intr.lrint(%{{.*}}) : (f32) -> i32 + %3 = call i32 @llvm.lrint.i32.f32(float %0) + ; CHECK: llvm.intr.lrint(%{{.*}}) : (f32) -> i64 + %4 = call i64 @llvm.lrint.i64.f32(float %0) + ; CHECK: llvm.intr.lrint(%{{.*}}) : (f64) -> i32 + %5 = call i32 @llvm.lrint.i32.f64(double %1) + ; CHECK: llvm.intr.lrint(%{{.*}}) : (f64) -> i64 + %6 = call i64 @llvm.lrint.i64.f64(double %1) + ret void +} +; CHECK-LABEL: llvm.func @llrint_test +define void @llrint_test(float %0, double %1) { + ; CHECK: llvm.intr.llrint(%{{.*}}) : (f32) -> i64 + %3 = call i64 @llvm.llrint.i64.f32(float %0) + ; CHECK: llvm.intr.llrint(%{{.*}}) : (f64) -> i64 + %4 = call i64 @llvm.llrint.i64.f64(double %1) + ret void +} + ; CHECK-LABEL: llvm.func @bitreverse_test define void @bitreverse_test(i32 %0, <8 x i32> %1) { ; CHECK: llvm.intr.bitreverse(%{{.*}}) : (i32) -> i32 @@ -781,6 +847,26 @@ declare float @llvm.copysign.f32(float, float) declare <8 x float> @llvm.copysign.v8f32(<8 x float>, <8 x float>) declare float @llvm.pow.f32(float, float) declare <8 x float> @llvm.pow.v8f32(<8 x float>, <8 x float>) +declare float @llvm.rint.f32(float) +declare double @llvm.rint.f64(double) +declare <8 x float> @llvm.rint.v8f32(<8 x float>) +declare <8 x double> @llvm.rint.v8f64(<8 x double>) +declare float @llvm.nearbyint.f32(float) +declare double @llvm.nearbyint.f64(double) +declare <8 x float> @llvm.nearbyint.v8f32(<8 x float>) +declare <8 x double> @llvm.nearbyint.v8f64(<8 x double>) +declare i32 @llvm.lround.i32.f32(float) +declare i64 @llvm.lround.i64.f32(float) +declare i32 @llvm.lround.i32.f64(double) +declare i64 @llvm.lround.i64.f64(double) +declare i64 @llvm.llround.i64.f32(float) +declare i64 @llvm.llround.i64.f64(double) +declare i32 @llvm.lrint.i32.f32(float) +declare i64 @llvm.lrint.i64.f32(float) +declare i32 @llvm.lrint.i32.f64(double) +declare i64 @llvm.lrint.i64.f64(double) +declare i64 @llvm.llrint.i64.f32(float) +declare i64 @llvm.llrint.i64.f64(double) declare i32 @llvm.bitreverse.i32(i32) declare <8 x i32> @llvm.bitreverse.v8i32(<8 x i32>) declare i32 @llvm.bswap.i32(i32) diff --git a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir index c6a3c7f..ec619b9 100644 --- a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir +++ b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir @@ -134,6 +134,76 @@ llvm.func @pow_test(%arg0: f32, %arg1: f32, %arg2: vector<8xf32>, %arg3: vector< llvm.return } +// CHECK-LABEL: @rint_test +llvm.func @rint_test(%arg0 : f32, %arg1 : f64, %arg2 : vector<8xf32>, %arg3 : vector<8xf64>) { + // CHECK: call float @llvm.rint.f32 + "llvm.intr.rint"(%arg0) : (f32) -> f32 + // CHECK: call double @llvm.rint.f64 + "llvm.intr.rint"(%arg1) : (f64) -> f64 + // CHECK: call <8 x float> @llvm.rint.v8f32 + "llvm.intr.rint"(%arg2) : (vector<8xf32>) -> vector<8xf32> + // CHECK: call <8 x double> @llvm.rint.v8f64 + "llvm.intr.rint"(%arg3) : (vector<8xf64>) -> vector<8xf64> + llvm.return +} + +// CHECK-LABEL: @nearbyint_test +llvm.func @nearbyint_test(%arg0 : f32, %arg1 : f64, %arg2 : vector<8xf32>, %arg3 : vector<8xf64>) { + // CHECK: call float @llvm.nearbyint.f32 + "llvm.intr.nearbyint"(%arg0) : (f32) -> f32 + // CHECK: call double @llvm.nearbyint.f64 + "llvm.intr.nearbyint"(%arg1) : (f64) -> f64 + // CHECK: call <8 x float> @llvm.nearbyint.v8f32 + "llvm.intr.nearbyint"(%arg2) : (vector<8xf32>) -> vector<8xf32> + // CHECK: call <8 x double> @llvm.nearbyint.v8f64 + "llvm.intr.nearbyint"(%arg3) : (vector<8xf64>) -> vector<8xf64> + llvm.return +} + +// CHECK-LABEL: @lround_test +llvm.func @lround_test(%arg0 : f32, %arg1 : f64) { + // CHECK: call i32 @llvm.lround.i32.f32 + "llvm.intr.lround"(%arg0) : (f32) -> i32 + // CHECK: call i64 @llvm.lround.i64.f32 + "llvm.intr.lround"(%arg0) : (f32) -> i64 + // CHECK: call i32 @llvm.lround.i32.f64 + "llvm.intr.lround"(%arg1) : (f64) -> i32 + // CHECK: call i64 @llvm.lround.i64.f64 + "llvm.intr.lround"(%arg1) : (f64) -> i64 + llvm.return +} + +// CHECK-LABEL: @llround_test +llvm.func @llround_test(%arg0 : f32, %arg1 : f64) { + // CHECK: call i64 @llvm.llround.i64.f32 + "llvm.intr.llround"(%arg0) : (f32) -> i64 + // CHECK: call i64 @llvm.llround.i64.f64 + "llvm.intr.llround"(%arg1) : (f64) -> i64 + llvm.return +} + +// CHECK-LABEL: @lrint_test +llvm.func @lrint_test(%arg0 : f32, %arg1 : f64) { + // CHECK: call i32 @llvm.lrint.i32.f32 + "llvm.intr.lrint"(%arg0) : (f32) -> i32 + // CHECK: call i64 @llvm.lrint.i64.f32 + "llvm.intr.lrint"(%arg0) : (f32) -> i64 + // CHECK: call i32 @llvm.lrint.i32.f64 + "llvm.intr.lrint"(%arg1) : (f64) -> i32 + // CHECK: call i64 @llvm.lrint.i64.f64 + "llvm.intr.lrint"(%arg1) : (f64) -> i64 + llvm.return +} + +// CHECK-LABEL: @llrint_test +llvm.func @llrint_test(%arg0 : f32, %arg1 : f64) { + // CHECK: call i64 @llvm.llrint.i64.f32 + "llvm.intr.llrint"(%arg0) : (f32) -> i64 + // CHECK: call i64 @llvm.llrint.i64.f64 + "llvm.intr.llrint"(%arg1) : (f64) -> i64 + llvm.return +} + // CHECK-LABEL: @bitreverse_test llvm.func @bitreverse_test(%arg0: i32, %arg1: vector<8xi32>) { // CHECK: call i32 @llvm.bitreverse.i32 @@ -865,6 +935,26 @@ llvm.func @lifetime(%p: !llvm.ptr) { // CHECK-DAG: declare float @llvm.cos.f32(float) // CHECK-DAG: declare <8 x float> @llvm.cos.v8f32(<8 x float>) #0 // CHECK-DAG: declare float @llvm.copysign.f32(float, float) +// CHECK-DAG: declare float @llvm.rint.f32(float) +// CHECK-DAG: declare double @llvm.rint.f64(double) +// CHECK-DAG: declare <8 x float> @llvm.rint.v8f32(<8 x float>) +// CHECK-DAG: declare <8 x double> @llvm.rint.v8f64(<8 x double>) +// CHECK-DAG: declare float @llvm.nearbyint.f32(float) +// CHECK-DAG: declare double @llvm.nearbyint.f64(double) +// CHECK-DAG: declare <8 x float> @llvm.nearbyint.v8f32(<8 x float>) +// CHECK-DAG: declare <8 x double> @llvm.nearbyint.v8f64(<8 x double>) +// CHECK-DAG: declare i32 @llvm.lround.i32.f32(float) +// CHECK-DAG: declare i64 @llvm.lround.i64.f32(float) +// CHECK-DAG: declare i32 @llvm.lround.i32.f64(double) +// CHECK-DAG: declare i64 @llvm.lround.i64.f64(double) +// CHECK-DAG: declare i64 @llvm.llround.i64.f32(float) +// CHECK-DAG: declare i64 @llvm.llround.i64.f64(double) +// CHECK-DAG: declare i32 @llvm.lrint.i32.f32(float) +// CHECK-DAG: declare i64 @llvm.lrint.i64.f32(float) +// CHECK-DAG: declare i32 @llvm.lrint.i32.f64(double) +// CHECK-DAG: declare i64 @llvm.lrint.i64.f64(double) +// CHECK-DAG: declare i64 @llvm.llrint.i64.f32(float) +// CHECK-DAG: declare i64 @llvm.llrint.i64.f64(double) // CHECK-DAG: declare <12 x float> @llvm.matrix.multiply.v12f32.v64f32.v48f32(<64 x float>, <48 x float>, i32 immarg, i32 immarg, i32 immarg) // CHECK-DAG: declare <48 x float> @llvm.matrix.transpose.v48f32(<48 x float>, i32 immarg, i32 immarg) // CHECK-DAG: declare <48 x float> @llvm.matrix.column.major.load.v48f32.i64(ptr nocapture, i64, i1 immarg, i32 immarg, i32 immarg) -- 2.7.4