[mlir][vector] Make write permutation lowering work with tensors.
authorgysit <gysit@google.com>
Wed, 2 Feb 2022 09:06:31 +0000 (09:06 +0000)
committergysit <gysit@google.com>
Wed, 2 Feb 2022 09:21:10 +0000 (09:21 +0000)
Use type inference when building the TransferWriteOp in the TransferWritePermutationLowering. Previously, the result type has been set to Type() which triggers an assertion if the pattern is used with tensors instead of memrefs.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D118758

mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp
mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir

index 533dc20..baf6973 100644 (file)
@@ -185,8 +185,8 @@ struct TransferWritePermutationLowering
     auto newMap = AffineMap::getMinorIdentityMap(
         map.getNumDims(), map.getNumResults(), rewriter.getContext());
     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
-        op, Type(), newVec, op.source(), op.indices(),
-        AffineMapAttr::get(newMap), newMask, newInBoundsAttr);
+        op, newVec, op.source(), op.indices(), AffineMapAttr::get(newMap),
+        newMask, newInBoundsAttr);
 
     return success();
   }
index 562870c..7983a81 100644 (file)
@@ -327,21 +327,24 @@ func @transfer_read_permutations(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?x?x?
 // -----
 
 // CHECK-LABEL: func @transfer_write_permutations
-func @transfer_write_permutations(%arg0 : memref<?x?x?x?xf32>,
-    %v1 : vector<7x14x8x16xf32>, %v2 : vector<8x16xf32>) -> () {
+// CHECK-SAME:      %[[ARG0:.*]]: memref<?x?x?x?xf32>
+// CHECK-SAME:      %[[ARG1:.*]]: tensor<?x?x?x?xf32>
+func @transfer_write_permutations(
+    %arg0 : memref<?x?x?x?xf32>, %arg1 : tensor<?x?x?x?xf32>,
+    %v1 : vector<7x14x8x16xf32>, %v2 : vector<8x16xf32>) -> tensor<?x?x?x?xf32> {
   // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
   %c0 = arith.constant 0 : index
   %m = arith.constant 1 : i1
 
   %mask0 = splat %m : vector<7x14x8x16xi1>
-  vector.transfer_write %v1, %arg0[%c0, %c0, %c0, %c0], %mask0 {in_bounds = [true, false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3, d0)>} : vector<7x14x8x16xf32>, memref<?x?x?x?xf32>
+  %0 = vector.transfer_write %v1, %arg1[%c0, %c0, %c0, %c0], %mask0 {in_bounds = [true, false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3, d0)>} : vector<7x14x8x16xf32>, tensor<?x?x?x?xf32>
   // CHECK: %[[NEW_MASK0:.*]] = vector.transpose %{{.*}} [2, 1, 3, 0] : vector<7x14x8x16xi1> to vector<8x14x16x7xi1>
   // CHECK: %[[NEW_VEC0:.*]] = vector.transpose %{{.*}} [2, 1, 3, 0] : vector<7x14x8x16xf32> to vector<8x14x16x7xf32>
-  // CHECK: vector.transfer_write %[[NEW_VEC0]], %arg0[%c0, %c0, %c0, %c0], %[[NEW_MASK0]] {in_bounds = [false, false, true, true]} : vector<8x14x16x7xf32>, memref<?x?x?x?xf32>
+  // CHECK: %[[NEW_RES0:.*]] = vector.transfer_write %[[NEW_VEC0]], %[[ARG1]][%c0, %c0, %c0, %c0], %[[NEW_MASK0]] {in_bounds = [false, false, true, true]} : vector<8x14x16x7xf32>, tensor<?x?x?x?xf32>
 
   vector.transfer_write %v2, %arg0[%c0, %c0, %c0, %c0] {permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>} : vector<8x16xf32>, memref<?x?x?x?xf32>
   // CHECK: %[[NEW_VEC1:.*]] = vector.transpose %{{.*}} [1, 0] : vector<8x16xf32> to vector<16x8xf32>
-  // CHECK: vector.transfer_write %[[NEW_VEC1]], %arg0[%c0, %c0, %c0, %c0] : vector<16x8xf32>, memref<?x?x?x?xf32>
+  // CHECK: vector.transfer_write %[[NEW_VEC1]], %[[ARG0]][%c0, %c0, %c0, %c0] : vector<16x8xf32>, memref<?x?x?x?xf32>
 
-  return
+  return %0 : tensor<?x?x?x?xf32>
 }