[mlir][Linalg] Relax vectorization condition to allow transposed output.
authorHanhan Wang <hanchung@google.com>
Fri, 27 May 2022 02:20:36 +0000 (19:20 -0700)
committerHanhan Wang <hanchung@google.com>
Fri, 27 May 2022 02:20:36 +0000 (19:20 -0700)
Reviewed By: ThomasRaoux, dcaballe

Differential Revision: https://reviews.llvm.org/D126454

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/vectorization.mlir

index 00c08af..bbb9dd7 100644 (file)
@@ -441,7 +441,7 @@ static bool isElementwise(Operation *op) {
     return false;
   // TODO: relax the restrictions on indexing map.
   for (OpOperand *opOperand : linalgOp.getOutputOperands()) {
-    if (!linalgOp.getTiedIndexingMap(opOperand).isIdentity())
+    if (!linalgOp.getTiedIndexingMap(opOperand).isPermutation())
       return false;
   }
   return hasOnlyScalarElementwiseOp(linalgOp->getRegion(0));
index 3414450..301e8f0 100644 (file)
@@ -121,6 +121,26 @@ func.func @generic_output_transpose(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
 
 // -----
 
+#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
+// CHECK: func @generic_interchanged_transpose
+func.func @generic_interchanged_transpose(%arg0: tensor<12x128x32xf32>) -> tensor<128x12x32xf32> {
+  // CHECK: %[[IN:.+]] = vector.transfer_read
+  // CHECK: vector.transfer_write %[[IN]], {{.+}} permutation_map = #[[MAP]]
+  %0 = linalg.init_tensor [128, 12, 32] : tensor<128x12x32xf32>
+  %1 = linalg.generic {indexing_maps = [#map0, #map1],
+                       iterator_types = ["parallel", "parallel", "parallel"]}
+    ins(%arg0 : tensor<12x128x32xf32>)
+    outs(%0 : tensor<128x12x32xf32>) {
+  ^bb0(%arg1: f32, %arg2: f32):
+    linalg.yield %arg1 : f32
+  } -> tensor<128x12x32xf32>
+  return %1 : tensor<128x12x32xf32>
+}
+
+// -----
+
 #matmul_trait = {
   args_in = 2,
   args_out = 1,