From 92e83afe44fbfd81ffd428bb41b7f760eee712f9 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Tue, 6 Oct 2020 09:56:35 -0700 Subject: [PATCH] [mlir][vector] Fold extractOp coming from broadcastOp Combine ExtractOp with scalar result with BroadcastOp source. This is useful to be able to incrementally convert degenerated vector of one element into scalar. Differential Revision: https://reviews.llvm.org/D88751 --- mlir/lib/Dialect/Vector/VectorOps.cpp | 33 ++++++++++++++++++++ mlir/test/Dialect/Vector/canonicalize.mlir | 48 ++++++++++++++++++++++++++++++ 2 files changed, 81 insertions(+) diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp index 672ad40..b71102c 100644 --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -812,6 +812,37 @@ static Value foldExtractOpFromInsertChainAndTranspose(ExtractOp extractOp) { return Value(); } +/// Fold extractOp with scalar result coming from BroadcastOp. +static Value foldExtractFromBroadcast(ExtractOp extractOp) { + auto broadcastOp = extractOp.vector().getDefiningOp(); + if (!broadcastOp) + return Value(); + if (extractOp.getType() == broadcastOp.getSourceType()) + return broadcastOp.source(); + auto getRank = [](Type type) { + return type.isa() ? type.cast().getRank() : 0; + }; + unsigned broadcasrSrcRank = getRank(broadcastOp.getSourceType()); + unsigned extractResultRank = getRank(extractOp.getType()); + if (extractResultRank < broadcasrSrcRank) { + auto extractPos = extractVector(extractOp.position()); + unsigned rankDiff = broadcasrSrcRank - extractResultRank; + extractPos.erase( + extractPos.begin(), + std::next(extractPos.begin(), extractPos.size() - rankDiff)); + extractOp.setOperand(broadcastOp.source()); + // OpBuilder is only used as a helper to build an I64ArrayAttr. + OpBuilder b(extractOp.getContext()); + extractOp.setAttr(ExtractOp::getPositionAttrName(), + b.getI64ArrayAttr(extractPos)); + return extractOp.getResult(); + } + // TODO: In case the rank of the broadcast source is greater than the rank of + // the extract result this can be combined into a new broadcast op. This needs + // to be added a canonicalization pattern if needed. + return Value(); +} + OpFoldResult ExtractOp::fold(ArrayRef) { if (succeeded(foldExtractOpFromExtractChain(*this))) return getResult(); @@ -819,6 +850,8 @@ OpFoldResult ExtractOp::fold(ArrayRef) { return getResult(); if (auto val = foldExtractOpFromInsertChainAndTranspose(*this)) return val; + if (auto val = foldExtractFromBroadcast(*this)) + return val; return OpFoldResult(); } diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 9c36f76..2f927a1 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -348,6 +348,54 @@ func @fold_extract_transpose( // ----- +// CHECK-LABEL: fold_extract_broadcast +// CHECK-SAME: %[[A:.*]]: f32 +// CHECK: return %[[A]] : f32 +func @fold_extract_broadcast(%a : f32) -> f32 { + %b = vector.broadcast %a : f32 to vector<1x2x4xf32> + %r = vector.extract %b[0, 1, 2] : vector<1x2x4xf32> + return %r : f32 +} + +// ----- + +// CHECK-LABEL: fold_extract_broadcast_vector +// CHECK-SAME: %[[A:.*]]: vector<4xf32> +// CHECK: return %[[A]] : vector<4xf32> +func @fold_extract_broadcast_vector(%a : vector<4xf32>) -> vector<4xf32> { + %b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32> + %r = vector.extract %b[0, 1] : vector<1x2x4xf32> + return %r : vector<4xf32> +} + +// ----- + +// CHECK-LABEL: fold_extract_broadcast +// CHECK-SAME: %[[A:.*]]: vector<4xf32> +// CHECK: %[[R:.*]] = vector.extract %[[A]][2] : vector<4xf32> +// CHECK: return %[[R]] : f32 +func @fold_extract_broadcast(%a : vector<4xf32>) -> f32 { + %b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32> + %r = vector.extract %b[0, 1, 2] : vector<1x2x4xf32> + return %r : f32 +} + +// ----- + +// Negative test for extract_op folding when the type of broadcast source +// doesn't match the type of vector.extract. +// CHECK-LABEL: fold_extract_broadcast_negative +// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<1x2x4xf32> +// CHECK: %[[R:.*]] = vector.extract %[[B]][0, 1] : vector<1x2x4xf32> +// CHECK: return %[[R]] : vector<4xf32> +func @fold_extract_broadcast_negative(%a : f32) -> vector<4xf32> { + %b = vector.broadcast %a : f32 to vector<1x2x4xf32> + %r = vector.extract %b[0, 1] : vector<1x2x4xf32> + return %r : vector<4xf32> +} + +// ----- + // CHECK-LABEL: fold_vector_transfers func @fold_vector_transfers(%A: memref) -> (vector<4x8xf32>, vector<4x9xf32>) { %c0 = constant 0 : index -- 2.7.4