[MLIR] Add return type inference to scf.if builder
authorFrederik Gossen <frgossen@google.com>
Tue, 17 Jan 2023 19:07:33 +0000 (14:07 -0500)
committerFrederik Gossen <frgossen@google.com>
Tue, 17 Jan 2023 19:09:22 +0000 (14:09 -0500)
Differential Revision: https://reviews.llvm.org/D141928

mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
mlir/lib/Dialect/SCF/IR/SCF.cpp

index a610562..9e1752b 100644 (file)
@@ -670,6 +670,8 @@ def IfOp : SCF_Op<"if",
     OpBuilder<(ins "Value":$cond, "bool":$withElseRegion)>,
     OpBuilder<(ins "TypeRange":$resultTypes, "Value":$cond,
       "bool":$withElseRegion)>,
+    // TODO: Remove builder when it is no longer used to create invalid `if` ops
+    // (with a type mispatch between the op and it's inner `yield` op).
     OpBuilder<(ins "TypeRange":$resultTypes, "Value":$cond,
       CArg<"function_ref<void(OpBuilder &, Location)>",
            "buildTerminatedBody">:$thenBuilder,
index fc7ce76..8699f1d 100644 (file)
@@ -1490,19 +1490,19 @@ void IfOp::build(OpBuilder &builder, OperationState &result,
                  function_ref<void(OpBuilder &, Location)> thenBuilder,
                  function_ref<void(OpBuilder &, Location)> elseBuilder) {
   assert(thenBuilder && "the builder callback for 'then' must be present");
-
   result.addOperands(cond);
   result.addTypes(resultTypes);
 
+  // Build then region.
   OpBuilder::InsertionGuard guard(builder);
   Region *thenRegion = result.addRegion();
   builder.createBlock(thenRegion);
   thenBuilder(builder, result.location);
 
+  // Build else region.
   Region *elseRegion = result.addRegion();
   if (!elseBuilder)
     return;
-
   builder.createBlock(elseRegion);
   elseBuilder(builder, result.location);
 }
@@ -1510,7 +1510,25 @@ void IfOp::build(OpBuilder &builder, OperationState &result,
 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
                  function_ref<void(OpBuilder &, Location)> thenBuilder,
                  function_ref<void(OpBuilder &, Location)> elseBuilder) {
-  build(builder, result, TypeRange(), cond, thenBuilder, elseBuilder);
+  assert(thenBuilder && "the builder callback for 'then' must be present");
+  result.addOperands(cond);
+
+  // Build then region.
+  OpBuilder::InsertionGuard guard(builder);
+  Region *thenRegion = result.addRegion();
+  Block *thenBlock = builder.createBlock(thenRegion);
+  thenBuilder(builder, result.location);
+
+  // Infer types if there are any.
+  if (auto yieldOp = llvm::dyn_cast<YieldOp>(thenBlock->getTerminator()))
+    result.addTypes(yieldOp.getOperandTypes());
+
+  // Build else region.
+  Region *elseRegion = result.addRegion();
+  if (!elseBuilder)
+    return;
+  builder.createBlock(elseRegion);
+  elseBuilder(builder, result.location);
 }
 
 LogicalResult IfOp::verify() {