Add pattern rewrite which splits a vector TransferReadOp into slices according to...
authorAndy Davis <andydavis@google.com>
Tue, 17 Dec 2019 15:28:37 +0000 (07:28 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 17 Dec 2019 15:29:06 +0000 (07:29 -0800)
PiperOrigin-RevId: 285975613

mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
mlir/test/Dialect/VectorOps/vector-transforms.mlir

index 8d70f4a..85f306e 100644 (file)
@@ -511,7 +511,8 @@ Value *mlir::vector::unrollSingleResultOpMatchingType(
                                         resultIndex, targetShape, builder);
 }
 
-// Splits vector TransferReadOp into smaller TransferReadOps for each user.
+// Splits vector TransferReadOp into smaller TransferReadOps based on slicing
+// scheme of its unique ExtractSlicesOp user.
 struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
   using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
 
@@ -521,54 +522,63 @@ struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
     // permutation maps. Repurpose code from MaterializeVectors transformation.
     if (!xferReadOp.permutation_map().isIdentity())
       return matchFailure();
-    // Gather 'xferReadOp' users.
-    SmallVector<vector::StridedSliceOp, 2> sliceUsers;
-    sliceUsers.reserve(std::distance(xferReadOp.getResult()->use_begin(),
-                                     xferReadOp.getResult()->use_end()));
-
-    for (auto *user : xferReadOp.getResult()->getUsers()) {
-      auto sliceOp = dyn_cast<vector::StridedSliceOp>(user);
-      // Return if any user is not a vector::StridedSliceOp.
-      if (!sliceOp)
-        return matchFailure();
-      sliceUsers.push_back(sliceOp);
-    }
-    // Make zero splat into which we will insert split xferReadOp results.
-    Location loc = xferReadOp.getLoc();
-    auto *res = makeSplatZero(loc, rewriter, xferReadOp.getVectorType());
+    // Return unless the unique 'xferReadOp' user is an ExtractSlicesOp.
+    Value *xferReadResult = xferReadOp.getResult();
+    auto extractSlicesOp =
+        dyn_cast<vector::ExtractSlicesOp>(*xferReadResult->getUsers().begin());
+    if (!xferReadResult->hasOneUse() || !extractSlicesOp)
+      return matchFailure();
 
-    // Update each user in 'sliceUser' to use 'res'.
+    // Get 'sizes' and 'strides' parameters from ExtractSlicesOp user.
+    auto sourceVectorType = extractSlicesOp.getSourceVectorType();
+    auto resultTupleType = extractSlicesOp.getResultTupleType();
+    SmallVector<int64_t, 4> sizes;
+    extractSlicesOp.getSizes(sizes);
+    SmallVector<int64_t, 4> strides;
+    extractSlicesOp.getStrides(strides);
+    assert(llvm::all_of(strides, [](int64_t s) { return s == 1; }));
+
+    // Compute strides w.r.t. to slice counts in each dimension.
+    auto maybeDimSliceCounts = shapeRatio(sourceVectorType.getShape(), sizes);
+    assert(maybeDimSliceCounts.hasValue());
+    auto sliceDimCounts = *maybeDimSliceCounts;
+    auto basis = computeStrides(sliceDimCounts);
+
+    Location loc = xferReadOp.getLoc();
+    auto *ctx = rewriter.getContext();
+    int64_t numSlices = resultTupleType.size();
     unsigned numSliceIndices = llvm::size(xferReadOp.indices());
-    for (auto sliceUser : sliceUsers) {
-      // Gather static offsets from 'sliceUser'.
-      SmallVector<int64_t, 4> sliceOffsets;
-      sliceUser.getOffsets(sliceOffsets);
-      assert(sliceOffsets.size() == numSliceIndices);
-      auto *ctx = rewriter.getContext();
+    SmallVector<Value *, 4> vectorTupleValues(numSlices);
+    for (unsigned i = 0; i < numSlices; ++i) {
+      // De-linearize w.r.t. 'basis'.
+      auto vectorOffsets = delinearize(i, basis);
+      // Convert from unrolled vector-space offsets to element-space offsets.
+      auto offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; },
+                            vectorOffsets, sizes);
       // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
       SmallVector<Value *, 4> sliceIndices(numSliceIndices);
       for (auto it : llvm::enumerate(xferReadOp.indices())) {
         auto expr = getAffineDimExpr(0, ctx) +
-                    getAffineConstantExpr(sliceOffsets[it.index()], ctx);
+                    getAffineConstantExpr(offsets[it.index()], ctx);
         auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
         SmallVector<Value *, 1> mapOperands = {it.value()};
         sliceIndices[it.index()] =
             rewriter.create<AffineApplyOp>(loc, map, mapOperands);
       }
+      // Get VectorType for slice 'i'.
+      auto sliceVectorType = resultTupleType.getType(i);
       // Create split TransferReadOp for 'sliceUser'.
-      auto sliceVectorType =
-          sliceUser.getResult()->getType().cast<VectorType>();
-      auto splitXferReadOp = rewriter.create<vector::TransferReadOp>(
+      vectorTupleValues[i] = rewriter.create<vector::TransferReadOp>(
           loc, sliceVectorType, xferReadOp.memref(), sliceIndices,
           xferReadOp.permutation_map(), xferReadOp.padding());
-      // Create InsertStridedSlice into splat at same offsets as slice.
-      res = rewriter.create<vector::InsertStridedSliceOp>(
-          loc, xferReadOp.getVectorType(), splitXferReadOp, res,
-          sliceUser.offsets(), sliceUser.strides());
     }
-
-    // Replace 'xferReadOp' with result 'res'.
-    rewriter.replaceOp(xferReadOp, res);
+    // Create tuple of splice xfer read operations.
+    Value *tupleOp = rewriter.create<vector::TupleOp>(loc, resultTupleType,
+                                                      vectorTupleValues);
+    // Replace 'xferReadOp' with result 'insertSlicesResult'.
+    rewriter.replaceOpWithNewOp<vector::InsertSlicesOp>(
+        xferReadOp, sourceVectorType, tupleOp, extractSlicesOp.sizes(),
+        extractSlicesOp.strides());
     return matchSuccess();
   }
 };
index 783f542..71b7b7a 100644 (file)
@@ -232,16 +232,23 @@ func @contraction4x4_ikj(%arg0 : vector<4x2xf32>, %arg1 : vector<2x4xf32>,
 }
 
 // CHECK-LABEL: func @contraction4x4_ikj_xfer_read
-// TODO(andydavis) Add VTR splitting back into this test in follow up CL.
 
 // CHECK:      %[[C0:.*]] = constant 0 : index
+// CHECK:      %[[C2:.*]] = constant 2 : index
 
 // Check LHS vector.transfer read is split for each user.
+// TODO(andydavis) Connect VTR results with users in subsequent CL.
 
-//      CHECK: %[[VTR0:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x2xf32>, vector<4x2xf32>
-// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<2x4xf32>, vector<2x4xf32>
-// CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x4xf32>, vector<4x4xf32>
+//      CHECK: %[[VTR0:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x2xf32>, vector<2x2xf32>
+// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x2xf32>, vector<2x2xf32>
 
+//      CHECK: %[[VTR2:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<2x4xf32>, vector<2x2xf32>
+// CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C2]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<2x4xf32>, vector<2x2xf32>
+
+//      CHECK: %[[VTR4:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x4xf32>, vector<2x2xf32>
+// CHECK-NEXT: %[[VTR5:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C2]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x4xf32>, vector<2x2xf32>
+// CHECK-NEXT: %[[VTR5:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x4xf32>, vector<2x2xf32>
+// CHECK-NEXT: %[[VTR5:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C2]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x4xf32>, vector<2x2xf32>
 
 func @contraction4x4_ikj_xfer_read(%arg0 : memref<4x2xf32>,
                                    %arg1 : memref<2x4xf32>,
@@ -270,7 +277,7 @@ func @contraction4x4_ikj_xfer_read(%arg0 : memref<4x2xf32>,
 
 // TODO(andydavis) Update test with VTR split transform.
 // CHECK-LABEL: func @vector_transfers
-// CHECK-COUNT-2: vector.transfer_read
+// CHECK-COUNT-8: vector.transfer_read
 // CHECK-COUNT-2: vector.extract_slices
 // CHECK-COUNT-4: addf
 // CHECK-COUNT-1: vector.insert_slices