From d97e8cd48239ba6f3e50f92b152e661656ea009d Mon Sep 17 00:00:00 2001 From: rkayaith Date: Thu, 20 Oct 2022 01:04:34 -0400 Subject: [PATCH] [mlir][python] Include anchor op in PassManager constructor This adds an extra argument for specifying the pass manager's anchor op, with a default of `any`. Previously the anchor was always defaulted to `builtin.module`. Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D136406 --- mlir/lib/Bindings/Python/Pass.cpp | 9 ++++++--- mlir/test/python/pass_manager.py | 11 +++++++++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index f08a4bd..13f1cfa 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -56,11 +56,14 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) { // Mapping of the top-level PassManager //---------------------------------------------------------------------------- py::class_(m, "PassManager", py::module_local()) - .def(py::init<>([](DefaultingPyMlirContext context) { - MlirPassManager passManager = - mlirPassManagerCreate(context->get()); + .def(py::init<>([](const std::string &anchorOp, + DefaultingPyMlirContext context) { + MlirPassManager passManager = mlirPassManagerCreateOnOperation( + context->get(), + mlirStringRefCreate(anchorOp.data(), anchorOp.size())); return new PyPassManager(passManager); }), + py::arg("anchor_op") = py::str("any"), py::arg("context") = py::none(), "Create a new PassManager for the current (or provided) Context.") .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py index 99170cd..04e325e 100644 --- a/mlir/test/python/pass_manager.py +++ b/mlir/test/python/pass_manager.py @@ -28,6 +28,17 @@ def testCapsule(): assert pm1 is not None # And does not crash. run(testCapsule) +# CHECK-LABEL: TEST: testConstruct +@run +def testConstruct(): + with Context(): + # CHECK: pm1: 'any()' + # CHECK: pm2: 'builtin.module()' + pm1 = PassManager() + pm2 = PassManager("builtin.module") + log(f"pm1: '{pm1}'") + log(f"pm2: '{pm2}'") + # Verify successful round-trip. # CHECK-LABEL: TEST: testParseSuccess -- 2.7.4