From d2f136920b9247a9e5874d4d3a00a880db6e2827 Mon Sep 17 00:00:00 2001 From: Frederik Gossen Date: Tue, 17 Jan 2023 14:07:33 -0500 Subject: [PATCH] [MLIR] Add return type inference to scf.if builder Differential Revision: https://reviews.llvm.org/D141928 --- mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 2 ++ mlir/lib/Dialect/SCF/IR/SCF.cpp | 24 +++++++++++++++++++++--- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td index a610562..9e1752b 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -670,6 +670,8 @@ def IfOp : SCF_Op<"if", OpBuilder<(ins "Value":$cond, "bool":$withElseRegion)>, OpBuilder<(ins "TypeRange":$resultTypes, "Value":$cond, "bool":$withElseRegion)>, + // TODO: Remove builder when it is no longer used to create invalid `if` ops + // (with a type mispatch between the op and it's inner `yield` op). OpBuilder<(ins "TypeRange":$resultTypes, "Value":$cond, CArg<"function_ref", "buildTerminatedBody">:$thenBuilder, diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index fc7ce76..8699f1d 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -1490,19 +1490,19 @@ void IfOp::build(OpBuilder &builder, OperationState &result, function_ref thenBuilder, function_ref elseBuilder) { assert(thenBuilder && "the builder callback for 'then' must be present"); - result.addOperands(cond); result.addTypes(resultTypes); + // Build then region. OpBuilder::InsertionGuard guard(builder); Region *thenRegion = result.addRegion(); builder.createBlock(thenRegion); thenBuilder(builder, result.location); + // Build else region. Region *elseRegion = result.addRegion(); if (!elseBuilder) return; - builder.createBlock(elseRegion); elseBuilder(builder, result.location); } @@ -1510,7 +1510,25 @@ void IfOp::build(OpBuilder &builder, OperationState &result, void IfOp::build(OpBuilder &builder, OperationState &result, Value cond, function_ref thenBuilder, function_ref elseBuilder) { - build(builder, result, TypeRange(), cond, thenBuilder, elseBuilder); + assert(thenBuilder && "the builder callback for 'then' must be present"); + result.addOperands(cond); + + // Build then region. + OpBuilder::InsertionGuard guard(builder); + Region *thenRegion = result.addRegion(); + Block *thenBlock = builder.createBlock(thenRegion); + thenBuilder(builder, result.location); + + // Infer types if there are any. + if (auto yieldOp = llvm::dyn_cast(thenBlock->getTerminator())) + result.addTypes(yieldOp.getOperandTypes()); + + // Build else region. + Region *elseRegion = result.addRegion(); + if (!elseBuilder) + return; + builder.createBlock(elseRegion); + elseBuilder(builder, result.location); } LogicalResult IfOp::verify() { -- 2.7.4