[mlir][python] Allow adding to existing pass manager
authorrkayaith <rkayaith@gmail.com>
Thu, 20 Oct 2022 04:27:09 +0000 (00:27 -0400)
committerrkayaith <rkayaith@gmail.com>
Fri, 4 Nov 2022 16:04:26 +0000 (12:04 -0400)
This adds a `PassManager.add` method which adds pipeline elements to the
pass manager. This allows for progressively building up a pipeline from
python without string manipulation.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D137344

mlir/lib/Bindings/Python/Pass.cpp
mlir/test/python/integration/dialects/linalg/opsrun.py
mlir/test/python/pass_manager.py

index 13f1cfa..cb3c158 100644 (file)
@@ -101,6 +101,20 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
           "that can be applied on a Module. Throw a ValueError if the pipeline "
           "can't be parsed")
       .def(
+          "add",
+          [](PyPassManager &passManager, const std::string &pipeline) {
+            PyPrintAccumulator errorMsg;
+            MlirLogicalResult status = mlirOpPassManagerAddPipeline(
+                mlirPassManagerGetAsOpPassManager(passManager.get()),
+                mlirStringRefCreate(pipeline.data(), pipeline.size()),
+                errorMsg.getCallback(), errorMsg.getUserData());
+            if (mlirLogicalResultIsFailure(status))
+              throw SetPyError(PyExc_ValueError, std::string(errorMsg.join()));
+          },
+          py::arg("pipeline"),
+          "Add textual pipeline elements to the pass manager. Throws a "
+          "ValueError if the pipeline can't be parsed.")
+      .def(
           "run",
           [](PyPassManager &passManager, PyModule &module) {
             MlirLogicalResult status =
index 2075ecf..585741a 100644 (file)
@@ -191,11 +191,17 @@ def transform(module, boilerplate):
   ops = module.operation.regions[0].blocks[0].operations
   mod = Module.parse("\n".join([str(op) for op in ops]) + boilerplate)
 
-  pm = PassManager.parse(
-      "builtin.module(func.func(convert-linalg-to-loops, lower-affine, " +
-      "convert-math-to-llvm, convert-scf-to-cf, arith-expand, memref-expand), "
-      + "convert-vector-to-llvm, convert-memref-to-llvm, convert-func-to-llvm," +
-      "reconcile-unrealized-casts)")
+  pm = PassManager('builtin.module')
+  pm.add("func.func(convert-linalg-to-loops)")
+  pm.add("func.func(lower-affine)")
+  pm.add("func.func(convert-math-to-llvm)")
+  pm.add("func.func(convert-scf-to-cf)")
+  pm.add("func.func(arith-expand)")
+  pm.add("func.func(memref-expand)")
+  pm.add("convert-vector-to-llvm")
+  pm.add("convert-memref-to-llvm")
+  pm.add("convert-func-to-llvm")
+  pm.add("reconcile-unrealized-casts")
   pm.run(mod)
   return mod
 
index 04e325e..492c7e0 100644 (file)
@@ -75,6 +75,20 @@ def testParseFail():
       log("Exception not produced")
 run(testParseFail)
 
+# Check that adding to a pass manager works
+# CHECK-LABEL: TEST: testAdd
+@run
+def testAdd():
+  pm = PassManager("any", Context())
+  # CHECK: pm: 'any()'
+  log(f"pm: '{pm}'")
+  # CHECK: pm: 'any(cse)'
+  pm.add("cse")
+  log(f"pm: '{pm}'")
+  # CHECK: pm: 'any(cse,cse)'
+  pm.add("cse")
+  log(f"pm: '{pm}'")
+
 
 # Verify failure on incorrect level of nesting.
 # CHECK-LABEL: TEST: testInvalidNesting