return loops;
}
+/// Replace the index operations in the body of the loop nest by the matching
+/// induction variables. If available use the interchange vector to map the
+/// interchanged induction variables to the dimension of the index operation.
+static void replaceIndexOpsByInductionVariables(
+ LinalgOp linalgOp, PatternRewriter &rewriter, ArrayRef<Operation *> loopOps,
+ ArrayRef<unsigned> interchangeVector) {
+ // Extract the induction variables of the loop nest from outer to inner.
+ SmallVector<Value> allIvs;
+ for (Operation *loopOp : loopOps) {
+ llvm::TypeSwitch<Operation *>(loopOp)
+ .Case([&](scf::ParallelOp parallelOp) {
+ allIvs.append(parallelOp.getInductionVars().begin(),
+ parallelOp.getInductionVars().end());
+ })
+ .Case([&](scf::ForOp forOp) {
+ allIvs.push_back(forOp.getInductionVar());
+ })
+ .Case([&](AffineForOp affineForOp) {
+ allIvs.push_back(affineForOp.getInductionVar());
+ })
+ .Default([&](Operation *op) { assert(false && "unexpected op"); });
+ }
+ assert(linalgOp.getNumLoops() == allIvs.size() &&
+ "expected the number of loops and induction variables to match");
+ // Replace the index operations in the body of the innermost loop op.
+ if (!loopOps.empty()) {
+ LoopLikeOpInterface loopOp = loopOps.back();
+ for (IndexOp indexOp :
+ llvm::make_early_inc_range(loopOp.getLoopBody().getOps<IndexOp>())) {
+ // Search the indexing dimension in the interchange vector if available.
+ assert(interchangeVector.empty() ||
+ interchangeVector.size() == linalgOp.getNumLoops());
+ const auto *it = llvm::find(interchangeVector, indexOp.dim());
+ uint64_t dim = it != interchangeVector.end()
+ ? std::distance(interchangeVector.begin(), it)
+ : indexOp.dim();
+ rewriter.replaceOp(indexOp, allIvs[dim]);
+ }
+ }
+}
+
namespace {
template <typename LoopType>
class LinalgRewritePattern : public RewritePattern {
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
auto linalgOp = dyn_cast<LinalgOp>(op);
- // TODO: remove hasIndexSemantics check once index ops are supported.
- if (!linalgOp || linalgOp.hasIndexSemantics())
+ if (!isa<LinalgOp>(op))
return failure();
- if (!linalgOpToLoopsImpl<LoopType>(op, rewriter, interchangeVector))
+ Optional<LinalgLoops> loopOps =
+ linalgOpToLoopsImpl<LoopType>(op, rewriter, interchangeVector);
+ if (!loopOps.hasValue())
return failure();
+ replaceIndexOpsByInductionVariables(linalgOp, rewriter, loopOps.getValue(),
+ interchangeVector);
rewriter.eraseOp(op);
return success();
}
// -----
-func @index_op(%arg0: memref<4x8xindex>) {
- linalg.generic {
- indexing_maps = [affine_map<(i, j) -> (i, j)>],
- iterator_types = ["parallel", "parallel"]}
- outs(%arg0 : memref<4x8xindex>) {
- ^bb0(%arg1: index): // no predecessors
- %0 = linalg.index 1 : index
- linalg.yield %0 : index
+#map = affine_map<(i, j, k, l, m) -> (i, j, k, l, m)>
+func @generic(%output: memref<1x2x3x4x5xindex>) {
+ linalg.generic {indexing_maps = [#map],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
+ outs(%output : memref<1x2x3x4x5xindex>) {
+ ^bb0(%arg0 : index):
+ %i = linalg.index 0 : index
+ %j = linalg.index 1 : index
+ %k = linalg.index 2 : index
+ %l = linalg.index 3 : index
+ %m = linalg.index 4 : index
+ %0 = addi %i, %j : index
+ %1 = addi %0, %k : index
+ %2 = addi %1, %l : index
+ %3 = addi %2, %m : index
+ linalg.yield %3: index
}
return
}
-// LOOP-LABEL: @index_op
-// LOOP: linalg.generic
-// PARALLEL-LABEL: @index_op
-// PARALLEL: linalg.generic
+// LOOP: scf.for %[[m:.*]] = %c0 to %c5 step %c1
+// LOOP: scf.for %[[i:.*]] = %c0 to %c1 step %c1
+// LOOP: scf.for %[[l:.*]] = %c0 to %c4 step %c1
+// LOOP: scf.for %[[j:.*]] = %c0 to %c2 step %c1
+// LOOP: scf.for %[[k:.*]] = %c0 to %c3 step %c1
+// LOOP: %{{.*}} = addi %[[i]], %[[j]] : index
+// LOOP: %{{.*}} = addi %{{.*}}, %[[k]] : index
+// LOOP: %{{.*}} = addi %{{.*}}, %[[l]] : index
+// LOOP: %{{.*}} = addi %{{.*}}, %[[m]] : index
-// AFFINE-LABEL: @index_op
-// AFFINE: linalg.generic
+// PARALLEL: scf.parallel (%[[m:.*]], %[[i:.*]], %[[l:.*]], %[[j:.*]], %[[k:.*]]) =
+// PARALLEL-SAME: to (%c5, %c1, %c4, %c2, %c3)
+// PARALLEL: %{{.*}} = addi %[[i]], %[[j]] : index
+// PARALLEL: %{{.*}} = addi %{{.*}}, %[[k]] : index
+// PARALLEL: %{{.*}} = addi %{{.*}}, %[[l]] : index
+// PARALLEL: %{{.*}} = addi %{{.*}}, %[[m]] : index
+
+// AFFINE: affine.for %[[m:.*]] = 0 to 5
+// AFFINE: affine.for %[[i:.*]] = 0 to 1
+// AFFINE: affine.for %[[l:.*]] = 0 to 4
+// AFFINE: affine.for %[[j:.*]] = 0 to 2
+// AFFINE: affine.for %[[k:.*]] = 0 to 3
+// AFFINE: %{{.*}} = addi %[[i]], %[[j]] : index
+// AFFINE: %{{.*}} = addi %{{.*}}, %[[k]] : index
+// AFFINE: %{{.*}} = addi %{{.*}}, %[[l]] : index
+// AFFINE: %{{.*}} = addi %{{.*}}, %[[m]] : index
library_call = "some_external_function_name_2",
doc = "B(i,j,k), C(i,k,j) = foo(A(i, j) * B(i,j,k), i * j * k + C(i,k,j))"
}
+func @generic_index_region(
+ %arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
+ %arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>,
+ %arg2: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
+ linalg.generic #trait4
+ ins(%arg0 : memref<?x?xf32, offset: ?, strides: [?, 1]>)
+ outs(%arg1, %arg2 : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>,
+ memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
+ ^bb0(%a: f32, %b: f32, %c: f32):
+ %i = linalg.index 0 : index
+ %j = linalg.index 1 : index
+ %k = linalg.index 2 : index
+ %result_1 = mulf %a, %b : f32
+
+ %ij = addi %i, %j : index
+ %ijk = addi %ij, %k : index
+ %ijk_int = index_cast %ijk : index to i32
+ %ijk_float = sitofp %ijk_int : i32 to f32
+
+ %result_2 = addf %c, %ijk_float : f32
+ linalg.yield %result_1, %result_2 : f32, f32
+ }
+ return
+}
+
+// CHECKLOOP-LABEL: @generic_index_region
+// CHECKLOOP: scf.for %[[i:.*]] = {{.*}}
+// CHECKLOOP: scf.for %[[j:.*]] = {{.*}}
+// CHECKLOOP: scf.for %[[k:.*]] = {{.*}}
+// CHECKLOOP: %[[a:.*]] = memref.load %{{.*}}[%[[i]], %[[j]]]
+// CHECKLOOP: %[[b:.*]] = memref.load %{{.*}}[%[[i]], %[[j]], %[[k]]]
+// CHECKLOOP: %[[c:.*]] = memref.load %{{.*}}[%[[i]], %[[k]], %[[j]]]
+// CHECKLOOP: %[[result_1:.*]] = mulf %[[a]], %[[b]] : f32
+// CHECKLOOP: %[[ij:.*]] = addi %[[i]], %[[j]] : index
+// CHECKLOOP: %[[ijk:.*]] = addi %[[ij]], %[[k]] : index
+// CHECKLOOP: %[[ijk_int:.*]] = index_cast %[[ijk]] : index to i32
+// CHECKLOOP: %[[ijk_float:.*]] = sitofp %[[ijk_int]] : i32 to f32
+// CHECKLOOP: %[[result_2:.*]] = addf %[[c]], %[[ijk_float]] : f32
+// CHECKLOOP: store %[[result_1]], %{{.*}}[%[[i]], %[[j]], %[[k]]]
+// CHECKLOOP: store %[[result_2]], %{{.*}}[%[[i]], %[[k]], %[[j]]]
+
+// CHECKPARALLEL-LABEL: @generic_index_region
+// CHECKPARALLEL: scf.parallel (%[[i:[a-zA-Z0-9_]*]], %[[j:[a-zA-Z0-9_]*]], %[[k:[a-zA-Z0-9_]*]])
+// CHECKPARALLEL: %[[a:.*]] = memref.load %{{.*}}[%[[i]], %[[j]]]
+// CHECKPARALLEL: %[[b:.*]] = memref.load %{{.*}}[%[[i]], %[[j]], %[[k]]]
+// CHECKPARALLEL: %[[c:.*]] = memref.load %{{.*}}[%[[i]], %[[k]], %[[j]]]
+// CHECKPARALLEL: %[[result_1:.*]] = mulf %[[a]], %[[b]] : f32
+// CHECKPARALLEL: %[[ij:.*]] = addi %[[i]], %[[j]] : index
+// CHECKPARALLEL: %[[ijk:.*]] = addi %[[ij]], %[[k]] : index
+// CHECKPARALLEL: %[[ijk_int:.*]] = index_cast %[[ijk]] : index to i32
+// CHECKPARALLEL: %[[ijk_float:.*]] = sitofp %[[ijk_int]] : i32 to f32
+// CHECKPARALLEL: %[[result_2:.*]] = addf %[[c]], %[[ijk_float]] : f32
+// CHECKPARALLEL: store %[[result_1]], %{{.*}}[%[[i]], %[[j]], %[[k]]]
+// CHECKPARALLEL: store %[[result_2]], %{{.*}}[%[[i]], %[[k]], %[[j]]]
+
func @indexed_generic_region(
%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>,
// CHECKPARALLEL: %[[a:.*]] = memref.load %[[ARG0]][]
// CHECKPARALLEL: store %[[a]], %[[ARG1]][%[[i]], %[[j]]]
+func @generic_index_op_zero_rank(%arg0: memref<i32>, %arg1: memref<3x4xi32>)
+{
+ linalg.generic #trait_broadcast
+ ins(%arg0 : memref<i32>)
+ outs(%arg1 : memref<3x4xi32>) {
+ ^bb(%a: i32, %b: i32) :
+ %i = linalg.index 0 : index
+ %j = linalg.index 1 : index
+ %ij = addi %i, %j : index
+ %ij_int = index_cast %ij : index to i32
+ %result = addi %a, %ij_int : i32
+ linalg.yield %result : i32
+ }
+ return
+}
+
+// CHECKLOOP-LABEL: @generic_index_op_zero_rank
+// CHECKLOOP-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<i32>
+// CHECKLOOP-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<3x4xi32>
+// CHECKLOOP: scf.for %[[i:.*]] = {{.*}}
+// CHECKLOOP: scf.for %[[j:.*]] = {{.*}}
+// CHECKLOOP: %[[a:.*]] = memref.load %[[ARG0]][
+// CHECKLOOP: %[[ij:.*]] = addi %[[i]], %[[j]] : index
+// CHECKLOOP: %[[ij_int:.*]] = index_cast %[[ij]] : index to i32
+// CHECKLOOP: %[[result:.*]] = addi %[[a]], %[[ij_int]] : i32
+// CHECKLOOP: store %[[result]], %[[ARG1]][%[[i]], %[[j]]]
+
+// CHECKPARALLEL-LABEL: @generic_index_op_zero_rank
+// CHECKPARALLEL-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<i32>
+// CHECKPARALLEL-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<3x4xi32>
+// CHECKPARALLEL: scf.parallel (%[[i:[a-zA-Z0-9_]*]], %[[j:[a-zA-Z0-9_]*]])
+// CHECKPARALLEL: %[[a:.*]] = memref.load %[[ARG0]][
+// CHECKPARALLEL: %[[ij:.*]] = addi %[[i]], %[[j]] : index
+// CHECKPARALLEL: %[[ij_int:.*]] = index_cast %[[ij]] : index to i32
+// CHECKPARALLEL: %[[result:.*]] = addi %[[a]], %[[ij_int]] : i32
+// CHECKPARALLEL: store %[[result]], %[[ARG1]][%[[i]], %[[j]]]
+
func @indexed_generic_op_zero_rank(%arg0: memref<i32>, %arg1: memref<3x4xi32>)
{
linalg.indexed_generic #trait_broadcast
library_call = "some_reduce_external_fn"
}
+func @generic_index_op_1D_reduce(%arg0: memref<?xf32>,
+ %arg1: memref<f32>,
+ %arg2: memref<f32>)
+{
+ linalg.generic #trait_reduce_init_1D
+ ins(%arg0, %arg1 : memref<?xf32>, memref<f32>)
+ outs(%arg2 : memref<f32>) {
+ ^bb(%a: f32, %b: f32, %c: f32) :
+ %i = linalg.index 0 : index
+ %0 = constant 0 : index
+ %1 = cmpi eq, %0, %i : index
+ %2 = select %1, %b, %c : f32
+ %3 = addf %a, %2 : f32
+ linalg.yield %3 : f32
+ }
+ return
+}
+// CHECKLOOP-LABEL: @generic_index_op_1D_reduce
+// CHECKLOOP-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?xf32>
+// CHECKLOOP-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<f32>
+// CHECKLOOP-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<f32>
+// CHECKLOOP: scf.for %[[i:.*]] = {{.*}}
+// CHECKLOOP: %[[a:.*]] = memref.load %[[ARG0]][%[[i]]]
+// CHECKLOOP: %[[b:.*]] = memref.load %[[ARG1]][]
+// CHECKLOOP: %[[c:.*]] = memref.load %[[ARG2]][]
+// CHECKLOOP: %[[d:.*]] = select %{{.*}}, %[[b]], %[[c]]
+// CHECKLOOP: %[[e:.*]] = addf %[[a]], %[[d]]
+// CHECKLOOP: store %[[e]], %[[ARG2]][]
+
+// CHECKPARALLEL-LABEL: @generic_index_op_1D_reduce
+// CHECKPARALLEL-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?xf32>
+// CHECKPARALLEL-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<f32>
+// CHECKPARALLEL-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<f32>
+// CHECKPARALLEL: scf.for %[[i:.*]] = {{.*}}
+// CHECKPARALLEL: %[[a:.*]] = memref.load %[[ARG0]][%[[i]]]
+// CHECKPARALLEL: %[[b:.*]] = memref.load %[[ARG1]][]
+// CHECKPARALLEL: %[[c:.*]] = memref.load %[[ARG2]][]
+// CHECKPARALLEL: %[[d:.*]] = select %{{.*}}, %[[b]], %[[c]]
+// CHECKPARALLEL: %[[e:.*]] = addf %[[a]], %[[d]]
+// CHECKPARALLEL: store %[[e]], %[[ARG2]][]
+
func @indexed_generic_op_1D_reduce(%arg0: memref<?xf32>,
%arg1: memref<f32>,
%arg2: memref<f32>)