newSteps.push_back(step);
}
}
- // Exit if all or none of the loop dimensions perform a single iteration.
- if (newLowerBounds.size() == 0 ||
- newLowerBounds.size() == op.lowerBound().size())
+ // Exit if none of the loop dimensions perform a single iteration.
+ if (newLowerBounds.size() == op.lowerBound().size())
return failure();
+
+ if (newLowerBounds.empty()) {
+ // All of the loop dimensions perform a single iteration. Inline
+ // loop body and nested ReduceOp's
+ SmallVector<Value> results;
+ results.reserve(op.initVals().size());
+ for (auto &bodyOp : op.getLoopBody().front().without_terminator()) {
+ auto reduce = dyn_cast<ReduceOp>(bodyOp);
+ if (!reduce) {
+ rewriter.clone(bodyOp, mapping);
+ continue;
+ }
+ Block &reduceBlock = reduce.reductionOperator().front();
+ auto initValIndex = results.size();
+ mapping.map(reduceBlock.getArgument(0), op.initVals()[initValIndex]);
+ mapping.map(reduceBlock.getArgument(1),
+ mapping.lookupOrDefault(reduce.operand()));
+ for (auto &reduceBodyOp : reduceBlock.without_terminator())
+ rewriter.clone(reduceBodyOp, mapping);
+
+ auto result = mapping.lookupOrDefault(
+ cast<ReduceReturnOp>(reduceBlock.getTerminator()).result());
+ results.push_back(result);
+ }
+ rewriter.replaceOp(op, results);
+ return success();
+ }
// Replace the parallel loop by lower-dimensional parallel loop.
auto newOp =
rewriter.create<ParallelOp>(op.getLoc(), newLowerBounds, newUpperBounds,
// -----
-func @single_iteration(%A: memref<?x?x?xi32>) {
+func @single_iteration_some(%A: memref<?x?x?xi32>) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%c2 = constant 2 : index
return
}
-// CHECK-LABEL: func @single_iteration(
+// CHECK-LABEL: func @single_iteration_some(
// CHECK-SAME: [[ARG0:%.*]]: memref<?x?x?xi32>) {
// CHECK-DAG: [[C42:%.*]] = constant 42 : i32
// CHECK-DAG: [[C7:%.*]] = constant 7 : index
// -----
+func @single_iteration_all(%A: memref<?x?x?xi32>) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c3 = constant 3 : index
+ %c6 = constant 6 : index
+ %c7 = constant 7 : index
+ %c10 = constant 10 : index
+ scf.parallel (%i0, %i1, %i2) = (%c0, %c3, %c7) to (%c1, %c6, %c10) step (%c1, %c3, %c3) {
+ %c42 = constant 42 : i32
+ memref.store %c42, %A[%i0, %i1, %i2] : memref<?x?x?xi32>
+ scf.yield
+ }
+ return
+}
+
+// CHECK-LABEL: func @single_iteration_all(
+// CHECK-SAME: [[ARG0:%.*]]: memref<?x?x?xi32>) {
+// CHECK-DAG: [[C42:%.*]] = constant 42 : i32
+// CHECK-DAG: [[C7:%.*]] = constant 7 : index
+// CHECK-DAG: [[C3:%.*]] = constant 3 : index
+// CHECK-DAG: [[C0:%.*]] = constant 0 : index
+// CHECK-NOT: scf.parallel
+// CHECK: memref.store [[C42]], [[ARG0]]{{\[}}[[C0]], [[C3]], [[C7]]] : memref<?x?x?xi32>
+// CHECK-NOT: scf.yield
+// CHECK: return
+
+// -----
+
+func @single_iteration_reduce(%A: index, %B: index) -> (index, index) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c2 = constant 2 : index
+ %c3 = constant 3 : index
+ %c6 = constant 6 : index
+ %0:2 = scf.parallel (%i0, %i1) = (%c1, %c3) to (%c2, %c6) step (%c1, %c3) init(%A, %B) -> (index, index) {
+ scf.reduce(%i0) : index {
+ ^bb0(%lhs: index, %rhs: index):
+ %1 = addi %lhs, %rhs : index
+ scf.reduce.return %1 : index
+ }
+ scf.reduce(%i1) : index {
+ ^bb0(%lhs: index, %rhs: index):
+ %2 = muli %lhs, %rhs : index
+ scf.reduce.return %2 : index
+ }
+ scf.yield
+ }
+ return %0#0, %0#1 : index, index
+}
+
+// CHECK-LABEL: func @single_iteration_reduce(
+// CHECK-SAME: [[ARG0:%.*]]: index, [[ARG1:%.*]]: index)
+// CHECK-DAG: [[C3:%.*]] = constant 3 : index
+// CHECK-DAG: [[C1:%.*]] = constant 1 : index
+// CHECK-NOT: scf.parallel
+// CHECK-NOT: scf.reduce
+// CHECK-NOT: scf.reduce.return
+// CHECK-NOT: scf.yield
+// CHECK: [[V0:%.*]] = addi [[ARG0]], [[C1]]
+// CHECK: [[V1:%.*]] = muli [[ARG1]], [[C3]]
+// CHECK: return [[V0]], [[V1]]
+
+// -----
+
func private @side_effect()
func @one_unused(%cond: i1) -> (index) {
%c0 = constant 0 : index
%ub : index, %lb : index, %step : index) -> (i32, i32) {
// CHECK-NEXT: %[[C32:.*]] = constant 32 : i32
%cst = constant 32 : i32
- // CHECK-NEXT: %[[FOR_RES:.*]] = scf.for {{.*}} iter_args({{.*}} = %[[A0]]) -> (i32) {
+ // CHECK-NEXT: %[[FOR_RES:.*]] = scf.for {{.*}} iter_args({{.*}} = %[[A0]]) -> (i32) {
%0:2 = scf.for %arg1 = %lb to %ub step %step iter_args(%arg2 = %arg0, %arg3 = %cst)
-> (i32, i32) {
%1 = addi %arg2, %cst : i32
%1 = addi %arg2, %cst : i32
scf.yield %1, %1 : i32, i32
}
-
+
// CHECK: return %[[FOR_RES]] : i32
return %0#0 : i32
}