Move CondBranchOp to the ODG framework.
authorRiver Riddle <riverriddle@google.com>
Wed, 29 May 2019 23:46:17 +0000 (16:46 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 2 Jun 2019 03:09:52 +0000 (20:09 -0700)
--

PiperOrigin-RevId: 250593367

mlir/include/mlir/StandardOps/Ops.h
mlir/include/mlir/StandardOps/Ops.td
mlir/lib/StandardOps/Ops.cpp

index 1668d84..224902d 100644 (file)
@@ -93,112 +93,6 @@ enum class CmpFPredicate {
 #define GET_OP_CLASSES
 #include "mlir/StandardOps/Ops.h.inc"
 
-/// The "cond_br" operation represents a conditional branch operation in a
-/// function. The operation takes variable number of operands and produces
-/// no results. The operand number and types for each successor must match the
-//  arguments of the block successor. For example:
-///
-///   ^bb0:
-///      %0 = extract_element %arg0[] : tensor<i1>
-///      cond_br %0, ^bb1, ^bb2
-///   ^bb1:
-///      ...
-///   ^bb2:
-///      ...
-///
-class CondBranchOp : public Op<CondBranchOp, OpTrait::AtLeastNOperands<1>::Impl,
-                               OpTrait::ZeroResult, OpTrait::IsTerminator> {
-  // These are the indices into the dests list.
-  enum { trueIndex = 0, falseIndex = 1 };
-
-  /// The operands list of a conditional branch operation is layed out as
-  /// follows:
-  /// { condition, [true_operands], [false_operands] }
-public:
-  using Op::Op;
-
-  static StringRef getOperationName() { return "std.cond_br"; }
-
-  static void build(Builder *builder, OperationState *result, Value *condition,
-                    Block *trueDest, ArrayRef<Value *> trueOperands,
-                    Block *falseDest, ArrayRef<Value *> falseOperands);
-
-  // Hooks to customize behavior of this op.
-  static ParseResult parse(OpAsmParser *parser, OperationState *result);
-  void print(OpAsmPrinter *p);
-  LogicalResult verify();
-
-  static void getCanonicalizationPatterns(OwningRewritePatternList &results,
-                                          MLIRContext *context);
-
-  // The condition operand is the first operand in the list.
-  Value *getCondition() { return getOperand(0); }
-
-  /// Return the destination if the condition is true.
-  Block *getTrueDest();
-
-  /// Return the destination if the condition is false.
-  Block *getFalseDest();
-
-  // Accessors for operands to the 'true' destination.
-  Value *getTrueOperand(unsigned idx) {
-    assert(idx < getNumTrueOperands());
-    return getOperand(getTrueDestOperandIndex() + idx);
-  }
-
-  void setTrueOperand(unsigned idx, Value *value) {
-    assert(idx < getNumTrueOperands());
-    setOperand(getTrueDestOperandIndex() + idx, value);
-  }
-
-  operand_iterator true_operand_begin() {
-    return operand_begin() + getTrueDestOperandIndex();
-  }
-  operand_iterator true_operand_end() {
-    return true_operand_begin() + getNumTrueOperands();
-  }
-  operand_range getTrueOperands() {
-    return {true_operand_begin(), true_operand_end()};
-  }
-
-  unsigned getNumTrueOperands();
-
-  /// Erase the operand at 'index' from the true operand list.
-  void eraseTrueOperand(unsigned index);
-
-  // Accessors for operands to the 'false' destination.
-  Value *getFalseOperand(unsigned idx) {
-    assert(idx < getNumFalseOperands());
-    return getOperand(getFalseDestOperandIndex() + idx);
-  }
-  void setFalseOperand(unsigned idx, Value *value) {
-    assert(idx < getNumFalseOperands());
-    setOperand(getFalseDestOperandIndex() + idx, value);
-  }
-
-  operand_iterator false_operand_begin() { return true_operand_end(); }
-  operand_iterator false_operand_end() {
-    return false_operand_begin() + getNumFalseOperands();
-  }
-  operand_range getFalseOperands() {
-    return {false_operand_begin(), false_operand_end()};
-  }
-
-  unsigned getNumFalseOperands();
-
-  /// Erase the operand at 'index' from the false operand list.
-  void eraseFalseOperand(unsigned index);
-
-private:
-  /// Get the index of the first true destination operand.
-  unsigned getTrueDestOperandIndex() { return 1; }
-
-  /// Get the index of the first false destination operand.
-  unsigned getFalseDestOperandIndex() {
-    return getTrueDestOperandIndex() + getNumTrueOperands();
-  }
-};
-
 /// This is a refinement of the "constant" op for the case where it is
 /// returning a float value of FloatType.
 ///
index d5a8831..e7b6087 100644 (file)
@@ -369,6 +369,124 @@ def CmpFOp : Std_Op<"cmpf", [NoSideEffect, SameTypeOperands, SameOperandsAndResu
   let hasFolder = 1;
 }
 
+def CondBranchOp : Std_Op<"cond_br", [Terminator]> {
+  let summary = "conditional branch operation";
+  let description = [{
+    The "cond_br" operation represents a conditional branch operation in a
+    function. The operation takes variable number of operands and produces
+    no results. The operand number and types for each successor must match the
+    arguments of the block successor. For example:
+
+      ^bb0:
+         %0 = extract_element %arg0[] : tensor<i1>
+         cond_br %0, ^bb1, ^bb2
+      ^bb1:
+         ...
+      ^bb2:
+         ...
+  }];
+
+  let arguments = (ins I1:$condition, Variadic<AnyType>:$branchOperands);
+
+  let builders = [OpBuilder<
+    "Builder *, OperationState *result, Value *condition,"
+    "Block *trueDest, ArrayRef<Value *> trueOperands,"
+    "Block *falseDest, ArrayRef<Value *> falseOperands", [{
+      result->addOperands(condition);
+      result->addSuccessor(trueDest, trueOperands);
+      result->addSuccessor(falseDest, falseOperands);
+  }]>];
+
+  // CondBranchOp is fully verified by traits.
+  let verifier = ?;
+
+  let extraClassDeclaration = [{
+    // These are the indices into the dests list.
+    enum { trueIndex = 0, falseIndex = 1 };
+
+    // The condition operand is the first operand in the list.
+    Value *getCondition() { return getOperand(0); }
+
+    /// Return the destination if the condition is true.
+    Block *getTrueDest() {
+      return getOperation()->getSuccessor(trueIndex);
+    }
+
+    /// Return the destination if the condition is false.
+    Block *getFalseDest() {
+      return getOperation()->getSuccessor(falseIndex);
+    }
+
+    // Accessors for operands to the 'true' destination.
+    Value *getTrueOperand(unsigned idx) {
+      assert(idx < getNumTrueOperands());
+      return getOperand(getTrueDestOperandIndex() + idx);
+    }
+
+    void setTrueOperand(unsigned idx, Value *value) {
+      assert(idx < getNumTrueOperands());
+      setOperand(getTrueDestOperandIndex() + idx, value);
+    }
+
+    operand_iterator true_operand_begin() {
+      return operand_begin() + getTrueDestOperandIndex();
+    }
+    operand_iterator true_operand_end() {
+      return true_operand_begin() + getNumTrueOperands();
+    }
+    operand_range getTrueOperands() {
+      return {true_operand_begin(), true_operand_end()};
+    }
+
+    unsigned getNumTrueOperands()  {
+      return getOperation()->getNumSuccessorOperands(trueIndex);
+    }
+
+    /// Erase the operand at 'index' from the true operand list.
+    void eraseTrueOperand(unsigned index)  {
+      getOperation()->eraseSuccessorOperand(trueIndex, index);
+    }
+
+    // Accessors for operands to the 'false' destination.
+    Value *getFalseOperand(unsigned idx) {
+      assert(idx < getNumFalseOperands());
+      return getOperand(getFalseDestOperandIndex() + idx);
+    }
+    void setFalseOperand(unsigned idx, Value *value) {
+      assert(idx < getNumFalseOperands());
+      setOperand(getFalseDestOperandIndex() + idx, value);
+    }
+
+    operand_iterator false_operand_begin() { return true_operand_end(); }
+    operand_iterator false_operand_end() {
+      return false_operand_begin() + getNumFalseOperands();
+    }
+    operand_range getFalseOperands() {
+      return {false_operand_begin(), false_operand_end()};
+    }
+
+    unsigned getNumFalseOperands() {
+      return getOperation()->getNumSuccessorOperands(falseIndex);
+    }
+
+    /// Erase the operand at 'index' from the false operand list.
+    void eraseFalseOperand(unsigned index) {
+      getOperation()->eraseSuccessorOperand(falseIndex, index);
+    }
+
+  private:
+    /// Get the index of the first true destination operand.
+    unsigned getTrueDestOperandIndex() { return 1; }
+
+    /// Get the index of the first false destination operand.
+    unsigned getFalseDestOperandIndex() {
+      return getTrueDestOperandIndex() + getNumTrueOperands();
+    }
+  }];
+
+  let hasCanonicalizer = 1;
+}
+
 def ConstantOp : Std_Op<"constant", [NoSideEffect]> {
   let summary = "constant";
 
index c5dc4a0..94f5a67 100644 (file)
@@ -81,7 +81,7 @@ template <typename T> static LogicalResult verifyCastOp(T op) {
 
 StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
     : Dialect(getDialectNamespace(), context) {
-  addOperations<CondBranchOp, DmaStartOp, DmaWaitOp,
+  addOperations<DmaStartOp, DmaWaitOp,
 #define GET_OP_LIST
 #include "mlir/StandardOps/Ops.cpp.inc"
                 >();
@@ -984,16 +984,8 @@ struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
 };
 } // end anonymous namespace.
 
-void CondBranchOp::build(Builder *builder, OperationState *result,
-                         Value *condition, Block *trueDest,
-                         ArrayRef<Value *> trueOperands, Block *falseDest,
-                         ArrayRef<Value *> falseOperands) {
-  result->addOperands(condition);
-  result->addSuccessor(trueDest, trueOperands);
-  result->addSuccessor(falseDest, falseOperands);
-}
-
-ParseResult CondBranchOp::parse(OpAsmParser *parser, OperationState *result) {
+static ParseResult parseCondBranchOp(OpAsmParser *parser,
+                                     OperationState *result) {
   SmallVector<Value *, 4> destOperands;
   Block *dest;
   OpAsmParser::OperandType condInfo;
@@ -1021,19 +1013,13 @@ ParseResult CondBranchOp::parse(OpAsmParser *parser, OperationState *result) {
   return success();
 }
 
-void CondBranchOp::print(OpAsmPrinter *p) {
+static void print(OpAsmPrinter *p, CondBranchOp op) {
   *p << "cond_br ";
-  p->printOperand(getCondition());
+  p->printOperand(op.getCondition());
   *p << ", ";
-  p->printSuccessorAndUseList(getOperation(), trueIndex);
+  p->printSuccessorAndUseList(op.getOperation(), CondBranchOp::trueIndex);
   *p << ", ";
-  p->printSuccessorAndUseList(getOperation(), falseIndex);
-}
-
-LogicalResult CondBranchOp::verify() {
-  if (!getCondition()->getType().isInteger(1))
-    return emitOpError("expected condition type was boolean (i1)");
-  return success();
+  p->printSuccessorAndUseList(op.getOperation(), CondBranchOp::falseIndex);
 }
 
 void CondBranchOp::getCanonicalizationPatterns(
@@ -1041,30 +1027,6 @@ void CondBranchOp::getCanonicalizationPatterns(
   results.push_back(llvm::make_unique<SimplifyConstCondBranchPred>(context));
 }
 
-Block *CondBranchOp::getTrueDest() {
-  return getOperation()->getSuccessor(trueIndex);
-}
-
-Block *CondBranchOp::getFalseDest() {
-  return getOperation()->getSuccessor(falseIndex);
-}
-
-unsigned CondBranchOp::getNumTrueOperands() {
-  return getOperation()->getNumSuccessorOperands(trueIndex);
-}
-
-void CondBranchOp::eraseTrueOperand(unsigned index) {
-  getOperation()->eraseSuccessorOperand(trueIndex, index);
-}
-
-unsigned CondBranchOp::getNumFalseOperands() {
-  return getOperation()->getNumSuccessorOperands(falseIndex);
-}
-
-void CondBranchOp::eraseFalseOperand(unsigned index) {
-  getOperation()->eraseSuccessorOperand(falseIndex, index);
-}
-
 //===----------------------------------------------------------------------===//
 // Constant*Op
 //===----------------------------------------------------------------------===//