From 8e4c806ed5a481e4d2163c8330f3c3c024d61a36 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Tue, 26 Oct 2021 21:13:11 +0000 Subject: [PATCH] [mlir][Linalg] NFC - Add additional control to lower vector.shape_cast ops This also moves some code to a new patterns file. Differential Revision: https://reviews.llvm.org/D112575 --- .../mlir/Dialect/Linalg/Transforms/Transforms.h | 60 +++-- mlir/include/mlir/Dialect/Vector/VectorOps.h | 6 - .../mlir/Dialect/Vector/VectorRewritePatterns.h | 58 +++++ mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp | 3 +- .../Linalg/Transforms/LinalgStrategyPasses.cpp | 42 ++-- mlir/lib/Dialect/Vector/CMakeLists.txt | 1 + ...VectorTransferPermutationMapRewritePatterns.cpp | 260 ++++++++++++++++++++ mlir/lib/Dialect/Vector/VectorTransforms.cpp | 261 +-------------------- 8 files changed, 393 insertions(+), 298 deletions(-) create mode 100644 mlir/lib/Dialect/Vector/VectorTransferPermutationMapRewritePatterns.cpp diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index cfa38d7..27688b5 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -877,49 +877,63 @@ struct LinalgEnablingOptions { /// Vector lowering options control how ops are lowered down to 1-D and scf.for /// form. struct LinalgVectorLoweringOptions { - /// Maximal transfer rank under which we do not lower further. - int64_t maxTransferRank = 1; - LinalgVectorLoweringOptions &setMaxTransferRank(int64_t val) { - maxTransferRank = val; - return *this; - } - /// Vector lowering operations may result in surprising behavior when - /// composing multiple codegen strategies and must be enabled explicitly. - bool transferLowering = true; - LinalgVectorLoweringOptions &enableTransferLowering(bool val = true) { - transferLowering = val; - return *this; - } - /// Enable lowering of vector.transpose. - bool transposeLowering = false; - LinalgVectorLoweringOptions &enableVectorTransposeLowering(bool val = true) { - transposeLowering = val; + /// Enable lowering of vector.contract. + /// In a progressive lowering of vectors, this would be the 1st step. + bool contractionLowering = false; + LinalgVectorLoweringOptions &enableContractionLowering(bool val = true) { + contractionLowering = val; return *this; } /// Enable lowering of vector.multi_reduce. + /// In a progressive lowering of vectors, this would be the 2nd step. bool multiReductionLowering = false; LinalgVectorLoweringOptions &enableMultiReductionLowering(bool val = true) { multiReductionLowering = val; return *this; } - /// Enable lowering of vector.contract. - bool contractionLowering = false; - LinalgVectorLoweringOptions &enableContractionLowering(bool val = true) { - contractionLowering = val; - return *this; - } /// Trigger full / partial vector.transfer splits. + /// In a progressive lowering of vectors, this would be the 3rd step. bool transferPartialRewrite = false; LinalgVectorLoweringOptions &enableTransferPartialRewrite(bool val = true) { transferPartialRewrite = val; return *this; } /// Enable lowering of vector.transfer to scf. + /// In a progressive lowering of vectors, this would be the 4th step. bool transferToSCFConversion = false; LinalgVectorLoweringOptions &enableTransferToSCFConversion(bool val = true) { transferToSCFConversion = val; return *this; } + /// Maximal transfer rank under which we do not lower further. + int64_t maxTransferRank = 1; + LinalgVectorLoweringOptions &setMaxTransferRank(int64_t val) { + maxTransferRank = val; + return *this; + } + /// Vector lowering operations may result in surprising behavior when + /// composing multiple codegen strategies and must be enabled explicitly. + /// In a progressive lowering of vectors, this would be the 5th step. + bool transferLowering = true; + LinalgVectorLoweringOptions &enableTransferLowering(bool val = true) { + transferLowering = val; + return *this; + } + /// Enable lowering of vector.shape_cast to insert/extract. + /// In a progressive lowering of vectors, this would be the 6th step. + bool shapeCastLowering = true; + LinalgVectorLoweringOptions &enableShapeCastLowering(bool val = true) { + shapeCastLowering = val; + return *this; + } + /// Enable lowering of vector.transpose. + /// In a progressive lowering of vectors, this would be the 7th step. + bool transposeLowering = false; + LinalgVectorLoweringOptions &enableVectorTransposeLowering(bool val = true) { + transposeLowering = val; + return *this; + } + /// Configure the post staged-patterns late vector.transfer to scf /// conversion. VectorTransferToSCFOptions vectorTransferToSCFOptions; diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h index c6f4ba4..a296834 100644 --- a/mlir/include/mlir/Dialect/Vector/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h @@ -81,12 +81,6 @@ void populateVectorTransferLoweringPatterns( RewritePatternSet &patterns, llvm::Optional maxTransferRank = llvm::None); -/// Collect a set of transfer read/write lowering patterns that simplify the -/// permutation map (e.g., converting it to a minor identity map) by inserting -/// broadcasts and transposes. -void populateVectorTransferPermutationMapLoweringPatterns( - RewritePatternSet &patterns); - /// These patterns materialize masks for various vector ops such as transfers. void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns, bool enableIndexOptimizations); diff --git a/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h index 47375c5..587f334 100644 --- a/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h @@ -169,6 +169,64 @@ void populateVectorContractLoweringPatterns( /// transpose/broadcast ops into the contract. void populateVectorReductionToContractPatterns(RewritePatternSet &patterns); +//===----------------------------------------------------------------------===// +// Vector.transfer patterns. +//===----------------------------------------------------------------------===// +/// Collect a set of transfer read/write lowering patterns that simplify the +/// permutation map (e.g., converting it to a minor identity map) by inserting +/// broadcasts and transposes. More specifically: +/// +/// [TransferReadPermutationLowering] +/// Lower transfer_read op with permutation into a transfer_read with a +/// permutation map composed of leading zeros followed by a minor identity + +/// vector.transpose op. +/// Ex: +/// vector.transfer_read ... +/// permutation_map: (d0, d1, d2) -> (0, d1) +/// into: +/// %v = vector.transfer_read ... +/// permutation_map: (d0, d1, d2) -> (d1, 0) +/// vector.transpose %v, [1, 0] +/// +/// vector.transfer_read ... +/// permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3) +/// into: +/// %v = vector.transfer_read ... +/// permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3) +/// vector.transpose %v, [0, 1, 3, 2, 4] +/// Note that an alternative is to transform it to linalg.transpose + +/// vector.transfer_read to do the transpose in memory instead. +/// +/// [TransferWritePermutationLowering] +/// Lower transfer_write op with permutation into a transfer_write with a +/// minor identity permutation map. (transfer_write ops cannot have broadcasts.) +/// Ex: +/// vector.transfer_write %v ... +/// permutation_map: (d0, d1, d2) -> (d2, d0, d1) +/// into: +/// %tmp = vector.transpose %v, [2, 0, 1] +/// vector.transfer_write %tmp ... +/// permutation_map: (d0, d1, d2) -> (d0, d1, d2) +/// +/// vector.transfer_write %v ... +/// permutation_map: (d0, d1, d2, d3) -> (d3, d2) +/// into: +/// %tmp = vector.transpose %v, [1, 0] +/// %v = vector.transfer_write %tmp ... +/// permutation_map: (d0, d1, d2, d3) -> (d2, d3) +/// +/// [TransferOpReduceRank] +/// Lower transfer_read op with broadcast in the leading dimensions into +/// transfer_read of lower rank + vector.broadcast. +/// Ex: vector.transfer_read ... +/// permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3) +/// into: +/// %v = vector.transfer_read ... +/// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3) +/// vector.broadcast %v +void populateVectorTransferPermutationMapLoweringPatterns( + RewritePatternSet &patterns); + /// Collect a set of patterns to reduce the rank of the operands of vector /// transfer ops to operate on the largest contigious vector. /// These patterns are useful when lowering to dialects with 1d vector type diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp index 1aacd49..2fd4959 100644 --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -20,8 +20,7 @@ #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" -#include "mlir/Dialect/Vector/VectorOps.h" -#include "mlir/Dialect/Vector/VectorUtils.h" +#include "mlir/Dialect/Vector/VectorTransforms.h" #include "mlir/IR/Builders.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/Pass/Pass.h" diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp index 9783186..4462bbe 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp @@ -264,34 +264,44 @@ struct LinalgStrategyLowerVectorsPass MLIRContext *context = funcOp.getContext(); RewritePatternSet patterns(context); vector::populateVectorToVectorCanonicalizationPatterns(patterns); - if (options.transferLowering) { - vector::populateVectorTransferLoweringPatterns(patterns, - options.maxTransferRank); - } - if (options.transposeLowering) { - vector::populateVectorTransposeLoweringPatterns( - patterns, options.vectorTransformOptions); - } - if (options.multiReductionLowering) { - vector::populateVectorMultiReductionLoweringPatterns( - patterns, - options.vectorTransformOptions.vectorMultiReductionLowering); - } + // In a progressive lowering of vectors, this would be the 1st step. if (options.contractionLowering) { patterns.add( options.vectorTransformOptions, context); vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); } + // In a progressive lowering of vectors, this would be the 2nd step. + if (options.multiReductionLowering) { + vector::populateVectorMultiReductionLoweringPatterns( + patterns, + options.vectorTransformOptions.vectorMultiReductionLowering); + } + // In a progressive lowering of vectors, this would be the 3rd step. if (options.transferPartialRewrite) { patterns.add( context, options.vectorTransformOptions); } + // In a progressive lowering of vectors, this would be the 4th step. + if (options.transferLowering) { + vector::populateVectorTransferLoweringPatterns(patterns, + options.maxTransferRank); + } + // In a progressive lowering of vectors, this would be the 5th step. if (options.transferToSCFConversion) { - populateVectorToSCFConversionPatterns(patterns, - options.vectorTransferToSCFOptions); + populateVectorToSCFConversionPatterns( + patterns, options.vectorTransferToSCFOptions.setTargetRank( + options.maxTransferRank)); + } + // In a progressive lowering of vectors, this would be the 6th step. + if (options.shapeCastLowering) { + vector::populateVectorShapeCastLoweringPatterns(patterns); + } + // In a progressive lowering of vectors, this would be the 7th step. + if (options.transposeLowering) { + vector::populateVectorTransposeLoweringPatterns( + patterns, options.vectorTransformOptions); } - vector::populateVectorShapeCastLoweringPatterns(patterns); (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } diff --git a/mlir/lib/Dialect/Vector/CMakeLists.txt b/mlir/lib/Dialect/Vector/CMakeLists.txt index f620a37..abd9616 100644 --- a/mlir/lib/Dialect/Vector/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/CMakeLists.txt @@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRVector VectorMultiDimReductionTransforms.cpp VectorOps.cpp VectorTransferOpTransforms.cpp + VectorTransferPermutationMapRewritePatterns.cpp VectorTransforms.cpp VectorUtils.cpp diff --git a/mlir/lib/Dialect/Vector/VectorTransferPermutationMapRewritePatterns.cpp b/mlir/lib/Dialect/Vector/VectorTransferPermutationMapRewritePatterns.cpp new file mode 100644 index 0000000..3f5c312 --- /dev/null +++ b/mlir/lib/Dialect/Vector/VectorTransferPermutationMapRewritePatterns.cpp @@ -0,0 +1,260 @@ +//===- VectorTransferPermutationMapRewritePatterns.cpp - Xfer map rewrite -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements rewrite patterns for the permutation_map attribute of +// vector.transfer operations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/VectorTransforms.h" +#include "mlir/Interfaces/VectorInterfaces.h" + +using namespace mlir; +using namespace mlir::vector; + +/// Transpose a vector transfer op's `in_bounds` attribute according to given +/// indices. +static ArrayAttr +transposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr, + const SmallVector &permutation) { + SmallVector newInBoundsValues; + for (unsigned pos : permutation) + newInBoundsValues.push_back( + attr.getValue()[pos].cast().getValue()); + return builder.getBoolArrayAttr(newInBoundsValues); +} +/// Lower transfer_read op with permutation into a transfer_read with a +/// permutation map composed of leading zeros followed by a minor identiy + +/// vector.transpose op. +/// Ex: +/// vector.transfer_read ... +/// permutation_map: (d0, d1, d2) -> (0, d1) +/// into: +/// %v = vector.transfer_read ... +/// permutation_map: (d0, d1, d2) -> (d1, 0) +/// vector.transpose %v, [1, 0] +/// +/// vector.transfer_read ... +/// permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3) +/// into: +/// %v = vector.transfer_read ... +/// permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3) +/// vector.transpose %v, [0, 1, 3, 2, 4] +/// Note that an alternative is to transform it to linalg.transpose + +/// vector.transfer_read to do the transpose in memory instead. +struct TransferReadPermutationLowering + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferReadOp op, + PatternRewriter &rewriter) const override { + SmallVector permutation; + AffineMap map = op.permutation_map(); + if (map.getNumResults() == 0) + return failure(); + if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) + return failure(); + AffineMap permutationMap = + map.getPermutationMap(permutation, op.getContext()); + if (permutationMap.isIdentity()) + return failure(); + + permutationMap = map.getPermutationMap(permutation, op.getContext()); + // Caluclate the map of the new read by applying the inverse permutation. + permutationMap = inversePermutation(permutationMap); + AffineMap newMap = permutationMap.compose(map); + // Apply the reverse transpose to deduce the type of the transfer_read. + ArrayRef originalShape = op.getVectorType().getShape(); + SmallVector newVectorShape(originalShape.size()); + for (auto pos : llvm::enumerate(permutation)) { + newVectorShape[pos.value()] = originalShape[pos.index()]; + } + + // Transpose mask operand. + Value newMask; + if (op.mask()) { + // Remove unused dims from the permutation map. E.g.: + // E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, 0, d3, 0, d2) + // comp = (d0, d1, d2) -> (d2, 0, d1, 0 d0) + auto comp = compressUnusedDims(map); + // Get positions of remaining result dims. + // E.g.: (d0, d1, d2) -> (d2, 0, d1, 0 d0) + // maskTransposeIndices = [ 2, 1, 0] + SmallVector maskTransposeIndices; + for (unsigned i = 0; i < comp.getNumResults(); ++i) { + if (auto expr = comp.getResult(i).dyn_cast()) + maskTransposeIndices.push_back(expr.getPosition()); + } + + newMask = rewriter.create(op.getLoc(), op.mask(), + maskTransposeIndices); + } + + // Transpose in_bounds attribute. + ArrayAttr newInBounds = + op.in_bounds() ? transposeInBoundsAttr( + rewriter, op.in_bounds().getValue(), permutation) + : ArrayAttr(); + + // Generate new transfer_read operation. + VectorType newReadType = + VectorType::get(newVectorShape, op.getVectorType().getElementType()); + Value newRead = rewriter.create( + op.getLoc(), newReadType, op.source(), op.indices(), newMap, + op.padding(), newMask, newInBounds); + + // Transpose result of transfer_read. + SmallVector transposePerm(permutation.begin(), permutation.end()); + rewriter.replaceOpWithNewOp(op, newRead, + transposePerm); + return success(); + } +}; + +/// Lower transfer_write op with permutation into a transfer_write with a +/// minor identity permutation map. (transfer_write ops cannot have broadcasts.) +/// Ex: +/// vector.transfer_write %v ... +/// permutation_map: (d0, d1, d2) -> (d2, d0, d1) +/// into: +/// %tmp = vector.transpose %v, [2, 0, 1] +/// vector.transfer_write %tmp ... +/// permutation_map: (d0, d1, d2) -> (d0, d1, d2) +/// +/// vector.transfer_write %v ... +/// permutation_map: (d0, d1, d2, d3) -> (d3, d2) +/// into: +/// %tmp = vector.transpose %v, [1, 0] +/// %v = vector.transfer_write %tmp ... +/// permutation_map: (d0, d1, d2, d3) -> (d2, d3) +struct TransferWritePermutationLowering + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferWriteOp op, + PatternRewriter &rewriter) const override { + if (op.isZeroD()) + return failure(); + + SmallVector permutation; + AffineMap map = op.permutation_map(); + if (map.isMinorIdentity()) + return failure(); + if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) + return failure(); + + // Remove unused dims from the permutation map. E.g.: + // E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, d3, d4) + // comp = (d0, d1, d2) -> (d2, d0, d1) + auto comp = compressUnusedDims(map); + // Get positions of remaining result dims. + SmallVector indices; + llvm::transform(comp.getResults(), std::back_inserter(indices), + [](AffineExpr expr) { + return expr.dyn_cast().getPosition(); + }); + + // Transpose mask operand. + Value newMask = op.mask() ? rewriter.create( + op.getLoc(), op.mask(), indices) + : Value(); + + // Transpose in_bounds attribute. + ArrayAttr newInBounds = + op.in_bounds() ? transposeInBoundsAttr( + rewriter, op.in_bounds().getValue(), permutation) + : ArrayAttr(); + + // Generate new transfer_write operation. + Value newVec = + rewriter.create(op.getLoc(), op.vector(), indices); + auto newMap = AffineMap::getMinorIdentityMap( + map.getNumDims(), map.getNumResults(), rewriter.getContext()); + rewriter.replaceOpWithNewOp( + op, Type(), newVec, op.source(), op.indices(), newMap, newMask, + newInBounds); + + return success(); + } +}; + +/// Lower transfer_read op with broadcast in the leading dimensions into +/// transfer_read of lower rank + vector.broadcast. +/// Ex: vector.transfer_read ... +/// permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3) +/// into: +/// %v = vector.transfer_read ... +/// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3) +/// vector.broadcast %v +struct TransferOpReduceRank : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferReadOp op, + PatternRewriter &rewriter) const override { + AffineMap map = op.permutation_map(); + unsigned numLeadingBroadcast = 0; + for (auto expr : map.getResults()) { + auto dimExpr = expr.dyn_cast(); + if (!dimExpr || dimExpr.getValue() != 0) + break; + numLeadingBroadcast++; + } + // If there are no leading zeros in the map there is nothing to do. + if (numLeadingBroadcast == 0) + return failure(); + VectorType originalVecType = op.getVectorType(); + unsigned reducedShapeRank = originalVecType.getRank() - numLeadingBroadcast; + // Calculate new map, vector type and masks without the leading zeros. + AffineMap newMap = AffineMap::get( + map.getNumDims(), 0, map.getResults().take_back(reducedShapeRank), + op.getContext()); + // Only remove the leading zeros if the rest of the map is a minor identity + // with broadasting. Otherwise we first want to permute the map. + if (!newMap.isMinorIdentityWithBroadcasting()) + return failure(); + + // TODO: support zero-dimension vectors natively. See: + // https://llvm.discourse.group/t/should-we-have-0-d-vectors/3097. + // In the meantime, lower these to a scalar load when they pop up. + if (reducedShapeRank == 0) { + Value newRead = rewriter.create( + op.getLoc(), originalVecType.getElementType(), op.source(), + op.indices()); + rewriter.replaceOpWithNewOp(op, originalVecType, + newRead); + return success(); + } + SmallVector newShape = llvm::to_vector<4>( + originalVecType.getShape().take_back(reducedShapeRank)); + // Vector rank cannot be zero. Handled by TransferReadToVectorLoadLowering. + if (newShape.empty()) + return failure(); + VectorType newReadType = + VectorType::get(newShape, originalVecType.getElementType()); + ArrayAttr newInBounds = + op.in_bounds() + ? rewriter.getArrayAttr( + op.in_boundsAttr().getValue().take_back(reducedShapeRank)) + : ArrayAttr(); + Value newRead = rewriter.create( + op.getLoc(), newReadType, op.source(), op.indices(), newMap, + op.padding(), op.mask(), newInBounds); + rewriter.replaceOpWithNewOp(op, originalVecType, + newRead); + return success(); + } +}; + +void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns( + RewritePatternSet &patterns) { + patterns.add( + patterns.getContext()); +} diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp index 4c7ef51..efb22b9 100644 --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -2882,240 +2882,6 @@ struct TransferWriteToVectorStoreLowering llvm::Optional maxTransferRank; }; -/// Transpose a vector transfer op's `in_bounds` attribute according to given -/// indices. -static ArrayAttr -transposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr, - const SmallVector &permutation) { - SmallVector newInBoundsValues; - for (unsigned pos : permutation) - newInBoundsValues.push_back( - attr.getValue()[pos].cast().getValue()); - return builder.getBoolArrayAttr(newInBoundsValues); -} - -/// Lower transfer_read op with permutation into a transfer_read with a -/// permutation map composed of leading zeros followed by a minor identiy + -/// vector.transpose op. -/// Ex: -/// vector.transfer_read ... -/// permutation_map: (d0, d1, d2) -> (0, d1) -/// into: -/// %v = vector.transfer_read ... -/// permutation_map: (d0, d1, d2) -> (d1, 0) -/// vector.transpose %v, [1, 0] -/// -/// vector.transfer_read ... -/// permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3) -/// into: -/// %v = vector.transfer_read ... -/// permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3) -/// vector.transpose %v, [0, 1, 3, 2, 4] -/// Note that an alternative is to transform it to linalg.transpose + -/// vector.transfer_read to do the transpose in memory instead. -struct TransferReadPermutationLowering - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(vector::TransferReadOp op, - PatternRewriter &rewriter) const override { - SmallVector permutation; - AffineMap map = op.permutation_map(); - if (map.getNumResults() == 0) - return failure(); - if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) - return failure(); - AffineMap permutationMap = - map.getPermutationMap(permutation, op.getContext()); - if (permutationMap.isIdentity()) - return failure(); - - permutationMap = map.getPermutationMap(permutation, op.getContext()); - // Caluclate the map of the new read by applying the inverse permutation. - permutationMap = inversePermutation(permutationMap); - AffineMap newMap = permutationMap.compose(map); - // Apply the reverse transpose to deduce the type of the transfer_read. - ArrayRef originalShape = op.getVectorType().getShape(); - SmallVector newVectorShape(originalShape.size()); - for (auto pos : llvm::enumerate(permutation)) { - newVectorShape[pos.value()] = originalShape[pos.index()]; - } - - // Transpose mask operand. - Value newMask; - if (op.mask()) { - // Remove unused dims from the permutation map. E.g.: - // E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, 0, d3, 0, d2) - // comp = (d0, d1, d2) -> (d2, 0, d1, 0 d0) - auto comp = compressUnusedDims(map); - // Get positions of remaining result dims. - // E.g.: (d0, d1, d2) -> (d2, 0, d1, 0 d0) - // maskTransposeIndices = [ 2, 1, 0] - SmallVector maskTransposeIndices; - for (unsigned i = 0; i < comp.getNumResults(); ++i) { - if (auto expr = comp.getResult(i).dyn_cast()) - maskTransposeIndices.push_back(expr.getPosition()); - } - - newMask = rewriter.create(op.getLoc(), op.mask(), - maskTransposeIndices); - } - - // Transpose in_bounds attribute. - ArrayAttr newInBounds = - op.in_bounds() ? transposeInBoundsAttr( - rewriter, op.in_bounds().getValue(), permutation) - : ArrayAttr(); - - // Generate new transfer_read operation. - VectorType newReadType = - VectorType::get(newVectorShape, op.getVectorType().getElementType()); - Value newRead = rewriter.create( - op.getLoc(), newReadType, op.source(), op.indices(), newMap, - op.padding(), newMask, newInBounds); - - // Transpose result of transfer_read. - SmallVector transposePerm(permutation.begin(), permutation.end()); - rewriter.replaceOpWithNewOp(op, newRead, - transposePerm); - return success(); - } -}; - -/// Lower transfer_write op with permutation into a transfer_write with a -/// minor identity permutation map. (transfer_write ops cannot have broadcasts.) -/// Ex: -/// vector.transfer_write %v ... -/// permutation_map: (d0, d1, d2) -> (d2, d0, d1) -/// into: -/// %tmp = vector.transpose %v, [2, 0, 1] -/// vector.transfer_write %tmp ... -/// permutation_map: (d0, d1, d2) -> (d0, d1, d2) -/// -/// vector.transfer_write %v ... -/// permutation_map: (d0, d1, d2, d3) -> (d3, d2) -/// into: -/// %tmp = vector.transpose %v, [1, 0] -/// %v = vector.transfer_write %tmp ... -/// permutation_map: (d0, d1, d2, d3) -> (d2, d3) -struct TransferWritePermutationLowering - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(vector::TransferWriteOp op, - PatternRewriter &rewriter) const override { - if (op.isZeroD()) - return failure(); - - SmallVector permutation; - AffineMap map = op.permutation_map(); - if (map.isMinorIdentity()) - return failure(); - if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) - return failure(); - - // Remove unused dims from the permutation map. E.g.: - // E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, d3, d4) - // comp = (d0, d1, d2) -> (d2, d0, d1) - auto comp = compressUnusedDims(map); - // Get positions of remaining result dims. - SmallVector indices; - llvm::transform(comp.getResults(), std::back_inserter(indices), - [](AffineExpr expr) { - return expr.dyn_cast().getPosition(); - }); - - // Transpose mask operand. - Value newMask = op.mask() ? rewriter.create( - op.getLoc(), op.mask(), indices) - : Value(); - - // Transpose in_bounds attribute. - ArrayAttr newInBounds = - op.in_bounds() ? transposeInBoundsAttr( - rewriter, op.in_bounds().getValue(), permutation) - : ArrayAttr(); - - // Generate new transfer_write operation. - Value newVec = - rewriter.create(op.getLoc(), op.vector(), indices); - auto newMap = AffineMap::getMinorIdentityMap( - map.getNumDims(), map.getNumResults(), rewriter.getContext()); - rewriter.replaceOpWithNewOp( - op, Type(), newVec, op.source(), op.indices(), newMap, newMask, - newInBounds); - - return success(); - } -}; - -/// Lower transfer_read op with broadcast in the leading dimensions into -/// transfer_read of lower rank + vector.broadcast. -/// Ex: vector.transfer_read ... -/// permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3) -/// into: -/// %v = vector.transfer_read ... -/// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3) -/// vector.broadcast %v -struct TransferOpReduceRank : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(vector::TransferReadOp op, - PatternRewriter &rewriter) const override { - AffineMap map = op.permutation_map(); - unsigned numLeadingBroadcast = 0; - for (auto expr : map.getResults()) { - auto dimExpr = expr.dyn_cast(); - if (!dimExpr || dimExpr.getValue() != 0) - break; - numLeadingBroadcast++; - } - // If there are no leading zeros in the map there is nothing to do. - if (numLeadingBroadcast == 0) - return failure(); - VectorType originalVecType = op.getVectorType(); - unsigned reducedShapeRank = originalVecType.getRank() - numLeadingBroadcast; - // Calculate new map, vector type and masks without the leading zeros. - AffineMap newMap = AffineMap::get( - map.getNumDims(), 0, map.getResults().take_back(reducedShapeRank), - op.getContext()); - // Only remove the leading zeros if the rest of the map is a minor identity - // with broadasting. Otherwise we first want to permute the map. - if (!newMap.isMinorIdentityWithBroadcasting()) - return failure(); - - // TODO: support zero-dimension vectors natively. See: - // https://llvm.discourse.group/t/should-we-have-0-d-vectors/3097. - // In the meantime, lower these to a scalar load when they pop up. - if (reducedShapeRank == 0) { - Value newRead = rewriter.create( - op.getLoc(), originalVecType.getElementType(), op.source(), - op.indices()); - rewriter.replaceOpWithNewOp(op, originalVecType, - newRead); - return success(); - } - SmallVector newShape = llvm::to_vector<4>( - originalVecType.getShape().take_back(reducedShapeRank)); - // Vector rank cannot be zero. Handled by TransferReadToVectorLoadLowering. - if (newShape.empty()) - return failure(); - VectorType newReadType = - VectorType::get(newShape, originalVecType.getElementType()); - ArrayAttr newInBounds = - op.in_bounds() - ? rewriter.getArrayAttr( - op.in_boundsAttr().getValue().take_back(reducedShapeRank)) - : ArrayAttr(); - Value newRead = rewriter.create( - op.getLoc(), newReadType, op.source(), op.indices(), newMap, - op.padding(), op.mask(), newInBounds); - rewriter.replaceOpWithNewOp(op, originalVecType, - newRead); - return success(); - } -}; - // Trims leading one dimensions from `oldType` and returns the result type. // Returns `vector<1xT>` if `oldType` only has one element. static VectorType trimLeadingOneDims(VectorType oldType) { @@ -3891,23 +3657,6 @@ void mlir::vector::populateVectorReductionToContractPatterns( CombineContractTranspose>(patterns.getContext()); } -void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns( - RewritePatternSet &patterns) { - patterns.add( - patterns.getContext()); -} - -void mlir::vector::populateVectorTransferLoweringPatterns( - RewritePatternSet &patterns, llvm::Optional maxTransferRank) { - patterns.add(patterns.getContext(), - maxTransferRank); - patterns - .add( - patterns.getContext()); -} - void mlir::vector::populateVectorUnrollPatterns( RewritePatternSet &patterns, const UnrollVectorOptions &options) { patterns.add(patterns.getContext()); } + +void mlir::vector::populateVectorTransferLoweringPatterns( + RewritePatternSet &patterns, llvm::Optional maxTransferRank) { + patterns.add(patterns.getContext(), + maxTransferRank); + patterns + .add( + patterns.getContext()); +} -- 2.7.4