[mlir][python] Include pipeline parse errors in exception message
authorrkayaith <rkayaith@gmail.com>
Thu, 20 Oct 2022 20:40:32 +0000 (16:40 -0400)
committerrkayaith <rkayaith@gmail.com>
Thu, 27 Oct 2022 17:05:38 +0000 (13:05 -0400)
Currently any errors during pipeline parsing are reported to stderr.
This adds a new pipeline parsing function to the C api that reports
errors through a callback, and updates the python bindings to use it.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D136402

mlir/include/mlir-c/Pass.h
mlir/lib/Bindings/Python/Pass.cpp
mlir/lib/CAPI/IR/Pass.cpp
mlir/test/CAPI/pass.c
mlir/test/python/pass_manager.py

index b66bdfe..6f281b6 100644 (file)
@@ -105,6 +105,13 @@ MLIR_CAPI_EXPORTED void mlirPassManagerAddOwnedPass(MlirPassManager passManager,
 MLIR_CAPI_EXPORTED void
 mlirOpPassManagerAddOwnedPass(MlirOpPassManager passManager, MlirPass pass);
 
+/// Parse a sequence of textual MLIR pass pipeline elements and add them to the
+/// provided OpPassManager. If parsing fails an error message is reported using
+/// the provided callback.
+MLIR_CAPI_EXPORTED MlirLogicalResult mlirOpPassManagerAddPipeline(
+    MlirOpPassManager passManager, MlirStringRef pipelineElements,
+    MlirStringCallback callback, void *userData);
+
 /// Print a textual MLIR pass pipeline by sending chunks of the string
 /// representation and forwarding `userData to `callback`. Note that the
 /// callback may be called several times with consecutive chunks of the string.
index 3278d3a..99d6758 100644 (file)
@@ -82,15 +82,15 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
           py::arg("enable"), "Enable / disable verify-each.")
       .def_static(
           "parse",
-          [](const std::string pipeline, DefaultingPyMlirContext context) {
+          [](const std::string &pipeline, DefaultingPyMlirContext context) {
             MlirPassManager passManager = mlirPassManagerCreate(context->get());
-            MlirLogicalResult status = mlirParsePassPipeline(
+            PyPrintAccumulator errorMsg;
+            MlirLogicalResult status = mlirOpPassManagerAddPipeline(
                 mlirPassManagerGetAsOpPassManager(passManager),
-                mlirStringRefCreate(pipeline.data(), pipeline.size()));
+                mlirStringRefCreate(pipeline.data(), pipeline.size()),
+                errorMsg.getCallback(), errorMsg.getUserData());
             if (mlirLogicalResultIsFailure(status))
-              throw SetPyError(PyExc_ValueError,
-                               llvm::Twine("invalid pass pipeline '") +
-                                   pipeline + "'.");
+              throw SetPyError(PyExc_ValueError, std::string(errorMsg.join()));
             return new PyPassManager(passManager);
           },
           py::arg("pipeline"), py::arg("context") = py::none(),
index a299893..398abfe 100644 (file)
@@ -65,6 +65,15 @@ void mlirOpPassManagerAddOwnedPass(MlirOpPassManager passManager,
   unwrap(passManager)->addPass(std::unique_ptr<Pass>(unwrap(pass)));
 }
 
+MlirLogicalResult mlirOpPassManagerAddPipeline(MlirOpPassManager passManager,
+                                               MlirStringRef pipelineElements,
+                                               MlirStringCallback callback,
+                                               void *userData) {
+  detail::CallbackOstream stream(callback, userData);
+  return wrap(parsePassPipeline(unwrap(pipelineElements), *unwrap(passManager),
+                                stream));
+}
+
 void mlirPrintPassPipeline(MlirOpPassManager passManager,
                            MlirStringCallback callback, void *userData) {
   detail::CallbackOstream stream(callback, userData);
index 4c68a3e..966bcaf 100644 (file)
@@ -133,6 +133,11 @@ static void printToStderr(MlirStringRef str, void *userData) {
   fwrite(str.data, 1, str.length, stderr);
 }
 
+static void dontPrint(MlirStringRef str, void *userData) {
+  (void)str;
+  (void)userData;
+}
+
 void testPrintPassPipeline() {
   MlirContext ctx = mlirContextCreate();
   MlirPassManager pm = mlirPassManagerCreate(ctx);
@@ -176,8 +181,7 @@ void testParsePassPipeline() {
   MlirLogicalResult status = mlirParsePassPipeline(
       mlirPassManagerGetAsOpPassManager(pm),
       mlirStringRefCreateFromCString(
-          "builtin.module(func.func(print-op-stats{json=false}),"
-          " func.func(print-op-stats{json=false}))"));
+          "builtin.module(func.func(print-op-stats{json=false}))"));
   // Expect a failure, we haven't registered the print-op-stats pass yet.
   if (mlirLogicalResultIsSuccess(status)) {
     fprintf(
@@ -190,8 +194,7 @@ void testParsePassPipeline() {
   status = mlirParsePassPipeline(
       mlirPassManagerGetAsOpPassManager(pm),
       mlirStringRefCreateFromCString(
-          "builtin.module(func.func(print-op-stats{json=false}),"
-          " func.func(print-op-stats{json=false}))"));
+          "builtin.module(func.func(print-op-stats{json=false}))"));
   // Expect a failure, we haven't registered the print-op-stats pass yet.
   if (mlirLogicalResultIsFailure(status)) {
     fprintf(stderr,
@@ -199,14 +202,61 @@ void testParsePassPipeline() {
     exit(EXIT_FAILURE);
   }
 
-  //      CHECK: Round-trip: builtin.module(builtin.module(
-  // CHECK-SAME:   func.func(print-op-stats{json=false}),
-  // CHECK-SAME:   func.func(print-op-stats{json=false})
-  // CHECK-SAME: ))
+  //      CHECK: Round-trip: builtin.module(
+  // CHECK-SAME:   builtin.module(func.func(print-op-stats{json=false}))
+  // CHECK-SAME: )
   fprintf(stderr, "Round-trip: ");
   mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr,
                         NULL);
   fprintf(stderr, "\n");
+
+  // Try appending a pass:
+  status = mlirOpPassManagerAddPipeline(
+      mlirPassManagerGetAsOpPassManager(pm),
+      mlirStringRefCreateFromCString("func.func(print-op-stats{json=false})"),
+      printToStderr, NULL);
+  if (mlirLogicalResultIsFailure(status)) {
+    fprintf(stderr, "Unexpected failure appending pipeline\n");
+    exit(EXIT_FAILURE);
+  }
+  //      CHECK: Appended: builtin.module(
+  // CHECK-SAME:   builtin.module(func.func(print-op-stats{json=false})),
+  // CHECK-SAME:   func.func(print-op-stats{json=false})
+  // CHECK-SAME: )
+  fprintf(stderr, "Appended: ");
+  mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr,
+                        NULL);
+  fprintf(stderr, "\n");
+
+  mlirPassManagerDestroy(pm);
+  mlirContextDestroy(ctx);
+}
+
+void testParseErrorCapture() {
+  // CHECK-LABEL: testParseErrorCapture:
+  fprintf(stderr, "\nTEST: testParseErrorCapture:\n");
+
+  MlirContext ctx = mlirContextCreate();
+  MlirPassManager pm = mlirPassManagerCreate(ctx);
+  MlirOpPassManager opm = mlirPassManagerGetAsOpPassManager(pm);
+  MlirStringRef invalidPipeline = mlirStringRefCreateFromCString("invalid");
+
+  // CHECK: mlirOpPassManagerAddPipeline:
+  // CHECK: 'invalid' does not refer to a registered pass or pass pipeline
+  fprintf(stderr, "mlirOpPassManagerAddPipeline:\n");
+  if (mlirLogicalResultIsSuccess(mlirOpPassManagerAddPipeline(
+          opm, invalidPipeline, printToStderr, NULL)))
+    exit(EXIT_FAILURE);
+  fprintf(stderr, "\n");
+
+  // Make sure all output is going through the callback.
+  // CHECK: dontPrint: <>
+  fprintf(stderr, "dontPrint: <");
+  if (mlirLogicalResultIsSuccess(
+          mlirOpPassManagerAddPipeline(opm, invalidPipeline, dontPrint, NULL)))
+    exit(EXIT_FAILURE);
+  fprintf(stderr, ">\n");
+
   mlirPassManagerDestroy(pm);
   mlirContextDestroy(ctx);
 }
@@ -534,6 +584,7 @@ int main() {
   testRunPassOnNestedModule();
   testPrintPassPipeline();
   testParsePassPipeline();
+  testParseErrorCapture();
   testExternalPass();
   return 0;
 }
index df55f20..a2d56a1 100644 (file)
@@ -36,10 +36,8 @@ def testParseSuccess():
     # An unregistered pass should not parse.
     try:
       pm = PassManager.parse("builtin.module(func.func(not-existing-pass{json=false}))")
-      # TODO: this error should be propagate to Python but the C API does not help right now.
-      # CHECK: error: 'not-existing-pass' does not refer to a registered pass or pass pipeline
     except ValueError as e:
-      # CHECK: ValueError exception: invalid pass pipeline 'builtin.module(func.func(not-existing-pass{json=false}))'.
+      # CHECK: ValueError exception: {{.+}} 'not-existing-pass' does not refer to a registered pass
       log("ValueError exception:", e)
     else:
       log("Exception not produced")
@@ -57,7 +55,10 @@ def testParseFail():
     try:
       pm = PassManager.parse("unknown-pass")
     except ValueError as e:
-      # CHECK: ValueError exception: invalid pass pipeline 'unknown-pass'.
+      #      CHECK: ValueError exception: MLIR Textual PassPipeline Parser:1:1: error:
+      # CHECK-SAME: 'unknown-pass' does not refer to a registered pass or pass pipeline
+      #      CHECK: unknown-pass
+      #      CHECK: ^
       log("ValueError exception:", e)
     else:
       log("Exception not produced")
@@ -71,8 +72,7 @@ def testInvalidNesting():
     try:
       pm = PassManager.parse("func.func(normalize-memrefs)")
     except ValueError as e:
-      # CHECK: Can't add pass 'NormalizeMemRefs' restricted to 'builtin.module' on a PassManager intended to run on 'func.func', did you intend to nest?
-      # CHECK: ValueError exception: invalid pass pipeline 'func.func(normalize-memrefs)'.
+      # CHECK: ValueError exception: Can't add pass 'NormalizeMemRefs' restricted to 'builtin.module' on a PassManager intended to run on 'func.func', did you intend to nest?
       log("ValueError exception:", e)
     else:
       log("Exception not produced")