[mlir][ods] Allow filtering of ops
authorJacques Pienaar <jpienaar@google.com>
Mon, 22 Jun 2020 21:56:54 +0000 (14:56 -0700)
committerJacques Pienaar <jpienaar@google.com>
Mon, 22 Jun 2020 21:56:54 +0000 (14:56 -0700)
Add option to filter which op the OpDefinitionsGen run on. This enables having multiple ops together in the same TD file but generating different CC files for them (useful if one wants to use multiclasses or split out 1 dialect into multiple different libraries). There is probably more general query here (e.g., split out all ops that don't have a verify method, or that are commutative) but filtering based on op name (e.g., test.a_op) seemed a reasonable start and didn't require inventing a query specification mechanism here.

Differential Revision: https://reviews.llvm.org/D82319

mlir/test/mlir-tblgen/op-decl.td
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

index f5bf03e..8d58d5b 100644 (file)
@@ -1,4 +1,5 @@
 // RUN: mlir-tblgen -gen-op-decls -I %S/../../include %s | FileCheck  %s
+// RUN: mlir-tblgen -gen-op-decls -op-regex="test.a_op" -I %S/../../include %s | FileCheck  %s --check-prefix=REDUCE
 
 include "mlir/IR/OpBase.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
@@ -195,3 +196,5 @@ def _BOp : NS_Op<"_op_with_leading_underscore_and_no_namespace", []>;
 // CHECK-LABEL: _BOp declarations
 // CHECK: class _BOp : public Op<_BOp
 
+// REDUCE-LABEL: NS::AOp declarations
+// REDUCE-NOT: NS::BOp declarations
index 21dccd4..6aa7b01 100644 (file)
@@ -21,6 +21,8 @@
 #include "mlir/TableGen/SideEffects.h"
 #include "llvm/ADT/Sequence.h"
 #include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Regex.h"
 #include "llvm/Support/Signals.h"
 #include "llvm/TableGen/Error.h"
 #include "llvm/TableGen/Record.h"
@@ -32,6 +34,13 @@ using namespace llvm;
 using namespace mlir;
 using namespace mlir::tblgen;
 
+cl::OptionCategory opDefGenCat("Options for -gen-op-defs and -gen-op-decls");
+
+static cl::opt<std::string>
+    opFilter("op-regex",
+             cl::desc("Regex of name of op's to filter (no filter if empty)"),
+             cl::cat(opDefGenCat));
+
 static const char *const tblgenNamePrefix = "tblgen_";
 static const char *const generatedArgName = "odsArg";
 static const char *const builderOpState = "odsState";
@@ -2081,10 +2090,37 @@ static void emitOpList(const std::vector<Record *> &defs, raw_ostream &os) {
       [&os]() { os << ",\n"; });
 }
 
+static std::string getOperationName(const Record &def) {
+  auto prefix = def.getValueAsDef("opDialect")->getValueAsString("name");
+  auto opName = def.getValueAsString("opName");
+  if (prefix.empty())
+    return std::string(opName);
+  return std::string(llvm::formatv("{0}.{1}", prefix, opName));
+}
+
+static std::vector<Record *>
+getAllDerivedDefinitions(const RecordKeeper &recordKeeper,
+                         StringRef className) {
+  Record *classDef = recordKeeper.getClass(className);
+  if (!classDef)
+    PrintFatalError("ERROR: Couldn't find the `" + className + "' class!\n");
+
+  llvm::Regex includeRegex(opFilter);
+  std::vector<Record *> defs;
+  for (const auto &def : recordKeeper.getDefs()) {
+    if (def.second->isSubClassOf(classDef)) {
+      if (opFilter.empty() || includeRegex.match(getOperationName(*def.second)))
+        defs.push_back(def.second.get());
+    }
+  }
+
+  return defs;
+}
+
 static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
   emitSourceFileHeader("Op Declarations", os);
 
-  const auto &defs = recordKeeper.getAllDerivedDefinitions("Op");
+  const auto &defs = getAllDerivedDefinitions(recordKeeper, "Op");
   emitOpClasses(defs, os, /*emitDecl=*/true);
 
   return false;
@@ -2093,7 +2129,7 @@ static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
 static bool emitOpDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
   emitSourceFileHeader("Op Definitions", os);
 
-  const auto &defs = recordKeeper.getAllDerivedDefinitions("Op");
+  const auto &defs = getAllDerivedDefinitions(recordKeeper, "Op");
   emitOpList(defs, os);
   emitOpClasses(defs, os, /*emitDecl=*/false);