resultIndex, targetShape, builder);
}
-// Splits vector TransferReadOp into smaller TransferReadOps for each user.
+// Splits vector TransferReadOp into smaller TransferReadOps based on slicing
+// scheme of its unique ExtractSlicesOp user.
struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
// permutation maps. Repurpose code from MaterializeVectors transformation.
if (!xferReadOp.permutation_map().isIdentity())
return matchFailure();
- // Gather 'xferReadOp' users.
- SmallVector<vector::StridedSliceOp, 2> sliceUsers;
- sliceUsers.reserve(std::distance(xferReadOp.getResult()->use_begin(),
- xferReadOp.getResult()->use_end()));
-
- for (auto *user : xferReadOp.getResult()->getUsers()) {
- auto sliceOp = dyn_cast<vector::StridedSliceOp>(user);
- // Return if any user is not a vector::StridedSliceOp.
- if (!sliceOp)
- return matchFailure();
- sliceUsers.push_back(sliceOp);
- }
- // Make zero splat into which we will insert split xferReadOp results.
- Location loc = xferReadOp.getLoc();
- auto *res = makeSplatZero(loc, rewriter, xferReadOp.getVectorType());
+ // Return unless the unique 'xferReadOp' user is an ExtractSlicesOp.
+ Value *xferReadResult = xferReadOp.getResult();
+ auto extractSlicesOp =
+ dyn_cast<vector::ExtractSlicesOp>(*xferReadResult->getUsers().begin());
+ if (!xferReadResult->hasOneUse() || !extractSlicesOp)
+ return matchFailure();
- // Update each user in 'sliceUser' to use 'res'.
+ // Get 'sizes' and 'strides' parameters from ExtractSlicesOp user.
+ auto sourceVectorType = extractSlicesOp.getSourceVectorType();
+ auto resultTupleType = extractSlicesOp.getResultTupleType();
+ SmallVector<int64_t, 4> sizes;
+ extractSlicesOp.getSizes(sizes);
+ SmallVector<int64_t, 4> strides;
+ extractSlicesOp.getStrides(strides);
+ assert(llvm::all_of(strides, [](int64_t s) { return s == 1; }));
+
+ // Compute strides w.r.t. to slice counts in each dimension.
+ auto maybeDimSliceCounts = shapeRatio(sourceVectorType.getShape(), sizes);
+ assert(maybeDimSliceCounts.hasValue());
+ auto sliceDimCounts = *maybeDimSliceCounts;
+ auto basis = computeStrides(sliceDimCounts);
+
+ Location loc = xferReadOp.getLoc();
+ auto *ctx = rewriter.getContext();
+ int64_t numSlices = resultTupleType.size();
unsigned numSliceIndices = llvm::size(xferReadOp.indices());
- for (auto sliceUser : sliceUsers) {
- // Gather static offsets from 'sliceUser'.
- SmallVector<int64_t, 4> sliceOffsets;
- sliceUser.getOffsets(sliceOffsets);
- assert(sliceOffsets.size() == numSliceIndices);
- auto *ctx = rewriter.getContext();
+ SmallVector<Value *, 4> vectorTupleValues(numSlices);
+ for (unsigned i = 0; i < numSlices; ++i) {
+ // De-linearize w.r.t. 'basis'.
+ auto vectorOffsets = delinearize(i, basis);
+ // Convert from unrolled vector-space offsets to element-space offsets.
+ auto offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; },
+ vectorOffsets, sizes);
// Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
SmallVector<Value *, 4> sliceIndices(numSliceIndices);
for (auto it : llvm::enumerate(xferReadOp.indices())) {
auto expr = getAffineDimExpr(0, ctx) +
- getAffineConstantExpr(sliceOffsets[it.index()], ctx);
+ getAffineConstantExpr(offsets[it.index()], ctx);
auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
SmallVector<Value *, 1> mapOperands = {it.value()};
sliceIndices[it.index()] =
rewriter.create<AffineApplyOp>(loc, map, mapOperands);
}
+ // Get VectorType for slice 'i'.
+ auto sliceVectorType = resultTupleType.getType(i);
// Create split TransferReadOp for 'sliceUser'.
- auto sliceVectorType =
- sliceUser.getResult()->getType().cast<VectorType>();
- auto splitXferReadOp = rewriter.create<vector::TransferReadOp>(
+ vectorTupleValues[i] = rewriter.create<vector::TransferReadOp>(
loc, sliceVectorType, xferReadOp.memref(), sliceIndices,
xferReadOp.permutation_map(), xferReadOp.padding());
- // Create InsertStridedSlice into splat at same offsets as slice.
- res = rewriter.create<vector::InsertStridedSliceOp>(
- loc, xferReadOp.getVectorType(), splitXferReadOp, res,
- sliceUser.offsets(), sliceUser.strides());
}
-
- // Replace 'xferReadOp' with result 'res'.
- rewriter.replaceOp(xferReadOp, res);
+ // Create tuple of splice xfer read operations.
+ Value *tupleOp = rewriter.create<vector::TupleOp>(loc, resultTupleType,
+ vectorTupleValues);
+ // Replace 'xferReadOp' with result 'insertSlicesResult'.
+ rewriter.replaceOpWithNewOp<vector::InsertSlicesOp>(
+ xferReadOp, sourceVectorType, tupleOp, extractSlicesOp.sizes(),
+ extractSlicesOp.strides());
return matchSuccess();
}
};
}
// CHECK-LABEL: func @contraction4x4_ikj_xfer_read
-// TODO(andydavis) Add VTR splitting back into this test in follow up CL.
// CHECK: %[[C0:.*]] = constant 0 : index
+// CHECK: %[[C2:.*]] = constant 2 : index
// Check LHS vector.transfer read is split for each user.
+// TODO(andydavis) Connect VTR results with users in subsequent CL.
-// CHECK: %[[VTR0:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x2xf32>, vector<4x2xf32>
-// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<2x4xf32>, vector<2x4xf32>
-// CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x4xf32>, vector<4x4xf32>
+// CHECK: %[[VTR0:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x2xf32>, vector<2x2xf32>
+// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x2xf32>, vector<2x2xf32>
+// CHECK: %[[VTR2:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<2x4xf32>, vector<2x2xf32>
+// CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C2]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<2x4xf32>, vector<2x2xf32>
+
+// CHECK: %[[VTR4:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x4xf32>, vector<2x2xf32>
+// CHECK-NEXT: %[[VTR5:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C2]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x4xf32>, vector<2x2xf32>
+// CHECK-NEXT: %[[VTR5:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x4xf32>, vector<2x2xf32>
+// CHECK-NEXT: %[[VTR5:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C2]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x4xf32>, vector<2x2xf32>
func @contraction4x4_ikj_xfer_read(%arg0 : memref<4x2xf32>,
%arg1 : memref<2x4xf32>,
// TODO(andydavis) Update test with VTR split transform.
// CHECK-LABEL: func @vector_transfers
-// CHECK-COUNT-2: vector.transfer_read
+// CHECK-COUNT-8: vector.transfer_read
// CHECK-COUNT-2: vector.extract_slices
// CHECK-COUNT-4: addf
// CHECK-COUNT-1: vector.insert_slices