From 4d8ba886103b0022b019671bf27547d55a902b54 Mon Sep 17 00:00:00 2001 From: Andy Davis Date: Tue, 10 Dec 2019 17:02:17 -0800 Subject: [PATCH] Add VectorOp transform pattern which splits vector TransferReadOps to target vector unroll size. PiperOrigin-RevId: 284880592 --- mlir/include/mlir/Dialect/VectorOps/VectorOps.h | 4 ++ mlir/include/mlir/Dialect/VectorOps/VectorOps.td | 1 + mlir/lib/Dialect/VectorOps/VectorOps.cpp | 16 +++-- mlir/lib/Dialect/VectorOps/VectorTransforms.cpp | 68 ++++++++++++++++++++++ mlir/test/Dialect/VectorOps/vector-transforms.mlir | 66 ++++++++++++++++++++- mlir/test/lib/Transforms/TestVectorTransforms.cpp | 1 + 6 files changed, 149 insertions(+), 7 deletions(-) diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.h b/mlir/include/mlir/Dialect/VectorOps/VectorOps.h index 8cb0d85..5b4351b 100644 --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.h +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.h @@ -43,6 +43,10 @@ public: void populateVectorToVectorCanonicalizationPatterns( OwningRewritePatternList &patterns, MLIRContext *context); +/// Collect a set of vector-to-vector transformation patterns. +void populateVectorToVectorTransformationPatterns( + OwningRewritePatternList &patterns, MLIRContext *context); + #define GET_OP_CLASSES #include "mlir/Dialect/VectorOps/VectorOps.h.inc" diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td index f6e1ae5..d87f101 100644 --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -451,6 +451,7 @@ def Vector_StridedSliceOp : static StringRef getSizesAttrName() { return "sizes"; } static StringRef getStridesAttrName() { return "strides"; } VectorType getVectorType(){ return vector()->getType().cast(); } + void getOffsets(SmallVectorImpl &results); }]; let hasCanonicalizer = 1; } diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp index 7714623..28a0322 100644 --- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -731,6 +731,12 @@ LogicalResult isSumOfIntegerArrayAttrConfinedToShape( return success(); } +static void populateFromInt64AttrArray(ArrayAttr arrayAttr, + SmallVectorImpl &results) { + for (auto attr : arrayAttr) + results.push_back(attr.cast().getInt()); +} + static ArrayAttr makeI64ArrayAttr(ArrayRef values, MLIRContext *context) { auto attrs = functional::map( @@ -929,14 +935,12 @@ static LogicalResult verify(StridedSliceOp op) { return success(); } -namespace { - -static void populateFromInt64AttrArray(ArrayAttr arrayAttr, - SmallVectorImpl &results) { - for (auto attr : arrayAttr) - results.push_back(attr.cast().getInt()); +void StridedSliceOp::getOffsets(SmallVectorImpl &results) { + populateFromInt64AttrArray(offsets(), results); } +namespace { + // Pattern to rewrite a StridedSliceOp(ConstantMaskOp) -> ConstantMaskOp. class StridedSliceConstantMaskFolder final : public OpRewritePattern { diff --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp index 6b13bcf..6825709 100644 --- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp @@ -446,3 +446,71 @@ Value *mlir::vector::unrollSingleResultOpMatchingType( return unrollSingleResultStructuredOp(op, iterationBounds, vectors, resultIndex, targetShape, builder); } + +// Splits vector TransferReadOp into smaller TransferReadOps for each user. +struct SplitTransferReadOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(vector::TransferReadOp xferReadOp, + PatternRewriter &rewriter) const override { + // TODO(andydavis, ntv) Support spliting TransferReadOp with non-identity + // permutation maps. Repurpose code from MaterializeVectors transformation. + if (!xferReadOp.permutation_map().isIdentity()) + return matchFailure(); + // Gather 'xferReadOp' users. + SmallVector sliceUsers; + sliceUsers.reserve(std::distance(xferReadOp.getResult()->use_begin(), + xferReadOp.getResult()->use_end())); + + for (auto *user : xferReadOp.getResult()->getUsers()) { + auto sliceOp = dyn_cast(user); + // Return if any user is not a vector::StridedSliceOp. + if (!sliceOp) + return matchFailure(); + sliceUsers.push_back(sliceOp); + } + // Make zero splat into which we will insert split xferReadOp results. + Location loc = xferReadOp.getLoc(); + auto *res = makeSplatZero(loc, rewriter, xferReadOp.getVectorType()); + + // Update each user in 'sliceUser' to use 'res'. + unsigned numSliceIndices = llvm::size(xferReadOp.indices()); + for (auto sliceUser : sliceUsers) { + // Gather static offsets from 'sliceUser'. + SmallVector sliceOffsets; + sliceUser.getOffsets(sliceOffsets); + assert(sliceOffsets.size() == numSliceIndices); + auto *ctx = rewriter.getContext(); + // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'. + SmallVector sliceIndices(numSliceIndices); + for (auto it : llvm::enumerate(xferReadOp.indices())) { + auto expr = getAffineDimExpr(0, ctx) + + getAffineConstantExpr(sliceOffsets[it.index()], ctx); + auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr); + SmallVector mapOperands = {it.value()}; + sliceIndices[it.index()] = + rewriter.create(loc, map, mapOperands); + } + // Create split TransferReadOp for 'sliceUser'. + auto sliceVectorType = + sliceUser.getResult()->getType().cast(); + auto splitXferReadOp = rewriter.create( + loc, sliceVectorType, xferReadOp.memref(), sliceIndices, + xferReadOp.permutation_map(), xferReadOp.padding()); + // Create InsertStridedSlice into splat at same offsets as slice. + res = rewriter.create( + loc, xferReadOp.getVectorType(), splitXferReadOp, res, + sliceUser.offsets(), sliceUser.strides()); + } + + // Replace 'xferReadOp' with result 'res'. + rewriter.replaceOp(xferReadOp, res); + return matchSuccess(); + } +}; + +// TODO(andydavis) Add this as DRR pattern. +void mlir::vector::populateVectorToVectorTransformationPatterns( + OwningRewritePatternList &patterns, MLIRContext *context) { + patterns.insert(context); +} diff --git a/mlir/test/Dialect/VectorOps/vector-transforms.mlir b/mlir/test/Dialect/VectorOps/vector-transforms.mlir index 4fb235d..c8d92ee 100644 --- a/mlir/test/Dialect/VectorOps/vector-transforms.mlir +++ b/mlir/test/Dialect/VectorOps/vector-transforms.mlir @@ -1,5 +1,7 @@ // RUN: mlir-opt %s -test-vector-to-vector-conversion | FileCheck %s +// CHECK-DAG: #[[MAP0:map[0-9]+]] = (d0, d1) -> (d0, d1) + // CHECK-LABEL: func @add4x2 // CHECK: %[[V1:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32> // CHECK-NEXT: %[[V2:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32> @@ -206,8 +208,70 @@ func @contraction4x4_ikj(%arg0 : vector<4x2xf32>, %arg1 : vector<2x4xf32>, return %0 : vector<4x4xf32> } +// CHECK-LABEL: func @contraction4x4_ikj_xfer_read + +// Capture constants used to re-index vector transfer reads. +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[C2:.*]] = constant 2 : index + +// Check LHS vector.transfer read is split for each user. +// CHECK: %[[VTR0:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x2xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[ISS0:.*]] = vector.insert_strided_slice %[[VTR0]], %{{.*}} {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x2xf32> +// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x2xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[ISS1:.*]] = vector.insert_strided_slice %[[VTR1]], %[[ISS0]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x2xf32> + +// Check RHS vector.transfer read is split for each user. +// CHECK: %[[VTR2:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C2]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<2x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[ISS2:.*]] = vector.insert_strided_slice %[[VTR2]], %{{.*}} {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<2x4xf32> +// CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<2x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[ISS3:.*]] = vector.insert_strided_slice %[[VTR3]], %[[ISS2]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<2x4xf32> + +// Check ACC vector.transfer read is split for each user (should be 4). +// CHECK: %[[VTR4:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C2]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[ISS4:.*]] = vector.insert_strided_slice %[[VTR4]], %{{.*}} {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// CHECK-NEXT: %[[VTR5:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[ISS5:.*]] = vector.insert_strided_slice %[[VTR5]], %[[ISS4]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// CHECK-NEXT: %[[VTR6:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C2]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[ISS6:.*]] = vector.insert_strided_slice %[[VTR6]], %[[ISS5]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// CHECK-NEXT: %[[VTR7:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[ISS7:.*]] = vector.insert_strided_slice %[[VTR7]], %[[ISS6]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> + +// Check LHS slice uses splat of split tranfer read results. +// CHECK: vector.strided_slice %[[ISS1]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32> + +// Check RHS slice uses splat of split tranfer read results. +// CHECK: vector.strided_slice %[[ISS3]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<2x4xf32> to vector<2x2xf32> + +// Check ACC slice uses splat of split tranfer read results. +// CHECK: vector.strided_slice %[[ISS7]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> + +func @contraction4x4_ikj_xfer_read(%arg0 : memref<4x2xf32>, + %arg1 : memref<2x4xf32>, + %arg2 : memref<4x4xf32>) + -> (vector<4x4xf32>) { + %c0 = constant 0 : index + %cf0 = constant 0.0 : f32 + + %0 = vector.transfer_read %arg0[%c0, %c0], %cf0 + { permutation_map = (d0, d1) -> (d0, d1) } + : memref<4x2xf32>, vector<4x2xf32> + + %1 = vector.transfer_read %arg1[%c0, %c0], %cf0 + { permutation_map = (d0, d1) -> (d0, d1) } + : memref<2x4xf32>, vector<2x4xf32> + + %2 = vector.transfer_read %arg2[%c0, %c0], %cf0 + { permutation_map = (d0, d1) -> (d0, d1) } + : memref<4x4xf32>, vector<4x4xf32> + + %3 = vector.contract #contraction_trait1 %0, %1, %2 + : vector<4x2xf32>, vector<2x4xf32> into vector<4x4xf32> + + return %3 : vector<4x4xf32> +} + // CHECK-LABEL: func @vector_transfers -// CHECK-COUNT-2: vector.transfer_read +// CHECK-COUNT-8: vector.transfer_read // CHECK-COUNT-2: vector.strided_slice // CHECK-COUNT-1: addf // CHECK-COUNT-2: vector.strided_slice diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp index 909fe2a..1d51306 100644 --- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -36,6 +36,7 @@ struct TestVectorToVectorConversion auto *context = &getContext(); populateWithGenerated(context, &patterns); populateVectorToVectorCanonicalizationPatterns(patterns, context); + populateVectorToVectorTransformationPatterns(patterns, context); applyPatternsGreedily(getFunction(), patterns); } }; -- 2.7.4