Implement Pass and Dialect plugins for mlir-opt
authorFabian Mora <fmorac@udel.edu>
Fri, 7 Apr 2023 00:01:00 +0000 (17:01 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Fri, 7 Apr 2023 01:28:50 +0000 (18:28 -0700)
Implementation of Pass and Dialect Plugins that mirrors LLVM Pass Plugin
implementation from the new pass manager.

Currently the implementation only supports using the pass-pipeline option
for adding passes. This restriction is imposed by the `PassPipelineCLParser`
variable in mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:114 that loads the
parse options statically before parsing the cmd line args.

```
mlir-opt stanalone-plugin.mlir --load-dialect-plugin=lib/libStandalonePlugin.so --pass-pipeline="builtin.module(standalone-switch-bar-foo)"
```

Reviewed By: rriddle, mehdi_amini

Differential Revision: https://reviews.llvm.org/D147053

22 files changed:
mlir/examples/standalone/CMakeLists.txt
mlir/examples/standalone/include/Standalone/CMakeLists.txt
mlir/examples/standalone/include/Standalone/StandalonePasses.h [new file with mode: 0644]
mlir/examples/standalone/include/Standalone/StandalonePasses.td [new file with mode: 0644]
mlir/examples/standalone/lib/Standalone/CMakeLists.txt
mlir/examples/standalone/lib/Standalone/StandalonePasses.cpp [new file with mode: 0644]
mlir/examples/standalone/standalone-opt/standalone-opt.cpp
mlir/examples/standalone/standalone-plugin/CMakeLists.txt [new file with mode: 0644]
mlir/examples/standalone/standalone-plugin/standalone-plugin.cpp [new file with mode: 0644]
mlir/examples/standalone/test/Standalone/standalone-pass-plugin.mlir [new file with mode: 0644]
mlir/examples/standalone/test/Standalone/standalone-plugin.mlir [new file with mode: 0644]
mlir/examples/standalone/test/lit.cfg.py
mlir/include/mlir/Tools/Plugins/DialectPlugin.h [new file with mode: 0644]
mlir/include/mlir/Tools/Plugins/PassPlugin.h [new file with mode: 0644]
mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
mlir/lib/Tools/CMakeLists.txt
mlir/lib/Tools/Plugins/CMakeLists.txt [new file with mode: 0644]
mlir/lib/Tools/Plugins/DialectPlugin.cpp [new file with mode: 0644]
mlir/lib/Tools/Plugins/PassPlugin.cpp [new file with mode: 0644]
mlir/lib/Tools/mlir-opt/CMakeLists.txt
mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
mlir/tools/mlir-opt/CMakeLists.txt

index d36a6ba..65461c0 100644 (file)
@@ -52,4 +52,5 @@ if(MLIR_ENABLE_BINDINGS_PYTHON)
 endif()
 add_subdirectory(test)
 add_subdirectory(standalone-opt)
+add_subdirectory(standalone-plugin)
 add_subdirectory(standalone-translate)
index 8acf640..975a4ff 100644 (file)
@@ -1,3 +1,7 @@
 add_mlir_dialect(StandaloneOps standalone)
 add_mlir_doc(StandaloneDialect StandaloneDialect Standalone/ -gen-dialect-doc)
 add_mlir_doc(StandaloneOps StandaloneOps Standalone/ -gen-op-doc)
+
+set(LLVM_TARGET_DEFINITIONS StandalonePasses.td)
+mlir_tablegen(StandalonePasses.h.inc --gen-pass-decls)
+add_public_tablegen_target(MLIRStandalonePassesIncGen)
diff --git a/mlir/examples/standalone/include/Standalone/StandalonePasses.h b/mlir/examples/standalone/include/Standalone/StandalonePasses.h
new file mode 100644 (file)
index 0000000..75546d6
--- /dev/null
@@ -0,0 +1,26 @@
+//===- StandalonePasses.h - Standalone passes  ------------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef STANDALONE_STANDALONEPASSES_H
+#define STANDALONE_STANDALONEPASSES_H
+
+#include "Standalone/StandaloneDialect.h"
+#include "Standalone/StandaloneOps.h"
+#include "mlir/Pass/Pass.h"
+#include <memory>
+
+namespace mlir {
+namespace standalone {
+#define GEN_PASS_DECL
+#include "Standalone/StandalonePasses.h.inc"
+
+#define GEN_PASS_REGISTRATION
+#include "Standalone/StandalonePasses.h.inc"
+} // namespace standalone
+} // namespace mlir
+
+#endif
diff --git a/mlir/examples/standalone/include/Standalone/StandalonePasses.td b/mlir/examples/standalone/include/Standalone/StandalonePasses.td
new file mode 100644 (file)
index 0000000..4cb2be0
--- /dev/null
@@ -0,0 +1,30 @@
+//===- StandalonePsss.td - Standalone dialect passes -------*- tablegen -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef STANDALONE_PASS
+#define STANDALONE_PASS
+
+include "mlir/Pass/PassBase.td"
+
+def StandaloneSwitchBarFoo: Pass<"standalone-switch-bar-foo", "::mlir::ModuleOp"> {
+  let summary = "Switches the name of a FuncOp named `bar` to `foo` and folds.";
+  let description = [{
+    Switches the name of a FuncOp named `bar` to `foo` and folds.
+    ```
+    func.func @bar() {
+      return
+    }
+    // Gets transformed to:
+    func.func @foo() {
+      return
+    }
+    ```
+  }];
+}
+
+#endif // STANDALONE_PASS
index 599c1c5..0e2d043 100644 (file)
@@ -2,14 +2,17 @@ add_mlir_dialect_library(MLIRStandalone
         StandaloneTypes.cpp
         StandaloneDialect.cpp
         StandaloneOps.cpp
+        StandalonePasses.cpp
 
         ADDITIONAL_HEADER_DIRS
         ${PROJECT_SOURCE_DIR}/include/Standalone
 
         DEPENDS
         MLIRStandaloneOpsIncGen
+        MLIRStandalonePassesIncGen
 
        LINK_LIBS PUBLIC
        MLIRIR
         MLIRInferTypeOpInterface
+        MLIRFuncDialect
        )
diff --git a/mlir/examples/standalone/lib/Standalone/StandalonePasses.cpp b/mlir/examples/standalone/lib/Standalone/StandalonePasses.cpp
new file mode 100644 (file)
index 0000000..6af45c9
--- /dev/null
@@ -0,0 +1,48 @@
+//===- StandalonePasses.cpp - Standalone passes -----------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include "Standalone/StandalonePasses.h"
+
+namespace mlir::standalone {
+#define GEN_PASS_DEF_STANDALONESWITCHBARFOO
+#include "Standalone/StandalonePasses.h.inc"
+
+namespace {
+class StandaloneSwitchBarFooRewriter : public OpRewritePattern<func::FuncOp> {
+public:
+  using OpRewritePattern<func::FuncOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(func::FuncOp op,
+                                PatternRewriter &rewriter) const final {
+    if (op.getSymName() == "bar") {
+      rewriter.updateRootInPlace(op, [&op]() { op.setSymName("foo"); });
+      return success();
+    }
+    return failure();
+  }
+};
+
+class StandaloneSwitchBarFoo
+    : public impl::StandaloneSwitchBarFooBase<StandaloneSwitchBarFoo> {
+public:
+  using impl::StandaloneSwitchBarFooBase<
+      StandaloneSwitchBarFoo>::StandaloneSwitchBarFooBase;
+  void runOnOperation() final {
+    RewritePatternSet patterns(&getContext());
+    patterns.add<StandaloneSwitchBarFooRewriter>(&getContext());
+    FrozenRewritePatternSet patternSet(std::move(patterns));
+    if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet)))
+      signalPassFailure();
+  }
+};
+} // namespace
+} // namespace mlir::standalone
index 4cf7e15..e75db35 100644 (file)
 #include "llvm/Support/ToolOutputFile.h"
 
 #include "Standalone/StandaloneDialect.h"
+#include "Standalone/StandalonePasses.h"
 
 int main(int argc, char **argv) {
   mlir::registerAllPasses();
+  mlir::standalone::registerPasses();
   // TODO: Register standalone passes here.
 
   mlir::DialectRegistry registry;
diff --git a/mlir/examples/standalone/standalone-plugin/CMakeLists.txt b/mlir/examples/standalone/standalone-plugin/CMakeLists.txt
new file mode 100644 (file)
index 0000000..961a3ea
--- /dev/null
@@ -0,0 +1,22 @@
+get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
+get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
+set(LIBS
+        MLIRIR
+        MLIRPass
+        MLIRPluginsLib
+        MLIRStandalone
+        MLIRTransformUtils
+        )
+
+add_mlir_dialect_library(StandalonePlugin
+        SHARED
+        standalone-plugin.cpp
+
+        DEPENDS
+        MLIRStandalone
+        )
+
+llvm_update_compile_flags(StandalonePlugin)
+target_link_libraries(StandalonePlugin PRIVATE ${LIBS})
+
+mlir_check_all_link_libraries(StandalonePlugin)
diff --git a/mlir/examples/standalone/standalone-plugin/standalone-plugin.cpp b/mlir/examples/standalone/standalone-plugin/standalone-plugin.cpp
new file mode 100644 (file)
index 0000000..129b86b
--- /dev/null
@@ -0,0 +1,39 @@
+//===- standalone-plugin.cpp ------------------------------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/InitAllDialects.h"
+#include "mlir/InitAllPasses.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Tools/Plugins/DialectPlugin.h"
+
+#include "Standalone/StandaloneDialect.h"
+#include "Standalone/StandalonePasses.h"
+
+using namespace mlir;
+
+/// Dialect plugin registration mechanism.
+/// Observe that it also allows to register passes.
+/// Necessary symbol to register the dialect plugin.
+extern "C" LLVM_ATTRIBUTE_WEAK DialectPluginLibraryInfo
+mlirGetDialectPluginInfo() {
+  return {MLIR_PLUGIN_API_VERSION, "Standalone", LLVM_VERSION_STRING,
+          [](DialectRegistry *registry) {
+            registry->insert<mlir::standalone::StandaloneDialect>();
+            mlir::standalone::registerPasses();
+          }};
+}
+
+/// Pass plugin registration mechanism.
+/// Necessary symbol to register the pass plugin.
+extern "C" LLVM_ATTRIBUTE_WEAK PassPluginLibraryInfo mlirGetPassPluginInfo() {
+  return {MLIR_PLUGIN_API_VERSION, "StandalonePasses", LLVM_VERSION_STRING,
+          []() { mlir::standalone::registerPasses(); }};
+}
diff --git a/mlir/examples/standalone/test/Standalone/standalone-pass-plugin.mlir b/mlir/examples/standalone/test/Standalone/standalone-pass-plugin.mlir
new file mode 100644 (file)
index 0000000..5af4b3d
--- /dev/null
@@ -0,0 +1,13 @@
+// RUN: mlir-opt %s --load-pass-plugin=%standalone_libs/libStandalonePlugin.so --pass-pipeline="builtin.module(standalone-switch-bar-foo)" | FileCheck %s
+
+module {
+  // CHECK-LABEL: func @foo()
+  func.func @bar() {
+    return
+  }
+
+  // CHECK-LABEL: func @abar()
+  func.func @abar() {
+    return
+  }
+}
diff --git a/mlir/examples/standalone/test/Standalone/standalone-plugin.mlir b/mlir/examples/standalone/test/Standalone/standalone-plugin.mlir
new file mode 100644 (file)
index 0000000..3f935db
--- /dev/null
@@ -0,0 +1,13 @@
+// RUN: mlir-opt %s --load-dialect-plugin=%standalone_libs/libStandalonePlugin.so --pass-pipeline="builtin.module(standalone-switch-bar-foo)" | FileCheck %s
+
+module {
+  // CHECK-LABEL: func @foo()
+  func.func @bar() {
+    return
+  }
+
+  // CHECK-LABEL: func @standalone_types(%arg0: !standalone.custom<"10">)
+  func.func @standalone_types(%arg0: !standalone.custom<"10">) {
+    return
+  }
+}
index a6a3d24..3e4ceee 100644 (file)
@@ -44,12 +44,16 @@ config.excludes = ['Inputs', 'Examples', 'CMakeLists.txt', 'README.txt', 'LICENS
 # test_exec_root: The root path where tests should be run.
 config.test_exec_root = os.path.join(config.standalone_obj_root, 'test')
 config.standalone_tools_dir = os.path.join(config.standalone_obj_root, 'bin')
+config.standalone_libs_dir = os.path.join(config.standalone_obj_root, 'lib')
+
+config.substitutions.append(('%standalone_libs', config.standalone_libs_dir))
 
 # Tweak the PATH to include the tools dir.
 llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True)
 
 tool_dirs = [config.standalone_tools_dir, config.llvm_tools_dir]
 tools = [
+    'mlir-opt',
     'standalone-capi-test',
     'standalone-opt',
     'standalone-translate',
diff --git a/mlir/include/mlir/Tools/Plugins/DialectPlugin.h b/mlir/include/mlir/Tools/Plugins/DialectPlugin.h
new file mode 100644 (file)
index 0000000..1373e76
--- /dev/null
@@ -0,0 +1,106 @@
+//===- mlir/Tools/Plugins/DialectPlugin.h - Public Plugin API -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This defines the public entry point for dialect plugins.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TOOLS_PLUGINS_DIALECTPLUGIN_H
+#define MLIR_TOOLS_PLUGINS_DIALECTPLUGIN_H
+
+#include "mlir/IR/DialectRegistry.h"
+#include "mlir/Tools/Plugins/PassPlugin.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Compiler.h"
+#include "llvm/Support/DynamicLibrary.h"
+#include "llvm/Support/Error.h"
+#include <cstdint>
+#include <string>
+
+namespace mlir {
+extern "C" {
+/// Information about the plugin required to load its dialects & passes
+///
+/// This struct defines the core interface for dialect plugins and is supposed
+/// to be filled out by plugin implementors. MLIR-side users of a plugin are
+/// expected to use the \c DialectPlugin class below to interface with it.
+struct DialectPluginLibraryInfo {
+  /// The API version understood by this plugin, usually
+  /// \c MLIR_PLUGIN_API_VERSION
+  uint32_t apiVersion;
+  /// A meaningful name of the plugin.
+  const char *pluginName;
+  /// The version of the plugin.
+  const char *pluginVersion;
+
+  /// The callback for registering dialect plugin with a \c DialectRegistry
+  /// instance
+  void (*registerDialectRegistryCallbacks)(DialectRegistry *);
+};
+}
+
+/// A loaded dialect plugin.
+///
+/// An instance of this class wraps a loaded dialect plugin and gives access to
+/// its interface defined by the \c DialectPluginLibraryInfo it exposes.
+class DialectPlugin {
+public:
+  /// Attempts to load a dialect plugin from a given file.
+  ///
+  /// \returns Returns an error if either the library cannot be found or loaded,
+  /// there is no public entry point, or the plugin implements the wrong API
+  /// version.
+  static llvm::Expected<DialectPlugin> load(const std::string &filename);
+
+  /// Get the filename of the loaded plugin.
+  StringRef getFilename() const { return filename; }
+
+  /// Get the plugin name
+  StringRef getPluginName() const { return info.pluginName; }
+
+  /// Get the plugin version
+  StringRef getPluginVersion() const { return info.pluginVersion; }
+
+  /// Get the plugin API version
+  uint32_t getAPIVersion() const { return info.apiVersion; }
+
+  /// Invoke the DialectRegistry callback registration
+  void
+  registerDialectRegistryCallbacks(DialectRegistry &dialectRegistry) const {
+    info.registerDialectRegistryCallbacks(&dialectRegistry);
+  }
+
+private:
+  DialectPlugin(const std::string &filename,
+                const llvm::sys::DynamicLibrary &library)
+      : filename(filename), library(library), info() {}
+
+  std::string filename;
+  llvm::sys::DynamicLibrary library;
+  DialectPluginLibraryInfo info;
+};
+} // namespace mlir
+
+/// The public entry point for a dialect plugin.
+///
+/// When a plugin is loaded by the driver, it will call this entry point to
+/// obtain information about this plugin and about how to register its dialects.
+/// This function needs to be implemented by the plugin, see the example below:
+///
+/// ```
+/// extern "C" ::mlir::DialectPluginLibraryInfo LLVM_ATTRIBUTE_WEAK
+/// mlirGetDialectPluginInfo() {
+///   return {
+///     MLIR_PLUGIN_API_VERSION, "MyPlugin", "v0.1", [](DialectRegistry) { ... }
+///   };
+/// }
+/// ```
+extern "C" ::mlir::DialectPluginLibraryInfo LLVM_ATTRIBUTE_WEAK
+mlirGetDialectPluginInfo();
+
+#endif /* MLIR_TOOLS_PLUGINS_DIALECTPLUGIN_H */
diff --git a/mlir/include/mlir/Tools/Plugins/PassPlugin.h b/mlir/include/mlir/Tools/Plugins/PassPlugin.h
new file mode 100644 (file)
index 0000000..cd8cc38
--- /dev/null
@@ -0,0 +1,112 @@
+//===- mlir/Tools/Plugins/PassPlugin.h - Public Plugin API ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This defines the public entry point for pass plugins.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TOOLS_PLUGINS_PASSPLUGIN_H
+#define MLIR_TOOLS_PLUGINS_PASSPLUGIN_H
+
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Compiler.h"
+#include "llvm/Support/DynamicLibrary.h"
+#include "llvm/Support/Error.h"
+#include <cstdint>
+#include <string>
+
+namespace mlir {
+/// \macro MLIR_PLUGIN_API_VERSION
+/// Identifies the API version understood by this plugin.
+///
+/// When a plugin is loaded, the driver will check it's supported plugin version
+/// against that of the plugin. A mismatch is an error. The supported version
+/// will be incremented for ABI-breaking changes to the \c PassPluginLibraryInfo
+/// struct, i.e. when callbacks are added, removed, or reordered.
+#define MLIR_PLUGIN_API_VERSION 1
+
+extern "C" {
+/// Information about the plugin required to load its passes
+///
+/// This struct defines the core interface for pass plugins and is supposed to
+/// be filled out by plugin implementors. LLVM-side users of a plugin are
+/// expected to use the \c PassPlugin class below to interface with it.
+struct PassPluginLibraryInfo {
+  /// The API version understood by this plugin, usually \c
+  /// MLIR_PLUGIN_API_VERSION
+  uint32_t apiVersion;
+  /// A meaningful name of the plugin.
+  const char *pluginName;
+  /// The version of the plugin.
+  const char *pluginVersion;
+
+  /// The callback for registering plugin passes.
+  void (*registerPassRegistryCallbacks)();
+};
+}
+
+/// A loaded pass plugin.
+///
+/// An instance of this class wraps a loaded pass plugin and gives access to
+/// its interface defined by the \c PassPluginLibraryInfo it exposes.
+class PassPlugin {
+public:
+  /// Attempts to load a pass plugin from a given file.
+  ///
+  /// \returns Returns an error if either the library cannot be found or loaded,
+  /// there is no public entry point, or the plugin implements the wrong API
+  /// version.
+  static llvm::Expected<PassPlugin> load(const std::string &filename);
+
+  /// Get the filename of the loaded plugin.
+  StringRef getFilename() const { return filename; }
+
+  /// Get the plugin name
+  StringRef getPluginName() const { return info.pluginName; }
+
+  /// Get the plugin version
+  StringRef getPluginVersion() const { return info.pluginVersion; }
+
+  /// Get the plugin API version
+  uint32_t getAPIVersion() const { return info.apiVersion; }
+
+  /// Invoke the PassRegistry callback registration
+  void registerPassRegistryCallbacks() const {
+    info.registerPassRegistryCallbacks();
+  }
+
+private:
+  PassPlugin(const std::string &filename,
+             const llvm::sys::DynamicLibrary &library)
+      : filename(filename), library(library), info() {}
+
+  std::string filename;
+  llvm::sys::DynamicLibrary library;
+  PassPluginLibraryInfo info;
+};
+} // namespace mlir
+
+/// The public entry point for a pass plugin.
+///
+/// When a plugin is loaded by the driver, it will call this entry point to
+/// obtain information about this plugin and about how to register its passes.
+/// This function needs to be implemented by the plugin, see the example below:
+///
+/// ```
+/// extern "C" ::mlir::PassPluginLibraryInfo LLVM_ATTRIBUTE_WEAK
+/// mlirGetPassPluginInfo() {
+///   return {
+///     MLIR_PLUGIN_API_VERSION, "MyPlugin", "v0.1", []() { ... }
+///   };
+/// }
+/// ```
+extern "C" ::mlir::PassPluginLibraryInfo LLVM_ATTRIBUTE_WEAK
+mlirGetPassPluginInfo();
+
+#endif /* MLIR_TOOLS_PLUGINS_PASSPLUGIN_H */
index bcefc4f..f54c29c 100644 (file)
@@ -35,12 +35,12 @@ class PassManager;
 /// supported options.
 /// The API is fluent, and the options are sorted in alphabetical order below.
 /// The options can be exposed to the LLVM command line by registering them
-/// with `MlirOptMainConfig::registerCLOptions();` and creating a config using
-/// `auto config = MlirOptMainConfig::createFromCLOptions();`.
+/// with `MlirOptMainConfig::registerCLOptions(DialectRegistry &);` and creating
+/// a config using `auto config = MlirOptMainConfig::createFromCLOptions();`.
 class MlirOptMainConfig {
 public:
   /// Register the options as global LLVM command line options.
-  static void registerCLOptions();
+  static void registerCLOptions(DialectRegistry &dialectRegistry);
 
   /// Create a new config with the default set from the CL options.
   static MlirOptMainConfig createFromCLOptions();
index 6dab371..6175a1c 100644 (file)
@@ -6,4 +6,5 @@ add_subdirectory(mlir-reduce)
 add_subdirectory(mlir-tblgen)
 add_subdirectory(mlir-translate)
 add_subdirectory(PDLL)
+add_subdirectory(Plugins)
 add_subdirectory(tblgen-lsp-server)
diff --git a/mlir/lib/Tools/Plugins/CMakeLists.txt b/mlir/lib/Tools/Plugins/CMakeLists.txt
new file mode 100644 (file)
index 0000000..59dc9e5
--- /dev/null
@@ -0,0 +1,12 @@
+add_mlir_library(MLIRPluginsLib
+  DialectPlugin.cpp
+  PassPlugin.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Tools/Plugins
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRPass
+  MLIRSupport
+  )
diff --git a/mlir/lib/Tools/Plugins/DialectPlugin.cpp b/mlir/lib/Tools/Plugins/DialectPlugin.cpp
new file mode 100644 (file)
index 0000000..1443399
--- /dev/null
@@ -0,0 +1,53 @@
+//===- lib/Tools/Plugins/DialectPlugin.cpp - Load Dialect Plugins ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Tools/Plugins/DialectPlugin.h"
+#include "llvm/Support/raw_ostream.h"
+
+#include <cstdint>
+
+using namespace mlir;
+
+llvm::Expected<DialectPlugin> DialectPlugin::load(const std::string &filename) {
+  std::string error;
+  auto library =
+      llvm::sys::DynamicLibrary::getPermanentLibrary(filename.c_str(), &error);
+  if (!library.isValid())
+    return llvm::make_error<llvm::StringError>(
+        Twine("Could not load library '") + filename + "': " + error,
+        llvm::inconvertibleErrorCode());
+
+  DialectPlugin plugin{filename, library};
+
+  // mlirGetDialectPluginInfo should be resolved to the definition from the
+  // plugin we are currently loading.
+  intptr_t getDetailsFn =
+      (intptr_t)library.getAddressOfSymbol("mlirGetDialectPluginInfo");
+
+  if (!getDetailsFn)
+    return llvm::make_error<llvm::StringError>(
+        Twine("Plugin entry point not found in '") + filename,
+        llvm::inconvertibleErrorCode());
+
+  plugin.info =
+      reinterpret_cast<decltype(mlirGetDialectPluginInfo) *>(getDetailsFn)();
+
+  if (plugin.info.apiVersion != MLIR_PLUGIN_API_VERSION)
+    return llvm::make_error<llvm::StringError>(
+        Twine("Wrong API version on plugin '") + filename + "'. Got version " +
+            Twine(plugin.info.apiVersion) + ", supported version is " +
+            Twine(MLIR_PLUGIN_API_VERSION) + ".",
+        llvm::inconvertibleErrorCode());
+
+  if (!plugin.info.registerDialectRegistryCallbacks)
+    return llvm::make_error<llvm::StringError>(
+        Twine("Empty entry callback in plugin '") + filename + "'.'",
+        llvm::inconvertibleErrorCode());
+
+  return plugin;
+}
diff --git a/mlir/lib/Tools/Plugins/PassPlugin.cpp b/mlir/lib/Tools/Plugins/PassPlugin.cpp
new file mode 100644 (file)
index 0000000..98e15e4
--- /dev/null
@@ -0,0 +1,53 @@
+//===- lib/Tools/Plugins/PassPlugin.cpp - Load Plugins for PR Passes ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Tools/Plugins/PassPlugin.h"
+#include "llvm/Support/raw_ostream.h"
+
+#include <cstdint>
+
+using namespace mlir;
+
+llvm::Expected<PassPlugin> PassPlugin::load(const std::string &filename) {
+  std::string Error;
+  auto library =
+      llvm::sys::DynamicLibrary::getPermanentLibrary(filename.c_str(), &Error);
+  if (!library.isValid())
+    return llvm::make_error<llvm::StringError>(
+        Twine("Could not load library '") + filename + "': " + Error,
+        llvm::inconvertibleErrorCode());
+
+  PassPlugin plugin{filename, library};
+
+  // mlirGetPassPluginInfo should be resolved to the definition from the plugin
+  // we are currently loading.
+  intptr_t getDetailsFn =
+      (intptr_t)library.getAddressOfSymbol("mlirGetPassPluginInfo");
+
+  if (!getDetailsFn)
+    return llvm::make_error<llvm::StringError>(
+        Twine("Plugin entry point not found in '") + filename,
+        llvm::inconvertibleErrorCode());
+
+  plugin.info =
+      reinterpret_cast<decltype(mlirGetPassPluginInfo) *>(getDetailsFn)();
+
+  if (plugin.info.apiVersion != MLIR_PLUGIN_API_VERSION)
+    return llvm::make_error<llvm::StringError>(
+        Twine("Wrong API version on plugin '") + filename + "'. Got version " +
+            Twine(plugin.info.apiVersion) + ", supported version is " +
+            Twine(MLIR_PLUGIN_API_VERSION) + ".",
+        llvm::inconvertibleErrorCode());
+
+  if (!plugin.info.registerPassRegistryCallbacks)
+    return llvm::make_error<llvm::StringError>(
+        Twine("Empty entry callback in plugin '") + filename + "'.'",
+        llvm::inconvertibleErrorCode());
+
+  return plugin;
+}
index 983e855..a15677e 100644 (file)
@@ -10,5 +10,6 @@ add_mlir_library(MLIROptLib
   MLIRObservers
   MLIRPass
   MLIRParser
+  MLIRPluginsLib
   MLIRSupport
   )
index dfc5117..8f608ef 100644 (file)
@@ -30,6 +30,8 @@
 #include "mlir/Support/Timing.h"
 #include "mlir/Support/ToolUtilities.h"
 #include "mlir/Tools/ParseUtilities.h"
+#include "mlir/Tools/Plugins/DialectPlugin.h"
+#include "mlir/Tools/Plugins/PassPlugin.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/FileUtilities.h"
 #include "llvm/Support/InitLLVM.h"
@@ -101,15 +103,41 @@ struct MlirOptMainConfigCLOptions : public MlirOptMainConfig {
         cl::desc("Run the verifier after each transformation pass"),
         cl::location(verifyPassesFlag), cl::init(true));
 
+    static cl::list<std::string> passPlugins(
+        "load-pass-plugin", cl::desc("Load passes from plugin library"));
+    /// Set the callback to load a pass plugin.
+    passPlugins.setCallback([&](const std::string &pluginPath) {
+      auto plugin = PassPlugin::load(pluginPath);
+      if (!plugin) {
+        errs() << "Failed to load passes from '" << pluginPath
+               << "'. Request ignored.\n";
+        return;
+      }
+      plugin.get().registerPassRegistryCallbacks();
+    });
+
+    static cl::list<std::string> dialectPlugins(
+        "load-dialect-plugin", cl::desc("Load dialects from plugin library"));
+    this->dialectPlugins = std::addressof(dialectPlugins);
+
     static PassPipelineCLParser passPipeline("", "Compiler passes to run", "p");
     setPassPipelineParser(passPipeline);
   }
+
+  /// Set the callback to load a dialect plugin.
+  void setDialectPluginsCallback(DialectRegistry &registry);
+
+  /// Pointer to static dialectPlugins variable in constructor, needed by
+  /// setDialectPluginsCallback(DialectRegistry&).
+  cl::list<std::string> *dialectPlugins = nullptr;
 };
 } // namespace
 
 ManagedStatic<MlirOptMainConfigCLOptions> clOptionsConfig;
 
-void MlirOptMainConfig::registerCLOptions() { *clOptionsConfig; }
+void MlirOptMainConfig::registerCLOptions(DialectRegistry &registry) {
+  clOptionsConfig->setDialectPluginsCallback(registry);
+}
 
 MlirOptMainConfig MlirOptMainConfig::createFromCLOptions() {
   return *clOptionsConfig;
@@ -134,6 +162,19 @@ MlirOptMainConfig &MlirOptMainConfig::setPassPipelineParser(
   return *this;
 }
 
+void MlirOptMainConfigCLOptions::setDialectPluginsCallback(
+    DialectRegistry &registry) {
+  dialectPlugins->setCallback([&](const std::string &pluginPath) {
+    auto plugin = DialectPlugin::load(pluginPath);
+    if (!plugin) {
+      errs() << "Failed to load dialect plugin from '" << pluginPath
+             << "'. Request ignored.\n";
+      return;
+    };
+    plugin.get().registerDialectRegistryCallbacks(registry);
+  });
+}
+
 /// Set the ExecutionContext on the context and handle the observers.
 class InstallDebugHandler {
 public:
@@ -365,7 +406,7 @@ LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName,
   InitLLVM y(argc, argv);
 
   // Register any command line options.
-  MlirOptMainConfig::registerCLOptions();
+  MlirOptMainConfig::registerCLOptions(registry);
   registerAsmPrinterCLOptions();
   registerMLIRContextCLOptions();
   registerPassManagerCLOptions();
index c430569..2d021d5 100644 (file)
@@ -79,8 +79,10 @@ add_mlir_tool(mlir-opt
 
   DEPENDS
   ${LIBS}
+  SUPPORT_PLUGINS
   )
 target_link_libraries(mlir-opt PRIVATE ${LIBS})
 llvm_update_compile_flags(mlir-opt)
 
 mlir_check_all_link_libraries(mlir-opt)
+export_executable_symbols_for_plugins(mlir-opt)