[mlir][vector] Support 0-D vector when eliding single element reduction
authorKai Sasaki <lewuathe@gmail.com>
Wed, 8 Feb 2023 02:51:08 +0000 (11:51 +0900)
committerKai Sasaki <lewuathe@gmail.com>
Wed, 8 Feb 2023 03:01:56 +0000 (12:01 +0900)
ElideSingleElementReduction causes assertion failure when we give 0-D vector. It's possible to fold the case by using vector.extractelement op instead. It's originally reported in https://github.com/llvm/llvm-project/issues/60193.

Reviewed By: dcaballe

Differential Revision: https://reviews.llvm.org/D143242

mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir

index 32ae7b1..8073757 100644 (file)
@@ -530,13 +530,19 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
     if (maskableOp.isMasked())
       return failure();
 
-    if (reductionOp.getVectorType().getDimSize(0) != 1)
+    auto vectorType = reductionOp.getVectorType();
+    if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
       return failure();
 
     Location loc = reductionOp.getLoc();
-    Value result = rewriter.create<ExtractOp>(loc, reductionOp.getType(),
-                                              reductionOp.getVector(),
-                                              rewriter.getI64ArrayAttr(0));
+    Value result;
+    if (vectorType.getRank() == 0) {
+      result = rewriter.create<ExtractElementOp>(loc, reductionOp.getVector());
+    } else {
+      result = rewriter.create<ExtractOp>(loc, reductionOp.getType(),
+                                          reductionOp.getVector(),
+                                          rewriter.getI64ArrayAttr(0));
+    }
 
     if (Value acc = reductionOp.getAcc())
       result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
index 8fc1834..cac24b3 100644 (file)
@@ -2157,3 +2157,13 @@ func.func @fold_extractelement_of_broadcast(%f: f32) -> f32 {
   %1 = vector.extractelement %0 [%c5 : index] : vector<15xf32>
   return %1 : f32
 }
+
+// -----
+
+// CHECK-LABEL: func.func @fold_0d_vector_reduction
+func.func @fold_0d_vector_reduction(%arg0: vector<f32>) -> f32 {
+  // CHECK-NEXT: %[[RES:.*]] = vector.extractelement %arg{{.*}}[] : vector<f32>
+  // CHECK-NEXT: return %[[RES]] : f32
+  %0 = vector.reduction <add>, %arg0 : vector<f32> into f32
+  return %0 : f32
+}