[mlir][linalg] Make Linalg vectorizer lower affine.apply
authorAndrzej Warzynski <andrzej.warzynski@arm.com>
Mon, 23 Jan 2023 19:01:04 +0000 (19:01 +0000)
committerAndrzej Warzynski <andrzej.warzynski@gmail.com>
Fri, 27 Jan 2023 08:30:50 +0000 (08:30 +0000)
It is possible that the input to the Linalg vectorizer contains
`affine.apply` ops (see the example in [1]). Such operations are not
vectarizable at the moment, but this can be fixed by simply converting
them to arithmetic operations. This is basically what this patch
introduces.

The IR change enabled in this patch could be part of a larger set of
"linalgOp pre-processing" transformations that happens right before
vectorization starts but after we know we can vectorize the op. I am
leaving this as a TODO.

[1] https://github.com/iree-org/iree/issues/10876.

Differential Revision: https://reviews.llvm.org/D142371

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/Dialect/Linalg/Utils/Utils.cpp
mlir/test/Dialect/Linalg/vectorization.mlir

index d81496e..b356de9 100644 (file)
@@ -9,6 +9,7 @@
 // This file implements the linalg dialect Vectorization transformations.
 //
 //===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Affine/Utils.h"
 
 #include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
@@ -1048,6 +1049,21 @@ mlir::linalg::vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
   return success();
 }
 
+/// Converts affine.apply Ops to arithmetic operations.
+static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
+  auto &newIP = linalgOp.getBlock()->front();
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPointAfter(&newIP);
+  auto toReplace = linalgOp.getBlock()->getOps<AffineApplyOp>();
+
+  for (auto op : make_early_inc_range(toReplace)) {
+    auto expanded =
+        expandAffineExpr(rewriter, op->getLoc(), op.getAffineMap().getResult(0),
+                         op.getOperands(), ValueRange{});
+    rewriter.replaceOp(op, expanded);
+  }
+}
+
 /// Emit a suitable vector form for a Linalg op. If provided, `inputVectorSizes`
 /// are used to vectorize this operation. `inputVectorSizes` must match the rank
 /// of the iteration space of the operation and the sizes must be smaller or
@@ -1084,6 +1100,10 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, LinalgOp linalgOp,
                                              vectorizeNDExtract)))
       return failure();
     LDBG("Vectorize generic by broadcasting to the canonical vector shape\n");
+
+    // Pre-process before proceeding.
+    convertAffineApply(rewriter, linalgOp);
+
     // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted to
     // 'OpBuilder' when it is passed over to some methods like
     // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we erase an op
index e02702a..c45c34c 100644 (file)
@@ -164,7 +164,7 @@ bool hasOnlyScalarElementwiseOp(Region &r) {
     return false;
   for (Operation &op : r.front()) {
     if (!(isa<arith::ConstantOp, func::ConstantOp, tensor::ExtractOp,
-              linalg::YieldOp, linalg::IndexOp>(op) ||
+              linalg::YieldOp, linalg::IndexOp, AffineApplyOp>(op) ||
           OpTrait::hasElementwiseMappableTraits(&op)) ||
         llvm::any_of(op.getResultTypes(),
                      [](Type type) { return !type.isIntOrIndexOrFloat(); }))
index 171a518..f966b0e 100644 (file)
@@ -290,6 +290,43 @@ transform.sequence failures(propagate) {
 
 // -----
 
+#map0 = affine_map<(d0) -> (d0)>
+
+func.func @vectorize_affine_apply(%arg0: tensor<32xf32>, %arg3: index) -> tensor<32xi32> {
+  %0 = tensor.empty() : tensor<32xi32>
+  %1 = linalg.generic {indexing_maps = [#map0, #map0],
+                       iterator_types = ["parallel"]}
+    ins(%arg0 : tensor<32xf32>)
+    outs(%0 : tensor<32xi32>) {
+  ^bb0(%arg1: f32, %arg2: i32):
+    %2 = linalg.index 0 : index
+    %12 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %arg3)
+    %3 = arith.index_cast %12 : index to i32
+    linalg.yield %3 : i32
+  } -> tensor<32xi32>
+  return %1 : tensor<32xi32>
+}
+
+// CHECK-LABEL:  func.func @vectorize_affine_apply
+// CHECK-SAME: %arg0: tensor<32xf32>
+// CHECK-SAME: %[[ARG1:.*]]: index
+// CHECK:   %[[CST:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]> : vector<32xindex>
+// CHECK:   %[[C0:.*]] = arith.constant 0 : index
+// CHECK:   %[[EMPTY:.*]] = tensor.empty() : tensor<32xi32>
+// CHECK:   %[[BCAST:.*]] = vector.broadcast %[[ARG1]] : index to vector<32xindex>
+// CHECK:   %[[ADDI:.*]] = arith.addi %[[BCAST]], %[[CST]] : vector<32xindex>
+// CHECK:   %[[CAST:.*]] = arith.index_cast %[[ADDI]] : vector<32xindex> to vector<32xi32>
+// CHECK:   vector.transfer_write %[[CAST]], %[[EMPTY]][%[[C0:.*]]] {in_bounds = [true]} : vector<32xi32>, tensor<32xi32>
+
+transform.sequence failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+   %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+   %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation
+   %2 = transform.structured.vectorize %1 { vectorize_nd_extract }
+}
+
+// -----
+
 // CHECK-LABEL: func @test_vectorize_fill
 func.func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) {
   //       CHECK: %[[V:.*]] = vector.broadcast {{.*}} : f32 to vector<8x16xf32>