[mlir][vector] Add extra lowering for more transfer_write maps
authorThomas Raoux <thomasraoux@google.com>
Tue, 17 Jan 2023 17:05:11 +0000 (17:05 +0000)
committerThomas Raoux <thomasraoux@google.com>
Tue, 17 Jan 2023 17:06:00 +0000 (17:06 +0000)
Add pattern to lower transfer_write with permutation map that are not
permutation of minor identity map.

Differential Revision: https://reviews.llvm.org/D141815

mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp
mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir

index df8ba7b..68d9a34 100644 (file)
@@ -33,6 +33,19 @@ inverseTransposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr,
   return builder.getBoolArrayAttr(newInBoundsValues);
 }
 
+/// Extend the rank of a vector Value by `addedRanks` by adding outer unit
+/// dimensions.
+static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec,
+                              int64_t addedRank) {
+  auto originalVecType = vec.getType().cast<VectorType>();
+  SmallVector<int64_t> newShape(addedRank, 1);
+  newShape.append(originalVecType.getShape().begin(),
+                  originalVecType.getShape().end());
+  VectorType newVecType =
+      VectorType::get(newShape, originalVecType.getElementType());
+  return builder.create<vector::BroadcastOp>(loc, newVecType, vec);
+}
+
 /// Lower transfer_read op with permutation into a transfer_read with a
 /// permutation map composed of leading zeros followed by a minor identiy +
 /// vector.transpose op.
@@ -170,6 +183,77 @@ struct TransferWritePermutationLowering
   }
 };
 
+/// Convert a transfer.write op with a map which isn't the permutation of a
+/// minor identity into a vector.broadcast + transfer_write with permutation of
+/// minor identity map by adding unit dim on inner dimension. Ex:
+/// ```
+///   vector.transfer_write %v
+///     {permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>} :
+///     vector<8x16xf32>
+/// ```
+/// into:
+/// ```
+///   %v1 = vector.broadcast %v : vector<8x16xf32> to vector<1x8x16xf32>
+///   vector.transfer_write %v1
+///     {permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d1, d2)>} :
+///     vector<1x8x16xf32>
+/// ```
+struct TransferWriteNonPermutationLowering
+    : public OpRewritePattern<vector::TransferWriteOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransferWriteOp op,
+                                PatternRewriter &rewriter) const override {
+    if (op.getTransferRank() == 0)
+      return failure();
+    SmallVector<unsigned> permutation;
+    AffineMap map = op.getPermutationMap();
+    if (map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
+      return failure();
+
+    // Missing outer dimensions are allowed, find the most outer existing
+    // dimension then deduce the missing inner dimensions.
+    SmallVector<bool> foundDim(map.getNumDims(), false);
+    for (AffineExpr exp : map.getResults()) {
+      foundDim[exp.cast<AffineDimExpr>().getPosition()] = true;
+    }
+    SmallVector<AffineExpr> exprs;
+    bool foundFirstDim = false;
+    SmallVector<int64_t> missingInnerDim;
+    for (size_t i = 0; i < foundDim.size(); i++) {
+      if (foundDim[i]) {
+        foundFirstDim = true;
+        continue;
+      }
+      if (!foundFirstDim)
+        continue;
+      // Once we found one outer dimension existing in the map keep track of all
+      // the missing dimensions after that.
+      missingInnerDim.push_back(i);
+      exprs.push_back(rewriter.getAffineDimExpr(i));
+    }
+    // Add unit dims at the beginning of the shape.
+    Value newVec = extendVectorRank(rewriter, op.getLoc(), op.getVector(),
+                                    missingInnerDim.size());
+    exprs.append(map.getResults().begin(), map.getResults().end());
+    AffineMap newMap =
+        AffineMap::get(map.getNumDims(), 0, exprs, op.getContext());
+    ArrayAttr newInBoundsAttr;
+    if (op.getInBounds()) {
+      // All the new dimensions added are inbound.
+      SmallVector<bool> newInBoundsValues(missingInnerDim.size(), true);
+      for (Attribute attr : op.getInBounds().value().getValue()) {
+        newInBoundsValues.push_back(attr.cast<BoolAttr>().getValue());
+      }
+      newInBoundsAttr = rewriter.getBoolArrayAttr(newInBoundsValues);
+    }
+    rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
+        op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap),
+        op.getMask(), newInBoundsAttr);
+    return success();
+  }
+};
+
 /// Lower transfer_read op with broadcast in the leading dimensions into
 /// transfer_read of lower rank + vector.broadcast.
 /// Ex: vector.transfer_read ...
@@ -250,7 +334,8 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
 
 void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns.add<TransferReadPermutationLowering,
-               TransferWritePermutationLowering, TransferOpReduceRank>(
-      patterns.getContext(), benefit);
+  patterns
+      .add<TransferReadPermutationLowering, TransferWritePermutationLowering,
+           TransferOpReduceRank, TransferWriteNonPermutationLowering>(
+          patterns.getContext(), benefit);
 }
index 0d56781..ca353a0 100644 (file)
@@ -149,34 +149,32 @@ func.func @materialize_read(%M: index, %N: index, %O: index, %P: index) {
 
 // CHECK-LABEL:func @materialize_write(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
 func.func @materialize_write(%M: index, %N: index, %O: index, %P: index) {
-  // CHECK-DAG:  %{{.*}} = arith.constant dense<1.000000e+00> : vector<5x4x3xf32>
+  // CHECK-DAG:  %{{.*}} = arith.constant dense<1.000000e+00> : vector<3x4x1x5xf32>
   // CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
   // CHECK-DAG:  %[[C1:.*]] = arith.constant 1 : index
   // CHECK-DAG:  %[[C3:.*]] = arith.constant 3 : index
   // CHECK-DAG:  %[[C4:.*]] = arith.constant 4 : index
-  // CHECK-DAG:  %[[C5:.*]] = arith.constant 5 : index
   // CHECK:      %{{.*}} = memref.alloc(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : memref<?x?x?x?xf32>
   // CHECK-NEXT: affine.for %[[I0:.*]] = 0 to %{{.*}} step 3 {
   // CHECK-NEXT:   affine.for %[[I1:.*]] = 0 to %{{.*}} step 4 {
   // CHECK-NEXT:     affine.for %[[I2:.*]] = 0 to %{{.*}} {
   // CHECK-NEXT:       affine.for %[[I3:.*]] = 0 to %{{.*}} step 5 {
-  // CHECK:              %[[ALLOC:.*]] = memref.alloca() : memref<vector<5x4x3xf32>>
-  // CHECK:              memref.store %{{.*}}, %[[ALLOC]][] : memref<vector<5x4x3xf32>>
-  // CHECK:              %[[VECTOR_VIEW1:.*]] = vector.type_cast %[[ALLOC]] : memref<vector<5x4x3xf32>> to memref<5xvector<4x3xf32>>
-  // CHECK:              scf.for %[[I4:.*]] = %[[C0]] to %[[C5]] step %[[C1]] {
+  // CHECK:              %[[ALLOC:.*]] = memref.alloca() : memref<vector<3x4x1x5xf32>>
+  // CHECK:              memref.store %{{.*}}, %[[ALLOC]][] : memref<vector<3x4x1x5xf32>>
+  // CHECK:              %[[VECTOR_VIEW1:.*]] = vector.type_cast %[[ALLOC]] : memref<vector<3x4x1x5xf32>> to memref<3xvector<4x1x5xf32>>
+  // CHECK:              scf.for %[[I4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] {
   // CHECK:                scf.if
-  // CHECK:                  %[[S3:.*]] = affine.apply #[[$ADD]](%[[I3]], %[[I4]])
-  // CHECK:                  %[[VECTOR_VIEW2:.*]] = vector.type_cast %[[VECTOR_VIEW1]] : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
+  // CHECK:                  %[[S3:.*]] = affine.apply #[[$ADD]](%[[I0]], %[[I4]])
+  // CHECK:                  %[[VECTOR_VIEW2:.*]] = vector.type_cast %[[VECTOR_VIEW1]] : memref<3xvector<4x1x5xf32>> to memref<3x4xvector<1x5xf32>>
   // CHECK:                  scf.for %[[I5:.*]] = %[[C0]] to %[[C4]] step %[[C1]] {
   // CHECK:                    scf.if
   // CHECK:                      %[[S1:.*]] = affine.apply #[[$ADD]](%[[I1]], %[[I5]])
-  // CHECK:                      %[[VEC:.*]] = memref.load %[[VECTOR_VIEW2]][%[[I4]], %[[I5]]] : memref<5x4xvector<3xf32>>
-  // CHECK:                      scf.for %[[I6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] {
-  // CHECK:                        %[[S0:.*]] = affine.apply #[[$ADD]](%[[I0]], %[[I6]])
+  // CHECK:                      %[[VECTOR_VIEW3:.*]] = vector.type_cast %[[VECTOR_VIEW2]] : memref<3x4xvector<1x5xf32>> to memref<3x4x1xvector<5xf32>>
+  // CHECK:                      scf.for %[[I6:.*]] = %[[C0]] to %[[C1]] step %[[C1]] {
   // CHECK:                        scf.if
-  // CHECK:                          %[[SCAL:.*]] = vector.extractelement %[[VEC]][%[[I6]] : index] : vector<3xf32>
-  // CHECK:                          memref.store %[[SCAL]], {{.*}}[%[[S0]], %[[S1]], %[[I2]], %[[S3]]] : memref<?x?x?x?xf32>
-  // CHECK:                        }
+  // CHECK:                          %[[S0:.*]] = affine.apply #[[$ADD]](%[[I2]], %[[I6]])
+  // CHECK:                          %[[VEC:.*]] = memref.load %[[VECTOR_VIEW3]][%[[I4]], %[[I5]], %[[I6]]] : memref<3x4x1xvector<5xf32>>
+  // CHECK:                          vector.transfer_write %[[VEC]], %{{.*}}[%[[S3]], %[[S1]], %[[S0]], %[[I3]]] : vector<5xf32>, memref<?x?x?x?xf32>
   // CHECK:                      }
   // CHECK:                    }
   // CHECK:                  }
index 779b84f..6911cd5 100644 (file)
@@ -178,14 +178,12 @@ func.func @transfer_nondefault_layout(%mem : memref<8x8xf32, #layout>, %i : inde
 // CHECK-SAME:                                 %[[IDX:.*]]: index) -> vector<4xf32> {
 // CHECK-NEXT:      %[[CF0:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK-NEXT:      %[[RES:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[CF0]] {in_bounds = [true], permutation_map = #{{.*}}} : memref<8x8xf32>, vector<4xf32>
-// CHECK-NEXT:      vector.transfer_write %[[RES]], %[[MEM]][%[[IDX]], %[[IDX]]] {in_bounds = [true], permutation_map = #{{.*}}} : vector<4xf32>, memref<8x8xf32>
 // CHECK-NEXT:      return %[[RES]] : vector<4xf32>
 // CHECK-NEXT:    }
 
 func.func @transfer_perm_map(%mem : memref<8x8xf32>, %i : index) -> vector<4xf32> {
   %cf0 = arith.constant 0.0 : f32
   %res = vector.transfer_read %mem[%i, %i], %cf0 {in_bounds = [true], permutation_map = affine_map<(d0, d1) -> (d0)>} : memref<8x8xf32>, vector<4xf32>
-  vector.transfer_write %res, %mem[%i, %i] {in_bounds = [true], permutation_map = affine_map<(d0, d1) -> (d0)>} : vector<4xf32>, memref<8x8xf32>
   return %res : vector<4xf32>
 }
 
@@ -349,3 +347,30 @@ func.func @transfer_write_permutations(
 
   return %0 : tensor<?x?x?x?xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @transfer_write_broadcast_unit_dim
+// CHECK-SAME:      %[[ARG0:.*]]: memref<?x?x?x?xf32>
+// CHECK-SAME:      %[[ARG1:.*]]: tensor<?x?x?x?xf32>
+// CHECK-SAME:      %[[ARG2:.*]]: vector<14x8x16xf32>
+// CHECK-SAME:      %[[ARG3:.*]]: vector<8x16xf32>
+// CHECK-SAME:      %[[M:.*]]: i1
+func.func @transfer_write_broadcast_unit_dim(
+    %arg0 : memref<?x?x?x?xf32>, %arg1 : tensor<?x?x?x?xf32>,
+    %v1 : vector<14x8x16xf32>, %v2 : vector<8x16xf32>, %m: i1) -> tensor<?x?x?x?xf32> {
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  %c0 = arith.constant 0 : index
+
+  %0 = vector.transfer_write %v1, %arg1[%c0, %c0, %c0, %c0] {in_bounds = [false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>} : vector<14x8x16xf32>, tensor<?x?x?x?xf32>
+  // CHECK: %[[NEW_VEC0:.*]] = vector.broadcast %{{.*}} : vector<14x8x16xf32> to vector<1x14x8x16xf32>
+  // CHECK: %[[NEW_VEC1:.*]] = vector.transpose %[[NEW_VEC0]], [1, 2, 0, 3] : vector<1x14x8x16xf32> to vector<14x8x1x16xf32>
+  // CHECK: %[[NEW_RES0:.*]] = vector.transfer_write %[[NEW_VEC1]], %[[ARG1]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [false, false, true, true]} : vector<14x8x1x16xf32>, tensor<?x?x?x?xf32>
+
+  vector.transfer_write %v2, %arg0[%c0, %c0, %c0, %c0] {permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>} : vector<8x16xf32>, memref<?x?x?x?xf32>
+  // CHECK: %[[NEW_VEC2:.*]] = vector.broadcast %{{.*}} : vector<8x16xf32> to vector<1x8x16xf32>
+  // CHECK: %[[NEW_VEC3:.*]] = vector.transpose %[[NEW_VEC2]], [1, 2, 0] : vector<1x8x16xf32> to vector<8x16x1xf32>
+  // CHECK: vector.transfer_write %[[NEW_VEC3]], %[[ARG0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] : vector<8x16x1xf32>, memref<?x?x?x?xf32>
+
+  return %0 : tensor<?x?x?x?xf32>
+}