Move the definitions for CallOp and IndirectCallOp to the Op Definition Generator.
authorRiver Riddle <riverriddle@google.com>
Fri, 10 May 2019 22:27:34 +0000 (15:27 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sat, 11 May 2019 02:30:40 +0000 (19:30 -0700)
--

PiperOrigin-RevId: 247686419

mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/StandardOps/Ops.h
mlir/include/mlir/StandardOps/Ops.td
mlir/lib/StandardOps/Ops.cpp
mlir/test/mlir-tblgen/pattern-benefit.td
mlir/test/mlir-tblgen/pattern-multi-result-op.td

index b4b159e..5b186dd 100644 (file)
@@ -309,6 +309,11 @@ def F64 : F<64>;
 def BF16 : Type<CPred<"$_self.isBF16()">, "bfloat16 type">,
            BuildableType<"getBF16Type()">;
 
+// Function Type
+
+// Any function type.
+def FunctionType : Type<CPred<"$_self.isa<FunctionType>()">, "function type">;
+
 // A container type is a type that has another type embedded within it.
 class ContainerType<Type etype, Pred containerPred, code elementTypeCall,
                     string descr> :
index 7f3e8ab..838cd03 100644 (file)
@@ -46,77 +46,6 @@ public:
 #define GET_OP_CLASSES
 #include "mlir/StandardOps/Ops.h.inc"
 
-/// The "call" operation represents a direct call to a function.  The operands
-/// and result types of the call must match the specified function type.  The
-/// callee is encoded as a function attribute named "callee".
-///
-///   %31 = call @my_add(%0, %1)
-///            : (tensor<16xf32>, tensor<16xf32>) -> tensor<16xf32>
-class CallOp
-    : public Op<CallOp, OpTrait::VariadicOperands, OpTrait::VariadicResults> {
-public:
-  friend Operation;
-  using Op::Op;
-
-  static StringRef getOperationName() { return "std.call"; }
-
-  static void build(Builder *builder, OperationState *result, Function *callee,
-                    ArrayRef<Value *> operands);
-
-  Function *getCallee() {
-    return getAttrOfType<FunctionAttr>("callee").getValue();
-  }
-
-  /// Get the argument operands to the called function.
-  operand_range getArgOperands() {
-    return {arg_operand_begin(), arg_operand_end()};
-  }
-
-  operand_iterator arg_operand_begin() { return operand_begin(); }
-  operand_iterator arg_operand_end() { return operand_end(); }
-
-  // Hooks to customize behavior of this op.
-  static ParseResult parse(OpAsmParser *parser, OperationState *result);
-  void print(OpAsmPrinter *p);
-  LogicalResult verify();
-};
-
-/// The "call_indirect" operation represents an indirect call to a value of
-/// function type.  Functions are first class types in MLIR, and may be passed
-/// as arguments and merged together with block arguments.  The operands
-/// and result types of the call must match the specified function type.
-///
-///   %31 = call_indirect %15(%0, %1)
-///            : (tensor<16xf32>, tensor<16xf32>) -> tensor<16xf32>
-///
-class CallIndirectOp : public Op<CallIndirectOp, OpTrait::VariadicOperands,
-                                 OpTrait::VariadicResults> {
-public:
-  friend Operation;
-  using Op::Op;
-  static StringRef getOperationName() { return "std.call_indirect"; }
-
-  static void build(Builder *builder, OperationState *result, Value *callee,
-                    ArrayRef<Value *> operands);
-
-  Value *getCallee() { return getOperand(0); }
-
-  /// Get the argument operands to the called function.
-  operand_range getArgOperands() {
-    return {arg_operand_begin(), arg_operand_end()};
-  }
-
-  operand_iterator arg_operand_begin() { return ++operand_begin(); }
-  operand_iterator arg_operand_end() { return operand_end(); }
-
-  // Hooks to customize behavior of this op.
-  static ParseResult parse(OpAsmParser *parser, OperationState *result);
-  void print(OpAsmPrinter *p);
-  LogicalResult verify();
-  static void getCanonicalizationPatterns(OwningRewritePatternList &results,
-                                          MLIRContext *context);
-};
-
 /// The predicate indicates the type of the comparison to perform:
 /// (in)equality; (un)signed less/greater than (or equal to).
 enum class CmpIPredicate {
index cfdbf1d..16e3bf1 100644 (file)
@@ -168,6 +168,87 @@ def BranchOp : Op<Standard_Dialect, "br", [Terminator]> {
   }];
 }
 
+def CallOp : Op<Standard_Dialect, "call"> {
+  let summary = "call operation";
+  let description = [{
+    The "call" operation represents a direct call to a function.  The operands
+    and result types of the call must match the specified function type.  The
+    callee is encoded as a function attribute named "callee".
+
+      %2 = call @my_add(%0, %1) : (f32, f32) -> f32
+  }];
+
+  let arguments = (ins FunctionAttr:$callee, Variadic<AnyType>:$operands);
+  let results = (outs Variadic<AnyType>);
+
+  let parser = [{ return parseCallOp(parser, result); }];
+  let printer = [{ return printCallOp(p, *this); }];
+  let verifier = [{ return ::verify(*this); }];
+
+  let builders = [OpBuilder<
+    "Builder *builder, OperationState *result, Function *callee,"
+    "ArrayRef<Value *> operands = {}", [{
+      result->addOperands(operands);
+      result->addAttribute("callee", builder->getFunctionAttr(callee));
+      result->addTypes(callee->getType().getResults());
+  }]>];
+
+  let extraClassDeclaration = [{
+    Function *getCallee() {
+      return getAttrOfType<FunctionAttr>("callee").getValue();
+    }
+
+    /// Get the argument operands to the called function.
+    operand_range getArgOperands() {
+      return {arg_operand_begin(), arg_operand_end()};
+    }
+
+    operand_iterator arg_operand_begin() { return operand_begin(); }
+    operand_iterator arg_operand_end() { return operand_end(); }
+  }];
+}
+
+def CallIndirectOp : Op<Standard_Dialect, "call_indirect"> {
+  let summary = "indirect call operation";
+  let description = [{
+    The "call_indirect" operation represents an indirect call to a value of
+    function type.  Functions are first class types in MLIR, and may be passed
+    as arguments and merged together with block arguments.  The operands
+    and result types of the call must match the specified function type.
+
+      %3 = call_indirect %2(%0, %1) : (f32, f32) -> f32
+  }];
+
+  let arguments = (ins FunctionType:$callee, Variadic<AnyType>:$operands);
+  let results = (outs Variadic<AnyType>);
+
+  let parser = [{ return parseCallIndirectOp(parser, result); }];
+  let printer = [{ return printCallIndirectOp(p, *this); }];
+  let verifier = [{ return ::verify(*this); }];
+
+  let builders = [OpBuilder<
+    "Builder *, OperationState *result, Value *callee,"
+    "ArrayRef<Value *> operands = {}", [{
+      result->operands.push_back(callee);
+      result->addOperands(operands);
+      result->addTypes(callee->getType().cast<FunctionType>().getResults());
+  }]>];
+
+  let extraClassDeclaration = [{
+    Value *getCallee() { return getOperand(0); }
+
+    /// Get the argument operands to the called function.
+    operand_range getArgOperands() {
+      return {arg_operand_begin(), arg_operand_end()};
+    }
+
+    operand_iterator arg_operand_begin() { return ++operand_begin(); }
+    operand_iterator arg_operand_end() { return operand_end(); }
+  }];
+
+  let hasCanonicalizer = 0b1;
+}
+
 def ConstantOp : Op<Standard_Dialect, "constant", [NoSideEffect]> {
   let summary = "constant";
 
index 9ad37fd..9a3d9c8 100644 (file)
@@ -61,9 +61,8 @@ void detail::printStandardBinaryOp(Operation *op, OpAsmPrinter *p) {
 
 StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
     : Dialect(/*name=*/"std", context) {
-  addOperations<CallOp, CallIndirectOp, CmpFOp, CmpIOp, CondBranchOp,
-                DmaStartOp, DmaWaitOp, LoadOp, MemRefCastOp, ReturnOp, SelectOp,
-                StoreOp, TensorCastOp,
+  addOperations<CmpFOp, CmpIOp, CondBranchOp, DmaStartOp, DmaWaitOp, LoadOp,
+                MemRefCastOp, ReturnOp, SelectOp, StoreOp, TensorCastOp,
 #define GET_OP_LIST
 #include "mlir/StandardOps/Ops.cpp.inc"
                 >();
@@ -402,14 +401,7 @@ void BranchOp::eraseOperand(unsigned index) {
 // CallOp
 //===----------------------------------------------------------------------===//
 
-void CallOp::build(Builder *builder, OperationState *result, Function *callee,
-                   ArrayRef<Value *> operands) {
-  result->addOperands(operands);
-  result->addAttribute("callee", builder->getFunctionAttr(callee));
-  result->addTypes(callee->getType().getResults());
-}
-
-ParseResult CallOp::parse(OpAsmParser *parser, OperationState *result) {
+static ParseResult parseCallOp(OpAsmParser *parser, OperationState *result) {
   StringRef calleeName;
   llvm::SMLoc calleeLoc;
   FunctionType calleeType;
@@ -430,39 +422,37 @@ ParseResult CallOp::parse(OpAsmParser *parser, OperationState *result) {
   return success();
 }
 
-void CallOp::print(OpAsmPrinter *p) {
+static void printCallOp(OpAsmPrinter *p, CallOp op) {
   *p << "call ";
-  p->printFunctionReference(getCallee());
+  p->printFunctionReference(op.getCallee());
   *p << '(';
-  p->printOperands(getOperands());
+  p->printOperands(op.getOperands());
   *p << ')';
-  p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{"callee"});
-  *p << " : " << getCallee()->getType();
+  p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
+  *p << " : " << op.getCallee()->getType();
 }
 
-LogicalResult CallOp::verify() {
+static LogicalResult verify(CallOp op) {
   // Check that the callee attribute was specified.
-  auto fnAttr = getAttrOfType<FunctionAttr>("callee");
+  auto fnAttr = op.getAttrOfType<FunctionAttr>("callee");
   if (!fnAttr)
-    return emitOpError("requires a 'callee' function attribute");
+    return op.emitOpError("requires a 'callee' function attribute");
 
   // Verify that the operand and result types match the callee.
   auto fnType = fnAttr.getValue()->getType();
-  if (fnType.getNumInputs() != getNumOperands())
-    return emitOpError("incorrect number of operands for callee");
+  if (fnType.getNumInputs() != op.getNumOperands())
+    return op.emitOpError("incorrect number of operands for callee");
 
-  for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
-    if (getOperand(i)->getType() != fnType.getInput(i))
-      return emitOpError("operand type mismatch");
-  }
+  for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
+    if (op.getOperand(i)->getType() != fnType.getInput(i))
+      return op.emitOpError("operand type mismatch");
 
-  if (fnType.getNumResults() != getNumResults())
-    return emitOpError("incorrect number of results for callee");
+  if (fnType.getNumResults() != op.getNumResults())
+    return op.emitOpError("incorrect number of results for callee");
 
-  for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
-    if (getResult(i)->getType() != fnType.getResult(i))
-      return emitOpError("result type mismatch");
-  }
+  for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
+    if (op.getResult(i)->getType() != fnType.getResult(i))
+      return op.emitOpError("result type mismatch");
 
   return success();
 }
@@ -498,15 +488,8 @@ struct SimplifyIndirectCallWithKnownCallee : public RewritePattern {
 };
 } // end anonymous namespace.
 
-void CallIndirectOp::build(Builder *builder, OperationState *result,
-                           Value *callee, ArrayRef<Value *> operands) {
-  auto fnType = callee->getType().cast<FunctionType>();
-  result->operands.push_back(callee);
-  result->addOperands(operands);
-  result->addTypes(fnType.getResults());
-}
-
-ParseResult CallIndirectOp::parse(OpAsmParser *parser, OperationState *result) {
+static ParseResult parseCallIndirectOp(OpAsmParser *parser,
+                                       OperationState *result) {
   FunctionType calleeType;
   OpAsmParser::OperandType callee;
   llvm::SMLoc operandsLoc;
@@ -524,39 +507,37 @@ ParseResult CallIndirectOp::parse(OpAsmParser *parser, OperationState *result) {
       parser->addTypesToList(calleeType.getResults(), result->types));
 }
 
-void CallIndirectOp::print(OpAsmPrinter *p) {
+static void printCallIndirectOp(OpAsmPrinter *p, CallIndirectOp op) {
   *p << "call_indirect ";
-  p->printOperand(getCallee());
+  p->printOperand(op.getCallee());
   *p << '(';
-  auto operandRange = getOperands();
+  auto operandRange = op.getOperands();
   p->printOperands(++operandRange.begin(), operandRange.end());
   *p << ')';
-  p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{"callee"});
-  *p << " : " << getCallee()->getType();
+  p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
+  *p << " : " << op.getCallee()->getType();
 }
 
-LogicalResult CallIndirectOp::verify() {
+static LogicalResult verify(CallIndirectOp op) {
   // The callee must be a function.
-  auto fnType = getCallee()->getType().dyn_cast<FunctionType>();
+  auto fnType = op.getCallee()->getType().dyn_cast<FunctionType>();
   if (!fnType)
-    return emitOpError("callee must have function type");
+    return op.emitOpError("callee must have function type");
 
   // Verify that the operand and result types match the callee.
-  if (fnType.getNumInputs() != getNumOperands() - 1)
-    return emitOpError("incorrect number of operands for callee");
+  if (fnType.getNumInputs() != op.getNumOperands() - 1)
+    return op.emitOpError("incorrect number of operands for callee");
 
-  for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
-    if (getOperand(i + 1)->getType() != fnType.getInput(i))
-      return emitOpError("operand type mismatch");
-  }
+  for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
+    if (op.getOperand(i + 1)->getType() != fnType.getInput(i))
+      return op.emitOpError("operand type mismatch");
 
-  if (fnType.getNumResults() != getNumResults())
-    return emitOpError("incorrect number of results for callee");
+  if (fnType.getNumResults() != op.getNumResults())
+    return op.emitOpError("incorrect number of results for callee");
 
-  for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
-    if (getResult(i)->getType() != fnType.getResult(i))
-      return emitOpError("result type mismatch");
-  }
+  for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
+    if (op.getResult(i)->getType() != fnType.getResult(i))
+      return op.emitOpError("result type mismatch");
 
   return success();
 }
index dfe9e67..61db84b 100644 (file)
@@ -23,12 +23,12 @@ def Z_AddOp : NS_Op<"add"> {
 }
 
 // Define rewrite patterns.
-def : Pat<(X_AddOp (X_AddOp $lhs, $rhs), $rhs), (Y_AddOp $lhs, $rhs, $rhs)>;
+def bena : Pat<(X_AddOp (X_AddOp $lhs, $rhs), $rhs), (Y_AddOp $lhs, $rhs, $rhs)>;
 
-// CHECK-LABEL: struct GeneratedConvert0
-// CHECK: GeneratedConvert0(MLIRContext *context) : RewritePattern("x.add", 2, context) {}
+// CHECK-LABEL: struct bena
+// CHECK: RewritePattern("x.add", 2, context) {}
 
-def : Pat<(X_AddOp $lhs, $rhs), (Z_AddOp $lhs), [(IfEqual $lhs, $rhs)], (addBenefit 100)>;
+def benb : Pat<(X_AddOp $lhs, $rhs), (Z_AddOp $lhs), [(IfEqual $lhs, $rhs)], (addBenefit 100)>;
 
-// CHECK-LABEL: struct GeneratedConvert1
-// CHECK: GeneratedConvert1(MLIRContext *context) : RewritePattern("x.add", 101, context) {}
+// CHECK-LABEL: struct benb
+// CHECK: RewritePattern("x.add", 101, context) {}
index 8acfeb8..4efe691 100644 (file)
@@ -23,13 +23,13 @@ def OneResultOp : NS_Op<"one_result_op", []> {
   let results = (outs I32:$r1);
 }
 
-def : Pattern<(ThreeResultOp $input), [
+def : Pattern<(ThreeResultOp $input), [
         (OneResultOp $input),
         (OneResultOp $input),
         (OneResultOp $input)
       ]>;
 
-// CHECK-LABEL: struct GeneratedConvert0
+// CHECK-LABEL: struct a
 
 // CHECK: void rewrite(
 // CHECK:      auto vOneResultOp0 = rewriter.create<OneResultOp>(
@@ -37,13 +37,13 @@ def : Pattern<(ThreeResultOp $input), [
 // CHECK:      auto vOneResultOp2 = rewriter.create<OneResultOp>(
 // CHECK:      rewriter.replaceOp(op, {vOneResultOp0, vOneResultOp1, vOneResultOp2});
 
-def : Pattern<(ThreeResultOp $input), [
+def : Pattern<(ThreeResultOp $input), [
         (OneResultOp (OneResultOp:$interm $input)),
         (OneResultOp $interm),
         (OneResultOp (OneResultOp $interm))
       ]>;
 
-// CHECK-LABEL: struct GeneratedConvert1
+// CHECK-LABEL: struct b
 
 // CHECK:      void rewrite(
 // CHECK:        auto interm = rewriter.create<OneResultOp>(
@@ -64,7 +64,7 @@ def AdditionalOp : NS_Op<"additional_one_result_op", []> {
   let arguments = (ins I32:$input);
   let results = (outs I32:$r1);
 }
-def : Pattern<(TwoResultOp $input), [
+def : Pattern<(TwoResultOp $input), [
         // Additional op generated to help build the final result but not
         // directly used to replace the source op
         (AdditionalOp:$interm $input),
@@ -73,7 +73,7 @@ def : Pattern<(TwoResultOp $input), [
         (OneResultOp $input)
       ]>;
 
-// CHECK-LABEL: struct GeneratedConvert2
+// CHECK-LABEL: struct c
 
 // CHECK:      auto interm = rewriter.create<AdditionalOp>(
 // CHECK:      auto vOneResultOp0 = rewriter.create<OneResultOp>(