// are split out into a separate continuation (exit) block. A condition block is
// created before the continuation block. It checks the exit condition of the
// loop and branches either to the continuation block, or to the first block of
-// the body. Induction variable modification is appended to the last block of
-// the body (which is the exit block from the body subgraph thanks to the
+// the body. The condition block takes as arguments the values of the induction
+// variable followed by loop-carried values. Since it dominates both the body
+// blocks and the continuation block, loop-carried values are visible in all of
+// those blocks. Induction variable modification is appended to the last block
+// of the body (which is the exit block from the body subgraph thanks to the
// invariant we maintain) along with a branch that loops back to the condition
-// block.
+// block. Loop-carried values are the loop terminator operands, which are
+// forwarded to the branch.
//
// +---------------------------------+
// | <code before the ForOp> |
+// | <definitions of %init...> |
// | <compute initial %iv value> |
-// | br cond(%iv) |
+// | br cond(%iv, %init...) |
// +---------------------------------+
// |
// -------| |
// | v v
// | +--------------------------------+
-// | | cond(%iv): |
+// | | cond(%iv, %init...): |
// | | <compare %iv to upper bound> |
// | | cond_br %r, body, end |
// | +--------------------------------+
// | v |
// | +--------------------------------+ |
// | | body-first: | |
+// | | <%init visible by dominance> | |
// | | <body contents> | |
// | +--------------------------------+ |
// | | |
// | +--------------------------------+ |
// | | body-last: | |
// | | <body contents> | |
+// | | <operands of yield = %yields>| |
// | | %new_iv =<add step to %iv> | |
-// | | br cond(%new_iv) | |
+// | | br cond(%new_iv, %yields) | |
// | +--------------------------------+ |
// | | |
// |----------- |--------------------
// v
// +--------------------------------+
// | end: |
-// | <code after the ForOp> |
+// | <code after the ForOp> |
+// | <%init visible by dominance> |
// +--------------------------------+
//
struct ForLowering : public OpRewritePattern<ForOp> {
// v v
// +--------------------------------+
// | continue: |
-// | <code after the IfOp> |
+// | <code after the IfOp> |
// +--------------------------------+
//
struct IfLowering : public OpRewritePattern<IfOp> {
auto initPosition = rewriter.getInsertionPoint();
auto *endBlock = rewriter.splitBlock(initBlock, initPosition);
- // Use the first block of the loop body as the condition block since it is
- // the block that has the induction variable as its argument. Split out
- // all operations from the first block into a new block. Move all body
- // blocks from the loop body region to the region containing the loop.
+ // Use the first block of the loop body as the condition block since it is the
+ // block that has the induction variable and loop-carried values as arguments.
+ // Split out all operations from the first block into a new block. Move all
+ // body blocks from the loop body region to the region containing the loop.
auto *conditionBlock = &forOp.region().front();
auto *firstBodyBlock =
rewriter.splitBlock(conditionBlock, conditionBlock->begin());
auto iv = conditionBlock->getArgument(0);
// Append the induction variable stepping logic to the last body block and
- // branch back to the condition block. Construct an expression f :
- // (x -> x+step) and apply this expression to the induction variable.
- rewriter.eraseOp(lastBodyBlock->getTerminator());
+ // branch back to the condition block. Loop-carried values are taken from
+ // operands of the loop terminator.
+ Operation *terminator = lastBodyBlock->getTerminator();
rewriter.setInsertionPointToEnd(lastBodyBlock);
auto step = forOp.step();
auto stepped = rewriter.create<AddIOp>(loc, iv, step).getResult();
if (!stepped)
return matchFailure();
- rewriter.create<BranchOp>(loc, conditionBlock, stepped);
+
+ SmallVector<Value, 8> loopCarried;
+ loopCarried.push_back(stepped);
+ loopCarried.append(terminator->operand_begin(), terminator->operand_end());
+ rewriter.create<BranchOp>(loc, conditionBlock, loopCarried);
+ rewriter.eraseOp(terminator);
// Compute loop bounds before branching to the condition.
rewriter.setInsertionPointToEnd(initBlock);
Value upperBound = forOp.upperBound();
if (!lowerBound || !upperBound)
return matchFailure();
- rewriter.create<BranchOp>(loc, conditionBlock, lowerBound);
+
+ // The initial values of loop-carried values is obtained from the operands
+ // of the loop operation.
+ SmallVector<Value, 8> destOperands;
+ destOperands.push_back(lowerBound);
+ auto iterOperands = forOp.getIterOperands();
+ destOperands.append(iterOperands.begin(), iterOperands.end());
+ rewriter.create<BranchOp>(loc, conditionBlock, destOperands);
// With the body block done, we can fill in the condition block.
rewriter.setInsertionPointToEnd(conditionBlock);
rewriter.create<CondBranchOp>(loc, comparison, firstBodyBlock,
ArrayRef<Value>(), endBlock, ArrayRef<Value>());
- // Ok, we're done!
- rewriter.eraseOp(forOp);
+ // The result of the loop operation is the values of the condition block
+ // arguments except the induction variable on the last iteration.
+ rewriter.replaceOp(forOp, conditionBlock->getArguments().drop_front());
return matchSuccess();
}
}
return
}
+
+// CHECK-LABEL: @for_yield
+// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index)
+// CHECK: %[[INIT0:.*]] = constant 0
+// CHECK: %[[INIT1:.*]] = constant 1
+// CHECK: br ^[[COND:.*]](%[[LB]], %[[INIT0]], %[[INIT1]] : index, f32, f32)
+//
+// CHECK: ^[[COND]](%[[ITER:.*]]: index, %[[ITER_ARG0:.*]]: f32, %[[ITER_ARG1:.*]]: f32):
+// CHECK: %[[CMP:.*]] = cmpi "slt", %[[ITER]], %[[UB]] : index
+// CHECK: cond_br %[[CMP]], ^[[BODY:.*]], ^[[CONTINUE:.*]]
+//
+// CHECK: ^[[BODY]]:
+// CHECK: %[[SUM:.*]] = addf %[[ITER_ARG0]], %[[ITER_ARG1]] : f32
+// CHECK: %[[STEPPED:.*]] = addi %[[ITER]], %[[STEP]] : index
+// CHECK: br ^[[COND]](%[[STEPPED]], %[[SUM]], %[[SUM]] : index, f32, f32)
+//
+// CHECK: ^[[CONTINUE]]:
+// CHECK: return %[[ITER_ARG0]], %[[ITER_ARG1]] : f32, f32
+func @for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> (f32, f32) {
+ %s0 = constant 0.0 : f32
+ %s1 = constant 1.0 : f32
+ %result:2 = loop.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%si = %s0, %sj = %s1) -> (f32, f32) {
+ %sn = addf %si, %sj : f32
+ loop.yield %sn, %sn : f32, f32
+ }
+ return %result#0, %result#1 : f32, f32
+}
+
+// CHECK-LABEL: @nested_for_yield
+// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index)
+// CHECK: %[[INIT:.*]] = constant
+// CHECK: br ^[[COND_OUT:.*]](%[[LB]], %[[INIT]] : index, f32)
+// CHECK: ^[[COND_OUT]](%[[ITER_OUT:.*]]: index, %[[ARG_OUT:.*]]: f32):
+// CHECK: cond_br %{{.*}}, ^[[BODY_OUT:.*]], ^[[CONT_OUT:.*]]
+// CHECK: ^[[BODY_OUT]]:
+// CHECK: br ^[[COND_IN:.*]](%[[LB]], %[[ARG_OUT]] : index, f32)
+// CHECK: ^[[COND_IN]](%[[ITER_IN:.*]]: index, %[[ARG_IN:.*]]: f32):
+// CHECK: cond_br %{{.*}}, ^[[BODY_IN:.*]], ^[[CONT_IN:.*]]
+// CHECK: ^[[BODY_IN]]
+// CHECK: %[[RES:.*]] = addf
+// CHECK: br ^[[COND_IN]](%{{.*}}, %[[RES]] : index, f32)
+// CHECK: ^[[CONT_IN]]:
+// CHECK: br ^[[COND_OUT]](%{{.*}}, %[[ARG_IN]] : index, f32)
+// CHECK: ^[[CONT_OUT]]:
+// CHECK: return %[[ARG_OUT]] : f32
+func @nested_for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> f32 {
+ %s0 = constant 1.0 : f32
+ %r = loop.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%iter = %s0) -> (f32) {
+ %result = loop.for %i1 = %arg0 to %arg1 step %arg2 iter_args(%si = %iter) -> (f32) {
+ %sn = addf %si, %si : f32
+ loop.yield %sn : f32
+ }
+ loop.yield %result : f32
+ }
+ return %r : f32
+}