Revert "[mlir][linalg] Fix crash in vectorizer when expanding affine apply"
authorDiego Caballero <diegocaballero@google.com>
Sat, 4 Feb 2023 05:16:46 +0000 (05:16 +0000)
committerDiego Caballero <diegocaballero@google.com>
Sat, 4 Feb 2023 05:18:10 +0000 (05:18 +0000)
This reverts commit 62570b722fa36fddde0d24bf06a245efadda66f5.

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/vectorization.mlir

index 05a2110..a44bc22 100644 (file)
@@ -1056,15 +1056,15 @@ mlir::linalg::vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
 
 /// Converts affine.apply Ops to arithmetic operations.
 static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
+  auto &newIP = linalgOp.getBlock()->front();
   OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPointAfter(&newIP);
   auto toReplace = linalgOp.getBlock()->getOps<AffineApplyOp>();
 
   for (auto op : make_early_inc_range(toReplace)) {
-    rewriter.setInsertionPoint(op);
-    auto expanded = expandAffineExpr(
-        rewriter, op->getLoc(), op.getAffineMap().getResult(0),
-        op.getOperands().take_front(op.getAffineMap().getNumDims()),
-        op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
+    auto expanded =
+        expandAffineExpr(rewriter, op->getLoc(), op.getAffineMap().getResult(0),
+                         op.getOperands(), ValueRange{});
     rewriter.replaceOp(op, expanded);
   }
 }
index a6c5602..f966b0e 100644 (file)
@@ -301,8 +301,7 @@ func.func @vectorize_affine_apply(%arg0: tensor<32xf32>, %arg3: index) -> tensor
   ^bb0(%arg1: f32, %arg2: i32):
     %2 = linalg.index 0 : index
     %12 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %arg3)
-    %13 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%12)[%arg3]
-    %3 = arith.index_cast %13 : index to i32
+    %3 = arith.index_cast %12 : index to i32
     linalg.yield %3 : i32
   } -> tensor<32xi32>
   return %1 : tensor<32xi32>
@@ -316,9 +315,7 @@ func.func @vectorize_affine_apply(%arg0: tensor<32xf32>, %arg3: index) -> tensor
 // CHECK:   %[[EMPTY:.*]] = tensor.empty() : tensor<32xi32>
 // CHECK:   %[[BCAST:.*]] = vector.broadcast %[[ARG1]] : index to vector<32xindex>
 // CHECK:   %[[ADDI:.*]] = arith.addi %[[BCAST]], %[[CST]] : vector<32xindex>
-// CHECK:   %[[BCAST2:.*]] = vector.broadcast %[[ARG1]] : index to vector<32xindex>
-// CHECK:   %[[ADDI2:.*]] = arith.addi %[[ADDI]], %[[BCAST2]] : vector<32xindex>
-// CHECK:   %[[CAST:.*]] = arith.index_cast %[[ADDI2]] : vector<32xindex> to vector<32xi32>
+// CHECK:   %[[CAST:.*]] = arith.index_cast %[[ADDI]] : vector<32xindex> to vector<32xi32>
 // CHECK:   vector.transfer_write %[[CAST]], %[[EMPTY]][%[[C0:.*]]] {in_bounds = [true]} : vector<32xi32>, tensor<32xi32>
 
 transform.sequence failures(propagate) {