Note that the types of region arguments need not to match with each other.
The op expects the operand types to match with argument types of the
- "before" region"; the result types to match with the trailing operand types
+ "before" region; the result types to match with the trailing operand types
of the terminator of the "before" region, and with the argument types of the
"after" region. The following scheme can be used to share the results of
some operations executed in the "before" region with the "after" region,
let results = (outs Variadic<AnyType>:$results);
let regions = (region SizedRegion<1>:$before, SizedRegion<1>:$after);
+ let builders = [
+ OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$operands,
+ "function_ref<void(OpBuilder &, Location, ValueRange)>":$beforeBuilder,
+ "function_ref<void(OpBuilder &, Location, ValueRange)>":$afterBuilder)>
+ ];
+
let extraClassDeclaration = [{
+ using BodyBuilderFn =
+ function_ref<void(OpBuilder &, Location, ValueRange)>;
+
OperandRange getSuccessorEntryOperands(Optional<unsigned> index);
ConditionOp getConditionOp();
YieldOp getYieldOp();
SmallVector<Type> types = {elementTy, elementTy, elementTy};
SmallVector<Location> locations = {loc, loc, loc};
- auto whileOp = rewriter.create<scf::WhileOp>(loc, types, operands);
- Block *before =
- rewriter.createBlock(&whileOp.getBefore(), {}, types, locations);
- Block *after =
- rewriter.createBlock(&whileOp.getAfter(), {}, types, locations);
-
- // The conditional block of the while loop.
- {
- rewriter.setInsertionPointToStart(&whileOp.getBefore().front());
- Value input = before->getArgument(0);
- Value zero = before->getArgument(2);
-
- Value inputNotZero = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::ne, input, zero);
- rewriter.create<scf::ConditionOp>(loc, inputNotZero,
- before->getArguments());
- }
-
- // The body of the while loop: shift right until reaching a value of 0.
- {
- rewriter.setInsertionPointToStart(&whileOp.getAfter().front());
- Value input = after->getArgument(0);
- Value leadingZeros = after->getArgument(1);
-
- auto one =
- rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 1));
- auto shifted = rewriter.create<arith::ShRUIOp>(loc, resultTy, input, one);
- auto leadingZerosMinusOne =
- rewriter.create<arith::SubIOp>(loc, resultTy, leadingZeros, one);
-
- rewriter.create<scf::YieldOp>(
- loc,
- ValueRange({shifted, leadingZerosMinusOne, after->getArgument(2)}));
- }
+ auto whileOp = rewriter.create<scf::WhileOp>(
+ loc, types, operands,
+ [&](OpBuilder &beforeBuilder, Location beforeLoc, ValueRange args) {
+ // The conditional block of the while loop.
+ Value input = args[0];
+ Value zero = args[2];
+
+ Value inputNotZero = beforeBuilder.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::ne, input, zero);
+ beforeBuilder.create<scf::ConditionOp>(loc, inputNotZero, args);
+ },
+ [&](OpBuilder &afterBuilder, Location afterLoc, ValueRange args) {
+ // The body of the while loop: shift right until reaching a value of 0.
+ Value input = args[0];
+ Value leadingZeros = args[1];
+
+ auto one = afterBuilder.create<arith::ConstantOp>(
+ loc, IntegerAttr::get(elementTy, 1));
+ auto shifted =
+ afterBuilder.create<arith::ShRUIOp>(loc, resultTy, input, one);
+ auto leadingZerosMinusOne = afterBuilder.create<arith::SubIOp>(
+ loc, resultTy, leadingZeros, one);
+
+ afterBuilder.create<scf::YieldOp>(
+ loc, ValueRange({shifted, leadingZerosMinusOne, args[2]}));
+ });
rewriter.setInsertionPointAfter(whileOp);
rewriter.replaceOp(op, whileOp->getResult(1));
// WhileOp
//===----------------------------------------------------------------------===//
+void WhileOp::build(::mlir::OpBuilder &odsBuilder,
+ ::mlir::OperationState &odsState, TypeRange resultTypes,
+ ValueRange operands, BodyBuilderFn beforeBuilder,
+ BodyBuilderFn afterBuilder) {
+ assert(beforeBuilder && "the builder callback for 'before' must be present");
+ assert(afterBuilder && "the builder callback for 'after' must be present");
+
+ odsState.addOperands(operands);
+ odsState.addTypes(resultTypes);
+
+ OpBuilder::InsertionGuard guard(odsBuilder);
+
+ SmallVector<Location, 4> blockArgLocs;
+ for (Value operand : operands) {
+ blockArgLocs.push_back(operand.getLoc());
+ }
+
+ Region *beforeRegion = odsState.addRegion();
+ Block *beforeBlock = odsBuilder.createBlock(beforeRegion, /*insertPt=*/{},
+ resultTypes, blockArgLocs);
+ beforeBuilder(odsBuilder, odsState.location, beforeBlock->getArguments());
+
+ Region *afterRegion = odsState.addRegion();
+ Block *afterBlock = odsBuilder.createBlock(afterRegion, /*insertPt=*/{},
+ resultTypes, blockArgLocs);
+ afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
+}
+
OperandRange WhileOp::getSuccessorEntryOperands(Optional<unsigned> index) {
assert(index && *index == 0 &&
"WhileOp is expected to branch only to the first region");