From: Zhi <5145158+zhiics@users.noreply.github.com> Date: Thu, 6 Aug 2020 03:36:01 +0000 (-0700) Subject: [DOCS] Update pass infra tutorial (#6193) X-Git-Tag: upstream/0.7.0~300 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=57213879e6ccaa4c0e2ba08b0dca075b623a8742;p=platform%2Fupstream%2Ftvm.git [DOCS] Update pass infra tutorial (#6193) * [DOCS] Update pass infra tutorial * update tutorial --- diff --git a/docs/dev/index.rst b/docs/dev/index.rst index c448cb0..2e577df 100644 --- a/docs/dev/index.rst +++ b/docs/dev/index.rst @@ -295,6 +295,11 @@ The following code snippet gives an example of PassContext configuration. Op is the common class to represent all system-defined primitive operator/intrinsics. Developers can register new Ops as well as their additional attributes(e.g. whether the Op is elementwise) to the system. +.. toctree:: + :maxdepth: 1 + + pass_infra + tvm/target ---------- @@ -353,7 +358,6 @@ memory(for memory optimization). relay_intro relay_op_strategy - relay_pass_infra convert_layout diff --git a/docs/dev/relay_pass_infra.rst b/docs/dev/pass_infra.rst similarity index 67% rename from docs/dev/relay_pass_infra.rst rename to docs/dev/pass_infra.rst index 15487ac..6fd150d 100644 --- a/docs/dev/relay_pass_infra.rst +++ b/docs/dev/pass_infra.rst @@ -15,24 +15,28 @@ specific language governing permissions and limitations under the License. -.. _relay-pass-infra: +.. _pass-infra: -Relay Pass Infrastructure -========================= +Pass Infrastructure +=================== -Relay features a series of optimization passes which improve performance metrics +Both Relay and TVM IR contain a series of optimization passes which improve performance metrics of models such as mean inference, memory footprint, or power consumption for specific devices. There is a suite of standard optimizations as well as machine learning-specific optimizations including constant folding, dead code -elimination, operator layout alteration, and operator fusion, etc. Each of these -passes is structured as a Relay-to-Relay transformation on the abstract syntax -tree (AST) using the analysis result collected during and/or before traversal. +elimination, operator layout alteration, operator fusion, buffer handling, and +loop transformation, etc. Each of these passes is structured as a ir-to-ir +transformation using the analysis result collected during and/or before traversal. -However, as Relay evolves quickly, the need for a more systematic and efficient -way to manage these passes is becoming apparent. This doc describes the design of -such an infra that takes the advantage of the way production compilers are used to -manage the optimization passes and the style modern deep learning frameworks -adopted to build up layers. +However, as TVM evolves quickly, the need for a more systematic and efficient +way to manage these passes is becoming apparent. In addition, a generic +framework that manages the passes across different layers of the TVM stack (e.g. +Relay and tir) paves the way for developers to quickly prototype and plug the +implemented passes into the system. + +This doc describes the design of such an infra that takes the advantage of the +way production compilers are used to manage the optimization passes and the style +modern deep learning frameworks adopted to build up layers. For example, many existing production compilers, such as GCC and LLVM, employ pass managers to effectively manage the execution of passes. Initially managing @@ -88,10 +92,10 @@ needs to be executed when running under a user-provided optimization level. The .. code:: c++ - class PassInfoNode : public RelayNode { - std::string name; + class PassInfoNode : public Object { + String name; int opt_level; - std::vector required; + Array required; }; PassContext @@ -111,17 +115,16 @@ This class is designed for users to conveniently write the Python ``with`` syntax to perform optimizations under a certain configuration. In addition, the users can obtain the context that is available within a certain program scope in a thread-safe way through ``PassContext::Current()``, since a thread-local store -``RelayPassContextThreadLocalStore`` is used to hold the created pass context +``PassContextThreadLocalStore`` is used to hold the created pass context objects. Examples will be provided later to show how we can use both the C++ and Python APIs to create a compilation pipeline using pass context. .. code:: c++ - class PassContextNode : public RelayNode { + class PassContextNode : public Object { public: ErrorReporter err_reporter; int opt_level{2}; - int fallback_device{static_cast(kDLCPU)}; tvm::Array required_pass; tvm::Array disabled_pass; }; @@ -142,32 +145,32 @@ Python APIs to create a compilation pipeline using pass context. friend class tvm::With; }; - struct RelayPassContextThreadLocalEntry { + struct PassContextThreadLocalEntry { /*! \brief The default pass context. */ PassContext default_context; /*! \brief The current pass context. */ std::stack context_stack; - RelayPassContextThreadLocalEntry() { + PassContextThreadLocalEntry() { default_context = PassContext(make_node()); } }; /*! \brief The thread-local store to hold the pass context. */ - typedef dmlc::ThreadLocalStore - RelayPassContextThreadLocalStore; + typedef dmlc::ThreadLocalStore + PassContextThreadLocalStore; Pass Constructs ^^^^^^^^^^^^^^^ The pass infra is designed in a hierarchical manner, and it could work at -different granularities of Relay programs. A pure virtual class ``PassNode`` is +different granularities of Relay/tir programs. A pure virtual class ``PassNode`` is introduced to serve as the base of the different optimization passes. This class contains several virtual methods that must be implemented by the -subclasses at the level of modules, functions, or sequences of passes.. +subclasses at the level of modules, functions, or sequences of passes. .. code:: c++ - class PassNode : RelayNode { + class PassNode : Object { virtual PassInfo Info() const = 0; virtual Module operator()(const IRModule& mod const PassContext& pass_ctx) const = 0; @@ -192,7 +195,8 @@ Module level passes are geared mainly for global and inter-procedural optimizations (IPO), which are similar to the module pass used in LLVM. Some typical passes in Relay that need the global picture of a module, such as A-normal form conversion and lambda lifting, etc., fall into this set. At this -level, users can even add and/or delete functions in a module. +level, users can even add and/or delete functions in a module. Note that all +passes .. code:: c++ @@ -215,13 +219,14 @@ Function-Level Passes ^^^^^^^^^^^^^^^^^^^^^ Function-level passes are used to implement various intra-function level -optimizations for a given Relay module. It fetches one function at a time from +optimizations for a given Relay/tir module. It fetches one function at a time from the function list of a module for optimization and yields a rewritten Relay -function. Most of Relay's passes can be classified into this category, such as -common subexpression elimination and inference simplification, etc. +``Function`` or tir ``PrimFunc``. Most of passes can be classified into this category, such as +common subexpression elimination and inference simplification in Relay as well as vectorization +and flattening storage in tir, etc. -Note that the scope of passes at this level is a Relay function. Therefore, we -cannot add or delete a function through these passes as they are not aware of +Note that the scope of passes at this level is either a Relay function or a tir primitive function. +Therefore, we cannot add or delete a function through these passes as they are not aware of the global information. .. code:: c++ @@ -312,74 +317,25 @@ favorably use Python APIs to create a specific pass object. .. code:: c++ - FunctionPass CreateFunctionPass(std::string name, - int opt_level, - PassFunc pass_func); - - ModulePass CreateModulePass(std::string name, - int opt_level, - PassFunc pass_func); - - SequentialPass CreateSequentialPass(std::string name, - int opt_level, - Array passes, - Array disabled); - -C++ Sequential Example -^^^^^^^^^^^^^^^^^^^^^^ - -Let's now take an example to illustrate how the pass infra works on -``SequentialPass``. For illustrative purpose, only a code snippet is provided. -First, we create a simple Relay program, ``y = f(x)``. Then, we build a module -based on the function. After creating the module, we instantiate a sequential -pass object which contains some standard Relay optimization passes, including -type inference, dead code elimination, common subexpression elimination, and -layout alteration. - -Finally, a pass context is constructed and the passes will be executed -sequentially. During the execution of these passes, the pass dependency will be -resolved automatically as we have encoded the dependent passes during -registration. - -.. code:: c++ - - // Create a simple Relay program. - auto tensor_type = relay::TensorType({}, tvm::Bool()); - auto x = relay::Var("x", relay::Type()); - auto f = relay::Function(tvm::Array{ x }, x, relay::Type(), {}); + Pass CreateFunctionPass( + const runtime::TypedPackedFunc& pass_func, + int opt_level, + String name, + Array required); - auto y = relay::Var("y", tensor_type); - auto call = relay::Call(f, tvm::Array{ y }); - auto fx = relay::Function(tvm::Array{ y }, call, relay::Type(), {}); + Pass CreatePrimFuncPass( + const runtime::TypedPackedFunc& pass_func, + int opt_level, + String name, + Array required); - // Create a module for optimization. - auto mod = IRModule::FromExpr(fx); + Pass CreateModulePass( + const runtime::TypedPackedFunc& pass_func, + int opt_level, + String name, + Array required); - // Create a sequential pass. - tvm::Array pass_seqs{ - relay::transform::InferType(), - relay::transform::DeadCodeElimination(), - relay::transform::EliminateCommonSubexpr(), - relay::transform::AlterOpLayout() - }; - relay::transform::Pass seq = relay::transform::Sequential(pass_seqs); - - // Create a pass context for the optimization. - auto ctx = relay::transform::PassContext::Create(); - ctx->opt_level = 2; - ctx->fallback_device = kDLCPU; - - // Use the Python with syntax to execute the sequence of optimizations. - tvm::With scope(ctx); - mod = seq(mod); - - // View the updated module. - LOG(INFO) << relay::AsText(mod) << std::endl; - -Other types of passes should be directly invoked for execution on a module. For -example, users can directly apply const folding pass on a given module, ``mod -= transform::FoldConstant()(mod)``. However, it is users' responsibility to -execute the required passes explicitly. + Pass Sequential(tvm::Array passes, PassInfo pass_info); Pass Registration ~~~~~~~~~~~~~~~~~ @@ -400,7 +356,7 @@ In order to register this pass to the pass infra, we first need to decide at which level this pass will be performed. As const folding happens on individual functions, we should intuitively create a ``FunctionPass`` for it through ``CreateFunctionPass``. The ``pass_func`` is returned as a packed function that -invokes the ``Expr`` to ``Expr`` API on each function in a Relay module. ``{}`` +invokes the ``Expr`` to ``Expr`` API on each function in a `IRModule`. ``{}`` indicates that no prerequisite is required for this pass. Otherwise, the pass developer has to identify and list them. @@ -414,8 +370,8 @@ Python when needed. namespace transform { Pass FoldConstant() { - runtime::TypedPackedFunc pass_func = - [=](Function f, Module m, PassContext pc) { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { return Downcast(FoldConstant(f)); }; return CreateFunctionPass(pass_func, 2, "FoldConstant", {}); @@ -438,7 +394,8 @@ Python Frontend Only some simple APIs are needed for the frontend side. For example, we can provide users the following APIs to create and execute a pass (full -implementation is provided in `python/tvm/relay/transform.py`_). The backend +implementation is provided in `python/tvm/relay/transform.py`_ and +`python/tvm/ir/transform.py`_). The backend receives the information and decides which function it should use to create a Pass object. @@ -452,13 +409,13 @@ a certain scope. .. code:: python - @register_relay_node - class PassContext(RelayNode): + @tvm._ffi.register_object("transform.PassContext") + class PassContext(tvm.runtime.Object): def __enter__(self): _transform.EnterPassContext(self) return self - def __exit__(self, ptype, value, trace): + def __exit__(self, ptype, value, trace, config): _transform.ExitPassContext(self) @staticmethod @@ -466,10 +423,19 @@ a certain scope. """Return the current pass context.""" return _transform.GetCurrentPassContext() -A ``PassContext`` object can be instantiated through the ``build_config`` API -which was used by Relay to configure the compilation options, including the -optimization level, fallback device for heterogeneous execution, and -required/disabled passes. +A ``PassContext`` is used to configure the compilation options, including the +optimization level and required/disabled passes. It can also take a dictionary +of configs so that different passes can conveniently fetch the passed data, such +as fallback device info and step/depth for loop unrolling, etc. In order to +enable fetching the required config, the key must be registered through +``TVM_REGISTER_PASS_CONFIG_OPTION``. For example, the following is used by the +loop unrolling pass + +.. code:: c++ + + TVM_REGISTER_PASS_CONFIG_OPTION("tir.UnrollLoop", UnrollLoopConfig); + +Please refer to `src/tir/transforms/unroll_loop.cc`_ for more details. Pass Objects ^^^^^^^^^^^^ @@ -493,7 +459,8 @@ example, ``module_pass``, ``function_pass``, and ``sequential`` are provided to users so that they can customize their own pass or pass pipeline. For all the passes that are implemented in the C++ backend, we provide -a corresponding Python API in `python/tvm/relay/transform.py`_. For instance, +corresponding Python APIs in `python/tvm/ir/transform.py`_ and +`python/tvm/relay/transform.py`_, respectively. For instance, const folding has a Python API like the following: .. code:: python @@ -555,95 +522,9 @@ instance, an example function-level pass could be written as the following: Alternatively, users can also directly register a pass without using the -decorators and then invoke it. Let's use ``Sequential`` to demo this scenario. - -Python Sequential Example -^^^^^^^^^^^^^^^^^^^^^^^^^ - -This example not only illustrates how users can directly create a sequential -pass using Python APIs (this could be applied to module- and function-level -passes as well), but also explains how we can build an optimization pipeline -using ``Sequential`` associated with other types of passes. - -.. code:: python - - # Create a simple Relay program. - shape = (1, 2, 3) - c_data = np.array(shape).astype("float32") - tp = relay.TensorType(shape, "float32") - c = relay.const(c_data) - x = relay.var("x", tp) - y = relay.add(c, c) - y = relay.multiply(y, relay.const(2, "float32")) - y = relay.add(x, y) - z = relay.add(y, c) - z1 = relay.add(y, c) - z2 = relay.add(z, z1) - func = relay.Function([x], z2) - - # Customize the optimization pipeline. - seq = tvm.transform.Sequential([ - relay.transform.InferType(), - relay.transform.FoldConstant(), - relay.transform.EliminateCommonSubexpr(), - relay.transform.AlterOpLayout() - ]) - - # Create a module to perform optimizations. - mod = relay.Module({"main": func}) - - # Users can disable any passes that they don't want to execute by providing - # a list, e.g. disabled_pass=["EliminateCommonSubexpr"]. - with relay.build_config(opt_level=3): - with tvm.target.create("llvm"): - # Perform the optimizations. - mod = seq(mod) - -Debugging -~~~~~~~~~ - -The pass infra provides a special pass (``PrintIR``) to dump the IR of the -whole module after applying a certain pass. A slightly modified version of the -sequential pass example could be like the following to enable IR dumping for -``FoldConstant`` optimization. - -.. code:: python - - seq = tvm.transform.Sequential([ - relay.transform.InferType(), - relay.transform.FoldConstant(), - ir.transform.PrintIR(), - relay.transform.EliminateCommonSubexpr(), - relay.transform.AlterOpLayout() - ]) - -By inserting the ``PrintIR`` pass after ``FoldConstant``, the pass infra will -dump out the module IR when ``FoldConstant`` is done. Users can plug in this -pass after any pass they want to debug for viewing the optimization effect. - -There is a more flexible debugging mechanism also exposed by the build configuration -object. One can pass a tracing function which can be used to execute arbitrary code -before and/or after each pass. A tracing function will receive a ``IRModule``, ``PassInfo``, -and a boolean indicating whether you are executing before, or after a pass. -An example is below. - -.. code:: python - - def print_ir(mod, info, is_before): - """Print the name of the pass, the IR, only before passes execute.""" - if is_before: - print(f"Running pass: {}", info) - print(mod) - - with relay.build_config(opt_level=3, trace=print_ir): - with tvm.target.create("llvm"): - # Perform the optimizations. - mod = seq(mod) - - -For more pass infra related examples in Python and C++, please refer to -`tests/python/relay/test_pass_manager.py`_ and -`tests/cpp/relay_transform_sequential.cc`_, respectively. +decorators and then invoke it. For more examples about how to customize your own +optimization pipeline and debug Relay and tir passes, please refer to the +`use pass infra`_ tutorial. .. _Sequential: https://pytorch.org/docs/stable/nn.html?highlight=sequential#torch.nn.Sequential @@ -659,8 +540,10 @@ For more pass infra related examples in Python and C++, please refer to .. _python/tvm/relay/transform.py: https://github.com/apache/incubator-tvm/blob/master/python/tvm/relay/transform.py -.. _tests/python/relay/test_pass_manager.py: https://github.com/apache/incubator-tvm/blob/master/tests/python/relay/test_pass_manager.py +.. _include/tvm/relay/transform.h: https://github.com/apache/incubator-tvm/blob/master/include/tvm/relay/transform.h + +.. _python/tvm/ir/transform.py: https://github.com/apache/incubator-tvm/blob/master/python/tvm/ir/transform.py -.. _tests/cpp/relay_transform_sequential.cc: https://github.com/apache/incubator-tvm/blob/master/tests/cpp/relay_transform_sequential.cc +.. _src/tir/transforms/unroll_loop.cc: https://github.com/apache/incubator-tvm/blob/master/src/tir/transforms/unroll_loop.cc -.. _include/tvm/relay/transform.h: https://github.com/apache/incubator-tvm/blob/master/include/tvm/relay/transform.h +.. _use pass infra: https://github.com/apache/incubator-tvm/blob/master/tutorials/dev/use_pass_infra.py diff --git a/docs/dev/relay_add_pass.rst b/docs/dev/relay_add_pass.rst index fc26559..e1a5e7e 100644 --- a/docs/dev/relay_add_pass.rst +++ b/docs/dev/relay_add_pass.rst @@ -30,7 +30,7 @@ compiler passes. At a high level, there are two key components to writing a pass: - Creating one or more C++ classes that traverse the program -- Wrapping the traversal implementation and its metadata in the pass manager API so it can neatly interface with the :ref:`relay-pass-infra` +- Wrapping the traversal implementation and its metadata in the pass manager API so it can neatly interface with the :ref:`pass-infra` To begin, we'll give an overview of the key mechanisms for writing a compiler pass. Then, we'll walk through a concrete example of the constant-folding @@ -335,7 +335,7 @@ class that takes an expression and internally creates and uses a Registering a Pass with the Pass Manager ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -*Note: please see the documentation on the :ref:`relay-pass-infra` for more specific detail on this subject.* +*Note: please see the documentation on the :ref:`pass-infra` for more specific detail on this subject.* With the AST traversers written, the pass can be registered to become a TVM API endpoint with the following code: @@ -395,7 +395,7 @@ the below code applies both the ``FoldConstant`` and ``ToANormalForm`` passes new_mod = seq(mod) More detail about registration can be found in :ref:`tvm-runtime-system` and more -information about the pass manager interface can be found in :ref:`relay-pass-infra`. +information about the pass manager interface can be found in :ref:`pass-infra`. Relay's standard passes are listed in `include/tvm/relay/transform.h`_ and implemented in `src/relay/pass/`_. diff --git a/tutorials/dev/relay_pass_infra.py b/tutorials/dev/use_pass_infra.py similarity index 76% rename from tutorials/dev/relay_pass_infra.py rename to tutorials/dev/use_pass_infra.py index ae7f544..8212334 100644 --- a/tutorials/dev/relay_pass_infra.py +++ b/tutorials/dev/use_pass_infra.py @@ -16,26 +16,28 @@ # under the License. # pylint: disable=line-too-long """ -.. _tutorial-relay-pass-infra: +.. _tutorial-use-pass-infra: -How to Use Relay Pass Infra -=========================== +How to Use TVM Pass Infra +========================= **Author**: `Zhi Chen `_ -As the number of optimization passes increases in Relay, it becomes intractable to +As the number of optimization passes increases in Relay/tir, it becomes intractable to execute them and maintain their dependencies manually. Therefore, we have -introduced an infrastructure to manage the optimization passes. - -The optimizations of a Relay program could be applied at various granularity, -namely function-level and module-level using :py:class:`tvm.relay.transform.FunctionPass` -and :py:class:`tvm.relay.transform.ModulePass` -respectively. Or users can rely on py:class:`tvm.transform.Sequential` to apply a sequence of passes -on a Relay program where the dependencies between passes can be resolved by the +introduced an infrastructure to manage the optimization passes and make it +applicable to different layers of the IR in the TVM stack. + +The optimizations of a Relay/tir program could be applied at various granularity, +namely function-level and module-level using :py:class:`tvm.relay.transform.FunctionPass`/ +:py:class:`tvm.tir.transform.PrimFuncPass` and :py:class:`tvm.transform.ModulePass` +respectively. Or users can rely on :py:class:`tvm.transform.Sequential` to apply a sequence of passes +on a Relay/tir program where the dependencies between passes can be resolved by the pass infra. For more details about each type of these passes, please refer to -the :ref:`relay-pass-infra` +the :ref:`pass-infra` -This tutorial demostrates how developers can use the Relay pass infra to perform -a certain optimization and create an optimization pipeline. +This tutorial mainly demostrates how developers can use the pass infra to perform +a certain optimization and create an optimization pipeline for a Relay program. +The same approach can be used for tir as well. """ import numpy as np @@ -48,6 +50,7 @@ import tvm.relay as relay # ------------------------------- # First of all, we create a simple Relay program for the tutorial. This program # will be used by various optimizations of the examples in this tutorial. +# Similarly, users can write a tir primitive function and apply the tir passes. def example(): shape = (1, 64, 54, 54) @@ -153,7 +156,7 @@ print(mod1) ############################################################################### # From the transformed Relay program, we can see that there are still two -# identical addition operations. This is because `EliminateCommonSubexpr` +# identical addition operations. This is because ``EliminateCommonSubexpr`` # was not actually performed. The reason is because only the passes that have # optimization level less or equal to 2 will be executed by default under # :py:class:`tvm.transform.Sequential`. The pass infra, @@ -230,10 +233,10 @@ print(mod3) ############################################################################## # Debug a Pass # ------------ -# Relay provides users a plug-and-play style debugging pass that print the IR -# after a certain pass is done. For example, we can print out the IR on the -# completion of constant folding and fusion by adding the debugging pass after -# them. +# TVM provides users a plug-and-play style debugging pass that print the IR +# after a certain pass is done through a special pass (``PrintIR``) to dump the IR of the +# whole module. A slightly modified version of the sequential pass example +# could be like the following to enable IR dumping for ``FoldConstant`` optimization. f = example() mod = tvm.IRModule.from_expr(f) @@ -241,8 +244,39 @@ seq = tvm.transform.Sequential([relay.transform.FoldConstant(), tvm.transform.PrintIR(), relay.transform.EliminateCommonSubexpr(), relay.transform.FuseOps(), - tvm.transform.PrintIR()]) -with tvm.transform.PassContext(opt_level=3): - mod = seq(mod) + relay.transform.AlterOpLayout()]) + +# By inserting the ``PrintIR`` pass after ``FoldConstant``, the pass infra will +# dump out the module IR when ``FoldConstant`` is done. Users can plug in this +# pass after any pass they want to debug for viewing the optimization effect. +# +# There is a more flexible debugging mechanism also exposed by the build configuration +# object. One can pass a tracing function which can be used to execute arbitrary code +# before and/or after each pass. A tracing function will receive a :py::class:`tvm.IRModule`, +# a :py:class:`tvm.transform.PassInfo` object, +# and a boolean indicating whether you are executing before, or after a pass. +# An example is below. + +def print_ir(mod, info, is_before): + """Print the name of the pass, the IR, only before passes execute.""" + if is_before: + print("Running pass: {}", info) + print(mod) + +with tvm.transform.PassContext(opt_level=3, trace=print_ir): + with tvm.target.create("llvm"): + # Perform the optimizations. + mod = seq(mod) +print(mod) print("done") + +############################################################################## +# Summary +# ------- +# This tutorial has covered how we can write and invoke passes in TVM more +# conveniently using the pass infra. Different ways of invoking a pass are also +# disucssed. Using :py:class:`tvm.transform.Sequential` can largely help +# users to ease the work of handling multiple optimization passes and their +# dependencies. In addition, an example is provided to illustrate +# how we can debug a pass using the ``PrintIR`` and tracing.