[MLIR] Remove scf.if builder with explicit result types and callbacks
authorFrederik Gossen <frgossen@google.com>
Fri, 20 Jan 2023 15:51:16 +0000 (10:51 -0500)
committerFrederik Gossen <frgossen@google.com>
Fri, 20 Jan 2023 15:52:08 +0000 (10:52 -0500)
Instead, use the builder and infer the return type based on the inner `yield` ops.
Also, fix uses that do not create the terminator as required for the callback builders.

Differential Revision: https://reviews.llvm.org/D142056

mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
mlir/lib/Dialect/SCF/IR/SCF.cpp
mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
mlir/lib/Dialect/Tensor/Transforms/SplitPaddingPatterns.cpp
mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp

index 05adc85..b3b8825 100644 (file)
@@ -667,21 +667,15 @@ def IfOp : SCF_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterface, [
 
   let skipDefaultBuilders = 1;
   let builders = [
+    OpBuilder<(ins "TypeRange":$resultTypes, "Value":$cond)>,
     OpBuilder<(ins "Value":$cond, "bool":$withElseRegion)>,
     OpBuilder<(ins "TypeRange":$resultTypes, "Value":$cond,
       "bool":$withElseRegion)>,
-    // TODO: Remove builder when it is no longer used to create invalid `if` ops
-    // (with a type mispatch between the op and it's inner `yield` op).
-    OpBuilder<(ins "TypeRange":$resultTypes, "Value":$cond,
-      CArg<"function_ref<void(OpBuilder &, Location)>",
-           "buildTerminatedBody">:$thenBuilder,
-      CArg<"function_ref<void(OpBuilder &, Location)>",
-           "nullptr">:$elseBuilder)>,
     OpBuilder<(ins "Value":$cond,
       CArg<"function_ref<void(OpBuilder &, Location)>",
            "buildTerminatedBody">:$thenBuilder,
       CArg<"function_ref<void(OpBuilder &, Location)>",
-           "nullptr">:$elseBuilder)>
+           "nullptr">:$elseBuilder)>,
   ];
 
   let extraClassDeclaration = [{
index 4e746ea..16cbfca 100644 (file)
@@ -92,7 +92,7 @@ Value getBroadcastedDim(ImplicitLocOpBuilder lb, ValueRange extentTensors,
     Type indexTy = lb.getIndexType();
     broadcastedDim =
         lb.create<IfOp>(
-              TypeRange{indexTy}, outOfBounds,
+              outOfBounds,
               [&](OpBuilder &b, Location loc) {
                 b.create<scf::YieldOp>(loc, broadcastedDim);
               },
@@ -293,7 +293,7 @@ LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
               loc, arith::CmpIPredicate::ult, iv, rankDiff);
           broadcastable =
               b.create<IfOp>(
-                   loc, TypeRange{i1Ty}, outOfBounds,
+                   loc, outOfBounds,
                    [&](OpBuilder &b, Location loc) {
                      // Non existent dimensions are always broadcastable
                      b.create<scf::YieldOp>(loc, broadcastable);
@@ -522,7 +522,7 @@ ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor,
     Value eqRank = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
                                                   firstRank, rank);
     auto same = rewriter.create<IfOp>(
-        loc, i1Ty, eqRank,
+        loc, eqRank,
         [&](OpBuilder &b, Location loc) {
           Value one = b.create<arith::ConstantIndexOp>(loc, 1);
           Value init =
index dfaa8b3..401e227 100644 (file)
@@ -192,7 +192,7 @@ static Value generateInBoundsCheck(
   // If the condition is non-empty, generate an SCF::IfOp.
   if (cond) {
     auto check = lb.create<scf::IfOp>(
-        resultTypes, cond,
+        cond,
         /*thenBuilder=*/
         [&](OpBuilder &b, Location loc) {
           maybeYieldValue(b, loc, hasRetVal, inBoundsCase(b, loc));
index 3cd4677..880a8ca 100644 (file)
@@ -645,7 +645,7 @@ static void doAsyncDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter,
   };
 
   // Dispatch either single block compute function, or launch async dispatch.
-  b.create<scf::IfOp>(TypeRange(), isSingleBlock, syncDispatch, asyncDispatch);
+  b.create<scf::IfOp>(isSingleBlock, syncDispatch, asyncDispatch);
 }
 
 // Dispatch parallel compute functions by submitting all async compute tasks
@@ -910,8 +910,8 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
       Value useBlockAlignedComputeFn = b.create<arith::CmpIOp>(
           arith::CmpIPredicate::sge, blockSize, numIters);
 
-      b.create<scf::IfOp>(TypeRange(), useBlockAlignedComputeFn,
-                          dispatchBlockAligned, dispatchDefault);
+      b.create<scf::IfOp>(useBlockAlignedComputeFn, dispatchBlockAligned,
+                          dispatchDefault);
       b.create<scf::YieldOp>();
     } else {
       dispatchDefault(b, loc);
@@ -919,7 +919,7 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
   };
 
   // Replace the `scf.parallel` operation with the parallel compute function.
-  b.create<scf::IfOp>(TypeRange(), isZeroIterations, noOp, dispatch);
+  b.create<scf::IfOp>(isZeroIterations, noOp, dispatch);
 
   // Parallel operation was replaced with a block iteration loop.
   rewriter.eraseOp(op);
index b870330..15e1a68 100644 (file)
@@ -1485,44 +1485,41 @@ IfOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
   return success();
 }
 
+void IfOp::build(OpBuilder &builder, OperationState &result,
+                 TypeRange resultTypes, Value cond) {
+  result.addTypes(resultTypes);
+  result.addOperands(cond);
+
+  // Build regions.
+  OpBuilder::InsertionGuard guard(builder);
+  result.addRegion();
+  result.addRegion();
+}
+
 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
                  bool withElseRegion) {
-  build(builder, result, /*resultTypes=*/std::nullopt, cond, withElseRegion);
+  build(builder, result, TypeRange{}, cond, withElseRegion);
 }
 
 void IfOp::build(OpBuilder &builder, OperationState &result,
                  TypeRange resultTypes, Value cond, bool withElseRegion) {
-  auto addTerminator = [&](OpBuilder &nested, Location loc) {
-    if (resultTypes.empty())
-      IfOp::ensureTerminator(*nested.getInsertionBlock()->getParent(), nested,
-                             loc);
-  };
-
-  build(builder, result, resultTypes, cond, addTerminator,
-        withElseRegion ? addTerminator
-                       : function_ref<void(OpBuilder &, Location)>());
-}
-
-void IfOp::build(OpBuilder &builder, OperationState &result,
-                 TypeRange resultTypes, Value cond,
-                 function_ref<void(OpBuilder &, Location)> thenBuilder,
-                 function_ref<void(OpBuilder &, Location)> elseBuilder) {
-  assert(thenBuilder && "the builder callback for 'then' must be present");
-  result.addOperands(cond);
   result.addTypes(resultTypes);
+  result.addOperands(cond);
 
   // Build then region.
   OpBuilder::InsertionGuard guard(builder);
   Region *thenRegion = result.addRegion();
   builder.createBlock(thenRegion);
-  thenBuilder(builder, result.location);
+  if (resultTypes.empty())
+    IfOp::ensureTerminator(*thenRegion, builder, result.location);
 
   // Build else region.
   Region *elseRegion = result.addRegion();
-  if (!elseBuilder)
-    return;
-  builder.createBlock(elseRegion);
-  elseBuilder(builder, result.location);
+  if (withElseRegion) {
+    builder.createBlock(elseRegion);
+    if (resultTypes.empty())
+      IfOp::ensureTerminator(*elseRegion, builder, result.location);
+  }
 }
 
 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
@@ -1730,9 +1727,10 @@ struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
                     [](OpResult result) { return result.getType(); });
 
     // Create a replacement operation with empty then and else regions.
-    auto emptyBuilder = [](OpBuilder &, Location) {};
-    auto newOp = rewriter.create<IfOp>(op.getLoc(), newTypes, op.getCondition(),
-                                       emptyBuilder, emptyBuilder);
+    auto newOp =
+        rewriter.create<IfOp>(op.getLoc(), newTypes, op.getCondition());
+    rewriter.createBlock(&newOp.getThenRegion());
+    rewriter.createBlock(&newOp.getElseRegion());
 
     // Move the bodies and replace the terminators (note there is a then and
     // an else region since the operation returns results).
@@ -1796,7 +1794,8 @@ struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
     if (nonHoistable.size() == op->getNumResults())
       return failure();
 
-    IfOp replacement = rewriter.create<IfOp>(op.getLoc(), nonHoistable, cond);
+    IfOp replacement = rewriter.create<IfOp>(op.getLoc(), nonHoistable, cond,
+                                             /*withElseRegion=*/false);
     if (replacement.thenBlock())
       rewriter.eraseBlock(replacement.thenBlock());
     replacement.getThenRegion().takeBody(op.getThenRegion());
@@ -2249,6 +2248,7 @@ struct CombineNestedIfs : public OpRewritePattern<IfOp> {
     Value newCondition = rewriter.create<arith::AndIOp>(
         loc, op.getCondition(), nestedIf.getCondition());
     auto newIf = rewriter.create<IfOp>(loc, op.getResultTypes(), newCondition);
+    Block *newIfBlock = rewriter.createBlock(&newIf.getThenRegion());
 
     SmallVector<Value> results;
     llvm::append_range(results, newIf.getResults());
@@ -2258,11 +2258,6 @@ struct CombineNestedIfs : public OpRewritePattern<IfOp> {
       results[idx] = rewriter.create<arith::SelectOp>(
           op.getLoc(), op.getCondition(), thenYield[idx], elseYield[idx]);
 
-    Block *newIfBlock = newIf.thenBlock();
-    if (newIfBlock)
-      rewriter.eraseOp(newIfBlock->getTerminator());
-    else
-      newIfBlock = rewriter.createBlock(&newIf.getThenRegion());
     rewriter.mergeBlocks(nestedIf.thenBlock(), newIfBlock);
     rewriter.setInsertionPointToEnd(newIf.thenBlock());
     rewriter.replaceOpWithNewOp<YieldOp>(newIf.thenYield(), thenYield);
index 3fc470e..0ab41b6 100644 (file)
@@ -632,7 +632,7 @@ Operation *tensor::bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp,
   // creating SliceOps with result dimensions of size 0 at runtime.
   if (generateZeroSliceGuard && dynHasZeroLenCond) {
     auto result = b.create<scf::IfOp>(
-        loc, resultType, dynHasZeroLenCond,
+        loc, dynHasZeroLenCond,
         /*thenBuilder=*/
         [&](OpBuilder &b, Location loc) {
           b.create<scf::YieldOp>(loc, createGenerateOp()->getResult(0));
index 662ba6c..9536f32 100644 (file)
@@ -81,8 +81,8 @@ struct SplitPadding final : public OpRewritePattern<tensor::PadOp> {
       Operation *newOp = builder.clone(*padOp);
       builder.create<scf::YieldOp>(loc, newOp->getResults());
     };
-    rewriter.replaceOpWithNewOp<scf::IfOp>(padOp, padOp.getType(), ifCond,
-                                           thenBuilder, elseBuilder);
+    rewriter.replaceOpWithNewOp<scf::IfOp>(padOp, ifCond, thenBuilder,
+                                           elseBuilder);
     return success();
   }
 };
index c8b0fc4..48995af 100644 (file)
@@ -1126,7 +1126,7 @@ struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
     Value newResult =
         rewriter
             .create<scf::IfOp>(
-                loc, distrType, isInsertingLane,
+                loc, isInsertingLane,
                 /*thenBuilder=*/
                 [&](OpBuilder &builder, Location loc) {
                   Value newInsert = builder.create<vector::InsertElementOp>(
@@ -1257,7 +1257,7 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
         builder.create<scf::YieldOp>(loc, distributedDest);
       };
       newResult = rewriter
-                      .create<scf::IfOp>(loc, distrDestType, isInsertingLane,
+                      .create<scf::IfOp>(loc, isInsertingLane,
                                          /*thenBuilder=*/insertingBuilder,
                                          /*elseBuilder=*/nonInsertingBuilder)
                       .getResult(0);
index c4aad0f..ee23b54 100644 (file)
@@ -252,7 +252,7 @@ createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp,
   Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
   Value memref = xferOp.getSource();
   return b.create<scf::IfOp>(
-      loc, returnTypes, inBoundsCond,
+      loc, inBoundsCond,
       [&](OpBuilder &b, Location loc) {
         Value res = memref;
         if (compatibleMemRefType != xferOp.getShapedType())
@@ -307,7 +307,7 @@ static scf::IfOp createFullPartialVectorTransferRead(
   Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
   Value memref = xferOp.getSource();
   return b.create<scf::IfOp>(
-      loc, returnTypes, inBoundsCond,
+      loc, inBoundsCond,
       [&](OpBuilder &b, Location loc) {
         Value res = memref;
         if (compatibleMemRefType != xferOp.getShapedType())
@@ -358,7 +358,7 @@ getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp,
   Value memref = xferOp.getSource();
   return b
       .create<scf::IfOp>(
-          loc, returnTypes, inBoundsCond,
+          loc, inBoundsCond,
           [&](OpBuilder &b, Location loc) {
             Value res = memref;
             if (compatibleMemRefType != xferOp.getShapedType())