From 7006daa548c25960dbb5a50e9b9987d4dd01798b Mon Sep 17 00:00:00 2001 From: Andy Davis Date: Wed, 8 Apr 2020 08:39:48 -0700 Subject: [PATCH] [MLIR][Vector] Update ShapeCastOp folder to use producer-consumer value forwarding. Summary: Update ShapeCastOp folder to use producer-consumer value forwarding. Support is added for tracking sub-vectors through trivial shape cast operations, where the sub-vector shape is preserved across shape cast operations and only leading ones are added or removed. Support is preserved for cancelling shape cast operations. One unit test is added and two are updated. Reviewers: aartbik, nicolasvasilache Reviewed By: aartbik, nicolasvasilache Subscribers: frgossen, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, grosul1, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D77253 --- mlir/lib/Dialect/Vector/VectorTransforms.cpp | 65 +++++++++++++++++++++++-- mlir/test/Dialect/Vector/vector-transforms.mlir | 58 ++++++++++++++-------- 2 files changed, 99 insertions(+), 24 deletions(-) diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp index dbb0bf4..7a197ef 100644 --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -676,10 +676,10 @@ struct ShapeCastOpDecomposer : public OpRewritePattern { /// Returns the producer Value of the same type as 'consumerValue', by tracking /// the tuple index and offsets of the consumer vector value through the -/// chain of operations (TupleGetOp, InsertSlicesOp, ExtractSlicesOp, TupleOp) -/// from consumer to producer. Each operation in the chain is structured, and -/// so the tuple index and offsets can be mapped from result to input, while -/// visiting each operation in the chain. +/// chain of operations (TupleGetOp, InsertSlicesOp, ExtractSlicesOp, TupleOp, +/// and ShapeCastOp) from consumer to producer. Each operation in the chain is +/// structured, and so the tuple index and offsets can be mapped from result to +/// input, while visiting each operation in the chain. /// Returns nullptr on failure. static Value getProducerValue(Value consumerValue) { auto consumerVectorType = consumerValue.getType().cast(); @@ -760,8 +760,57 @@ static Value getProducerValue(Value consumerValue) { // Update 'tupleIndex' and next defining 'op' to visit. tupleIndex = -1; op = value.getDefiningOp(); + } else if (auto shapeCastOp = dyn_cast(op)) { + if (shapeCastOp.source().getType().isa()) + return nullptr; + assert(tupleIndex == -1); + auto sourceVectorType = shapeCastOp.getSourceVectorType(); + auto sourceVectorShape = sourceVectorType.getShape(); + unsigned sourceVectorRank = sourceVectorType.getRank(); + auto resultVectorType = shapeCastOp.getResultVectorType(); + auto resultVectorShape = resultVectorType.getShape(); + unsigned resultVectorRank = resultVectorType.getRank(); + + int i = sourceVectorRank - 1; + int j = resultVectorRank - 1; + + // Check that source/result vector shape prefixes match while + // updating 'newOffsets'. + bool canShapeCastFold = true; + SmallVector newOffsets(sourceVectorRank, 0); + + auto apply = [&](int64_t sourceSize, int64_t resultSize) { + canShapeCastFold = sourceSize == resultSize; + newOffsets[i--] = offsets[j--]; + }; + functional::zipApply(apply, llvm::reverse(sourceVectorShape), + llvm::reverse(resultVectorShape)); + if (!canShapeCastFold) + return nullptr; + + // Check that remaining prefix of source/result vector shapes are all 1s. + // Currently we only support producer/consumer tracking through trivial + // shape cast ops. Examples: + // %1 = vector.shape_cast %0 : vector<1x1x2x4xf32> to vector<2x4xf32> + // %3 = vector.shape_cast %2 : vector<16x8xf32> to vector<1x16x8xf32> + assert(i == -1 || j == -1); + if (i >= 0 && + !std::all_of(sourceVectorShape.begin(), sourceVectorShape.begin() + i, + [](int64_t v) { return v == 1; })) + return nullptr; + if (j >= 0 && + !std::all_of(resultVectorShape.begin(), resultVectorShape.begin() + j, + [](int64_t v) { return v == 1; })) + return nullptr; + + offsets.swap(newOffsets); + op = shapeCastOp.source().getDefiningOp(); } else { - break; + // Check if 'op' produces a Value with the same type as 'consumerValue'. + if (op->getNumResults() == 1 && + op->getResult(0).getType() == consumerVectorType) + return op->getResult(0); + return nullptr; } } return nullptr; @@ -788,6 +837,12 @@ struct ShapeCastOpFolder : public OpRewritePattern { LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp, PatternRewriter &rewriter) const override { + // Check if we can replace 'shapeCastOp' result with its producer. + if (auto producer = getProducerValue(shapeCastOp.getResult())) { + rewriter.replaceOp(shapeCastOp, producer); + return success(); + } + // Check if 'shapeCastOp' has vector source/result type. auto sourceVectorType = shapeCastOp.source().getType().dyn_cast_or_null(); diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir index 082afba..2e4e903 100644 --- a/mlir/test/Dialect/Vector/vector-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-transforms.mlir @@ -341,16 +341,21 @@ func @tuple_get_producer_consumer( %2 = vector.extract_slices %1, [4, 8], [1, 1] : vector<4x16xf32> into tuple, vector<4x8xf32>> // %arg7 == %2 at tupleIndex = 1, offsets = [2, 4] - %3 = vector.tuple_get %2, 1 : tuple, vector<4x8xf32>> - // %arg7 == %3 at tupleIndex = -1, offsets = [2, 4] - %4 = vector.extract_slices %3, [2, 4], [1, 1] + %3 = vector.shape_cast %2 : tuple, vector<4x8xf32>> to + tuple, vector<1x1x4x8xf32>> + // %arg7 = %3 at tupleIndex = 1, offsets = [0, 0, 2, 4] + %4 = vector.tuple_get %3, 1 : tuple, vector<1x1x4x8xf32>> + // %arg7 == %4 at tupleIndex = -1, offsets = [0, 0, 2, 4] + %5 = vector.shape_cast %4 : vector<1x1x4x8xf32> to vector<4x8xf32> + // %arg7 == %5 at tupleIndex = -1, offsets = [2, 4] + %6 = vector.extract_slices %5, [2, 4], [1, 1] : vector<4x8xf32> into tuple, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>> - // %arg7 == %4 at tupleIndex = 3, offsets = [0, 0] - %5 = vector.tuple_get %4, 3 + // %arg7 == %6 at tupleIndex = 3, offsets = [0, 0] + %7 = vector.tuple_get %6, 3 : tuple, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>> - // %arg7 == %5 - return %5 : vector<2x4xf32> + // %arg7 == %7 + return %7 : vector<2x4xf32> } // CHECK-LABEL: func @tuple_get_producer_consumer_swizzle @@ -381,25 +386,40 @@ func @tuple_get_producer_consumer_swizzle( %2 = vector.extract_slices %1, [4, 8], [1, 1] : vector<4x16xf32> into tuple, vector<4x8xf32>> // %arg7 == %2 at tupleIndex = 1, offsets = [2, 4] + %3= vector.shape_cast %2 : tuple, vector<4x8xf32>> to + tuple, vector<1x1x4x8xf32>> + // %arg7 = %3 at tupleIndex = 1, offsets = [0, 0, 2, 4] // Extract tuple elements. - %3 = vector.tuple_get %2, 0 : tuple, vector<4x8xf32>> - %4 = vector.tuple_get %2, 1 : tuple, vector<4x8xf32>> - // %arg7 == %4 at tupleIndex = -1, offsets = [2, 4] + %4 = vector.tuple_get %3, 0 : tuple, vector<1x1x4x8xf32>> + %5 = vector.tuple_get %3, 1 : tuple, vector<1x1x4x8xf32>> + // %arg7 == %5 at tupleIndex = -1, offsets = [0, 0, 2, 4] // Swizzle tuple elements. - %5 = vector.tuple %4, %3 : vector<4x8xf32>, vector<4x8xf32> - // %arg7 == %5 at tupleIndex = 0, offsets = [2, 4] - %6 = vector.tuple_get %5, 0 : tuple, vector<4x8xf32>> - // %arg7 == %6 at tupleIndex = -1, offsets = [2, 4] - %7 = vector.extract_slices %6, [2, 4], [1, 1] + %6 = vector.tuple %5, %4 : vector<1x1x4x8xf32>, vector<1x1x4x8xf32> + // %arg7 == %6 at tupleIndex = 0, offsets = [0, 0, 2, 4] + %7 = vector.shape_cast %6 : tuple, vector<1x1x4x8xf32>> to + tuple, vector<4x8xf32>> + // %arg7 = %7 at tupleIndex = 0, offsets = [2, 4] + %8 = vector.tuple_get %7, 0 : tuple, vector<4x8xf32>> + // %arg7 == %8 at tupleIndex = -1, offsets = [2, 4] + %9 = vector.extract_slices %8, [2, 4], [1, 1] : vector<4x8xf32> into tuple, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>> - // %arg7 == %7 at tupleIndex = 3, offsets = [0, 0] - %8 = vector.tuple_get %7, 3 + // %arg7 == %9 at tupleIndex = 3, offsets = [0, 0] + %10 = vector.tuple_get %9, 3 : tuple, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>> - // %arg7 == %8 - return %8 : vector<2x4xf32> + // %arg7 == %10 + return %10 : vector<2x4xf32> +} + +// CHECK-LABEL: func @cancelling_shape_cast_ops +// CHECK-SAME: %[[A0:.*0]]: vector<2x4xf32> +// CHECK: return %[[A0]] : vector<2x4xf32> +func @cancelling_shape_cast_ops(%arg0 : vector<2x4xf32>) -> vector<2x4xf32> { + %0 = vector.shape_cast %arg0 : vector<2x4xf32> to vector<8xf32> + %1 = vector.shape_cast %0 : vector<8xf32> to vector<2x4xf32> + return %1 : vector<2x4xf32> } // CHECK-LABEL: func @vector_transfers_vector_element_type -- 2.7.4