From ac0a70f3737ecb2c0586c00240d14e46ff00644e Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Thu, 22 Apr 2021 17:32:10 +0200 Subject: [PATCH] [mlir] Split out Python bindings entry point into a separate file This will allow the bindings to be built as a library and reused in out-of-tree projects that want to provide bindings on top of MLIR bindings. Reviewed By: stellaraccident, mikeurbach Differential Revision: https://reviews.llvm.org/D101075 --- mlir/lib/Bindings/Python/CMakeLists.txt | 1 + mlir/lib/Bindings/Python/IRModule.cpp | 146 ++++++++++++++++++++++++++++++++ mlir/lib/Bindings/Python/MainModule.cpp | 129 ---------------------------- 3 files changed, 147 insertions(+), 129 deletions(-) create mode 100644 mlir/lib/Bindings/Python/IRModule.cpp diff --git a/mlir/lib/Bindings/Python/CMakeLists.txt b/mlir/lib/Bindings/Python/CMakeLists.txt index bbccea6..580405f 100644 --- a/mlir/lib/Bindings/Python/CMakeLists.txt +++ b/mlir/lib/Bindings/Python/CMakeLists.txt @@ -84,6 +84,7 @@ add_mlir_python_extension(MLIRCoreBindingsPythonExtension _mlir IRAffine.cpp IRAttributes.cpp IRCore.cpp + IRModule.cpp IRTypes.cpp PybindUtils.cpp Pass.cpp diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp new file mode 100644 index 0000000..08ce06d --- /dev/null +++ b/mlir/lib/Bindings/Python/IRModule.cpp @@ -0,0 +1,146 @@ +//===- IRModule.cpp - IR pybind module ------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "IRModule.h" +#include "Globals.h" +#include "PybindUtils.h" + +#include + +namespace py = pybind11; +using namespace mlir; +using namespace mlir::python; + +// ----------------------------------------------------------------------------- +// PyGlobals +// ----------------------------------------------------------------------------- + +PyGlobals *PyGlobals::instance = nullptr; + +PyGlobals::PyGlobals() { + assert(!instance && "PyGlobals already constructed"); + instance = this; +} + +PyGlobals::~PyGlobals() { instance = nullptr; } + +void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { + py::gil_scoped_acquire(); + if (loadedDialectModulesCache.contains(dialectNamespace)) + return; + // Since re-entrancy is possible, make a copy of the search prefixes. + std::vector localSearchPrefixes = dialectSearchPrefixes; + py::object loaded; + for (std::string moduleName : localSearchPrefixes) { + moduleName.push_back('.'); + moduleName.append(dialectNamespace.data(), dialectNamespace.size()); + + try { + py::gil_scoped_release(); + loaded = py::module::import(moduleName.c_str()); + } catch (py::error_already_set &e) { + if (e.matches(PyExc_ModuleNotFoundError)) { + continue; + } else { + throw; + } + } + break; + } + + // Note: Iterator cannot be shared from prior to loading, since re-entrancy + // may have occurred, which may do anything. + loadedDialectModulesCache.insert(dialectNamespace); +} + +void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, + py::object pyClass) { + py::gil_scoped_acquire(); + py::object &found = dialectClassMap[dialectNamespace]; + if (found) { + throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") + + dialectNamespace + + "' is already registered."); + } + found = std::move(pyClass); +} + +void PyGlobals::registerOperationImpl(const std::string &operationName, + py::object pyClass, + py::object rawOpViewClass) { + py::gil_scoped_acquire(); + py::object &found = operationClassMap[operationName]; + if (found) { + throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") + + operationName + + "' is already registered."); + } + found = std::move(pyClass); + rawOpViewClassMap[operationName] = std::move(rawOpViewClass); +} + +llvm::Optional +PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { + py::gil_scoped_acquire(); + loadDialectModule(dialectNamespace); + // Fast match against the class map first (common case). + const auto foundIt = dialectClassMap.find(dialectNamespace); + if (foundIt != dialectClassMap.end()) { + if (foundIt->second.is_none()) + return llvm::None; + assert(foundIt->second && "py::object is defined"); + return foundIt->second; + } + + // Not found and loading did not yield a registration. Negative cache. + dialectClassMap[dialectNamespace] = py::none(); + return llvm::None; +} + +llvm::Optional +PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) { + { + py::gil_scoped_acquire(); + auto foundIt = rawOpViewClassMapCache.find(operationName); + if (foundIt != rawOpViewClassMapCache.end()) { + if (foundIt->second.is_none()) + return llvm::None; + assert(foundIt->second && "py::object is defined"); + return foundIt->second; + } + } + + // Not found. Load the dialect namespace. + auto split = operationName.split('.'); + llvm::StringRef dialectNamespace = split.first; + loadDialectModule(dialectNamespace); + + // Attempt to find from the canonical map and cache. + { + py::gil_scoped_acquire(); + auto foundIt = rawOpViewClassMap.find(operationName); + if (foundIt != rawOpViewClassMap.end()) { + if (foundIt->second.is_none()) + return llvm::None; + assert(foundIt->second && "py::object is defined"); + // Positive cache. + rawOpViewClassMapCache[operationName] = foundIt->second; + return foundIt->second; + } else { + // Negative cache. + rawOpViewClassMap[operationName] = py::none(); + return llvm::None; + } + } +} + +void PyGlobals::clearImportCache() { + py::gil_scoped_acquire(); + loadedDialectModulesCache.clear(); + rawOpViewClassMapCache.clear(); +} diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index 79128f2..60c282d 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -21,135 +21,6 @@ using namespace mlir; using namespace mlir::python; // ----------------------------------------------------------------------------- -// PyGlobals -// ----------------------------------------------------------------------------- - -PyGlobals *PyGlobals::instance = nullptr; - -PyGlobals::PyGlobals() { - assert(!instance && "PyGlobals already constructed"); - instance = this; -} - -PyGlobals::~PyGlobals() { instance = nullptr; } - -void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { - py::gil_scoped_acquire(); - if (loadedDialectModulesCache.contains(dialectNamespace)) - return; - // Since re-entrancy is possible, make a copy of the search prefixes. - std::vector localSearchPrefixes = dialectSearchPrefixes; - py::object loaded; - for (std::string moduleName : localSearchPrefixes) { - moduleName.push_back('.'); - moduleName.append(dialectNamespace.data(), dialectNamespace.size()); - - try { - py::gil_scoped_release(); - loaded = py::module::import(moduleName.c_str()); - } catch (py::error_already_set &e) { - if (e.matches(PyExc_ModuleNotFoundError)) { - continue; - } else { - throw; - } - } - break; - } - - // Note: Iterator cannot be shared from prior to loading, since re-entrancy - // may have occurred, which may do anything. - loadedDialectModulesCache.insert(dialectNamespace); -} - -void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, - py::object pyClass) { - py::gil_scoped_acquire(); - py::object &found = dialectClassMap[dialectNamespace]; - if (found) { - throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") + - dialectNamespace + - "' is already registered."); - } - found = std::move(pyClass); -} - -void PyGlobals::registerOperationImpl(const std::string &operationName, - py::object pyClass, - py::object rawOpViewClass) { - py::gil_scoped_acquire(); - py::object &found = operationClassMap[operationName]; - if (found) { - throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") + - operationName + - "' is already registered."); - } - found = std::move(pyClass); - rawOpViewClassMap[operationName] = std::move(rawOpViewClass); -} - -llvm::Optional -PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { - py::gil_scoped_acquire(); - loadDialectModule(dialectNamespace); - // Fast match against the class map first (common case). - const auto foundIt = dialectClassMap.find(dialectNamespace); - if (foundIt != dialectClassMap.end()) { - if (foundIt->second.is_none()) - return llvm::None; - assert(foundIt->second && "py::object is defined"); - return foundIt->second; - } - - // Not found and loading did not yield a registration. Negative cache. - dialectClassMap[dialectNamespace] = py::none(); - return llvm::None; -} - -llvm::Optional -PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) { - { - py::gil_scoped_acquire(); - auto foundIt = rawOpViewClassMapCache.find(operationName); - if (foundIt != rawOpViewClassMapCache.end()) { - if (foundIt->second.is_none()) - return llvm::None; - assert(foundIt->second && "py::object is defined"); - return foundIt->second; - } - } - - // Not found. Load the dialect namespace. - auto split = operationName.split('.'); - llvm::StringRef dialectNamespace = split.first; - loadDialectModule(dialectNamespace); - - // Attempt to find from the canonical map and cache. - { - py::gil_scoped_acquire(); - auto foundIt = rawOpViewClassMap.find(operationName); - if (foundIt != rawOpViewClassMap.end()) { - if (foundIt->second.is_none()) - return llvm::None; - assert(foundIt->second && "py::object is defined"); - // Positive cache. - rawOpViewClassMapCache[operationName] = foundIt->second; - return foundIt->second; - } else { - // Negative cache. - rawOpViewClassMap[operationName] = py::none(); - return llvm::None; - } - } -} - -void PyGlobals::clearImportCache() { - py::gil_scoped_acquire(); - loadedDialectModulesCache.clear(); - rawOpViewClassMapCache.clear(); -} - -// ----------------------------------------------------------------------------- // Module initialization. // ----------------------------------------------------------------------------- -- 2.7.4