[mlir][SCF] Adding custom builder to SCF::WhileOp.
authorMohammed Anany <manany@google.com>
Tue, 15 Nov 2022 17:10:17 +0000 (18:10 +0100)
committerBenjamin Kramer <benny.kra@googlemail.com>
Tue, 15 Nov 2022 17:16:49 +0000 (18:16 +0100)
This is a similar builder to the one for SCF::IfOp which allows users to pass region builders to it. Refer to the builders for IfOp.

Reviewed By: tpopp

Differential Revision: https://reviews.llvm.org/D137709

mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
mlir/lib/Dialect/SCF/IR/SCF.cpp

index 67a8e43..2d880ac 100644 (file)
@@ -935,7 +935,7 @@ def WhileOp : SCF_Op<"while",
 
     Note that the types of region arguments need not to match with each other.
     The op expects the operand types to match with argument types of the
-    "before" region"; the result types to match with the trailing operand types
+    "before" region; the result types to match with the trailing operand types
     of the terminator of the "before" region, and with the argument types of the
     "after" region. The following scheme can be used to share the results of
     some operations executed in the "before" region with the "after" region,
@@ -983,7 +983,16 @@ def WhileOp : SCF_Op<"while",
   let results = (outs Variadic<AnyType>:$results);
   let regions = (region SizedRegion<1>:$before, SizedRegion<1>:$after);
 
+  let builders = [
+    OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$operands,
+      "function_ref<void(OpBuilder &, Location, ValueRange)>":$beforeBuilder,
+      "function_ref<void(OpBuilder &, Location, ValueRange)>":$afterBuilder)>
+  ];
+
   let extraClassDeclaration = [{
+    using BodyBuilderFn =
+        function_ref<void(OpBuilder &, Location, ValueRange)>;
+
     OperandRange getSuccessorEntryOperands(Optional<unsigned> index);
     ConditionOp getConditionOp();
     YieldOp getYieldOp();
index ac055a7..98d76f3 100644 (file)
@@ -71,40 +71,32 @@ static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,
   SmallVector<Type> types = {elementTy, elementTy, elementTy};
   SmallVector<Location> locations = {loc, loc, loc};
 
-  auto whileOp = rewriter.create<scf::WhileOp>(loc, types, operands);
-  Block *before =
-      rewriter.createBlock(&whileOp.getBefore(), {}, types, locations);
-  Block *after =
-      rewriter.createBlock(&whileOp.getAfter(), {}, types, locations);
-
-  // The conditional block of the while loop.
-  {
-    rewriter.setInsertionPointToStart(&whileOp.getBefore().front());
-    Value input = before->getArgument(0);
-    Value zero = before->getArgument(2);
-
-    Value inputNotZero = rewriter.create<arith::CmpIOp>(
-        loc, arith::CmpIPredicate::ne, input, zero);
-    rewriter.create<scf::ConditionOp>(loc, inputNotZero,
-                                      before->getArguments());
-  }
-
-  // The body of the while loop: shift right until reaching a value of 0.
-  {
-    rewriter.setInsertionPointToStart(&whileOp.getAfter().front());
-    Value input = after->getArgument(0);
-    Value leadingZeros = after->getArgument(1);
-
-    auto one =
-        rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 1));
-    auto shifted = rewriter.create<arith::ShRUIOp>(loc, resultTy, input, one);
-    auto leadingZerosMinusOne =
-        rewriter.create<arith::SubIOp>(loc, resultTy, leadingZeros, one);
-
-    rewriter.create<scf::YieldOp>(
-        loc,
-        ValueRange({shifted, leadingZerosMinusOne, after->getArgument(2)}));
-  }
+  auto whileOp = rewriter.create<scf::WhileOp>(
+      loc, types, operands,
+      [&](OpBuilder &beforeBuilder, Location beforeLoc, ValueRange args) {
+        // The conditional block of the while loop.
+        Value input = args[0];
+        Value zero = args[2];
+
+        Value inputNotZero = beforeBuilder.create<arith::CmpIOp>(
+            loc, arith::CmpIPredicate::ne, input, zero);
+        beforeBuilder.create<scf::ConditionOp>(loc, inputNotZero, args);
+      },
+      [&](OpBuilder &afterBuilder, Location afterLoc, ValueRange args) {
+        // The body of the while loop: shift right until reaching a value of 0.
+        Value input = args[0];
+        Value leadingZeros = args[1];
+
+        auto one = afterBuilder.create<arith::ConstantOp>(
+            loc, IntegerAttr::get(elementTy, 1));
+        auto shifted =
+            afterBuilder.create<arith::ShRUIOp>(loc, resultTy, input, one);
+        auto leadingZerosMinusOne = afterBuilder.create<arith::SubIOp>(
+            loc, resultTy, leadingZeros, one);
+
+        afterBuilder.create<scf::YieldOp>(
+            loc, ValueRange({shifted, leadingZerosMinusOne, args[2]}));
+      });
 
   rewriter.setInsertionPointAfter(whileOp);
   rewriter.replaceOp(op, whileOp->getResult(1));
index b39edc0..118452a 100644 (file)
@@ -2669,6 +2669,34 @@ LogicalResult ReduceReturnOp::verify() {
 // WhileOp
 //===----------------------------------------------------------------------===//
 
+void WhileOp::build(::mlir::OpBuilder &odsBuilder,
+                    ::mlir::OperationState &odsState, TypeRange resultTypes,
+                    ValueRange operands, BodyBuilderFn beforeBuilder,
+                    BodyBuilderFn afterBuilder) {
+  assert(beforeBuilder && "the builder callback for 'before' must be present");
+  assert(afterBuilder && "the builder callback for 'after' must be present");
+
+  odsState.addOperands(operands);
+  odsState.addTypes(resultTypes);
+
+  OpBuilder::InsertionGuard guard(odsBuilder);
+
+  SmallVector<Location, 4> blockArgLocs;
+  for (Value operand : operands) {
+    blockArgLocs.push_back(operand.getLoc());
+  }
+
+  Region *beforeRegion = odsState.addRegion();
+  Block *beforeBlock = odsBuilder.createBlock(beforeRegion, /*insertPt=*/{},
+                                              resultTypes, blockArgLocs);
+  beforeBuilder(odsBuilder, odsState.location, beforeBlock->getArguments());
+
+  Region *afterRegion = odsState.addRegion();
+  Block *afterBlock = odsBuilder.createBlock(afterRegion, /*insertPt=*/{},
+                                             resultTypes, blockArgLocs);
+  afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
+}
+
 OperandRange WhileOp::getSuccessorEntryOperands(Optional<unsigned> index) {
   assert(index && *index == 0 &&
          "WhileOp is expected to branch only to the first region");