if (!isa<linalg::GenericOp>(op))
return failure();
+ // TODO: Index vectorization assumes static shape.
+ if (op.hasIndexSemantics())
+ return failure();
+
// TODO: 0-d vectors are not supported yet.
if (llvm::any_of(op.getIndexingMapsArray(), [](AffineMap map) {
return map.isEmpty() || map.getResults().empty();
/// Converts affine.apply Ops to arithmetic operations.
static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
- auto &newIP = linalgOp.getBlock()->front();
OpBuilder::InsertionGuard g(rewriter);
- rewriter.setInsertionPointAfter(&newIP);
auto toReplace = linalgOp.getBlock()->getOps<AffineApplyOp>();
for (auto op : make_early_inc_range(toReplace)) {
- auto expanded =
- expandAffineExpr(rewriter, op->getLoc(), op.getAffineMap().getResult(0),
- op.getOperands(), ValueRange{});
+ rewriter.setInsertionPoint(op);
+ auto expanded = expandAffineExpr(
+ rewriter, op->getLoc(), op.getAffineMap().getResult(0),
+ op.getOperands().take_front(op.getAffineMap().getNumDims()),
+ op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
rewriter.replaceOp(op, expanded);
}
}
^bb0(%arg1: f32, %arg2: i32):
%2 = linalg.index 0 : index
%12 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %arg3)
- %3 = arith.index_cast %12 : index to i32
+ %13 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%12)[%arg3]
+ %3 = arith.index_cast %13 : index to i32
linalg.yield %3 : i32
} -> tensor<32xi32>
return %1 : tensor<32xi32>
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<32xi32>
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG1]] : index to vector<32xindex>
// CHECK: %[[ADDI:.*]] = arith.addi %[[BCAST]], %[[CST]] : vector<32xindex>
-// CHECK: %[[CAST:.*]] = arith.index_cast %[[ADDI]] : vector<32xindex> to vector<32xi32>
+// CHECK: %[[BCAST2:.*]] = vector.broadcast %[[ARG1]] : index to vector<32xindex>
+// CHECK: %[[ADDI2:.*]] = arith.addi %[[ADDI]], %[[BCAST2]] : vector<32xindex>
+// CHECK: %[[CAST:.*]] = arith.index_cast %[[ADDI2]] : vector<32xindex> to vector<32xi32>
// CHECK: vector.transfer_write %[[CAST]], %[[EMPTY]][%[[C0:.*]]] {in_bounds = [true]} : vector<32xi32>, tensor<32xi32>
transform.sequence failures(propagate) {