From 038ad1d8567ae2f46294e7e7fe68e09c20a309d6 Mon Sep 17 00:00:00 2001 From: Andy Davis Date: Tue, 17 Dec 2019 07:28:37 -0800 Subject: [PATCH] Add pattern rewrite which splits a vector TransferReadOp into slices according to the unrolling/slicing scheme of its ExtractSlicesOp user. PiperOrigin-RevId: 285975613 --- mlir/lib/Dialect/VectorOps/VectorTransforms.cpp | 78 ++++++++++++---------- mlir/test/Dialect/VectorOps/vector-transforms.mlir | 17 +++-- 2 files changed, 56 insertions(+), 39 deletions(-) diff --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp index 8d70f4a..85f306e 100644 --- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp @@ -511,7 +511,8 @@ Value *mlir::vector::unrollSingleResultOpMatchingType( resultIndex, targetShape, builder); } -// Splits vector TransferReadOp into smaller TransferReadOps for each user. +// Splits vector TransferReadOp into smaller TransferReadOps based on slicing +// scheme of its unique ExtractSlicesOp user. struct SplitTransferReadOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -521,54 +522,63 @@ struct SplitTransferReadOp : public OpRewritePattern { // permutation maps. Repurpose code from MaterializeVectors transformation. if (!xferReadOp.permutation_map().isIdentity()) return matchFailure(); - // Gather 'xferReadOp' users. - SmallVector sliceUsers; - sliceUsers.reserve(std::distance(xferReadOp.getResult()->use_begin(), - xferReadOp.getResult()->use_end())); - - for (auto *user : xferReadOp.getResult()->getUsers()) { - auto sliceOp = dyn_cast(user); - // Return if any user is not a vector::StridedSliceOp. - if (!sliceOp) - return matchFailure(); - sliceUsers.push_back(sliceOp); - } - // Make zero splat into which we will insert split xferReadOp results. - Location loc = xferReadOp.getLoc(); - auto *res = makeSplatZero(loc, rewriter, xferReadOp.getVectorType()); + // Return unless the unique 'xferReadOp' user is an ExtractSlicesOp. + Value *xferReadResult = xferReadOp.getResult(); + auto extractSlicesOp = + dyn_cast(*xferReadResult->getUsers().begin()); + if (!xferReadResult->hasOneUse() || !extractSlicesOp) + return matchFailure(); - // Update each user in 'sliceUser' to use 'res'. + // Get 'sizes' and 'strides' parameters from ExtractSlicesOp user. + auto sourceVectorType = extractSlicesOp.getSourceVectorType(); + auto resultTupleType = extractSlicesOp.getResultTupleType(); + SmallVector sizes; + extractSlicesOp.getSizes(sizes); + SmallVector strides; + extractSlicesOp.getStrides(strides); + assert(llvm::all_of(strides, [](int64_t s) { return s == 1; })); + + // Compute strides w.r.t. to slice counts in each dimension. + auto maybeDimSliceCounts = shapeRatio(sourceVectorType.getShape(), sizes); + assert(maybeDimSliceCounts.hasValue()); + auto sliceDimCounts = *maybeDimSliceCounts; + auto basis = computeStrides(sliceDimCounts); + + Location loc = xferReadOp.getLoc(); + auto *ctx = rewriter.getContext(); + int64_t numSlices = resultTupleType.size(); unsigned numSliceIndices = llvm::size(xferReadOp.indices()); - for (auto sliceUser : sliceUsers) { - // Gather static offsets from 'sliceUser'. - SmallVector sliceOffsets; - sliceUser.getOffsets(sliceOffsets); - assert(sliceOffsets.size() == numSliceIndices); - auto *ctx = rewriter.getContext(); + SmallVector vectorTupleValues(numSlices); + for (unsigned i = 0; i < numSlices; ++i) { + // De-linearize w.r.t. 'basis'. + auto vectorOffsets = delinearize(i, basis); + // Convert from unrolled vector-space offsets to element-space offsets. + auto offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; }, + vectorOffsets, sizes); // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'. SmallVector sliceIndices(numSliceIndices); for (auto it : llvm::enumerate(xferReadOp.indices())) { auto expr = getAffineDimExpr(0, ctx) + - getAffineConstantExpr(sliceOffsets[it.index()], ctx); + getAffineConstantExpr(offsets[it.index()], ctx); auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr); SmallVector mapOperands = {it.value()}; sliceIndices[it.index()] = rewriter.create(loc, map, mapOperands); } + // Get VectorType for slice 'i'. + auto sliceVectorType = resultTupleType.getType(i); // Create split TransferReadOp for 'sliceUser'. - auto sliceVectorType = - sliceUser.getResult()->getType().cast(); - auto splitXferReadOp = rewriter.create( + vectorTupleValues[i] = rewriter.create( loc, sliceVectorType, xferReadOp.memref(), sliceIndices, xferReadOp.permutation_map(), xferReadOp.padding()); - // Create InsertStridedSlice into splat at same offsets as slice. - res = rewriter.create( - loc, xferReadOp.getVectorType(), splitXferReadOp, res, - sliceUser.offsets(), sliceUser.strides()); } - - // Replace 'xferReadOp' with result 'res'. - rewriter.replaceOp(xferReadOp, res); + // Create tuple of splice xfer read operations. + Value *tupleOp = rewriter.create(loc, resultTupleType, + vectorTupleValues); + // Replace 'xferReadOp' with result 'insertSlicesResult'. + rewriter.replaceOpWithNewOp( + xferReadOp, sourceVectorType, tupleOp, extractSlicesOp.sizes(), + extractSlicesOp.strides()); return matchSuccess(); } }; diff --git a/mlir/test/Dialect/VectorOps/vector-transforms.mlir b/mlir/test/Dialect/VectorOps/vector-transforms.mlir index 783f542..71b7b7a 100644 --- a/mlir/test/Dialect/VectorOps/vector-transforms.mlir +++ b/mlir/test/Dialect/VectorOps/vector-transforms.mlir @@ -232,16 +232,23 @@ func @contraction4x4_ikj(%arg0 : vector<4x2xf32>, %arg1 : vector<2x4xf32>, } // CHECK-LABEL: func @contraction4x4_ikj_xfer_read -// TODO(andydavis) Add VTR splitting back into this test in follow up CL. // CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[C2:.*]] = constant 2 : index // Check LHS vector.transfer read is split for each user. +// TODO(andydavis) Connect VTR results with users in subsequent CL. -// CHECK: %[[VTR0:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x2xf32>, vector<4x2xf32> -// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<2x4xf32>, vector<2x4xf32> -// CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x4xf32>, vector<4x4xf32> +// CHECK: %[[VTR0:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x2xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x2xf32>, vector<2x2xf32> +// CHECK: %[[VTR2:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<2x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C2]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<2x4xf32>, vector<2x2xf32> + +// CHECK: %[[VTR4:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VTR5:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C2]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VTR5:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VTR5:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C2]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x4xf32>, vector<2x2xf32> func @contraction4x4_ikj_xfer_read(%arg0 : memref<4x2xf32>, %arg1 : memref<2x4xf32>, @@ -270,7 +277,7 @@ func @contraction4x4_ikj_xfer_read(%arg0 : memref<4x2xf32>, // TODO(andydavis) Update test with VTR split transform. // CHECK-LABEL: func @vector_transfers -// CHECK-COUNT-2: vector.transfer_read +// CHECK-COUNT-8: vector.transfer_read // CHECK-COUNT-2: vector.extract_slices // CHECK-COUNT-4: addf // CHECK-COUNT-1: vector.insert_slices -- 2.7.4