Enable autogenerating OpInterface method declarations
authorJacques Pienaar <jpienaar@google.com>
Mon, 30 Sep 2019 19:42:31 +0000 (12:42 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 30 Sep 2019 19:42:58 +0000 (12:42 -0700)
Add DeclareOpInterfaceFunctions to enable specifying whether OpInterfaceMethods
for an OpInterface should be generated automatically. This avoids needing to
declare the extra methods, while also allowing adding function declaration by way of trait/inheritance.

Most of this change is mechanical/extracting classes to be reusable.

PiperOrigin-RevId: 272042739

mlir/g3doc/OpDefinitions.md
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/TableGen/OpInterfaces.h [new file with mode: 0644]
mlir/include/mlir/TableGen/OpTrait.h
mlir/lib/TableGen/OpInterfaces.cpp [new file with mode: 0644]
mlir/lib/TableGen/OpTrait.cpp
mlir/test/lib/TestDialect/TestOps.td
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
mlir/tools/mlir-tblgen/OpInterfacesGen.cpp

index ff7192c..d852de7 100644 (file)
@@ -330,6 +330,10 @@ An `InterfaceMethod` is comprised of the following components:
     -   In non-static methods, a variable 'ConcreteOp op' is defined and may be
         used to refer to an instance of the derived operation.
 
+ODS also allows generating the declarations for the `InterfaceMethod` of the op
+if one specifies the interface with `DeclareOpInterfaceMethods` (see example
+below).
+
 Examples:
 
 ```tablegen
@@ -369,6 +373,13 @@ def MyInterface : OpInterface<"MyInterface"> {
     }]>,
   ];
 }
+
+// Interfaces can optionally be wrapped inside DeclareOpInterfaceMethods. This
+// would result in autogenerating declarations for members `foo`, `bar` and
+// `fooStatic`. Methods without bodies are not declared inside the op
+// declaration but instead handled by the op interface trait directly.
+def OpWithInferTypeInterfaceOp : Op<...
+    [DeclareOpInterfaceMethods<MyInterface>]> { ... }
 ```
 
 ### Custom builder methods
index 89609ed..c662576 100644 (file)
@@ -1279,16 +1279,16 @@ class InterfaceMethod<string desc, string retTy, string methodName,
   // A human-readable description of what this method does.
   string description = desc;
 
-  /// The name of the interface method.
+  // The name of the interface method.
   string name = methodName;
 
-  /// The c++ type-name of the return type.
+  // The c++ type-name of the return type.
   string returnType = retTy;
 
-  /// A dag of string that correspond to the arguments of the method.
+  // A dag of string that correspond to the arguments of the method.
   dag arguments = args;
 
-  /// An optional body to the method.
+  // An optional body to the method.
   code body = methodBody;
 }
 
@@ -1305,10 +1305,20 @@ class OpInterface<string name> : OpInterfaceTrait<name> {
   // The name given to the c++ interface class.
   string cppClassName = name;
 
-  /// The list of methods defined by this interface.
+  // The list of methods defined by this interface.
   list<InterfaceMethod> methods = [];
 }
 
+// Whether to declare the op interface methods in the op's header. This class
+// simply wraps an OpInterface but is used to indicate that the method
+// declarations should be generated.
+class DeclareOpInterfaceMethods<OpInterface interface> :
+  OpInterface<interface.cppClassName> {
+    let description = interface.description;
+    let cppClassName = interface.cppClassName;
+    let methods = interface.methods;
+}
+
 //===----------------------------------------------------------------------===//
 // Op definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/TableGen/OpInterfaces.h b/mlir/include/mlir/TableGen/OpInterfaces.h
new file mode 100644 (file)
index 0000000..46f43c6
--- /dev/null
@@ -0,0 +1,105 @@
+//===- OpInterfaces.h - OpInterfaces wrapper class --------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// OpInterfaces wrapper to simplify using TableGen OpInterfaces.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TABLEGEN_OPINTERFACES_H_
+#define MLIR_TABLEGEN_OPINTERFACES_H_
+
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/StringRef.h"
+
+namespace llvm {
+class Init;
+class Record;
+} // end namespace llvm
+
+namespace mlir {
+namespace tblgen {
+
+// Wrapper class with helper methods for accessing OpInterfaceMethod defined
+// in TableGen.
+class OpInterfaceMethod {
+public:
+  // This struct represents a single method argument.
+  struct Argument {
+    StringRef type;
+    StringRef name;
+  };
+
+  explicit OpInterfaceMethod(const llvm::Record *def);
+
+  // Return the return type of this method.
+  StringRef getReturnType() const;
+
+  // Return the name of this method.
+  StringRef getName() const;
+
+  // Return if this method is static.
+  bool isStatic() const;
+
+  // Return the body for this method if it has one.
+  llvm::Optional<StringRef> getBody() const;
+
+  // Return the description of this method if it has one.
+  llvm::Optional<StringRef> getDescription() const;
+
+  // Arguments.
+  ArrayRef<Argument> getArguments() const;
+  bool arg_empty() const;
+
+private:
+  // The TableGen definition of this method.
+  const llvm::Record *def;
+
+  // The arguments of this method.
+  SmallVector<Argument, 2> arguments;
+};
+
+//===----------------------------------------------------------------------===//
+// OpInterface
+//===----------------------------------------------------------------------===//
+
+// Wrapper class with helper methods for accessing OpInterfaces defined in
+// TableGen.
+class OpInterface {
+public:
+  explicit OpInterface(const llvm::Record *def);
+
+  // Return the name of this interface.
+  StringRef getName() const;
+
+  // Return the methods of this interface.
+  ArrayRef<OpInterfaceMethod> getMethods() const;
+
+  // Return the description of this method if it has one.
+  llvm::Optional<StringRef> getDescription() const;
+
+private:
+  // The TableGen definition of this interface.
+  const llvm::Record *def;
+
+  // The methods of this interface.
+  SmallVector<OpInterfaceMethod, 8> methods;
+};
+
+} // end namespace tblgen
+} // end namespace mlir
+
+#endif // MLIR_TABLEGEN_OPINTERFACES_H_
index 8a3463d..cfa1d93 100644 (file)
@@ -33,6 +33,8 @@ class Record;
 namespace mlir {
 namespace tblgen {
 
+class OpInterface;
+
 // Wrapper class with helper methods for accessing OpTrait constraints defined
 // in TableGen.
 class OpTrait {
@@ -44,7 +46,9 @@ public:
     // OpTrait corresponding to predicate on operation.
     Pred,
     // OpTrait controlling op definition generator internals.
-    Internal
+    Internal,
+    // OpTrait corresponding to OpInterface.
+    Interface
   };
 
   explicit OpTrait(Kind kind, const llvm::Record *def);
@@ -92,6 +96,23 @@ public:
   }
 };
 
+// OpTrait corresponding to an OpInterface on the operation.
+class InterfaceOpTrait : public OpTrait {
+public:
+  // Returns member function defitions corresponding to the trait,
+  OpInterface getOpInterface() const;
+
+  // Returns the trait corresponding to a C++ trait class.
+  StringRef getTrait() const;
+
+  static bool classof(const OpTrait *t) {
+    return t->getKind() == Kind::Interface;
+  }
+
+  // Whether the declaration of methods for this trait should be emitted.
+  bool shouldDeclareMethods() const;
+};
+
 } // end namespace tblgen
 } // end namespace mlir
 
diff --git a/mlir/lib/TableGen/OpInterfaces.cpp b/mlir/lib/TableGen/OpInterfaces.cpp
new file mode 100644 (file)
index 0000000..e4e80e0
--- /dev/null
@@ -0,0 +1,90 @@
+//===- OpInterfaces.cpp - OpInterfaces class ------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// OpInterfaces wrapper to simplify using TableGen OpInterfaces.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/TableGen/OpInterfaces.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/TableGen/Error.h"
+#include "llvm/TableGen/Record.h"
+
+using namespace mlir;
+using namespace mlir::tblgen;
+
+OpInterfaceMethod::OpInterfaceMethod(const llvm::Record *def) : def(def) {
+  llvm::DagInit *args = def->getValueAsDag("arguments");
+  for (unsigned i = 0, e = args->getNumArgs(); i != e; ++i) {
+    arguments.push_back(
+        {llvm::cast<llvm::StringInit>(args->getArg(i))->getValue(),
+         args->getArgNameStr(i)});
+  }
+}
+
+StringRef OpInterfaceMethod::getReturnType() const {
+  return def->getValueAsString("returnType");
+}
+
+// Return the name of this method.
+StringRef OpInterfaceMethod::getName() const {
+  return def->getValueAsString("name");
+}
+
+// Return if this method is static.
+bool OpInterfaceMethod::isStatic() const {
+  return def->isSubClassOf("StaticInterfaceMethod");
+}
+
+// Return the body for this method if it has one.
+llvm::Optional<StringRef> OpInterfaceMethod::getBody() const {
+  auto value = def->getValueAsString("body");
+  return value.empty() ? llvm::Optional<StringRef>() : value;
+}
+
+// Return the description of this method if it has one.
+llvm::Optional<StringRef> OpInterfaceMethod::getDescription() const {
+  auto value = def->getValueAsString("description");
+  return value.empty() ? llvm::Optional<StringRef>() : value;
+}
+
+ArrayRef<OpInterfaceMethod::Argument> OpInterfaceMethod::getArguments() const {
+  return arguments;
+}
+
+bool OpInterfaceMethod::arg_empty() const { return arguments.empty(); }
+
+OpInterface::OpInterface(const llvm::Record *def) : def(def) {
+  auto *listInit = dyn_cast<llvm::ListInit>(def->getValueInit("methods"));
+  for (llvm::Init *init : listInit->getValues())
+    methods.emplace_back(cast<llvm::DefInit>(init)->getDef());
+}
+
+// Return the name of this interface.
+StringRef OpInterface::getName() const {
+  return def->getValueAsString("cppClassName");
+}
+
+// Return the methods of this interface.
+ArrayRef<OpInterfaceMethod> OpInterface::getMethods() const { return methods; }
+
+// Return the description of this method if it has one.
+llvm::Optional<StringRef> OpInterface::getDescription() const {
+  auto value = def->getValueAsString("description");
+  return value.empty() ? llvm::Optional<StringRef>() : value;
+}
index 0a357ac..0e436a8 100644 (file)
@@ -20,6 +20,8 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/TableGen/OpTrait.h"
+#include "mlir/Support/STLExtras.h"
+#include "mlir/TableGen/OpInterfaces.h"
 #include "mlir/TableGen/Predicate.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/TableGen/Record.h"
 
 using namespace mlir;
+using namespace mlir::tblgen;
 
-mlir::tblgen::OpTrait mlir::tblgen::OpTrait::create(const llvm::Init *init) {
+OpTrait OpTrait::create(const llvm::Init *init) {
   auto def = cast<llvm::DefInit>(init)->getDef();
   if (def->isSubClassOf("PredOpTrait"))
     return OpTrait(Kind::Pred, def);
   if (def->isSubClassOf("GenInternalOpTrait"))
     return OpTrait(Kind::Internal, def);
+  if (def->isSubClassOf("OpInterface"))
+    return OpTrait(Kind::Interface, def);
   assert(def->isSubClassOf("NativeOpTrait"));
   return OpTrait(Kind::Native, def);
 }
 
-mlir::tblgen::OpTrait::OpTrait(Kind kind, const llvm::Record *def)
-    : def(def), kind(kind) {}
+OpTrait::OpTrait(Kind kind, const llvm::Record *def) : def(def), kind(kind) {}
 
-llvm::StringRef mlir::tblgen::NativeOpTrait::getTrait() const {
+llvm::StringRef NativeOpTrait::getTrait() const {
   return def->getValueAsString("trait");
 }
 
-llvm::StringRef mlir::tblgen::InternalOpTrait::getTrait() const {
+llvm::StringRef InternalOpTrait::getTrait() const {
   return def->getValueAsString("trait");
 }
 
-std::string mlir::tblgen::PredOpTrait::getPredTemplate() const {
+std::string PredOpTrait::getPredTemplate() const {
   auto pred = tblgen::Pred(def->getValueInit("predicate"));
   return pred.getCondition();
 }
 
-llvm::StringRef mlir::tblgen::PredOpTrait::getDescription() const {
+llvm::StringRef PredOpTrait::getDescription() const {
   return def->getValueAsString("description");
 }
+
+OpInterface InterfaceOpTrait::getOpInterface() const {
+  return OpInterface(def);
+}
+
+llvm::StringRef InterfaceOpTrait::getTrait() const {
+  return def->getValueAsString("trait");
+}
+
+bool InterfaceOpTrait::shouldDeclareMethods() const {
+  return def->isSubClassOf("DeclareOpInterfaceMethods");
+}
index 72991ce..e419b7e 100644 (file)
@@ -330,15 +330,9 @@ def I32ElementsAttrOp : TEST_Op<"i32ElementsAttr"> {
 }
 
 def OpWithInferTypeInterfaceOp : TEST_Op<"op_with_infer_type_if",
-    [InferTypeOpInterface]> {
+    [DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
   let arguments = (ins AnyTensor:$x, AnyTensor:$y);
   let results = (outs AnyTensor:$res);
-  // TODO(jpienaar): Remove the need to specify these here.
-  let extraClassDeclaration = [{
-    SmallVector<Type, 2> inferReturnTypes(llvm::Optional<Location> location,
-      ArrayRef<Value*> operands, ArrayRef<NamedAttribute> attributes,
-      ArrayRef<Region> regions);
-  }];
 }
 
 def IsNotScalar : Constraint<CPred<"$0.getType().getRank() != 0">>;
index e3d70a1..089a492 100644 (file)
@@ -23,6 +23,7 @@
 #include "mlir/Support/STLExtras.h"
 #include "mlir/TableGen/Format.h"
 #include "mlir/TableGen/GenInfo.h"
+#include "mlir/TableGen/OpInterfaces.h"
 #include "mlir/TableGen/OpTrait.h"
 #include "mlir/TableGen/Operator.h"
 #include "llvm/ADT/StringExtras.h"
@@ -541,6 +542,9 @@ private:
   // Generates the traits used by the object.
   void genTraits();
 
+  // Generate the OpInterface methods.
+  void genOpInterfaceMethods();
+
 private:
   // The TableGen record for this op.
   // TODO(antiagainst,zinenko): OpEmitter should not have a Record directly,
@@ -577,6 +581,7 @@ OpEmitter::OpEmitter(const Operator &op)
   genVerifier();
   genCanonicalizerDecls();
   genFolderDecls();
+  genOpInterfaceMethods();
 }
 
 void OpEmitter::emitDecl(const Operator &op, raw_ostream &os) {
@@ -1071,6 +1076,30 @@ void OpEmitter::genFolderDecls() {
   }
 }
 
+void OpEmitter::genOpInterfaceMethods() {
+  for (const auto &trait : op.getTraits()) {
+    auto opTrait = dyn_cast<tblgen::InterfaceOpTrait>(&trait);
+    if (!opTrait || !opTrait->shouldDeclareMethods())
+      continue;
+    auto interface = opTrait->getOpInterface();
+    for (auto method : interface.getMethods()) {
+      // Don't declare if the method has a body.
+      if (method.getBody())
+        continue;
+      std::string args;
+      llvm::raw_string_ostream os(args);
+      mlir::interleaveComma(method.getArguments(), os,
+                            [&](const OpInterfaceMethod::Argument &arg) {
+                              os << arg.type << " " << arg.name;
+                            });
+      opClass.newMethod(method.getReturnType(), method.getName(), os.str(),
+                        method.isStatic() ? OpMethod::MP_Static
+                                          : OpMethod::MP_None,
+                        /*declOnly=*/true);
+    }
+  }
+}
+
 void OpEmitter::genParser() {
   if (!hasStringAttribute(def, "parser"))
     return;
@@ -1286,6 +1315,8 @@ void OpEmitter::genTraits() {
   for (const auto &trait : op.getTraits()) {
     if (auto opTrait = dyn_cast<tblgen::NativeOpTrait>(&trait))
       opClass.addTrait(opTrait->getTrait());
+    else if (auto opTrait = dyn_cast<tblgen::InterfaceOpTrait>(&trait))
+      opClass.addTrait(opTrait->getTrait());
   }
 
   // Add variadic size trait and normal op traits.
index 7961b6f..4da412c 100644 (file)
@@ -22,6 +22,7 @@
 #include "DocGenUtilities.h"
 #include "mlir/Support/STLExtras.h"
 #include "mlir/TableGen/GenInfo.h"
+#include "mlir/TableGen/OpInterfaces.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Support/FormatVariadic.h"
 
 using namespace llvm;
 using namespace mlir;
-
-namespace {
-//===----------------------------------------------------------------------===//
-// OpInterfaceMethod
-//===----------------------------------------------------------------------===//
-
-// This struct represents a single method argument.
-struct MethodArgument {
-  StringRef type, name;
-};
-
-// Wrapper class around a single interface method.
-class OpInterfaceMethod {
-public:
-  explicit OpInterfaceMethod(const llvm::Record *def) : def(def) {
-    llvm::DagInit *args = def->getValueAsDag("arguments");
-    for (unsigned i = 0, e = args->getNumArgs(); i != e; ++i) {
-      arguments.push_back(
-          {llvm::cast<llvm::StringInit>(args->getArg(i))->getValue(),
-           args->getArgNameStr(i)});
-    }
-  }
-
-  // Return the return type of this method.
-  StringRef getReturnType() const {
-    return def->getValueAsString("returnType");
-  }
-
-  // Return the name of this method.
-  StringRef getName() const { return def->getValueAsString("name"); }
-
-  // Return if this method is static.
-  bool isStatic() const { return def->isSubClassOf("StaticInterfaceMethod"); }
-
-  // Return the body for this method if it has one.
-  llvm::Optional<StringRef> getBody() const {
-    auto value = def->getValueAsString("body");
-    return value.empty() ? llvm::Optional<StringRef>() : value;
-  }
-
-  // Return the description of this method if it has one.
-  llvm::Optional<StringRef> getDescription() const {
-    auto value = def->getValueAsString("description");
-    return value.empty() ? llvm::Optional<StringRef>() : value;
-  }
-
-  // Arguments.
-  ArrayRef<MethodArgument> getArguments() const { return arguments; }
-  bool arg_empty() const { return arguments.empty(); }
-
-protected:
-  // The TableGen definition of this method.
-  const llvm::Record *def;
-
-  // The arguments of this method.
-  SmallVector<MethodArgument, 2> arguments;
-};
-
-//===----------------------------------------------------------------------===//
-// OpInterface
-//===----------------------------------------------------------------------===//
-
-// Wrapper class with helper methods for accessing OpInterfaces defined in
-// TableGen.
-class OpInterface {
-public:
-  explicit OpInterface(const llvm::Record *def) : def(def) {
-    auto *listInit = dyn_cast<llvm::ListInit>(def->getValueInit("methods"));
-    for (llvm::Init *init : listInit->getValues())
-      methods.emplace_back(cast<llvm::DefInit>(init)->getDef());
-  }
-
-  // Return the name of this interface.
-  StringRef getName() const { return def->getValueAsString("cppClassName"); }
-
-  // Return the methods of this interface.
-  ArrayRef<OpInterfaceMethod> getMethods() const { return methods; }
-
-  // Return the description of this method if it has one.
-  llvm::Optional<StringRef> getDescription() const {
-    auto value = def->getValueAsString("description");
-    return value.empty() ? llvm::Optional<StringRef>() : value;
-  }
-
-protected:
-  // The TableGen definition of this interface.
-  const llvm::Record *def;
-
-  // The methods of this interface.
-  SmallVector<OpInterfaceMethod, 8> methods;
-};
-} // end anonymous namespace
+using mlir::tblgen::OpInterface;
+using mlir::tblgen::OpInterfaceMethod;
 
 // Emit the method name and argument list for the given method. If
 // 'addOperationArg' is true, then an Operation* argument is added to the
@@ -133,9 +44,10 @@ static void emitMethodNameAndArgs(const OpInterfaceMethod &method,
   os << method.getName() << '(';
   if (addOperationArg)
     os << "Operation *tablegen_opaque_op" << (method.arg_empty() ? "" : ", ");
-  interleaveComma(method.getArguments(), os, [&](const MethodArgument &arg) {
-    os << arg.type << " " << arg.name;
-  });
+  interleaveComma(method.getArguments(), os,
+                  [&](const OpInterfaceMethod::Argument &arg) {
+                    os << arg.type << " " << arg.name;
+                  });
   os << ')';
 }
 
@@ -155,8 +67,9 @@ static void emitInterfaceDef(OpInterface &interface, raw_ostream &os) {
     os << " {\n      return getImpl()->" << method.getName() << '(';
     if (!method.isStatic())
       os << "getOperation()" << (method.arg_empty() ? "" : ", ");
-    interleaveComma(method.getArguments(), os,
-                    [&](const MethodArgument &arg) { os << arg.name; });
+    interleaveComma(
+        method.getArguments(), os,
+        [&](const OpInterfaceMethod::Argument &arg) { os << arg.name; });
     os << ");\n  }\n";
   }
 }
@@ -218,8 +131,9 @@ static void emitModelDecl(OpInterface &interface, raw_ostream &os) {
 
     // Add the arguments to the call.
     os << method.getName() << '(';
-    interleaveComma(method.getArguments(), os,
-                    [&](const MethodArgument &arg) { os << arg.name; });
+    interleaveComma(
+        method.getArguments(), os,
+        [&](const OpInterfaceMethod::Argument &arg) { os << arg.name; });
     os << ");\n    }\n";
   }
   os << "  };\n";
@@ -294,9 +208,10 @@ static void emitInterfaceDoc(const Record &interfaceDef, raw_ostream &os) {
     if (method.isStatic())
       os << "static ";
     emitCPPType(method.getReturnType(), os) << method.getName() << '(';
-    interleaveComma(method.getArguments(), os, [&](const MethodArgument &arg) {
-      emitCPPType(arg.type, os) << arg.name;
-    });
+    interleaveComma(method.getArguments(), os,
+                    [&](const OpInterfaceMethod::Argument &arg) {
+                      emitCPPType(arg.type, os) << arg.name;
+                    });
     os << ");\n```\n";
 
     // Emit the description.