[Relay][Transform] merge PassContext and BuildConfig (#3234)
authorZhi <5145158+zhiics@users.noreply.github.com>
Fri, 24 May 2019 19:05:00 +0000 (12:05 -0700)
committerTianqi Chen <tqchen@users.noreply.github.com>
Fri, 24 May 2019 19:05:00 +0000 (12:05 -0700)
docs/api/python/relay/build_module.rst
docs/api/python/relay/transform.rst [new file with mode: 0644]
include/tvm/relay/transform.h
python/tvm/relay/__init__.py
python/tvm/relay/build_module.py
python/tvm/relay/quantize/quantize.py
python/tvm/relay/transform.py
src/relay/pass/pass_manager.cc
tests/python/frontend/coreml/test_forward.py
tests/python/frontend/keras/test_forward.py
tutorials/frontend/from_tflite.py

index 28dadea..26164bf 100644 (file)
@@ -22,17 +22,9 @@ tvm.relay.build_module
 
 .. autofunction:: tvm.relay.build_module.build
 
-.. autofunction:: tvm.relay.build_module.build_config
-
 .. autofunction:: tvm.relay.build_module.optimize
 
 .. autofunction:: tvm.relay.build_module.create_executor
 
-.. autoclass:: tvm.relay.build_module.BuildConfig
-    :members:
-
-.. autofunction:: tvm.relay.build_module.build_config
-    :members:
-
 .. autoclass:: tvm.relay.build_module.GraphExecutor
     :members:
diff --git a/docs/api/python/relay/transform.rst b/docs/api/python/relay/transform.rst
new file mode 100644 (file)
index 0000000..4eb7f9d
--- /dev/null
@@ -0,0 +1,45 @@
+..  Licensed to the Apache Software Foundation (ASF) under one
+    or more contributor license agreements.  See the NOTICE file
+    distributed with this work for additional information
+    regarding copyright ownership.  The ASF licenses this file
+    to you under the Apache License, Version 2.0 (the
+    "License"); you may not use this file except in compliance
+    with the License.  You may obtain a copy of the License at
+
+..    http://www.apache.org/licenses/LICENSE-2.0
+
+..  Unless required by applicable law or agreed to in writing,
+    software distributed under the License is distributed on an
+    "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+    KIND, either express or implied.  See the License for the
+    specific language governing permissions and limitations
+    under the License.
+
+tvm.relay.transform
+----------------------
+
+.. automodule:: tvm.relay.transform
+
+.. autofunction:: tvm.relay.transform.build_config
+
+.. autofunction:: tvm.relay.transform.module_pass
+
+.. autofunction:: tvm.relay.transform.function_pass
+
+.. autoclass:: tvm.relay.transform.Pass
+    :members:
+
+.. autoclass:: tvm.relay.transform.PassInfo
+    :members:
+
+.. autoclass:: tvm.relay.transform.PassContext
+    :members:
+
+.. autoclass:: tvm.relay.transform.ModulePass
+    :members:
+
+.. autoclass:: tvm.relay.transform.FunctionPass
+    :members:
+
+.. autoclass:: tvm.relay.transform.Sequential
+    :members:
index ba25483..5123f3a 100644 (file)
 #ifndef TVM_RELAY_TRANSFORM_H_
 #define TVM_RELAY_TRANSFORM_H_
 
+#include <tvm/base.h>
 #include <tvm/packed_func_ext.h>
 #include <tvm/relay/error.h>
 #include <tvm/relay/expr.h>
 #include <tvm/relay/module.h>
 #include <string>
+#include <unordered_map>
 #include <vector>
 
 namespace tvm {
@@ -83,18 +85,69 @@ class PassContextNode : public RelayNode {
    */
   ErrorReporter err_reporter;
 
+  /*! \brief The default optimization level. */
+  int opt_level{2};
+
+  /*! \brief CPU is the default fallback device for heterogeneous execution. */
+  int fallback_device{static_cast<int>(kDLCPU)};
+
+  /*! \brief The list of required passes. */
+  tvm::Array<tvm::Expr> required_pass;
+  /*! \brief The list of disabled passes. */
+  tvm::Array<tvm::Expr> disabled_pass;
+
   PassContextNode() = default;
 
   void VisitAttrs(tvm::AttrVisitor* v) final {
+    v->Visit("opt_level", &opt_level);
+    v->Visit("fallback_device", &fallback_device);
+    v->Visit("required_pass", &required_pass);
+    v->Visit("disabled_pass", &disabled_pass);
   }
 
-  TVM_DLL static PassContext make();
-
   static constexpr const char* _type_key = "relay.PassContext";
   TVM_DECLARE_NODE_TYPE_INFO(PassContextNode, RelayNode);
 };
 
-TVM_DEFINE_NODE_REF(PassContext, PassContextNode)
+class PassContext : public NodeRef {
+ public:
+  PassContext() {}
+  explicit PassContext(tvm::NodePtr<Node> n) : NodeRef(n) {}
+
+  /*
+   * \brief Constructor of a `PassContext` object.
+   *
+   * \param opt_level The optimization level that will be applied.
+   * \param fallback_device The fallback device used for heterogeneous
+   *        execution.
+   * \param required_pass The passes that are required for a context to execute
+   *        other passes.
+   * \param required_pass The passes that will be disabled during the
+   *        optimization under a context.
+   */
+  TVM_DLL PassContext(int opt_level,
+                      int fallback_device,
+                      tvm::Array<tvm::Expr> required_pass,
+                      tvm::Array<tvm::Expr> disabled_pass);
+
+  // Get the currently used pass context.
+  TVM_DLL static PassContext Current();
+
+  const PassContextNode* operator->() const;
+
+  using ContainerType = PassContextNode;
+  class Internal;
+
+ private:
+  // The entry of a pass context scope.
+  TVM_DLL void EnterWithScope();
+  // The exit of a pass context scope.
+  TVM_DLL void ExitWithScope();
+
+  // Classes to get the Python `with` like syntax.
+  friend class Internal;
+  friend class tvm::With<PassContext>;
+};
 
 /*
  * \brief The meta data of a pass.
@@ -150,20 +203,28 @@ class PassNode : public RelayNode {
   virtual PassInfo Info() const = 0;
 
   /*!
-   * \brief Set the context information for a pass.
+   * \brief Execute the optimization pass using a functor. This functor
+   * internally uses a current pass context.
+   *
+   * \param mod The module that an optimization pass runs on.
    *
-   * \param pass_ctx The context information for a certain pass.
+   * \return The updated module.
    */
-  virtual void SetContext(const PassContext& pass_ctx) = 0;
+  Module operator()(const Module& mod) const {
+    return this->operator()(mod, PassContext::Current());
+  }
 
   /*!
-   * \brief Execute the optimization pass using a functor.
+   * \brief Execute the optimization pass using a functor under a given pass context.
    *
    * \param mod The module that an optimization pass runs on.
+   * \param pass_ctx The pass context that will be used to help the execution of
+   *        optimizations.
    *
    * \return The updated module.
    */
-  virtual Module operator()(const Module& mod) const = 0;
+  virtual Module operator()(const Module& mod,
+                            const PassContext& pass_ctx) const = 0;
 
   void VisitAttrs(tvm::AttrVisitor* v) override {}
 
@@ -189,13 +250,22 @@ class Sequential : public Pass {
  public:
   /*!
    * \brief The constructor of `Sequential`.
+   *
    * \param passes The passes to apply.
    * \param pass_info The pass metadata.
-   * \param disabled The passes that will not be applied.
    */
   TVM_DLL Sequential(tvm::Array<Pass> passes,
-                     PassInfo pass_info,
-                     tvm::Array<tvm::Expr> disabled);
+                     PassInfo pass_info);
+/*!
+   * \brief The constructor of `Sequential`.
+   *
+   * \param passes The passes to apply.
+   * \param name The name of a sequential pass. It's defaulted to "sequential".
+   *        This allows users to only provide a list of passes and execute them
+   *        under a given context.
+   */
+  TVM_DLL Sequential(tvm::Array<Pass> passes, std::string name = "sequential");
+
   Sequential() = default;
   explicit Sequential(tvm::NodePtr<::tvm::Node> n) : Pass(n) {}
 
index d832c89..1c8f5d6 100644 (file)
@@ -26,7 +26,8 @@ from . import module
 from . import adt
 from . import ir_pass
 from . import transform
-from .build_module import build, build_config, create_executor
+from .build_module import build, create_executor
+from .transform import build_config
 from . import prelude
 from . import parser
 from . import debug
index d0ad78f..6cee393 100644 (file)
@@ -28,81 +28,10 @@ from . import _build_module
 from . import ir_pass
 from . import ty as _ty
 from . import expr as _expr
+from . import transform as _transform
 from .backend import interpreter as _interpreter
 from .backend.vm import VMExecutor
 
-class BuildConfig(object):
-    """Configuration scope to set a build config option.
-
-    Parameters
-    ----------
-    kwargs
-        Keyword arguments of configurations to set.
-    """
-    current = None
-    defaults = {
-        "opt_level": 2,
-        "add_pass": None,
-        "disable_pass": None,
-        "fallback_device": None,
-    }
-
-    def __init__(self, **kwargs):
-        self._old_scope = None
-        for k, _ in kwargs.items():
-            if k not in BuildConfig.defaults:
-                raise ValueError("invalid argument %s, candidates are %s" %
-                                 (k, BuildConfig.defaults.keys()))
-        self._attr = kwargs
-
-    def __getattr__(self, name):
-        if name not in self._attr:
-            return BuildConfig.defaults[name]
-        return self._attr[name]
-
-    def __enter__(self):
-        # pylint: disable=protected-access
-        self._old_scope = BuildConfig.current
-        attr = BuildConfig.current._attr.copy()
-        attr.update(self._attr)
-        self._attr = attr
-        BuildConfig.current = self
-        return self
-
-    def __exit__(self, ptype, value, trace):
-        assert self._old_scope
-        BuildConfig.current = self._old_scope
-
-
-BuildConfig.current = BuildConfig()
-
-
-def build_config(**kwargs):
-    """Configure the build behavior by setting config variables.
-
-    Parameters
-    ----------
-    opt_level: int, default=2
-        Optimization level. See OPT_PASS_LEVEL for level of each pass.
-
-    add_pass: set of str
-        Optimization pass to be added regardless of optimization level.
-
-    disable_pass: set of str
-        Optimization pass to be disabled during optimization.
-
-    fallback_device : str or tvm.TVMContext
-        The fallback device. It is also used as the default device for
-        operators without specified device during heterogeneous execution.
-
-    Returns
-    -------
-    config: BuildConfig
-        The build configuration
-    """
-    return BuildConfig(**kwargs)
-
-
 def _update_target(target):
     target = target if target else _target.current_target()
     if target is None:
@@ -189,7 +118,7 @@ class BuildModule(object):
         return graph_json, mod, params
 
     def _setup_build_config(self, params):
-        cfg = BuildConfig.current
+        cfg = _transform.PassContext.current()
 
         # Set opt_level.
         self.set_opt_level(cfg.opt_level)
@@ -199,24 +128,24 @@ class BuildModule(object):
             self.set_fallback_device(cfg.fallback_device)
 
         # Add required passes.
-        if cfg.add_pass:
+        if cfg.required_pass:
             passes = set()
-            if isinstance(cfg.add_pass, (list, tuple, set)):
-                passes = set(cfg.add_pass)
+            if isinstance(cfg.required_pass, (list, tuple, set)):
+                passes = set(cfg.required_pass)
             else:
                 raise TypeError("add_pass must be list, tuple, or set, but " +
-                                "got {}".format(type(cfg.add_pass)))
+                                "got {}".format(type(cfg.required_pass)))
             for pass_name in passes:
                 self.add_pass(pass_name)
 
         # Add disabled passes.
-        if cfg.disable_pass:
+        if cfg.disabled_pass:
             passes = set()
-            if isinstance(cfg.disable_pass, (list, tuple, set)):
-                passes = set(cfg.disable_pass)
+            if isinstance(cfg.disabled_pass, (list, tuple, set)):
+                passes = set(cfg.disabled_pass)
             else:
                 raise TypeError("disable_pass must be list, tuple, or set, " +
-                                "but got {}".format(type(cfg.disable_pass)))
+                                "but got {}".format(type(cfg.disabled_pass)))
             for pass_name in passes:
                 self.disable_pass(pass_name)
 
@@ -287,12 +216,11 @@ class BuildModule(object):
         fallback_device : str or tvm.TVMContext
             The fallback device used for heterogeneous execution.
         """
-        if isinstance(fallback_device, str):
+        if isinstance(fallback_device, (int, str)):
             fallback_device = _nd.context(fallback_device)
         if not isinstance(fallback_device, TVMContext):
-            raise TypeError("fallback_device is expected to be str " +
-                            "TVMContext, or dict of device name to target, " +
-                            "but received: {}".format(type(fallback_device)))
+            raise TypeError("fallback_device is expected to be str, int, or " +
+                            "TVMContext but received: {}".format(type(fallback_device)))
 
         self._set_fallback_device(fallback_device.device_type)
 
index 7fd0099..2423e76 100644 (file)
@@ -22,7 +22,7 @@ import numpy as np
 from . import _quantize
 from .. import expr as _expr
 from .. import ir_pass as _ir_pass
-from .. import build_module as _build
+from .. import transform as _transform
 from .. import op as _op
 from ... import make as _make
 from ..base import NodeBase, register_relay_node
@@ -301,7 +301,7 @@ def optimize(func, params=None):
                   "FoldConstant",
                   "CanonicalizeOps"]
 
-    cfg = _build.build_config(add_pass=opt_passes)
+    cfg = _transform.build_config(required_pass=opt_passes)
 
     if params:
         name_dict = {}
@@ -321,25 +321,25 @@ def optimize(func, params=None):
             bind_dict[arg] = _expr.const(v)
         func = _expr.bind(func, bind_dict)
 
-    if "SimplifyInference" in cfg.add_pass:
+    if "SimplifyInference" in cfg.required_pass:
         func = _ir_pass.infer_type(func)
         func = _ir_pass.simplify_inference(func)
 
-    if "FoldConstant" in cfg.add_pass:
+    if "FoldConstant" in cfg.required_pass:
         func = _ir_pass.fold_constant(func)
 
-    if "FoldScaleAxis" in cfg.add_pass:
+    if "FoldScaleAxis" in cfg.required_pass:
         func = _ir_pass.infer_type(func)
         func = _ir_pass.backward_fold_scale_axis(func)
         func = _ir_pass.infer_type(func)
         func = _ir_pass.forward_fold_scale_axis(func)
         func = _ir_pass.fold_constant(func)
 
-    if "CanonicalizeOps" in cfg.add_pass:
+    if "CanonicalizeOps" in cfg.required_pass:
         func = _ir_pass.infer_type(func)
         func = _ir_pass.canonicalize_ops(func)
 
-    if "FoldConstant" in cfg.add_pass:
+    if "FoldConstant" in cfg.required_pass:
         func = _ir_pass.fold_constant(func)
 
     return func
index 877538a..a7887c6 100644 (file)
@@ -23,8 +23,10 @@ conveniently.
 """
 import types
 
+from tvm._ffi.runtime_ctypes import TVMContext
 from . import _transform
 from .base import RelayNode, register_relay_node
+from .. import nd as _nd
 
 
 @register_relay_node
@@ -57,10 +59,102 @@ class PassContext(RelayNode):
     Each pass context contains a number of auxiliary information that is used
     to help an optimization pass. Such information includes the error reporter
     to record the errors of during the optimization, etc.
+
+    opt_level : Optional[int]
+        The optimization level of this pass.
+
+    fallback_device : Optional[Union[int, str, TVMContext]]
+        The fallback device type. It is also used as the default device for
+        operators that are not annotated during heterogeneous execution.
+
+    required_pass : Optional[Union[List[str], Set[str], Tuple[str]]]
+        The list of passes that are required by a certain pass.
+
+    disabled_pass : Optional[Union[List[str], Set[str], Tuple[str]]]
+        The list of passes that are disabled.
     """
+    def __init__(self,
+                 opt_level=2,
+                 fallback_device=_nd.cpu(),
+                 required_pass=None,
+                 disabled_pass=None):
+        if isinstance(fallback_device, str):
+            fallback_device = _nd.context(fallback_device).device_type
+        elif isinstance(fallback_device, TVMContext):
+            fallback_device = fallback_device.device_type
+        if not isinstance(fallback_device, int):
+            raise TypeError("required_pass is expected to be the type of " +
+                            "int/str/TVMContext.")
+
+        required = list(required_pass) if required_pass else []
+        if not isinstance(required, (list, tuple)):
+            raise TypeError("required_pass is expected to be the type of " +
+                            "list/tuple/set.")
 
-    def __init__(self):
-        self.__init_handle_by_constructor__(_transform.PassContext)
+        disabled = list(disabled_pass) if disabled_pass else []
+        if not isinstance(disabled, (list, tuple)):
+            raise TypeError("disabled_pass is expected to be the type of " +
+                            "list/tuple/set.")
+
+        self.__init_handle_by_constructor__(_transform.PassContext, opt_level,
+                                            fallback_device, required,
+                                            disabled)
+
+    def __enter__(self):
+        _transform.EnterPassContext(self)
+        return self
+
+    def __exit__(self, ptype, value, trace):
+        _transform.ExitPassContext(self)
+
+    @staticmethod
+    def current():
+        """Return the current pass context."""
+        return _transform.GetCurrentPassContext()
+
+
+def build_config(opt_level=2,
+                 fallback_device=_nd.cpu(),
+                 required_pass=None,
+                 disabled_pass=None):
+    """Configure the build behavior by setting config variables.
+
+    Parameters
+    ----------
+    opt_level: int, optional
+        Optimization level. The optimization pass name and level are as the
+        following:
+
+        .. code-block:: python
+
+            OPT_PASS_LEVEL = {
+                "SimplifyInference": 0,
+                "OpFusion": 1,
+                "FoldConstant": 2,
+                "CombineParallelConv2D": 3,
+                "FoldScaleAxis": 3,
+                "AlterOpLayout": 3,
+                "CanonicalizeOps": 3,
+                "EliminateCommonSubexpr": 3,
+            }
+
+    fallback_device : int, str, or tvm.TVMContext, optional
+        The fallback device. It is also used as the default device for
+        operators without specified device during heterogeneous execution.
+
+    required_pass: set of str, optional
+        Optimization passes that are required regardless of optimization level.
+
+    disabled_pass: set of str, optional
+        Optimization passes to be disabled during optimization.
+
+    Returns
+    -------
+    pass_context: PassContext
+        The pass context for optimizations.
+    """
+    return PassContext(opt_level, fallback_device, required_pass,
+                       disabled_pass)
 
 
 @register_relay_node
@@ -70,20 +164,6 @@ class Pass(RelayNode):
     conveniently interact with the base class.
     """
 
-    def set_pass_context(self, pass_ctx):
-        """Setup the pass context for analysis and optimizations. This context
-        could be shared by different passes for sequential passes.
-
-        Parameters
-        ----------
-        pass_ctx : PassContext
-            The context that is used to help perform a certain pass or a series
-            of passes.
-        """
-        if not isinstance(pass_ctx, PassContext):
-            raise TypeError("pass_ctx is expected to be the PassContext type")
-        _transform.SetContext(self, pass_ctx)
-
     @property
     def info(self):
         """Get the pass meta."""
@@ -150,32 +230,23 @@ class Sequential(Pass):
 
     required : Optional[List[str]]
         The list of passes that the sequential pass is dependent on.
-
-    disabled : Optional[List[str]]
-        A list of disabled passes.
     """
 
     def __init__(self,
                  passes=None,
                  opt_level=2,
                  name="sequential",
-                 required=None,
-                 disabled=None):
+                 required=None):
         passes = passes if passes else []
         if not isinstance(passes, (list, tuple)):
             raise TypeError("passes must be a list of Pass objects.")
 
-        disabled = disabled if disabled else []
-        if not isinstance(disabled, (list, tuple)):
-            raise TypeError("disabled must be a list or tuple of pass names")
-
         required = required if required else []
         if not isinstance(required, (list, tuple)):
             raise TypeError("Required is expected to be the type of list/tuple.")
 
         self.__init_handle_by_constructor__(_transform.Sequential,
-                                            passes, opt_level, name, required,
-                                            disabled)
+                                            passes, opt_level, name, required)
 
 
 def module_pass(pass_func=None, opt_level=None, name=None, required=None):
index a105b69..4bcc0bb 100644 (file)
  * \file src/relay/pass/pass_manager.cc
  * \brief Relay pass manager implementation.
  */
+#include <dmlc/thread_local.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/transform.h>
+#include <tvm/runtime/device_api.h>
+
+#include <algorithm>
+#include <stack>
+#include <unordered_set>
 
 namespace tvm {
 namespace relay {
@@ -31,6 +37,98 @@ namespace transform {
 
 using tvm::IRPrinter;
 
+/*!
+ * \brief A data structure to map the names of specific optimizations to
+ *        numeric optimization levels
+ */
+class OptPassLevel {
+ public:
+  /*!
+   * \brief Get level for an optimization pass
+   *
+   * \param key pass name
+   * \return int level
+   */
+  int operator[](const std::string& key) const {
+    const auto data = CreateMap();
+    auto it = data.find(key);
+    if (it == data.end()) {
+      return -1;
+    }
+    return it->second;
+  }
+
+ private:
+  static const std::unordered_map<std::string, int> CreateMap() {
+    const std::unordered_map<std::string, int> m = {
+      {"SimplifyInference", 0},
+      {"OpFusion", 1},
+      {"FoldConstant", 2},
+      {"CombineParallelConv2D", 3},
+      {"FoldScaleAxis", 3},
+      {"AlterOpLayout", 3},
+      {"CanonicalizeOps", 3},
+      {"EliminateCommonSubexpr", 3}
+    };
+    return m;
+  }
+};
+
+PassContext::PassContext(int opt_level, int fallback_device,
+                         tvm::Array<tvm::Expr> required_pass,
+                         tvm::Array<tvm::Expr> disabled_pass) {
+  auto ctx = make_node<PassContextNode>();
+  ctx->opt_level = opt_level;
+  ctx->fallback_device = fallback_device;
+  ctx->required_pass = std::move(required_pass);
+  ctx->disabled_pass = std::move(disabled_pass);
+  node_ = std::move(ctx);
+}
+
+const PassContextNode* PassContext::operator->() const {
+  return static_cast<const PassContextNode*>(node_.get());
+}
+
+struct RelayPassContextThreadLocalEntry {
+  /*! \brief The default pass context. */
+  PassContext default_context;
+
+  /*! \brief The current pass context. */
+  std::stack<PassContext> context_stack;
+
+  RelayPassContextThreadLocalEntry() {
+    default_context = PassContext(make_node<PassContextNode>());
+  }
+};
+
+/*! \brief Thread local store to hold the pass context. */
+typedef dmlc::ThreadLocalStore<RelayPassContextThreadLocalEntry>
+    RelayPassContextThreadLocalStore;
+
+void PassContext::EnterWithScope() {
+  RelayPassContextThreadLocalEntry* entry =
+      RelayPassContextThreadLocalStore::Get();
+  entry->context_stack.push(*this);
+}
+
+void PassContext::ExitWithScope() {
+  RelayPassContextThreadLocalEntry* entry =
+      RelayPassContextThreadLocalStore::Get();
+  CHECK(!entry->context_stack.empty());
+  CHECK(entry->context_stack.top().same_as(*this));
+  entry->context_stack.pop();
+}
+
+PassContext PassContext::Current() {
+  RelayPassContextThreadLocalEntry* entry =
+      RelayPassContextThreadLocalStore::Get();
+  if (!entry->context_stack.empty()) {
+    return entry->context_stack.top();
+  } else {
+    return entry->default_context;
+  }
+}
+
 class ModulePass;
 
 /*!
@@ -58,38 +156,26 @@ class ModulePassNode : public PassNode {
   }
 
   /*!
-   * \brief Run a module pass on a certain module.
+   * \brief Run a module pass on given pass context.
    *
-   * \param mod The module that an optimization pass runs on.
+   * \param mod The module that an optimization pass is applied on.
+   * \param mod The context that an optimization pass executes on.
    *
    * \return Return the updated module.
    */
-  Module operator()(const Module& mod) const final;
+  Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
 
   /*!
    * \brief Get the pass information/meta data.
    */
   PassInfo Info() const { return pass_info; }
 
-  /*!
-   * \brief Set the context information for a module pass.
-   *
-   * \param pass_ctx The context information for a module pass.
-   */
-  void SetContext(const PassContext& pass_ctx) final;
-
   TVM_DLL static ModulePass make(
       runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func,
       PassInfo pass_info);
 
   static constexpr const char* _type_key = "relay.ModulePass";
   TVM_DECLARE_NODE_TYPE_INFO(ModulePassNode, PassNode);
-
- private:
-  /*!
-   * \brief The context information that is used to help perform a module pass.
-   */
-  PassContext pass_ctx_;
 };
 
 RELAY_DEFINE_NODE_REF(ModulePass, ModulePassNode, Pass);
@@ -124,26 +210,20 @@ class FunctionPassNode : public PassNode {
   }
 
   /*!
-   * \brief Run a function pass on a certain module.
+   * \brief Run a function pass on given pass context.
    *
-   * \param mod The module that an optimization pass runs on.
+   * \param mod The module that an optimization pass is applied on.
+   * \param mod The context that an optimization pass executes on.
    *
    * \return Return the updated module.
    */
-  Module operator()(const Module& mod) const final;
+  Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
 
   /*!
    * \brief Get the pass information/meta data.
    */
   PassInfo Info() const { return pass_info; }
 
-  /*!
-   * \brief Set the context information for a function-level pass.
-   *
-   * \param pass_ctx The context information for a function-level pass.
-   */
-  void SetContext(const PassContext& pass_ctx) final;
-
   TVM_DLL static FunctionPass make(
       runtime::TypedPackedFunc<Function(Function, PassContext)> pass_func,
       PassInfo pass_info);
@@ -160,11 +240,6 @@ class FunctionPassNode : public PassNode {
    * \return Return true if the function will be skipped, otherwise false.
    */
   bool SkipFunction(const Function& func) const;
-
-  /*!
-   * \brief The context information that is used to help perform a module pass.
-   */
-  PassContext pass_ctx_;
 };
 
 RELAY_DEFINE_NODE_REF(FunctionPass, FunctionPassNode, Pass);
@@ -182,18 +257,17 @@ class SequentialNode : public PassNode {
   /* \brief The pass meta data.*/
   PassInfo pass_info;
 
-  /*! \brief A list of passes that used to compose a sequential pass. */
-  tvm::Array<Pass> passes;
   /*!
-   * \brief A list of disabled passes that should be excluded when executing the
-   * sequential pass.
+   * \brief A helper struct to get the optimization pass name to opt level
+   * mapping.
    */
-  tvm::Array<tvm::Expr> disabled;
+  OptPassLevel opt_pass_level;
 
+  /*! \brief A list of passes that used to compose a sequential pass. */
+  tvm::Array<Pass> passes;
   void VisitAttrs(tvm::AttrVisitor* v) final {
     v->Visit("pass_info", &pass_info);
     v->Visit("passes", &passes);
-    v->Visit("disabled", &disabled);
   }
 
   /*!
@@ -211,6 +285,15 @@ class SequentialNode : public PassNode {
   }
 
   /*!
+   * \brief Check if a pass is enabled.
+   *
+   * \param pass_name The name of an optimization/analysis pass.
+   *
+   * \return true if the pass is enabled. Otherwise, false.
+   */
+  bool pass_enabled(const std::string& pass_name) const;
+
+  /*!
    * \brief Resolve the pass dependency. It globs all required passes by
    *        a given pass and executes them.
    *
@@ -224,7 +307,11 @@ class SequentialNode : public PassNode {
    */
   void ResolveDependency(const Module& mod);
 
-  TVM_DLL std::vector<std::string> DisabledPasses() const;
+  std::unordered_set<std::string> DisabledPasses(
+      const Array<tvm::Expr>& disabled) const;
+
+  std::unordered_set<std::string> RequiredPasses(
+      const Array<tvm::Expr>& disabled) const;
 
   /*!
    * \brief Perform optimizations on a series of passes. The aforementioned
@@ -232,27 +319,15 @@ class SequentialNode : public PassNode {
    *        be overloaded to focus on different metrics, i.e. performance,
    *        memory footprint, etc.
    *
-   * \param mod The module that an optimization pass runs on.
+   * \param mod The module that these passes are applied on.
+   * \param pass_ctx The context that these passes execute on.
    *
    * \return Return the updated module.
    */
-  Module operator()(const Module& mod) const final;
-
-  /*!
-   * \brief Set the context information for a sequential pass.
-   *
-   * \param pass_ctx The context information for a sequential pass.
-   */
-  void SetContext(const PassContext& pass_ctx) final;
+  Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
 
   static constexpr const char* _type_key = "relay.Sequential";
   TVM_DECLARE_NODE_TYPE_INFO(SequentialNode, PassNode);
-
- private:
-  /*!
-   * \brief The context information that is used to help perform a module pass.
-   */
-  PassContext pass_ctx_;
 };
 
 PassInfo PassInfoNode::make(int opt_level, std::string name,
@@ -264,11 +339,6 @@ PassInfo PassInfoNode::make(int opt_level, std::string name,
   return PassInfo(pass_info);
 }
 
-PassContext PassContextNode::make() {
-  auto ctx = make_node<PassContextNode>();
-  return PassContext(ctx);
-}
-
 ModulePass ModulePassNode::make(
     runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func,
     PassInfo pass_info) {
@@ -279,23 +349,19 @@ ModulePass ModulePassNode::make(
 }
 
 // Module -> Module optimizations.
-// TODO(zhiics) 1. Check and handle the required passes.
-//              2. Probably use CoW for all places that use module instead of
-//              returning the updated one.
-Module ModulePassNode::operator()(const Module& mod) const {
+// TODO(zhiics) Check and handle the required passes.
+Module ModulePassNode::operator()(const Module& mod,
+                                  const PassContext& pass_ctx) const {
   PassInfo pass_info = Info();
   LOG(INFO) << "Executing module pass : " << pass_info.operator->()->name
             << " with opt level: " << pass_info.operator->()->opt_level << "\n";
+
   CHECK(mod.defined());
-  auto updated_mod = pass_func(mod, pass_ctx_);
+  auto updated_mod = pass_func(mod, pass_ctx);
   CHECK(updated_mod.defined());
   return updated_mod;
 }
 
-void ModulePassNode::SetContext(const PassContext& pass_ctx) {
-  pass_ctx_ = pass_ctx;
-}
-
 FunctionPass FunctionPassNode::make(
     runtime::TypedPackedFunc<Function(Function, PassContext)> pass_func,
     PassInfo pass_info) {
@@ -307,31 +373,22 @@ FunctionPass FunctionPassNode::make(
 
 // Perform Module -> Module optimizations at the Function level.
 // TODO(zhiics) Check and handle the required passes.
-Module FunctionPassNode::operator()(const Module& mod) const {
+Module FunctionPassNode::operator()(const Module& mod,
+                                    const PassContext& pass_ctx) const {
   PassInfo pass_info = Info();
   LOG(INFO) << "Executing function pass : " << pass_info.operator->()->name
             << " with opt level: " << pass_info.operator->()->opt_level << "\n";
   CHECK(mod.defined());
-  std::vector<std::pair<GlobalVar, Function>> updated_funcs;
-  ModuleNode* mod_node = mod.operator->();
-  for (const auto& it : mod_node->functions) {
-    if (!SkipFunction(it.second)) {
-      auto updated_func = pass_func(it.second, pass_ctx_);
-      CHECK(updated_func.defined());
-      updated_funcs.push_back({std::move(it.first), std::move(updated_func)});
-    }
-  }
+  Module new_mod = ModuleNode::make({}, mod->type_definitions);
 
-  // Update the optimized functions.
-  for (const auto& it : updated_funcs) {
-    mod_node->Update(it.first, it.second);
+  // Execute the pass function and return a new module.
+  for (const auto& it : mod->functions) {
+    auto updated_func =
+        SkipFunction(it.second) ? it.second : pass_func(it.second, pass_ctx);
+    new_mod->Add(it.first, updated_func);
   }
 
-  return GetRef<Module>(mod_node);
-}
-
-void FunctionPassNode::SetContext(const PassContext& pass_ctx) {
-  pass_ctx_ = pass_ctx;
+  return new_mod;
 }
 
 // TODO(zhiics) Create an enum attribute for FunctionNode
@@ -342,31 +399,23 @@ bool FunctionPassNode::SkipFunction(const Function& func) const {
   return pval && pval->value != 0;
 }
 
-Sequential::Sequential(tvm::Array<Pass> passes,
-                       PassInfo pass_info,
-                       tvm::Array<tvm::Expr> disabled) {
+Sequential::Sequential(tvm::Array<Pass> passes, PassInfo pass_info) {
   auto n = make_node<SequentialNode>();
   n->passes = std::move(passes);
   n->pass_info = std::move(pass_info);
-  n->disabled = std::move(disabled);
   node_ = std::move(n);
 }
 
-const SequentialNode* Sequential::operator->() const {
-  return static_cast<const SequentialNode*>(this->node_.get());
+Sequential::Sequential(tvm::Array<Pass> passes, std::string name) {
+  auto n = make_node<SequentialNode>();
+  n->passes = std::move(passes);
+  PassInfo pass_info = PassInfoNode::make(2, std::move(name), {});
+  n->pass_info = std::move(pass_info);
+  node_ = std::move(n);
 }
 
-// TODO(jroesch, zhiics): we currenlty only sequentially execute each pass in
-// a Sequential without the consideration of their orders. The phase
-// ordering problem needed to be handled in the future.
-Module SequentialNode::operator()(const Module& module) const {
-  Module mod = module;
-  for (const Pass& pass : passes) {
-    CHECK(pass.defined()) << "Found undefined pass for optimization.";
-    const auto* pn = pass.operator->();
-    mod = (*pn)(mod);
-  }
-  return mod;
+const SequentialNode* Sequential::operator->() const {
+  return static_cast<const SequentialNode*>(this->node_.get());
 }
 
 void SequentialNode::ResolveDependency(const Module& mod) {
@@ -378,18 +427,68 @@ void SequentialNode::ResolveDependency(const Module& mod) {
              << "\n";
 }
 
-std::vector<std::string> SequentialNode::DisabledPasses() const {
-  std::vector<std::string> ret;
+std::unordered_set<std::string> SequentialNode::DisabledPasses(
+    const Array<tvm::Expr>& disabled) const {
+  std::unordered_set<std::string> ret;
   for (const auto& it : disabled) {
     const auto* str = it.as<tvm::ir::StringImm>();
     CHECK(str) << "disabled passes must be string.";
-    ret.push_back(str->value);
+    ret.emplace(str->value);
   }
   return ret;
 }
 
-void SequentialNode::SetContext(const PassContext& pass_ctx) {
-  pass_ctx_ = pass_ctx;
+std::unordered_set<std::string> SequentialNode::RequiredPasses(
+    const Array<tvm::Expr>& required) const {
+  std::unordered_set<std::string> ret;
+  for (const auto& it : required) {
+    const auto* str = it.as<tvm::ir::StringImm>();
+    CHECK(str) << "disabled passes must be string.";
+    ret.emplace(str->value);
+  }
+  return ret;
+}
+
+bool SequentialNode::pass_enabled(const std::string& pass_name) const {
+  PassContext ctx = PassContext::Current();
+
+  const PassContextNode* ctx_node = ctx.operator->();
+  auto required = RequiredPasses(ctx_node->required_pass);
+  auto disabled = DisabledPasses(ctx_node->required_pass);
+
+  if (disabled.count(pass_name)) {
+    return false;
+  }
+
+  if (required.count(pass_name)) {
+    return true;
+  }
+  return ctx_node->opt_level >= opt_pass_level[pass_name];
+}
+
+// TODO(zhiics): we currenlty only sequentially execute each pass in
+// a Sequential without the consideration of their orders. The phase
+// ordering problem needed to be handled in the future.
+Module SequentialNode::operator()(const Module& module,
+                                  const PassContext& pass_ctx) const {
+  const auto* ctx_node = pass_ctx.operator->();
+  int opt_level = ctx_node->opt_level;
+  auto disabled = DisabledPasses(ctx_node->disabled_pass);
+  Module mod = module;
+  for (const Pass& pass : passes) {
+    CHECK(pass.defined()) << "Found undefined pass for optimization.";
+    PassInfo info = pass->Info();
+    const auto& pass_name = info.operator->()->name;
+    const auto& pass_opt_level = info.operator->()->opt_level;
+    // Skip the pass if its optimization level is higher that the  one of in the
+    // pass context or if this pass is disabled.
+    if (pass_opt_level > opt_level || disabled.count(pass_name)) {
+      continue;
+    }
+    const auto* pn = pass.operator->();
+    mod = (*pn)(mod, pass_ctx);
+  }
+  return mod;
 }
 
 Pass CreateModulePass(
@@ -481,9 +580,8 @@ TVM_REGISTER_API("relay._transform.Sequential")
   int opt_level = args[1];
   std::string name = args[2];
   tvm::Array<tvm::Expr> required = args[3];
-  tvm::Array<tvm::Expr> disabled = args[4];
   PassInfo pass_info = PassInfoNode::make(opt_level, name, required);
-  *ret = Sequential(passes, pass_info, disabled);
+  *ret = Sequential(passes, pass_info);
 });
 
 TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
@@ -501,26 +599,58 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
   p->stream << "]";
 });
 
-TVM_REGISTER_API("relay._transform.SetContext")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-  Pass pass = args[0];
-  PassContext pass_ctx = args[1];
-  pass->SetContext(pass_ctx);
-});
-
 TVM_REGISTER_NODE_TYPE(PassContextNode);
 
 TVM_REGISTER_API("relay._transform.PassContext")
-.set_body_typed(PassContextNode::make);
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+  int opt_level = args[0];
+  int fallback_device = args[1];
+  tvm::Array<tvm::Expr> required = args[2];
+  tvm::Array<tvm::Expr> disabled = args[3];
+  *ret = PassContext(opt_level, fallback_device, required, disabled);
+});
 
 TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
 .set_dispatch<PassContextNode>([](const PassContextNode* node,
-                                tvm::IRPrinter* p) {
-    p->stream << "TODO(zhiics): printing context";
-    LOG(FATAL) << "PassContext printer has not been implemented yet."
-               << "\n";
+                               tvm::IRPrinter* p) {
+  p->stream << "Pass context information: " << "\n";
+  p->stream << "\topt_level: " << node->opt_level << "\n";
+  p->stream << "\tfallback device: " << runtime::DeviceName(node->opt_level)
+            << "\n";
+
+  p->stream << "\trequired passes: [" << node->opt_level;
+  for (const auto& it : node->required_pass) {
+    p->stream << it << " ";
+  }
+  p->stream << "]\n";
+
+  p->stream << "\tdisabled passes: [" << node->opt_level;
+  for (const auto& it : node->disabled_pass) {
+    p->stream << it << " ";
+  }
+  p->stream << "]";
 });
 
+class PassContext::Internal {
+ public:
+  static void EnterScope(PassContext pass_ctx) {
+    pass_ctx.EnterWithScope();
+  }
+
+  static void ExitScope(PassContext pass_ctx) {
+    pass_ctx.ExitWithScope();
+  }
+};
+
+TVM_REGISTER_API("relay._transform.GetCurrentPassContext")
+.set_body_typed(PassContext::Current);
+
+TVM_REGISTER_API("relay._transform.EnterPassContext")
+.set_body_typed(PassContext::Internal::EnterScope);
+
+TVM_REGISTER_API("relay._transform.ExitPassContext")
+.set_body_typed(PassContext::Internal::ExitScope);
+
 }  // namespace transform
 }  // namespace relay
 }  // namespace tvm
index 0fed490..da78e96 100644 (file)
@@ -31,7 +31,7 @@ import model_zoo
 
 def get_tvm_output(func, x, params, target, ctx,
                    out_shape=(1, 1000), input_name='image', dtype='float32'):
-    with relay.build_module.build_config(opt_level=3):
+    with relay.transform.build_config(opt_level=3):
         graph, lib, params = relay.build(func, target, params=params)
     m = graph_runtime.create(graph, lib, ctx)
     # set inputs
@@ -72,7 +72,7 @@ def run_tvm_graph(coreml_model, target, ctx, input_data, input_name, output_shap
         dtype_dict = {input_name: input_data.dtype}
 
     func, params = relay.frontend.from_coreml(coreml_model, shape_dict)
-    with relay.build_module.build_config(opt_level=3):
+    with relay.transform.build_config(opt_level=3):
         graph, lib, params = relay.build(func, target, params=params)
 
     from tvm.contrib import graph_runtime
index 35a9229..8817d4f 100644 (file)
@@ -43,7 +43,7 @@ def verify_keras_frontend(keras_model, need_transpose=True):
     def get_tvm_output(xs, target, ctx, dtype='float32'):
         shape_dict = {name: x.shape for (name, x) in zip(keras_model.input_names, xs)}
         func, params = relay.frontend.from_keras(keras_model, shape_dict)
-        with relay.build_module.build_config(opt_level=2):
+        with relay.transform.build_config(opt_level=2):
             graph, lib, params = relay.build(func, target, params=params)
         m = graph_runtime.create(graph, lib, ctx)
         for name, x in zip(keras_model.input_names, xs):
index f8686e9..0166981 100644 (file)
@@ -144,7 +144,7 @@ func, params = relay.frontend.from_tflite(tflite_model,
 
 # target x86 CPU
 target = "llvm"
-with relay.build_module.build_config(opt_level=3):
+with relay.transform.build_config(opt_level=3):
     graph, lib, params = relay.build(func, target, params=params)
 
 ######################################################################