From 991040478b9de4dba2e44787b23f8fe9412e9a82 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Fri, 5 Jul 2019 03:34:49 -0700 Subject: [PATCH] Add a standard if op This CL adds an "std.if" op to represent an if-then-else construct whose condition is an arbitrary value of type i1. This is necessary to lower all the existing examples from affine and linalg to std.for + std.if. This CL introduces the op and adds the relevant positive and negative unit test. Lowering will be done in a separate followup CL. PiperOrigin-RevId: 256649138 --- mlir/include/mlir/StandardOps/Ops.h | 7 +++ mlir/include/mlir/StandardOps/Ops.td | 40 ++++++++++++++++ mlir/lib/StandardOps/Ops.cpp | 90 ++++++++++++++++++++++++++++++++---- mlir/test/IR/core-ops.mlir | 26 ++++++++++- mlir/test/IR/invalid-ops.mlir | 42 ++++++++++++++++- 5 files changed, 194 insertions(+), 11 deletions(-) diff --git a/mlir/include/mlir/StandardOps/Ops.h b/mlir/include/mlir/StandardOps/Ops.h index 311ec45..4e74edc 100644 --- a/mlir/include/mlir/StandardOps/Ops.h +++ b/mlir/include/mlir/StandardOps/Ops.h @@ -24,6 +24,7 @@ #define MLIR_STANDARDOPS_OPS_H #include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/StandardTypes.h" @@ -356,6 +357,12 @@ ParseResult parseDimAndSymbolList(OpAsmParser *parser, SmallVector &operands, unsigned &numDims); +// Insert `std.terminator` at the end of the only region's only block if it does +// not have a terminator already. If a new `std.terminator` is inserted, the +// location is specified by `loc`. If the region is empty, insert a new block +// first. +void ensureStdTerminator(Region ®ion, Builder &builder, Location loc); + /// The "std.for" operation represents a loop nest taking 3 SSA value as /// operands that represent the lower bound, upper bound and step respectively. /// The operation defines an SSA value for its induction variable. It has one diff --git a/mlir/include/mlir/StandardOps/Ops.td b/mlir/include/mlir/StandardOps/Ops.td index 02e288a..a233b23 100644 --- a/mlir/include/mlir/StandardOps/Ops.td +++ b/mlir/include/mlir/StandardOps/Ops.td @@ -609,6 +609,46 @@ def ExtractElementOp : Std_Op<"extract_element", [NoSideEffect]> { let hasFolder = 1; } +// TODO(ntv): Default generated builder creates IR that does not verify. Atm it +// is the responsibility of each caller to call ensureStdTerminator on the +// then and else regions. +def IfOp : Std_Op<"if"> { + let summary = "if-then-else operation"; + let description = [{ + The "std.if" operation represents an if-then-else construct for + conditionally executing two regions of code. The operand to an if operation + is a boolean value. The operation produces no results. For example: + + std.if %b { + ... + } else { + ... + } + + The 'else' block is optional, and may be omitted. For + example: + + std.if %b { + ... + } + }]; + let arguments = (ins I1:$condition); + let regions = (region SizedRegion<1>:$thenRegion, AnyRegion:$elseRegion); + + let extraClassDeclaration = [{ + OpBuilder getThenBodyBuilder() { + assert(!thenRegion().empty() && "Unexpected empty 'then' region."); + Block &body = thenRegion().front(); + return OpBuilder(&body, std::prev(body.end())); + } + OpBuilder getElseBodyBuilder() { + assert(!elseRegion().empty() && "Unexpected empty 'else' region."); + Block &body = elseRegion().front(); + return OpBuilder(&body, std::prev(body.end())); + } + }]; +} + def IndexCastOp : CastOp<"index_cast">, Arguments<(ins AnyType:$in)> { let summary = "cast between index and integer types"; let description = [{ diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp index cb2253f..e5d2240 100644 --- a/mlir/lib/StandardOps/Ops.cpp +++ b/mlir/lib/StandardOps/Ops.cpp @@ -1638,8 +1638,9 @@ OpFoldResult ExtractElementOp::fold(ArrayRef operands) { //////////////////////////////////////////////////////////////////////////////// // StdForOp. //////////////////////////////////////////////////////////////////////////////// + // Check that if a "block" has a terminator, it is an `TerminatorOp`. -static LogicalResult checkHasTerminator(OpState &op, Block &block) { +static LogicalResult checkHasStdTerminator(OpState &op, Block &block) { if (block.empty() || isa(block.back())) return success(); @@ -1650,11 +1651,7 @@ static LogicalResult checkHasTerminator(OpState &op, Block &block) { << StdTerminatorOp::getOperationName() << "'"; } -// Insert `cf.terminator` at the end of the StdForOp only region's only block -// if it does not have a terminator already. If a new `cf.terminator` is -// inserted, the location is specified by `loc`. If the region is empty, -// insert a new block first. -static void ensureTerminator(Region ®ion, Builder &builder, Location loc) { +void mlir::ensureStdTerminator(Region ®ion, Builder &builder, Location loc) { impl::ensureRegionTerminator(region, builder, loc); } @@ -1665,7 +1662,7 @@ void StdForOp::build(Builder *builder, OperationState *result, Value *lb, Block *body = new Block(); body->addArgument(IndexType::get(builder->getContext())); bodyRegion->push_back(body); - ensureTerminator(*bodyRegion, *builder, result->location); + ensureStdTerminator(*bodyRegion, *builder, result->location); } LogicalResult StdForOp::verify() { @@ -1694,7 +1691,7 @@ LogicalResult StdForOp::verify() { !body->getArgument(0)->getType().isIndex()) return emitOpError("expected body to have a single index argument for " "the induction variable"); - if (failed(checkHasTerminator(*this, *body))) + if (failed(checkHasStdTerminator(*this, *body))) return failure(); return success(); } @@ -1731,7 +1728,7 @@ ParseResult StdForOp::parse(OpAsmParser *parser, OperationState *result) { if (parser->parseRegion(*body, inductionVariable, indexType)) return failure(); - ensureTerminator(*body, builder, result->location); + ensureStdTerminator(*body, builder, result->location); // Parse the optional attribute list. if (parser->parseOptionalAttributeDict(result->attributes)) @@ -1755,6 +1752,81 @@ StdForOp getStdForInductionVarOwner(Value *val) { } //===----------------------------------------------------------------------===// +// IfOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(IfOp op) { + // Verify that the entry block of each child region does not have arguments. + for (auto ®ion : op.getOperation()->getRegions()) { + if (region.empty()) + continue; + + // TODO(riverriddle) We currently do not allow multiple blocks in child + // regions. + if (std::next(region.begin()) != region.end()) + return op.emitOpError("expected one block per 'then' or 'else' regions"); + if (failed(checkHasStdTerminator(op, region.front()))) + return failure(); + + for (auto &b : region) + if (b.getNumArguments() != 0) + return op.emitOpError( + "requires that child entry blocks have no arguments"); + } + return success(); +} + +static ParseResult parseIfOp(OpAsmParser *parser, OperationState *result) { + // Create the regions for 'then'. + result->regions.reserve(2); + Region *thenRegion = result->addRegion(); + Region *elseRegion = result->addRegion(); + + auto &builder = parser->getBuilder(); + OpAsmParser::OperandType cond; + Type i1Type = builder.getIntegerType(1); + if (parser->parseOperand(cond) || + parser->resolveOperand(cond, i1Type, result->operands)) + return failure(); + + // Parse the 'then' region. + if (parser->parseRegion(*thenRegion, {}, {})) + return failure(); + ensureStdTerminator(*thenRegion, parser->getBuilder(), result->location); + + // If we find an 'else' keyword then parse the 'else' region. + if (!parser->parseOptionalKeyword("else")) { + if (parser->parseRegion(*elseRegion, {}, {})) + return failure(); + ensureStdTerminator(*elseRegion, parser->getBuilder(), result->location); + } + + // Parse the optional attribute list. + if (parser->parseOptionalAttributeDict(result->attributes)) + return failure(); + + return success(); +} + +static void print(OpAsmPrinter *p, IfOp op) { + *p << IfOp::getOperationName() << " " << *op.condition(); + p->printRegion(op.thenRegion(), + /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/false); + + // Print the 'else' regions if it exists and has a block. + auto &elseRegion = op.elseRegion(); + if (!elseRegion.empty()) { + *p << " else"; + p->printRegion(elseRegion, + /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/false); + } + + p->printOptionalAttrDict(op.getAttrs()); +} + +//===----------------------------------------------------------------------===// // IndexCastOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir index ffbf4fb..653c91a 100644 --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -437,7 +437,7 @@ func @std_for(%arg0 : index, %arg1 : index, %arg2 : index) { } return } -// CHECK-LABEL: func @std_for(%arg0: index, %arg1: index, %arg2: index) { +// CHECK-LABEL: func @std_for( // CHECK-NEXT: std.for %i0 = %arg0 to %arg1 step %arg2 { // CHECK-NEXT: std.for %i1 = %arg0 to %arg1 step %arg2 { // CHECK-NEXT: %0 = cmpi "slt", %i0, %i1 : index @@ -445,3 +445,27 @@ func @std_for(%arg0 : index, %arg1 : index, %arg2 : index) { // CHECK-NEXT: %2 = cmpi "sge", %i0, %i1 : index // CHECK-NEXT: %3 = select %2, %i0, %i1 : index // CHECK-NEXT: std.for %i2 = %1 to %3 step %i1 { + +func @std_if(%arg0: i1, %arg1: f32) { + std.if %arg0 { + %0 = addf %arg1, %arg1 : f32 + } + return +} +// CHECK-LABEL: func @std_if( +// CHECK-NEXT: std.if %arg0 { +// CHECK-NEXT: %{{.*}} = addf %arg1, %arg1 : f32 + +func @std_if_else(%arg0: i1, %arg1: f32) { + std.if %arg0 { + %0 = addf %arg1, %arg1 : f32 + } else { + %1 = addf %arg1, %arg1 : f32 + } + return +} +// CHECK-LABEL: func @std_if_else( +// CHECK-NEXT: std.if %arg0 { +// CHECK-NEXT: %{{.*}} = addf %arg1, %arg1 : f32 +// CHECK-NEXT: } else { +// CHECK-NEXT: %{{.*}} = addf %arg1, %arg1 : f32 diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir index b4c76dc..2049d6e 100644 --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -770,4 +770,44 @@ func @std_for_single_index_argument(%arg0: index) { } ) : (index, index, index) -> () return -} \ No newline at end of file +} + +// ----- + +func @std_if_not_i1(%arg0: index) { + // expected-error@+1 {{operand #0 must be 1-bit integer}} + "std.if"(%arg0) : (index) -> () + return +} + +// ----- + +func @std_if_more_than_2_regions(%arg0: i1) { + // expected-error@+1 {{op has incorrect number of regions: expected 2}} + "std.if"(%arg0) ({}, {}, {}): (i1) -> () + return +} + +// ----- + +func @std_if_not_one_block_per_region(%arg0: i1) { + // expected-error@+1 {{region #0 ('thenRegion') failed to verify constraint: region with 1 blocks}} + "std.if"(%arg0) ({ + ^bb0: + "std.terminator"() : () -> () + ^bb1: + "std.terminator"() : () -> () + }, {}): (i1) -> () + return +} + +// ----- + +func @std_if_illegal_block_argument(%arg0: i1) { + // expected-error@+1 {{requires that child entry blocks have no arguments}} + "std.if"(%arg0) ({ + ^bb0(%0 : index): + "std.terminator"() : () -> () + }, {}): (i1) -> () + return +} -- 2.7.4