Add VectorOp transform pattern which splits vector TransferReadOps to target vector...
authorAndy Davis <andydavis@google.com>
Wed, 11 Dec 2019 01:02:17 +0000 (17:02 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 11 Dec 2019 01:02:51 +0000 (17:02 -0800)
PiperOrigin-RevId: 284880592

mlir/include/mlir/Dialect/VectorOps/VectorOps.h
mlir/include/mlir/Dialect/VectorOps/VectorOps.td
mlir/lib/Dialect/VectorOps/VectorOps.cpp
mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
mlir/test/Dialect/VectorOps/vector-transforms.mlir
mlir/test/lib/Transforms/TestVectorTransforms.cpp

index 8cb0d85..5b4351b 100644 (file)
@@ -43,6 +43,10 @@ public:
 void populateVectorToVectorCanonicalizationPatterns(
     OwningRewritePatternList &patterns, MLIRContext *context);
 
+/// Collect a set of vector-to-vector transformation patterns.
+void populateVectorToVectorTransformationPatterns(
+    OwningRewritePatternList &patterns, MLIRContext *context);
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/VectorOps/VectorOps.h.inc"
 
index f6e1ae5..d87f101 100644 (file)
@@ -451,6 +451,7 @@ def Vector_StridedSliceOp :
     static StringRef getSizesAttrName() { return "sizes"; }
     static StringRef getStridesAttrName() { return "strides"; }
     VectorType getVectorType(){ return vector()->getType().cast<VectorType>(); }
+    void getOffsets(SmallVectorImpl<int64_t> &results);
   }];
   let hasCanonicalizer = 1;
 }
index 7714623..28a0322 100644 (file)
@@ -731,6 +731,12 @@ LogicalResult isSumOfIntegerArrayAttrConfinedToShape(
   return success();
 }
 
+static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
+                                       SmallVectorImpl<int64_t> &results) {
+  for (auto attr : arrayAttr)
+    results.push_back(attr.cast<IntegerAttr>().getInt());
+}
+
 static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
                                   MLIRContext *context) {
   auto attrs = functional::map(
@@ -929,14 +935,12 @@ static LogicalResult verify(StridedSliceOp op) {
   return success();
 }
 
-namespace {
-
-static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
-                                       SmallVectorImpl<int64_t> &results) {
-  for (auto attr : arrayAttr)
-    results.push_back(attr.cast<IntegerAttr>().getInt());
+void StridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
+  populateFromInt64AttrArray(offsets(), results);
 }
 
+namespace {
+
 // Pattern to rewrite a StridedSliceOp(ConstantMaskOp) -> ConstantMaskOp.
 class StridedSliceConstantMaskFolder final
     : public OpRewritePattern<StridedSliceOp> {
index 6b13bcf..6825709 100644 (file)
@@ -446,3 +446,71 @@ Value *mlir::vector::unrollSingleResultOpMatchingType(
   return unrollSingleResultStructuredOp(op, iterationBounds, vectors,
                                         resultIndex, targetShape, builder);
 }
+
+// Splits vector TransferReadOp into smaller TransferReadOps for each user.
+struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
+  using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
+
+  PatternMatchResult matchAndRewrite(vector::TransferReadOp xferReadOp,
+                                     PatternRewriter &rewriter) const override {
+    // TODO(andydavis, ntv) Support spliting TransferReadOp with non-identity
+    // permutation maps. Repurpose code from MaterializeVectors transformation.
+    if (!xferReadOp.permutation_map().isIdentity())
+      return matchFailure();
+    // Gather 'xferReadOp' users.
+    SmallVector<vector::StridedSliceOp, 2> sliceUsers;
+    sliceUsers.reserve(std::distance(xferReadOp.getResult()->use_begin(),
+                                     xferReadOp.getResult()->use_end()));
+
+    for (auto *user : xferReadOp.getResult()->getUsers()) {
+      auto sliceOp = dyn_cast<vector::StridedSliceOp>(user);
+      // Return if any user is not a vector::StridedSliceOp.
+      if (!sliceOp)
+        return matchFailure();
+      sliceUsers.push_back(sliceOp);
+    }
+    // Make zero splat into which we will insert split xferReadOp results.
+    Location loc = xferReadOp.getLoc();
+    auto *res = makeSplatZero(loc, rewriter, xferReadOp.getVectorType());
+
+    // Update each user in 'sliceUser' to use 'res'.
+    unsigned numSliceIndices = llvm::size(xferReadOp.indices());
+    for (auto sliceUser : sliceUsers) {
+      // Gather static offsets from 'sliceUser'.
+      SmallVector<int64_t, 4> sliceOffsets;
+      sliceUser.getOffsets(sliceOffsets);
+      assert(sliceOffsets.size() == numSliceIndices);
+      auto *ctx = rewriter.getContext();
+      // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
+      SmallVector<Value *, 4> sliceIndices(numSliceIndices);
+      for (auto it : llvm::enumerate(xferReadOp.indices())) {
+        auto expr = getAffineDimExpr(0, ctx) +
+                    getAffineConstantExpr(sliceOffsets[it.index()], ctx);
+        auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
+        SmallVector<Value *, 1> mapOperands = {it.value()};
+        sliceIndices[it.index()] =
+            rewriter.create<AffineApplyOp>(loc, map, mapOperands);
+      }
+      // Create split TransferReadOp for 'sliceUser'.
+      auto sliceVectorType =
+          sliceUser.getResult()->getType().cast<VectorType>();
+      auto splitXferReadOp = rewriter.create<vector::TransferReadOp>(
+          loc, sliceVectorType, xferReadOp.memref(), sliceIndices,
+          xferReadOp.permutation_map(), xferReadOp.padding());
+      // Create InsertStridedSlice into splat at same offsets as slice.
+      res = rewriter.create<vector::InsertStridedSliceOp>(
+          loc, xferReadOp.getVectorType(), splitXferReadOp, res,
+          sliceUser.offsets(), sliceUser.strides());
+    }
+
+    // Replace 'xferReadOp' with result 'res'.
+    rewriter.replaceOp(xferReadOp, res);
+    return matchSuccess();
+  }
+};
+
+// TODO(andydavis) Add this as DRR pattern.
+void mlir::vector::populateVectorToVectorTransformationPatterns(
+    OwningRewritePatternList &patterns, MLIRContext *context) {
+  patterns.insert<SplitTransferReadOp>(context);
+}
index 4fb235d..c8d92ee 100644 (file)
@@ -1,5 +1,7 @@
 // RUN: mlir-opt %s -test-vector-to-vector-conversion | FileCheck %s
 
+// CHECK-DAG: #[[MAP0:map[0-9]+]] = (d0, d1) -> (d0, d1)
+
 // CHECK-LABEL: func @add4x2
 //      CHECK: %[[V1:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32>
 // CHECK-NEXT: %[[V2:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32>
@@ -206,8 +208,70 @@ func @contraction4x4_ikj(%arg0 : vector<4x2xf32>, %arg1 : vector<2x4xf32>,
   return %0 : vector<4x4xf32>
 }
 
+// CHECK-LABEL: func @contraction4x4_ikj_xfer_read
+
+// Capture constants used to re-index vector transfer reads.
+// CHECK:       %[[C0:.*]] = constant 0 : index
+// CHECK:       %[[C2:.*]] = constant 2 : index
+
+// Check LHS vector.transfer read is split for each user.
+// CHECK:       %[[VTR0:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x2xf32>, vector<2x2xf32>
+// CHECK-NEXT:  %[[ISS0:.*]] = vector.insert_strided_slice %[[VTR0]], %{{.*}} {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x2xf32>
+// CHECK-NEXT:  %[[VTR1:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x2xf32>, vector<2x2xf32>
+// CHECK-NEXT:  %[[ISS1:.*]] = vector.insert_strided_slice %[[VTR1]], %[[ISS0]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x2xf32>
+
+// Check RHS vector.transfer read is split for each user.
+// CHECK:       %[[VTR2:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C2]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<2x4xf32>, vector<2x2xf32>
+// CHECK-NEXT:  %[[ISS2:.*]] = vector.insert_strided_slice %[[VTR2]], %{{.*}} {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<2x4xf32>
+// CHECK-NEXT:  %[[VTR3:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<2x4xf32>, vector<2x2xf32>
+// CHECK-NEXT:  %[[ISS3:.*]] = vector.insert_strided_slice %[[VTR3]], %[[ISS2]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<2x4xf32>
+
+// Check ACC vector.transfer read is split for each user (should be 4).
+// CHECK:       %[[VTR4:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C2]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x4xf32>, vector<2x2xf32>
+// CHECK-NEXT:  %[[ISS4:.*]] = vector.insert_strided_slice %[[VTR4]], %{{.*}} {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+// CHECK-NEXT:  %[[VTR5:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x4xf32>, vector<2x2xf32>
+// CHECK-NEXT:  %[[ISS5:.*]] = vector.insert_strided_slice %[[VTR5]], %[[ISS4]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+// CHECK-NEXT:  %[[VTR6:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C2]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x4xf32>, vector<2x2xf32>
+// CHECK-NEXT:  %[[ISS6:.*]] = vector.insert_strided_slice %[[VTR6]], %[[ISS5]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+// CHECK-NEXT:  %[[VTR7:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<4x4xf32>, vector<2x2xf32>
+// CHECK-NEXT:  %[[ISS7:.*]] = vector.insert_strided_slice %[[VTR7]], %[[ISS6]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+
+// Check LHS slice uses splat of split tranfer read results.
+// CHECK: vector.strided_slice %[[ISS1]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32>
+
+// Check RHS slice uses splat of split tranfer read results.
+// CHECK: vector.strided_slice %[[ISS3]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<2x4xf32> to vector<2x2xf32>
+
+// Check ACC slice uses splat of split tranfer read results.
+// CHECK: vector.strided_slice %[[ISS7]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+
+func @contraction4x4_ikj_xfer_read(%arg0 : memref<4x2xf32>,
+                                   %arg1 : memref<2x4xf32>,
+                                   %arg2 : memref<4x4xf32>)
+                                   -> (vector<4x4xf32>) {
+  %c0 = constant 0 : index
+  %cf0 = constant 0.0 : f32
+
+  %0 = vector.transfer_read %arg0[%c0, %c0], %cf0
+    { permutation_map = (d0, d1) -> (d0, d1) }
+      : memref<4x2xf32>, vector<4x2xf32>
+
+  %1 = vector.transfer_read %arg1[%c0, %c0], %cf0
+    { permutation_map = (d0, d1) -> (d0, d1) }
+    : memref<2x4xf32>, vector<2x4xf32>
+
+  %2 = vector.transfer_read %arg2[%c0, %c0], %cf0
+    { permutation_map = (d0, d1) -> (d0, d1) }
+      : memref<4x4xf32>, vector<4x4xf32>
+
+  %3 = vector.contract #contraction_trait1 %0, %1, %2
+      : vector<4x2xf32>, vector<2x4xf32> into vector<4x4xf32>
+
+  return %3 : vector<4x4xf32>
+}
+
 // CHECK-LABEL: func @vector_transfers
-// CHECK-COUNT-2: vector.transfer_read
+// CHECK-COUNT-8: vector.transfer_read
 // CHECK-COUNT-2: vector.strided_slice
 // CHECK-COUNT-1: addf
 // CHECK-COUNT-2: vector.strided_slice
index 909fe2a..1d51306 100644 (file)
@@ -36,6 +36,7 @@ struct TestVectorToVectorConversion
     auto *context = &getContext();
     populateWithGenerated(context, &patterns);
     populateVectorToVectorCanonicalizationPatterns(patterns, context);
+    populateVectorToVectorTransformationPatterns(patterns, context);
     applyPatternsGreedily(getFunction(), patterns);
   }
 };