Instead, use the builder and infer the return type based on the inner `yield` ops.
Also, fix uses that do not create the terminator as required for the callback builders.
Differential Revision: https://reviews.llvm.org/D142056
let skipDefaultBuilders = 1;
let builders = [
+ OpBuilder<(ins "TypeRange":$resultTypes, "Value":$cond)>,
OpBuilder<(ins "Value":$cond, "bool":$withElseRegion)>,
OpBuilder<(ins "TypeRange":$resultTypes, "Value":$cond,
"bool":$withElseRegion)>,
- // TODO: Remove builder when it is no longer used to create invalid `if` ops
- // (with a type mispatch between the op and it's inner `yield` op).
- OpBuilder<(ins "TypeRange":$resultTypes, "Value":$cond,
- CArg<"function_ref<void(OpBuilder &, Location)>",
- "buildTerminatedBody">:$thenBuilder,
- CArg<"function_ref<void(OpBuilder &, Location)>",
- "nullptr">:$elseBuilder)>,
OpBuilder<(ins "Value":$cond,
CArg<"function_ref<void(OpBuilder &, Location)>",
"buildTerminatedBody">:$thenBuilder,
CArg<"function_ref<void(OpBuilder &, Location)>",
- "nullptr">:$elseBuilder)>
+ "nullptr">:$elseBuilder)>,
];
let extraClassDeclaration = [{
Type indexTy = lb.getIndexType();
broadcastedDim =
lb.create<IfOp>(
- TypeRange{indexTy}, outOfBounds,
+ outOfBounds,
[&](OpBuilder &b, Location loc) {
b.create<scf::YieldOp>(loc, broadcastedDim);
},
loc, arith::CmpIPredicate::ult, iv, rankDiff);
broadcastable =
b.create<IfOp>(
- loc, TypeRange{i1Ty}, outOfBounds,
+ loc, outOfBounds,
[&](OpBuilder &b, Location loc) {
// Non existent dimensions are always broadcastable
b.create<scf::YieldOp>(loc, broadcastable);
Value eqRank = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
firstRank, rank);
auto same = rewriter.create<IfOp>(
- loc, i1Ty, eqRank,
+ loc, eqRank,
[&](OpBuilder &b, Location loc) {
Value one = b.create<arith::ConstantIndexOp>(loc, 1);
Value init =
// If the condition is non-empty, generate an SCF::IfOp.
if (cond) {
auto check = lb.create<scf::IfOp>(
- resultTypes, cond,
+ cond,
/*thenBuilder=*/
[&](OpBuilder &b, Location loc) {
maybeYieldValue(b, loc, hasRetVal, inBoundsCase(b, loc));
};
// Dispatch either single block compute function, or launch async dispatch.
- b.create<scf::IfOp>(TypeRange(), isSingleBlock, syncDispatch, asyncDispatch);
+ b.create<scf::IfOp>(isSingleBlock, syncDispatch, asyncDispatch);
}
// Dispatch parallel compute functions by submitting all async compute tasks
Value useBlockAlignedComputeFn = b.create<arith::CmpIOp>(
arith::CmpIPredicate::sge, blockSize, numIters);
- b.create<scf::IfOp>(TypeRange(), useBlockAlignedComputeFn,
- dispatchBlockAligned, dispatchDefault);
+ b.create<scf::IfOp>(useBlockAlignedComputeFn, dispatchBlockAligned,
+ dispatchDefault);
b.create<scf::YieldOp>();
} else {
dispatchDefault(b, loc);
};
// Replace the `scf.parallel` operation with the parallel compute function.
- b.create<scf::IfOp>(TypeRange(), isZeroIterations, noOp, dispatch);
+ b.create<scf::IfOp>(isZeroIterations, noOp, dispatch);
// Parallel operation was replaced with a block iteration loop.
rewriter.eraseOp(op);
return success();
}
+void IfOp::build(OpBuilder &builder, OperationState &result,
+ TypeRange resultTypes, Value cond) {
+ result.addTypes(resultTypes);
+ result.addOperands(cond);
+
+ // Build regions.
+ OpBuilder::InsertionGuard guard(builder);
+ result.addRegion();
+ result.addRegion();
+}
+
void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
bool withElseRegion) {
- build(builder, result, /*resultTypes=*/std::nullopt, cond, withElseRegion);
+ build(builder, result, TypeRange{}, cond, withElseRegion);
}
void IfOp::build(OpBuilder &builder, OperationState &result,
TypeRange resultTypes, Value cond, bool withElseRegion) {
- auto addTerminator = [&](OpBuilder &nested, Location loc) {
- if (resultTypes.empty())
- IfOp::ensureTerminator(*nested.getInsertionBlock()->getParent(), nested,
- loc);
- };
-
- build(builder, result, resultTypes, cond, addTerminator,
- withElseRegion ? addTerminator
- : function_ref<void(OpBuilder &, Location)>());
-}
-
-void IfOp::build(OpBuilder &builder, OperationState &result,
- TypeRange resultTypes, Value cond,
- function_ref<void(OpBuilder &, Location)> thenBuilder,
- function_ref<void(OpBuilder &, Location)> elseBuilder) {
- assert(thenBuilder && "the builder callback for 'then' must be present");
- result.addOperands(cond);
result.addTypes(resultTypes);
+ result.addOperands(cond);
// Build then region.
OpBuilder::InsertionGuard guard(builder);
Region *thenRegion = result.addRegion();
builder.createBlock(thenRegion);
- thenBuilder(builder, result.location);
+ if (resultTypes.empty())
+ IfOp::ensureTerminator(*thenRegion, builder, result.location);
// Build else region.
Region *elseRegion = result.addRegion();
- if (!elseBuilder)
- return;
- builder.createBlock(elseRegion);
- elseBuilder(builder, result.location);
+ if (withElseRegion) {
+ builder.createBlock(elseRegion);
+ if (resultTypes.empty())
+ IfOp::ensureTerminator(*elseRegion, builder, result.location);
+ }
}
void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
[](OpResult result) { return result.getType(); });
// Create a replacement operation with empty then and else regions.
- auto emptyBuilder = [](OpBuilder &, Location) {};
- auto newOp = rewriter.create<IfOp>(op.getLoc(), newTypes, op.getCondition(),
- emptyBuilder, emptyBuilder);
+ auto newOp =
+ rewriter.create<IfOp>(op.getLoc(), newTypes, op.getCondition());
+ rewriter.createBlock(&newOp.getThenRegion());
+ rewriter.createBlock(&newOp.getElseRegion());
// Move the bodies and replace the terminators (note there is a then and
// an else region since the operation returns results).
if (nonHoistable.size() == op->getNumResults())
return failure();
- IfOp replacement = rewriter.create<IfOp>(op.getLoc(), nonHoistable, cond);
+ IfOp replacement = rewriter.create<IfOp>(op.getLoc(), nonHoistable, cond,
+ /*withElseRegion=*/false);
if (replacement.thenBlock())
rewriter.eraseBlock(replacement.thenBlock());
replacement.getThenRegion().takeBody(op.getThenRegion());
Value newCondition = rewriter.create<arith::AndIOp>(
loc, op.getCondition(), nestedIf.getCondition());
auto newIf = rewriter.create<IfOp>(loc, op.getResultTypes(), newCondition);
+ Block *newIfBlock = rewriter.createBlock(&newIf.getThenRegion());
SmallVector<Value> results;
llvm::append_range(results, newIf.getResults());
results[idx] = rewriter.create<arith::SelectOp>(
op.getLoc(), op.getCondition(), thenYield[idx], elseYield[idx]);
- Block *newIfBlock = newIf.thenBlock();
- if (newIfBlock)
- rewriter.eraseOp(newIfBlock->getTerminator());
- else
- newIfBlock = rewriter.createBlock(&newIf.getThenRegion());
rewriter.mergeBlocks(nestedIf.thenBlock(), newIfBlock);
rewriter.setInsertionPointToEnd(newIf.thenBlock());
rewriter.replaceOpWithNewOp<YieldOp>(newIf.thenYield(), thenYield);
// creating SliceOps with result dimensions of size 0 at runtime.
if (generateZeroSliceGuard && dynHasZeroLenCond) {
auto result = b.create<scf::IfOp>(
- loc, resultType, dynHasZeroLenCond,
+ loc, dynHasZeroLenCond,
/*thenBuilder=*/
[&](OpBuilder &b, Location loc) {
b.create<scf::YieldOp>(loc, createGenerateOp()->getResult(0));
Operation *newOp = builder.clone(*padOp);
builder.create<scf::YieldOp>(loc, newOp->getResults());
};
- rewriter.replaceOpWithNewOp<scf::IfOp>(padOp, padOp.getType(), ifCond,
- thenBuilder, elseBuilder);
+ rewriter.replaceOpWithNewOp<scf::IfOp>(padOp, ifCond, thenBuilder,
+ elseBuilder);
return success();
}
};
Value newResult =
rewriter
.create<scf::IfOp>(
- loc, distrType, isInsertingLane,
+ loc, isInsertingLane,
/*thenBuilder=*/
[&](OpBuilder &builder, Location loc) {
Value newInsert = builder.create<vector::InsertElementOp>(
builder.create<scf::YieldOp>(loc, distributedDest);
};
newResult = rewriter
- .create<scf::IfOp>(loc, distrDestType, isInsertingLane,
+ .create<scf::IfOp>(loc, isInsertingLane,
/*thenBuilder=*/insertingBuilder,
/*elseBuilder=*/nonInsertingBuilder)
.getResult(0);
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
Value memref = xferOp.getSource();
return b.create<scf::IfOp>(
- loc, returnTypes, inBoundsCond,
+ loc, inBoundsCond,
[&](OpBuilder &b, Location loc) {
Value res = memref;
if (compatibleMemRefType != xferOp.getShapedType())
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
Value memref = xferOp.getSource();
return b.create<scf::IfOp>(
- loc, returnTypes, inBoundsCond,
+ loc, inBoundsCond,
[&](OpBuilder &b, Location loc) {
Value res = memref;
if (compatibleMemRefType != xferOp.getShapedType())
Value memref = xferOp.getSource();
return b
.create<scf::IfOp>(
- loc, returnTypes, inBoundsCond,
+ loc, inBoundsCond,
[&](OpBuilder &b, Location loc) {
Value res = memref;
if (compatibleMemRefType != xferOp.getShapedType())