#include "mlir/TableGen/SideEffects.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Regex.h"
#include "llvm/Support/Signals.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
using namespace mlir;
using namespace mlir::tblgen;
+cl::OptionCategory opDefGenCat("Options for -gen-op-defs and -gen-op-decls");
+
+static cl::opt<std::string>
+ opFilter("op-regex",
+ cl::desc("Regex of name of op's to filter (no filter if empty)"),
+ cl::cat(opDefGenCat));
+
static const char *const tblgenNamePrefix = "tblgen_";
static const char *const generatedArgName = "odsArg";
static const char *const builderOpState = "odsState";
[&os]() { os << ",\n"; });
}
+static std::string getOperationName(const Record &def) {
+ auto prefix = def.getValueAsDef("opDialect")->getValueAsString("name");
+ auto opName = def.getValueAsString("opName");
+ if (prefix.empty())
+ return std::string(opName);
+ return std::string(llvm::formatv("{0}.{1}", prefix, opName));
+}
+
+static std::vector<Record *>
+getAllDerivedDefinitions(const RecordKeeper &recordKeeper,
+ StringRef className) {
+ Record *classDef = recordKeeper.getClass(className);
+ if (!classDef)
+ PrintFatalError("ERROR: Couldn't find the `" + className + "' class!\n");
+
+ llvm::Regex includeRegex(opFilter);
+ std::vector<Record *> defs;
+ for (const auto &def : recordKeeper.getDefs()) {
+ if (def.second->isSubClassOf(classDef)) {
+ if (opFilter.empty() || includeRegex.match(getOperationName(*def.second)))
+ defs.push_back(def.second.get());
+ }
+ }
+
+ return defs;
+}
+
static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
emitSourceFileHeader("Op Declarations", os);
- const auto &defs = recordKeeper.getAllDerivedDefinitions("Op");
+ const auto &defs = getAllDerivedDefinitions(recordKeeper, "Op");
emitOpClasses(defs, os, /*emitDecl=*/true);
return false;
static bool emitOpDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
emitSourceFileHeader("Op Definitions", os);
- const auto &defs = recordKeeper.getAllDerivedDefinitions("Op");
+ const auto &defs = getAllDerivedDefinitions(recordKeeper, "Op");
emitOpList(defs, os);
emitOpClasses(defs, os, /*emitDecl=*/false);