[mlir][linalg] Fix crash in vectorizer when expanding affine apply
authorThomas Raoux <thomasraoux@google.com>
Fri, 3 Feb 2023 07:49:38 +0000 (07:49 +0000)
committerThomas Raoux <thomasraoux@google.com>
Fri, 3 Feb 2023 08:16:49 +0000 (08:16 +0000)
Fix the insert point when expanding affine apply and handle cases with
symbols. Also add missing precondition to dynamic shape vectorization.

Differential Revision: https://reviews.llvm.org/D143243

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/vectorization.mlir

index 3173a44..05a2110 100644 (file)
@@ -960,6 +960,10 @@ static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op) {
   if (!isa<linalg::GenericOp>(op))
     return failure();
 
+  // TODO: Index vectorization assumes static shape.
+  if (op.hasIndexSemantics())
+    return failure();
+
   // TODO: 0-d vectors are not supported yet.
   if (llvm::any_of(op.getIndexingMapsArray(), [](AffineMap map) {
         return map.isEmpty() || map.getResults().empty();
@@ -1052,15 +1056,15 @@ mlir::linalg::vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
 
 /// Converts affine.apply Ops to arithmetic operations.
 static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
-  auto &newIP = linalgOp.getBlock()->front();
   OpBuilder::InsertionGuard g(rewriter);
-  rewriter.setInsertionPointAfter(&newIP);
   auto toReplace = linalgOp.getBlock()->getOps<AffineApplyOp>();
 
   for (auto op : make_early_inc_range(toReplace)) {
-    auto expanded =
-        expandAffineExpr(rewriter, op->getLoc(), op.getAffineMap().getResult(0),
-                         op.getOperands(), ValueRange{});
+    rewriter.setInsertionPoint(op);
+    auto expanded = expandAffineExpr(
+        rewriter, op->getLoc(), op.getAffineMap().getResult(0),
+        op.getOperands().take_front(op.getAffineMap().getNumDims()),
+        op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
     rewriter.replaceOp(op, expanded);
   }
 }
index f966b0e..a6c5602 100644 (file)
@@ -301,7 +301,8 @@ func.func @vectorize_affine_apply(%arg0: tensor<32xf32>, %arg3: index) -> tensor
   ^bb0(%arg1: f32, %arg2: i32):
     %2 = linalg.index 0 : index
     %12 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %arg3)
-    %3 = arith.index_cast %12 : index to i32
+    %13 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%12)[%arg3]
+    %3 = arith.index_cast %13 : index to i32
     linalg.yield %3 : i32
   } -> tensor<32xi32>
   return %1 : tensor<32xi32>
@@ -315,7 +316,9 @@ func.func @vectorize_affine_apply(%arg0: tensor<32xf32>, %arg3: index) -> tensor
 // CHECK:   %[[EMPTY:.*]] = tensor.empty() : tensor<32xi32>
 // CHECK:   %[[BCAST:.*]] = vector.broadcast %[[ARG1]] : index to vector<32xindex>
 // CHECK:   %[[ADDI:.*]] = arith.addi %[[BCAST]], %[[CST]] : vector<32xindex>
-// CHECK:   %[[CAST:.*]] = arith.index_cast %[[ADDI]] : vector<32xindex> to vector<32xi32>
+// CHECK:   %[[BCAST2:.*]] = vector.broadcast %[[ARG1]] : index to vector<32xindex>
+// CHECK:   %[[ADDI2:.*]] = arith.addi %[[ADDI]], %[[BCAST2]] : vector<32xindex>
+// CHECK:   %[[CAST:.*]] = arith.index_cast %[[ADDI2]] : vector<32xindex> to vector<32xi32>
 // CHECK:   vector.transfer_write %[[CAST]], %[[EMPTY]][%[[C0:.*]]] {in_bounds = [true]} : vector<32xi32>, tensor<32xi32>
 
 transform.sequence failures(propagate) {