[mlir] Broadcast scalars when vectorising tensor.extract
authorAndrzej Warzynski <andrzej.warzynski@arm.com>
Fri, 30 Dec 2022 15:22:22 +0000 (15:22 +0000)
committerAndrzej Warzynski <andrzej.warzynski@arm.com>
Thu, 12 Jan 2023 16:34:11 +0000 (16:34 +0000)
When vectorizing tensor.extract embedded within linalg.generic, the
default option is to rewrite it as vector.gather. When doing so, we need
to make sure that the corresponding indices are vectorized accordingly.
However, the Linalg vectorizer will not vectorize constants like in the
following example. This is fixed by simply broadcasting %c0 and %c1.

```
  func.func @example(%arg0: tensor<3x3xf32>, %arg2: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> {
    %c0 = arith.constant 1 : index
    %c1 = arith.constant 2 : index
    %1 = linalg.generic {
      (...)
    } outs(...) {
    ^bb0(...):
      %2 = tensor.extract %arg0[%c0, %c1] : tensor<3x3xf32>
      linalg.yield %2 : f32
    } -> tensor<1x1x3xf32>
    return %1 : tensor<1x1x3xf32>
  }
```

This patch makes sure that in the case above (and other similar cases),
the vectorizer broadcasts %c0 and %c1.

Differential Revision: https://reviews.llvm.org/D140781

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/vectorization.mlir

index a17663b1019126f04f726e4225769ff4ff31d73c..58d746088ad913d88cd466e50c53e9ac4098f32a 100644 (file)
@@ -626,23 +626,19 @@ calculateGatherOffset(OpBuilder &b, tensor::ExtractOp extractOp,
 
   const size_t numIndices = extractOp.getIndices().size();
   for (size_t i = 1; i < numIndices; i++) {
-    auto dimSizeBcast = b.create<vector::BroadcastOp>(
-        loc, indexVecType,
+    auto dimSize = broadcastIfNeeded(
+        b,
         b.create<arith::ConstantIndexOp>(
             loc,
-            extractOp.getTensor().getType().cast<ShapedType>().getDimSize(i)));
-    offset = b.create<arith::MulIOp>(loc, offset, dimSizeBcast);
-
-    auto originalIndexBcast = bvm.lookup(extractOp.getIndices()[i]);
-    if (i == numIndices - 1) {
-      // We only need an additional broadcast for the trailing index. All other
-      // indices have already been broadcast by `vectorizeLinalgIndex` to match
-      // the output size.
-      originalIndexBcast = b.create<vector::BroadcastOp>(
-          loc, indexVecType, bvm.lookup(extractOp.getIndices()[i]));
-    }
+            extractOp.getTensor().getType().cast<ShapedType>().getDimSize(i)),
+        indexVecType.getShape());
+
+    offset = b.create<arith::MulIOp>(loc, offset, dimSize);
+
+    auto extractOpIndex = broadcastIfNeeded(
+        b, bvm.lookup(extractOp.getIndices()[i]), indexVecType.getShape());
 
-    offset = b.create<arith::AddIOp>(loc, originalIndexBcast, offset);
+    offset = b.create<arith::AddIOp>(loc, extractOpIndex, offset);
   }
 
   return offset;
index 9de78468f3e6e086e8d1fc64a35e6b9f9e08d14c..2c7d34066a4bcaabfb51d58b5337a511bc94ddd6 100644 (file)
@@ -1494,10 +1494,83 @@ transform.sequence failures(propagate) {
 
 // -----
 
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+func.func @vectorize_nd_tensor_extract_constant_idx(%arg0: tensor<3x3xf32>, %arg2: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> {
+  %c0 = arith.constant 1 : index
+  %c1 = arith.constant 2 : index
+  %2 = linalg.generic {
+    indexing_maps = [#map1],
+    iterator_types = ["parallel", "parallel", "parallel"]
+  } outs(%arg2 : tensor<1x1x3xf32>) {
+  ^bb0(%arg4: f32):
+    %3 = linalg.index 2 : index
+    %7 = tensor.extract %arg0[%c0, %c1] : tensor<3x3xf32>
+    linalg.yield %7 : f32
+  } -> tensor<1x1x3xf32>
+  return %2 : tensor<1x1x3xf32>
+}
+
+// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_constant_idx
+// CHECK-SAME: %[[ARG0:.*]]: tensor<3x3xf32>
+// CHECK-SAME: %[[ARG1:.*]]: tensor<1x1x3xf32>
+// CHECK:    %[[MASK:.*]] = arith.constant dense<true> : vector<1x1x3xi1>
+// CHECK:    %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x3xf32>
+// CHECK:    %[[C0:.*]] = arith.constant 0 : index
+// Magic "5" below comes from (1 * 3 + 2) (1: index into dim 1, 2: index into dim 2)
+// CHECK:    %[[IDX:.*]] = arith.constant dense<5> : vector<1x1x3xindex>
+// CHECK:    %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]]] [%[[IDX]]], %[[MASK]], %[[PASSTHRU]] : tensor<3x3xf32>, vector<1x1x3xindex>, vector<1x1x3xi1>, vector<1x1x3xf32> into vector<1x1x3xf32>
+// CHECK:    vector.transfer_write %[[GATHER]]
+// CHECK:  }
+
+transform.sequence failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+   %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+   %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation
+   %2 = transform.structured.vectorize %1 { vectorize_nd_extract }
+ }
+
+// -----
+
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+func.func @vectorize_nd_tensor_extract_idx_from_iteration_index(%arg0: tensor<3x3x3xf32>, %arg2: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> {
+  %1 = linalg.generic {
+    indexing_maps = [#map1],
+    iterator_types = ["parallel", "parallel", "parallel"]
+  } outs(%arg2 : tensor<1x1x3xf32>) {
+  ^bb0(%arg4: f32):
+    %2 = linalg.index 0 : index
+    %3 = linalg.index 1 : index
+    %4 = linalg.index 2 : index
+    %5 = tensor.extract %arg0[%2, %3, %4] : tensor<3x3x3xf32>
+    linalg.yield %5 : f32
+  } -> tensor<1x1x3xf32>
+  return %1 : tensor<1x1x3xf32>
+}
+
+// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_idx_from_iteration_index
+// CHECK-SAME: %[[ARG0:.*]]: tensor<3x3x3xf32>
+// CHECK-SAME: %[[ARG1:.*]]: tensor<1x1x3xf32>
+// CHECK:   %[[INDICES:.*]] = arith.constant dense<[0, 1, 2]> : vector<3xindex>
+// CHECK:   %[[MASK:.*]] = arith.constant dense<true> : vector<1x1x3xi1>
+// CHECK:   %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x3xf32>
+// CHECK:   %[[C0:.*]] = arith.constant 0 : index
+// CHECK:   %[[B:.*]] = vector.broadcast %[[INDICES]] : vector<3xindex> to vector<1x1x3xindex>
+// CHECK:   %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[B]]], %[[MASK]], %[[PASSTHRU]] : tensor<3x3x3xf32>, vector<1x1x3xindex>, vector<1x1x3xi1>, vector<1x1x3xf32> into vector<1x1x3xf32>
+// CHECK:   vector.transfer_write %[[GATHER]]
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+  %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation
+  %2 = transform.structured.vectorize %1 { vectorize_nd_extract }
+}
+
+// -----
+
 #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
 #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
 #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-func.func @vectorize_nd_tensor_extract(%arg0: tensor<3x3xf32>, %arg1: tensor<4x3xi32>, %arg2: tensor<4x3xi32>, %arg3: tensor<4x7x2xf32>, %arg4: tensor<4x7x3x2xf32>) -> tensor<4x7x3x2xf32> {
+func.func @vectorize_nd_tensor_extract_index_from_tensor(%arg0: tensor<3x3xf32>, %arg1: tensor<4x3xi32>, %arg2: tensor<4x3xi32>, %arg3: tensor<4x7x2xf32>, %arg4: tensor<4x7x3x2xf32>) -> tensor<4x7x3x2xf32> {
   %2 = linalg.generic {
     indexing_maps = [#map0, #map0, #map1, #map2],
     iterator_types = ["parallel", "parallel", "parallel", "parallel"]
@@ -1510,7 +1583,7 @@ func.func @vectorize_nd_tensor_extract(%arg0: tensor<3x3xf32>, %arg1: tensor<4x3
   } -> tensor<4x7x3x2xf32>
   return %2 : tensor<4x7x3x2xf32>
 }
-// CHECK-LABEL: func.func @vectorize_nd_tensor_extract
+// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_index_from_tensor
 // CHECK-SAME: %[[ARG0:.*]]: tensor<3x3xf32>
 // CHECK-SAME: %[[ARG1:arg1]]: tensor<4x3xi32>
 // CHECK-SAME: %[[ARG2:arg2]]: tensor<4x3xi32>