[mlir][LLVM] Introduce reduction intrinsics for minimum/maximum
authorDaniil Dudkin <unterumarmung@yandex.ru>
Sat, 22 Jul 2023 13:25:26 +0000 (16:25 +0300)
committerDaniil Dudkin <unterumarmung@yandex.ru>
Sat, 22 Jul 2023 13:25:32 +0000 (16:25 +0300)
This patch adds supports for the reduction intrinsic
for floating point minimum and maximum that have
been added to LLVM by https://reviews.llvm.org/D152370.

Related to: #63969

Reviewed By: dcaballe

Differential Revision: https://reviews.llvm.org/D155869

mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
mlir/test/Dialect/LLVMIR/roundtrip.mlir
mlir/test/Integration/Dialect/LLVMIR/CPU/test-vector-reductions-fp.mlir
mlir/test/Target/LLVMIR/Import/fastmath.ll
mlir/test/Target/LLVMIR/Import/intrinsic.ll
mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
mlir/test/Target/LLVMIR/llvmir.mlir

index 94c3687..c144afb 100644 (file)
@@ -718,6 +718,8 @@ def LLVM_vector_reduce_xor : LLVM_VecReductionI<"xor">;
 
 def LLVM_vector_reduce_fmax : LLVM_VecReductionF<"fmax">;
 def LLVM_vector_reduce_fmin : LLVM_VecReductionF<"fmin">;
+def LLVM_vector_reduce_fmaximum : LLVM_VecReductionF<"fmaximum">;
+def LLVM_vector_reduce_fminimum : LLVM_VecReductionF<"fminimum">;
 
 def LLVM_vector_reduce_fadd : LLVM_VecReductionAccF<"fadd">;
 def LLVM_vector_reduce_fmul : LLVM_VecReductionAccF<"fmul">;
index ecf73a2..b9ce693 100644 (file)
@@ -516,6 +516,10 @@ func.func @fastmathFlags(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: vector<2 x f
   %13 = llvm.intr.vector.reduce.fmin(%arg3) {fastmathFlags = #llvm.fastmath<nnan>} : (vector<2xf32>) -> f32
 // CHECK: {{.*}} = llvm.intr.vector.reduce.fmax(%arg3) {fastmathFlags = #llvm.fastmath<nnan>} : (vector<2xf32>) -> f32
   %14 = llvm.intr.vector.reduce.fmax(%arg3) {fastmathFlags = #llvm.fastmath<nnan>} : (vector<2xf32>) -> f32
+// CHECK: {{.*}} = llvm.intr.vector.reduce.fminimum(%arg3) {fastmathFlags = #llvm.fastmath<nnan>} : (vector<2xf32>) -> f32
+  %15 = llvm.intr.vector.reduce.fminimum(%arg3) {fastmathFlags = #llvm.fastmath<nnan>} : (vector<2xf32>) -> f32
+// CHECK: {{.*}} = llvm.intr.vector.reduce.fmaximum(%arg3) {fastmathFlags = #llvm.fastmath<nnan>} : (vector<2xf32>) -> f32
+  %16 = llvm.intr.vector.reduce.fmaximum(%arg3) {fastmathFlags = #llvm.fastmath<nnan>} : (vector<2xf32>) -> f32
   return
 }
 
index f2e59e1..d2431fc 100644 (file)
@@ -35,6 +35,18 @@ module {
     llvm.call @printNewline() : () -> ()
     // CHECK: 1
 
+    %maximum = llvm.intr.vector.reduce.fmaximum(%v)
+        : (vector<4xf32>) -> f32
+    llvm.call @printF32(%maximum) : (f32) -> ()
+    llvm.call @printNewline() : () -> ()
+    // CHECK: 4
+
+    %minimum = llvm.intr.vector.reduce.fminimum(%v)
+        : (vector<4xf32>) -> f32
+    llvm.call @printF32(%minimum) : (f32) -> ()
+    llvm.call @printNewline() : () -> ()
+    // CHECK: 1
+
     %add1 = "llvm.intr.vector.reduce.fadd"(%0, %v)
         : (f32, vector<4xf32>) -> f32
     llvm.call @printF32(%add1) : (f32) -> ()
index 9b30c32..0c6a74c 100644 (file)
@@ -45,6 +45,8 @@ declare float @llvm.pow.f32(float, float)
 declare float @llvm.fmuladd.f32(float, float, float)
 declare float @llvm.vector.reduce.fmin.v2f32(<2 x float>)
 declare float @llvm.vector.reduce.fmax.v2f32(<2 x float>)
+declare float @llvm.vector.reduce.fminimum.v2f32(<2 x float>)
+declare float @llvm.vector.reduce.fmaximum.v2f32(<2 x float>)
 
 ; CHECK-LABEL: @fastmath_intr
 define void @fastmath_intr(float %arg1, i32 %arg2, <2 x float> %arg3) {
@@ -60,6 +62,10 @@ define void @fastmath_intr(float %arg1, i32 %arg2, <2 x float> %arg3) {
   %5 = call nnan float @llvm.vector.reduce.fmin.v2f32(<2 x float> %arg3)
   ; CHECK: %{{.*}} = llvm.intr.vector.reduce.fmax({{.*}}) {fastmathFlags = #llvm.fastmath<nnan>} : (vector<2xf32>) -> f32
   %6 = call nnan float @llvm.vector.reduce.fmax.v2f32(<2 x float> %arg3)
+  ; CHECK: %{{.*}} = llvm.intr.vector.reduce.fminimum({{.*}}) {fastmathFlags = #llvm.fastmath<nnan>} : (vector<2xf32>) -> f32
+  %7 = call nnan float @llvm.vector.reduce.fminimum.v2f32(<2 x float> %arg3)
+  ; CHECK: %{{.*}} = llvm.intr.vector.reduce.fmaximum({{.*}}) {fastmathFlags = #llvm.fastmath<nnan>} : (vector<2xf32>) -> f32
+  %8 = call nnan float @llvm.vector.reduce.fmaximum.v2f32(<2 x float> %arg3)
 
   ret void
 }
index a0aa4ae..543de93 100644 (file)
@@ -364,6 +364,10 @@ define void @vector_reductions(float %0, <8 x float> %1, <8 x i32> %2) {
   %17 = call reassoc float @llvm.vector.reduce.fmul.v8f32(float %0, <8 x float> %1)
   ; CHECK:  "llvm.intr.vector.reduce.xor"(%{{.*}}) : (vector<8xi32>) -> i32
   %18 = call i32 @llvm.vector.reduce.xor.v8i32(<8 x i32> %2)
+  ; CHECK: llvm.intr.vector.reduce.fmaximum(%{{.*}}) : (vector<8xf32>) -> f32
+  %19 = call float @llvm.vector.reduce.fmaximum.v8f32(<8 x float> %1)
+  ; CHECK: llvm.intr.vector.reduce.fminimum(%{{.*}}) : (vector<8xf32>) -> f32
+  %20 = call float @llvm.vector.reduce.fminimum.v8f32(<8 x float> %1)
   ret void
 }
 
@@ -944,6 +948,8 @@ declare i32 @llvm.vector.reduce.add.v8i32(<8 x i32>)
 declare i32 @llvm.vector.reduce.and.v8i32(<8 x i32>)
 declare float @llvm.vector.reduce.fmax.v8f32(<8 x float>)
 declare float @llvm.vector.reduce.fmin.v8f32(<8 x float>)
+declare float @llvm.vector.reduce.fmaximum.v8f32(<8 x float>)
+declare float @llvm.vector.reduce.fminimum.v8f32(<8 x float>)
 declare i32 @llvm.vector.reduce.mul.v8i32(<8 x i32>)
 declare i32 @llvm.vector.reduce.or.v8i32(<8 x i32>)
 declare i32 @llvm.vector.reduce.smax.v8i32(<8 x i32>)
index d6c6855..d666017 100644 (file)
@@ -354,6 +354,10 @@ llvm.func @vector_reductions(%arg0: f32, %arg1: vector<8xf32>, %arg2: vector<8xi
   llvm.intr.vector.reduce.fmax(%arg1) : (vector<8xf32>) -> f32
   // CHECK: call float @llvm.vector.reduce.fmin.v8f32
   llvm.intr.vector.reduce.fmin(%arg1) : (vector<8xf32>) -> f32
+  // CHECK: call float @llvm.vector.reduce.fmaximum.v8f32
+  llvm.intr.vector.reduce.fmaximum(%arg1) : (vector<8xf32>) -> f32
+  // CHECK: call float @llvm.vector.reduce.fminimum.v8f32
+  llvm.intr.vector.reduce.fminimum(%arg1) : (vector<8xf32>) -> f32
   // CHECK: call i32 @llvm.vector.reduce.mul.v8i32
   "llvm.intr.vector.reduce.mul"(%arg2) : (vector<8xi32>) -> i32
   // CHECK: call i32 @llvm.vector.reduce.or.v8i32
index 3f97ebd..fbe0d30 100644 (file)
@@ -2011,10 +2011,14 @@ llvm.func @fastmathFlags(%arg0: f32, %arg1 : vector<2xf32>) {
   %21 = llvm.intr.vector.reduce.fmax(%arg1) {fastmathFlags = #llvm.fastmath<nnan>} : (vector<2xf32>) -> f32
   %22 = llvm.intr.vector.reduce.fmin(%arg1) {fastmathFlags = #llvm.fastmath<nnan>} : (vector<2xf32>) -> f32
 
+// CHECK: call nnan float @llvm.vector.reduce.fmaximum.v2f32(<2 x float> {{.*}})
+// CHECK: call nnan float @llvm.vector.reduce.fminimum.v2f32(<2 x float> {{.*}})
+  %23 = llvm.intr.vector.reduce.fmaximum(%arg1) {fastmathFlags = #llvm.fastmath<nnan>} : (vector<2xf32>) -> f32
+  %24 = llvm.intr.vector.reduce.fminimum(%arg1) {fastmathFlags = #llvm.fastmath<nnan>} : (vector<2xf32>) -> f32
 
-  %23 = llvm.mlir.constant(true) : i1
+  %25 = llvm.mlir.constant(true) : i1
 // CHECK: select contract i1
-  %24 = llvm.select %23, %arg0, %20 {fastmathFlags = #llvm.fastmath<contract>} : i1, f32
+  %26 = llvm.select %25, %arg0, %20 {fastmathFlags = #llvm.fastmath<contract>} : i1, f32
   llvm.return
 }