From 3941355d8fee763e99c259ecd02f6fe567583296 Mon Sep 17 00:00:00 2001 From: Kai Sasaki Date: Wed, 8 Feb 2023 11:51:08 +0900 Subject: [PATCH] [mlir][vector] Support 0-D vector when eliding single element reduction ElideSingleElementReduction causes assertion failure when we give 0-D vector. It's possible to fold the case by using vector.extractelement op instead. It's originally reported in https://github.com/llvm/llvm-project/issues/60193. Reviewed By: dcaballe Differential Revision: https://reviews.llvm.org/D143242 --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 14 ++++++++++---- mlir/test/Dialect/Vector/canonicalize.mlir | 10 ++++++++++ 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 32ae7b1..8073757 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -530,13 +530,19 @@ struct ElideSingleElementReduction : public OpRewritePattern { if (maskableOp.isMasked()) return failure(); - if (reductionOp.getVectorType().getDimSize(0) != 1) + auto vectorType = reductionOp.getVectorType(); + if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1) return failure(); Location loc = reductionOp.getLoc(); - Value result = rewriter.create(loc, reductionOp.getType(), - reductionOp.getVector(), - rewriter.getI64ArrayAttr(0)); + Value result; + if (vectorType.getRank() == 0) { + result = rewriter.create(loc, reductionOp.getVector()); + } else { + result = rewriter.create(loc, reductionOp.getType(), + reductionOp.getVector(), + rewriter.getI64ArrayAttr(0)); + } if (Value acc = reductionOp.getAcc()) result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(), diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 8fc1834..cac24b3 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -2157,3 +2157,13 @@ func.func @fold_extractelement_of_broadcast(%f: f32) -> f32 { %1 = vector.extractelement %0 [%c5 : index] : vector<15xf32> return %1 : f32 } + +// ----- + +// CHECK-LABEL: func.func @fold_0d_vector_reduction +func.func @fold_0d_vector_reduction(%arg0: vector) -> f32 { + // CHECK-NEXT: %[[RES:.*]] = vector.extractelement %arg{{.*}}[] : vector + // CHECK-NEXT: return %[[RES]] : f32 + %0 = vector.reduction , %arg0 : vector into f32 + return %0 : f32 +} -- 2.7.4