[mlir][vector] Support distributing transfer op with permutation map
authorthomasraoux <thomasraoux@google.com>
Mon, 14 Jun 2021 20:25:18 +0000 (13:25 -0700)
committerthomasraoux <thomasraoux@google.com>
Mon, 21 Jun 2021 19:56:08 +0000 (12:56 -0700)
Differential Revision: https://reviews.llvm.org/D104263

mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/Dialect/Vector/vector-distribution.mlir

index baded89..6765fd4 100644 (file)
@@ -2842,6 +2842,20 @@ Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp(
   return ops;
 }
 
+/// Converts TransferRead op used by ExtractMap op into a smaller dimension
+/// TransferRead.
+/// Example:
+/// ```
+/// %a = vector.transfer_read %A[%c0, %c0, %c0], %cf0:
+///   memref<64x64x64xf32>, vector<64x4x32xf32>
+/// %e = vector.extract_map %a[%id] : vector<64x4x32xf32> to vector<2x4x1xf32>
+/// ```
+/// to:
+/// ```
+/// %id1 = affine.apply affine_map<()[s0] -> (s0 * 2)> (%id)
+/// %e = vector.transfer_read %A[%id1, %c0, %id1], %cf0 :
+///   memref<64x64x64xf32>, vector<2x4x1xf32>
+/// ```
 struct TransferReadExtractPattern
     : public OpRewritePattern<vector::TransferReadOp> {
   TransferReadExtractPattern(MLIRContext *context)
@@ -2858,18 +2872,23 @@ struct TransferReadExtractPattern
       return failure();
 
     SmallVector<Value, 4> indices(read.indices().begin(), read.indices().end());
-    AffineMap map = extract.map();
+    AffineMap indexMap = extract.map().compose(read.permutation_map());
     unsigned idCount = 0;
     ImplicitLocOpBuilder lb(read.getLoc(), rewriter);
-    for (auto expr : map.getResults()) {
+    for (auto it :
+         llvm::zip(indexMap.getResults(), extract.map().getResults())) {
       AffineExpr d0, d1;
       bindDims(read.getContext(), d0, d1);
-      unsigned pos = expr.cast<AffineDimExpr>().getPosition();
+      auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
+      if (!indexExpr)
+        continue;
+      unsigned indexPos = indexExpr.getPosition();
+      unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
       auto scale = getAffineConstantExpr(
-          extract.getResultType().getDimSize(pos), read.getContext());
-      indices[pos] =
-          makeComposedAffineApply(rewriter, read.getLoc(), d0 + scale * d1,
-                                  {indices[pos], extract.ids()[idCount++]});
+          extract.getResultType().getDimSize(vectorPos), read.getContext());
+      indices[indexPos] = makeComposedAffineApply(
+          rewriter, read.getLoc(), d0 + scale * d1,
+          {indices[indexPos], extract.ids()[idCount++]});
     }
     Value newRead = lb.create<vector::TransferReadOp>(
         extract.getType(), read.source(), indices, read.permutation_map(),
@@ -2895,18 +2914,24 @@ struct TransferWriteInsertPattern
       return failure();
     SmallVector<Value, 4> indices(write.indices().begin(),
                                   write.indices().end());
-    AffineMap map = insert.map();
+    AffineMap indexMap = insert.map().compose(write.permutation_map());
     unsigned idCount = 0;
     Location loc = write.getLoc();
-    for (auto expr : map.getResults()) {
+    for (auto it :
+         llvm::zip(indexMap.getResults(), insert.map().getResults())) {
       AffineExpr d0, d1;
       bindDims(write.getContext(), d0, d1);
-      unsigned pos = expr.cast<AffineDimExpr>().getPosition();
+      auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
+      if (!indexExpr)
+        continue;
+      unsigned indexPos = indexExpr.getPosition();
+      unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
       auto scale = getAffineConstantExpr(
-          insert.getSourceVectorType().getDimSize(pos), write.getContext());
-      indices[pos] =
+          insert.getSourceVectorType().getDimSize(vectorPos),
+          write.getContext());
+      indices[indexPos] =
           makeComposedAffineApply(rewriter, loc, d0 + scale * d1,
-                                  {indices[pos], insert.ids()[idCount++]});
+                                  {indices[indexPos], insert.ids()[idCount++]});
     }
     rewriter.create<vector::TransferWriteOp>(
         loc, insert.vector(), write.source(), indices, write.permutation_map(),
index 950786e..0ad46d1 100644 (file)
@@ -123,4 +123,34 @@ func @vector_add_transfer_3d(%id0 : index, %id1 : index, %A: memref<64x64x64xf32
   return
 }
 
+// -----
+
+#map0 = affine_map<(d0, d1, d2, d3) -> (d3, 0, 0)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (0, d3, d0)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d3, d2, d1)>
+
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 2)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, 0, 0)>
+// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (0, d3, d0)>
+// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d2, d1)>
 
+//       CHECK: func @vector_add_transfer_permutation
+//  CHECK-SAME: (%[[ID_0:.*]]: index, %[[ID_1:.*]]: index
+//       CHECK:    %[[C0:.*]] = constant 0 : index
+//       CHECK:    %[[ID2:.*]] = affine.apply #[[MAP0]]()[%[[ID_0]]]
+//  CHECK-NEXT:    %[[EXA:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[ID2]]], %{{.*}} {permutation_map = #[[MAP1]]} : memref<?x?x?x?xf32>, vector<2x4x1xf32>
+//  CHECK-NEXT:    %[[EXB:.*]] = vector.transfer_read %{{.*}}[%[[ID_0]], %[[C0]], %[[C0]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP2]]} : memref<?x?x?x?xf32>, vector<2x4x1xf32>
+//  CHECK-NEXT:    %[[ADD:.*]] = addf %[[EXA]], %[[EXB]] : vector<2x4x1xf32>
+//  CHECK-NEXT:    %[[ID3:.*]] = affine.apply #[[MAP0]]()[%[[ID_0]]]
+//  CHECK-NEXT:    vector.transfer_write %[[ADD]], %{{.*}}[%[[C0]], %[[ID_1]], %[[C0]], %[[ID3]]] {permutation_map = #[[MAP3]]} : vector<2x4x1xf32>, memref<?x?x?x?xf32>
+//  CHECK-NEXT:    return
+func @vector_add_transfer_permutation(%id0 : index, %id1 : index, %A: memref<?x?x?x?xf32>,
+  %B: memref<?x?x?x?xf32>, %C: memref<?x?x?x?xf32>) {
+  %c0 = constant 0 : index
+  %cf0 = constant 0.0 : f32
+  %a = vector.transfer_read %A[%c0, %c0, %c0, %c0], %cf0 {permutation_map = #map0} : memref<?x?x?x?xf32>, vector<64x4x32xf32>
+  %b = vector.transfer_read %B[%c0, %c0, %c0, %c0], %cf0 {permutation_map = #map1}: memref<?x?x?x?xf32>, vector<64x4x32xf32>
+  %acc = addf %a, %b: vector<64x4x32xf32>
+  vector.transfer_write %acc, %C[%c0, %c0, %c0, %c0] {permutation_map = #map2}: vector<64x4x32xf32>, memref<?x?x?x?xf32>
+  return
+}