From: River Riddle Date: Wed, 29 May 2019 23:46:17 +0000 (-0700) Subject: Move CondBranchOp to the ODG framework. X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=e2af847a2eb5f7f9d6863b00a3a23b4ba635342c;p=platform%2Fupstream%2Fllvm.git Move CondBranchOp to the ODG framework. -- PiperOrigin-RevId: 250593367 --- diff --git a/mlir/include/mlir/StandardOps/Ops.h b/mlir/include/mlir/StandardOps/Ops.h index 1668d84..224902d 100644 --- a/mlir/include/mlir/StandardOps/Ops.h +++ b/mlir/include/mlir/StandardOps/Ops.h @@ -93,112 +93,6 @@ enum class CmpFPredicate { #define GET_OP_CLASSES #include "mlir/StandardOps/Ops.h.inc" -/// The "cond_br" operation represents a conditional branch operation in a -/// function. The operation takes variable number of operands and produces -/// no results. The operand number and types for each successor must match the -// arguments of the block successor. For example: -/// -/// ^bb0: -/// %0 = extract_element %arg0[] : tensor -/// cond_br %0, ^bb1, ^bb2 -/// ^bb1: -/// ... -/// ^bb2: -/// ... -/// -class CondBranchOp : public Op::Impl, - OpTrait::ZeroResult, OpTrait::IsTerminator> { - // These are the indices into the dests list. - enum { trueIndex = 0, falseIndex = 1 }; - - /// The operands list of a conditional branch operation is layed out as - /// follows: - /// { condition, [true_operands], [false_operands] } -public: - using Op::Op; - - static StringRef getOperationName() { return "std.cond_br"; } - - static void build(Builder *builder, OperationState *result, Value *condition, - Block *trueDest, ArrayRef trueOperands, - Block *falseDest, ArrayRef falseOperands); - - // Hooks to customize behavior of this op. - static ParseResult parse(OpAsmParser *parser, OperationState *result); - void print(OpAsmPrinter *p); - LogicalResult verify(); - - static void getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context); - - // The condition operand is the first operand in the list. - Value *getCondition() { return getOperand(0); } - - /// Return the destination if the condition is true. - Block *getTrueDest(); - - /// Return the destination if the condition is false. - Block *getFalseDest(); - - // Accessors for operands to the 'true' destination. - Value *getTrueOperand(unsigned idx) { - assert(idx < getNumTrueOperands()); - return getOperand(getTrueDestOperandIndex() + idx); - } - - void setTrueOperand(unsigned idx, Value *value) { - assert(idx < getNumTrueOperands()); - setOperand(getTrueDestOperandIndex() + idx, value); - } - - operand_iterator true_operand_begin() { - return operand_begin() + getTrueDestOperandIndex(); - } - operand_iterator true_operand_end() { - return true_operand_begin() + getNumTrueOperands(); - } - operand_range getTrueOperands() { - return {true_operand_begin(), true_operand_end()}; - } - - unsigned getNumTrueOperands(); - - /// Erase the operand at 'index' from the true operand list. - void eraseTrueOperand(unsigned index); - - // Accessors for operands to the 'false' destination. - Value *getFalseOperand(unsigned idx) { - assert(idx < getNumFalseOperands()); - return getOperand(getFalseDestOperandIndex() + idx); - } - void setFalseOperand(unsigned idx, Value *value) { - assert(idx < getNumFalseOperands()); - setOperand(getFalseDestOperandIndex() + idx, value); - } - - operand_iterator false_operand_begin() { return true_operand_end(); } - operand_iterator false_operand_end() { - return false_operand_begin() + getNumFalseOperands(); - } - operand_range getFalseOperands() { - return {false_operand_begin(), false_operand_end()}; - } - - unsigned getNumFalseOperands(); - - /// Erase the operand at 'index' from the false operand list. - void eraseFalseOperand(unsigned index); - -private: - /// Get the index of the first true destination operand. - unsigned getTrueDestOperandIndex() { return 1; } - - /// Get the index of the first false destination operand. - unsigned getFalseDestOperandIndex() { - return getTrueDestOperandIndex() + getNumTrueOperands(); - } -}; - /// This is a refinement of the "constant" op for the case where it is /// returning a float value of FloatType. /// diff --git a/mlir/include/mlir/StandardOps/Ops.td b/mlir/include/mlir/StandardOps/Ops.td index d5a8831..e7b6087 100644 --- a/mlir/include/mlir/StandardOps/Ops.td +++ b/mlir/include/mlir/StandardOps/Ops.td @@ -369,6 +369,124 @@ def CmpFOp : Std_Op<"cmpf", [NoSideEffect, SameTypeOperands, SameOperandsAndResu let hasFolder = 1; } +def CondBranchOp : Std_Op<"cond_br", [Terminator]> { + let summary = "conditional branch operation"; + let description = [{ + The "cond_br" operation represents a conditional branch operation in a + function. The operation takes variable number of operands and produces + no results. The operand number and types for each successor must match the + arguments of the block successor. For example: + + ^bb0: + %0 = extract_element %arg0[] : tensor + cond_br %0, ^bb1, ^bb2 + ^bb1: + ... + ^bb2: + ... + }]; + + let arguments = (ins I1:$condition, Variadic:$branchOperands); + + let builders = [OpBuilder< + "Builder *, OperationState *result, Value *condition," + "Block *trueDest, ArrayRef trueOperands," + "Block *falseDest, ArrayRef falseOperands", [{ + result->addOperands(condition); + result->addSuccessor(trueDest, trueOperands); + result->addSuccessor(falseDest, falseOperands); + }]>]; + + // CondBranchOp is fully verified by traits. + let verifier = ?; + + let extraClassDeclaration = [{ + // These are the indices into the dests list. + enum { trueIndex = 0, falseIndex = 1 }; + + // The condition operand is the first operand in the list. + Value *getCondition() { return getOperand(0); } + + /// Return the destination if the condition is true. + Block *getTrueDest() { + return getOperation()->getSuccessor(trueIndex); + } + + /// Return the destination if the condition is false. + Block *getFalseDest() { + return getOperation()->getSuccessor(falseIndex); + } + + // Accessors for operands to the 'true' destination. + Value *getTrueOperand(unsigned idx) { + assert(idx < getNumTrueOperands()); + return getOperand(getTrueDestOperandIndex() + idx); + } + + void setTrueOperand(unsigned idx, Value *value) { + assert(idx < getNumTrueOperands()); + setOperand(getTrueDestOperandIndex() + idx, value); + } + + operand_iterator true_operand_begin() { + return operand_begin() + getTrueDestOperandIndex(); + } + operand_iterator true_operand_end() { + return true_operand_begin() + getNumTrueOperands(); + } + operand_range getTrueOperands() { + return {true_operand_begin(), true_operand_end()}; + } + + unsigned getNumTrueOperands() { + return getOperation()->getNumSuccessorOperands(trueIndex); + } + + /// Erase the operand at 'index' from the true operand list. + void eraseTrueOperand(unsigned index) { + getOperation()->eraseSuccessorOperand(trueIndex, index); + } + + // Accessors for operands to the 'false' destination. + Value *getFalseOperand(unsigned idx) { + assert(idx < getNumFalseOperands()); + return getOperand(getFalseDestOperandIndex() + idx); + } + void setFalseOperand(unsigned idx, Value *value) { + assert(idx < getNumFalseOperands()); + setOperand(getFalseDestOperandIndex() + idx, value); + } + + operand_iterator false_operand_begin() { return true_operand_end(); } + operand_iterator false_operand_end() { + return false_operand_begin() + getNumFalseOperands(); + } + operand_range getFalseOperands() { + return {false_operand_begin(), false_operand_end()}; + } + + unsigned getNumFalseOperands() { + return getOperation()->getNumSuccessorOperands(falseIndex); + } + + /// Erase the operand at 'index' from the false operand list. + void eraseFalseOperand(unsigned index) { + getOperation()->eraseSuccessorOperand(falseIndex, index); + } + + private: + /// Get the index of the first true destination operand. + unsigned getTrueDestOperandIndex() { return 1; } + + /// Get the index of the first false destination operand. + unsigned getFalseDestOperandIndex() { + return getTrueDestOperandIndex() + getNumTrueOperands(); + } + }]; + + let hasCanonicalizer = 1; +} + def ConstantOp : Std_Op<"constant", [NoSideEffect]> { let summary = "constant"; diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp index c5dc4a0..94f5a67b 100644 --- a/mlir/lib/StandardOps/Ops.cpp +++ b/mlir/lib/StandardOps/Ops.cpp @@ -81,7 +81,7 @@ template static LogicalResult verifyCastOp(T op) { StandardOpsDialect::StandardOpsDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context) { - addOperations(); @@ -984,16 +984,8 @@ struct SimplifyConstCondBranchPred : public OpRewritePattern { }; } // end anonymous namespace. -void CondBranchOp::build(Builder *builder, OperationState *result, - Value *condition, Block *trueDest, - ArrayRef trueOperands, Block *falseDest, - ArrayRef falseOperands) { - result->addOperands(condition); - result->addSuccessor(trueDest, trueOperands); - result->addSuccessor(falseDest, falseOperands); -} - -ParseResult CondBranchOp::parse(OpAsmParser *parser, OperationState *result) { +static ParseResult parseCondBranchOp(OpAsmParser *parser, + OperationState *result) { SmallVector destOperands; Block *dest; OpAsmParser::OperandType condInfo; @@ -1021,19 +1013,13 @@ ParseResult CondBranchOp::parse(OpAsmParser *parser, OperationState *result) { return success(); } -void CondBranchOp::print(OpAsmPrinter *p) { +static void print(OpAsmPrinter *p, CondBranchOp op) { *p << "cond_br "; - p->printOperand(getCondition()); + p->printOperand(op.getCondition()); *p << ", "; - p->printSuccessorAndUseList(getOperation(), trueIndex); + p->printSuccessorAndUseList(op.getOperation(), CondBranchOp::trueIndex); *p << ", "; - p->printSuccessorAndUseList(getOperation(), falseIndex); -} - -LogicalResult CondBranchOp::verify() { - if (!getCondition()->getType().isInteger(1)) - return emitOpError("expected condition type was boolean (i1)"); - return success(); + p->printSuccessorAndUseList(op.getOperation(), CondBranchOp::falseIndex); } void CondBranchOp::getCanonicalizationPatterns( @@ -1041,30 +1027,6 @@ void CondBranchOp::getCanonicalizationPatterns( results.push_back(llvm::make_unique(context)); } -Block *CondBranchOp::getTrueDest() { - return getOperation()->getSuccessor(trueIndex); -} - -Block *CondBranchOp::getFalseDest() { - return getOperation()->getSuccessor(falseIndex); -} - -unsigned CondBranchOp::getNumTrueOperands() { - return getOperation()->getNumSuccessorOperands(trueIndex); -} - -void CondBranchOp::eraseTrueOperand(unsigned index) { - getOperation()->eraseSuccessorOperand(trueIndex, index); -} - -unsigned CondBranchOp::getNumFalseOperands() { - return getOperation()->getNumSuccessorOperands(falseIndex); -} - -void CondBranchOp::eraseFalseOperand(unsigned index) { - getOperation()->eraseSuccessorOperand(falseIndex, index); -} - //===----------------------------------------------------------------------===// // Constant*Op //===----------------------------------------------------------------------===//