From 62570b722fa36fddde0d24bf06a245efadda66f5 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Fri, 3 Feb 2023 07:49:38 +0000 Subject: [PATCH] [mlir][linalg] Fix crash in vectorizer when expanding affine apply Fix the insert point when expanding affine apply and handle cases with symbols. Also add missing precondition to dynamic shape vectorization. Differential Revision: https://reviews.llvm.org/D143243 --- mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp | 14 +++++++++----- mlir/test/Dialect/Linalg/vectorization.mlir | 7 +++++-- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 3173a44..05a2110 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -960,6 +960,10 @@ static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op) { if (!isa(op)) return failure(); + // TODO: Index vectorization assumes static shape. + if (op.hasIndexSemantics()) + return failure(); + // TODO: 0-d vectors are not supported yet. if (llvm::any_of(op.getIndexingMapsArray(), [](AffineMap map) { return map.isEmpty() || map.getResults().empty(); @@ -1052,15 +1056,15 @@ mlir::linalg::vectorizeLinalgOpPrecondition(LinalgOp linalgOp, /// Converts affine.apply Ops to arithmetic operations. static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) { - auto &newIP = linalgOp.getBlock()->front(); OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPointAfter(&newIP); auto toReplace = linalgOp.getBlock()->getOps(); for (auto op : make_early_inc_range(toReplace)) { - auto expanded = - expandAffineExpr(rewriter, op->getLoc(), op.getAffineMap().getResult(0), - op.getOperands(), ValueRange{}); + rewriter.setInsertionPoint(op); + auto expanded = expandAffineExpr( + rewriter, op->getLoc(), op.getAffineMap().getResult(0), + op.getOperands().take_front(op.getAffineMap().getNumDims()), + op.getOperands().take_back(op.getAffineMap().getNumSymbols())); rewriter.replaceOp(op, expanded); } } diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir index f966b0e..a6c5602 100644 --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -301,7 +301,8 @@ func.func @vectorize_affine_apply(%arg0: tensor<32xf32>, %arg3: index) -> tensor ^bb0(%arg1: f32, %arg2: i32): %2 = linalg.index 0 : index %12 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %arg3) - %3 = arith.index_cast %12 : index to i32 + %13 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%12)[%arg3] + %3 = arith.index_cast %13 : index to i32 linalg.yield %3 : i32 } -> tensor<32xi32> return %1 : tensor<32xi32> @@ -315,7 +316,9 @@ func.func @vectorize_affine_apply(%arg0: tensor<32xf32>, %arg3: index) -> tensor // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<32xi32> // CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG1]] : index to vector<32xindex> // CHECK: %[[ADDI:.*]] = arith.addi %[[BCAST]], %[[CST]] : vector<32xindex> -// CHECK: %[[CAST:.*]] = arith.index_cast %[[ADDI]] : vector<32xindex> to vector<32xi32> +// CHECK: %[[BCAST2:.*]] = vector.broadcast %[[ARG1]] : index to vector<32xindex> +// CHECK: %[[ADDI2:.*]] = arith.addi %[[ADDI]], %[[BCAST2]] : vector<32xindex> +// CHECK: %[[CAST:.*]] = arith.index_cast %[[ADDI2]] : vector<32xindex> to vector<32xi32> // CHECK: vector.transfer_write %[[CAST]], %[[EMPTY]][%[[C0:.*]]] {in_bounds = [true]} : vector<32xi32>, tensor<32xi32> transform.sequence failures(propagate) { -- 2.7.4