[mlir][vector] Allow transposing multi_reduction when the parallel dim is in the...
authorBenjamin Kramer <benny.kra@googlemail.com>
Tue, 24 Jan 2023 17:01:26 +0000 (18:01 +0100)
committerBenjamin Kramer <benny.kra@googlemail.com>
Thu, 26 Jan 2023 17:06:42 +0000 (18:06 +0100)
The check for the outer lowering wasn't quite right.

Differential Revision: https://reviews.llvm.org/D142483

mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir

index e89059c..117fdcb 100644 (file)
@@ -77,8 +77,9 @@ public:
       return failure();
 
     if (!useInnerDimsForReduction &&
-        (parallelDims !=
-         llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
+        (parallelDims == llvm::to_vector<4>(llvm::seq<int64_t>(
+                             reductionDims.size(),
+                             parallelDims.size() + reductionDims.size()))))
       return failure();
 
     SmallVector<int64_t, 4> indices;
index ee4ab7a..5647089 100644 (file)
@@ -234,3 +234,13 @@ func.func @vectorize_dynamic_transpose_reduction(%arg0: tensor<?x?x?xf32>, %arg1
 // CHECK:           %[[VAL_159:.*]] = vector.mask %[[VAL_158]] { vector.reduction <add>
 // CHECK:           %[[VAL_160:.*]] = vector.insertelement %[[VAL_159]]
 
+// -----
+
+func.func @vector_multi_reduction_parallel_middle(%arg0: vector<3x4x5xf32>, %acc: vector<4xf32>) -> vector<4xf32> {
+    %0 = vector.multi_reduction <add>, %arg0, %acc [0, 2] : vector<3x4x5xf32> to vector<4xf32>
+    return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @vector_multi_reduction_parallel_middle
+//  CHECK-SAME:   %[[INPUT:.+]]: vector<3x4x5xf32>, %[[ACC:.+]]: vector<4xf32>
+//       CHECK: vector.transpose %[[INPUT]], [1, 0, 2] : vector<3x4x5xf32> to vector<4x3x5xf32>
index 8a8bf86..9f22972 100644 (file)
@@ -162,6 +162,15 @@ func.func @vector_reduction_outer(%arg0: vector<2x3x4x5xi32>, %acc: vector<2x3xi
 //       CHECK:   %[[RESULT_VEC:.+]] = vector.shape_cast %[[R18]] : vector<6xi32> to vector<2x3xi32>
 //       CHECK:   return %[[RESULT_VEC]] : vector<2x3xi32>
 
+func.func @vector_multi_reduction_parallel_middle(%arg0: vector<3x4x5xf32>, %acc: vector<4xf32>) -> vector<4xf32> {
+    %0 = vector.multi_reduction <add>, %arg0, %acc [0, 2] : vector<3x4x5xf32> to vector<4xf32>
+    return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @vector_multi_reduction_parallel_middle
+//  CHECK-SAME:   %[[INPUT:.+]]: vector<3x4x5xf32>, %[[ACC:.+]]: vector<4xf32>
+//       CHECK: vector.transpose %[[INPUT]], [0, 2, 1] : vector<3x4x5xf32> to vector<3x5x4xf32>
+
 // This test is mainly to catch a bug that running
 // `InnerOuterDimReductionConversion` on this function results in an
 // infinite loop. So just check that some value is returned.