From 7becf0f6cd31ea7462c5e18a88cb2f7a2c508886 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Thu, 7 Apr 2022 12:59:09 -0400 Subject: [PATCH] [mlir][vector] Fold extract(broadcast) of same rank This case is handled in neither the folding or canonicalization patterns. The folding pattern cannot generate new broadcast ops, so it should be handled by the canonicalization pattern. Reviewed By: ThomasRaoux Differential Revision: https://reviews.llvm.org/D123307 --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 9 +++++---- mlir/test/Dialect/Vector/canonicalize.mlir | 12 ++++++++++++ 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index eda7739..07546c0 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1496,6 +1496,7 @@ public: Operation *defOp = extractOp.getVector().getDefiningOp(); if (!defOp || !isa(defOp)) return failure(); + Value source = defOp->getOperand(0); if (extractOp.getType() == source.getType()) return failure(); @@ -1504,10 +1505,10 @@ public: }; unsigned broadcastSrcRank = getRank(source.getType()); unsigned extractResultRank = getRank(extractOp.getType()); - // We only consider the case where the rank of the source is smaller than - // the rank of the extract dst. The other cases are handled in the folding - // patterns. - if (extractResultRank <= broadcastSrcRank) + // We only consider the case where the rank of the source is less than or + // equal to the rank of the extract dst. The other cases are handled in the + // folding patterns. + if (extractResultRank < broadcastSrcRank) return failure(); rewriter.replaceOpWithNewOp( extractOp, extractOp.getType(), source); diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index a083851..8b6640b 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -566,6 +566,18 @@ func @fold_extract_broadcast(%a : f32) -> vector<4xf32> { // ----- +// CHECK-LABEL: fold_extract_broadcast +// CHECK-SAME: %[[A:.*]]: vector<1xf32> +// CHECK: %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<8xf32> +// CHECK: return %[[R]] : vector<8xf32> +func @fold_extract_broadcast(%a : vector<1xf32>) -> vector<8xf32> { + %b = vector.broadcast %a : vector<1xf32> to vector<1x8xf32> + %r = vector.extract %b[0] : vector<1x8xf32> + return %r : vector<8xf32> +} + +// ----- + // CHECK-LABEL: func @fold_extract_shapecast // CHECK-SAME: (%[[A0:.*]]: vector<5x1x3x2xf32>, %[[A1:.*]]: vector<8x4x2xf32> // CHECK: %[[R0:.*]] = vector.extract %[[A0]][1, 0, 1, 1] : vector<5x1x3x2xf32> -- 2.7.4