[mlir] Improve debug flag management in Python bindings
authorAlex Zinenko <zinenko@google.com>
Mon, 19 Apr 2021 11:37:01 +0000 (13:37 +0200)
committerAlex Zinenko <zinenko@google.com>
Mon, 19 Apr 2021 12:45:43 +0000 (14:45 +0200)
Expose the debug flag as a readable and assignable property of a
dedicated class instead of a write-only function. Actually test the fact
of setting the flag. Move test to a dedicated file, it has zero relation
to context_managers.py where it was added.

Arguably, it should be promoted from mlir.ir to mlir module, but we are
not re-exporting the latter and this functionality is purposefully
hidden so can stay in IR for now. Drop unnecessary export code.

Refactor C API and put Debug into a separate library, fix it to actually
set the flag to the given value.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D100757

mlir/include/mlir-c/Debug.h [new file with mode: 0644]
mlir/include/mlir-c/IR.h
mlir/lib/Bindings/Python/IRCore.cpp
mlir/lib/Bindings/Python/mlir/ir.py
mlir/lib/CAPI/CMakeLists.txt
mlir/lib/CAPI/Debug/CMakeLists.txt [new file with mode: 0644]
mlir/lib/CAPI/Debug/Debug.cpp [new file with mode: 0644]
mlir/lib/CAPI/IR/IR.cpp
mlir/test/Bindings/Python/context_managers.py
mlir/test/Bindings/Python/debug.py [new file with mode: 0644]

diff --git a/mlir/include/mlir-c/Debug.h b/mlir/include/mlir-c/Debug.h
new file mode 100644 (file)
index 0000000..2502f2f
--- /dev/null
@@ -0,0 +1,30 @@
+//===-- mlir-c/Debug.h - C API for MLIR/LLVM debugging functions --*- C -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM
+// Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir-c/Support.h"
+
+#include <stdbool.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/// Sets the global debugging flag.
+MLIR_CAPI_EXPORTED void mlirEnableGlobalDebug(bool enable);
+
+/// Retuns `true` if the global debugging flag is set, false otherwise.
+MLIR_CAPI_EXPORTED bool mlirIsGlobalDebugEnabled();
+
+#ifdef __cplusplus
+}
+#endif
+
+#ifndef MLIR_C_DEBUG_H
+#define MLIR_C_DEBUG_H
+#endif // MLIR_C_DEBUG_H
index c64ec17..8e92510 100644 (file)
@@ -76,13 +76,6 @@ struct MlirNamedAttribute {
 typedef struct MlirNamedAttribute MlirNamedAttribute;
 
 //===----------------------------------------------------------------------===//
-// Global API.
-//===----------------------------------------------------------------------===//
-
-/// Set the global debugging flag.
-MLIR_CAPI_EXPORTED void mlirEnableGlobalDebug(bool enable);
-
-//===----------------------------------------------------------------------===//
 // Context API.
 //===----------------------------------------------------------------------===//
 
index 0f3a1c0..a2655d9 100644 (file)
@@ -14,6 +14,7 @@
 #include "mlir-c/Bindings/Python/Interop.h"
 #include "mlir-c/BuiltinAttributes.h"
 #include "mlir-c/BuiltinTypes.h"
+#include "mlir-c/Debug.h"
 #include "mlir-c/Registration.h"
 #include "llvm/ADT/SmallVector.h"
 #include <pybind11/stl.h>
@@ -129,7 +130,7 @@ equivalent to printing the operation that produced it.
 // Utilities.
 //------------------------------------------------------------------------------
 
-// Helper for creating an @classmethod.
+/// Helper for creating an @classmethod.
 template <class Func, typename... Args>
 py::object classmethod(Func f, Args... args) {
   py::object cf = py::cpp_function(f, args...);
@@ -153,6 +154,20 @@ static MlirStringRef toMlirStringRef(const std::string &s) {
   return mlirStringRefCreate(s.data(), s.size());
 }
 
+/// Wrapper for the global LLVM debugging flag.
+struct PyGlobalDebugFlag {
+  static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
+
+  static bool get(py::object) { return mlirIsGlobalDebugEnabled(); }
+
+  static void bind(py::module &m) {
+    // Debug flags.
+    py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug")
+        .def_property_static("flag", &PyGlobalDebugFlag::get,
+                             &PyGlobalDebugFlag::set, "LLVM-wide debug flag");
+  }
+};
+
 //------------------------------------------------------------------------------
 // Collections.
 //------------------------------------------------------------------------------
@@ -1713,12 +1728,7 @@ private:
 
 void mlir::python::populateIRCore(py::module &m) {
   //----------------------------------------------------------------------------
-  // Mapping of Global functions
-  //----------------------------------------------------------------------------
-  m.def("_enable_debug", [](bool enable) { mlirEnableGlobalDebug(enable); });
-
-  //----------------------------------------------------------------------------
-  // Mapping of MlirContext
+  // Mapping of MlirContext.
   //----------------------------------------------------------------------------
   py::class_<PyMlirContext>(m, "Context")
       .def(py::init<>(&PyMlirContext::createNewContextForInit))
@@ -2384,4 +2394,7 @@ void mlir::python::populateIRCore(py::module &m) {
   PyOpResultList::bind(m);
   PyRegionIterator::bind(m);
   PyRegionList::bind(m);
+
+  // Debug bindings.
+  PyGlobalDebugFlag::bind(m);
 }
index e2c785c..2b42051 100644 (file)
@@ -7,7 +7,3 @@ from ._cext_loader import _reexport_cext
 _reexport_cext("ir", __name__)
 del _reexport_cext
 
-# Extra functions that are not visible to _reexport_cext.
-# TODO: is this really necessary?
-from _mlir.ir import _enable_debug
-_enable_debug = _enable_debug
\ No newline at end of file
index ba58d99..db77cc1 100644 (file)
@@ -1,3 +1,4 @@
+add_subdirectory(Debug)
 add_subdirectory(Dialect)
 add_subdirectory(Conversion)
 add_subdirectory(ExecutionEngine)
diff --git a/mlir/lib/CAPI/Debug/CMakeLists.txt b/mlir/lib/CAPI/Debug/CMakeLists.txt
new file mode 100644 (file)
index 0000000..fdffe30
--- /dev/null
@@ -0,0 +1,6 @@
+add_mlir_public_c_api_library(MLIRCAPIDebug
+  Debug.cpp
+
+  LINK_LIBS PUBLIC
+  MLIRSupport
+)
diff --git a/mlir/lib/CAPI/Debug/Debug.cpp b/mlir/lib/CAPI/Debug/Debug.cpp
new file mode 100644 (file)
index 0000000..288ecd6
--- /dev/null
@@ -0,0 +1,18 @@
+//===- Debug.cpp - C Interface for MLIR/LLVM Debugging Functions ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir-c/Debug.h"
+#include "mlir-c/Support.h"
+
+#include "mlir/CAPI/Support.h"
+
+#include "llvm/Support/Debug.h"
+
+void mlirEnableGlobalDebug(bool enable) { llvm::DebugFlag = enable; }
+
+bool mlirIsGlobalDebugEnabled() { return llvm::DebugFlag; }
index 616caae..000b8f5 100644 (file)
 using namespace mlir;
 
 //===----------------------------------------------------------------------===//
-// Global API.
-//===----------------------------------------------------------------------===//
-
-void mlirEnableGlobalDebug(bool enable) { ::llvm::DebugFlag = true; }
-
-//===----------------------------------------------------------------------===//
 // Context API.
 //===----------------------------------------------------------------------===//
 
index 9fde95a..b93fcf7 100644 (file)
@@ -10,13 +10,6 @@ def run(f):
   assert Context._get_live_count() == 0
 
 
-# CHECK-LABEL: TEST: testExports
-def testExports():
-  from mlir.ir import _enable_debug
-
-run(testExports)
-
-
 # CHECK-LABEL: TEST: testContextEnterExit
 def testContextEnterExit():
   with Context() as ctx:
diff --git a/mlir/test/Bindings/Python/debug.py b/mlir/test/Bindings/Python/debug.py
new file mode 100644 (file)
index 0000000..3268d9f
--- /dev/null
@@ -0,0 +1,39 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from mlir.ir import *
+
+def run(f):
+  print("\nTEST:", f.__name__)
+  f()
+
+
+# CHECK-LABEL: TEST: testNameIsPrivate
+def testNameIsPrivate():
+  # `import *` ignores private names starting with an understore, so the debug
+  # flag shouldn't be visible unless explicitly imported.
+  try:
+    _GlobalDebug.flag = True
+  except NameError:
+    pass
+  else:
+    assert False, "_GlobalDebug must not be available by default"
+
+run(testNameIsPrivate)
+
+
+# CHECK-LABEL: TEST: testDebugDlag
+def testDebugDlag():
+  # Private names must be imported expilcitly.
+  from mlir.ir import _GlobalDebug
+
+  # CHECK: False
+  print(_GlobalDebug.flag)
+  _GlobalDebug.flag = True
+  # CHECK: True
+  print(_GlobalDebug.flag)
+  _GlobalDebug.flag = False
+  # CHECK: False
+  print(_GlobalDebug.flag)
+
+run(testDebugDlag)
+