return unrollSingleResultStructuredOp(op, iterationBounds, vectors,
resultIndex, targetShape, builder);
}
+
+// Splits vector TransferReadOp into smaller TransferReadOps for each user.
+struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
+ using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(vector::TransferReadOp xferReadOp,
+ PatternRewriter &rewriter) const override {
+ // TODO(andydavis, ntv) Support spliting TransferReadOp with non-identity
+ // permutation maps. Repurpose code from MaterializeVectors transformation.
+ if (!xferReadOp.permutation_map().isIdentity())
+ return matchFailure();
+ // Gather 'xferReadOp' users.
+ SmallVector<vector::StridedSliceOp, 2> sliceUsers;
+ sliceUsers.reserve(std::distance(xferReadOp.getResult()->use_begin(),
+ xferReadOp.getResult()->use_end()));
+
+ for (auto *user : xferReadOp.getResult()->getUsers()) {
+ auto sliceOp = dyn_cast<vector::StridedSliceOp>(user);
+ // Return if any user is not a vector::StridedSliceOp.
+ if (!sliceOp)
+ return matchFailure();
+ sliceUsers.push_back(sliceOp);
+ }
+ // Make zero splat into which we will insert split xferReadOp results.
+ Location loc = xferReadOp.getLoc();
+ auto *res = makeSplatZero(loc, rewriter, xferReadOp.getVectorType());
+
+ // Update each user in 'sliceUser' to use 'res'.
+ unsigned numSliceIndices = llvm::size(xferReadOp.indices());
+ for (auto sliceUser : sliceUsers) {
+ // Gather static offsets from 'sliceUser'.
+ SmallVector<int64_t, 4> sliceOffsets;
+ sliceUser.getOffsets(sliceOffsets);
+ assert(sliceOffsets.size() == numSliceIndices);
+ auto *ctx = rewriter.getContext();
+ // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
+ SmallVector<Value *, 4> sliceIndices(numSliceIndices);
+ for (auto it : llvm::enumerate(xferReadOp.indices())) {
+ auto expr = getAffineDimExpr(0, ctx) +
+ getAffineConstantExpr(sliceOffsets[it.index()], ctx);
+ auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
+ SmallVector<Value *, 1> mapOperands = {it.value()};
+ sliceIndices[it.index()] =
+ rewriter.create<AffineApplyOp>(loc, map, mapOperands);
+ }
+ // Create split TransferReadOp for 'sliceUser'.
+ auto sliceVectorType =
+ sliceUser.getResult()->getType().cast<VectorType>();
+ auto splitXferReadOp = rewriter.create<vector::TransferReadOp>(
+ loc, sliceVectorType, xferReadOp.memref(), sliceIndices,
+ xferReadOp.permutation_map(), xferReadOp.padding());
+ // Create InsertStridedSlice into splat at same offsets as slice.
+ res = rewriter.create<vector::InsertStridedSliceOp>(
+ loc, xferReadOp.getVectorType(), splitXferReadOp, res,
+ sliceUser.offsets(), sliceUser.strides());
+ }
+
+ // Replace 'xferReadOp' with result 'res'.
+ rewriter.replaceOp(xferReadOp, res);
+ return matchSuccess();
+ }
+};
+
+// TODO(andydavis) Add this as DRR pattern.
+void mlir::vector::populateVectorToVectorTransformationPatterns(
+ OwningRewritePatternList &patterns, MLIRContext *context) {
+ patterns.insert<SplitTransferReadOp>(context);
+}
// RUN: mlir-opt %s -test-vector-to-vector-conversion | FileCheck %s
+// CHECK-DAG: #[[MAP0:map[0-9]+]] = (d0, d1) -> (d0, d1)
+
// CHECK-LABEL: func @add4x2
// CHECK: %[[V1:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32>
// CHECK-NEXT: %[[V2:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32>
return %0 : vector<4x4xf32>
}
+// CHECK-LABEL: func @contraction4x4_ikj_xfer_read
+
+// Capture constants used to re-index vector transfer reads.
+// CHECK: %[[C0:.*]] = constant 0 : index
+// CHECK: %[[C2:.*]] = constant 2 : index
+
+// Check LHS vector.transfer read is split for each user.
+// CHECK: %[[VTR0:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x2xf32>, vector<2x2xf32>
+// CHECK-NEXT: %[[ISS0:.*]] = vector.insert_strided_slice %[[VTR0]], %{{.*}} {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x2xf32>
+// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x2xf32>, vector<2x2xf32>
+// CHECK-NEXT: %[[ISS1:.*]] = vector.insert_strided_slice %[[VTR1]], %[[ISS0]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x2xf32>
+
+// Check RHS vector.transfer read is split for each user.
+// CHECK: %[[VTR2:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C2]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<2x4xf32>, vector<2x2xf32>
+// CHECK-NEXT: %[[ISS2:.*]] = vector.insert_strided_slice %[[VTR2]], %{{.*}} {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<2x4xf32>
+// CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<2x4xf32>, vector<2x2xf32>
+// CHECK-NEXT: %[[ISS3:.*]] = vector.insert_strided_slice %[[VTR3]], %[[ISS2]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<2x4xf32>
+
+// Check ACC vector.transfer read is split for each user (should be 4).
+// CHECK: %[[VTR4:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C2]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x4xf32>, vector<2x2xf32>
+// CHECK-NEXT: %[[ISS4:.*]] = vector.insert_strided_slice %[[VTR4]], %{{.*}} {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+// CHECK-NEXT: %[[VTR5:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x4xf32>, vector<2x2xf32>
+// CHECK-NEXT: %[[ISS5:.*]] = vector.insert_strided_slice %[[VTR5]], %[[ISS4]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+// CHECK-NEXT: %[[VTR6:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C2]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x4xf32>, vector<2x2xf32>
+// CHECK-NEXT: %[[ISS6:.*]] = vector.insert_strided_slice %[[VTR6]], %[[ISS5]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+// CHECK-NEXT: %[[VTR7:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x4xf32>, vector<2x2xf32>
+// CHECK-NEXT: %[[ISS7:.*]] = vector.insert_strided_slice %[[VTR7]], %[[ISS6]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+
+// Check LHS slice uses splat of split tranfer read results.
+// CHECK: vector.strided_slice %[[ISS1]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32>
+
+// Check RHS slice uses splat of split tranfer read results.
+// CHECK: vector.strided_slice %[[ISS3]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<2x4xf32> to vector<2x2xf32>
+
+// Check ACC slice uses splat of split tranfer read results.
+// CHECK: vector.strided_slice %[[ISS7]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+
+func @contraction4x4_ikj_xfer_read(%arg0 : memref<4x2xf32>,
+ %arg1 : memref<2x4xf32>,
+ %arg2 : memref<4x4xf32>)
+ -> (vector<4x4xf32>) {
+ %c0 = constant 0 : index
+ %cf0 = constant 0.0 : f32
+
+ %0 = vector.transfer_read %arg0[%c0, %c0], %cf0
+ { permutation_map = (d0, d1) -> (d0, d1) }
+ : memref<4x2xf32>, vector<4x2xf32>
+
+ %1 = vector.transfer_read %arg1[%c0, %c0], %cf0
+ { permutation_map = (d0, d1) -> (d0, d1) }
+ : memref<2x4xf32>, vector<2x4xf32>
+
+ %2 = vector.transfer_read %arg2[%c0, %c0], %cf0
+ { permutation_map = (d0, d1) -> (d0, d1) }
+ : memref<4x4xf32>, vector<4x4xf32>
+
+ %3 = vector.contract #contraction_trait1 %0, %1, %2
+ : vector<4x2xf32>, vector<2x4xf32> into vector<4x4xf32>
+
+ return %3 : vector<4x4xf32>
+}
+
// CHECK-LABEL: func @vector_transfers
-// CHECK-COUNT-2: vector.transfer_read
+// CHECK-COUNT-8: vector.transfer_read
// CHECK-COUNT-2: vector.strided_slice
// CHECK-COUNT-1: addf
// CHECK-COUNT-2: vector.strided_slice