[REFACTOR][TIR][API-Change] Migrate BuildConfig to PassContext. (#5668)
authorTianqi Chen <tqchen@users.noreply.github.com>
Tue, 26 May 2020 03:50:00 +0000 (20:50 -0700)
committerGitHub <noreply@github.com>
Tue, 26 May 2020 03:50:00 +0000 (20:50 -0700)
* [REFACTOR][TIR] Migrate BuildConfig to PassContext.

This PR migrates the TIR configurations from BuildConfig to the
PassContext used by the unified IR.
Moving forward, PassContext will be the unified way to configure passes in the TVM stack.

Changes

- Refactored TVM_PASS_REGISTER_CONFIG_OPTION to take in the reference type.
- Removed BuildConfig.
- Migrated the passes to use PassContext.

* Update include/tvm/ir/attrs.h

Co-authored-by: Zhi <5145158+zhiics@users.noreply.github.com>
Co-authored-by: Zhi <5145158+zhiics@users.noreply.github.com>
58 files changed:
apps/lldb/tvm.py
include/tvm/driver/driver_api.h
include/tvm/ir/attrs.h
include/tvm/ir/transform.h
include/tvm/target/target.h
include/tvm/tir/transform.h
python/tvm/autotvm/measure/measure_methods.py
python/tvm/autotvm/task/relay_integration.py
python/tvm/driver/build_module.py
python/tvm/ir/container.py
python/tvm/ir/transform.py
python/tvm/target/__init__.py
python/tvm/target/build_config.py [deleted file]
python/tvm/tir/transform/transform.py
src/driver/driver_api.cc
src/relay/backend/build_module.cc
src/relay/backend/compile_engine.cc
src/relay/backend/interpreter.cc
src/relay/backend/vm/compiler.cc
src/relay/transforms/fold_constant.cc
src/relay/transforms/partial_eval.cc
src/target/codegen.cc
src/target/target.cc
src/tir/transforms/inject_double_buffer.cc
src/tir/transforms/loop_partition.cc
src/tir/transforms/unroll_loop.cc
tests/cpp/build_module_test.cc
tests/micro/test_runtime_micro_on_arm.py
tests/python/relay/test_pass_fold_constant.py
tests/python/relay/test_pass_manager.py
tests/python/unittest/test_runtime_micro.py
tests/python/unittest/test_target_codegen_c_host.py
tests/python/unittest/test_target_codegen_cuda.py
tests/python/unittest/test_target_codegen_llvm.py
tests/python/unittest/test_te_schedule_bound_inference.py
tests/python/unittest/test_tir_analysis_verify_gpu_code.py
tests/python/unittest/test_tir_transform_inject_double_buffer.py
tests/python/unittest/test_tir_transform_instrument_bound_checkers.py
tests/python/unittest/test_tir_transform_loop_partition.py
tests/python/unittest/test_tir_transform_storage_flatten.py
tests/python/unittest/test_tir_transform_unroll_loop.py
topi/python/topi/arm_cpu/bitserial_conv2d.py
topi/python/topi/arm_cpu/cortex_m7/micro_kernel/gemm.py
topi/python/topi/arm_cpu/tensor_intrin.py
topi/recipe/conv/depthwise_conv2d_test.py
topi/recipe/conv/test_conv2d_hwcn_map.py
topi/recipe/gemm/cuda_gemm_square.py
topi/recipe/reduce/test_reduce_map.py
topi/recipe/rnn/lstm.py
topi/recipe/rnn/matexp.py
tutorials/dev/low_level_custom_pass.py
tutorials/language/tensorize.py
tutorials/optimize/opt_conv_tensorcore.py
tutorials/optimize/opt_matmul_auto_tensorcore.py
vta/python/vta/build_module.py
vta/scripts/tune_resnet.py
vta/tutorials/autotvm/tune_relay_vta.py
vta/tutorials/frontend/deploy_classification.py

index 135aeff..fb5c4de 100644 (file)
@@ -36,7 +36,6 @@ def __lldb_init_module(debugger, _):
         "tvm::Attrs",
         "tvm::BijectiveLayout",
         "tvm::Buffer",
-        "tvm::BuildConfig",
         "tvm::Channel",
         "tvm::EnvFunc",
         "tvm::Expr",
index 1d4d493..547d982 100644 (file)
@@ -48,23 +48,20 @@ namespace tvm {
  * \param args The arguments to the function.
  * \param name The name of the lowered function.
  * \param binds Buffer assignments.
- * \param config The build configuration.
  * \return The result module.
  */
 TVM_DLL IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
-                       const std::unordered_map<te::Tensor, tir::Buffer>& binds,
-                       const BuildConfig& config);
+                       const std::unordered_map<te::Tensor, tir::Buffer>& binds);
 
 /*!
  * \brief Build a device and host module for a specific target from an IRModule.
  * \param funcs The functions to be built.
  * \param target The target device to build for.
  * \param target_host The target for building host code. To use the default, pass Target()
- * \param config The build configuration.
  * \return The built module.
  */
 TVM_DLL runtime::Module build(const IRModule& funcs, const Target& target,
-                              const Target& target_host, const BuildConfig& config);
+                              const Target& target_host);
 
 /*!
  * \brief Build a device and host module for a specific target from a map
@@ -73,11 +70,9 @@ TVM_DLL runtime::Module build(const IRModule& funcs, const Target& target,
  * \param input The map contains target to an IRModule.
  * \param target_host The target for building host code. To use the default,
  *        pass Target().
- * \param config The build configuration.
  * \return The built module that contains code for different processors.
  */
-TVM_DLL runtime::Module build(const Map<Target, IRModule>& input, const Target& target_host,
-                              const BuildConfig& config);
+TVM_DLL runtime::Module build(const Map<Target, IRModule>& input, const Target& target_host);
 
 /*!
  * \brief Build a device and host module for a specific target from a map
@@ -86,11 +81,9 @@ TVM_DLL runtime::Module build(const Map<Target, IRModule>& input, const Target&
  * \param input The map contains target string to an  IRModule.
  * \param target_host The target for building host code. To use the default,
  *        pass Target().
- * \param config The build configuration.
  * \return The built module that contains code for different processors.
  */
-TVM_DLL runtime::Module build(const Map<std::string, IRModule>& input, const Target& target_host,
-                              const BuildConfig& config);
+TVM_DLL runtime::Module build(const Map<std::string, IRModule>& input, const Target& target_host);
 }  // namespace tvm
 
 #endif  // TVM_DRIVER_DRIVER_API_H_
index 4e2e183..0946b2e 100644 (file)
@@ -236,6 +236,19 @@ class DictAttrs : public Attrs {
   TVM_DEFINE_OBJECT_REF_COW_METHOD(DictAttrsNode);
 };
 
+/*!
+ * \brief Create an Attr object with all default values.
+ * \tparam TAttrNode the type to be created.
+ * \return A instance that will represent None.
+ */
+template <typename TAttrs>
+inline TAttrs AttrsWithDefaultValues() {
+  static_assert(std::is_base_of<Attrs, TAttrs>::value, "Can only take attr nodes");
+  auto n = make_object<typename TAttrs::ContainerType>();
+  n->InitByPackedArgs(runtime::TVMArgs(nullptr, nullptr, 0), false);
+  return TAttrs(n);
+}
+
 // Namespace containing detail implementations
 namespace detail {
 using runtime::TVMArgValue;
index dc29b82..6c8bad2 100644 (file)
@@ -208,10 +208,11 @@ class PassContext : public ObjectRef {
    * \brief Register a valid configuration option and its ValueType for validation.
    *
    * \param key The configuration key.
-   * \tparam ValueNodeType The value type to be registered
+   * \tparam ValueType The value type to be registered
    */
-  template <typename ValueNodeType>
+  template <typename ValueType>
   static uint32_t RegisterConfigOption(const char* key) {
+    using ValueNodeType = typename ValueType::ContainerType;
     // NOTE: we could further update the function later.
     uint32_t tindex = ValueNodeType::_GetOrAllocRuntimeTypeIndex();
     RegisterConfigOption(key, tindex);
index de48ac2..c85349d 100644 (file)
@@ -172,109 +172,5 @@ TVM_DLL Target ext_dev(const std::vector<std::string>& options = std::vector<std
 TVM_DLL Target hexagon(const std::vector<std::string>& options = std::vector<std::string>());
 }  // namespace target
 
-/*!
- * \brief Container for build configuration options
- */
-class BuildConfigNode : public Object {
- public:
-  /*!
-   * \brief Splitting factor for loop splitting. If this is set to zero, no splitting will be
-   * done. Otherwise, a split will be done with this factor and the inner loop will be unrolled.
-   */
-  int double_buffer_split_loop = 1;
-  /*! \brief Threshold of number of steps in the loop to be automatically unrolled */
-  int auto_unroll_max_step = 0;
-  /*! \brief The maximum nested level of loops that can be automatically unrolled */
-  int auto_unroll_max_depth = 8;
-  /*! \brief The maximum extent of loop that will be unrolled */
-  int auto_unroll_max_extent = 0;
-  /*!
-   * \brief Whether to explicitly unroll the loop. If set to false, the unroll hint will
-   * be passed to the CodeGen phase. Set to true if CodeGen supports unroll pragma.
-   */
-  bool unroll_explicit = true;
-
-  /*! \brief Set to true if buffer arguments do not overlap. This enables more optimization. */
-  bool restricted_func = true;
-
-  /*! \brief Whether to detect global barrier */
-  bool detect_global_barrier = false;
-
-  /*! \brief Whether to partition const loop */
-  bool partition_const_loop = false;
-
-  /*! \brief List of passes to be injected into the low-level pipeline. */
-  std::vector<std::pair<int, transform::Pass>> add_lower_pass;
-
-  /*! \brief Whether to instrument loads and stores with check for out of the bounds. */
-  bool instrument_bound_checkers = false;
-
-  /*! \brief Whether to disable select rewriting. */
-  bool disable_select_rewriting = false;
-
-  /*! \brief Whether to disable loop vectorization. */
-  bool disable_vectorize = false;
-
-  /*! \brief Whether to disable assert stmt generation. */
-  bool disable_assert = false;
-
-  void VisitAttrs(AttrVisitor* v) {
-    v->Visit("double_buffer_split_loop", &double_buffer_split_loop);
-    v->Visit("auto_unroll_max_step", &auto_unroll_max_step);
-    v->Visit("auto_unroll_max_depth", &auto_unroll_max_depth);
-    v->Visit("auto_unroll_max_extent", &auto_unroll_max_extent);
-    v->Visit("unroll_explicit", &unroll_explicit);
-    v->Visit("restricted_func", &restricted_func);
-    v->Visit("detect_global_barrier", &detect_global_barrier);
-    v->Visit("partition_const_loop", &partition_const_loop);
-    v->Visit("instrument_bound_checkers", &instrument_bound_checkers);
-    v->Visit("disable_select_rewriting", &disable_select_rewriting);
-    v->Visit("disable_vectorize", &disable_vectorize);
-    v->Visit("disable_assert", &disable_assert);
-  }
-
-  static constexpr const char* _type_key = "BuildConfig";
-  TVM_DECLARE_FINAL_OBJECT_INFO(BuildConfigNode, Object);
-};
-
-/*!
- * \brief Build configuration for compilations.
- */
-class BuildConfig : public ::tvm::ObjectRef {
- public:
-  BuildConfig() {}
-  explicit BuildConfig(ObjectPtr<Object> n) : ObjectRef(n) {}
-  const BuildConfigNode* operator->() const { return static_cast<const BuildConfigNode*>(get()); }
-  BuildConfigNode* operator->() { return static_cast<BuildConfigNode*>(get_mutable()); }
-  /*!
-   * \brief Construct a BuildConfig containing a empty build config node.
-   * \return The new BuildConfig
-   */
-  TVM_DLL static BuildConfig Create();
-  /*!
-   * \brief Get the current BuildConfig context from thread local storage, or a default
-   * configuration if a BuildConfig scope has not been entered.
-   * \return The configuration that is the current context.
-   */
-  TVM_DLL static BuildConfig Current();
-
-  using ContainerType = BuildConfigNode;
-  class Internal;
-
- private:
-  // Enable with syntax.
-  friend class With<BuildConfig>;
-  /*!
-   * \brief Push a new BuildConfig context onto the thread local stack.
-   */
-  TVM_DLL void EnterWithScope();
-
-  /*!
-   * \brief Pop a build config off the thread local context stack,
-   * restoring the previous configuration as the current context.
-   */
-  TVM_DLL void ExitWithScope();
-};
-
 }  // namespace tvm
 #endif  // TVM_TARGET_TARGET_H_
index 13e1e25..371277a 100644 (file)
@@ -108,11 +108,9 @@ TVM_DLL Pass LiftAttrScope(std::string attr_key);
 /*!
  * \brief partition loops in the stmt.
  *
- * \param split_const_loop flag to enable partition for const loop
- *
  * \return The pass.
  */
-TVM_DLL Pass LoopPartition(bool split_const_loop);
+TVM_DLL Pass LoopPartition();
 
 /*!
  * \brief Lower vectorization loops.
@@ -133,10 +131,9 @@ TVM_DLL Pass InjectVirtualThread();
 /*!
  * \brief Inject double buffer statements.
  *
- * \param split_loop_factor Loop splitting factor.
  * \return The pass.
  */
-TVM_DLL Pass InjectDoubleBuffer(int split_loop_factor);
+TVM_DLL Pass InjectDoubleBuffer();
 
 /*!
  * \brief Rewrite storage allocation pattern.
@@ -152,15 +149,9 @@ TVM_DLL Pass StorageRewrite();
  * \brief unroll the constant loop marked by unroll.
  * This pass also automatically attach pragma unroll tag to loops which meets the standard.
  *
- * \param auto_max_step The maximum step before stop attach automatic unroll
- * \param auto_max_depth The maximum depth before stop attach automatic unroll
- * \param auto_max_extent The maximum extent of the loop we can unroll,
- *        this is an legacy option that do not take the loop total steps into account.
- * \param explicit_unroll Whether explicitly unroll the loop, or leave unroll annotation to codegen.
  * \return The pass.
  */
-TVM_DLL Pass UnrollLoop(int auto_max_step, int auto_max_depth, int auto_max_extent,
-                        bool explicit_unroll);
+TVM_DLL Pass UnrollLoop();
 
 /*!
  * \brief Remove No Op from the Stmt.
index 8f11a17..4db28d5 100644 (file)
@@ -34,9 +34,9 @@ import tempfile
 import numpy as np
 
 import tvm._ffi
+import tvm.ir.transform
 from tvm import nd, rpc as _rpc, target as _target
 from tvm.error import TVMError
-from tvm.target import build_config
 from tvm.driver import build
 from tvm.contrib import nvcc, ndk, tar
 
@@ -246,7 +246,7 @@ class RPCRunner(Runner):
             if 'cuda' in self.task.target.keys:
                 kwargs["cuda_arch"] = "sm_" + "".join(ctx.compute_version.split('.'))
         if self.task.target.device_name == 'micro_dev':
-            kwargs.setdefault('build_option', {})['disable_vectorize'] = True
+            kwargs.setdefault('build_option', {})['tir.disable_vectorize'] = True
 
         return kwargs
 
@@ -360,7 +360,7 @@ def _build_func_common(measure_input, check_gpu=None, cuda_arch=None, build_opti
 
         opts = build_option or {}
         if check_gpu:  # Add verify pass to filter out invalid configs in advance.
-            opts["add_lower_pass"] = [(2, gpu_verify_pass(**check_gpu))]
+            opts["tir.add_lower_pass"] = [(2, gpu_verify_pass(**check_gpu))]
         if cuda_arch:
             set_cuda_target_arch(cuda_arch)
 
@@ -371,7 +371,7 @@ def _build_func_common(measure_input, check_gpu=None, cuda_arch=None, build_opti
             import vta
             func = vta.build(s, args, target_host=task.target_host)
         else:
-            with build_config(**opts):
+            with tvm.ir.transform.PassContext(config=opts):
                 func = build(s, args, target_host=task.target_host)
     return func, tuple((get_const_tuple(x.shape), x.dtype) for x in args)
 
index f3edfb0..ecc0112 100644 (file)
@@ -41,13 +41,12 @@ def _lower(mod,
     from tvm.relay.backend import graph_runtime_codegen
 
     if hasattr(target, 'device_name') and target.device_name == "vta":
-        with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}):
-            import vta
-            with vta.build_config():
-                mod, _ = relay.optimize(mod, target, params)
-                grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target)
-                grc.codegen(mod["main"])
-                return
+        import vta
+        with vta.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}):
+            mod, _ = relay.optimize(mod, target, params)
+            grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target)
+            grc.codegen(mod["main"])
+            return
 
     # default case
     # Try graph codegen first to extract autotvm tasks.
index b1c15fb..a19b097 100644 (file)
@@ -25,7 +25,8 @@ import tvm.tir
 from tvm.runtime import ndarray
 from tvm.ir import container
 from tvm.ir import CallingConv
-from tvm.target import codegen, BuildConfig
+from tvm.ir.transform import PassContext
+from tvm.target import codegen
 from tvm.te import tensor
 from tvm.te import schedule
 from tvm import target as _target
@@ -102,7 +103,7 @@ def form_irmodule(sch, args, name, binds):
     The body formed according to the given schedule
     """
     # normalize schedule first
-    cfg = BuildConfig.current()
+    pass_ctx = PassContext.current()
     sch = sch.normalize()
     bounds = schedule.InferBound(sch)
     stmt = schedule.ScheduleOps(sch, bounds)
@@ -114,7 +115,8 @@ def form_irmodule(sch, args, name, binds):
     func = schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds)
 
     func = func.with_attr("global_symbol", name)
-    if cfg.restricted_func:
+
+    if pass_ctx.config.get("tir.noalias", True):
         func = func.with_attr("tir.noalias", True)
     return tvm.IRModule({name: func})
 
@@ -152,8 +154,12 @@ def lower(sch,
        The result IRModule, if simple_mode=False
        Then the Stmt before make api is returned.
     """
-    cfg = BuildConfig.current()
-    add_lower_pass = cfg.add_lower_pass if cfg.add_lower_pass else []
+    # config setup
+    pass_ctx = PassContext.current()
+    instrument_bound_checkers = bool(pass_ctx.config.get("tir.instrument_bound_checkers", False))
+    disable_vectorize = bool(pass_ctx.config.get("tir.disable_vectorize", False))
+    add_lower_pass = pass_ctx.config.get("tir.add_lower_pass", [])
+
     lower_phase0 = [x[1] for x in add_lower_pass if x[0] == 0]
     lower_phase1 = [x[1] for x in add_lower_pass if x[0] == 1]
     lower_phase2 = [x[1] for x in add_lower_pass if x[0] == 2]
@@ -169,7 +175,7 @@ def lower(sch,
     # Phase 1
     pass_list += [
         tvm.tir.transform.InjectPrefetch(),
-        tvm.tir.transform.StorageFlatten(64, cfg.instrument_bound_checkers),
+        tvm.tir.transform.StorageFlatten(64, instrument_bound_checkers),
         tvm.tir.transform.NarrowDataType(32),
         tvm.tir.transform.Simplify(),
     ]
@@ -177,18 +183,14 @@ def lower(sch,
 
     # Phase 2
     if not simple_mode:
-        pass_list += [(tvm.tir.transform.LoopPartition(cfg.partition_const_loop))]
+        pass_list += [(tvm.tir.transform.LoopPartition())]
 
     pass_list += [
-        tvm.tir.transform.VectorizeLoop(not cfg.disable_vectorize),
+        tvm.tir.transform.VectorizeLoop(not disable_vectorize),
         tvm.tir.transform.InjectVirtualThread(),
-        tvm.tir.transform.InjectDoubleBuffer(cfg.double_buffer_split_loop),
+        tvm.tir.transform.InjectDoubleBuffer(),
         tvm.tir.transform.StorageRewrite(),
-        tvm.tir.transform.UnrollLoop(
-            cfg.auto_unroll_max_step,
-            cfg.auto_unroll_max_depth,
-            cfg.auto_unroll_max_extent,
-            cfg.unroll_explicit),
+        tvm.tir.transform.UnrollLoop()
     ]
     pass_list += lower_phase2
 
@@ -198,12 +200,11 @@ def lower(sch,
         tvm.tir.transform.RemoveNoOp(),
     ]
 
-    if not cfg.disable_select_rewriting:
-        pass_list += [tvm.tir.transform.RewriteUnsafeSelect()]
+    pass_list += [tvm.tir.transform.RewriteUnsafeSelect()]
     pass_list += lower_phase3
 
     # Instrument BoundCheckers
-    if cfg.instrument_bound_checkers:
+    if instrument_bound_checkers:
         pass_list += [tvm.tir.transform.InstrumentBoundCheckers()]
 
     optimize = tvm.transform.Sequential(pass_list)
@@ -244,7 +245,8 @@ def _build_for_device(input_mod, target, target_host):
     opt_mixed = [tvm.tir.transform.VerifyMemory()]
     if len(mod_mixed.functions) == 1:
         opt_mixed += [tvm.tir.transform.Apply(lambda f: f.with_attr("tir.is_entry_func", True))]
-    if BuildConfig.current().detect_global_barrier:
+
+    if PassContext.current().config.get("tir.detect_global_barrier", False):
         opt_mixed += [tvm.tir.transform.ThreadSync("global")]
     opt_mixed += [tvm.tir.transform.ThreadSync("shared"),
                   tvm.tir.transform.ThreadSync("warp"),
index c415454..26233f1 100644 (file)
@@ -61,6 +61,24 @@ class Map(Object):
     def __len__(self):
         return _ffi_node_api.MapSize(self)
 
+    def get(self, key, default=None):
+        """Get an element with a default value.
+
+        Parameters
+        ----------
+        key : object
+            The attribute key.
+
+        default : object
+            The default object.
+
+        Returns
+        -------
+        value: object
+            The result value.
+        """
+        return self[key] if key in self else default
+
 
 @tvm._ffi.register_object
 class StrMap(Map):
index eb57e34..5f49092 100644 (file)
@@ -99,6 +99,7 @@ class PassContext(tvm.runtime.Object):
             raise TypeError("disabled_pass is expected to be the type of " +
                             "list/tuple/set.")
 
+        config = config if config else None
         self.__init_handle_by_constructor__(_ffi_transform_api.PassContext, opt_level,
                                             fallback_device, required,
                                             disabled, trace, config)
index 6b86ff0..eac939b 100644 (file)
@@ -61,4 +61,3 @@ from .generic_func import generic_func, get_native_generic_func, override_native
 from . import datatype
 from . import codegen
 from .intrin import register_intrin_rule
-from .build_config import BuildConfig, build_config
diff --git a/python/tvm/target/build_config.py b/python/tvm/target/build_config.py
deleted file mode 100644 (file)
index a99797a..0000000
+++ /dev/null
@@ -1,156 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Target dependent BuildConfig for low-level passes."""
-# TODO(tvm-team) consolidate with PassContext
-import tvm._ffi
-import tvm.ir
-
-from tvm.runtime import Object
-from . import _ffi_api
-
-
-@tvm._ffi.register_object
-class BuildConfig(Object):
-    """Configuration scope to set a build config option.
-
-    Note
-    ----
-    This object is backed by object protocol in C++, with arguments that can be
-    exchanged between python and C++.
-
-    Do not construct directly, use build_config instead.
-
-    The fields that are backed by the C++ object are immutable once an instance
-    is constructed. See _object_defaults for the fields.
-    """
-
-    _object_defaults = {
-        "auto_unroll_max_step": 0,
-        "auto_unroll_max_depth": 8,
-        "auto_unroll_max_extent": 0,
-        "unroll_explicit": True,
-        "detect_global_barrier": False,
-        "partition_const_loop": False,
-        "restricted_func": True,
-        "double_buffer_split_loop": 1,
-        "instrument_bound_checkers": False,
-        "disable_select_rewriting": False,
-        "disable_vectorize": False,
-        "disable_assert": False
-    }
-
-    # pylint: disable=no-member
-    def __init__(self, handle):
-        """Initialize the function with handle
-
-        Parameters
-        ----------
-        handle : SymbolHandle
-            the handle to the underlying C++ Symbol
-        """
-        super(BuildConfig, self).__init__(handle)
-        self.handle = handle
-
-    @property
-    def add_lower_pass(self):
-        size = _ffi_api.BuildConfigGetAddLowerPassInfo(self)
-        result = []
-        for i in range(size):
-            phase = _ffi_api.BuildConfigGetAddLowerPassInfo(self, i, True)
-            func = _ffi_api.BuildConfigGetAddLowerPassInfo(self, i, False)
-            result += [(phase, func)]
-        return result
-
-    @add_lower_pass.setter
-    def add_lower_pass(self, value):
-        add_lower_pass_args = []
-        for x in value:
-            add_lower_pass_args += [x[0], x[1]]
-        _ffi_api.BuildConfigSetAddLowerPass(self, *add_lower_pass_args)
-
-    def __enter__(self):
-        # pylint: disable=protected-access
-        _ffi_api.EnterBuildConfigScope(self)
-        return self
-
-    def __exit__(self, ptype, value, trace):
-        _ffi_api.ExitBuildConfigScope(self)
-
-    def __setattr__(self, name, value):
-        if name in BuildConfig._object_defaults:
-            raise AttributeError(
-                "'%s' object cannot set attribute '%s'" % (str(type(self)), name))
-        return super(BuildConfig, self).__setattr__(name, value)
-
-    @staticmethod
-    def current():
-        """Get the current build configuration."""
-        return _ffi_api.GetCurrentBuildConfig()
-
-
-def build_config(**kwargs):
-    """Configure the build behavior by setting config variables.
-
-    Parameters
-    ----------
-    auto_unroll_max_step: int, default=0
-        Threshold of number of steps in the loop to be automatically unrolled.
-        This takes inner loop count into consideration.
-
-    auto_unroll_max_depth: int, default=8
-        The maximum nested level of loops that can be automatically unrolled.
-
-    unroll_explicit: bool, default=True
-        Whether explicitly unroll the loop, if set false, the unroll hint will
-        be passed to the CodeGen phase, which may generate pragma unroll hint.
-        Set this to be true if CodeGen support unroll pragma and
-        when we want to be more readable.
-
-    detect_global_barrier: bool, default=True
-        Whether detect global barrier.
-
-    partition_const_loop: bool, default=False
-        Whether partition const loop
-
-    restricted_func: bool, default=True
-        Whether build restricted function.
-        That is each buffer argument to the function are guaranteed
-        not to overlap. This enables more optimization.
-        Corresponds to restricted keyword in C99
-
-    double_buffer_split_loop: int, default=2
-        Whether split the loop with factor. If it is zero, no splitting will happen.
-        It it is bigger than one, the logic will do a split with factor equals the integer
-        and unroll the inner loop. This allows the buffer fetching won't contain condition.
-
-    add_lower_pass: list of tuple (phase, function(Stmt->Stmt)), default=None
-        phase contains an integer on which optimization pass we apply the pass.
-        Additional lowering passes to be applied before make_api.
-
-    Returns
-    -------
-    config: BuildConfig
-        The build configuration
-    """
-    node_args = {k: v if k not in kwargs else kwargs[k]
-                 for k, v in BuildConfig._object_defaults.items()}
-    config = tvm.ir.make_node("BuildConfig", **node_args)
-
-    if "add_lower_pass" in kwargs:
-        config.add_lower_pass = kwargs["add_lower_pass"]
-
-    return config
index 6d797f8..a5af353 100644 (file)
@@ -138,20 +138,15 @@ def LiftAttrScope(attr_key):
     return _ffi_api.LiftAttrScope(attr_key)
 
 
-def LoopPartition(split_const_loop):
+def LoopPartition():
     """Inject virtual thread loops.
 
-    Parameters
-    ----------
-    split_const_loop : bool
-        Flag to enable partition for const loop.
-
     Returns
     -------
     fpass : tvm.transform.Pass
         The result pass
     """
-    return _ffi_api.LoopPartition(split_const_loop)
+    return _ffi_api.LoopPartition()
 
 
 def VectorizeLoop(enable_vectorize=True):
@@ -182,20 +177,15 @@ def InjectVirtualThread():
     return _ffi_api.InjectVirtualThread()
 
 
-def InjectDoubleBuffer(split_loop_factor):
+def InjectDoubleBuffer():
     """Inject double buffer statements.
 
-    Parameters
-    ----------
-    split_loop_factor : int
-        Loop splitting factor.
-
     Returns
     -------
     fpass : tvm.transform.Pass
         The result pass
     """
-    return _ffi_api.InjectDoubleBuffer(split_loop_factor)
+    return _ffi_api.InjectDoubleBuffer()
 
 
 def StorageRewrite():
@@ -213,36 +203,17 @@ def StorageRewrite():
     return _ffi_api.StorageRewrite()
 
 
-def UnrollLoop(auto_max_step,
-               auto_max_depth,
-               auto_max_extent,
-               explicit_unroll):
+def UnrollLoop():
     """Unroll the constant loop marked by unroll.
 
     This pass also automatically attach pragma unroll tag to loops which meets the standard.
 
-    Parameters
-    ----------
-    auto_max_step : int
-        The maximum step before stop attach automatic unroll
-
-    auto_max_depth : int
-        The maximum depth before stop attach automatic unroll
-
-     auto_max_extent : int
-        The maximum extent of the loop we can unroll.
-        This is an legacy option that do not take the loop total steps into account.
-
-    explicit_unroll : bool
-        Whether explicitly unroll the loop, or leave unroll annotation to codegen.
-
     Returns
     -------
     fpass : tvm.transform.Pass
         The result pass
     """
-    return _ffi_api.UnrollLoop(
-        auto_max_step, auto_max_depth, auto_max_extent, explicit_unroll)
+    return _ffi_api.UnrollLoop()
 
 
 def RemoveNoOp():
index ca1e122..62b8bba 100644 (file)
@@ -23,6 +23,7 @@
  */
 #include <dmlc/thread_local.h>
 #include <tvm/driver/driver_api.h>
+#include <tvm/ir/transform.h>
 #include <tvm/runtime/container.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/target/codegen.h>
 
 namespace tvm {
 
+// Register build pipeline related options
+TVM_REGISTER_PASS_CONFIG_OPTION("tir.noalias", Bool);
+TVM_REGISTER_PASS_CONFIG_OPTION("tir.detect_global_barrier", Bool);
+TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_bound_checkers", Bool);
+TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_assert", Bool);
+TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_vectorize", Bool);
+TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array<Array<ObjectRef>>);
+
 using runtime::PackedFunc;
 using runtime::TVMArgs;
 using runtime::TVMRetValue;
@@ -85,8 +94,7 @@ tir::Buffer BufferWithOffsetAlignment(Array<PrimExpr> shape, DataType dtype, std
 
 void GetBinds(const Array<te::Tensor>& args, bool compact,
               const std::unordered_map<te::Tensor, tir::Buffer>& binds,
-              Map<te::Tensor, tir::Buffer>* out_binds, Array<ObjectRef>* out_arg_list,
-              const BuildConfig& config) {
+              Map<te::Tensor, tir::Buffer>* out_binds, Array<ObjectRef>* out_arg_list) {
   *out_binds = binds;
 
   for (const auto& x : args) {
@@ -120,9 +128,9 @@ transform::Pass Filter(FCond fcond) {
 }
 
 IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
-               const std::unordered_map<te::Tensor, tir::Buffer>& binds,
-               const BuildConfig& config) {
+               const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
   Array<ObjectRef> out_arg_list;
+  auto pass_ctx = transform::PassContext::Current();
 
   sch = sch.normalize();
 
@@ -132,13 +140,19 @@ IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::strin
   bool compact = te::VerifyCompactBuffer(stmt);
 
   Map<te::Tensor, tir::Buffer> out_binds;
-  GetBinds(args, compact, binds, &out_binds, &out_arg_list, config);
+  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
 
   // build the function
   tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds);
   f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
-  if (config->restricted_func) {
-    f = WithAttr(std::move(f), "tir.noalias", Integer(1));
+
+  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
+  bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize", Bool(false)).value();
+  bool instrument_bound_checkers =
+      pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", Bool(false)).value();
+
+  if (noalias) {
+    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
   }
 
   auto mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
@@ -146,25 +160,21 @@ IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::strin
 
   // Phase 0
   pass_list.push_back(tir::transform::InjectPrefetch());
-  pass_list.push_back(tir::transform::StorageFlatten(64, config->instrument_bound_checkers));
+  pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
   // Phase 1
   pass_list.push_back(tir::transform::NarrowDataType(32));
   pass_list.push_back(tir::transform::Simplify());
-  pass_list.push_back(tir::transform::LoopPartition(config->partition_const_loop));
-  pass_list.push_back(tir::transform::VectorizeLoop(!config->disable_vectorize));
+  pass_list.push_back(tir::transform::LoopPartition());
+  pass_list.push_back(tir::transform::VectorizeLoop(!disable_vectorize));
   pass_list.push_back(tir::transform::InjectVirtualThread());
-  pass_list.push_back(tir::transform::InjectDoubleBuffer(config->double_buffer_split_loop));
+  pass_list.push_back(tir::transform::InjectDoubleBuffer());
   pass_list.push_back(tir::transform::StorageRewrite());
-  pass_list.push_back(
-      tir::transform::UnrollLoop(config->auto_unroll_max_step, config->auto_unroll_max_depth,
-                                 config->auto_unroll_max_extent, config->unroll_explicit));
+  pass_list.push_back(tir::transform::UnrollLoop());
   // Phase 2
   pass_list.push_back(tir::transform::Simplify());
   pass_list.push_back(tir::transform::RemoveNoOp());
-  if (!(config->disable_select_rewriting)) {
-    pass_list.push_back(tir::transform::RewriteUnsafeSelect());
-  }
-  if (config->instrument_bound_checkers) {
+  pass_list.push_back(tir::transform::RewriteUnsafeSelect());
+  if (instrument_bound_checkers) {
     pass_list.push_back(tir::transform::InstrumentBoundCheckers());
   }
   // run
@@ -173,12 +183,13 @@ IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::strin
   return mod;
 }
 
-std::pair<IRModule, IRModule> split_dev_host_funcs(IRModule mod_mixed, const Target& target,
-                                                   const Target& target_host,
-                                                   const BuildConfig& config) {
+std::pair<IRModule, IRModule> SplitDevHostFuncs(IRModule mod_mixed, const Target& target,
+                                                const Target& target_host,
+                                                const transform::PassContext& pass_ctx) {
   Array<tvm::transform::Pass> mixed_pass_list = {BindTarget(target),
                                                  tir::transform::VerifyMemory()};
-  if (config->detect_global_barrier) {
+
+  if (pass_ctx->GetConfig<Bool>("tir.detect_global_barrier", Bool(false)).value()) {
     mixed_pass_list.push_back(tir::transform::ThreadSync("global"));
   }
   mixed_pass_list.push_back(tir::transform::ThreadSync("shared"));
@@ -237,10 +248,10 @@ std::pair<IRModule, IRModule> split_dev_host_funcs(IRModule mod_mixed, const Tar
 }
 
 // Build for heterogeneous execution.
-runtime::Module build(const Map<Target, IRModule>& inputs, const Target& target_host,
-                      const BuildConfig& config) {
-  std::vector<runtime::Module> device_modules;
+runtime::Module build(const Map<Target, IRModule>& inputs, const Target& target_host) {
+  auto pass_ctx = transform::PassContext::Current();
 
+  std::vector<runtime::Module> device_modules;
   Target target_host_val = target_host;
   if (!target_host.defined()) {
     for (const auto& it : inputs) {
@@ -258,7 +269,7 @@ runtime::Module build(const Map<Target, IRModule>& inputs, const Target& target_
   IRModule mhost_all = IRModule(Map<GlobalVar, BaseFunc>());
 
   for (const auto& it : inputs) {
-    auto pair = split_dev_host_funcs(it.second, it.first, target_host_val, config);
+    auto pair = SplitDevHostFuncs(it.second, it.first, target_host_val, pass_ctx);
     auto& mhost = pair.first;
     auto& mdevice = pair.second;
 
@@ -279,8 +290,7 @@ runtime::Module build(const Map<Target, IRModule>& inputs, const Target& target_
 }
 
 // Build for heterogeneous execution when target is a string.
-runtime::Module build(const Map<std::string, IRModule>& inputs, const Target& target_host,
-                      const BuildConfig& config) {
+runtime::Module build(const Map<std::string, IRModule>& inputs, const Target& target_host) {
   Map<Target, IRModule> updated_input;
   for (const auto& it : inputs) {
     auto target = Target::Create(it.first);
@@ -289,14 +299,13 @@ runtime::Module build(const Map<std::string, IRModule>& inputs, const Target& ta
     }
     updated_input.Set(target, it.second);
   }
-  return build(updated_input, target_host, config);
+  return build(updated_input, target_host);
 }
 
 // Build for homogeneous execution.
-runtime::Module build(const IRModule& funcs, const Target& target, const Target& target_host,
-                      const BuildConfig& config) {
+runtime::Module build(const IRModule& funcs, const Target& target, const Target& target_host) {
   Map<Target, IRModule> inputs = {{target, funcs}};
-  return build(inputs, target_host, config);
+  return build(inputs, target_host);
 }
 
 }  // namespace tvm
index ef273c3..fe53336 100644 (file)
@@ -445,7 +445,7 @@ class RelayBuildModule : public runtime::ModuleNode {
         ret_.mod = tvm::codegen::CSourceModuleCreate(";", "");
       }
     } else {
-      ret_.mod = tvm::build(lowered_funcs, target_host_, BuildConfig::Current());
+      ret_.mod = tvm::build(lowered_funcs, target_host_);
     }
 
     Array<tvm::runtime::Module> ext_mods = graph_codegen_->GetExternalModules();
index 421b032..182207a 100644 (file)
@@ -561,7 +561,7 @@ class CompileEngineImpl : public CompileEngineNode {
     if (const auto* f = runtime::Registry::Get("relay.backend.build")) {
       m = (*f)(value->cached_func->funcs, key->target);
     } else {
-      m = build(value->cached_func->funcs, key->target, Target(nullptr), BuildConfig::Current());
+      m = build(value->cached_func->funcs, key->target, Target(nullptr));
     }
     value->packed_func = m.GetFunction(value->cached_func->func_name);
     return value->packed_func;
@@ -688,9 +688,11 @@ class CompileEngineImpl : public CompileEngineNode {
     if (const auto* f = runtime::Registry::Get("relay.backend.lower")) {
       cache_node->funcs = (*f)(cfunc->schedule, all_args, cache_node->func_name, key->source_func);
     } else {
-      tvm::BuildConfig bcfg = BuildConfig::Create();
+      using tvm::transform::PassContext;
+      With<PassContext> fresh_pass_ctx_scope(PassContext::Create());
+
       std::unordered_map<te::Tensor, tir::Buffer> binds;
-      cache_node->funcs = tvm::lower(cfunc->schedule, all_args, cache_node->func_name, binds, bcfg);
+      cache_node->funcs = tvm::lower(cfunc->schedule, all_args, cache_node->func_name, binds);
     }
     value->cached_func = CachedFunc(cache_node);
     return value;
@@ -722,9 +724,12 @@ class CompileEngineImpl : public CompileEngineNode {
     for (te::Tensor arg : cache_node->outputs) {
       all_args.push_back(arg);
     }
-    tvm::BuildConfig bcfg = BuildConfig::Create();
+
+    using tvm::transform::PassContext;
+    With<PassContext> fresh_pass_ctx_scope(PassContext::Create());
+
     std::unordered_map<te::Tensor, tir::Buffer> binds;
-    cache_node->funcs = tvm::lower(spair.first, all_args, cache_node->func_name, binds, bcfg);
+    cache_node->funcs = tvm::lower(spair.first, all_args, cache_node->func_name, binds);
     value->cached_func = CachedFunc(cache_node);
     return value;
   }
index c529997..d9be91d 100644 (file)
@@ -380,7 +380,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
     if (const auto* f = runtime::Registry::Get("relay.backend.build")) {
       m = (*f)(cfunc->funcs, cfunc->target);
     } else {
-      m = build(cfunc->funcs, cfunc->target, Target(nullptr), BuildConfig::Current());
+      m = build(cfunc->funcs, cfunc->target, Target(nullptr));
     }
     shape_func = m.GetFunction(cfunc->func_name);
     shape_func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv);
index 810664e..702d9b4 100644 (file)
@@ -1008,7 +1008,7 @@ void VMCompiler::Codegen() {
   auto ext_mods = compile_engine->LowerExternalFunctions();
   runtime::Module mod;
   if (funcs.size() > 0) {
-    mod = tvm::build(funcs, target_host_, tvm::BuildConfig::Current());
+    mod = tvm::build(funcs, target_host_);
     CHECK(mod.operator->());
   } else {
     CHECK_EQ(ext_mods.size(), 1U)
index e2ab35b..b41284b 100644 (file)
@@ -268,13 +268,14 @@ class ConstantFolder : public ExprMutator {
 };
 
 Expr FoldConstant(const Expr& expr, const IRModule& mod) {
+  using tvm::transform::PassContext;
   DLContext ctx;
   ctx.device_type = kDLCPU;
   ctx.device_id = 0;
   Target target = Target::Create("llvm");
   // use a fresh build context
   // in case we are already in a build context.
-  With<BuildConfig> fresh_build_ctx(BuildConfig::Create());
+  With<PassContext> fresh_build_ctx(PassContext::Create());
 
   return ConstantFolder(CreateInterpreter(mod, ctx, target), mod).Mutate(expr);
 }
index 3e27b87..9a002ef 100644 (file)
@@ -533,10 +533,12 @@ DLContext CPUContext() {
 }
 
 FInterpreter CPUInterpreter() {
+  using tvm::transform::PassContext;
+
   Target target = Target::Create("llvm");
   // use a fresh build context
   // in case we are already in a build context.
-  With<BuildConfig> fresh_build_ctx(BuildConfig::Create());
+  With<PassContext> fresh_build_ctx(PassContext::Create());
 
   return CreateInterpreter(IRModule(nullptr), CPUContext(), target);
 }
index e3890ca..d0a3156 100644 (file)
@@ -42,10 +42,11 @@ namespace tvm {
 namespace codegen {
 
 runtime::Module Build(IRModule mod, const Target& target) {
-  if (BuildConfig::Current()->disable_assert) {
+  if (transform::PassContext::Current()
+          ->GetConfig<Bool>("tir.disable_assert", Bool(false))
+          .value()) {
     mod = tir::transform::SkipAssert()(mod);
   }
-
   std::string build_f_name = "target.build." + target->target_name;
   // the build function.
   const PackedFunc* bf = runtime::Registry::Get(build_f_name);
index aac5a2b..f3ade8d 100644 (file)
@@ -313,111 +313,4 @@ Target ext_dev(const std::vector<std::string>& options) { return CreateTarget("e
 
 Target hexagon(const std::vector<std::string>& options) { return CreateTarget("hexagon", options); }
 }  // namespace target
-
-BuildConfig BuildConfig::Create() { return BuildConfig(make_object<BuildConfigNode>()); }
-
-/*! \brief Entry to hold the BuildConfig context stack. */
-struct TVMBuildConfigThreadLocalEntry {
-  /*! \brief The default build config if the stack is empty */
-  BuildConfig default_config;
-
-  /*! \brief The current build config context */
-  std::stack<BuildConfig> context_stack;
-
-  TVMBuildConfigThreadLocalEntry() : default_config(BuildConfig::Create()) {}
-};
-
-/*! \brief Thread local store to hold the BuildConfig context stack. */
-typedef dmlc::ThreadLocalStore<TVMBuildConfigThreadLocalEntry> TVMBuildConfigThreadLocalStore;
-
-void BuildConfig::EnterWithScope() {
-  TVMBuildConfigThreadLocalEntry* entry = TVMBuildConfigThreadLocalStore::Get();
-  entry->context_stack.push(*this);
-}
-
-void BuildConfig::ExitWithScope() {
-  TVMBuildConfigThreadLocalEntry* entry = TVMBuildConfigThreadLocalStore::Get();
-  CHECK(!entry->context_stack.empty());
-  CHECK(entry->context_stack.top().same_as(*this));
-  entry->context_stack.pop();
-}
-
-tvm::BuildConfig BuildConfig::Current() {
-  TVMBuildConfigThreadLocalEntry* entry = TVMBuildConfigThreadLocalStore::Get();
-  if (entry->context_stack.size() > 0) {
-    return entry->context_stack.top();
-  }
-
-  return entry->default_config;
-}
-
-TVM_REGISTER_NODE_TYPE(BuildConfigNode);
-
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-    .set_dispatch<BuildConfigNode>([](const ObjectRef& node, ReprPrinter* p) {
-      auto* op = static_cast<const BuildConfigNode*>(node.get());
-      p->stream << "build_config(";
-      p->stream << "double_buffer_split_loop=" << op->double_buffer_split_loop << ", ";
-      p->stream << "auto_unroll_max_step=" << op->auto_unroll_max_step << ", ";
-      p->stream << "auto_unroll_max_depth=" << op->auto_unroll_max_depth << ", ";
-      p->stream << "auto_unroll_max_extent=" << op->auto_unroll_max_extent << ", ";
-      p->stream << "unroll_explicit=" << op->unroll_explicit << ", ";
-      p->stream << "restricted_func=" << op->restricted_func << ", ";
-      p->stream << "detect_global_barrier=" << op->detect_global_barrier << ", ";
-      p->stream << "partition_const_loop=" << op->partition_const_loop << ", ";
-      p->stream << "instrument_bound_checkers=" << op->instrument_bound_checkers << ", ";
-      p->stream << "disable_select_rewriting=" << op->disable_select_rewriting;
-      p->stream << "disable_vectorize=" << op->disable_vectorize;
-      p->stream << "disable_assert=" << op->disable_assert;
-      p->stream << ")";
-    });
-
-TVM_REGISTER_GLOBAL("target.GetCurrentBuildConfig").set_body([](TVMArgs args, TVMRetValue* ret) {
-  *ret = BuildConfig::Current();
-});
-
-class BuildConfig::Internal {
- public:
-  static void EnterScope(BuildConfig target) { target.EnterWithScope(); }
-  static void ExitScope(BuildConfig target) { target.ExitWithScope(); }
-};
-
-TVM_REGISTER_GLOBAL("target.EnterBuildConfigScope")
-    .set_body_typed(BuildConfig::Internal::EnterScope);
-
-TVM_REGISTER_GLOBAL("target.ExitBuildConfigScope").set_body_typed(BuildConfig::Internal::ExitScope);
-
-TVM_REGISTER_GLOBAL("target.BuildConfigSetAddLowerPass")
-    .set_body([](TVMArgs args, TVMRetValue* ret) {
-      BuildConfig cfg = args[0];
-      std::vector<std::pair<int, transform::Pass>> add_lower_pass;
-      CHECK_EQ(args.size() % 2, 1);
-      for (int i = 1; i < args.size(); i += 2) {
-        add_lower_pass.push_back(
-            std::make_pair(args[i].operator int(), args[i + 1].operator transform::Pass()));
-      }
-      cfg->add_lower_pass = add_lower_pass;
-    });
-
-TVM_REGISTER_GLOBAL("target.BuildConfigGetAddLowerPassInfo")
-    .set_body([](TVMArgs args, TVMRetValue* ret) {
-      // Return one of the following:
-      //  * Size of add_lower_pass if num_args == 1
-      //  * Phase index of pass if args are (config, index, true)
-      //  * Function of pass if args are (config, index, false)
-      BuildConfig cfg = args[0];
-      if (args.num_args == 1) {
-        *ret = static_cast<int64_t>(cfg->add_lower_pass.size());
-      } else {
-        int index = args[1];
-        bool get_phase = args[2];
-        auto item = cfg->add_lower_pass[index];
-        if (get_phase) {
-          *ret = item.first;
-        } else {
-          *ret = item.second;
-        }
-      }
-    });
-
 }  // namespace tvm
index 0189978..ae5e673 100644 (file)
 namespace tvm {
 namespace tir {
 
+struct InjectDoubleBufferConfigNode : public tvm::AttrsNode<InjectDoubleBufferConfigNode> {
+  int split_loop;
+
+  TVM_DECLARE_ATTRS(InjectDoubleBufferConfigNode, "tir.transform.InjectDoubleBufferConfig") {
+    TVM_ATTR_FIELD(split_loop).describe("Split loop factors").set_default(1);
+  }
+};
+
+class InjectDoubleBufferConfig : public Attrs {
+ public:
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(InjectDoubleBufferConfig, Attrs,
+                                            InjectDoubleBufferConfigNode);
+};
+
+TVM_REGISTER_NODE_TYPE(InjectDoubleBufferConfigNode);
+TVM_REGISTER_PASS_CONFIG_OPTION("tir.InjectDoubleBuffer", InjectDoubleBufferConfig);
+
 // Detect double buffer variables.
 class DoubleBufferDetector : public StmtExprVisitor {
  public:
@@ -258,16 +275,16 @@ class DoubleBufferInjector : public StmtExprMutator {
   std::unordered_map<const VarNode*, StorageEntry> dbuffer_info_;
 };
 
-Stmt InjectDoubleBuffer(Stmt stmt, int split_loop) {
-  return DoubleBufferInjector(split_loop).Inject(stmt);
-}
-
 namespace transform {
 
-Pass InjectDoubleBuffer(int split_loop) {
+Pass InjectDoubleBuffer() {
   auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
     auto* n = f.CopyOnWrite();
-    n->body = DoubleBufferInjector(split_loop).Inject(std::move(n->body));
+    auto cfg = ctx->GetConfig<InjectDoubleBufferConfig>("tir.InjectDoubleBuffer");
+    if (!cfg.defined()) {
+      cfg = AttrsWithDefaultValues<InjectDoubleBufferConfig>();
+    }
+    n->body = DoubleBufferInjector(cfg.value()->split_loop).Inject(std::move(n->body));
     return f;
   };
   return CreatePrimFuncPass(pass_func, 0, "tir.InjectDoubleBuffer", {});
index 6392e70..c72928b 100644 (file)
 namespace tvm {
 namespace tir {
 
+struct LoopPartitionConfigNode : public tvm::AttrsNode<LoopPartitionConfigNode> {
+  bool partition_const_loop;
+
+  TVM_DECLARE_ATTRS(LoopPartitionConfigNode, "tir.transform.LoopPartitionConfig") {
+    TVM_ATTR_FIELD(partition_const_loop).describe("Split constant loop").set_default(false);
+  }
+};
+
+class LoopPartitionConfig : public Attrs {
+ public:
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(LoopPartitionConfig, Attrs, LoopPartitionConfigNode);
+};
+
+TVM_REGISTER_NODE_TYPE(LoopPartitionConfigNode);
+TVM_REGISTER_PASS_CONFIG_OPTION("tir.LoopPartition", LoopPartitionConfig);
+
 using arith::DeduceBound;
 using arith::Intersect;
 using arith::IntSet;
@@ -74,11 +90,12 @@ bool ExprUseVars(PrimExpr expr, const std::unordered_set<const VarNode*>& vars)
 class CandidateSelector final : public StmtExprVisitor {
  public:
   using VarIsUsed = bool;
-  explicit CandidateSelector(bool split_const_loop) : split_const_loop_(split_const_loop) {}
+  explicit CandidateSelector(bool partition_const_loop)
+      : partition_const_loop_(partition_const_loop) {}
 
   void VisitStmt_(const ForNode* op) final {
-    // partition const loop when sets split_const_loop_
-    if (!is_const(op->min) || !is_const(op->extent) || split_const_loop_) {
+    // partition const loop when sets partition_const_loop_
+    if (!is_const(op->min) || !is_const(op->extent) || partition_const_loop_) {
       const VarNode* var = op->loop_var.get();
       record_.insert({var, false});
       StmtExprVisitor::VisitStmt_(op);
@@ -97,7 +114,7 @@ class CandidateSelector final : public StmtExprVisitor {
       CHECK(iv);
       Var var = iv->var;
       runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag);
-      if ((scope.rank == 0) && (!is_const(op->value) || split_const_loop_)) {
+      if ((scope.rank == 0) && (!is_const(op->value) || partition_const_loop_)) {
         record_.insert({var.get(), false});
         StmtExprVisitor::VisitStmt_(op);
         if (record_.at(var.get()) && !no_split_) {
@@ -147,7 +164,7 @@ class CandidateSelector final : public StmtExprVisitor {
  private:
   bool in_likely_{false};
   bool no_split_{false};
-  bool split_const_loop_{false};
+  bool partition_const_loop_{false};
   std::unordered_map<const VarNode*, VarIsUsed> record_;
 };
 
@@ -307,7 +324,8 @@ class ThreadPartitionInserter : public StmtMutator {
 // likely conditions
 class LoopPartitioner : public StmtMutator {
  public:
-  explicit LoopPartitioner(bool split_const_loop) : selector(CandidateSelector(split_const_loop)) {}
+  explicit LoopPartitioner(bool partition_const_loop)
+      : selector(CandidateSelector(partition_const_loop)) {}
 
   Stmt VisitAndMutate(Stmt stmt) {
     selector(stmt);
@@ -587,18 +605,22 @@ class RemoveLikelyTags : public StmtExprMutator {
   }
 };
 
-Stmt LoopPartition(Stmt stmt, bool split_const_loop) {
-  stmt = LoopPartitioner(split_const_loop).VisitAndMutate(std::move(stmt));
+Stmt LoopPartition(Stmt stmt, bool partition_const_loop) {
+  stmt = LoopPartitioner(partition_const_loop).VisitAndMutate(std::move(stmt));
   stmt = RemoveLikelyTags()(std::move(stmt));
   return stmt;
 }
 
 namespace transform {
 
-Pass LoopPartition(bool split_const_loop) {
+Pass LoopPartition() {
   auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
     auto* n = f.CopyOnWrite();
-    n->body = LoopPartition(std::move(n->body), split_const_loop);
+    auto cfg = ctx->GetConfig<LoopPartitionConfig>("tir.LoopPartition");
+    if (!cfg.defined()) {
+      cfg = AttrsWithDefaultValues<LoopPartitionConfig>();
+    }
+    n->body = LoopPartition(std::move(n->body), cfg.value()->partition_const_loop);
     return f;
   };
   return CreatePrimFuncPass(pass_func, 0, "tir.LoopPartition", {});
index 8378a88..4ccfbc3 100644 (file)
 namespace tvm {
 namespace tir {
 
-struct LoopUnrollConfig : public tvm::AttrsNode<LoopUnrollConfig> {
+struct UnrollLoopConfigNode : public tvm::AttrsNode<UnrollLoopConfigNode> {
   int auto_max_step;
   int auto_max_depth;
   int auto_max_extent;
   int explicit_unroll;
 
-  TVM_DECLARE_ATTRS(LoopUnrollConfig, "tir.transform.LoopUnrollConfig") {
+  TVM_DECLARE_ATTRS(UnrollLoopConfigNode, "tir.transform.UnrollLoopConfig") {
     TVM_ATTR_FIELD(auto_max_step)
         .describe("Threshold of number of steps in the loop to be automatically unrolled")
         .set_default(0);
@@ -61,8 +61,13 @@ struct LoopUnrollConfig : public tvm::AttrsNode<LoopUnrollConfig> {
   }
 };
 
-TVM_REGISTER_NODE_TYPE(LoopUnrollConfig);
-TVM_REGISTER_PASS_CONFIG_OPTION("tir.UnrollLoop", LoopUnrollConfig);
+class UnrollLoopConfig : public Attrs {
+ public:
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(UnrollLoopConfig, Attrs, UnrollLoopConfigNode);
+};
+
+TVM_REGISTER_NODE_TYPE(UnrollLoopConfigNode);
+TVM_REGISTER_PASS_CONFIG_OPTION("tir.UnrollLoop", UnrollLoopConfig);
 
 class LoopUnroller : public StmtExprMutator {
  public:
@@ -204,9 +209,9 @@ class LoopUnroller : public StmtExprMutator {
   arith::Analyzer analyzer_;
 };
 
-Stmt UnrollLoop(Stmt stmt, int auto_max_step, int auto_max_depth, int auto_max_extent,
-                bool explicit_unroll) {
-  Stmt ret = LoopUnroller(auto_max_step, auto_max_depth, auto_max_extent, explicit_unroll)(stmt);
+Stmt UnrollLoop(Stmt stmt, UnrollLoopConfig cfg) {
+  Stmt ret = LoopUnroller(cfg->auto_max_step, cfg->auto_max_depth, cfg->auto_max_extent,
+                          cfg->explicit_unroll)(stmt);
   if (!ret.same_as(stmt)) {
     return ConvertSSA(ret);
   } else {
@@ -216,11 +221,14 @@ Stmt UnrollLoop(Stmt stmt, int auto_max_step, int auto_max_depth, int auto_max_e
 
 namespace transform {
 
-Pass UnrollLoop(int auto_max_step, int auto_max_depth, int auto_max_extent, bool explicit_unroll) {
+Pass UnrollLoop() {
   auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
     auto* n = f.CopyOnWrite();
-    n->body = UnrollLoop(std::move(f->body), auto_max_step, auto_max_depth, auto_max_extent,
-                         explicit_unroll);
+    auto cfg = ctx->GetConfig<UnrollLoopConfig>("tir.UnrollLoop");
+    if (!cfg.defined()) {
+      cfg = AttrsWithDefaultValues<UnrollLoopConfig>();
+    }
+    n->body = UnrollLoop(std::move(f->body), cfg.value());
     return f;
   };
   return CreatePrimFuncPass(pass_func, 0, "tir.UnrollLoop", {});
index c9a91fc..fc9edf8 100644 (file)
@@ -50,11 +50,10 @@ TEST(BuildModule, Basic) {
   auto args = Array<Tensor>({A, B, C});
   std::unordered_map<Tensor, Buffer> binds;
 
-  auto config = BuildConfig::Create();
   auto target = target::llvm();
 
-  auto lowered = lower(s, args, "func", binds, config);
-  auto module = build(lowered, target, Target(), config);
+  auto lowered = lower(s, args, "func", binds);
+  auto module = build(lowered, target, Target());
 
   auto mali_target = Target::Create("opencl -model=Mali-T860MP4@800Mhz -device=mali");
   CHECK_EQ(mali_target->str(), "opencl -model=Mali-T860MP4@800Mhz -device=mali");
@@ -106,15 +105,14 @@ TEST(BuildModule, Heterogeneous) {
   With<Target> llvm_scope(target_llvm);
   auto s2 = create_schedule({elemwise_sub->op});
 
-  auto config = BuildConfig::Create();
   auto args1 = Array<Tensor>({A, B, elemwise_add});
   auto args2 = Array<Tensor>({copy, C, elemwise_sub});
 
   std::unordered_map<Tensor, Buffer> binds;
-  auto lowered_s1 = lower(s1, args1, "elemwise_add", binds, config);
-  auto lowered_s2 = lower(s2, args2, "elemwise_sub", binds, config);
+  auto lowered_s1 = lower(s1, args1, "elemwise_add", binds);
+  auto lowered_s2 = lower(s2, args2, "elemwise_sub", binds);
   Map<tvm::Target, IRModule> inputs = {{target_cuda, lowered_s1}, {target_llvm, lowered_s2}};
-  auto module = build(inputs, Target(), config);
+  auto module = build(inputs, Target());
 
   // Assertion for build.
   CHECK_EQ(module->imports().size(), 1);
index a872afe..301677e 100644 (file)
@@ -26,7 +26,7 @@ from tvm.micro import create_micro_mod
 from tvm.relay.testing import resnet
 
 # Use real micro device - an STM32F746 discovery board
-# SETUP: 
+# SETUP:
 # Be sure to have openocd installed and running
 # Ex : openocd -f board/stm32f7discovery.cfg
 # Be sure to have the ST CMSIS library downloaded, installed and
@@ -54,9 +54,9 @@ def relay_micro_build(func, dev_config, params=None):
     mod : tvm.runtime.Module
         graph runtime module for the target device
     """
-    disable_vectorize = tvm.target.build_config(disable_vectorize=True)
-    disable_fusion = relay.build_config(disabled_pass={'FuseOps'})
-    with disable_vectorize, disable_fusion:
+    with tvm.transform.PassContext(disabled_pass={'FuseOps'}, config={
+        "tir.disable_vectorize": True
+    }):
         graph, c_mod, params = relay.build(func, target=TARGET, params=params)
     micro_mod = micro.create_micro_mod(c_mod, dev_config)
     ctx = tvm.micro_dev(0)
@@ -76,7 +76,7 @@ break UTVMDone
 def reset_gdbinit():
     if 'server_port' not in DEV_CONFIG_A:
         return
-    try: 
+    try:
         gdb_init_dir = os.environ['MICRO_GDB_INIT_DIR']
     except KeyError:
         return
@@ -230,7 +230,9 @@ def test_conv2d():
     w_shape = list(map(lambda x: x.value, mod['main'].params[1].checked_type.shape))
     out_shape = list(map(lambda x: x.value, mod['main'].ret_type.shape))
 
-    with tvm.target.build_config(disable_vectorize=True):
+    with tvm.transform.PassContext(config={
+        "tir.disable_vectorize": True
+    }):
         graph, c_mod, params = relay.build(mod, target="c")
 
     with micro.Session(DEV_CONFIG_A):
index a981667..1e8c6da 100644 (file)
@@ -70,15 +70,9 @@ def test_fold_const():
         z = relay.add(y, relay.const(c_data))
         return relay.Function([x], z)
 
-    def FailPass():
-        def _transform(m, *args):
-            raise RuntimeError()
-        return tvm.transform.module_pass(_transform, opt_level=0)
-
     # the fold constant should work on any context.
-    with tvm.target.build_config(add_lower_pass=[(0, FailPass())]):
-        with tvm.target.create("cuda"):
-            zz = run_opt_pass(before(), transform.FoldConstant())
+    with tvm.target.create("cuda"):
+        zz = run_opt_pass(before(), transform.FoldConstant())
     zexpected = run_opt_pass(expected(), transform.InferType())
     assert tvm.ir.structural_equal(zz, zexpected)
 
index 28ccf6f..d6037b5 100644 (file)
@@ -530,6 +530,7 @@ def _tracer(module, info, is_before):
     if bool(is_before):
         __TRACE_COUNTER__ += 1
 
+
 def test_print_debug_callback():
     global __TRACE_COUNTER__
     shape = (1, 2, 3)
@@ -551,7 +552,7 @@ def test_print_debug_callback():
     with relay.build_config(opt_level=3, trace=_tracer):
         mod = seq(mod)
 
-    assert __TRACE_COUNTER__ == 4
+    assert __TRACE_COUNTER__ == 3
 
 
 if __name__ == "__main__":
index bec74fb..2eea3df 100644 (file)
@@ -49,9 +49,9 @@ def relay_micro_build(func, dev_config, params=None):
     mod : tvm.runtime.Module
         graph runtime module for the target device
     """
-    disable_vectorize = tvm.target.build_config(disable_vectorize=True)
-    disable_fusion = relay.build_config(disabled_pass={'FuseOps'})
-    with disable_vectorize, disable_fusion:
+    with tvm.transform.PassContext(disabled_pass={'FuseOps'}, config={
+        "tir.disable_vectorize": True
+    }):
         graph, c_mod, params = relay.build(func, target=TARGET, params=params)
     micro_mod = micro.create_micro_mod(c_mod, dev_config)
     ctx = tvm.micro_dev(0)
@@ -222,7 +222,9 @@ def test_conv2d():
     w_shape = list(map(lambda x: x.value, mod['main'].params[1].checked_type.shape))
     out_shape = list(map(lambda x: x.value, mod['main'].ret_type.shape))
 
-    with tvm.target.build_config(disable_vectorize=True):
+    with tvm.transform.PassContext(config={
+        "tir.disable_vectorize": True
+    }):
         graph, c_mod, params = relay.build(mod, target="c")
 
     with micro.Session(DEV_CONFIG_A):
@@ -362,10 +364,6 @@ if __name__ == "__main__":
     print()
     print('finished conv2d test')
     input('[press enter to continue]')
-    test_multiple_modules()
-    print()
-    print('finished multiple modules test')
-    input('[press enter to continue]')
     test_interleave_sessions()
     print()
     print('finished interleaved sessions test')
index c96531e..0f00e08 100644 (file)
@@ -91,8 +91,7 @@ def test_add_pipeline():
         tvm.testing.assert_allclose(
             c.asnumpy(), a.asnumpy() + b.asnumpy())
 
-    with tvm.target.build_config(offset_factor=4):
-        check_c()
+    check_c()
 
 
 def test_reinterpret():
index 9692058..efc3b4b 100644 (file)
@@ -217,7 +217,7 @@ def test_cuda_shuffle():
                 tvm.tir.stmt_functor.ir_transform(f.body, None, vectorizer, ['For']))
         return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name="MyVectorize")
 
-    with tvm.target.build_config(add_lower_pass=[(1, MyVectorize())]):
+    with tvm.transform.PassContext(config={"tir.add_lower_pass": [(1, MyVectorize())]}):
         module = tvm.build(sch, [a, b, c], target='cuda')
         a_ = np.array(list(range(64)), dtype='int32')
         b_ = np.array((list(range(4))[::-1]) * 16, dtype='int32')
index ff35de0..2e53bfd 100644 (file)
@@ -191,8 +191,7 @@ def test_llvm_add_pipeline():
         tvm.testing.assert_allclose(
             c.asnumpy(), a.asnumpy() + b.asnumpy())
 
-    with tvm.target.build_config(offset_factor=4):
-        check_llvm()
+    check_llvm()
 
 
 def test_llvm_persist_parallel():
@@ -306,7 +305,8 @@ def test_llvm_madd_pipeline():
             c.asnumpy(), a.asnumpy()[base:] + 1)
     check_llvm(64, 0, 2)
     check_llvm(4, 0, 1)
-    with tvm.target.build_config(restricted_func=False):
+
+    with tvm.transform.PassContext(config={"tir.noalias": False}):
         check_llvm(4, 0, 3)
 
 
@@ -435,7 +435,7 @@ def test_rank_zero_bound_checkers():
     def check_llvm(n):
         if not tvm.runtime.enabled("llvm"):
             return
-        with tvm.target.build_config(instrument_bound_checkers=True):
+        with tvm.transform.PassContext(config={"tir.instrument_bound_checkers": True}):
             A = te.placeholder((n, ), name='A')
             scale = te.placeholder((), name='scale')
             k = te.reduce_axis((0, n), name="k")
@@ -728,7 +728,7 @@ def test_llvm_shuffle():
 
         return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name="my_vectorize")
 
-    with tvm.target.build_config(add_lower_pass=[(1, my_vectorize())]):
+    with tvm.transform.PassContext(config={"tir.add_lower_pass": [(1, my_vectorize())]}):
         ir = tvm.lower(sch, [a, b, c], simple_mode=True)
         module = tvm.build(sch, [a, b, c])
         a_ = tvm.nd.array(np.arange(1, 9, dtype='int32'))
index 1fcc88d..e226b7a 100644 (file)
@@ -384,23 +384,22 @@ def test_gemm_bound():
 
 def test_bound_tensor_compute_op():
     def intrin_test():
-      m1 = te.var("m1")
-      n1 = te.var("n1")
-      a = te.placeholder((m1, n1), name='a')
-      c = te.compute((1, n1), lambda i, j : a[0, j] + a[1, j] + a[2, j], name='c')
-
-      Ab = tvm.tir.decl_buffer(a.shape, name="Abuf", offset_factor=1)
-      Cb = tvm.tir.decl_buffer(c.shape, name="Cbuf", offset_factor=1)
-
-      def intrin_func(ins, outs):
-        aa = ins[0]
-        cc = outs[0]
-        def _body():
-          ib = tvm.tir.ir_builder.create()
-          ib.emit(tvm.tir.call_extern("int32", "test", cc.access_ptr("w"), aa.access_ptr("r")))
-          return ib.get()
-        return _body()
-      with tvm.target.build_config(offset_factor=1):
+        m1 = te.var("m1")
+        n1 = te.var("n1")
+        a = te.placeholder((m1, n1), name='a')
+        c = te.compute((1, n1), lambda i, j : a[0, j] + a[1, j] + a[2, j], name='c')
+
+        Ab = tvm.tir.decl_buffer(a.shape, name="Abuf", offset_factor=1)
+        Cb = tvm.tir.decl_buffer(c.shape, name="Cbuf", offset_factor=1)
+
+        def intrin_func(ins, outs):
+            aa = ins[0]
+            cc = outs[0]
+            def _body():
+                ib = tvm.tir.ir_builder.create()
+                ib.emit(tvm.tir.call_extern("int32", "test", cc.access_ptr("w"), aa.access_ptr("r")))
+                return ib.get()
+            return _body()
         return te.decl_tensor_intrin(c.op, intrin_func, binds={a : Ab, c : Cb})
 
     test_func = intrin_test()
index ea17d89..11960ca 100644 (file)
@@ -50,14 +50,14 @@ def test_shared_memory():
             if not tvm.context(target).exist:
                 continue
             valid = [None]
-            with tvm.target.build_config(**{"add_lower_pass": [
+            with tvm.transform.PassContext(config={"tir.add_lower_pass": [
                 (2, get_verify_pass(valid,
                                     max_shared_memory_per_block=type_size * M - 1,
                                     max_threads_per_block=M))]}):
                 tvm.build(s, [A, B], target)
             assert not valid[0]
 
-            with tvm.target.build_config(**{"add_lower_pass": [
+            with tvm.transform.PassContext(config={"tir.add_lower_pass": [
                 (2, get_verify_pass(valid,
                                     max_shared_memory_per_block=type_size * M,
                                     max_threads_per_block=M))]}):
@@ -87,14 +87,14 @@ def test_local_memory():
             continue
 
         valid = [None]
-        with tvm.target.build_config(**{"add_lower_pass": [
+        with tvm.transform.PassContext(config={"tir.add_lower_pass": [
             (2, get_verify_pass(valid,
                                 max_local_memory_per_block=4 * M - 1,
                                 max_threads_per_block=1))]}):
             tvm.build(s, [A, B], target)
         assert not valid[0]
 
-        with tvm.target.build_config(**{"add_lower_pass": [
+        with tvm.transform.PassContext(config={"tir.add_lower_pass": [
             (2, get_verify_pass(valid,
                                 max_local_memory_per_block=4 * M,
                                 max_threads_per_block=1))]}):
@@ -122,21 +122,21 @@ def test_num_thread():
             continue
 
         valid = [None]
-        with tvm.target.build_config(**{"add_lower_pass": [
+        with tvm.transform.PassContext(config={"tir.add_lower_pass": [
             (2, get_verify_pass(valid,
                                 max_shared_memory_per_block=0,
                                 max_threads_per_block=N - 1))]}):
             tvm.build(s, [A, B], target)
         assert not valid[0]
 
-        with tvm.target.build_config(**{"add_lower_pass": [
+        with tvm.transform.PassContext(config={"tir.add_lower_pass": [
             (2, get_verify_pass(valid,
                                 max_shared_memory_per_block=0,
                                 max_threads_per_block=N))]}):
             tvm.build(s, [A, B], target)
         assert valid[0]
 
-        with tvm.target.build_config(**{"add_lower_pass": [
+        with tvm.transform.PassContext(config={"tir.add_lower_pass": [
             (2, get_verify_pass(valid,
                                 max_shared_memory_per_block=0,
                                 max_threads_per_block=N,
@@ -144,7 +144,7 @@ def test_num_thread():
             tvm.build(s, [A, B], target)
         assert not valid[0]
 
-        with tvm.target.build_config(**{"add_lower_pass": [
+        with tvm.transform.PassContext(config={"tir.add_lower_pass": [
             (2, get_verify_pass(valid,
                                 max_shared_memory_per_block=0,
                                 max_threads_per_block=N,
@@ -172,14 +172,14 @@ def test_multiple_kernels():
             continue
 
         valid = [None]
-        with tvm.target.build_config(**{"add_lower_pass": [
+        with tvm.transform.PassContext(config={"tir.add_lower_pass": [
             (2, get_verify_pass(valid,
                                 max_shared_memory_per_block=0,
                                 max_threads_per_block=N - 1))]}):
             tvm.build(s, [A, C], target)
         assert not valid[0]
 
-        with tvm.target.build_config(**{"add_lower_pass": [
+        with tvm.transform.PassContext(config={"tir.add_lower_pass": [
             (2, get_verify_pass(valid,
                                 max_shared_memory_per_block=0,
                                 max_threads_per_block=N))]}):
@@ -203,7 +203,7 @@ def test_wrong_bind():
             continue
 
         valid = [None]
-        with tvm.target.build_config(**{"add_lower_pass": [
+        with tvm.transform.PassContext(config={"tir.add_lower_pass": [
                 (2, get_verify_pass(valid, max_threads_per_block=N*N))]}):
             tvm.build(s, [A, B], target)
         assert not valid[0]
index dd1f6a3..0b6b167 100644 (file)
@@ -41,9 +41,13 @@ def test_double_buffer():
     })
 
     opt = tvm.transform.Sequential(
-        [tvm.tir.transform.InjectDoubleBuffer(2),
+        [tvm.tir.transform.InjectDoubleBuffer(),
          tvm.tir.transform.Simplify()])
-    mod = opt(mod)
+
+    with tvm.transform.PassContext(config={
+        "tir.InjectDoubleBuffer" : {"split_loop" : 2}
+    }):
+        mod = opt(mod)
     stmt = mod["db"].body
 
     assert isinstance(stmt.body.body, tvm.tir.Allocate)
index dcedca9..fa27fdd 100644 (file)
@@ -170,8 +170,10 @@ def test_in_bounds_const_loop_partition_ir():
     s = te.create_schedule(T.op)
     xo, xi = s[T].split(T.op.axis[0], factor=4)
 
-    with tvm.target.build_config(instrument_bound_checkers=True,
-                                 partition_const_loop=True):
+    with tvm.transform.PassContext(config={
+        "tir.instrument_bound_checkers": True,
+        "tir.LoopPartition": {"partition_const_loop": True}
+    }):
         mod = tvm.driver.lower(s, [A, B, T], name="main")
 
     stmt = mod["main"].body
@@ -185,8 +187,10 @@ def test_in_bounds_const_loop_partition_ir():
 
 
 def test_in_bounds_const_loop_partition_llvm():
-    with tvm.target.build_config(instrument_bound_checkers=True,
-                                 partition_const_loop=True):
+    with tvm.transform.PassContext(config={
+        "tir.instrument_bound_checkers": True,
+        "tir.LoopPartition": {"partition_const_loop": True}
+    }):
         n = 21
         A = te.placeholder((n, ), name='A')
         B = te.placeholder((n, ), name='B')
@@ -205,7 +209,10 @@ def test_in_bounds_const_loop_partition_llvm():
 
 @pytest.mark.xfail
 def test_out_of_bounds_const_loop_partition_llvm(index_a, index_b):
-    with tvm.target.build_config(instrument_bound_checkers=True, partition_const_loop=True):
+    with tvm.transform.PassContext(config={
+        "tir.instrument_bound_checkers": True,
+        "tir.LoopPartition": {"partition_const_loop": True}
+    }):
         n = 21
         A = te.placeholder((n, ), name='A')
         B = te.placeholder((n, ), name='B')
@@ -439,7 +446,9 @@ def test_out_of_bounds_tensors_with_zero_shape_op_with_not_zero_shape_llvm():
     tvm.testing.assert_allclose(d.asnumpy(), d_np)
 
 if __name__ == "__main__":
-    with tvm.target.build_config(instrument_bound_checkers=True):
+    with tvm.transform.PassContext(config={
+        "tir.instrument_bound_checkers": True,
+    }):
         # zero scale
         test_out_of_bounds_tensors_with_zero_shape_op_with_not_zero_shape_llvm()
         # in bound
index 59b8796..ce8c16e 100644 (file)
@@ -37,7 +37,7 @@ def test_basic():
     stmt = tvm.te.schedule.ScheduleOps(s, bounds)
 
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], stmt))
-    mod = tvm.tir.transform.LoopPartition(False)(mod)
+    mod = tvm.tir.transform.LoopPartition()(mod)
     stmt = tvm.tir.transform.Simplify()(mod)["main"].body
 
     assert(not any(
@@ -59,8 +59,11 @@ def test_const_loop():
     stmt = tvm.te.schedule.ScheduleOps(s, bounds)
 
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
-    mod = tvm.tir.transform.LoopPartition(True)(mod)
-    stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+    with tvm.transform.PassContext(config={
+        "tir.LoopPartition": {"partition_const_loop": True}
+    }):
+        mod = tvm.tir.transform.LoopPartition()(mod)
+        stmt = tvm.tir.transform.Simplify()(mod)["main"].body
 
     assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))))
 
@@ -78,7 +81,7 @@ def test_multi_loop():
     stmt = ib.get()
 
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n, m], stmt))
-    mod = tvm.tir.transform.LoopPartition(False)(mod)
+    mod = tvm.tir.transform.LoopPartition()(mod)
     stmt = tvm.tir.transform.Simplify()(mod)["main"].body
 
     assert(not any(collect_visit(stmt.body[0], lambda x: isinstance(x, tvm.tir.IfThenElse))))
@@ -101,7 +104,7 @@ def test_multi_if():
     stmt = ib.get()
 
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
-    mod = tvm.tir.transform.LoopPartition(False)(mod)
+    mod = tvm.tir.transform.LoopPartition()(mod)
     stmt = tvm.tir.transform.Simplify()(mod)["main"].body
 
     assert(not any(
@@ -125,7 +128,7 @@ def test_thread_axis():
     stmt = tvm.te.schedule.ScheduleOps(s, bounds)
 
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
-    mod = tvm.tir.transform.LoopPartition(False)(mod)
+    mod = tvm.tir.transform.LoopPartition()(mod)
     stmt = tvm.tir.transform.Simplify()(mod)["main"].body
 
     assert(not any(
@@ -166,7 +169,7 @@ def test_condition():
     stmt = ib.get()
 
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([m, n], stmt))
-    mod = tvm.tir.transform.LoopPartition(False)(mod)
+    mod = tvm.tir.transform.LoopPartition()(mod)
     stmt = tvm.tir.transform.Simplify()(mod)["main"].body
 
     assert(not any(collect_visit(stmt[0], lambda x: isinstance(x, tvm.tir.Select))))
@@ -182,8 +185,11 @@ def test_condition_EQ():
     stmt = ib.get()
 
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([m, n], stmt))
-    mod = tvm.tir.transform.LoopPartition(True)(mod)
-    stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+    with tvm.transform.PassContext(config={
+        "tir.LoopPartition": {"partition_const_loop": True}
+    }):
+        mod = tvm.tir.transform.LoopPartition()(mod)
+        stmt = tvm.tir.transform.Simplify()(mod)["main"].body
 
     assert(not any(collect_visit(stmt[0], lambda x: isinstance(x, tvm.tir.Select))))
 
@@ -217,7 +223,7 @@ def test_everything_during_deduction():
     stmt = ib.get()
 
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([m, n], stmt))
-    mod = tvm.tir.transform.LoopPartition(False)(mod)
+    mod = tvm.tir.transform.LoopPartition()(mod)
     stmt = tvm.tir.transform.Simplify()(mod)["main"].body
 
 
@@ -237,8 +243,12 @@ def test_single_likely():
     stmt = tvm.te.schedule.ScheduleOps(s, bounds)
 
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
-    mod = tvm.tir.transform.LoopPartition(True)(mod)
-    stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+
+    with tvm.transform.PassContext(config={
+        "tir.LoopPartition": {"partition_const_loop": True}
+    }):
+        mod = tvm.tir.transform.LoopPartition()(mod)
+        stmt = tvm.tir.transform.Simplify()(mod)["main"].body
 
     assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))))
 
@@ -261,8 +271,12 @@ def test_multi_likely():
     stmt = tvm.te.schedule.ScheduleOps(s, bounds)
 
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
-    mod = tvm.tir.transform.LoopPartition(True)(mod)
-    stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+
+    with tvm.transform.PassContext(config={
+        "tir.LoopPartition": {"partition_const_loop": True}
+    }):
+        mod = tvm.tir.transform.LoopPartition()(mod)
+        stmt = tvm.tir.transform.Simplify()(mod)["main"].body
 
     assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))))
 
@@ -292,8 +306,12 @@ def test_oneD_pool():
     stmt = ib.get()
 
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([m, data, out], stmt))
-    mod = tvm.tir.transform.LoopPartition(True)(mod)
-    stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+
+    with tvm.transform.PassContext(config={
+        "tir.LoopPartition": {"partition_const_loop": True}
+    }):
+        mod = tvm.tir.transform.LoopPartition()(mod)
+        stmt = tvm.tir.transform.Simplify()(mod)["main"].body
 
     assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))))
 
@@ -317,8 +335,11 @@ def test_cce_loop_1():
   stmt = ib.get()
 
   mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
-  mod = tvm.tir.transform.LoopPartition(True)(mod)
-  stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+  with tvm.transform.PassContext(config={
+        "tir.LoopPartition": {"partition_const_loop": True}
+  }):
+    mod = tvm.tir.transform.LoopPartition()(mod)
+    stmt = tvm.tir.transform.Simplify()(mod)["main"].body
 
   assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))))
 
@@ -338,10 +359,12 @@ def test_cce_loop_2():
 
   stmt = ib.get()
 
-
   mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
-  mod = tvm.tir.transform.LoopPartition(True)(mod)
-  stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+  with tvm.transform.PassContext(config={
+      "tir.LoopPartition": {"partition_const_loop": True}
+  }):
+    mod = tvm.tir.transform.LoopPartition()(mod)
+    stmt = tvm.tir.transform.Simplify()(mod)["main"].body
 
   assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))))
 
@@ -359,10 +382,13 @@ def test_cce_loop_3():
                 ib.emit(tvm.tir.call_extern('float16',"cce_intrisic",head1))
 
     stmt = ib.get()
-
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
-    mod = tvm.tir.transform.LoopPartition(True)(mod)
-    stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+
+    with tvm.transform.PassContext(config={
+        "tir.LoopPartition": {"partition_const_loop": True}
+    }):
+        mod = tvm.tir.transform.LoopPartition()(mod)
+        stmt = tvm.tir.transform.Simplify()(mod)["main"].body
 
     assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))))
 
@@ -394,8 +420,11 @@ def test_conv_tiling():
     stmt = tvm.te.schedule.ScheduleOps(s, bounds)
 
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
-    mod = tvm.tir.transform.LoopPartition(True)(mod)
-    stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+    with tvm.transform.PassContext(config={
+        "tir.LoopPartition": {"partition_const_loop": True}
+    }):
+        mod = tvm.tir.transform.LoopPartition()(mod)
+        stmt = tvm.tir.transform.Simplify()(mod)["main"].body
 
     assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))))
 
@@ -412,7 +441,9 @@ def test_multilevel_splitting_with_indivisble_factors():
     s[B].unroll(yi)
 
     ## But this does the right thing.
-    with tvm.target.build_config(partition_const_loop=True):
+    with tvm.transform.PassContext(config={
+        "tir.LoopPartition": {"partition_const_loop": True}
+    }):
         lowered_body = tvm.lower(s, [A, B], name="x")["x"].body
         def visit_stmt(op):
             return(isinstance(op, tvm.tir.Max))
@@ -433,7 +464,9 @@ def test_double_splitting_with_indivisible_factors():
     s[C].compute_at(s[D], do)
 
     target = 'llvm'
-    with tvm.target.build_config(partition_const_loop=True):
+    with tvm.transform.PassContext(config={
+        "tir.LoopPartition": {"partition_const_loop": True}
+    }):
         f = tvm.lower(s, [A, C, D], name="fadd1", simple_mode=False)
         func = tvm.build(f, target=target)
 
@@ -471,8 +504,11 @@ def test_simple_rfactor():
     mod1 = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt1))
     stmt1 = tvm.tir.transform.Simplify()(mod1)["main"].body
 
-    mod2 = tvm.tir.transform.LoopPartition(True)(mod1)
-    stmt2 = tvm.tir.transform.Simplify()(mod2)["main"].body
+    with tvm.transform.PassContext(config={
+        "tir.LoopPartition": {"partition_const_loop": True}
+    }):
+        mod2 = tvm.tir.transform.LoopPartition()(mod1)
+        stmt2 = tvm.tir.transform.Simplify()(mod2)["main"].body
 
     # make sure loop partition actually did something
     assert not tvm.ir.structural_equal(stmt1.body, stmt2.body)
index 8400915..5fea580 100644 (file)
@@ -106,10 +106,14 @@ def test_flatten_double_buffer():
     mod = tvm.IRModule.from_expr(
         tvm.tir.PrimFunc([A, C], stmt))
 
-    mod = tvm.transform.Sequential([
-        tvm.tir.transform.StorageFlatten(64),
-        tvm.tir.transform.InjectDoubleBuffer(2),
-        tvm.tir.transform.Simplify()])(mod)
+
+    with tvm.transform.PassContext(config={
+        "tir.InjectDoubleBuffer" : {"split_loop" : 2}
+    }):
+        mod = tvm.transform.Sequential([
+            tvm.tir.transform.StorageFlatten(64),
+            tvm.tir.transform.InjectDoubleBuffer(),
+            tvm.tir.transform.Simplify()])(mod)
 
     stmt = mod["main"].body
     assert isinstance(stmt.body.body, tvm.tir.Allocate)
index 26e9438..6863994 100644 (file)
@@ -35,14 +35,20 @@ def test_unroll_loop():
 
     assert isinstance(stmt, tvm.tir.For)
 
-    ret = tvm.tir.transform.UnrollLoop(16, 8, 0, True)(mod)["main"].body
+    with tvm.transform.PassContext(config={"tir.UnrollLoop": {"auto_max_step": 16}}):
+        ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body
+        assert not isinstance(ret, tvm.tir.For)
 
-    assert not isinstance(ret, tvm.tir.For)
-    ret = tvm.tir.transform.UnrollLoop(15, 8, 0, True)(mod)["main"].body
-    assert isinstance(ret, tvm.tir.For)
-    ret = tvm.tir.transform.UnrollLoop(16, 8, 0, False)(mod)["main"].body
-    assert isinstance(ret, tvm.tir.For)
-    assert ret.for_type == tvm.tir.For.Unrolled
+    with tvm.transform.PassContext(config={"tir.UnrollLoop": {"auto_max_step": 15}}):
+        ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body
+        assert isinstance(ret, tvm.tir.For)
+
+    with tvm.transform.PassContext(config={
+            "tir.UnrollLoop": {"auto_max_step": 16, "explicit_unroll": False}
+    }):
+        ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body
+        assert isinstance(ret, tvm.tir.For)
+        assert ret.for_type == tvm.tir.For.Unrolled
 
     ib = tvm.tir.ir_builder.create()
     ib.scope_attr(tvm.tir.const(0, "int32"), "pragma_auto_unroll_max_step", 16)
@@ -51,12 +57,15 @@ def test_unroll_loop():
     wrapped = tvm.tir.SeqStmt([wrapped, stmt])
     assert isinstance(ret, tvm.tir.For)
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], wrapped))
-    ret = tvm.tir.transform.UnrollLoop(0, 8, 0, False)(mod)["main"].body
 
-    assert isinstance(ret[0], tvm.tir.For)
-    assert ret[0].for_type == tvm.tir.For.Unrolled
-    assert isinstance(ret[1], tvm.tir.For)
-    assert ret[1].for_type != tvm.tir.For.Unrolled
+    with tvm.transform.PassContext(config={
+            "tir.UnrollLoop": {"auto_max_depth": 8, "explicit_unroll": False}
+    }):
+        ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body
+        assert isinstance(ret[0], tvm.tir.For)
+        assert ret[0].for_type == tvm.tir.For.Unrolled
+        assert isinstance(ret[1], tvm.tir.For)
+        assert ret[1].for_type != tvm.tir.For.Unrolled
 
 def test_unroll_fake_loop():
     ib = tvm.tir.ir_builder.create()
@@ -73,8 +82,15 @@ def test_unroll_fake_loop():
     stmt = ib.get()
 
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt))
-    ret = tvm.tir.transform.UnrollLoop(8, 0, 1, False)(mod)["main"].body
-    assert isinstance(ret[0], tvm.tir.Store)
+
+    with tvm.transform.PassContext(config={
+            "tir.UnrollLoop": {
+                "auto_max_depth": 8,
+                "auto_max_extent": 1,
+                "explicit_unroll": False
+            }}):
+        ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body
+        assert isinstance(ret[0], tvm.tir.Store)
 
 def test_unroll_single_count_loops():
     n = te.size_var('n')
@@ -87,9 +103,12 @@ def test_unroll_single_count_loops():
     # all parameters to UnrolLoops are default values except for
     # auto_unroll_max_extent which has been set to 1 (default:0)
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
-    ret = tvm.tir.transform.UnrollLoop(0, 8, 1, True)(mod)["main"].body
 
-    assert ret == stmt
+    with tvm.transform.PassContext(config={
+            "tir.UnrollLoop": {"auto_max_step": 1}
+    }):
+        ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body
+        assert ret == stmt
 
 if __name__ == "__main__":
     test_unroll_loop()
index b7da66f..ac1ac45 100644 (file)
@@ -270,8 +270,9 @@ def _intrin_popcount(m, k_i, w_b, x_b, unipolar):
             return irb.get()
         # body, reset, update
         return _instr(0), _instr(1), _instr(2)
-    with tvm.target.build_config(offset_factor=1, partition_const_loop=True):
-        return te.decl_tensor_intrin(z.op, _intrin_func, binds={w: Wb, x:Xb, z:Zb})
+    buffer_params = {"offset_factor": 1}
+    return te.decl_tensor_intrin(
+        z.op, _intrin_func, binds={w: Wb, x:Xb, z:Zb}, default_buffer_params=buffer_params)
 
 # ARM specific schedule that using custom microkernel
 def _schedule_spatial_conv2d_nhwc(cfg, s, data_pad, data_vec, kernel_vec,
index 9af7bef..7bd9bdb 100644 (file)
@@ -102,10 +102,10 @@ def intrin_gemm_MxKxN(M, K, N, in_dtype, out_dtype):
                                         cc.strides[0]))
             return ib.get()
         return _body(), _reduce_reset(), _reduce_update()
-    with tvm.target.build_config(offset_factor=1):
-        intrin_decl = te.decl_tensor_intrin(
-            C.op, intrin_func, binds={A: A_buf, B: B_buf, C: C_buf})
-        return intrin_decl, uniq_id
+
+    intrin_decl = te.decl_tensor_intrin(
+        C.op, intrin_func, binds={A: A_buf, B: B_buf, C: C_buf})
+    return intrin_decl, uniq_id
 
 
 def gemm_MxKxN_impl(M, K, N, uniq_id):
index 135c87d..bab9157 100644 (file)
@@ -107,5 +107,7 @@ def dot_int8_int8_int32(int32_lanes, dtype='uint'):
         # body, reset, update
         return _instr(0), _instr(1), _instr(2)
 
-    with tvm.target.build_config(offset_factor=1, partition_const_loop=True):
-        return te.decl_tensor_intrin(C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer})
+    buffer_params = {"offset_factor": 1}
+    return te.decl_tensor_intrin(
+        C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer},
+        default_buffer_params=buffer_params)
index a2b5273..72e054e 100644 (file)
@@ -129,10 +129,10 @@ def test_depthwise_conv2d_nchw():
         print("success")
 
     for device in ['cuda', 'opencl', 'rocm']:
-        with tvm.target.build_config(auto_unroll_max_step=128,
-                              unroll_explicit=device == 'rocm',
-                              detect_global_barrier=False,
-                              restricted_func=True):
+        with tvm.transform.PassContext(config={"tir.UnrollLoop": {
+            "auto_max_step": 128,
+            "explicit_unroll": device != "rocm"
+        }}):
             check_device(device)
 
 def test_depthwise_conv2d_nhwc():
@@ -218,9 +218,10 @@ def test_depthwise_conv2d_nhwc():
         print("success")
 
     for device in ['cuda', 'opencl', 'rocm']:
-        with tvm.target.build_config(auto_unroll_max_step=128,
-                              detect_global_barrier=False,
-                              restricted_func=True):
+        with tvm.transform.PassContext(config={"tir.UnrollLoop": {
+            "auto_max_step": 128,
+            "explicit_unroll": device != "cuda"
+        }}):
             check_device(device)
 
 if __name__ == "__main__":
index 69bda79..35cd477 100644 (file)
@@ -77,8 +77,11 @@ def test_conv2d_hwcn_map():
         w = tvm.nd.array(w_np, ctx)
         b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
         c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
-        with tvm.target.build_config(auto_unroll_max_step=128,
-                              unroll_explicit=device == 'rocm'):
+
+        with tvm.transform.PassContext(config={"tir.UrollLoop": {
+                "auto_unroll_max_step": 128,
+                "explicit_unroll": device == "rocm"
+        }}):
             func1 = tvm.build(s1, [A, W, B], device)
             func1(a, w, b)
             tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
index 196bf72..b35cd60 100644 (file)
@@ -146,8 +146,10 @@ def test_gemm():
         print("average time cost of %d runs = %g ms, %g GFLOPS." % (num_runs, t * 1e3, GFLOPS))
 
     for device in ["cuda", "opencl", "rocm", "nvptx", "vulkan"]:
-        with tvm.target.build_config(auto_unroll_max_step=128,
-                              unroll_explicit=(device != "cuda")):
+        with tvm.transform.PassContext(config={"tir.UnrollLoop": {
+            "auto_max_step": 128,
+            "explicit_unroll": device != "cuda"
+        }}):
             check_device(device)
 
 if __name__ == "__main__":
index 31f9bae..5e5caec 100644 (file)
@@ -27,12 +27,6 @@ TASK = "reduce_map"
 USE_MANUAL_CODE = False
 
 
-@tvm.register_func
-def tvm_callback_cuda_compile(code):
-    ptx = nvcc.compile_cuda(code, target="ptx")
-    return ptx
-
-
 def write_code(code, fname):
     with open(fname, "w") as f:
         f.write(code)
@@ -64,8 +58,9 @@ def test_reduce_map(in_shape, axis, keepdims, type="sum", test_id=0):
     else:
         raise NotImplementedError
     s = topi.cuda.schedule_reduce(B)
-    with tvm.target.build_config(auto_unroll_max_step=16,
-                          auto_unroll_min_depth=0):
+    with tvm.transform.PassContext(config={"tir.UnrollLoop": {
+        "auto_max_step": 16,
+    }}):
         fcuda = tvm.build(s, [A, B], "cuda", name="sum")
 
     # Test
index 4076eb6..be46d89 100644 (file)
@@ -188,10 +188,12 @@ def lstm():
         print("Time cost=%g" % eval_result.mean)
 
     # set unroll_explicit for more readable code.
-    with tvm.target.build_config(
-            detect_global_barrier=DETECT_GLOBAL_BARRIER,
-            auto_unroll_max_step=128,
-            unroll_explicit=False):
+    with tvm.transform.PassContext(config={
+        "tir.UnrollLoop": {
+            "auto_max_step": 128,
+        },
+        "tir.detect_global_barrier": DETECT_GLOBAL_BARRIER
+    }):
         check_device("cuda")
 
 if __name__ == "__main__":
index 9991895..444e27f 100644 (file)
@@ -127,10 +127,12 @@ def rnn_matexp():
     s[SS].bind(tx, thread_x)
 
     def check_device(target):
-        with tvm.target.build_config(
-                detect_global_barrier=detect_global_barrier,
-                auto_unroll_max_step=128,
-                unroll_explicit=False):
+        with tvm.transform.PassContext(config={
+            "tir.UnrollLoop": {
+                "auto_max_step": 128,
+            },
+            "tir.detect_global_barrier": detect_global_barrier
+        }):
             f = tvm.build(s, [s_scan, Whh], target)
         ctx = tvm.gpu(0) if target == "cuda" else tvm.cl(0)
         # launch the kernel.
index 03eae1c..db50572 100644 (file)
@@ -138,9 +138,8 @@ def vectorize(f, mod, ctx):
 # So far, we are done with writing this IR transformation pass. What we need to do next is to glue
 # this pass to TVM's lower pass.
 #
-# In TVM, there is a property called ``BuildConfig``. You can use this property to customize your
-# own lowering options. In this case, we inject the pass written above into the TVM standard lowering
-# pass by feeding **a list of tuple** as argument to ``add_lower_pass``. "Tuple" indicates different
+# In this case, we inject the pass written above into the TVM standard lowering
+# pass by feeding **a list of tuple** as argument to ``tir.add_lower_pass``. "Tuple" indicates different
 # phases of lowering. In TVM, there are four phases of lowering and user-customized ones will be
 # called after each phase is done.
 #
@@ -154,7 +153,7 @@ def vectorize(f, mod, ctx):
 # Thus, a good place to put this transformation pass is just after Phase 1.
 #
 
-with tvm.target.build_config(add_lower_pass=[(1, vectorize)]) as cfg:
+with tvm.transform.PassContext(config={"tir.add_lower_pass": [(1, vectorize)]}):
     print(tvm.lower(sch, [a, b, c]))
 
 #####################################################################
@@ -164,5 +163,5 @@ with tvm.target.build_config(add_lower_pass=[(1, vectorize)]) as cfg:
 # - Use ``tvm.tir.stmt_functor.post_order_visit`` to gather information on each IR nodes.
 # - Use ``tvm.tir.stmt_functor.ir_transform`` to transform IR nodes.
 # - Wrap up two above to write an IR-transformation function.
-# - Use ``tvm.target.build_config`` to put this function to TVM lowering pass
+# - Use ``tvm.transform.PassContext`` to put this function to TVM lowering pass
 #
index 6224c10..8a77c77 100644 (file)
@@ -115,8 +115,7 @@ def intrin_gemv(m, l):
                                 bb.access_ptr("r"),
                                 m, l, bb.strides[0]))
         return ib.get()
-    with tvm.target.build_config(offset_factor=1):
-        return te.decl_tensor_intrin(c.op, intrin_func, binds={a: Ab, b: Bb, c: Cb})
+    return te.decl_tensor_intrin(c.op, intrin_func, binds={a: Ab, b: Bb, c: Cb})
 
 ######################################################################
 # Here :code:`te.decl_tensor_intrin` declares how to execute the computation :code:`c.op`.
@@ -269,8 +268,7 @@ def intrin_gemv(m, l):
         def _reduce_update():
             return _body()
         return _body(), _reduce_reset(), _reduce_update()
-    with tvm.target.build_config(offset_factor=1):
-        return te.decl_tensor_intrin(c.op, intrin_func, binds={a: Ab, b: Bb, c: Cb})
+    return te.decl_tensor_intrin(c.op, intrin_func, binds={a: Ab, b: Bb, c: Cb})
 
 ######################################################################
 # Note that :code:`intrin_func` now returns a triplet:
index 44b9de3..cd40a91 100644 (file)
@@ -331,7 +331,9 @@ print(tvm.lower(s, [A, W, Conv], simple_mode=True))
 
 ctx = tvm.gpu(0)
 if nvcc.have_tensorcore(ctx.compute_version):
-    with tvm.target.build_config(auto_unroll_max_step=16):
+    with tvm.transform.PassContext(config={"tir.UnrollLoop": {
+        "auto_max_step": 16
+    }}):
         func = tvm.build(s, [A, W, Conv], 'cuda')
     a_np = np.random.uniform(size=data_shape).astype(A.dtype)
     w_np = np.random.uniform(size=kernel_shape).astype(W.dtype)
index 50cc1eb..7dbd475 100644 (file)
@@ -287,10 +287,9 @@ def tune_and_evaluate(M, N, L, dtype, layout):
   print(best_config)
   with autotvm.apply_history_best('matmul.log'):
     with tvm.target.create("cuda"):
-        with tvm.target.build_config():
-            s, arg_bufs = test_gemm(N, L, M, dtype, layout)
-            print(tvm.lower(s, arg_bufs, simple_mode=True))
-            func = tvm.build(s, arg_bufs)
+          s, arg_bufs = test_gemm(N, L, M, dtype, layout)
+          print(tvm.lower(s, arg_bufs, simple_mode=True))
+          func = tvm.build(s, arg_bufs)
   dev_module = func.imported_modules[0]
   print(dev_module.get_source())
 
index 40bee86..2d67edb 100644 (file)
@@ -45,7 +45,7 @@ def build_config(debug_flag=0, **kwargs):
 
     Returns
     -------
-    build_config: BuildConfig
+    build_config: tvm.transform.PassContext
         The build config that can be used in TVM.
 
     Example
@@ -83,7 +83,14 @@ def build_config(debug_flag=0, **kwargs):
     pass_list.append((3, tvm.tir.transform.LowerDeviceStorageAccessInfo()))
     pass_list.append((3, transform.FoldUopLoop()))
     pass_list.append((3, transform.CPUAccessRewrite()))
-    return tvm.target.build_config(add_lower_pass=pass_list, **kwargs)
+    config = {
+        "tir.add_lower_pass": pass_list
+    }
+    if kwargs.get("config"):
+        config.update(kwargs[config])
+        del kwargs["config"]
+
+    return tvm.transform.PassContext(config=config, **kwargs)
 
 
 def lower(*args, **kwargs):
@@ -96,8 +103,8 @@ def lower(*args, **kwargs):
     --------
     tvm.lower : The original TVM's lower function
     """
-    cfg = tvm.target.BuildConfig.current()
-    if not cfg.add_lower_pass:
+    pass_ctx = tvm.transform.PassContext.current()
+    if not pass_ctx.config.get("add_lower_pass"):
         with build_config():
             return tvm.lower(*args, **kwargs)
     return tvm.lower(*args, **kwargs)
@@ -113,8 +120,8 @@ def build(*args, **kwargs):
     --------
     tvm.build : The original TVM's build function
     """
-    cfg = tvm.target.BuildConfig.current()
-    if not cfg.add_lower_pass:
+    pass_ctx = tvm.transform.PassContext.current()
+    if not pass_ctx.config.get("tir.add_lower_pass"):
         with build_config():
             return tvm.build(*args, **kwargs)
     return tvm.build(*args, **kwargs)
index 1de35c0..26c240e 100644 (file)
@@ -271,16 +271,16 @@ if __name__ == '__main__':
 
         # Compile network
         print("Compiling network with best tuning parameters...")
-        with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}):
-            if target.device_name != "vta":
+        if target.device_name != "vta":
+            with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}):
+                graph, lib, params = relay.build(
+                    relay_prog, target=target,
+                    params=params, target_host=env.target_host)
+        else:
+            with vta.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}):
                 graph, lib, params = relay.build(
                     relay_prog, target=target,
                     params=params, target_host=env.target_host)
-            else:
-                with vta.build_config():
-                    graph, lib, params = relay.build(
-                        relay_prog, target=target,
-                        params=params, target_host=env.target_host)
 
         # Export library
         temp = util.tempdir()
index 571dde6..63106a5 100644 (file)
@@ -392,19 +392,19 @@ def tune_and_evaluate(tuning_opt):
     with autotvm.tophub.context(target, extra_files=[log_file]):
         # Compile network
         print("Compile...")
-        with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}):
-            if target.device_name != "vta":
+        if target.device_name != "vta":
+            with tvm.transform.PassContext(opt_level=3, disabled_pass={"AlterOpLayout"}):
                 graph, lib, params = relay.build(relay_prog,
-                                                 target=target,
-                                                 params=params,
-                                                 target_host=env.target_host)
-            else:
-                with vta.build_config():
-                    graph, lib, params = relay.build(
-                        relay_prog,
-                        target=target,
-                        params=params,
-                        target_host=env.target_host)
+                                                target=target,
+                                                params=params,
+                                                target_host=env.target_host)
+        else:
+            with vta.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}):
+                graph, lib, params = relay.build(
+                    relay_prog,
+                    target=target,
+                    params=params,
+                    target_host=env.target_host)
 
         # Export library
         print("Upload...")
index 62fb321..7ca4b98 100644 (file)
@@ -188,16 +188,16 @@ with autotvm.tophub.context(target):
         relay_prog = mod["main"]
 
     # Compile Relay program with AlterOpLayout disabled
-    with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}):
-        if target.device_name != "vta":
+    if target.device_name != "vta":
+        with tvm.transform.PassContext(opt_level=3, disabled_pass={"AlterOpLayout"}):
+            graph, lib, params = relay.build(
+                relay_prog, target=target,
+                params=params, target_host=env.target_host)
+    else:
+        with vta.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}):
             graph, lib, params = relay.build(
                 relay_prog, target=target,
                 params=params, target_host=env.target_host)
-        else:
-            with vta.build_config():
-                graph, lib, params = relay.build(
-                    relay_prog, target=target,
-                    params=params, target_host=env.target_host)
 
     # Measure Relay build time
     build_time = time.time() - build_start