[PassManager] Implement pass manager tracing API (#4782)
authorJared Roesch <jroesch@octoml.ai>
Tue, 28 Jan 2020 11:25:52 +0000 (03:25 -0800)
committerZhi <5145158+zhiics@users.noreply.github.com>
Tue, 28 Jan 2020 11:25:52 +0000 (19:25 +0800)
* Implement pass tracing API

* Set is_before correctly

* Add docs for trace function

* Fix lint

* Remove PDB

* Ensure trace_func is set before calling

* Fix conditional

docs/dev/relay_pass_infra.rst
include/tvm/ir/transform.h
python/tvm/relay/transform.py
src/ir/transform.cc
src/relay/ir/transform.cc
tests/python/relay/test_pass_manager.py

index 60d2b72..b4f3f6b 100644 (file)
@@ -621,6 +621,26 @@ By inserting the ``PrintIR`` pass after ``FoldConstant``, the pass infra will
 dump out the module IR when ``FoldConstant`` is done. Users can plug in this
 pass after any pass they want to debug for viewing the optimization effect.
 
+There is a more flexible debugging mechanism also exposed by the build configuration
+object. One can pass a tracing function which can be used to execute arbitrary code
+before and/or after each pass. A tracing function will receive a ``IRModule``, ``PassInfo``,
+and a boolean indicating whether you are executing before, or after a pass.
+An example is below.
+
+.. code:: python
+
+    def print_ir(mod, info, is_before):
+        """Print the name of the pass, the IR, only before passes execute."""
+        if is_before:
+            print(f"Running pass: {}", info)
+            print(mod)
+
+    with relay.build_config(opt_level=3, trace=print_ir):
+            with tvm.target.create("llvm"):
+                # Perform the optimizations.
+                mod = seq(mod)
+
+
 For more pass infra related examples in Python and C++, please refer to
 `tests/python/relay/test_pass_manager.py`_ and
 `tests/cpp/relay_transform_sequential.cc`_, respectively.
index c606b34..2afcb17 100644 (file)
 namespace tvm {
 namespace transform {
 
+// Forward declare for TraceFunc.
+class PassInfo;
+
+/*! \brief A callback for tracing passes, useful for debugging and logging.
+ *
+ */
+using TraceFunc =
+  runtime::TypedPackedFunc<void(const IRModule& ir_module,
+                                const PassInfo& ctx,
+                                bool is_before)>;
+
 /*!
  * \brief PassContextNode contains the information that a pass can rely on,
  * such as analysis results.
@@ -88,6 +99,8 @@ class PassContextNode : public Object {
   /*! \brief The list of disabled passes. */
   Array<PrimExpr> disabled_pass;
 
+  TraceFunc trace_func;
+
   PassContextNode() = default;
 
   void VisitAttrs(AttrVisitor* v) {
@@ -101,6 +114,7 @@ class PassContextNode : public Object {
   TVM_DECLARE_FINAL_OBJECT_INFO(PassContextNode, Object);
 };
 
+
 /*!
  * \brief PassContext that is used to configure the pass behavior.
  *
@@ -146,6 +160,14 @@ class PassContext : public ObjectRef {
    */
   TVM_DLL static PassContext Current();
 
+  /*!
+   * \brief Apply the tracing functions of the context to the module, with the info.
+   * \param module The IRModule to trace.
+   * \param info The pass information.
+   * \param is_before Indicated whether the tracing is before or after a pass.
+   */
+  TVM_DLL void Trace(const IRModule& module, const PassInfo& info, bool is_before) const;
+
   // accessor.
   using ContainerType = PassContextNode;
   class Internal;
index c4fbde6..26b20e0 100644 (file)
@@ -78,7 +78,8 @@ class PassContext(RelayNode):
                  opt_level=2,
                  fallback_device=_nd.cpu(),
                  required_pass=None,
-                 disabled_pass=None):
+                 disabled_pass=None,
+                 trace=None):
         if isinstance(fallback_device, str):
             fallback_device = _nd.context(fallback_device).device_type
         elif isinstance(fallback_device, TVMContext):
@@ -99,7 +100,7 @@ class PassContext(RelayNode):
 
         self.__init_handle_by_constructor__(_transform.PassContext, opt_level,
                                             fallback_device, required,
-                                            disabled)
+                                            disabled, trace)
 
     def __enter__(self):
         _transform.EnterPassContext(self)
@@ -117,7 +118,8 @@ class PassContext(RelayNode):
 def build_config(opt_level=2,
                  fallback_device=_nd.cpu(),
                  required_pass=None,
-                 disabled_pass=None):
+                 disabled_pass=None,
+                 trace=None):
     """Configure the build behavior by setting config variables.
 
     Parameters
@@ -151,13 +153,16 @@ def build_config(opt_level=2,
     disabled_pass: set of str, optional
         Optimization passes to be disabled during optimization.
 
+    trace: Callable[[IRModule, PassInfo, bool], None]
+        A tracing function for debugging or introspection.
+
     Returns
     -------
     pass_context: PassContext
         The pass context for optimizations.
     """
     return PassContext(opt_level, fallback_device, required_pass,
-                       disabled_pass)
+                       disabled_pass, trace)
 
 
 @register_relay_node
index 1da010c..14bd063 100644 (file)
@@ -84,6 +84,13 @@ PassContext PassContext::Create() {
   return PassContext(make_object<PassContextNode>());
 }
 
+void PassContext::Trace(const IRModule& module, const PassInfo& info, bool is_before) const {
+    auto pass_ctx_node = this->operator->();
+    if (pass_ctx_node->trace_func != nullptr) {
+      pass_ctx_node->trace_func(module, info, is_before);
+    }
+}
+
 class ModulePass;
 
 /*!
@@ -231,8 +238,10 @@ IRModule ModulePassNode::operator()(const IRModule& mod,
              << " with opt level: "
              << pass_info->opt_level;
   CHECK(mod.defined());
+  pass_ctx.Trace(mod, pass_info, true);
   IRModule updated_mod = pass_func(mod, pass_ctx);
   CHECK(updated_mod.defined());
+  pass_ctx.Trace(updated_mod, pass_info, false);
   return updated_mod;
 }
 
@@ -414,10 +423,12 @@ TVM_REGISTER_GLOBAL("relay._transform.PassContext")
   int fallback_device = args[1];
   tvm::Array<tvm::PrimExpr> required = args[2];
   tvm::Array<tvm::PrimExpr> disabled = args[3];
+  TraceFunc trace_func = args[4];
   pctx->opt_level = opt_level;
   pctx->fallback_device = fallback_device;
   pctx->required_pass = std::move(required);
   pctx->disabled_pass = std::move(disabled);
+  pctx->trace_func = std::move(trace_func);
   *ret = pctx;
 });
 
index ac0f36c..d5cd5c9 100644 (file)
@@ -116,7 +116,7 @@ IRModule FunctionPassNode::operator()(const IRModule& mod,
              << pass_info->name
              << " with opt level: "
              << pass_info->opt_level;
-
+  pass_ctx.Trace(mod, pass_info, true);
   // Execute the pass function and return a new module.
   IRModule updated_mod = IRModule(mod->functions, mod->type_definitions, mod->Imports());
   std::vector<std::pair<GlobalVar, Function> > updates;
@@ -134,6 +134,7 @@ IRModule FunctionPassNode::operator()(const IRModule& mod,
   for (const auto& pair : updates) {
     updated_mod->Add(pair.first, pair.second, true);
   }
+  pass_ctx.Trace(updated_mod, pass_info, false);
   return updated_mod;
 }
 
index e02e917..bd055ee 100644 (file)
@@ -522,6 +522,36 @@ def test_print_ir(capfd):
     assert "Dumping the module IR" in out
     assert "multiply" in out
 
+__TRACE_COUNTER__ = 0
+
+def _tracer(module, info, is_before):
+    global __TRACE_COUNTER__
+    if bool(is_before):
+        __TRACE_COUNTER__ += 1
+
+def test_print_debug_callback():
+    global __TRACE_COUNTER__
+    shape = (1, 2, 3)
+    tp = relay.TensorType(shape, "float32")
+    x = relay.var("x", tp)
+    y = relay.add(x, x)
+    y = relay.multiply(y, relay.const(2, "float32"))
+    func = relay.Function([x], y)
+
+    seq = _transform.Sequential([
+        relay.transform.InferType(),
+        relay.transform.FoldConstant(),
+        relay.transform.DeadCodeElimination()
+    ])
+
+    assert __TRACE_COUNTER__ == 0
+    mod = relay.Module({"main": func})
+
+    with relay.build_config(opt_level=3, trace=_tracer):
+        mod = seq(mod)
+
+    assert __TRACE_COUNTER__ == 4
+
 
 if __name__ == "__main__":
     pytest.main()